diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..1df95d4 --- /dev/null +++ b/connection.go @@ -0,0 +1,92 @@ +package main + +import ( + "bufio" + "fmt" + "io" + "time" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +type Connection struct { + MaxPacketSize *uint32 + RecvMax uint16 + TopicAliasMax uint16 + WantsRespInf bool + WantsProblemInf bool + Will packets.Will + + KeepAliveInterval time.Duration + keepAliveTicker time.Ticker + + PacketChannel chan packets.ClientPacket + ClientDisconnectedChan chan bool + + rw io.ReadWriteCloser +} + +func (c *Connection) resetKeepAlive() { + if c.KeepAliveInterval != 0 { + // TODO IMPLEMENT THIS + //s.keepAliveTicker.Reset(s.KeepAliveInterval) + } +} + +func (c *Connection) readPacket() (*packets.ClientPacket, error) { + return packets.ReadPacket(bufio.NewReader(c.rw)) +} + +func (c *Connection) sendPacket(p packets.ServerPacket) error { + c.resetKeepAlive() + return p.Write(c.rw) +} + +func (c *Connection) close() error { + close(c.PacketChannel) + return c.rw.Close() +} + +func (c *Connection) packetReadLoop() { + for { + pack, err := c.readPacket() + if err == io.EOF { + c.ClientDisconnectedChan <- true + } else if err != nil { + panic(fmt.Errorf("Unimplemented error handling, %e", err).Error()) + } else { + c.PacketChannel <- *pack + } + } +} + +func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection { + conn := Connection{} + conn.rw = rw + + if p.Properties.ReceiveMaximum.Value != nil { + conn.RecvMax = *p.Properties.ReceiveMaximum.Value + } else { + conn.RecvMax = 65535 + } + conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value + + if p.Properties.TopicAliasMaximum.Value != nil { + conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value + } else { + conn.TopicAliasMax = 0 + } + + if p.Properties.RequestProblemInformation.Value != nil { + conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0 + } else { + conn.WantsRespInf = false + } + + conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second + + conn.PacketChannel = make(chan packets.ClientPacket) + + go conn.packetReadLoop() + return conn +} diff --git a/go.mod b/go.mod index c48e2ce..fb19e8b 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,10 @@ module badat.dev/maeqtt/v2 go 1.16 -require github.com/gdexlab/go-render v1.0.1 // indirect +require ( + github.com/gdexlab/go-render v1.0.1 // indirect + github.com/josharian/impl v1.1.0 // indirect + golang.org/x/mod v0.5.0 // indirect + golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect + golang.org/x/tools v0.1.5 // indirect +) diff --git a/go.sum b/go.sum index 349e9bd..8834db0 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,39 @@ github.com/gdexlab/go-render v1.0.1 h1:rxqB3vo5s4n1kF0ySmoNeSPRYkEsyHgln4jFIQY7v0U= github.com/gdexlab/go-render v1.0.1/go.mod h1:wRi5nW2qfjiGj4mPukH4UV0IknS1cHD4VgFTmJX5JzM= +github.com/josharian/impl v1.1.0 h1:gafhg1OFVMq46ifdkBa8wp4hlGogjktjjA5h/2j4+2k= +github.com/josharian/impl v1.1.0/go.mod h1:SQ6aJMP6xsJpGSD/36IIqrUdigLCYe9bz/9o5AKm6Aw= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q= +golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw= +golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index 8ccef38..12afe21 100644 --- a/main.go +++ b/main.go @@ -4,10 +4,9 @@ import ( "bufio" "log" "net" + "runtime/debug" "badat.dev/maeqtt/v2/mqtt/packets" - "badat.dev/maeqtt/v2/mqtt/properties" - "github.com/gdexlab/go-render/render" // For testing ) func main() { @@ -28,37 +27,31 @@ func main() { } func handleConnection(con net.Conn) { - defer closeConnection(con) + defer handlePanic() - for { - reader := bufio.NewReader(con) - packet, err := packets.ReadPacket(reader) - if err != nil { - log.Println("Error reading packet ", err) - break - } - log.Println(render.AsCode(packet)) - clientId := "aa" - resp := packets.ConnackPacket{ - ResonCode: packets.ConnectReasonCodeSuccess, - SessionPresent: false, - Properties: properties.ConnackPacketProperties{ - AssignedClientIdentifier: properties.AssignedClientIdentifier{Value: &clientId}, - }, - } - err = resp.Write(con) - log.Println("Wrote response") - if err != nil { - log.Println("Error writing response ", err) - break - } - } + reader := bufio.NewReader(con) -} - -func closeConnection(con net.Conn) { - err := con.Close() + packet, err := packets.ReadPacket(reader) if err != nil { - log.Println("Failed to close connection", err) + log.Println("Error reading packet ", err) + return + } + connect, isConn := (*packet).(packets.ConnectPacket) + if !isConn { + log.Println("Didn't recieve a connet packet") + panic("TODO: Send a disconnect packet") + } + + conn := NewConnection(connect, con) + + sess := NewSession(&conn, connect) + sess.HandlerLoop() +} + +func handlePanic() { + if r := recover(); r != nil { + log.Println("Recovering from panic:", r) + log.Println("Stack Trace:") + debug.PrintStack() } } diff --git a/mqtt/packets/Connect.go b/mqtt/packets/Connect.go index f7560e8..dba8ed1 100644 --- a/mqtt/packets/Connect.go +++ b/mqtt/packets/Connect.go @@ -8,8 +8,8 @@ import ( ) type Will struct { - retain bool - properties properties.WillProperties + Retain bool + Properties properties.WillProperties } type ConnectPacket struct { @@ -23,8 +23,8 @@ type ConnectPacket struct { Properties properties.ConnectPacketProperties } -func (c ConnectPacket) visit(visitor PacketVisitor) { - visitor.visitConnect(c) +func (c ConnectPacket) Visit(Visitor PacketVisitor) { + Visitor.VisitConnect(c) } func parseConnectPacket(control controlPacket) (ConnectPacket, error) { @@ -95,11 +95,11 @@ func parseConnectPacket(control controlPacket) (ConnectPacket, error) { if willFlag { packet.Will = &Will{} - err = properties.ParseProperties(r, packet.Will.properties.ArrayOf()) + err = properties.ParseProperties(r, packet.Will.Properties.ArrayOf()) if err != nil { return packet, err } - packet.Will.retain = willRetainFlag + packet.Will.Retain = willRetainFlag } var username string diff --git a/mqtt/packets/Disconnect.go b/mqtt/packets/Disconnect.go index 47637e9..3e0be31 100644 --- a/mqtt/packets/Disconnect.go +++ b/mqtt/packets/Disconnect.go @@ -99,6 +99,6 @@ func (p DisconnectPacket) Write(w io.Writer) error { return control.write(w) } -func (p DisconnectPacket) visit(v PacketVisitor) { - v.visitDisconnect(p) +func (p DisconnectPacket) Visit(v PacketVisitor) { + v.VisitDisconnect(p) } diff --git a/mqtt/packets/Ping.go b/mqtt/packets/Ping.go index 89f4b0c..d6dd1c3 100644 --- a/mqtt/packets/Ping.go +++ b/mqtt/packets/Ping.go @@ -9,7 +9,7 @@ import ( type PingreqPacket struct {} -func ParesPingreq(control controlPacket) (PingreqPacket, error) { +func parsePingreq(control controlPacket) (PingreqPacket, error) { packet := PingreqPacket{} if control.packetType != PacketTypePingreq { @@ -22,6 +22,10 @@ func ParesPingreq(control controlPacket) (PingreqPacket, error) { return packet, nil } +func (r PingreqPacket) Visit(p PacketVisitor) { + p.VisitPing(r) +} + type PingrespPacket struct {} func (p PingrespPacket) Write(w io.Writer) error { diff --git a/mqtt/packets/Publish.go b/mqtt/packets/Publish.go index 5d551c2..d33f1e1 100644 --- a/mqtt/packets/Publish.go +++ b/mqtt/packets/Publish.go @@ -19,8 +19,8 @@ type PublishPacket struct { Properties properties.PublishPacketProperties } -func (p PublishPacket) visit(v PacketVisitor) { - v.visitPublish(p) +func (p PublishPacket) Visit(v PacketVisitor) { + v.VisitPublish(p) } func parsePublishPacket(control controlPacket) (PublishPacket, error) { diff --git a/mqtt/packets/Subscriptions.go b/mqtt/packets/Subscriptions.go index 64c587a..dc91052 100644 --- a/mqtt/packets/Subscriptions.go +++ b/mqtt/packets/Subscriptions.go @@ -4,13 +4,32 @@ import ( "bufio" "errors" "io" + "strings" "badat.dev/maeqtt/v2/mqtt/properties" "badat.dev/maeqtt/v2/mqtt/types" ) +type Topic struct { + Fields []string +} + +var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic") +func parseTopic(topic_name string) (Topic, error) { + topic := Topic{} + fields := strings.Split(topic_name, "/") + for i, field := range fields { + if field == "#" && len(fields) > i+1 { + return topic, multiLevelWildcardNotLast + } + } + topic.Fields = fields + + return topic, nil +} + type TopicFilter struct { - Topic string + Topic Topic MaxQoS uint NoLocal bool RetainAsPublished bool @@ -21,7 +40,12 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) { filter := TopicFilter{} var err error - filter.Topic, err = types.DecodeUTF8String(r) + topic_str, err := types.DecodeUTF8String(r) + if err != nil { + return filter, err + } + + filter.Topic, err = parseTopic(topic_str) if err != nil { return filter, err } @@ -69,10 +93,14 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property) return packet, err } _, err = r.Peek(1) - if err != nil || err != io.EOF { + if err != nil && err != io.EOF { return packet, err } + if err == io.EOF { + return packet, nil + } } + println("A") return packet, nil } @@ -82,6 +110,9 @@ type SubscribePacket struct { props properties.SubscribePacketProperties } +/// CURRENTLY BROKEN + +// TODO FIXME AAAAA func parseSubscribePacket(control controlPacket) (SubscribePacket, error) { if control.packetType != PacketTypeSubscribe { panic("Wrong packet type for parseSubscribePacket") @@ -92,11 +123,14 @@ func parseSubscribePacket(control controlPacket) (SubscribePacket, error) { if err != nil { return pack, err } - pack.PacketId = subscriptionPack.PacketId - pack.TopicFilters = subscriptionPack.TopicFilters + pack.SubscriptionPacket = &subscriptionPack return pack, nil } +func (p SubscribePacket) Visit(v PacketVisitor) { + v.VisitSubscribe(p) +} + type SubackReasonCode byte const ( @@ -151,6 +185,10 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) { return pack, nil } +func (p UnsubscribePacket) Visit(v PacketVisitor) { + v.VisitUnsubscribe(p) +} + type UnsubackReasonCode byte const ( diff --git a/mqtt/packets/packets.go b/mqtt/packets/packets.go index 3210264..455c855 100644 --- a/mqtt/packets/packets.go +++ b/mqtt/packets/packets.go @@ -2,7 +2,6 @@ package packets import ( "bufio" - "bytes" "fmt" "io" @@ -10,17 +9,24 @@ import ( ) type PacketVisitor interface { - visitConnect(ConnectPacket) - visitPublish(PublishPacket) - visitDisconnect(DisconnectPacket) + VisitConnect(ConnectPacket) + VisitPublish(PublishPacket) + VisitDisconnect(DisconnectPacket) + VisitSubscribe(SubscribePacket) + VisitUnsubscribe(UnsubscribePacket) + VisitPing(PingreqPacket) + VisitPubackPacket(PubackPacket) + VisitPubrecPacket(PubrecPacket) + VisitPubrelPacket(PubrelPacket) + VisitPubcompPacket(PubcompPacket) } type ClientPacket interface { - visit(PacketVisitor) + Visit(PacketVisitor) } type ServerPacket interface { - Encode() (bytes.Buffer, error) + Write(w io.Writer) error } type PacketType byte @@ -45,12 +51,10 @@ const ( ) func ReadPacket(r *bufio.Reader) (*ClientPacket, error) { - println("AAAA") fixedHeader, err := r.ReadByte() if err != nil { return nil, err } - println("BBB") highestFourBits := uint((fixedHeader >> 4) & 0b1111) lowerFourBits := uint(fixedHeader & 0b1111) @@ -60,6 +64,7 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) { return nil, err } reader := io.LimitReader(r, int64(dataLength)) + control := controlPacket{ packetType: PacketType(highestFourBits), flags: lowerFourBits, @@ -74,6 +79,20 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) { packet, err = parsePublishPacket(control) case PacketTypeDisconnect: packet, err = parseDisconnectPacket(control) + case PacketTypeSubscribe: + packet, err = parseSubscribePacket(control) + case PacketTypeUnsubscribe: + packet, err = parseUnsubscribePacket(control) + case PacketTypePingreq: + packet, err = parsePingreq(control) + case PacketTypePuback: + panic("Puback packet parsing unimplemented") + case PacketTypePubrec: + panic("Pubrec packet parsing unimplemented") + case PacketTypePubrel: + panic("Pubrel packet parsing unimplemented") + case PacketTypePubcomp: + panic("Pubcomp packet parsing unimplemented") default: return nil, fmt.Errorf("Unknown packet type %v", control.packetType) } diff --git a/session.go b/session.go new file mode 100644 index 0000000..f73c545 --- /dev/null +++ b/session.go @@ -0,0 +1,181 @@ +package main + +import ( + "encoding/base64" + "fmt" + "log" + "math/rand" + "time" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func Auth(username string, password []byte) bool { + return true +} + +type Session struct { + ClientID *string + + // Nullable + Connection *Connection + SubscriptionChannel chan packets.PublishPacket + + ExpiryInterval time.Duration + expireTimer time.Timer // TODO +} + +func NewSession(conn *Connection, p packets.ConnectPacket) Session { + sess := Session{} + sess.SubscriptionChannel = make(chan packets.PublishPacket) + + sess.Connect(conn, p) + return sess +} + +func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { + if s.Connection != nil { + //TODO + panic("Disconnect if already have a connection, unimplemented") + } + connAck := packets.ConnackPacket{} + + s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value) + + if p.ClientId != nil { + if s.ClientID == nil { + s.ClientID = genClientID() + } + connAck.Properties.AssignedClientIdentifier.Value = s.ClientID + } + + true := byte(1) + false := byte(0) + connAck.Properties.WildcardSubscriptionAvailable.Value = &true + + connAck.Properties.RetainAvailable.Value = &false + connAck.Properties.SharedSubscriptionAvailable.Value = &false + + + s.Connection = conn + err := s.Connection.sendPacket(connAck) + if err != nil { + panic(err) + } +} + +// Starts a loop the recieves and responds to packets +func (s *Session) HandlerLoop() { + for s.Connection != nil { + select { + case packet := <-s.Connection.PacketChannel: + packet.Visit(s) + case _ = <-s.Connection.ClientDisconnectedChan: + s.OnDisconnect() + case subMessage := <-s.SubscriptionChannel: + //TODO, log for now + log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage) + } + } +} + +func (s *Session) Disconnect() error { + panic("Disconnection unimplemented") + err := s.Connection.close() + if err != nil { + return err + } + s.OnDisconnect() + return nil +} + +func (s *Session) OnDisconnect() { + s.Connection = nil + s.resetExpireTimer() + log.Printf("Client disconnected, id: %s", *s.ClientID) +} + +// newTime is nullable +func (s *Session) updateExpireTimer(newTime *uint32) { + var expiry = uint32(0) + if newTime != nil { + expiry = *newTime + } else { + expiry = uint32(0) + } + s.ExpiryInterval = time.Duration(expiry) * time.Second + + if s.Connection == nil { + s.resetExpireTimer() + } +} +func (s *Session) resetExpireTimer() { + //s.expireTimer.Reset(s.ExpiryInterval) +} + +func genClientID() *string { + buf := make([]byte, 32) + _, err := rand.Read(buf) + if err != nil { + // I don't think this can actually happen but just in case panic + panic(fmt.Errorf("Failed to generate a client id, %e", err)) + } + id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf) + return &id +} + +func (s *Session) VisitConnect(_ packets.ConnectPacket) { + // ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION + s.Disconnect() +} + +func (s *Session) VisitPublish(p packets.PublishPacket) { + println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload)) + subs, lock := Subscriptions.GetSubscriptions(p.TopicName) + defer lock.Unlock() + + for _, sub := range subs { + go func(sub Subscription) {sub <- p}(sub) + } +} + +func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { + //TODO FINISH + // HANDLE CLIENT DISCONNECTING + s.OnDisconnect() +} + +func (s *Session) VisitSubscribe(p packets.SubscribePacket) { + //TODO FINISH + for _, filter := range p.TopicFilters { + Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel) + } +} + +func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPing(p packets.PingreqPacket) { + s.Connection.sendPacket(packets.PingrespPacket{}) +} + +func (s *Session) VisitPubackPacket(_ packets.PubackPacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { + panic("not implemented") // TODO: Implement +} diff --git a/subscription.go b/subscription.go new file mode 100644 index 0000000..63fee8b --- /dev/null +++ b/subscription.go @@ -0,0 +1,86 @@ +package main + +//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS + +import ( + "strings" + "sync" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode() + +type Subscription chan packets.PublishPacket + +type SubscriptionTreeNode struct { + subscriptions []Subscription + children map[string]*SubscriptionTreeNode + nodeLock sync.RWMutex +} +func NewSubscriptionTreeNode() *SubscriptionTreeNode { + s := SubscriptionTreeNode{} + s.children = make(map[string]*SubscriptionTreeNode) + return &s +} + +func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode { + if len(fields) == 0 { + return s + } + + field := fields[0] + + s.nodeLock.RLock() + _, exists := s.children[field] + // Insert a value into the map if one doesn't exist yet + if !exists { + // Can't upgrade a read lock so we need to unlock and + // check again, this time with a write lock + s.nodeLock.RUnlock() + s.nodeLock.Lock() + + _, exists = s.children[field] + if !exists { + s.children[field] = NewSubscriptionTreeNode() + } + s.nodeLock.Unlock() + s.nodeLock.RLock() + } + + child, _ := s.children[field] + s.nodeLock.RUnlock() + return child.findNode(fields[1:]) +} + +func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) { + node := s.findNode(topic.Fields) + node.nodeLock.Lock() + node.subscriptions = append(node.subscriptions, sub) + node.nodeLock.Unlock() +} + +func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) { + fields := strings.Split(topic,"/") + + child := s.findNode(fields) + locker := child.nodeLock.RLocker() + locker.Lock() + return child.subscriptions, locker +} + +func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) { + locker := s.nodeLock.RLocker() + s.nodeLock.RLock() + if len(topic) == 0 { + return s.subscriptions,locker + } + defer s.nodeLock.RUnlock() + + child, exists := s.children[topic[0]] + if exists { + return child.findMatchingRec(topic[1:]) + } else { + return []Subscription{},locker + } +}