Compare commits
No commits in common. "7cab9720650c9836a8b7cf4c8e7591b066a91523" and "01df3272b5c571f1c544bca1a4aced5f86088604" have entirely different histories.
7cab972065
...
01df3272b5
5 changed files with 63 additions and 127 deletions
2
main.go
2
main.go
|
@ -38,7 +38,7 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
|||
}
|
||||
|
||||
var sess *session.Session
|
||||
if conReq.ConnectPakcet.ClientId != nil {
|
||||
if(conReq.ConnectPakcet.ClientId != nil) {
|
||||
sess, exists := sessions[*conReq.ConnectPakcet.ClientId]
|
||||
if exists {
|
||||
sess.ConnecionChannel <- conReq
|
||||
|
|
|
@ -64,7 +64,7 @@ func parsePublishPacket(control controlPacket) (PublishPacket, error) {
|
|||
|
||||
func (p PublishPacket) Write(w io.Writer) error {
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
|
||||
|
||||
err := types.WriteUTF8String(buf, p.TopicName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -77,6 +77,7 @@ func (p PublishPacket) Write(w io.Writer) error {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
err = properties.WriteProps(buf, p.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -26,7 +26,7 @@ type Session struct {
|
|||
// Nullable
|
||||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
ConnecionChannel chan ConnectionRequest
|
||||
ConnecionChannel chan ConnectionRequest
|
||||
|
||||
ExpiryInterval time.Duration // TODO
|
||||
expireTimer time.Timer // TODO
|
||||
|
@ -35,7 +35,7 @@ type Session struct {
|
|||
}
|
||||
|
||||
type ConnectionRequest struct {
|
||||
Connection *Connection
|
||||
Connection *Connection
|
||||
ConnectPakcet packets.ConnectPacket
|
||||
}
|
||||
|
||||
|
@ -89,14 +89,11 @@ func (s *Session) HandlerLoop() {
|
|||
case c := <-s.ConnecionChannel:
|
||||
s.Connect(c)
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
subMessage.QOSLevel = 0
|
||||
subMessage.Dup = false
|
||||
s.Connection.sendPacket(subMessage)
|
||||
subMessage.QOSLevel = 0
|
||||
subMessage.Dup = false
|
||||
s.Connection.sendPacket(subMessage)
|
||||
}
|
||||
}
|
||||
c := <-s.ConnecionChannel
|
||||
s.Connect(c)
|
||||
s.HandlerLoop()
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
||||
|
@ -159,30 +156,30 @@ func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
|||
}
|
||||
|
||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
||||
for _, subNode := range subNodes {
|
||||
subNode.NodeLock.RLock()
|
||||
defer subNode.NodeLock.RUnlock()
|
||||
if p.QOSLevel == 0 {
|
||||
if p.PacketId != nil {
|
||||
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
||||
return
|
||||
}
|
||||
} else if p.QOSLevel == 1 {
|
||||
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
||||
ack := packets.PubackPacket{
|
||||
PacketID: *p.PacketId,
|
||||
Reason: reason,
|
||||
}
|
||||
s.Connection.sendPacket(ack)
|
||||
} else if p.QOSLevel == 2 {
|
||||
panic("UNIMPLEMENTED QOS level 2")
|
||||
subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
||||
defer lock.Unlock()
|
||||
if p.QOSLevel == 0 {
|
||||
if p.PacketId != nil {
|
||||
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
||||
return
|
||||
}
|
||||
} else if p.QOSLevel == 1 {
|
||||
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
||||
if len(subs) == 0 {
|
||||
reason = packets.PubackReasonCodeNoMatchingSubscribers
|
||||
}
|
||||
ack := packets.PubackPacket{
|
||||
PacketID: *p.PacketId,
|
||||
Reason: reason,
|
||||
}
|
||||
s.Connection.sendPacket(ack)
|
||||
} else if p.QOSLevel == 2 {
|
||||
panic("UNIMPLEMENTED QOS level 2")
|
||||
}
|
||||
|
||||
for _, sub := range subNode.Subscriptions {
|
||||
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
||||
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
||||
}
|
||||
for _, sub := range subs {
|
||||
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
||||
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,9 +19,9 @@ type Subscription struct {
|
|||
}
|
||||
|
||||
type SubscriptionTreeNode struct {
|
||||
subscriptions []Subscription
|
||||
children map[string]*SubscriptionTreeNode
|
||||
Subscriptions []Subscription
|
||||
NodeLock sync.RWMutex
|
||||
nodeLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSubscriptionTreeNode() *SubscriptionTreeNode {
|
||||
|
@ -37,36 +37,34 @@ func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
|||
|
||||
field := fields[0]
|
||||
|
||||
s.NodeLock.RLock()
|
||||
s.nodeLock.RLock()
|
||||
_, exists := s.children[field]
|
||||
// Insert a value into the map if one doesn't exist yet
|
||||
if !exists {
|
||||
// Can't upgrade a read lock so we need to unlock and
|
||||
// check again, this time with a write lock
|
||||
s.NodeLock.RUnlock()
|
||||
s.NodeLock.Lock()
|
||||
s.nodeLock.RUnlock()
|
||||
s.nodeLock.Lock()
|
||||
|
||||
_, exists = s.children[field]
|
||||
if !exists {
|
||||
s.children[field] = newSubscriptionTreeNode()
|
||||
}
|
||||
s.NodeLock.Unlock()
|
||||
s.NodeLock.RLock()
|
||||
s.nodeLock.Unlock()
|
||||
s.nodeLock.RLock()
|
||||
}
|
||||
|
||||
child, _ := s.children[field]
|
||||
s.NodeLock.RUnlock()
|
||||
s.nodeLock.RUnlock()
|
||||
return child.findNode(fields[1:])
|
||||
}
|
||||
|
||||
|
||||
|
||||
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
|
||||
for i, sub := range s.Subscriptions {
|
||||
for i, sub := range s.subscriptions {
|
||||
if sub.SubscriptionChannel == subChan {
|
||||
lst := len(s.Subscriptions) - 1
|
||||
s.Subscriptions[i] = s.Subscriptions[lst]
|
||||
s.Subscriptions = s.Subscriptions[:lst]
|
||||
lst := len(s.subscriptions) - 1
|
||||
s.subscriptions[i] = s.subscriptions[lst]
|
||||
s.subscriptions = s.subscriptions[:lst]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -75,64 +73,34 @@ func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subCha
|
|||
sub := Subscription{subChan, topicFilter}
|
||||
|
||||
node := s.findNode(topicFilter.Topic.Fields)
|
||||
node.NodeLock.Lock()
|
||||
node.Subscriptions = append(node.Subscriptions, sub)
|
||||
node.NodeLock.Unlock()
|
||||
node.nodeLock.Lock()
|
||||
node.subscriptions = append(node.subscriptions, sub)
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
|
||||
node := s.findNode(topic.Fields)
|
||||
|
||||
node.NodeLock.Lock()
|
||||
node.nodeLock.Lock()
|
||||
node.removeSubscription(subChan)
|
||||
node.NodeLock.Unlock()
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
|
||||
for _, node := range s.children {
|
||||
node.NodeLock.Lock()
|
||||
node.nodeLock.Lock()
|
||||
node.removeSubscription(subChan)
|
||||
node.NodeLock.Unlock()
|
||||
node.nodeLock.Unlock()
|
||||
|
||||
node.RemoveSubsForChannel(subChan)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the subscriptions whose filters match the given topic name
|
||||
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) []*SubscriptionTreeNode {
|
||||
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) {
|
||||
fields := strings.Split(topicName, "/")
|
||||
return s.matchSubscriptions(fields)
|
||||
}
|
||||
|
||||
// Returns nodes with subscriptions that match the given topic
|
||||
func (s *SubscriptionTreeNode) matchSubscriptions(fields []string) []*SubscriptionTreeNode {
|
||||
if len(fields) == 0 {
|
||||
return []*SubscriptionTreeNode{s}
|
||||
}
|
||||
|
||||
sub := make([]*SubscriptionTreeNode, 0)
|
||||
|
||||
s.NodeLock.RLock()
|
||||
|
||||
// Single level wildcard(+)
|
||||
if SlWildcard, exists := s.children["+"]; exists {
|
||||
sub = append(sub, SlWildcard.matchSubscriptions(fields[1:])...)
|
||||
}
|
||||
|
||||
// Multi level wildcard(#)
|
||||
if MlWildcard, exists := s.children["#"]; exists {
|
||||
sub = append(sub, MlWildcard)
|
||||
}
|
||||
|
||||
field := fields[0]
|
||||
if field == "#" || field == "+" {
|
||||
// TODO handle gracefully
|
||||
panic("Wildcard in topic")
|
||||
}
|
||||
|
||||
if child, exists := s.children[field]; exists {
|
||||
sub = append(sub, child.matchSubscriptions(fields[1:])...)
|
||||
}
|
||||
s.NodeLock.RUnlock()
|
||||
return sub
|
||||
|
||||
child := s.findNode(fields)
|
||||
locker := child.nodeLock.RLocker()
|
||||
locker.Lock()
|
||||
return child.subscriptions, locker
|
||||
}
|
||||
|
|
|
@ -6,52 +6,22 @@ import (
|
|||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
func assertMatches(topic packets.Topic, topicName string, shouldMatch bool, t *testing.T) {
|
||||
func TestSubscribe(t *testing.T) {
|
||||
tree := newSubscriptionTreeNode()
|
||||
topic, _ := packets.ParseTopic("a/b/c")
|
||||
channel := make(SubscriptionChannel)
|
||||
topicFilter := packets.TopicFilter{
|
||||
Topic: topic,
|
||||
MaxQoS: 1,
|
||||
}
|
||||
tree.Subscribe(topicFilter, channel)
|
||||
subs := tree.GetSubscriptions(topicName)
|
||||
subs, lock := tree.GetSubscriptions("a/b/c")
|
||||
defer lock.Unlock()
|
||||
|
||||
if (len(subs) != 1) && shouldMatch {
|
||||
t.Errorf("Topic %v did not match %v", topic, topicName)
|
||||
if len(subs) != 1 {
|
||||
t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs))
|
||||
}
|
||||
if (len(subs) == 1) && !shouldMatch {
|
||||
t.Errorf("Topic %v matched %v (it wasn't supposed to)", topic, topicName )
|
||||
if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel {
|
||||
t.Error("Error with data stored in a subscription")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
topic, _ := packets.ParseTopic("a/b/c")
|
||||
assertMatches(topic, "a/b/c", true, t)
|
||||
assertMatches(topic, "a/c/c", false, t)
|
||||
assertMatches(topic, "b/b/c", false, t)
|
||||
assertMatches(topic, "aaa/c/a", false, t)
|
||||
}
|
||||
|
||||
func TestSingleLevelWildcard(t *testing.T) {
|
||||
topic, _ := packets.ParseTopic("a/+/c")
|
||||
assertMatches(topic, "a/b/c", true, t)
|
||||
assertMatches(topic, "a/c/c", true, t)
|
||||
assertMatches(topic, "a/b/d", false, t)
|
||||
assertMatches(topic, "aaa/c/a", false, t)
|
||||
|
||||
topic, _ = packets.ParseTopic("+/+/+")
|
||||
assertMatches(topic, "a/b/c", true, t)
|
||||
assertMatches(topic, "a/c/c", true, t)
|
||||
assertMatches(topic, "a/b/d/e", false, t)
|
||||
assertMatches(topic, "c/a", true, t)
|
||||
}
|
||||
|
||||
func TestMultiLevelWildcard(t *testing.T) {
|
||||
topic, _ := packets.ParseTopic("a/b/c/#")
|
||||
assertMatches(topic, "a/b/c", true, t)
|
||||
assertMatches(topic, "a/b/c/a", true, t)
|
||||
assertMatches(topic, "a/b/c/d", true, t)
|
||||
assertMatches(topic, "a/b/c/f", true, t)
|
||||
|
||||
assertMatches(topic, "a/b/d/a", false, t)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue