Initial session handling
This commit is contained in:
parent
a927a5ac1a
commit
e9612e2430
12 changed files with 512 additions and 56 deletions
92
connection.go
Normal file
92
connection.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Connection struct {
|
||||||
|
MaxPacketSize *uint32
|
||||||
|
RecvMax uint16
|
||||||
|
TopicAliasMax uint16
|
||||||
|
WantsRespInf bool
|
||||||
|
WantsProblemInf bool
|
||||||
|
Will packets.Will
|
||||||
|
|
||||||
|
KeepAliveInterval time.Duration
|
||||||
|
keepAliveTicker time.Ticker
|
||||||
|
|
||||||
|
PacketChannel chan packets.ClientPacket
|
||||||
|
ClientDisconnectedChan chan bool
|
||||||
|
|
||||||
|
rw io.ReadWriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) resetKeepAlive() {
|
||||||
|
if c.KeepAliveInterval != 0 {
|
||||||
|
// TODO IMPLEMENT THIS
|
||||||
|
//s.keepAliveTicker.Reset(s.KeepAliveInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
|
||||||
|
return packets.ReadPacket(bufio.NewReader(c.rw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) sendPacket(p packets.ServerPacket) error {
|
||||||
|
c.resetKeepAlive()
|
||||||
|
return p.Write(c.rw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) close() error {
|
||||||
|
close(c.PacketChannel)
|
||||||
|
return c.rw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) packetReadLoop() {
|
||||||
|
for {
|
||||||
|
pack, err := c.readPacket()
|
||||||
|
if err == io.EOF {
|
||||||
|
c.ClientDisconnectedChan <- true
|
||||||
|
} else if err != nil {
|
||||||
|
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
|
||||||
|
} else {
|
||||||
|
c.PacketChannel <- *pack
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
|
||||||
|
conn := Connection{}
|
||||||
|
conn.rw = rw
|
||||||
|
|
||||||
|
if p.Properties.ReceiveMaximum.Value != nil {
|
||||||
|
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
|
||||||
|
} else {
|
||||||
|
conn.RecvMax = 65535
|
||||||
|
}
|
||||||
|
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
|
||||||
|
|
||||||
|
if p.Properties.TopicAliasMaximum.Value != nil {
|
||||||
|
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
|
||||||
|
} else {
|
||||||
|
conn.TopicAliasMax = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Properties.RequestProblemInformation.Value != nil {
|
||||||
|
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
|
||||||
|
} else {
|
||||||
|
conn.WantsRespInf = false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
|
||||||
|
|
||||||
|
conn.PacketChannel = make(chan packets.ClientPacket)
|
||||||
|
|
||||||
|
go conn.packetReadLoop()
|
||||||
|
return conn
|
||||||
|
}
|
8
go.mod
8
go.mod
|
@ -2,4 +2,10 @@ module badat.dev/maeqtt/v2
|
||||||
|
|
||||||
go 1.16
|
go 1.16
|
||||||
|
|
||||||
require github.com/gdexlab/go-render v1.0.1 // indirect
|
require (
|
||||||
|
github.com/gdexlab/go-render v1.0.1 // indirect
|
||||||
|
github.com/josharian/impl v1.1.0 // indirect
|
||||||
|
golang.org/x/mod v0.5.0 // indirect
|
||||||
|
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect
|
||||||
|
golang.org/x/tools v0.1.5 // indirect
|
||||||
|
)
|
||||||
|
|
37
go.sum
37
go.sum
|
@ -1,2 +1,39 @@
|
||||||
github.com/gdexlab/go-render v1.0.1 h1:rxqB3vo5s4n1kF0ySmoNeSPRYkEsyHgln4jFIQY7v0U=
|
github.com/gdexlab/go-render v1.0.1 h1:rxqB3vo5s4n1kF0ySmoNeSPRYkEsyHgln4jFIQY7v0U=
|
||||||
github.com/gdexlab/go-render v1.0.1/go.mod h1:wRi5nW2qfjiGj4mPukH4UV0IknS1cHD4VgFTmJX5JzM=
|
github.com/gdexlab/go-render v1.0.1/go.mod h1:wRi5nW2qfjiGj4mPukH4UV0IknS1cHD4VgFTmJX5JzM=
|
||||||
|
github.com/josharian/impl v1.1.0 h1:gafhg1OFVMq46ifdkBa8wp4hlGogjktjjA5h/2j4+2k=
|
||||||
|
github.com/josharian/impl v1.1.0/go.mod h1:SQ6aJMP6xsJpGSD/36IIqrUdigLCYe9bz/9o5AKm6Aw=
|
||||||
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
|
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
|
golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q=
|
||||||
|
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw=
|
||||||
|
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
|
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA=
|
||||||
|
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||||
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|
55
main.go
55
main.go
|
@ -4,10 +4,9 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime/debug"
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
"badat.dev/maeqtt/v2/mqtt/properties"
|
|
||||||
"github.com/gdexlab/go-render/render" // For testing
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -28,37 +27,31 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleConnection(con net.Conn) {
|
func handleConnection(con net.Conn) {
|
||||||
defer closeConnection(con)
|
defer handlePanic()
|
||||||
|
|
||||||
for {
|
reader := bufio.NewReader(con)
|
||||||
reader := bufio.NewReader(con)
|
|
||||||
packet, err := packets.ReadPacket(reader)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Error reading packet ", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
log.Println(render.AsCode(packet))
|
|
||||||
clientId := "aa"
|
|
||||||
resp := packets.ConnackPacket{
|
|
||||||
ResonCode: packets.ConnectReasonCodeSuccess,
|
|
||||||
SessionPresent: false,
|
|
||||||
Properties: properties.ConnackPacketProperties{
|
|
||||||
AssignedClientIdentifier: properties.AssignedClientIdentifier{Value: &clientId},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err = resp.Write(con)
|
|
||||||
log.Println("Wrote response")
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Error writing response ", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
packet, err := packets.ReadPacket(reader)
|
||||||
|
|
||||||
func closeConnection(con net.Conn) {
|
|
||||||
err := con.Close()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to close connection", err)
|
log.Println("Error reading packet ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
connect, isConn := (*packet).(packets.ConnectPacket)
|
||||||
|
if !isConn {
|
||||||
|
log.Println("Didn't recieve a connet packet")
|
||||||
|
panic("TODO: Send a disconnect packet")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := NewConnection(connect, con)
|
||||||
|
|
||||||
|
sess := NewSession(&conn, connect)
|
||||||
|
sess.HandlerLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlePanic() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Println("Recovering from panic:", r)
|
||||||
|
log.Println("Stack Trace:")
|
||||||
|
debug.PrintStack()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Will struct {
|
type Will struct {
|
||||||
retain bool
|
Retain bool
|
||||||
properties properties.WillProperties
|
Properties properties.WillProperties
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConnectPacket struct {
|
type ConnectPacket struct {
|
||||||
|
@ -23,8 +23,8 @@ type ConnectPacket struct {
|
||||||
Properties properties.ConnectPacketProperties
|
Properties properties.ConnectPacketProperties
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnectPacket) visit(visitor PacketVisitor) {
|
func (c ConnectPacket) Visit(Visitor PacketVisitor) {
|
||||||
visitor.visitConnect(c)
|
Visitor.VisitConnect(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
|
func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
|
||||||
|
@ -95,11 +95,11 @@ func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
|
||||||
|
|
||||||
if willFlag {
|
if willFlag {
|
||||||
packet.Will = &Will{}
|
packet.Will = &Will{}
|
||||||
err = properties.ParseProperties(r, packet.Will.properties.ArrayOf())
|
err = properties.ParseProperties(r, packet.Will.Properties.ArrayOf())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return packet, err
|
return packet, err
|
||||||
}
|
}
|
||||||
packet.Will.retain = willRetainFlag
|
packet.Will.Retain = willRetainFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
var username string
|
var username string
|
||||||
|
|
|
@ -99,6 +99,6 @@ func (p DisconnectPacket) Write(w io.Writer) error {
|
||||||
return control.write(w)
|
return control.write(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p DisconnectPacket) visit(v PacketVisitor) {
|
func (p DisconnectPacket) Visit(v PacketVisitor) {
|
||||||
v.visitDisconnect(p)
|
v.VisitDisconnect(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
type PingreqPacket struct {}
|
type PingreqPacket struct {}
|
||||||
|
|
||||||
func ParesPingreq(control controlPacket) (PingreqPacket, error) {
|
func parsePingreq(control controlPacket) (PingreqPacket, error) {
|
||||||
packet := PingreqPacket{}
|
packet := PingreqPacket{}
|
||||||
|
|
||||||
if control.packetType != PacketTypePingreq {
|
if control.packetType != PacketTypePingreq {
|
||||||
|
@ -22,6 +22,10 @@ func ParesPingreq(control controlPacket) (PingreqPacket, error) {
|
||||||
return packet, nil
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r PingreqPacket) Visit(p PacketVisitor) {
|
||||||
|
p.VisitPing(r)
|
||||||
|
}
|
||||||
|
|
||||||
type PingrespPacket struct {}
|
type PingrespPacket struct {}
|
||||||
|
|
||||||
func (p PingrespPacket) Write(w io.Writer) error {
|
func (p PingrespPacket) Write(w io.Writer) error {
|
||||||
|
|
|
@ -19,8 +19,8 @@ type PublishPacket struct {
|
||||||
Properties properties.PublishPacketProperties
|
Properties properties.PublishPacketProperties
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PublishPacket) visit(v PacketVisitor) {
|
func (p PublishPacket) Visit(v PacketVisitor) {
|
||||||
v.visitPublish(p)
|
v.VisitPublish(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePublishPacket(control controlPacket) (PublishPacket, error) {
|
func parsePublishPacket(control controlPacket) (PublishPacket, error) {
|
||||||
|
|
|
@ -4,13 +4,32 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/mqtt/properties"
|
"badat.dev/maeqtt/v2/mqtt/properties"
|
||||||
"badat.dev/maeqtt/v2/mqtt/types"
|
"badat.dev/maeqtt/v2/mqtt/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Topic struct {
|
||||||
|
Fields []string
|
||||||
|
}
|
||||||
|
|
||||||
|
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
|
||||||
|
func parseTopic(topic_name string) (Topic, error) {
|
||||||
|
topic := Topic{}
|
||||||
|
fields := strings.Split(topic_name, "/")
|
||||||
|
for i, field := range fields {
|
||||||
|
if field == "#" && len(fields) > i+1 {
|
||||||
|
return topic, multiLevelWildcardNotLast
|
||||||
|
}
|
||||||
|
}
|
||||||
|
topic.Fields = fields
|
||||||
|
|
||||||
|
return topic, nil
|
||||||
|
}
|
||||||
|
|
||||||
type TopicFilter struct {
|
type TopicFilter struct {
|
||||||
Topic string
|
Topic Topic
|
||||||
MaxQoS uint
|
MaxQoS uint
|
||||||
NoLocal bool
|
NoLocal bool
|
||||||
RetainAsPublished bool
|
RetainAsPublished bool
|
||||||
|
@ -21,7 +40,12 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
|
||||||
filter := TopicFilter{}
|
filter := TopicFilter{}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
filter.Topic, err = types.DecodeUTF8String(r)
|
topic_str, err := types.DecodeUTF8String(r)
|
||||||
|
if err != nil {
|
||||||
|
return filter, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filter.Topic, err = parseTopic(topic_str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return filter, err
|
return filter, err
|
||||||
}
|
}
|
||||||
|
@ -69,10 +93,14 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
|
||||||
return packet, err
|
return packet, err
|
||||||
}
|
}
|
||||||
_, err = r.Peek(1)
|
_, err = r.Peek(1)
|
||||||
if err != nil || err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
return packet, err
|
return packet, err
|
||||||
}
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
return packet, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
println("A")
|
||||||
|
|
||||||
return packet, nil
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
@ -82,6 +110,9 @@ type SubscribePacket struct {
|
||||||
props properties.SubscribePacketProperties
|
props properties.SubscribePacketProperties
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// CURRENTLY BROKEN
|
||||||
|
|
||||||
|
// TODO FIXME AAAAA
|
||||||
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
||||||
if control.packetType != PacketTypeSubscribe {
|
if control.packetType != PacketTypeSubscribe {
|
||||||
panic("Wrong packet type for parseSubscribePacket")
|
panic("Wrong packet type for parseSubscribePacket")
|
||||||
|
@ -92,11 +123,14 @@ func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pack, err
|
return pack, err
|
||||||
}
|
}
|
||||||
pack.PacketId = subscriptionPack.PacketId
|
pack.SubscriptionPacket = &subscriptionPack
|
||||||
pack.TopicFilters = subscriptionPack.TopicFilters
|
|
||||||
return pack, nil
|
return pack, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p SubscribePacket) Visit(v PacketVisitor) {
|
||||||
|
v.VisitSubscribe(p)
|
||||||
|
}
|
||||||
|
|
||||||
type SubackReasonCode byte
|
type SubackReasonCode byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -151,6 +185,10 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
|
||||||
return pack, nil
|
return pack, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p UnsubscribePacket) Visit(v PacketVisitor) {
|
||||||
|
v.VisitUnsubscribe(p)
|
||||||
|
}
|
||||||
|
|
||||||
type UnsubackReasonCode byte
|
type UnsubackReasonCode byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -2,7 +2,6 @@ package packets
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
@ -10,17 +9,24 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketVisitor interface {
|
type PacketVisitor interface {
|
||||||
visitConnect(ConnectPacket)
|
VisitConnect(ConnectPacket)
|
||||||
visitPublish(PublishPacket)
|
VisitPublish(PublishPacket)
|
||||||
visitDisconnect(DisconnectPacket)
|
VisitDisconnect(DisconnectPacket)
|
||||||
|
VisitSubscribe(SubscribePacket)
|
||||||
|
VisitUnsubscribe(UnsubscribePacket)
|
||||||
|
VisitPing(PingreqPacket)
|
||||||
|
VisitPubackPacket(PubackPacket)
|
||||||
|
VisitPubrecPacket(PubrecPacket)
|
||||||
|
VisitPubrelPacket(PubrelPacket)
|
||||||
|
VisitPubcompPacket(PubcompPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientPacket interface {
|
type ClientPacket interface {
|
||||||
visit(PacketVisitor)
|
Visit(PacketVisitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerPacket interface {
|
type ServerPacket interface {
|
||||||
Encode() (bytes.Buffer, error)
|
Write(w io.Writer) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type PacketType byte
|
type PacketType byte
|
||||||
|
@ -45,12 +51,10 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
||||||
println("AAAA")
|
|
||||||
fixedHeader, err := r.ReadByte()
|
fixedHeader, err := r.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
println("BBB")
|
|
||||||
|
|
||||||
highestFourBits := uint((fixedHeader >> 4) & 0b1111)
|
highestFourBits := uint((fixedHeader >> 4) & 0b1111)
|
||||||
lowerFourBits := uint(fixedHeader & 0b1111)
|
lowerFourBits := uint(fixedHeader & 0b1111)
|
||||||
|
@ -60,6 +64,7 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
reader := io.LimitReader(r, int64(dataLength))
|
reader := io.LimitReader(r, int64(dataLength))
|
||||||
|
|
||||||
control := controlPacket{
|
control := controlPacket{
|
||||||
packetType: PacketType(highestFourBits),
|
packetType: PacketType(highestFourBits),
|
||||||
flags: lowerFourBits,
|
flags: lowerFourBits,
|
||||||
|
@ -74,6 +79,20 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
||||||
packet, err = parsePublishPacket(control)
|
packet, err = parsePublishPacket(control)
|
||||||
case PacketTypeDisconnect:
|
case PacketTypeDisconnect:
|
||||||
packet, err = parseDisconnectPacket(control)
|
packet, err = parseDisconnectPacket(control)
|
||||||
|
case PacketTypeSubscribe:
|
||||||
|
packet, err = parseSubscribePacket(control)
|
||||||
|
case PacketTypeUnsubscribe:
|
||||||
|
packet, err = parseUnsubscribePacket(control)
|
||||||
|
case PacketTypePingreq:
|
||||||
|
packet, err = parsePingreq(control)
|
||||||
|
case PacketTypePuback:
|
||||||
|
panic("Puback packet parsing unimplemented")
|
||||||
|
case PacketTypePubrec:
|
||||||
|
panic("Pubrec packet parsing unimplemented")
|
||||||
|
case PacketTypePubrel:
|
||||||
|
panic("Pubrel packet parsing unimplemented")
|
||||||
|
case PacketTypePubcomp:
|
||||||
|
panic("Pubcomp packet parsing unimplemented")
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknown packet type %v", control.packetType)
|
return nil, fmt.Errorf("Unknown packet type %v", control.packetType)
|
||||||
}
|
}
|
||||||
|
|
181
session.go
Normal file
181
session.go
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func Auth(username string, password []byte) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ClientID *string
|
||||||
|
|
||||||
|
// Nullable
|
||||||
|
Connection *Connection
|
||||||
|
SubscriptionChannel chan packets.PublishPacket
|
||||||
|
|
||||||
|
ExpiryInterval time.Duration
|
||||||
|
expireTimer time.Timer // TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
||||||
|
sess := Session{}
|
||||||
|
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||||
|
|
||||||
|
sess.Connect(conn, p)
|
||||||
|
return sess
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
||||||
|
if s.Connection != nil {
|
||||||
|
//TODO
|
||||||
|
panic("Disconnect if already have a connection, unimplemented")
|
||||||
|
}
|
||||||
|
connAck := packets.ConnackPacket{}
|
||||||
|
|
||||||
|
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
|
||||||
|
|
||||||
|
if p.ClientId != nil {
|
||||||
|
if s.ClientID == nil {
|
||||||
|
s.ClientID = genClientID()
|
||||||
|
}
|
||||||
|
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
true := byte(1)
|
||||||
|
false := byte(0)
|
||||||
|
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
||||||
|
|
||||||
|
connAck.Properties.RetainAvailable.Value = &false
|
||||||
|
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||||
|
|
||||||
|
|
||||||
|
s.Connection = conn
|
||||||
|
err := s.Connection.sendPacket(connAck)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Starts a loop the recieves and responds to packets
|
||||||
|
func (s *Session) HandlerLoop() {
|
||||||
|
for s.Connection != nil {
|
||||||
|
select {
|
||||||
|
case packet := <-s.Connection.PacketChannel:
|
||||||
|
packet.Visit(s)
|
||||||
|
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||||
|
s.OnDisconnect()
|
||||||
|
case subMessage := <-s.SubscriptionChannel:
|
||||||
|
//TODO, log for now
|
||||||
|
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Disconnect() error {
|
||||||
|
panic("Disconnection unimplemented")
|
||||||
|
err := s.Connection.close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.OnDisconnect()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) OnDisconnect() {
|
||||||
|
s.Connection = nil
|
||||||
|
s.resetExpireTimer()
|
||||||
|
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTime is nullable
|
||||||
|
func (s *Session) updateExpireTimer(newTime *uint32) {
|
||||||
|
var expiry = uint32(0)
|
||||||
|
if newTime != nil {
|
||||||
|
expiry = *newTime
|
||||||
|
} else {
|
||||||
|
expiry = uint32(0)
|
||||||
|
}
|
||||||
|
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
||||||
|
|
||||||
|
if s.Connection == nil {
|
||||||
|
s.resetExpireTimer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (s *Session) resetExpireTimer() {
|
||||||
|
//s.expireTimer.Reset(s.ExpiryInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genClientID() *string {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
_, err := rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
// I don't think this can actually happen but just in case panic
|
||||||
|
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
||||||
|
}
|
||||||
|
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
||||||
|
return &id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||||
|
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION
|
||||||
|
s.Disconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||||
|
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload))
|
||||||
|
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
for _, sub := range subs {
|
||||||
|
go func(sub Subscription) {sub <- p}(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||||
|
//TODO FINISH
|
||||||
|
// HANDLE CLIENT DISCONNECTING
|
||||||
|
s.OnDisconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||||
|
//TODO FINISH
|
||||||
|
for _, filter := range p.TopicFilters {
|
||||||
|
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||||
|
s.Connection.sendPacket(packets.PingrespPacket{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
86
subscription.go
Normal file
86
subscription.go
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
|
||||||
|
|
||||||
|
type Subscription chan packets.PublishPacket
|
||||||
|
|
||||||
|
type SubscriptionTreeNode struct {
|
||||||
|
subscriptions []Subscription
|
||||||
|
children map[string]*SubscriptionTreeNode
|
||||||
|
nodeLock sync.RWMutex
|
||||||
|
}
|
||||||
|
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
|
||||||
|
s := SubscriptionTreeNode{}
|
||||||
|
s.children = make(map[string]*SubscriptionTreeNode)
|
||||||
|
return &s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
||||||
|
if len(fields) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
field := fields[0]
|
||||||
|
|
||||||
|
s.nodeLock.RLock()
|
||||||
|
_, exists := s.children[field]
|
||||||
|
// Insert a value into the map if one doesn't exist yet
|
||||||
|
if !exists {
|
||||||
|
// Can't upgrade a read lock so we need to unlock and
|
||||||
|
// check again, this time with a write lock
|
||||||
|
s.nodeLock.RUnlock()
|
||||||
|
s.nodeLock.Lock()
|
||||||
|
|
||||||
|
_, exists = s.children[field]
|
||||||
|
if !exists {
|
||||||
|
s.children[field] = NewSubscriptionTreeNode()
|
||||||
|
}
|
||||||
|
s.nodeLock.Unlock()
|
||||||
|
s.nodeLock.RLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
child, _ := s.children[field]
|
||||||
|
s.nodeLock.RUnlock()
|
||||||
|
return child.findNode(fields[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
|
||||||
|
node := s.findNode(topic.Fields)
|
||||||
|
node.nodeLock.Lock()
|
||||||
|
node.subscriptions = append(node.subscriptions, sub)
|
||||||
|
node.nodeLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
|
||||||
|
fields := strings.Split(topic,"/")
|
||||||
|
|
||||||
|
child := s.findNode(fields)
|
||||||
|
locker := child.nodeLock.RLocker()
|
||||||
|
locker.Lock()
|
||||||
|
return child.subscriptions, locker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
|
||||||
|
locker := s.nodeLock.RLocker()
|
||||||
|
s.nodeLock.RLock()
|
||||||
|
if len(topic) == 0 {
|
||||||
|
return s.subscriptions,locker
|
||||||
|
}
|
||||||
|
defer s.nodeLock.RUnlock()
|
||||||
|
|
||||||
|
child, exists := s.children[topic[0]]
|
||||||
|
if exists {
|
||||||
|
return child.findMatchingRec(topic[1:])
|
||||||
|
} else {
|
||||||
|
return []Subscription{},locker
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue