Initial session handling
parent
a927a5ac1a
commit
e9612e2430
@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
MaxPacketSize *uint32
|
||||
RecvMax uint16
|
||||
TopicAliasMax uint16
|
||||
WantsRespInf bool
|
||||
WantsProblemInf bool
|
||||
Will packets.Will
|
||||
|
||||
KeepAliveInterval time.Duration
|
||||
keepAliveTicker time.Ticker
|
||||
|
||||
PacketChannel chan packets.ClientPacket
|
||||
ClientDisconnectedChan chan bool
|
||||
|
||||
rw io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (c *Connection) resetKeepAlive() {
|
||||
if c.KeepAliveInterval != 0 {
|
||||
// TODO IMPLEMENT THIS
|
||||
//s.keepAliveTicker.Reset(s.KeepAliveInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
|
||||
return packets.ReadPacket(bufio.NewReader(c.rw))
|
||||
}
|
||||
|
||||
func (c *Connection) sendPacket(p packets.ServerPacket) error {
|
||||
c.resetKeepAlive()
|
||||
return p.Write(c.rw)
|
||||
}
|
||||
|
||||
func (c *Connection) close() error {
|
||||
close(c.PacketChannel)
|
||||
return c.rw.Close()
|
||||
}
|
||||
|
||||
func (c *Connection) packetReadLoop() {
|
||||
for {
|
||||
pack, err := c.readPacket()
|
||||
if err == io.EOF {
|
||||
c.ClientDisconnectedChan <- true
|
||||
} else if err != nil {
|
||||
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
|
||||
} else {
|
||||
c.PacketChannel <- *pack
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
|
||||
conn := Connection{}
|
||||
conn.rw = rw
|
||||
|
||||
if p.Properties.ReceiveMaximum.Value != nil {
|
||||
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
|
||||
} else {
|
||||
conn.RecvMax = 65535
|
||||
}
|
||||
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
|
||||
|
||||
if p.Properties.TopicAliasMaximum.Value != nil {
|
||||
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
|
||||
} else {
|
||||
conn.TopicAliasMax = 0
|
||||
}
|
||||
|
||||
if p.Properties.RequestProblemInformation.Value != nil {
|
||||
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
|
||||
} else {
|
||||
conn.WantsRespInf = false
|
||||
}
|
||||
|
||||
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
|
||||
|
||||
conn.PacketChannel = make(chan packets.ClientPacket)
|
||||
|
||||
go conn.packetReadLoop()
|
||||
return conn
|
||||
}
|
@ -1,2 +1,39 @@
|
||||
github.com/gdexlab/go-render v1.0.1 h1:rxqB3vo5s4n1kF0ySmoNeSPRYkEsyHgln4jFIQY7v0U=
|
||||
github.com/gdexlab/go-render v1.0.1/go.mod h1:wRi5nW2qfjiGj4mPukH4UV0IknS1cHD4VgFTmJX5JzM=
|
||||
github.com/josharian/impl v1.1.0 h1:gafhg1OFVMq46ifdkBa8wp4hlGogjktjjA5h/2j4+2k=
|
||||
github.com/josharian/impl v1.1.0/go.mod h1:SQ6aJMP6xsJpGSD/36IIqrUdigLCYe9bz/9o5AKm6Aw=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.5.0 h1:UG21uOlmZabA4fW5i7ZX6bjw1xELEGg/ZLgZq9auk/Q=
|
||||
golang.org/x/mod v0.5.0/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw=
|
||||
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
@ -0,0 +1,181 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func Auth(username string, password []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ClientID *string
|
||||
|
||||
// Nullable
|
||||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
|
||||
ExpiryInterval time.Duration
|
||||
expireTimer time.Timer // TODO
|
||||
}
|
||||
|
||||
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
||||
sess := Session{}
|
||||
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||
|
||||
sess.Connect(conn, p)
|
||||
return sess
|
||||
}
|
||||
|
||||
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
||||
if s.Connection != nil {
|
||||
//TODO
|
||||
panic("Disconnect if already have a connection, unimplemented")
|
||||
}
|
||||
connAck := packets.ConnackPacket{}
|
||||
|
||||
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
|
||||
|
||||
if p.ClientId != nil {
|
||||
if s.ClientID == nil {
|
||||
s.ClientID = genClientID()
|
||||
}
|
||||
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||
}
|
||||
|
||||
true := byte(1)
|
||||
false := byte(0)
|
||||
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
||||
|
||||
connAck.Properties.RetainAvailable.Value = &false
|
||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||
|
||||
|
||||
s.Connection = conn
|
||||
err := s.Connection.sendPacket(connAck)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Starts a loop the recieves and responds to packets
|
||||
func (s *Session) HandlerLoop() {
|
||||
for s.Connection != nil {
|
||||
select {
|
||||
case packet := <-s.Connection.PacketChannel:
|
||||
packet.Visit(s)
|
||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||
s.OnDisconnect()
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
//TODO, log for now
|
||||
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect() error {
|
||||
panic("Disconnection unimplemented")
|
||||
err := s.Connection.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.OnDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) OnDisconnect() {
|
||||
s.Connection = nil
|
||||
s.resetExpireTimer()
|
||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||
}
|
||||
|
||||
// newTime is nullable
|
||||
func (s *Session) updateExpireTimer(newTime *uint32) {
|
||||
var expiry = uint32(0)
|
||||
if newTime != nil {
|
||||
expiry = *newTime
|
||||
} else {
|
||||
expiry = uint32(0)
|
||||
}
|
||||
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
||||
|
||||
if s.Connection == nil {
|
||||
s.resetExpireTimer()
|
||||
}
|
||||
}
|
||||
func (s *Session) resetExpireTimer() {
|
||||
//s.expireTimer.Reset(s.ExpiryInterval)
|
||||
}
|
||||
|
||||
func genClientID() *string {
|
||||
buf := make([]byte, 32)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
// I don't think this can actually happen but just in case panic
|
||||
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
||||
}
|
||||
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
||||
return &id
|
||||
}
|
||||
|
||||
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION
|
||||
s.Disconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload))
|
||||
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
|
||||
defer lock.Unlock()
|
||||
|
||||
for _, sub := range subs {
|
||||
go func(sub Subscription) {sub <- p}(sub)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||
//TODO FINISH
|
||||
// HANDLE CLIENT DISCONNECTING
|
||||
s.OnDisconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||
//TODO FINISH
|
||||
for _, filter := range p.TopicFilters {
|
||||
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||
s.Connection.sendPacket(packets.PingrespPacket{})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
package main
|
||||
|
||||
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
|
||||
|
||||
type Subscription chan packets.PublishPacket
|
||||
|
||||
type SubscriptionTreeNode struct {
|
||||
subscriptions []Subscription
|
||||
children map[string]*SubscriptionTreeNode
|
||||
nodeLock sync.RWMutex
|
||||
}
|
||||
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
|
||||
s := SubscriptionTreeNode{}
|
||||
s.children = make(map[string]*SubscriptionTreeNode)
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
||||
if len(fields) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
field := fields[0]
|
||||
|
||||
s.nodeLock.RLock()
|
||||
_, exists := s.children[field]
|
||||
// Insert a value into the map if one doesn't exist yet
|
||||
if !exists {
|
||||
// Can't upgrade a read lock so we need to unlock and
|
||||
// check again, this time with a write lock
|
||||
s.nodeLock.RUnlock()
|
||||
s.nodeLock.Lock()
|
||||
|
||||
_, exists = s.children[field]
|
||||
if !exists {
|
||||
s.children[field] = NewSubscriptionTreeNode()
|
||||
}
|
||||
s.nodeLock.Unlock()
|
||||
s.nodeLock.RLock()
|
||||
}
|
||||
|
||||
child, _ := s.children[field]
|
||||
s.nodeLock.RUnlock()
|
||||
return child.findNode(fields[1:])
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
|
||||
node := s.findNode(topic.Fields)
|
||||
node.nodeLock.Lock()
|
||||
node.subscriptions = append(node.subscriptions, sub)
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
|
||||
fields := strings.Split(topic,"/")
|
||||
|
||||
child := s.findNode(fields)
|
||||
locker := child.nodeLock.RLocker()
|
||||
locker.Lock()
|
||||
return child.subscriptions, locker
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
|
||||
locker := s.nodeLock.RLocker()
|
||||
s.nodeLock.RLock()
|
||||
if len(topic) == 0 {
|
||||
return s.subscriptions,locker
|
||||
}
|
||||
defer s.nodeLock.RUnlock()
|
||||
|
||||
child, exists := s.children[topic[0]]
|
||||
if exists {
|
||||
return child.findMatchingRec(topic[1:])
|
||||
} else {
|
||||
return []Subscription{},locker
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue