Wildcard subscriptions

This commit is contained in:
bad 2021-10-08 23:53:37 +02:00
parent a4eaf5f8f7
commit 7cab972065
3 changed files with 117 additions and 55 deletions

View file

@ -159,8 +159,10 @@ func (s *Session) VisitConnect(_ packets.ConnectPacket) {
} }
func (s *Session) VisitPublish(p packets.PublishPacket) { func (s *Session) VisitPublish(p packets.PublishPacket) {
subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName) subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
defer lock.Unlock() for _, subNode := range subNodes {
subNode.NodeLock.RLock()
defer subNode.NodeLock.RUnlock()
if p.QOSLevel == 0 { if p.QOSLevel == 0 {
if p.PacketId != nil { if p.PacketId != nil {
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
@ -168,9 +170,6 @@ func (s *Session) VisitPublish(p packets.PublishPacket) {
} }
} else if p.QOSLevel == 1 { } else if p.QOSLevel == 1 {
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
if len(subs) == 0 {
reason = packets.PubackReasonCodeNoMatchingSubscribers
}
ack := packets.PubackPacket{ ack := packets.PubackPacket{
PacketID: *p.PacketId, PacketID: *p.PacketId,
Reason: reason, Reason: reason,
@ -180,12 +179,13 @@ func (s *Session) VisitPublish(p packets.PublishPacket) {
panic("UNIMPLEMENTED QOS level 2") panic("UNIMPLEMENTED QOS level 2")
} }
for _, sub := range subs { for _, sub := range subNode.Subscriptions {
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
} }
} }
} }
}
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
err := s.Connection.close() err := s.Connection.close()

View file

@ -19,9 +19,9 @@ type Subscription struct {
} }
type SubscriptionTreeNode struct { type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex Subscriptions []Subscription
NodeLock sync.RWMutex
} }
func newSubscriptionTreeNode() *SubscriptionTreeNode { func newSubscriptionTreeNode() *SubscriptionTreeNode {
@ -37,34 +37,36 @@ func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
field := fields[0] field := fields[0]
s.nodeLock.RLock() s.NodeLock.RLock()
_, exists := s.children[field] _, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet // Insert a value into the map if one doesn't exist yet
if !exists { if !exists {
// Can't upgrade a read lock so we need to unlock and // Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock // check again, this time with a write lock
s.nodeLock.RUnlock() s.NodeLock.RUnlock()
s.nodeLock.Lock() s.NodeLock.Lock()
_, exists = s.children[field] _, exists = s.children[field]
if !exists { if !exists {
s.children[field] = newSubscriptionTreeNode() s.children[field] = newSubscriptionTreeNode()
} }
s.nodeLock.Unlock() s.NodeLock.Unlock()
s.nodeLock.RLock() s.NodeLock.RLock()
} }
child, _ := s.children[field] child, _ := s.children[field]
s.nodeLock.RUnlock() s.NodeLock.RUnlock()
return child.findNode(fields[1:]) return child.findNode(fields[1:])
} }
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) { func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
for i, sub := range s.subscriptions { for i, sub := range s.Subscriptions {
if sub.SubscriptionChannel == subChan { if sub.SubscriptionChannel == subChan {
lst := len(s.subscriptions) - 1 lst := len(s.Subscriptions) - 1
s.subscriptions[i] = s.subscriptions[lst] s.Subscriptions[i] = s.Subscriptions[lst]
s.subscriptions = s.subscriptions[:lst] s.Subscriptions = s.Subscriptions[:lst]
} }
} }
} }
@ -73,34 +75,64 @@ func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subCha
sub := Subscription{subChan, topicFilter} sub := Subscription{subChan, topicFilter}
node := s.findNode(topicFilter.Topic.Fields) node := s.findNode(topicFilter.Topic.Fields)
node.nodeLock.Lock() node.NodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub) node.Subscriptions = append(node.Subscriptions, sub)
node.nodeLock.Unlock() node.NodeLock.Unlock()
} }
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) { func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
node := s.findNode(topic.Fields) node := s.findNode(topic.Fields)
node.nodeLock.Lock() node.NodeLock.Lock()
node.removeSubscription(subChan) node.removeSubscription(subChan)
node.nodeLock.Unlock() node.NodeLock.Unlock()
} }
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) { func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
for _, node := range s.children { for _, node := range s.children {
node.nodeLock.Lock() node.NodeLock.Lock()
node.removeSubscription(subChan) node.removeSubscription(subChan)
node.nodeLock.Unlock() node.NodeLock.Unlock()
node.RemoveSubsForChannel(subChan) node.RemoveSubsForChannel(subChan)
} }
} }
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) { // Returns the subscriptions whose filters match the given topic name
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) []*SubscriptionTreeNode {
fields := strings.Split(topicName, "/") fields := strings.Split(topicName, "/")
return s.matchSubscriptions(fields)
child := s.findNode(fields) }
locker := child.nodeLock.RLocker()
locker.Lock() // Returns nodes with subscriptions that match the given topic
return child.subscriptions, locker func (s *SubscriptionTreeNode) matchSubscriptions(fields []string) []*SubscriptionTreeNode {
if len(fields) == 0 {
return []*SubscriptionTreeNode{s}
}
sub := make([]*SubscriptionTreeNode, 0)
s.NodeLock.RLock()
// Single level wildcard(+)
if SlWildcard, exists := s.children["+"]; exists {
sub = append(sub, SlWildcard.matchSubscriptions(fields[1:])...)
}
// Multi level wildcard(#)
if MlWildcard, exists := s.children["#"]; exists {
sub = append(sub, MlWildcard)
}
field := fields[0]
if field == "#" || field == "+" {
// TODO handle gracefully
panic("Wildcard in topic")
}
if child, exists := s.children[field]; exists {
sub = append(sub, child.matchSubscriptions(fields[1:])...)
}
s.NodeLock.RUnlock()
return sub
} }

View file

@ -6,22 +6,52 @@ import (
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
) )
func TestSubscribe(t *testing.T) { func assertMatches(topic packets.Topic, topicName string, shouldMatch bool, t *testing.T) {
tree := newSubscriptionTreeNode() tree := newSubscriptionTreeNode()
topic, _ := packets.ParseTopic("a/b/c")
channel := make(SubscriptionChannel) channel := make(SubscriptionChannel)
topicFilter := packets.TopicFilter{ topicFilter := packets.TopicFilter{
Topic: topic, Topic: topic,
MaxQoS: 1, MaxQoS: 1,
} }
tree.Subscribe(topicFilter, channel) tree.Subscribe(topicFilter, channel)
subs, lock := tree.GetSubscriptions("a/b/c") subs := tree.GetSubscriptions(topicName)
defer lock.Unlock()
if len(subs) != 1 { if (len(subs) != 1) && shouldMatch {
t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs)) t.Errorf("Topic %v did not match %v", topic, topicName)
} }
if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel { if (len(subs) == 1) && !shouldMatch {
t.Error("Error with data stored in a subscription") t.Errorf("Topic %v matched %v (it wasn't supposed to)", topic, topicName )
} }
} }
func TestSubscribe(t *testing.T) {
topic, _ := packets.ParseTopic("a/b/c")
assertMatches(topic, "a/b/c", true, t)
assertMatches(topic, "a/c/c", false, t)
assertMatches(topic, "b/b/c", false, t)
assertMatches(topic, "aaa/c/a", false, t)
}
func TestSingleLevelWildcard(t *testing.T) {
topic, _ := packets.ParseTopic("a/+/c")
assertMatches(topic, "a/b/c", true, t)
assertMatches(topic, "a/c/c", true, t)
assertMatches(topic, "a/b/d", false, t)
assertMatches(topic, "aaa/c/a", false, t)
topic, _ = packets.ParseTopic("+/+/+")
assertMatches(topic, "a/b/c", true, t)
assertMatches(topic, "a/c/c", true, t)
assertMatches(topic, "a/b/d/e", false, t)
assertMatches(topic, "c/a", true, t)
}
func TestMultiLevelWildcard(t *testing.T) {
topic, _ := packets.ParseTopic("a/b/c/#")
assertMatches(topic, "a/b/c", true, t)
assertMatches(topic, "a/b/c/a", true, t)
assertMatches(topic, "a/b/c/d", true, t)
assertMatches(topic, "a/b/c/f", true, t)
assertMatches(topic, "a/b/d/a", false, t)
}