From 7cab9720650c9836a8b7cf4c8e7591b066a91523 Mon Sep 17 00:00:00 2001 From: bad Date: Fri, 8 Oct 2021 23:53:37 +0200 Subject: [PATCH] Wildcard subscriptions --- session/session.go | 44 ++++++++--------- subscription/subscription.go | 82 +++++++++++++++++++++---------- subscription/subscription_test.go | 46 ++++++++++++++--- 3 files changed, 117 insertions(+), 55 deletions(-) diff --git a/session/session.go b/session/session.go index a441774..d4d4f8b 100644 --- a/session/session.go +++ b/session/session.go @@ -159,30 +159,30 @@ func (s *Session) VisitConnect(_ packets.ConnectPacket) { } func (s *Session) VisitPublish(p packets.PublishPacket) { - subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName) - defer lock.Unlock() - if p.QOSLevel == 0 { - if p.PacketId != nil { - log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) - return + subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName) + for _, subNode := range subNodes { + subNode.NodeLock.RLock() + defer subNode.NodeLock.RUnlock() + if p.QOSLevel == 0 { + if p.PacketId != nil { + log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) + return + } + } else if p.QOSLevel == 1 { + var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess + ack := packets.PubackPacket{ + PacketID: *p.PacketId, + Reason: reason, + } + s.Connection.sendPacket(ack) + } else if p.QOSLevel == 2 { + panic("UNIMPLEMENTED QOS level 2") } - } else if p.QOSLevel == 1 { - var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess - if len(subs) == 0 { - reason = packets.PubackReasonCodeNoMatchingSubscribers - } - ack := packets.PubackPacket{ - PacketID: *p.PacketId, - Reason: reason, - } - s.Connection.sendPacket(ack) - } else if p.QOSLevel == 2 { - panic("UNIMPLEMENTED QOS level 2") - } - for _, sub := range subs { - if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { - go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) + for _, sub := range subNode.Subscriptions { + if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { + go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) + } } } } diff --git a/subscription/subscription.go b/subscription/subscription.go index 191f859..aaeb205 100644 --- a/subscription/subscription.go +++ b/subscription/subscription.go @@ -19,9 +19,9 @@ type Subscription struct { } type SubscriptionTreeNode struct { - subscriptions []Subscription children map[string]*SubscriptionTreeNode - nodeLock sync.RWMutex + Subscriptions []Subscription + NodeLock sync.RWMutex } func newSubscriptionTreeNode() *SubscriptionTreeNode { @@ -37,34 +37,36 @@ func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode { field := fields[0] - s.nodeLock.RLock() + s.NodeLock.RLock() _, exists := s.children[field] // Insert a value into the map if one doesn't exist yet if !exists { // Can't upgrade a read lock so we need to unlock and // check again, this time with a write lock - s.nodeLock.RUnlock() - s.nodeLock.Lock() + s.NodeLock.RUnlock() + s.NodeLock.Lock() _, exists = s.children[field] if !exists { s.children[field] = newSubscriptionTreeNode() } - s.nodeLock.Unlock() - s.nodeLock.RLock() + s.NodeLock.Unlock() + s.NodeLock.RLock() } child, _ := s.children[field] - s.nodeLock.RUnlock() + s.NodeLock.RUnlock() return child.findNode(fields[1:]) } + + func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) { - for i, sub := range s.subscriptions { + for i, sub := range s.Subscriptions { if sub.SubscriptionChannel == subChan { - lst := len(s.subscriptions) - 1 - s.subscriptions[i] = s.subscriptions[lst] - s.subscriptions = s.subscriptions[:lst] + lst := len(s.Subscriptions) - 1 + s.Subscriptions[i] = s.Subscriptions[lst] + s.Subscriptions = s.Subscriptions[:lst] } } } @@ -73,34 +75,64 @@ func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subCha sub := Subscription{subChan, topicFilter} node := s.findNode(topicFilter.Topic.Fields) - node.nodeLock.Lock() - node.subscriptions = append(node.subscriptions, sub) - node.nodeLock.Unlock() + node.NodeLock.Lock() + node.Subscriptions = append(node.Subscriptions, sub) + node.NodeLock.Unlock() } func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) { node := s.findNode(topic.Fields) - node.nodeLock.Lock() + node.NodeLock.Lock() node.removeSubscription(subChan) - node.nodeLock.Unlock() + node.NodeLock.Unlock() } func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) { for _, node := range s.children { - node.nodeLock.Lock() + node.NodeLock.Lock() node.removeSubscription(subChan) - node.nodeLock.Unlock() + node.NodeLock.Unlock() node.RemoveSubsForChannel(subChan) } } -func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) { +// Returns the subscriptions whose filters match the given topic name +func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) []*SubscriptionTreeNode { fields := strings.Split(topicName, "/") - - child := s.findNode(fields) - locker := child.nodeLock.RLocker() - locker.Lock() - return child.subscriptions, locker + return s.matchSubscriptions(fields) +} + +// Returns nodes with subscriptions that match the given topic +func (s *SubscriptionTreeNode) matchSubscriptions(fields []string) []*SubscriptionTreeNode { + if len(fields) == 0 { + return []*SubscriptionTreeNode{s} + } + + sub := make([]*SubscriptionTreeNode, 0) + + s.NodeLock.RLock() + + // Single level wildcard(+) + if SlWildcard, exists := s.children["+"]; exists { + sub = append(sub, SlWildcard.matchSubscriptions(fields[1:])...) + } + + // Multi level wildcard(#) + if MlWildcard, exists := s.children["#"]; exists { + sub = append(sub, MlWildcard) + } + + field := fields[0] + if field == "#" || field == "+" { + // TODO handle gracefully + panic("Wildcard in topic") + } + + if child, exists := s.children[field]; exists { + sub = append(sub, child.matchSubscriptions(fields[1:])...) + } + s.NodeLock.RUnlock() + return sub } diff --git a/subscription/subscription_test.go b/subscription/subscription_test.go index dfc442b..9849b82 100644 --- a/subscription/subscription_test.go +++ b/subscription/subscription_test.go @@ -6,22 +6,52 @@ import ( "badat.dev/maeqtt/v2/mqtt/packets" ) -func TestSubscribe(t *testing.T) { +func assertMatches(topic packets.Topic, topicName string, shouldMatch bool, t *testing.T) { tree := newSubscriptionTreeNode() - topic, _ := packets.ParseTopic("a/b/c") channel := make(SubscriptionChannel) topicFilter := packets.TopicFilter{ Topic: topic, MaxQoS: 1, } tree.Subscribe(topicFilter, channel) - subs, lock := tree.GetSubscriptions("a/b/c") - defer lock.Unlock() + subs := tree.GetSubscriptions(topicName) - if len(subs) != 1 { - t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs)) + if (len(subs) != 1) && shouldMatch { + t.Errorf("Topic %v did not match %v", topic, topicName) } - if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel { - t.Error("Error with data stored in a subscription") + if (len(subs) == 1) && !shouldMatch { + t.Errorf("Topic %v matched %v (it wasn't supposed to)", topic, topicName ) } } + +func TestSubscribe(t *testing.T) { + topic, _ := packets.ParseTopic("a/b/c") + assertMatches(topic, "a/b/c", true, t) + assertMatches(topic, "a/c/c", false, t) + assertMatches(topic, "b/b/c", false, t) + assertMatches(topic, "aaa/c/a", false, t) +} + +func TestSingleLevelWildcard(t *testing.T) { + topic, _ := packets.ParseTopic("a/+/c") + assertMatches(topic, "a/b/c", true, t) + assertMatches(topic, "a/c/c", true, t) + assertMatches(topic, "a/b/d", false, t) + assertMatches(topic, "aaa/c/a", false, t) + + topic, _ = packets.ParseTopic("+/+/+") + assertMatches(topic, "a/b/c", true, t) + assertMatches(topic, "a/c/c", true, t) + assertMatches(topic, "a/b/d/e", false, t) + assertMatches(topic, "c/a", true, t) +} + +func TestMultiLevelWildcard(t *testing.T) { + topic, _ := packets.ParseTopic("a/b/c/#") + assertMatches(topic, "a/b/c", true, t) + assertMatches(topic, "a/b/c/a", true, t) + assertMatches(topic, "a/b/c/d", true, t) + assertMatches(topic, "a/b/c/f", true, t) + + assertMatches(topic, "a/b/d/a", false, t) +}