Initial session handling
This commit is contained in:
parent
a927a5ac1a
commit
e9612e2430
12 changed files with 512 additions and 56 deletions
92
connection.go
Normal file
92
connection.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
MaxPacketSize *uint32
|
||||
RecvMax uint16
|
||||
TopicAliasMax uint16
|
||||
WantsRespInf bool
|
||||
WantsProblemInf bool
|
||||
Will packets.Will
|
||||
|
||||
KeepAliveInterval time.Duration
|
||||
keepAliveTicker time.Ticker
|
||||
|
||||
PacketChannel chan packets.ClientPacket
|
||||
ClientDisconnectedChan chan bool
|
||||
|
||||
rw io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (c *Connection) resetKeepAlive() {
|
||||
if c.KeepAliveInterval != 0 {
|
||||
// TODO IMPLEMENT THIS
|
||||
//s.keepAliveTicker.Reset(s.KeepAliveInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
|
||||
return packets.ReadPacket(bufio.NewReader(c.rw))
|
||||
}
|
||||
|
||||
func (c *Connection) sendPacket(p packets.ServerPacket) error {
|
||||
c.resetKeepAlive()
|
||||
return p.Write(c.rw)
|
||||
}
|
||||
|
||||
func (c *Connection) close() error {
|
||||
close(c.PacketChannel)
|
||||
return c.rw.Close()
|
||||
}
|
||||
|
||||
func (c *Connection) packetReadLoop() {
|
||||
for {
|
||||
pack, err := c.readPacket()
|
||||
if err == io.EOF {
|
||||
c.ClientDisconnectedChan <- true
|
||||
} else if err != nil {
|
||||
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
|
||||
} else {
|
||||
c.PacketChannel <- *pack
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
|
||||
conn := Connection{}
|
||||
conn.rw = rw
|
||||
|
||||
if p.Properties.ReceiveMaximum.Value != nil {
|
||||
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
|
||||
} else {
|
||||
conn.RecvMax = 65535
|
||||
}
|
||||
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
|
||||
|
||||
if p.Properties.TopicAliasMaximum.Value != nil {
|
||||
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
|
||||
} else {
|
||||
conn.TopicAliasMax = 0
|
||||
}
|
||||
|
||||
if p.Properties.RequestProblemInformation.Value != nil {
|
||||
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
|
||||
} else {
|
||||
conn.WantsRespInf = false
|
||||
}
|
||||
|
||||
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
|
||||
|
||||
conn.PacketChannel = make(chan packets.ClientPacket)
|
||||
|
||||
go conn.packetReadLoop()
|
||||
return conn
|
||||
}
|
8
go.mod
8
go.mod
|
@ -2,4 +2,10 @@ module badat.dev/maeqtt/v2
|
|||
|
||||
go 1.16
|
||||
|
||||
require github.com/gdexlab/go-render v1.0.1 // indirect
|
||||
require (
|
||||
github.com/gdexlab/go-render v1.0.1 // indirect
|
||||
github.com/josharian/impl v1.1.0 // indirect
|
||||
golang.org/x/mod v0.5.0 // indirect
|
||||
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect
|
||||
golang.org/x/tools v0.1.5 // indirect
|
||||
)
|
||||
|
|
37
go.sum
37
go.sum
|
@ -1,2 +1,39 @@
|
|||
github.com/gdexlab/go-render v1.0.1 h1:rxqB3vo5s4n1kF0ySmoNeSPRYkEsyHgln4jFIQY7v0U=
|
||||
github.com/gdexlab/go-render v1.0.1/go.mod h1:wRi5nW2qfjiGj4mPukH4UV0IknS1cHD4VgFTmJX5JzM=
|
||||
github.com/josharian/impl v1.1.0 h1:gafhg1OFVMq46ifdkBa8wp4hlGogjktjjA5h/2j4+2k=
|
||||
github.com/josharian/impl v1.1.0/go.mod h1:SQ6aJMP6xsJpGSD/36IIqrUdigLCYe9bz/9o5AKm6Aw=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q=
|
||||
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw=
|
||||
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
55
main.go
55
main.go
|
@ -4,10 +4,9 @@ import (
|
|||
"bufio"
|
||||
"log"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
"badat.dev/maeqtt/v2/mqtt/properties"
|
||||
"github.com/gdexlab/go-render/render" // For testing
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -28,37 +27,31 @@ func main() {
|
|||
}
|
||||
|
||||
func handleConnection(con net.Conn) {
|
||||
defer closeConnection(con)
|
||||
defer handlePanic()
|
||||
|
||||
for {
|
||||
reader := bufio.NewReader(con)
|
||||
packet, err := packets.ReadPacket(reader)
|
||||
if err != nil {
|
||||
log.Println("Error reading packet ", err)
|
||||
break
|
||||
}
|
||||
log.Println(render.AsCode(packet))
|
||||
clientId := "aa"
|
||||
resp := packets.ConnackPacket{
|
||||
ResonCode: packets.ConnectReasonCodeSuccess,
|
||||
SessionPresent: false,
|
||||
Properties: properties.ConnackPacketProperties{
|
||||
AssignedClientIdentifier: properties.AssignedClientIdentifier{Value: &clientId},
|
||||
},
|
||||
}
|
||||
err = resp.Write(con)
|
||||
log.Println("Wrote response")
|
||||
if err != nil {
|
||||
log.Println("Error writing response ", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
reader := bufio.NewReader(con)
|
||||
|
||||
}
|
||||
|
||||
func closeConnection(con net.Conn) {
|
||||
err := con.Close()
|
||||
packet, err := packets.ReadPacket(reader)
|
||||
if err != nil {
|
||||
log.Println("Failed to close connection", err)
|
||||
log.Println("Error reading packet ", err)
|
||||
return
|
||||
}
|
||||
connect, isConn := (*packet).(packets.ConnectPacket)
|
||||
if !isConn {
|
||||
log.Println("Didn't recieve a connet packet")
|
||||
panic("TODO: Send a disconnect packet")
|
||||
}
|
||||
|
||||
conn := NewConnection(connect, con)
|
||||
|
||||
sess := NewSession(&conn, connect)
|
||||
sess.HandlerLoop()
|
||||
}
|
||||
|
||||
func handlePanic() {
|
||||
if r := recover(); r != nil {
|
||||
log.Println("Recovering from panic:", r)
|
||||
log.Println("Stack Trace:")
|
||||
debug.PrintStack()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,8 +8,8 @@ import (
|
|||
)
|
||||
|
||||
type Will struct {
|
||||
retain bool
|
||||
properties properties.WillProperties
|
||||
Retain bool
|
||||
Properties properties.WillProperties
|
||||
}
|
||||
|
||||
type ConnectPacket struct {
|
||||
|
@ -23,8 +23,8 @@ type ConnectPacket struct {
|
|||
Properties properties.ConnectPacketProperties
|
||||
}
|
||||
|
||||
func (c ConnectPacket) visit(visitor PacketVisitor) {
|
||||
visitor.visitConnect(c)
|
||||
func (c ConnectPacket) Visit(Visitor PacketVisitor) {
|
||||
Visitor.VisitConnect(c)
|
||||
}
|
||||
|
||||
func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
|
||||
|
@ -95,11 +95,11 @@ func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
|
|||
|
||||
if willFlag {
|
||||
packet.Will = &Will{}
|
||||
err = properties.ParseProperties(r, packet.Will.properties.ArrayOf())
|
||||
err = properties.ParseProperties(r, packet.Will.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
packet.Will.retain = willRetainFlag
|
||||
packet.Will.Retain = willRetainFlag
|
||||
}
|
||||
|
||||
var username string
|
||||
|
|
|
@ -99,6 +99,6 @@ func (p DisconnectPacket) Write(w io.Writer) error {
|
|||
return control.write(w)
|
||||
}
|
||||
|
||||
func (p DisconnectPacket) visit(v PacketVisitor) {
|
||||
v.visitDisconnect(p)
|
||||
func (p DisconnectPacket) Visit(v PacketVisitor) {
|
||||
v.VisitDisconnect(p)
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
|
||||
type PingreqPacket struct {}
|
||||
|
||||
func ParesPingreq(control controlPacket) (PingreqPacket, error) {
|
||||
func parsePingreq(control controlPacket) (PingreqPacket, error) {
|
||||
packet := PingreqPacket{}
|
||||
|
||||
if control.packetType != PacketTypePingreq {
|
||||
|
@ -22,6 +22,10 @@ func ParesPingreq(control controlPacket) (PingreqPacket, error) {
|
|||
return packet, nil
|
||||
}
|
||||
|
||||
func (r PingreqPacket) Visit(p PacketVisitor) {
|
||||
p.VisitPing(r)
|
||||
}
|
||||
|
||||
type PingrespPacket struct {}
|
||||
|
||||
func (p PingrespPacket) Write(w io.Writer) error {
|
||||
|
|
|
@ -19,8 +19,8 @@ type PublishPacket struct {
|
|||
Properties properties.PublishPacketProperties
|
||||
}
|
||||
|
||||
func (p PublishPacket) visit(v PacketVisitor) {
|
||||
v.visitPublish(p)
|
||||
func (p PublishPacket) Visit(v PacketVisitor) {
|
||||
v.VisitPublish(p)
|
||||
}
|
||||
|
||||
func parsePublishPacket(control controlPacket) (PublishPacket, error) {
|
||||
|
|
|
@ -4,13 +4,32 @@ import (
|
|||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/properties"
|
||||
"badat.dev/maeqtt/v2/mqtt/types"
|
||||
)
|
||||
|
||||
type Topic struct {
|
||||
Fields []string
|
||||
}
|
||||
|
||||
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
|
||||
func parseTopic(topic_name string) (Topic, error) {
|
||||
topic := Topic{}
|
||||
fields := strings.Split(topic_name, "/")
|
||||
for i, field := range fields {
|
||||
if field == "#" && len(fields) > i+1 {
|
||||
return topic, multiLevelWildcardNotLast
|
||||
}
|
||||
}
|
||||
topic.Fields = fields
|
||||
|
||||
return topic, nil
|
||||
}
|
||||
|
||||
type TopicFilter struct {
|
||||
Topic string
|
||||
Topic Topic
|
||||
MaxQoS uint
|
||||
NoLocal bool
|
||||
RetainAsPublished bool
|
||||
|
@ -21,7 +40,12 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
|
|||
filter := TopicFilter{}
|
||||
var err error
|
||||
|
||||
filter.Topic, err = types.DecodeUTF8String(r)
|
||||
topic_str, err := types.DecodeUTF8String(r)
|
||||
if err != nil {
|
||||
return filter, err
|
||||
}
|
||||
|
||||
filter.Topic, err = parseTopic(topic_str)
|
||||
if err != nil {
|
||||
return filter, err
|
||||
}
|
||||
|
@ -69,10 +93,14 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
|
|||
return packet, err
|
||||
}
|
||||
_, err = r.Peek(1)
|
||||
if err != nil || err != io.EOF {
|
||||
if err != nil && err != io.EOF {
|
||||
return packet, err
|
||||
}
|
||||
if err == io.EOF {
|
||||
return packet, nil
|
||||
}
|
||||
}
|
||||
println("A")
|
||||
|
||||
return packet, nil
|
||||
}
|
||||
|
@ -82,6 +110,9 @@ type SubscribePacket struct {
|
|||
props properties.SubscribePacketProperties
|
||||
}
|
||||
|
||||
/// CURRENTLY BROKEN
|
||||
|
||||
// TODO FIXME AAAAA
|
||||
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
||||
if control.packetType != PacketTypeSubscribe {
|
||||
panic("Wrong packet type for parseSubscribePacket")
|
||||
|
@ -92,11 +123,14 @@ func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
|||
if err != nil {
|
||||
return pack, err
|
||||
}
|
||||
pack.PacketId = subscriptionPack.PacketId
|
||||
pack.TopicFilters = subscriptionPack.TopicFilters
|
||||
pack.SubscriptionPacket = &subscriptionPack
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (p SubscribePacket) Visit(v PacketVisitor) {
|
||||
v.VisitSubscribe(p)
|
||||
}
|
||||
|
||||
type SubackReasonCode byte
|
||||
|
||||
const (
|
||||
|
@ -151,6 +185,10 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
|
|||
return pack, nil
|
||||
}
|
||||
|
||||
func (p UnsubscribePacket) Visit(v PacketVisitor) {
|
||||
v.VisitUnsubscribe(p)
|
||||
}
|
||||
|
||||
type UnsubackReasonCode byte
|
||||
|
||||
const (
|
||||
|
|
|
@ -2,7 +2,6 @@ package packets
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
|
@ -10,17 +9,24 @@ import (
|
|||
)
|
||||
|
||||
type PacketVisitor interface {
|
||||
visitConnect(ConnectPacket)
|
||||
visitPublish(PublishPacket)
|
||||
visitDisconnect(DisconnectPacket)
|
||||
VisitConnect(ConnectPacket)
|
||||
VisitPublish(PublishPacket)
|
||||
VisitDisconnect(DisconnectPacket)
|
||||
VisitSubscribe(SubscribePacket)
|
||||
VisitUnsubscribe(UnsubscribePacket)
|
||||
VisitPing(PingreqPacket)
|
||||
VisitPubackPacket(PubackPacket)
|
||||
VisitPubrecPacket(PubrecPacket)
|
||||
VisitPubrelPacket(PubrelPacket)
|
||||
VisitPubcompPacket(PubcompPacket)
|
||||
}
|
||||
|
||||
type ClientPacket interface {
|
||||
visit(PacketVisitor)
|
||||
Visit(PacketVisitor)
|
||||
}
|
||||
|
||||
type ServerPacket interface {
|
||||
Encode() (bytes.Buffer, error)
|
||||
Write(w io.Writer) error
|
||||
}
|
||||
|
||||
type PacketType byte
|
||||
|
@ -45,12 +51,10 @@ const (
|
|||
)
|
||||
|
||||
func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
||||
println("AAAA")
|
||||
fixedHeader, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
println("BBB")
|
||||
|
||||
highestFourBits := uint((fixedHeader >> 4) & 0b1111)
|
||||
lowerFourBits := uint(fixedHeader & 0b1111)
|
||||
|
@ -60,6 +64,7 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
|||
return nil, err
|
||||
}
|
||||
reader := io.LimitReader(r, int64(dataLength))
|
||||
|
||||
control := controlPacket{
|
||||
packetType: PacketType(highestFourBits),
|
||||
flags: lowerFourBits,
|
||||
|
@ -74,6 +79,20 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
|
|||
packet, err = parsePublishPacket(control)
|
||||
case PacketTypeDisconnect:
|
||||
packet, err = parseDisconnectPacket(control)
|
||||
case PacketTypeSubscribe:
|
||||
packet, err = parseSubscribePacket(control)
|
||||
case PacketTypeUnsubscribe:
|
||||
packet, err = parseUnsubscribePacket(control)
|
||||
case PacketTypePingreq:
|
||||
packet, err = parsePingreq(control)
|
||||
case PacketTypePuback:
|
||||
panic("Puback packet parsing unimplemented")
|
||||
case PacketTypePubrec:
|
||||
panic("Pubrec packet parsing unimplemented")
|
||||
case PacketTypePubrel:
|
||||
panic("Pubrel packet parsing unimplemented")
|
||||
case PacketTypePubcomp:
|
||||
panic("Pubcomp packet parsing unimplemented")
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown packet type %v", control.packetType)
|
||||
}
|
||||
|
|
181
session.go
Normal file
181
session.go
Normal file
|
@ -0,0 +1,181 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func Auth(username string, password []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ClientID *string
|
||||
|
||||
// Nullable
|
||||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
|
||||
ExpiryInterval time.Duration
|
||||
expireTimer time.Timer // TODO
|
||||
}
|
||||
|
||||
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
||||
sess := Session{}
|
||||
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||
|
||||
sess.Connect(conn, p)
|
||||
return sess
|
||||
}
|
||||
|
||||
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
||||
if s.Connection != nil {
|
||||
//TODO
|
||||
panic("Disconnect if already have a connection, unimplemented")
|
||||
}
|
||||
connAck := packets.ConnackPacket{}
|
||||
|
||||
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
|
||||
|
||||
if p.ClientId != nil {
|
||||
if s.ClientID == nil {
|
||||
s.ClientID = genClientID()
|
||||
}
|
||||
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||
}
|
||||
|
||||
true := byte(1)
|
||||
false := byte(0)
|
||||
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
||||
|
||||
connAck.Properties.RetainAvailable.Value = &false
|
||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||
|
||||
|
||||
s.Connection = conn
|
||||
err := s.Connection.sendPacket(connAck)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Starts a loop the recieves and responds to packets
|
||||
func (s *Session) HandlerLoop() {
|
||||
for s.Connection != nil {
|
||||
select {
|
||||
case packet := <-s.Connection.PacketChannel:
|
||||
packet.Visit(s)
|
||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||
s.OnDisconnect()
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
//TODO, log for now
|
||||
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect() error {
|
||||
panic("Disconnection unimplemented")
|
||||
err := s.Connection.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.OnDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) OnDisconnect() {
|
||||
s.Connection = nil
|
||||
s.resetExpireTimer()
|
||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||
}
|
||||
|
||||
// newTime is nullable
|
||||
func (s *Session) updateExpireTimer(newTime *uint32) {
|
||||
var expiry = uint32(0)
|
||||
if newTime != nil {
|
||||
expiry = *newTime
|
||||
} else {
|
||||
expiry = uint32(0)
|
||||
}
|
||||
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
||||
|
||||
if s.Connection == nil {
|
||||
s.resetExpireTimer()
|
||||
}
|
||||
}
|
||||
func (s *Session) resetExpireTimer() {
|
||||
//s.expireTimer.Reset(s.ExpiryInterval)
|
||||
}
|
||||
|
||||
func genClientID() *string {
|
||||
buf := make([]byte, 32)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
// I don't think this can actually happen but just in case panic
|
||||
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
||||
}
|
||||
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
||||
return &id
|
||||
}
|
||||
|
||||
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION
|
||||
s.Disconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload))
|
||||
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
|
||||
defer lock.Unlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
go func(sub Subscription) {sub <- p}(sub)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||
//TODO FINISH
|
||||
// HANDLE CLIENT DISCONNECTING
|
||||
s.OnDisconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||
//TODO FINISH
|
||||
for _, filter := range p.TopicFilters {
|
||||
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||
s.Connection.sendPacket(packets.PingrespPacket{})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
86
subscription.go
Normal file
86
subscription.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package main
|
||||
|
||||
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
|
||||
|
||||
type Subscription chan packets.PublishPacket
|
||||
|
||||
type SubscriptionTreeNode struct {
|
||||
subscriptions []Subscription
|
||||
children map[string]*SubscriptionTreeNode
|
||||
nodeLock sync.RWMutex
|
||||
}
|
||||
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
|
||||
s := SubscriptionTreeNode{}
|
||||
s.children = make(map[string]*SubscriptionTreeNode)
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
||||
if len(fields) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
field := fields[0]
|
||||
|
||||
s.nodeLock.RLock()
|
||||
_, exists := s.children[field]
|
||||
// Insert a value into the map if one doesn't exist yet
|
||||
if !exists {
|
||||
// Can't upgrade a read lock so we need to unlock and
|
||||
// check again, this time with a write lock
|
||||
s.nodeLock.RUnlock()
|
||||
s.nodeLock.Lock()
|
||||
|
||||
_, exists = s.children[field]
|
||||
if !exists {
|
||||
s.children[field] = NewSubscriptionTreeNode()
|
||||
}
|
||||
s.nodeLock.Unlock()
|
||||
s.nodeLock.RLock()
|
||||
}
|
||||
|
||||
child, _ := s.children[field]
|
||||
s.nodeLock.RUnlock()
|
||||
return child.findNode(fields[1:])
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
|
||||
node := s.findNode(topic.Fields)
|
||||
node.nodeLock.Lock()
|
||||
node.subscriptions = append(node.subscriptions, sub)
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
|
||||
fields := strings.Split(topic,"/")
|
||||
|
||||
child := s.findNode(fields)
|
||||
locker := child.nodeLock.RLocker()
|
||||
locker.Lock()
|
||||
return child.subscriptions, locker
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
|
||||
locker := s.nodeLock.RLocker()
|
||||
s.nodeLock.RLock()
|
||||
if len(topic) == 0 {
|
||||
return s.subscriptions,locker
|
||||
}
|
||||
defer s.nodeLock.RUnlock()
|
||||
|
||||
child, exists := s.children[topic[0]]
|
||||
if exists {
|
||||
return child.findMatchingRec(topic[1:])
|
||||
} else {
|
||||
return []Subscription{},locker
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue