Initial session handling

This commit is contained in:
bad 2021-09-28 12:30:32 +02:00
parent a927a5ac1a
commit e9612e2430
12 changed files with 512 additions and 56 deletions

92
connection.go Normal file
View file

@ -0,0 +1,92 @@
package main
import (
"bufio"
"fmt"
"io"
"time"
"badat.dev/maeqtt/v2/mqtt/packets"
)
type Connection struct {
MaxPacketSize *uint32
RecvMax uint16
TopicAliasMax uint16
WantsRespInf bool
WantsProblemInf bool
Will packets.Will
KeepAliveInterval time.Duration
keepAliveTicker time.Ticker
PacketChannel chan packets.ClientPacket
ClientDisconnectedChan chan bool
rw io.ReadWriteCloser
}
func (c *Connection) resetKeepAlive() {
if c.KeepAliveInterval != 0 {
// TODO IMPLEMENT THIS
//s.keepAliveTicker.Reset(s.KeepAliveInterval)
}
}
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
return packets.ReadPacket(bufio.NewReader(c.rw))
}
func (c *Connection) sendPacket(p packets.ServerPacket) error {
c.resetKeepAlive()
return p.Write(c.rw)
}
func (c *Connection) close() error {
close(c.PacketChannel)
return c.rw.Close()
}
func (c *Connection) packetReadLoop() {
for {
pack, err := c.readPacket()
if err == io.EOF {
c.ClientDisconnectedChan <- true
} else if err != nil {
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
} else {
c.PacketChannel <- *pack
}
}
}
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
conn := Connection{}
conn.rw = rw
if p.Properties.ReceiveMaximum.Value != nil {
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
} else {
conn.RecvMax = 65535
}
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
if p.Properties.TopicAliasMaximum.Value != nil {
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
} else {
conn.TopicAliasMax = 0
}
if p.Properties.RequestProblemInformation.Value != nil {
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
} else {
conn.WantsRespInf = false
}
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
conn.PacketChannel = make(chan packets.ClientPacket)
go conn.packetReadLoop()
return conn
}

8
go.mod
View file

@ -2,4 +2,10 @@ module badat.dev/maeqtt/v2
go 1.16
require github.com/gdexlab/go-render v1.0.1 // indirect
require (
github.com/gdexlab/go-render v1.0.1 // indirect
github.com/josharian/impl v1.1.0 // indirect
golang.org/x/mod v0.5.0 // indirect
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect
golang.org/x/tools v0.1.5 // indirect
)

37
go.sum
View file

@ -1,2 +1,39 @@
github.com/gdexlab/go-render v1.0.1 h1:rxqB3vo5s4n1kF0ySmoNeSPRYkEsyHgln4jFIQY7v0U=
github.com/gdexlab/go-render v1.0.1/go.mod h1:wRi5nW2qfjiGj4mPukH4UV0IknS1cHD4VgFTmJX5JzM=
github.com/josharian/impl v1.1.0 h1:gafhg1OFVMq46ifdkBa8wp4hlGogjktjjA5h/2j4+2k=
github.com/josharian/impl v1.1.0/go.mod h1:SQ6aJMP6xsJpGSD/36IIqrUdigLCYe9bz/9o5AKm6Aw=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q=
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw=
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

55
main.go
View file

@ -4,10 +4,9 @@ import (
"bufio"
"log"
"net"
"runtime/debug"
"badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/mqtt/properties"
"github.com/gdexlab/go-render/render" // For testing
)
func main() {
@ -28,37 +27,31 @@ func main() {
}
func handleConnection(con net.Conn) {
defer closeConnection(con)
defer handlePanic()
for {
reader := bufio.NewReader(con)
packet, err := packets.ReadPacket(reader)
if err != nil {
log.Println("Error reading packet ", err)
break
}
log.Println(render.AsCode(packet))
clientId := "aa"
resp := packets.ConnackPacket{
ResonCode: packets.ConnectReasonCodeSuccess,
SessionPresent: false,
Properties: properties.ConnackPacketProperties{
AssignedClientIdentifier: properties.AssignedClientIdentifier{Value: &clientId},
},
}
err = resp.Write(con)
log.Println("Wrote response")
if err != nil {
log.Println("Error writing response ", err)
break
}
}
reader := bufio.NewReader(con)
}
func closeConnection(con net.Conn) {
err := con.Close()
packet, err := packets.ReadPacket(reader)
if err != nil {
log.Println("Failed to close connection", err)
log.Println("Error reading packet ", err)
return
}
connect, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't recieve a connet packet")
panic("TODO: Send a disconnect packet")
}
conn := NewConnection(connect, con)
sess := NewSession(&conn, connect)
sess.HandlerLoop()
}
func handlePanic() {
if r := recover(); r != nil {
log.Println("Recovering from panic:", r)
log.Println("Stack Trace:")
debug.PrintStack()
}
}

View file

@ -8,8 +8,8 @@ import (
)
type Will struct {
retain bool
properties properties.WillProperties
Retain bool
Properties properties.WillProperties
}
type ConnectPacket struct {
@ -23,8 +23,8 @@ type ConnectPacket struct {
Properties properties.ConnectPacketProperties
}
func (c ConnectPacket) visit(visitor PacketVisitor) {
visitor.visitConnect(c)
func (c ConnectPacket) Visit(Visitor PacketVisitor) {
Visitor.VisitConnect(c)
}
func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
@ -95,11 +95,11 @@ func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
if willFlag {
packet.Will = &Will{}
err = properties.ParseProperties(r, packet.Will.properties.ArrayOf())
err = properties.ParseProperties(r, packet.Will.Properties.ArrayOf())
if err != nil {
return packet, err
}
packet.Will.retain = willRetainFlag
packet.Will.Retain = willRetainFlag
}
var username string

View file

@ -99,6 +99,6 @@ func (p DisconnectPacket) Write(w io.Writer) error {
return control.write(w)
}
func (p DisconnectPacket) visit(v PacketVisitor) {
v.visitDisconnect(p)
func (p DisconnectPacket) Visit(v PacketVisitor) {
v.VisitDisconnect(p)
}

View file

@ -9,7 +9,7 @@ import (
type PingreqPacket struct {}
func ParesPingreq(control controlPacket) (PingreqPacket, error) {
func parsePingreq(control controlPacket) (PingreqPacket, error) {
packet := PingreqPacket{}
if control.packetType != PacketTypePingreq {
@ -22,6 +22,10 @@ func ParesPingreq(control controlPacket) (PingreqPacket, error) {
return packet, nil
}
func (r PingreqPacket) Visit(p PacketVisitor) {
p.VisitPing(r)
}
type PingrespPacket struct {}
func (p PingrespPacket) Write(w io.Writer) error {

View file

@ -19,8 +19,8 @@ type PublishPacket struct {
Properties properties.PublishPacketProperties
}
func (p PublishPacket) visit(v PacketVisitor) {
v.visitPublish(p)
func (p PublishPacket) Visit(v PacketVisitor) {
v.VisitPublish(p)
}
func parsePublishPacket(control controlPacket) (PublishPacket, error) {

View file

@ -4,13 +4,32 @@ import (
"bufio"
"errors"
"io"
"strings"
"badat.dev/maeqtt/v2/mqtt/properties"
"badat.dev/maeqtt/v2/mqtt/types"
)
type Topic struct {
Fields []string
}
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
func parseTopic(topic_name string) (Topic, error) {
topic := Topic{}
fields := strings.Split(topic_name, "/")
for i, field := range fields {
if field == "#" && len(fields) > i+1 {
return topic, multiLevelWildcardNotLast
}
}
topic.Fields = fields
return topic, nil
}
type TopicFilter struct {
Topic string
Topic Topic
MaxQoS uint
NoLocal bool
RetainAsPublished bool
@ -21,7 +40,12 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
filter := TopicFilter{}
var err error
filter.Topic, err = types.DecodeUTF8String(r)
topic_str, err := types.DecodeUTF8String(r)
if err != nil {
return filter, err
}
filter.Topic, err = parseTopic(topic_str)
if err != nil {
return filter, err
}
@ -69,10 +93,14 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
return packet, err
}
_, err = r.Peek(1)
if err != nil || err != io.EOF {
if err != nil && err != io.EOF {
return packet, err
}
if err == io.EOF {
return packet, nil
}
}
println("A")
return packet, nil
}
@ -82,6 +110,9 @@ type SubscribePacket struct {
props properties.SubscribePacketProperties
}
/// CURRENTLY BROKEN
// TODO FIXME AAAAA
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
@ -92,11 +123,14 @@ func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
if err != nil {
return pack, err
}
pack.PacketId = subscriptionPack.PacketId
pack.TopicFilters = subscriptionPack.TopicFilters
pack.SubscriptionPacket = &subscriptionPack
return pack, nil
}
func (p SubscribePacket) Visit(v PacketVisitor) {
v.VisitSubscribe(p)
}
type SubackReasonCode byte
const (
@ -151,6 +185,10 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
return pack, nil
}
func (p UnsubscribePacket) Visit(v PacketVisitor) {
v.VisitUnsubscribe(p)
}
type UnsubackReasonCode byte
const (

View file

@ -2,7 +2,6 @@ package packets
import (
"bufio"
"bytes"
"fmt"
"io"
@ -10,17 +9,24 @@ import (
)
type PacketVisitor interface {
visitConnect(ConnectPacket)
visitPublish(PublishPacket)
visitDisconnect(DisconnectPacket)
VisitConnect(ConnectPacket)
VisitPublish(PublishPacket)
VisitDisconnect(DisconnectPacket)
VisitSubscribe(SubscribePacket)
VisitUnsubscribe(UnsubscribePacket)
VisitPing(PingreqPacket)
VisitPubackPacket(PubackPacket)
VisitPubrecPacket(PubrecPacket)
VisitPubrelPacket(PubrelPacket)
VisitPubcompPacket(PubcompPacket)
}
type ClientPacket interface {
visit(PacketVisitor)
Visit(PacketVisitor)
}
type ServerPacket interface {
Encode() (bytes.Buffer, error)
Write(w io.Writer) error
}
type PacketType byte
@ -45,12 +51,10 @@ const (
)
func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
println("AAAA")
fixedHeader, err := r.ReadByte()
if err != nil {
return nil, err
}
println("BBB")
highestFourBits := uint((fixedHeader >> 4) & 0b1111)
lowerFourBits := uint(fixedHeader & 0b1111)
@ -60,6 +64,7 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
return nil, err
}
reader := io.LimitReader(r, int64(dataLength))
control := controlPacket{
packetType: PacketType(highestFourBits),
flags: lowerFourBits,
@ -74,6 +79,20 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
packet, err = parsePublishPacket(control)
case PacketTypeDisconnect:
packet, err = parseDisconnectPacket(control)
case PacketTypeSubscribe:
packet, err = parseSubscribePacket(control)
case PacketTypeUnsubscribe:
packet, err = parseUnsubscribePacket(control)
case PacketTypePingreq:
packet, err = parsePingreq(control)
case PacketTypePuback:
panic("Puback packet parsing unimplemented")
case PacketTypePubrec:
panic("Pubrec packet parsing unimplemented")
case PacketTypePubrel:
panic("Pubrel packet parsing unimplemented")
case PacketTypePubcomp:
panic("Pubcomp packet parsing unimplemented")
default:
return nil, fmt.Errorf("Unknown packet type %v", control.packetType)
}

181
session.go Normal file
View file

@ -0,0 +1,181 @@
package main
import (
"encoding/base64"
"fmt"
"log"
"math/rand"
"time"
"badat.dev/maeqtt/v2/mqtt/packets"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func Auth(username string, password []byte) bool {
return true
}
type Session struct {
ClientID *string
// Nullable
Connection *Connection
SubscriptionChannel chan packets.PublishPacket
ExpiryInterval time.Duration
expireTimer time.Timer // TODO
}
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
sess := Session{}
sess.SubscriptionChannel = make(chan packets.PublishPacket)
sess.Connect(conn, p)
return sess
}
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
if s.Connection != nil {
//TODO
panic("Disconnect if already have a connection, unimplemented")
}
connAck := packets.ConnackPacket{}
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
if p.ClientId != nil {
if s.ClientID == nil {
s.ClientID = genClientID()
}
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
}
true := byte(1)
false := byte(0)
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = conn
err := s.Connection.sendPacket(connAck)
if err != nil {
panic(err)
}
}
// Starts a loop the recieves and responds to packets
func (s *Session) HandlerLoop() {
for s.Connection != nil {
select {
case packet := <-s.Connection.PacketChannel:
packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan:
s.OnDisconnect()
case subMessage := <-s.SubscriptionChannel:
//TODO, log for now
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
}
}
}
func (s *Session) Disconnect() error {
panic("Disconnection unimplemented")
err := s.Connection.close()
if err != nil {
return err
}
s.OnDisconnect()
return nil
}
func (s *Session) OnDisconnect() {
s.Connection = nil
s.resetExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID)
}
// newTime is nullable
func (s *Session) updateExpireTimer(newTime *uint32) {
var expiry = uint32(0)
if newTime != nil {
expiry = *newTime
} else {
expiry = uint32(0)
}
s.ExpiryInterval = time.Duration(expiry) * time.Second
if s.Connection == nil {
s.resetExpireTimer()
}
}
func (s *Session) resetExpireTimer() {
//s.expireTimer.Reset(s.ExpiryInterval)
}
func genClientID() *string {
buf := make([]byte, 32)
_, err := rand.Read(buf)
if err != nil {
// I don't think this can actually happen but just in case panic
panic(fmt.Errorf("Failed to generate a client id, %e", err))
}
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
return &id
}
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION
s.Disconnect()
}
func (s *Session) VisitPublish(p packets.PublishPacket) {
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload))
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
defer lock.Unlock()
for _, sub := range subs {
go func(sub Subscription) {sub <- p}(sub)
}
}
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
//TODO FINISH
// HANDLE CLIENT DISCONNECTING
s.OnDisconnect()
}
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
//TODO FINISH
for _, filter := range p.TopicFilters {
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel)
}
}
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPing(p packets.PingreqPacket) {
s.Connection.sendPacket(packets.PingrespPacket{})
}
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
panic("not implemented") // TODO: Implement
}

86
subscription.go Normal file
View file

@ -0,0 +1,86 @@
package main
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
type Subscription chan packets.PublishPacket
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = NewSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
fields := strings.Split(topic,"/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
locker := s.nodeLock.RLocker()
s.nodeLock.RLock()
if len(topic) == 0 {
return s.subscriptions,locker
}
defer s.nodeLock.RUnlock()
child, exists := s.children[topic[0]]
if exists {
return child.findMatchingRec(topic[1:])
} else {
return []Subscription{},locker
}
}