maeqtt/subscription/subscription.go

134 lines
3.3 KiB
Go
Raw Normal View History

2021-10-01 22:18:48 +02:00
package subscription
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode()
type SubscriptionChannel chan packets.PublishPacket
type Subscription struct {
SubscriptionChannel
packets.TopicFilter
}
type SubscriptionTreeNode struct {
children map[string]*SubscriptionTreeNode
2021-10-08 23:53:37 +02:00
Subscriptions []Subscription
NodeLock sync.RWMutex
2021-10-01 22:18:48 +02:00
}
func newSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
2021-10-08 23:53:37 +02:00
s.NodeLock.RLock()
2021-10-01 22:18:48 +02:00
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
2021-10-08 23:53:37 +02:00
s.NodeLock.RUnlock()
s.NodeLock.Lock()
2021-10-01 22:18:48 +02:00
_, exists = s.children[field]
if !exists {
s.children[field] = newSubscriptionTreeNode()
}
2021-10-08 23:53:37 +02:00
s.NodeLock.Unlock()
s.NodeLock.RLock()
2021-10-01 22:18:48 +02:00
}
2021-10-17 20:58:16 +02:00
child := s.children[field]
2021-10-08 23:53:37 +02:00
s.NodeLock.RUnlock()
2021-10-01 22:18:48 +02:00
return child.findNode(fields[1:])
}
2021-10-08 23:53:37 +02:00
2021-10-01 22:18:48 +02:00
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
2021-10-08 23:53:37 +02:00
for i, sub := range s.Subscriptions {
2021-10-01 22:18:48 +02:00
if sub.SubscriptionChannel == subChan {
2021-10-08 23:53:37 +02:00
lst := len(s.Subscriptions) - 1
s.Subscriptions[i] = s.Subscriptions[lst]
s.Subscriptions = s.Subscriptions[:lst]
2021-10-01 22:18:48 +02:00
}
}
}
func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) {
sub := Subscription{subChan, topicFilter}
node := s.findNode(topicFilter.Topic.Fields)
2021-10-08 23:53:37 +02:00
node.NodeLock.Lock()
node.Subscriptions = append(node.Subscriptions, sub)
node.NodeLock.Unlock()
2021-10-01 22:18:48 +02:00
}
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
node := s.findNode(topic.Fields)
2021-10-08 23:53:37 +02:00
node.NodeLock.Lock()
2021-10-01 22:18:48 +02:00
node.removeSubscription(subChan)
2021-10-08 23:53:37 +02:00
node.NodeLock.Unlock()
2021-10-01 22:18:48 +02:00
}
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
for _, node := range s.children {
2021-10-08 23:53:37 +02:00
node.NodeLock.Lock()
2021-10-01 22:18:48 +02:00
node.removeSubscription(subChan)
2021-10-08 23:53:37 +02:00
node.NodeLock.Unlock()
2021-10-01 22:18:48 +02:00
node.RemoveSubsForChannel(subChan)
}
}
2021-10-08 23:53:37 +02:00
// Returns the subscriptions whose filters match the given topic name
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) []*SubscriptionTreeNode {
2021-10-01 22:18:48 +02:00
fields := strings.Split(topicName, "/")
2021-10-08 23:53:37 +02:00
return s.matchSubscriptions(fields)
}
// Returns nodes with subscriptions that match the given topic
func (s *SubscriptionTreeNode) matchSubscriptions(fields []string) []*SubscriptionTreeNode {
if len(fields) == 0 {
return []*SubscriptionTreeNode{s}
}
sub := make([]*SubscriptionTreeNode, 0)
s.NodeLock.RLock()
2021-10-01 22:18:48 +02:00
2021-10-08 23:53:37 +02:00
// Single level wildcard(+)
if SlWildcard, exists := s.children["+"]; exists {
sub = append(sub, SlWildcard.matchSubscriptions(fields[1:])...)
}
// Multi level wildcard(#)
if MlWildcard, exists := s.children["#"]; exists {
sub = append(sub, MlWildcard)
}
field := fields[0]
// this goes against the spec but I'm lazy so let's just be sane but not really correct
if child, exists := s.children[field]; exists && field != "+" && field != "#" {
2021-10-08 23:53:37 +02:00
sub = append(sub, child.matchSubscriptions(fields[1:])...)
}
s.NodeLock.RUnlock()
return sub
2021-10-01 22:18:48 +02:00
}