maeqtt/session/connection.go

99 lines
2.2 KiB
Go

package session
import (
"bufio"
"errors"
"io"
"log"
"badat.dev/maeqtt/v2/mqtt/packets"
)
type Connection struct {
MaxPacketSize *uint32
RecvMax uint16
TopicAliasMax uint16
WantsRespInf bool
WantsProblemInf bool
Will packets.Will
// TODO
//KeepAliveInterval time.Duration
//keepAliveTimer time.Timer
// Gets closed whenever the client disconnects
PacketChannel chan packets.ClientPacket
rw io.ReadWriteCloser
}
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
return packets.ReadPacket(bufio.NewReader(c.rw))
}
func (c *Connection) sendPacket(p packets.ServerPacket) error {
return p.Write(c.rw)
}
func (c *Connection) close() {
_ = c.rw.Close()
}
func (c *Connection) PacketReadLoop() {
for {
pack, err := c.readPacket()
if err != nil {
break
}
c.PacketChannel <- *pack
}
close(c.PacketChannel)
}
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
connReq := ConnectionRequest{}
conn := Connection{}
connReq.Connection = &conn
conn.rw = rw
packet, err := conn.readPacket()
conPack, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't receive a connect packet")
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(rw)
if err != nil {
log.Println("Failed to disconnect after not receiving a connect packet", err)
}
return connReq, FirstPackNotConnect
}
connReq.ConnectPakcet = conPack
if conPack.Properties.ReceiveMaximum.Value != nil {
conn.RecvMax = *conPack.Properties.ReceiveMaximum.Value
} else {
conn.RecvMax = 65535
}
conn.MaxPacketSize = conPack.Properties.MaximumPacketSize.Value
if conPack.Properties.TopicAliasMaximum.Value != nil {
conn.TopicAliasMax = *conPack.Properties.TopicAliasMaximum.Value
} else {
conn.TopicAliasMax = 0
}
if conPack.Properties.RequestProblemInformation.Value != nil {
conn.WantsRespInf = *conPack.Properties.RequestProblemInformation.Value != 0
} else {
conn.WantsRespInf = false
}
conn.PacketChannel = make(chan packets.ClientPacket, 1)
return connReq, err
}