maeqtt/mqtt/packets/Connect.go

125 lines
2.7 KiB
Go

package packets
import (
"badat.dev/maeqtt/v2/mqtt/properties"
"badat.dev/maeqtt/v2/mqtt/types"
"bufio"
"errors"
)
type Will struct {
Retain bool
Properties properties.WillProperties
}
type ConnectPacket struct {
ClientId *string
Username *string
Password *[]byte
CleanStart bool
KeepAliveInterval uint16
Will *Will // Optional
Properties properties.ConnectPacketProperties
}
func (c ConnectPacket) Visit(Visitor PacketVisitor) {
Visitor.VisitConnect(c)
}
func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
packet := ConnectPacket{}
if control.packetType != PacketTypeConnect {
panic("Wrong packet type for parseConnectPacket")
}
if control.flags != 0 {
return packet, errors.New("Malformed connect packet")
}
r := bufio.NewReader(control.reader)
protocolName, err := types.DecodeUTF8String(r)
if err != nil {
return packet, err
}
if protocolName != "MQTT" {
return ConnectPacket{}, errors.New("Malformed connect packet, invalid protocol name")
}
protocolVersion, err := r.ReadByte()
if err != nil {
return ConnectPacket{}, err
}
if protocolVersion != 5 {
return ConnectPacket{}, errors.New("Malformed connect packet, unsupported protocol version")
}
connectFlags, err := types.DecodeBits(r)
if err != nil {
return packet, err
}
userNameFlag := connectFlags[7]
passwordFlag := connectFlags[6]
willRetainFlag := connectFlags[5]
willFlag := connectFlags[2]
packet.CleanStart = connectFlags[1]
reserved := connectFlags[0]
if reserved {
return ConnectPacket{}, errors.New("Malformed connect packet, reserved connect flag set")
}
QOSLevel := types.BoolToUint(connectFlags[4])*2 + types.BoolToUint(connectFlags[3])
if QOSLevel > 3 {
return ConnectPacket{}, errors.New("Malformed connect packet, invalid QOS Level")
}
keepAlive, err := types.DecodeUint16(r)
if err != nil {
return packet, err
}
packet.KeepAliveInterval = keepAlive
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil {
return packet, err
}
// Parse payload(3.1.3)
clientId, err := types.DecodeUTF8String(r)
if err != nil {
return packet, err
}
packet.ClientId = &clientId
if willFlag {
packet.Will = &Will{}
err = properties.ParseProperties(r, packet.Will.Properties.ArrayOf())
if err != nil {
return packet, err
}
packet.Will.Retain = willRetainFlag
}
var username string
if userNameFlag {
username, err = types.DecodeUTF8String(r)
if err != nil {
return packet, err
}
packet.Username = &username
}
var password []byte
if passwordFlag {
password, err = types.DecodeBinaryData(r)
if err != nil {
return packet, err
}
packet.Password = &password
}
return packet, nil
}