136 lines
3.3 KiB
Go
136 lines
3.3 KiB
Go
package subscription
|
|
|
|
import (
|
|
"strings"
|
|
"sync"
|
|
|
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
|
)
|
|
|
|
var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode()
|
|
|
|
type SubscriptionChannel chan packets.PublishPacket
|
|
|
|
type Subscription struct {
|
|
SubscriptionChannel
|
|
packets.TopicFilter
|
|
}
|
|
|
|
type SubscriptionTreeNode struct {
|
|
children map[string]*SubscriptionTreeNode
|
|
Subscriptions []Subscription
|
|
NodeLock sync.RWMutex
|
|
}
|
|
|
|
func newSubscriptionTreeNode() *SubscriptionTreeNode {
|
|
s := SubscriptionTreeNode{}
|
|
s.children = make(map[string]*SubscriptionTreeNode)
|
|
return &s
|
|
}
|
|
|
|
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
|
if len(fields) == 0 {
|
|
return s
|
|
}
|
|
|
|
field := fields[0]
|
|
|
|
s.NodeLock.RLock()
|
|
_, exists := s.children[field]
|
|
// Insert a value into the map if one doesn't exist yet
|
|
if !exists {
|
|
// Can't upgrade a read lock so we need to unlock and
|
|
// check again, this time with a write lock
|
|
s.NodeLock.RUnlock()
|
|
s.NodeLock.Lock()
|
|
|
|
_, exists = s.children[field]
|
|
if !exists {
|
|
s.children[field] = newSubscriptionTreeNode()
|
|
}
|
|
s.NodeLock.Unlock()
|
|
s.NodeLock.RLock()
|
|
}
|
|
|
|
child, _ := s.children[field]
|
|
s.NodeLock.RUnlock()
|
|
return child.findNode(fields[1:])
|
|
}
|
|
|
|
|
|
|
|
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
|
|
for i, sub := range s.Subscriptions {
|
|
if sub.SubscriptionChannel == subChan {
|
|
lst := len(s.Subscriptions) - 1
|
|
s.Subscriptions[i] = s.Subscriptions[lst]
|
|
s.Subscriptions = s.Subscriptions[:lst]
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) {
|
|
sub := Subscription{subChan, topicFilter}
|
|
|
|
node := s.findNode(topicFilter.Topic.Fields)
|
|
node.NodeLock.Lock()
|
|
node.Subscriptions = append(node.Subscriptions, sub)
|
|
node.NodeLock.Unlock()
|
|
}
|
|
|
|
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
|
|
node := s.findNode(topic.Fields)
|
|
|
|
node.NodeLock.Lock()
|
|
node.removeSubscription(subChan)
|
|
node.NodeLock.Unlock()
|
|
}
|
|
|
|
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
|
|
for _, node := range s.children {
|
|
node.NodeLock.Lock()
|
|
node.removeSubscription(subChan)
|
|
node.NodeLock.Unlock()
|
|
|
|
node.RemoveSubsForChannel(subChan)
|
|
}
|
|
}
|
|
|
|
// Returns the subscriptions whose filters match the given topic name
|
|
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) []*SubscriptionTreeNode {
|
|
fields := strings.Split(topicName, "/")
|
|
return s.matchSubscriptions(fields)
|
|
}
|
|
|
|
// Returns nodes with subscriptions that match the given topic
|
|
func (s *SubscriptionTreeNode) matchSubscriptions(fields []string) []*SubscriptionTreeNode {
|
|
if len(fields) == 0 {
|
|
return []*SubscriptionTreeNode{s}
|
|
}
|
|
|
|
sub := make([]*SubscriptionTreeNode, 0)
|
|
|
|
s.NodeLock.RLock()
|
|
|
|
// Single level wildcard(+)
|
|
if SlWildcard, exists := s.children["+"]; exists {
|
|
sub = append(sub, SlWildcard.matchSubscriptions(fields[1:])...)
|
|
}
|
|
|
|
// Multi level wildcard(#)
|
|
if MlWildcard, exists := s.children["#"]; exists {
|
|
sub = append(sub, MlWildcard)
|
|
}
|
|
|
|
field := fields[0]
|
|
if field == "#" || field == "+" {
|
|
// TODO handle gracefully
|
|
panic("Wildcard in topic")
|
|
}
|
|
|
|
if child, exists := s.children[field]; exists {
|
|
sub = append(sub, child.matchSubscriptions(fields[1:])...)
|
|
}
|
|
s.NodeLock.RUnlock()
|
|
return sub
|
|
}
|