package subscription import ( "strings" "sync" "badat.dev/maeqtt/v2/mqtt/packets" ) var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode() type SubscriptionChannel chan packets.PublishPacket type Subscription struct { SubscriptionChannel packets.TopicFilter } type SubscriptionTreeNode struct { children map[string]*SubscriptionTreeNode Subscriptions []Subscription NodeLock sync.RWMutex } func newSubscriptionTreeNode() *SubscriptionTreeNode { s := SubscriptionTreeNode{} s.children = make(map[string]*SubscriptionTreeNode) return &s } func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode { if len(fields) == 0 { return s } field := fields[0] s.NodeLock.RLock() _, exists := s.children[field] // Insert a value into the map if one doesn't exist yet if !exists { // Can't upgrade a read lock so we need to unlock and // check again, this time with a write lock s.NodeLock.RUnlock() s.NodeLock.Lock() _, exists = s.children[field] if !exists { s.children[field] = newSubscriptionTreeNode() } s.NodeLock.Unlock() s.NodeLock.RLock() } child, _ := s.children[field] s.NodeLock.RUnlock() return child.findNode(fields[1:]) } func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) { for i, sub := range s.Subscriptions { if sub.SubscriptionChannel == subChan { lst := len(s.Subscriptions) - 1 s.Subscriptions[i] = s.Subscriptions[lst] s.Subscriptions = s.Subscriptions[:lst] } } } func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) { sub := Subscription{subChan, topicFilter} node := s.findNode(topicFilter.Topic.Fields) node.NodeLock.Lock() node.Subscriptions = append(node.Subscriptions, sub) node.NodeLock.Unlock() } func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) { node := s.findNode(topic.Fields) node.NodeLock.Lock() node.removeSubscription(subChan) node.NodeLock.Unlock() } func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) { for _, node := range s.children { node.NodeLock.Lock() node.removeSubscription(subChan) node.NodeLock.Unlock() node.RemoveSubsForChannel(subChan) } } // Returns the subscriptions whose filters match the given topic name func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) []*SubscriptionTreeNode { fields := strings.Split(topicName, "/") return s.matchSubscriptions(fields) } // Returns nodes with subscriptions that match the given topic func (s *SubscriptionTreeNode) matchSubscriptions(fields []string) []*SubscriptionTreeNode { if len(fields) == 0 { return []*SubscriptionTreeNode{s} } sub := make([]*SubscriptionTreeNode, 0) s.NodeLock.RLock() // Single level wildcard(+) if SlWildcard, exists := s.children["+"]; exists { sub = append(sub, SlWildcard.matchSubscriptions(fields[1:])...) } // Multi level wildcard(#) if MlWildcard, exists := s.children["#"]; exists { sub = append(sub, MlWildcard) } field := fields[0] // this goes against the spec but I'm lazy so let's just be sane but not really correct if child, exists := s.children[field]; exists && field != "+" && field != "#" { sub = append(sub, child.matchSubscriptions(fields[1:])...) } s.NodeLock.RUnlock() return sub }