124 lines
2.7 KiB
Go
124 lines
2.7 KiB
Go
package packets
|
|
|
|
import (
|
|
"badat.dev/maeqtt/v2/mqtt/properties"
|
|
"badat.dev/maeqtt/v2/mqtt/types"
|
|
"bufio"
|
|
"errors"
|
|
)
|
|
|
|
type Will struct {
|
|
retain bool
|
|
properties properties.WillProperties
|
|
}
|
|
|
|
type ConnectPacket struct {
|
|
ClientId *string
|
|
Username *string
|
|
Password *[]byte
|
|
CleanStart bool
|
|
KeepAliveInterval uint16
|
|
|
|
Will *Will // Optional
|
|
Properties properties.ConnectPacketProperties
|
|
}
|
|
|
|
func (c ConnectPacket) visit(visitor PacketVisitor) {
|
|
visitor.visitConnect(c)
|
|
}
|
|
|
|
func parseConnectPacket(control controlPacket) (ConnectPacket, error) {
|
|
packet := ConnectPacket{}
|
|
|
|
if control.packetType != PacketTypeConnect {
|
|
panic("Wrong packet type for parseConnectPacket")
|
|
}
|
|
if control.flags != 0 {
|
|
return packet, errors.New("Malformed connect packet")
|
|
}
|
|
|
|
r := bufio.NewReader(control.reader)
|
|
|
|
protocolName, err := types.DecodeUTF8String(r)
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
if protocolName != "MQTT" {
|
|
return ConnectPacket{}, errors.New("Malformed connect packet, invalid protocol name")
|
|
}
|
|
|
|
protocolVersion, err := r.ReadByte()
|
|
if err != nil {
|
|
return ConnectPacket{}, err
|
|
}
|
|
if protocolVersion != 5 {
|
|
return ConnectPacket{}, errors.New("Malformed connect packet, unsupported protocol version")
|
|
}
|
|
|
|
connectFlags, err := types.DecodeBits(r)
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
userNameFlag := connectFlags[7]
|
|
passwordFlag := connectFlags[6]
|
|
willRetainFlag := connectFlags[5]
|
|
willFlag := connectFlags[2]
|
|
packet.CleanStart = connectFlags[1]
|
|
reserved := connectFlags[0]
|
|
|
|
if reserved {
|
|
return ConnectPacket{}, errors.New("Malformed connect packet, reserved connect flag set")
|
|
}
|
|
|
|
QOSLevel := types.BoolToUint(connectFlags[4])*2 + types.BoolToUint(connectFlags[3])
|
|
if QOSLevel > 3 {
|
|
return ConnectPacket{}, errors.New("Malformed connect packet, invalid QOS Level")
|
|
}
|
|
|
|
keepAlive, err := types.DecodeUint16(r)
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
packet.KeepAliveInterval = keepAlive
|
|
|
|
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
|
|
// Parse payload(3.1.3)
|
|
clientId, err := types.DecodeUTF8String(r)
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
packet.ClientId = &clientId
|
|
|
|
if willFlag {
|
|
packet.Will = &Will{}
|
|
err = properties.ParseProperties(r, packet.Will.properties.ArrayOf())
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
packet.Will.retain = willRetainFlag
|
|
}
|
|
|
|
var username string
|
|
if userNameFlag {
|
|
username, err = types.DecodeUTF8String(r)
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
packet.Username = &username
|
|
}
|
|
|
|
var password []byte
|
|
if passwordFlag {
|
|
password, err = types.DecodeBinaryData(r)
|
|
if err != nil {
|
|
return packet, err
|
|
}
|
|
packet.Password = &password
|
|
}
|
|
|
|
return packet, nil
|
|
}
|