package subscription //TODO WILDCARD SUBSCRIPTIONS import ( "strings" "sync" "badat.dev/maeqtt/v2/mqtt/packets" ) var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode() type SubscriptionChannel chan packets.PublishPacket type Subscription struct { SubscriptionChannel packets.TopicFilter } type SubscriptionTreeNode struct { subscriptions []Subscription children map[string]*SubscriptionTreeNode nodeLock sync.RWMutex } func newSubscriptionTreeNode() *SubscriptionTreeNode { s := SubscriptionTreeNode{} s.children = make(map[string]*SubscriptionTreeNode) return &s } func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode { if len(fields) == 0 { return s } field := fields[0] s.nodeLock.RLock() _, exists := s.children[field] // Insert a value into the map if one doesn't exist yet if !exists { // Can't upgrade a read lock so we need to unlock and // check again, this time with a write lock s.nodeLock.RUnlock() s.nodeLock.Lock() _, exists = s.children[field] if !exists { s.children[field] = newSubscriptionTreeNode() } s.nodeLock.Unlock() s.nodeLock.RLock() } child, _ := s.children[field] s.nodeLock.RUnlock() return child.findNode(fields[1:]) } func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) { for i, sub := range s.subscriptions { if sub.SubscriptionChannel == subChan { lst := len(s.subscriptions) - 1 s.subscriptions[i] = s.subscriptions[lst] s.subscriptions = s.subscriptions[:lst] } } } func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) { sub := Subscription{subChan, topicFilter} node := s.findNode(topicFilter.Topic.Fields) node.nodeLock.Lock() node.subscriptions = append(node.subscriptions, sub) node.nodeLock.Unlock() } func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) { node := s.findNode(topic.Fields) node.nodeLock.Lock() node.removeSubscription(subChan) node.nodeLock.Unlock() } func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) { for _, node := range s.children { node.nodeLock.Lock() node.removeSubscription(subChan) node.nodeLock.Unlock() node.RemoveSubsForChannel(subChan) } } func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) { fields := strings.Split(topicName, "/") child := s.findNode(fields) locker := child.nodeLock.RLocker() locker.Lock() return child.subscriptions, locker }