diff --git a/main.go b/main.go index 12afe21..791a6b0 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "runtime/debug" "badat.dev/maeqtt/v2/mqtt/packets" + "badat.dev/maeqtt/v2/session" ) func main() { @@ -27,7 +28,7 @@ func main() { } func handleConnection(con net.Conn) { - defer handlePanic() + defer handlePanic(con) reader := bufio.NewReader(con) @@ -38,20 +39,33 @@ func handleConnection(con net.Conn) { } connect, isConn := (*packet).(packets.ConnectPacket) if !isConn { - log.Println("Didn't recieve a connet packet") - panic("TODO: Send a disconnect packet") + log.Println("Didn't recieve a connect packet") + err := packets.DisconnectPacket{ + ReasonCode: packets.DisconnectReasonCodeProtocolError, + }.Write(con) + if err != nil { + log.Println("Failed to disconnect after not recieving a connect packet", err) + } + return } - conn := NewConnection(connect, con) + conn := session.NewConnection(connect, con) + sess := session.NewSession(&conn, connect) - sess := NewSession(&conn, connect) sess.HandlerLoop() } -func handlePanic() { +func handlePanic(con net.Conn) { if r := recover(); r != nil { log.Println("Recovering from panic:", r) log.Println("Stack Trace:") debug.PrintStack() + + err := packets.DisconnectPacket{ + ReasonCode: packets.DisconnectReasonCodeImplErorr, + }.Write(con) + if err != nil { + log.Println("Failed to send a disconnect packet after recovering from panic", err) + } } } diff --git a/mqtt/GeneratedProperties.go b/mqtt/GeneratedProperties.go index 7a27990..c459093 100644 --- a/mqtt/GeneratedProperties.go +++ b/mqtt/GeneratedProperties.go @@ -1,6 +1,6 @@ package mqtt -// This code has been generated with the genProps.py script. Do not modify +// This code has been generated with the genProps.py script. Do not modify import "bufio" @@ -13,14 +13,13 @@ func (p PayloadFormatIndicator) id() int { } func (p *PayloadFormatIndicator) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type MessageExpiryInterval struct { value *uint32 @@ -31,14 +30,13 @@ func (p MessageExpiryInterval) id() int { } func (p *MessageExpiryInterval) parse(r *bufio.Reader) error { - val, err := decodeUint32(r) + val, err := decodeUint32(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ContentType struct { value *string @@ -49,14 +47,13 @@ func (p ContentType) id() int { } func (p *ContentType) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ResponseTopic struct { value *string @@ -67,14 +64,13 @@ func (p ResponseTopic) id() int { } func (p *ResponseTopic) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type CorrelationData struct { value *[]byte @@ -85,14 +81,13 @@ func (p CorrelationData) id() int { } func (p *CorrelationData) parse(r *bufio.Reader) error { - val, err := decodeBinaryData(r) + val, err := decodeBinaryData(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type SubscriptionIdentifier struct { value *int @@ -103,14 +98,13 @@ func (p SubscriptionIdentifier) id() int { } func (p *SubscriptionIdentifier) parse(r *bufio.Reader) error { - val, err := decodeVariableByteInt(r) + val, err := decodeVariableByteInt(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type SessionExpiryInterval struct { value *uint32 @@ -121,14 +115,13 @@ func (p SessionExpiryInterval) id() int { } func (p *SessionExpiryInterval) parse(r *bufio.Reader) error { - val, err := decodeUint32(r) + val, err := decodeUint32(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type AssignedClientIdentifier struct { value *string @@ -139,14 +132,13 @@ func (p AssignedClientIdentifier) id() int { } func (p *AssignedClientIdentifier) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ServerKeepAlive struct { value *uint16 @@ -157,14 +149,13 @@ func (p ServerKeepAlive) id() int { } func (p *ServerKeepAlive) parse(r *bufio.Reader) error { - val, err := decodeUint16(r) + val, err := decodeUint16(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type AuthenticationMethod struct { value *string @@ -175,14 +166,13 @@ func (p AuthenticationMethod) id() int { } func (p *AuthenticationMethod) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type AuthenticationData struct { value *[]byte @@ -193,14 +183,13 @@ func (p AuthenticationData) id() int { } func (p *AuthenticationData) parse(r *bufio.Reader) error { - val, err := decodeBinaryData(r) + val, err := decodeBinaryData(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type RequestProblemInformation struct { value *byte @@ -211,14 +200,13 @@ func (p RequestProblemInformation) id() int { } func (p *RequestProblemInformation) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type WillDelayInterval struct { value *uint32 @@ -229,14 +217,13 @@ func (p WillDelayInterval) id() int { } func (p *WillDelayInterval) parse(r *bufio.Reader) error { - val, err := decodeUint32(r) + val, err := decodeUint32(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type RequestResponseInformation struct { value *byte @@ -247,14 +234,13 @@ func (p RequestResponseInformation) id() int { } func (p *RequestResponseInformation) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ResponseInformation struct { value *string @@ -265,14 +251,13 @@ func (p ResponseInformation) id() int { } func (p *ResponseInformation) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ServerReference struct { value *string @@ -283,14 +268,13 @@ func (p ServerReference) id() int { } func (p *ServerReference) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ReasonString struct { value *string @@ -301,14 +285,13 @@ func (p ReasonString) id() int { } func (p *ReasonString) parse(r *bufio.Reader) error { - val, err := decodeUTF8String(r) + val, err := decodeUTF8String(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type ReceiveMaximum struct { value *uint16 @@ -319,14 +302,13 @@ func (p ReceiveMaximum) id() int { } func (p *ReceiveMaximum) parse(r *bufio.Reader) error { - val, err := decodeUint16(r) + val, err := decodeUint16(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type TopicAliasMaximum struct { value *uint16 @@ -337,14 +319,13 @@ func (p TopicAliasMaximum) id() int { } func (p *TopicAliasMaximum) parse(r *bufio.Reader) error { - val, err := decodeUint16(r) + val, err := decodeUint16(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type TopicAlias struct { value *uint16 @@ -355,14 +336,13 @@ func (p TopicAlias) id() int { } func (p *TopicAlias) parse(r *bufio.Reader) error { - val, err := decodeUint16(r) + val, err := decodeUint16(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type MaximumQoS struct { value *byte @@ -373,14 +353,13 @@ func (p MaximumQoS) id() int { } func (p *MaximumQoS) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type RetainAvailable struct { value *byte @@ -391,14 +370,13 @@ func (p RetainAvailable) id() int { } func (p *RetainAvailable) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type MaximumPacketSize struct { value *uint32 @@ -409,14 +387,13 @@ func (p MaximumPacketSize) id() int { } func (p *MaximumPacketSize) parse(r *bufio.Reader) error { - val, err := decodeUint32(r) + val, err := decodeUint32(r) if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type WildcardSubscriptionAvailable struct { value *byte @@ -427,14 +404,13 @@ func (p WildcardSubscriptionAvailable) id() int { } func (p *WildcardSubscriptionAvailable) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type SubscriptionIdentifierAvailable struct { value *byte @@ -445,14 +421,13 @@ func (p SubscriptionIdentifierAvailable) id() int { } func (p *SubscriptionIdentifierAvailable) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - type SharedSubscriptionAvailable struct { value *byte @@ -463,223 +438,250 @@ func (p SharedSubscriptionAvailable) id() int { } func (p *SharedSubscriptionAvailable) parse(r *bufio.Reader) error { - val, err := r.ReadByte() + val, err := r.ReadByte() if err != nil { return err } - p.value = &val - return nil + p.value = &val + return nil } - + type PublishPacketProperties struct { -PayloadFormatIndicator PayloadFormatIndicator -MessageExpiryInterval MessageExpiryInterval -ContentType ContentType -ResponseTopic ResponseTopic -CorrelationData CorrelationData -SubscriptionIdentifier SubscriptionIdentifier -TopicAlias TopicAlias -UserProperty UserProperty + PayloadFormatIndicator PayloadFormatIndicator + MessageExpiryInterval MessageExpiryInterval + ContentType ContentType + ResponseTopic ResponseTopic + CorrelationData CorrelationData + SubscriptionIdentifier SubscriptionIdentifier + TopicAlias TopicAlias + UserProperty UserProperty } + func (p *PublishPacketProperties) arrayOf() []Property { -return []Property { -&p.PayloadFormatIndicator, -&p.MessageExpiryInterval, -&p.ContentType, -&p.ResponseTopic, -&p.CorrelationData, -&p.SubscriptionIdentifier, -&p.TopicAlias, -&p.UserProperty, -} + return []Property{ + &p.PayloadFormatIndicator, + &p.MessageExpiryInterval, + &p.ContentType, + &p.ResponseTopic, + &p.CorrelationData, + &p.SubscriptionIdentifier, + &p.TopicAlias, + &p.UserProperty, + } } + type WillProperties struct { -PayloadFormatIndicator PayloadFormatIndicator -MessageExpiryInterval MessageExpiryInterval -ContentType ContentType -ResponseTopic ResponseTopic -CorrelationData CorrelationData -WillDelayInterval WillDelayInterval -UserProperty UserProperty + PayloadFormatIndicator PayloadFormatIndicator + MessageExpiryInterval MessageExpiryInterval + ContentType ContentType + ResponseTopic ResponseTopic + CorrelationData CorrelationData + WillDelayInterval WillDelayInterval + UserProperty UserProperty } + func (p *WillProperties) arrayOf() []Property { -return []Property { -&p.PayloadFormatIndicator, -&p.MessageExpiryInterval, -&p.ContentType, -&p.ResponseTopic, -&p.CorrelationData, -&p.WillDelayInterval, -&p.UserProperty, -} + return []Property{ + &p.PayloadFormatIndicator, + &p.MessageExpiryInterval, + &p.ContentType, + &p.ResponseTopic, + &p.CorrelationData, + &p.WillDelayInterval, + &p.UserProperty, + } } + type SubscribePacketProperties struct { -SubscriptionIdentifier SubscriptionIdentifier -UserProperty UserProperty + SubscriptionIdentifier SubscriptionIdentifier + UserProperty UserProperty } + func (p *SubscribePacketProperties) arrayOf() []Property { -return []Property { -&p.SubscriptionIdentifier, -&p.UserProperty, -} + return []Property{ + &p.SubscriptionIdentifier, + &p.UserProperty, + } } + type ConnectPacketProperties struct { -SessionExpiryInterval SessionExpiryInterval -AuthenticationMethod AuthenticationMethod -AuthenticationData AuthenticationData -RequestProblemInformation RequestProblemInformation -RequestResponseInformation RequestResponseInformation -ReceiveMaximum ReceiveMaximum -TopicAliasMaximum TopicAliasMaximum -UserProperty UserProperty -MaximumPacketSize MaximumPacketSize + SessionExpiryInterval SessionExpiryInterval + AuthenticationMethod AuthenticationMethod + AuthenticationData AuthenticationData + RequestProblemInformation RequestProblemInformation + RequestResponseInformation RequestResponseInformation + ReceiveMaximum ReceiveMaximum + TopicAliasMaximum TopicAliasMaximum + UserProperty UserProperty + MaximumPacketSize MaximumPacketSize } + func (p *ConnectPacketProperties) arrayOf() []Property { -return []Property { -&p.SessionExpiryInterval, -&p.AuthenticationMethod, -&p.AuthenticationData, -&p.RequestProblemInformation, -&p.RequestResponseInformation, -&p.ReceiveMaximum, -&p.TopicAliasMaximum, -&p.UserProperty, -&p.MaximumPacketSize, -} + return []Property{ + &p.SessionExpiryInterval, + &p.AuthenticationMethod, + &p.AuthenticationData, + &p.RequestProblemInformation, + &p.RequestResponseInformation, + &p.ReceiveMaximum, + &p.TopicAliasMaximum, + &p.UserProperty, + &p.MaximumPacketSize, + } } + type ConnackPacketProperties struct { -SessionExpiryInterval SessionExpiryInterval -AssignedClientIdentifier AssignedClientIdentifier -ServerKeepAlive ServerKeepAlive -AuthenticationMethod AuthenticationMethod -AuthenticationData AuthenticationData -ResponseInformation ResponseInformation -ServerReference ServerReference -ReasonString ReasonString -ReceiveMaximum ReceiveMaximum -TopicAliasMaximum TopicAliasMaximum -MaximumQoS MaximumQoS -RetainAvailable RetainAvailable -UserProperty UserProperty -MaximumPacketSize MaximumPacketSize -WildcardSubscriptionAvailable WildcardSubscriptionAvailable -SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable -SharedSubscriptionAvailable SharedSubscriptionAvailable + SessionExpiryInterval SessionExpiryInterval + AssignedClientIdentifier AssignedClientIdentifier + ServerKeepAlive ServerKeepAlive + AuthenticationMethod AuthenticationMethod + AuthenticationData AuthenticationData + ResponseInformation ResponseInformation + ServerReference ServerReference + ReasonString ReasonString + ReceiveMaximum ReceiveMaximum + TopicAliasMaximum TopicAliasMaximum + MaximumQoS MaximumQoS + RetainAvailable RetainAvailable + UserProperty UserProperty + MaximumPacketSize MaximumPacketSize + WildcardSubscriptionAvailable WildcardSubscriptionAvailable + SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable + SharedSubscriptionAvailable SharedSubscriptionAvailable } + func (p *ConnackPacketProperties) arrayOf() []Property { -return []Property { -&p.SessionExpiryInterval, -&p.AssignedClientIdentifier, -&p.ServerKeepAlive, -&p.AuthenticationMethod, -&p.AuthenticationData, -&p.ResponseInformation, -&p.ServerReference, -&p.ReasonString, -&p.ReceiveMaximum, -&p.TopicAliasMaximum, -&p.MaximumQoS, -&p.RetainAvailable, -&p.UserProperty, -&p.MaximumPacketSize, -&p.WildcardSubscriptionAvailable, -&p.SubscriptionIdentifierAvailable, -&p.SharedSubscriptionAvailable, -} + return []Property{ + &p.SessionExpiryInterval, + &p.AssignedClientIdentifier, + &p.ServerKeepAlive, + &p.AuthenticationMethod, + &p.AuthenticationData, + &p.ResponseInformation, + &p.ServerReference, + &p.ReasonString, + &p.ReceiveMaximum, + &p.TopicAliasMaximum, + &p.MaximumQoS, + &p.RetainAvailable, + &p.UserProperty, + &p.MaximumPacketSize, + &p.WildcardSubscriptionAvailable, + &p.SubscriptionIdentifierAvailable, + &p.SharedSubscriptionAvailable, + } } + type DisconnectPacketProperties struct { -SessionExpiryInterval SessionExpiryInterval -ServerReference ServerReference -ReasonString ReasonString -UserProperty UserProperty + SessionExpiryInterval SessionExpiryInterval + ServerReference ServerReference + ReasonString ReasonString + UserProperty UserProperty } + func (p *DisconnectPacketProperties) arrayOf() []Property { -return []Property { -&p.SessionExpiryInterval, -&p.ServerReference, -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.SessionExpiryInterval, + &p.ServerReference, + &p.ReasonString, + &p.UserProperty, + } } + type AuthPacketProperties struct { -AuthenticationMethod AuthenticationMethod -AuthenticationData AuthenticationData -ReasonString ReasonString -UserProperty UserProperty + AuthenticationMethod AuthenticationMethod + AuthenticationData AuthenticationData + ReasonString ReasonString + UserProperty UserProperty } + func (p *AuthPacketProperties) arrayOf() []Property { -return []Property { -&p.AuthenticationMethod, -&p.AuthenticationData, -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.AuthenticationMethod, + &p.AuthenticationData, + &p.ReasonString, + &p.UserProperty, + } } + type PubackPacketProperties struct { -ReasonString ReasonString -UserProperty UserProperty + ReasonString ReasonString + UserProperty UserProperty } + func (p *PubackPacketProperties) arrayOf() []Property { -return []Property { -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.ReasonString, + &p.UserProperty, + } } + type PubrecPacketProperties struct { -ReasonString ReasonString -UserProperty UserProperty + ReasonString ReasonString + UserProperty UserProperty } + func (p *PubrecPacketProperties) arrayOf() []Property { -return []Property { -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.ReasonString, + &p.UserProperty, + } } + type PubrelPacketProperties struct { -ReasonString ReasonString -UserProperty UserProperty + ReasonString ReasonString + UserProperty UserProperty } + func (p *PubrelPacketProperties) arrayOf() []Property { -return []Property { -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.ReasonString, + &p.UserProperty, + } } + type PubcompPacketProperties struct { -ReasonString ReasonString -UserProperty UserProperty + ReasonString ReasonString + UserProperty UserProperty } + func (p *PubcompPacketProperties) arrayOf() []Property { -return []Property { -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.ReasonString, + &p.UserProperty, + } } + type SubackPacketProperties struct { -ReasonString ReasonString -UserProperty UserProperty + ReasonString ReasonString + UserProperty UserProperty } + func (p *SubackPacketProperties) arrayOf() []Property { -return []Property { -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.ReasonString, + &p.UserProperty, + } } + type UnsubackPacketProperties struct { -ReasonString ReasonString -UserProperty UserProperty + ReasonString ReasonString + UserProperty UserProperty } + func (p *UnsubackPacketProperties) arrayOf() []Property { -return []Property { -&p.ReasonString, -&p.UserProperty, -} + return []Property{ + &p.ReasonString, + &p.UserProperty, + } } + type UnsubscribePacketProperties struct { -UserProperty UserProperty + UserProperty UserProperty } + func (p *UnsubscribePacketProperties) arrayOf() []Property { -return []Property { -&p.UserProperty, -} + return []Property{ + &p.UserProperty, + } } diff --git a/mqtt/packets/Disconnect.go b/mqtt/packets/Disconnect.go index 3e0be31..d5c63ec 100644 --- a/mqtt/packets/Disconnect.go +++ b/mqtt/packets/Disconnect.go @@ -60,9 +60,8 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) { r := bufio.NewReader(control.reader) - // If there is less then a byte in the reader assume the reason code == 0 - reason,err := r.ReadByte() + reason, err := r.ReadByte() if err == io.EOF { reason = 0 } else if err != nil { @@ -73,7 +72,7 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) { // If there are less than 2 bytes remaining in the reader assume that the packet has no properties _, err = r.Peek(2) if err == nil { - err = properties.ParseProperties(r,packet.Properties.ArrayOf()) + err = properties.ParseProperties(r, packet.Properties.ArrayOf()) } else if err != io.EOF { return packet, err } else if err == io.EOF { @@ -91,10 +90,10 @@ func (p DisconnectPacket) Write(w io.Writer) error { return err } - control := controlPacket { + control := controlPacket{ packetType: PacketTypeDisconnect, - flags: 0, - reader: buf, + flags: 0, + reader: buf, } return control.write(w) } diff --git a/mqtt/packets/Ping.go b/mqtt/packets/Ping.go index d6dd1c3..a794e31 100644 --- a/mqtt/packets/Ping.go +++ b/mqtt/packets/Ping.go @@ -6,8 +6,7 @@ import ( "io" ) - -type PingreqPacket struct {} +type PingreqPacket struct{} func parsePingreq(control controlPacket) (PingreqPacket, error) { packet := PingreqPacket{} @@ -18,7 +17,7 @@ func parsePingreq(control controlPacket) (PingreqPacket, error) { if control.flags != 0 { return packet, errors.New("Malformed connect packet") } - + return packet, nil } @@ -26,13 +25,13 @@ func (r PingreqPacket) Visit(p PacketVisitor) { p.VisitPing(r) } -type PingrespPacket struct {} +type PingrespPacket struct{} func (p PingrespPacket) Write(w io.Writer) error { - control := controlPacket { + control := controlPacket{ packetType: PacketTypePingresp, - flags: 0, - reader: bytes.NewReader([]byte{}), + flags: 0, + reader: bytes.NewReader([]byte{}), } return control.write(w) diff --git a/mqtt/packets/PubAckRecRel.go b/mqtt/packets/PubAckRecRel.go index 15355ab..eaf851d 100644 --- a/mqtt/packets/PubAckRecRel.go +++ b/mqtt/packets/PubAckRecRel.go @@ -1,11 +1,10 @@ package packets import ( - "io" "badat.dev/maeqtt/v2/mqtt/properties" + "io" ) - type PubackReasonCode byte const ( diff --git a/mqtt/packets/Subscriptions.go b/mqtt/packets/Subscriptions.go index dc91052..19610c9 100644 --- a/mqtt/packets/Subscriptions.go +++ b/mqtt/packets/Subscriptions.go @@ -2,6 +2,7 @@ package packets import ( "bufio" + "bytes" "errors" "io" "strings" @@ -15,7 +16,8 @@ type Topic struct { } var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic") -func parseTopic(topic_name string) (Topic, error) { + +func ParseTopic(topic_name string) (Topic, error) { topic := Topic{} fields := strings.Split(topic_name, "/") for i, field := range fields { @@ -45,7 +47,7 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) { return filter, err } - filter.Topic, err = parseTopic(topic_str) + filter.Topic, err = ParseTopic(topic_str) if err != nil { return filter, err } @@ -61,27 +63,32 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) { return filter, nil } -// Both sub and unsubscribe packets are identitcal so we can reuse the parsing logic -type SubscriptionPacket struct { +type SubscribePacket struct { PacketId uint16 TopicFilters []TopicFilter + Properties properties.SubscribePacketProperties } -func parseSubscriptionPacket(control controlPacket, props []properties.Property) (SubscriptionPacket, error) { - var err error +func parseSubscribePacket(control controlPacket) (SubscribePacket, error) { + if control.packetType != PacketTypeSubscribe { + panic("Wrong packet type for parseSubscribePacket") + } + + packet := SubscribePacket{} + r := bufio.NewReader(control.reader) - packet := SubscriptionPacket{} if control.flags != 2 { return packet, errors.New("Malformed subscription packet") } + var err error packet.PacketId, err = types.DecodeUint16(r) if err != nil { return packet, err } - err = properties.ParseProperties(r, props) + err = properties.ParseProperties(r, packet.Properties.ArrayOf()) if err != nil { return packet, err } @@ -100,33 +107,10 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property) return packet, nil } } - println("A") return packet, nil } -type SubscribePacket struct { - *SubscriptionPacket - props properties.SubscribePacketProperties -} - -/// CURRENTLY BROKEN - -// TODO FIXME AAAAA -func parseSubscribePacket(control controlPacket) (SubscribePacket, error) { - if control.packetType != PacketTypeSubscribe { - panic("Wrong packet type for parseSubscribePacket") - } - - pack := SubscribePacket{} - subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf()) - if err != nil { - return pack, err - } - pack.SubscriptionPacket = &subscriptionPack - return pack, nil -} - func (p SubscribePacket) Visit(v PacketVisitor) { v.VisitSubscribe(p) } @@ -143,31 +127,46 @@ const ( SubackReasonTopicFilterInvalid = 143 SubackReasonPacketIDInUse = 145 SubackReasonQuotaExceeded = 151 - SubackReasonSharedSubNotSupported = 151 - SubackReasonSubIDUnsupported = 151 - SubackReasonWildcardSubUnsupported = 151 + SubackReasonSharedSubNotSupported = 158 + SubackReasonSubIDUnsupported = 161 + SubackReasonWildcardSubUnsupported = 162 ) type SubAckPacket struct { PacketID uint16 Properties properties.SubackPacketProperties - Reason SubackReasonCode + Reason SubackReasonCode } - func (p SubAckPacket) Write(w io.Writer) error { - resp := pubRespPacket{ - PacketType: PacketTypeSuback, - PacketID: p.PacketID, - Properties: p.Properties.ArrayOf(), - Reason: byte(p.Reason), + buf := bytes.NewBuffer([]byte{}) + err := types.WriteUint16(buf, p.PacketID) + if err != nil { + return err } - return resp.Write(w) + + err = properties.WriteProps(buf, p.Properties.ArrayOf()) + if err != nil { + return err + } + + err = buf.WriteByte(byte(p.Reason)) + if err != nil { + return err + } + + conPack := controlPacket{ + packetType: PacketTypeSuback, + flags: 0, + reader: buf, + } + return conPack.write(w) } type UnsubscribePacket struct { - *SubscriptionPacket - props properties.UnsubscribePacketProperties + PacketID uint16 + Topics []Topic + Properties properties.UnsubscribePacketProperties } func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) { @@ -175,14 +174,41 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) { panic("Wrong packet type for parseSubscribePacket") } - pack := UnsubscribePacket{} - subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf()) - if err != nil { - return pack, err + packet := UnsubscribePacket{} + r := bufio.NewReader(control.reader) + + if control.flags != 2 { + return packet, errors.New("Malformed subscription packet") } - pack.PacketId = subscriptionPack.PacketId - pack.TopicFilters = subscriptionPack.TopicFilters - return pack, nil + + var err error + packet.PacketID, err = types.DecodeUint16(r) + if err != nil { + return packet, err + } + + err = properties.ParseProperties(r, packet.Properties.ArrayOf()) + if err != nil { + return packet, err + } + + for err != io.EOF { + topic_str, err := types.DecodeUTF8String(r) + if err != nil && err != io.EOF { + return packet, err + } else if err == io.EOF { + return packet, nil + } + + filter, err := ParseTopic(topic_str) + if err != nil { + return packet, err + } + + packet.Topics = append(packet.Topics, filter) + } + + return packet, nil } func (p UnsubscribePacket) Visit(v PacketVisitor) { @@ -192,12 +218,12 @@ func (p UnsubscribePacket) Visit(v PacketVisitor) { type UnsubackReasonCode byte const ( - UnsubackReasonSuccess PubackReasonCode = 0 - UnSubackReasonUnspecified = 128 - UnSubackReasonImplSpecificError = 131 - UnSubackReasonNotAuthorized = 135 - UnSubackReasonTopicFilterInvalid = 143 - UnSubackReasonPacketIDInUse = 145 + UnsubackReasonSuccess UnsubackReasonCode = 0 + UnSubackReasonUnspecified = 128 + UnSubackReasonImplSpecificError = 131 + UnSubackReasonNotAuthorized = 135 + UnSubackReasonTopicFilterInvalid = 143 + UnSubackReasonPacketIDInUse = 145 ) type UnsubAckPacket struct { @@ -206,13 +232,27 @@ type UnsubAckPacket struct { Reason UnsubackReasonCode } - func (p UnsubAckPacket) Write(w io.Writer) error { - resp := pubRespPacket{ - PacketType: PacketTypeUnsuback, - PacketID: p.PacketID, - Properties: p.Properties.ArrayOf(), - Reason: byte(p.Reason), + buf := bytes.NewBuffer([]byte{}) + err := types.WriteUint16(buf, p.PacketID) + if err != nil { + return err } - return resp.Write(w) + + err = properties.WriteProps(buf, p.Properties.ArrayOf()) + if err != nil { + return err + } + + err = buf.WriteByte(byte(p.Reason)) + if err != nil { + return err + } + + conPack := controlPacket{ + packetType: PacketTypeUnsuback, + flags: 0, + reader: buf, + } + return conPack.write(w) } diff --git a/mqtt/packets/ack.go b/mqtt/packets/ack.go index 61533a3..e84bb16 100644 --- a/mqtt/packets/ack.go +++ b/mqtt/packets/ack.go @@ -34,7 +34,7 @@ func (p pubRespPacket) Write(w io.Writer) error { } conPack := controlPacket{ - packetType: PacketTypePuback, + packetType: p.PacketType, flags: 0, reader: buf, } diff --git a/mqtt/packets/packets.go b/mqtt/packets/packets.go index 455c855..3a0ac72 100644 --- a/mqtt/packets/packets.go +++ b/mqtt/packets/packets.go @@ -64,7 +64,7 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) { return nil, err } reader := io.LimitReader(r, int64(dataLength)) - + control := controlPacket{ packetType: PacketType(highestFourBits), flags: lowerFourBits, diff --git a/mqtt/types/Decoding.go b/mqtt/types/Decoding.go index edd20e0..50111b8 100644 --- a/mqtt/types/Decoding.go +++ b/mqtt/types/Decoding.go @@ -57,7 +57,6 @@ func DecodeBinaryData(r *bufio.Reader) ([]byte, error) { return buffer, err } - func DecodeUTF8String(r *bufio.Reader) (string, error) { binary, err := DecodeBinaryData(r) return string(binary[:]), err diff --git a/mqtt/types/Encoding.go b/mqtt/types/Encoding.go index dd8dae1..a07baee 100644 --- a/mqtt/types/Encoding.go +++ b/mqtt/types/Encoding.go @@ -29,8 +29,8 @@ func WriteUint32(w io.Writer, v uint32) error { return err } - const uint32Max uint32 = ^uint32(0) + func WriteDataWithVarIntLen(w io.Writer, data []byte) error { if len(data) > int(uint32Max) { return errors.New("Tried to write more data than max varint size") @@ -46,6 +46,7 @@ func WriteDataWithVarIntLen(w io.Writer, data []byte) error { } const uint16Max uint16 = ^uint16(0) + func WriteBinaryData(w io.Writer, data []byte) error { if len(data) > int(uint16Max) { return errors.New("Tried to write more data than max uint16 size") @@ -64,7 +65,6 @@ func WriteUTF8String(w io.Writer, str string) error { return WriteBinaryData(w, []byte(str)) } - func WriteVariableByteInt(w io.Writer, v uint32) error { for { encodedByte := byte(v % 128) diff --git a/connection.go b/session/connection.go similarity index 99% rename from connection.go rename to session/connection.go index 1df95d4..d9eff1e 100644 --- a/connection.go +++ b/session/connection.go @@ -1,4 +1,4 @@ -package main +package session import ( "bufio" diff --git a/session.go b/session/session.go similarity index 60% rename from session.go rename to session/session.go index f73c545..65dcfab 100644 --- a/session.go +++ b/session/session.go @@ -1,13 +1,15 @@ -package main +package session import ( "encoding/base64" "fmt" + "io" "log" "math/rand" "time" "badat.dev/maeqtt/v2/mqtt/packets" + "badat.dev/maeqtt/v2/subscription" ) func init() { @@ -25,8 +27,10 @@ type Session struct { Connection *Connection SubscriptionChannel chan packets.PublishPacket - ExpiryInterval time.Duration - expireTimer time.Timer // TODO + ExpiryInterval time.Duration // TODO + expireTimer time.Timer // TODO + + freePacketID uint16 } func NewSession(conn *Connection, p packets.ConnectPacket) Session { @@ -39,8 +43,7 @@ func NewSession(conn *Connection, p packets.ConnectPacket) Session { func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { if s.Connection != nil { - //TODO - panic("Disconnect if already have a connection, unimplemented") + s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) } connAck := packets.ConnackPacket{} @@ -60,7 +63,6 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { connAck.Properties.RetainAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false - s.Connection = conn err := s.Connection.sendPacket(connAck) if err != nil { @@ -75,7 +77,7 @@ func (s *Session) HandlerLoop() { case packet := <-s.Connection.PacketChannel: packet.Visit(s) case _ = <-s.Connection.ClientDisconnectedChan: - s.OnDisconnect() + s.onDisconnect() case subMessage := <-s.SubscriptionChannel: //TODO, log for now log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage) @@ -83,22 +85,29 @@ func (s *Session) HandlerLoop() { } } -func (s *Session) Disconnect() error { - panic("Disconnection unimplemented") +func (s *Session) Disconnect(code packets.DisconnectReasonCode) error { + s.Connection.sendPacket(packets.DisconnectPacket{ + ReasonCode: code, + }) + err := s.Connection.close() if err != nil { return err } - s.OnDisconnect() + s.onDisconnect() return nil } -func (s *Session) OnDisconnect() { +func (s *Session) onDisconnect() { s.Connection = nil s.resetExpireTimer() log.Printf("Client disconnected, id: %s", *s.ClientID) } +func (s *Session) expireSession() { + subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel) +} + // newTime is nullable func (s *Session) updateExpireTimer(newTime *uint32) { var expiry = uint32(0) @@ -129,35 +138,67 @@ func genClientID() *string { } func (s *Session) VisitConnect(_ packets.ConnectPacket) { - // ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION - s.Disconnect() + // Disconnect, we handle the connect packet in Connect, + // this means that we have an estabilished connection already + log.Println("WARN: Got a connect packet on an already estabilished connection") + s.Disconnect(packets.DisconnectReasonCodeProtocolError) } func (s *Session) VisitPublish(p packets.PublishPacket) { - println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload)) - subs, lock := Subscriptions.GetSubscriptions(p.TopicName) + subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName) defer lock.Unlock() + if p.QOSLevel == 0 { + if p.PacketId != nil { + log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) + return + } + } else if p.QOSLevel == 1 { + var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess + if len(subs) == 0 { + reason = packets.PubackReasonCodeNoMatchingSubscribers + } + ack := packets.PubackPacket{ + PacketID: *p.PacketId, + Reason: reason, + } + s.Connection.sendPacket(ack) + } else if p.QOSLevel == 2 { + panic("UNIMPLEMENTED QOS level 2") + } for _, sub := range subs { - go func(sub Subscription) {sub <- p}(sub) + if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { + go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) + } } } func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { - //TODO FINISH - // HANDLE CLIENT DISCONNECTING - s.OnDisconnect() + err := s.Connection.close() + if err != nil && err != io.ErrClosedPipe { + log.Println("Error closing connection", err) + } + s.onDisconnect() } func (s *Session) VisitSubscribe(p packets.SubscribePacket) { - //TODO FINISH for _, filter := range p.TopicFilters { - Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel) + subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel) } + s.Connection.sendPacket(packets.SubAckPacket{ + PacketID: p.PacketId, + Reason: packets.SubackReasonGrantedQoSTwo, + }) } -func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) { - panic("not implemented") // TODO: Implement +func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) { + for _, topic := range p.Topics { + subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel) + } + s.Connection.sendPacket(packets.UnsubAckPacket{ + PacketID: p.PacketID, + Reason: packets.UnsubackReasonSuccess, + }) } func (s *Session) VisitPing(p packets.PingreqPacket) { @@ -179,3 +220,8 @@ func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) { func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { panic("not implemented") // TODO: Implement } + +func (s *Session) getFreePacketId() uint16 { + s.freePacketID += 1 + return s.freePacketID +} diff --git a/subscription.go b/subscription.go deleted file mode 100644 index 63fee8b..0000000 --- a/subscription.go +++ /dev/null @@ -1,86 +0,0 @@ -package main - -//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS - -import ( - "strings" - "sync" - - "badat.dev/maeqtt/v2/mqtt/packets" -) - -var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode() - -type Subscription chan packets.PublishPacket - -type SubscriptionTreeNode struct { - subscriptions []Subscription - children map[string]*SubscriptionTreeNode - nodeLock sync.RWMutex -} -func NewSubscriptionTreeNode() *SubscriptionTreeNode { - s := SubscriptionTreeNode{} - s.children = make(map[string]*SubscriptionTreeNode) - return &s -} - -func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode { - if len(fields) == 0 { - return s - } - - field := fields[0] - - s.nodeLock.RLock() - _, exists := s.children[field] - // Insert a value into the map if one doesn't exist yet - if !exists { - // Can't upgrade a read lock so we need to unlock and - // check again, this time with a write lock - s.nodeLock.RUnlock() - s.nodeLock.Lock() - - _, exists = s.children[field] - if !exists { - s.children[field] = NewSubscriptionTreeNode() - } - s.nodeLock.Unlock() - s.nodeLock.RLock() - } - - child, _ := s.children[field] - s.nodeLock.RUnlock() - return child.findNode(fields[1:]) -} - -func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) { - node := s.findNode(topic.Fields) - node.nodeLock.Lock() - node.subscriptions = append(node.subscriptions, sub) - node.nodeLock.Unlock() -} - -func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) { - fields := strings.Split(topic,"/") - - child := s.findNode(fields) - locker := child.nodeLock.RLocker() - locker.Lock() - return child.subscriptions, locker -} - -func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) { - locker := s.nodeLock.RLocker() - s.nodeLock.RLock() - if len(topic) == 0 { - return s.subscriptions,locker - } - defer s.nodeLock.RUnlock() - - child, exists := s.children[topic[0]] - if exists { - return child.findMatchingRec(topic[1:]) - } else { - return []Subscription{},locker - } -} diff --git a/subscription/subscription.go b/subscription/subscription.go new file mode 100644 index 0000000..191f859 --- /dev/null +++ b/subscription/subscription.go @@ -0,0 +1,106 @@ +package subscription + +//TODO WILDCARD SUBSCRIPTIONS + +import ( + "strings" + "sync" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode() + +type SubscriptionChannel chan packets.PublishPacket + +type Subscription struct { + SubscriptionChannel + packets.TopicFilter +} + +type SubscriptionTreeNode struct { + subscriptions []Subscription + children map[string]*SubscriptionTreeNode + nodeLock sync.RWMutex +} + +func newSubscriptionTreeNode() *SubscriptionTreeNode { + s := SubscriptionTreeNode{} + s.children = make(map[string]*SubscriptionTreeNode) + return &s +} + +func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode { + if len(fields) == 0 { + return s + } + + field := fields[0] + + s.nodeLock.RLock() + _, exists := s.children[field] + // Insert a value into the map if one doesn't exist yet + if !exists { + // Can't upgrade a read lock so we need to unlock and + // check again, this time with a write lock + s.nodeLock.RUnlock() + s.nodeLock.Lock() + + _, exists = s.children[field] + if !exists { + s.children[field] = newSubscriptionTreeNode() + } + s.nodeLock.Unlock() + s.nodeLock.RLock() + } + + child, _ := s.children[field] + s.nodeLock.RUnlock() + return child.findNode(fields[1:]) +} + +func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) { + for i, sub := range s.subscriptions { + if sub.SubscriptionChannel == subChan { + lst := len(s.subscriptions) - 1 + s.subscriptions[i] = s.subscriptions[lst] + s.subscriptions = s.subscriptions[:lst] + } + } +} + +func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) { + sub := Subscription{subChan, topicFilter} + + node := s.findNode(topicFilter.Topic.Fields) + node.nodeLock.Lock() + node.subscriptions = append(node.subscriptions, sub) + node.nodeLock.Unlock() +} + +func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) { + node := s.findNode(topic.Fields) + + node.nodeLock.Lock() + node.removeSubscription(subChan) + node.nodeLock.Unlock() +} + +func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) { + for _, node := range s.children { + node.nodeLock.Lock() + node.removeSubscription(subChan) + node.nodeLock.Unlock() + + node.RemoveSubsForChannel(subChan) + } +} + +func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) { + fields := strings.Split(topicName, "/") + + child := s.findNode(fields) + locker := child.nodeLock.RLocker() + locker.Lock() + return child.subscriptions, locker +} diff --git a/subscription/subscription_test.go b/subscription/subscription_test.go new file mode 100644 index 0000000..dfc442b --- /dev/null +++ b/subscription/subscription_test.go @@ -0,0 +1,27 @@ +package subscription + +import ( + "testing" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +func TestSubscribe(t *testing.T) { + tree := newSubscriptionTreeNode() + topic, _ := packets.ParseTopic("a/b/c") + channel := make(SubscriptionChannel) + topicFilter := packets.TopicFilter{ + Topic: topic, + MaxQoS: 1, + } + tree.Subscribe(topicFilter, channel) + subs, lock := tree.GetSubscriptions("a/b/c") + defer lock.Unlock() + + if len(subs) != 1 { + t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs)) + } + if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel { + t.Error("Error with data stored in a subscription") + } +}