Improve folder strutcute

This commit is contained in:
bad 2021-10-01 22:18:48 +02:00
parent e9612e2430
commit 35879183be
15 changed files with 606 additions and 461 deletions

26
main.go
View file

@ -7,6 +7,7 @@ import (
"runtime/debug" "runtime/debug"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/session"
) )
func main() { func main() {
@ -27,7 +28,7 @@ func main() {
} }
func handleConnection(con net.Conn) { func handleConnection(con net.Conn) {
defer handlePanic() defer handlePanic(con)
reader := bufio.NewReader(con) reader := bufio.NewReader(con)
@ -38,20 +39,33 @@ func handleConnection(con net.Conn) {
} }
connect, isConn := (*packet).(packets.ConnectPacket) connect, isConn := (*packet).(packets.ConnectPacket)
if !isConn { if !isConn {
log.Println("Didn't recieve a connet packet") log.Println("Didn't recieve a connect packet")
panic("TODO: Send a disconnect packet") err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(con)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
}
return
} }
conn := NewConnection(connect, con) conn := session.NewConnection(connect, con)
sess := session.NewSession(&conn, connect)
sess := NewSession(&conn, connect)
sess.HandlerLoop() sess.HandlerLoop()
} }
func handlePanic() { func handlePanic(con net.Conn) {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Println("Recovering from panic:", r) log.Println("Recovering from panic:", r)
log.Println("Stack Trace:") log.Println("Stack Trace:")
debug.PrintStack() debug.PrintStack()
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeImplErorr,
}.Write(con)
if err != nil {
log.Println("Failed to send a disconnect packet after recovering from panic", err)
}
} }
} }

View file

@ -1,6 +1,6 @@
package mqtt package mqtt
// This code has been generated with the genProps.py script. Do not modify
// This code has been generated with the genProps.py script. Do not modify
import "bufio" import "bufio"
@ -13,14 +13,13 @@ func (p PayloadFormatIndicator) id() int {
} }
func (p *PayloadFormatIndicator) parse(r *bufio.Reader) error { func (p *PayloadFormatIndicator) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type MessageExpiryInterval struct { type MessageExpiryInterval struct {
value *uint32 value *uint32
@ -31,14 +30,13 @@ func (p MessageExpiryInterval) id() int {
} }
func (p *MessageExpiryInterval) parse(r *bufio.Reader) error { func (p *MessageExpiryInterval) parse(r *bufio.Reader) error {
val, err := decodeUint32(r) val, err := decodeUint32(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ContentType struct { type ContentType struct {
value *string value *string
@ -49,14 +47,13 @@ func (p ContentType) id() int {
} }
func (p *ContentType) parse(r *bufio.Reader) error { func (p *ContentType) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ResponseTopic struct { type ResponseTopic struct {
value *string value *string
@ -67,14 +64,13 @@ func (p ResponseTopic) id() int {
} }
func (p *ResponseTopic) parse(r *bufio.Reader) error { func (p *ResponseTopic) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type CorrelationData struct { type CorrelationData struct {
value *[]byte value *[]byte
@ -85,14 +81,13 @@ func (p CorrelationData) id() int {
} }
func (p *CorrelationData) parse(r *bufio.Reader) error { func (p *CorrelationData) parse(r *bufio.Reader) error {
val, err := decodeBinaryData(r) val, err := decodeBinaryData(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type SubscriptionIdentifier struct { type SubscriptionIdentifier struct {
value *int value *int
@ -103,14 +98,13 @@ func (p SubscriptionIdentifier) id() int {
} }
func (p *SubscriptionIdentifier) parse(r *bufio.Reader) error { func (p *SubscriptionIdentifier) parse(r *bufio.Reader) error {
val, err := decodeVariableByteInt(r) val, err := decodeVariableByteInt(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type SessionExpiryInterval struct { type SessionExpiryInterval struct {
value *uint32 value *uint32
@ -121,14 +115,13 @@ func (p SessionExpiryInterval) id() int {
} }
func (p *SessionExpiryInterval) parse(r *bufio.Reader) error { func (p *SessionExpiryInterval) parse(r *bufio.Reader) error {
val, err := decodeUint32(r) val, err := decodeUint32(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type AssignedClientIdentifier struct { type AssignedClientIdentifier struct {
value *string value *string
@ -139,14 +132,13 @@ func (p AssignedClientIdentifier) id() int {
} }
func (p *AssignedClientIdentifier) parse(r *bufio.Reader) error { func (p *AssignedClientIdentifier) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ServerKeepAlive struct { type ServerKeepAlive struct {
value *uint16 value *uint16
@ -157,14 +149,13 @@ func (p ServerKeepAlive) id() int {
} }
func (p *ServerKeepAlive) parse(r *bufio.Reader) error { func (p *ServerKeepAlive) parse(r *bufio.Reader) error {
val, err := decodeUint16(r) val, err := decodeUint16(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type AuthenticationMethod struct { type AuthenticationMethod struct {
value *string value *string
@ -175,14 +166,13 @@ func (p AuthenticationMethod) id() int {
} }
func (p *AuthenticationMethod) parse(r *bufio.Reader) error { func (p *AuthenticationMethod) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type AuthenticationData struct { type AuthenticationData struct {
value *[]byte value *[]byte
@ -193,14 +183,13 @@ func (p AuthenticationData) id() int {
} }
func (p *AuthenticationData) parse(r *bufio.Reader) error { func (p *AuthenticationData) parse(r *bufio.Reader) error {
val, err := decodeBinaryData(r) val, err := decodeBinaryData(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type RequestProblemInformation struct { type RequestProblemInformation struct {
value *byte value *byte
@ -211,14 +200,13 @@ func (p RequestProblemInformation) id() int {
} }
func (p *RequestProblemInformation) parse(r *bufio.Reader) error { func (p *RequestProblemInformation) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type WillDelayInterval struct { type WillDelayInterval struct {
value *uint32 value *uint32
@ -229,14 +217,13 @@ func (p WillDelayInterval) id() int {
} }
func (p *WillDelayInterval) parse(r *bufio.Reader) error { func (p *WillDelayInterval) parse(r *bufio.Reader) error {
val, err := decodeUint32(r) val, err := decodeUint32(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type RequestResponseInformation struct { type RequestResponseInformation struct {
value *byte value *byte
@ -247,14 +234,13 @@ func (p RequestResponseInformation) id() int {
} }
func (p *RequestResponseInformation) parse(r *bufio.Reader) error { func (p *RequestResponseInformation) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ResponseInformation struct { type ResponseInformation struct {
value *string value *string
@ -265,14 +251,13 @@ func (p ResponseInformation) id() int {
} }
func (p *ResponseInformation) parse(r *bufio.Reader) error { func (p *ResponseInformation) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ServerReference struct { type ServerReference struct {
value *string value *string
@ -283,14 +268,13 @@ func (p ServerReference) id() int {
} }
func (p *ServerReference) parse(r *bufio.Reader) error { func (p *ServerReference) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ReasonString struct { type ReasonString struct {
value *string value *string
@ -301,14 +285,13 @@ func (p ReasonString) id() int {
} }
func (p *ReasonString) parse(r *bufio.Reader) error { func (p *ReasonString) parse(r *bufio.Reader) error {
val, err := decodeUTF8String(r) val, err := decodeUTF8String(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type ReceiveMaximum struct { type ReceiveMaximum struct {
value *uint16 value *uint16
@ -319,14 +302,13 @@ func (p ReceiveMaximum) id() int {
} }
func (p *ReceiveMaximum) parse(r *bufio.Reader) error { func (p *ReceiveMaximum) parse(r *bufio.Reader) error {
val, err := decodeUint16(r) val, err := decodeUint16(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type TopicAliasMaximum struct { type TopicAliasMaximum struct {
value *uint16 value *uint16
@ -337,14 +319,13 @@ func (p TopicAliasMaximum) id() int {
} }
func (p *TopicAliasMaximum) parse(r *bufio.Reader) error { func (p *TopicAliasMaximum) parse(r *bufio.Reader) error {
val, err := decodeUint16(r) val, err := decodeUint16(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type TopicAlias struct { type TopicAlias struct {
value *uint16 value *uint16
@ -355,14 +336,13 @@ func (p TopicAlias) id() int {
} }
func (p *TopicAlias) parse(r *bufio.Reader) error { func (p *TopicAlias) parse(r *bufio.Reader) error {
val, err := decodeUint16(r) val, err := decodeUint16(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type MaximumQoS struct { type MaximumQoS struct {
value *byte value *byte
@ -373,14 +353,13 @@ func (p MaximumQoS) id() int {
} }
func (p *MaximumQoS) parse(r *bufio.Reader) error { func (p *MaximumQoS) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type RetainAvailable struct { type RetainAvailable struct {
value *byte value *byte
@ -391,14 +370,13 @@ func (p RetainAvailable) id() int {
} }
func (p *RetainAvailable) parse(r *bufio.Reader) error { func (p *RetainAvailable) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type MaximumPacketSize struct { type MaximumPacketSize struct {
value *uint32 value *uint32
@ -409,14 +387,13 @@ func (p MaximumPacketSize) id() int {
} }
func (p *MaximumPacketSize) parse(r *bufio.Reader) error { func (p *MaximumPacketSize) parse(r *bufio.Reader) error {
val, err := decodeUint32(r) val, err := decodeUint32(r)
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type WildcardSubscriptionAvailable struct { type WildcardSubscriptionAvailable struct {
value *byte value *byte
@ -427,14 +404,13 @@ func (p WildcardSubscriptionAvailable) id() int {
} }
func (p *WildcardSubscriptionAvailable) parse(r *bufio.Reader) error { func (p *WildcardSubscriptionAvailable) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type SubscriptionIdentifierAvailable struct { type SubscriptionIdentifierAvailable struct {
value *byte value *byte
@ -445,14 +421,13 @@ func (p SubscriptionIdentifierAvailable) id() int {
} }
func (p *SubscriptionIdentifierAvailable) parse(r *bufio.Reader) error { func (p *SubscriptionIdentifierAvailable) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type SharedSubscriptionAvailable struct { type SharedSubscriptionAvailable struct {
value *byte value *byte
@ -463,223 +438,250 @@ func (p SharedSubscriptionAvailable) id() int {
} }
func (p *SharedSubscriptionAvailable) parse(r *bufio.Reader) error { func (p *SharedSubscriptionAvailable) parse(r *bufio.Reader) error {
val, err := r.ReadByte() val, err := r.ReadByte()
if err != nil { if err != nil {
return err return err
} }
p.value = &val p.value = &val
return nil return nil
} }
type PublishPacketProperties struct { type PublishPacketProperties struct {
PayloadFormatIndicator PayloadFormatIndicator PayloadFormatIndicator PayloadFormatIndicator
MessageExpiryInterval MessageExpiryInterval MessageExpiryInterval MessageExpiryInterval
ContentType ContentType ContentType ContentType
ResponseTopic ResponseTopic ResponseTopic ResponseTopic
CorrelationData CorrelationData CorrelationData CorrelationData
SubscriptionIdentifier SubscriptionIdentifier SubscriptionIdentifier SubscriptionIdentifier
TopicAlias TopicAlias TopicAlias TopicAlias
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PublishPacketProperties) arrayOf() []Property { func (p *PublishPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.PayloadFormatIndicator, &p.PayloadFormatIndicator,
&p.MessageExpiryInterval, &p.MessageExpiryInterval,
&p.ContentType, &p.ContentType,
&p.ResponseTopic, &p.ResponseTopic,
&p.CorrelationData, &p.CorrelationData,
&p.SubscriptionIdentifier, &p.SubscriptionIdentifier,
&p.TopicAlias, &p.TopicAlias,
&p.UserProperty, &p.UserProperty,
} }
} }
type WillProperties struct { type WillProperties struct {
PayloadFormatIndicator PayloadFormatIndicator PayloadFormatIndicator PayloadFormatIndicator
MessageExpiryInterval MessageExpiryInterval MessageExpiryInterval MessageExpiryInterval
ContentType ContentType ContentType ContentType
ResponseTopic ResponseTopic ResponseTopic ResponseTopic
CorrelationData CorrelationData CorrelationData CorrelationData
WillDelayInterval WillDelayInterval WillDelayInterval WillDelayInterval
UserProperty UserProperty UserProperty UserProperty
} }
func (p *WillProperties) arrayOf() []Property { func (p *WillProperties) arrayOf() []Property {
return []Property { return []Property{
&p.PayloadFormatIndicator, &p.PayloadFormatIndicator,
&p.MessageExpiryInterval, &p.MessageExpiryInterval,
&p.ContentType, &p.ContentType,
&p.ResponseTopic, &p.ResponseTopic,
&p.CorrelationData, &p.CorrelationData,
&p.WillDelayInterval, &p.WillDelayInterval,
&p.UserProperty, &p.UserProperty,
} }
} }
type SubscribePacketProperties struct { type SubscribePacketProperties struct {
SubscriptionIdentifier SubscriptionIdentifier SubscriptionIdentifier SubscriptionIdentifier
UserProperty UserProperty UserProperty UserProperty
} }
func (p *SubscribePacketProperties) arrayOf() []Property { func (p *SubscribePacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.SubscriptionIdentifier, &p.SubscriptionIdentifier,
&p.UserProperty, &p.UserProperty,
} }
} }
type ConnectPacketProperties struct { type ConnectPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval
AuthenticationMethod AuthenticationMethod AuthenticationMethod AuthenticationMethod
AuthenticationData AuthenticationData AuthenticationData AuthenticationData
RequestProblemInformation RequestProblemInformation RequestProblemInformation RequestProblemInformation
RequestResponseInformation RequestResponseInformation RequestResponseInformation RequestResponseInformation
ReceiveMaximum ReceiveMaximum ReceiveMaximum ReceiveMaximum
TopicAliasMaximum TopicAliasMaximum TopicAliasMaximum TopicAliasMaximum
UserProperty UserProperty UserProperty UserProperty
MaximumPacketSize MaximumPacketSize MaximumPacketSize MaximumPacketSize
} }
func (p *ConnectPacketProperties) arrayOf() []Property { func (p *ConnectPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.SessionExpiryInterval, &p.SessionExpiryInterval,
&p.AuthenticationMethod, &p.AuthenticationMethod,
&p.AuthenticationData, &p.AuthenticationData,
&p.RequestProblemInformation, &p.RequestProblemInformation,
&p.RequestResponseInformation, &p.RequestResponseInformation,
&p.ReceiveMaximum, &p.ReceiveMaximum,
&p.TopicAliasMaximum, &p.TopicAliasMaximum,
&p.UserProperty, &p.UserProperty,
&p.MaximumPacketSize, &p.MaximumPacketSize,
} }
} }
type ConnackPacketProperties struct { type ConnackPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval
AssignedClientIdentifier AssignedClientIdentifier AssignedClientIdentifier AssignedClientIdentifier
ServerKeepAlive ServerKeepAlive ServerKeepAlive ServerKeepAlive
AuthenticationMethod AuthenticationMethod AuthenticationMethod AuthenticationMethod
AuthenticationData AuthenticationData AuthenticationData AuthenticationData
ResponseInformation ResponseInformation ResponseInformation ResponseInformation
ServerReference ServerReference ServerReference ServerReference
ReasonString ReasonString ReasonString ReasonString
ReceiveMaximum ReceiveMaximum ReceiveMaximum ReceiveMaximum
TopicAliasMaximum TopicAliasMaximum TopicAliasMaximum TopicAliasMaximum
MaximumQoS MaximumQoS MaximumQoS MaximumQoS
RetainAvailable RetainAvailable RetainAvailable RetainAvailable
UserProperty UserProperty UserProperty UserProperty
MaximumPacketSize MaximumPacketSize MaximumPacketSize MaximumPacketSize
WildcardSubscriptionAvailable WildcardSubscriptionAvailable WildcardSubscriptionAvailable WildcardSubscriptionAvailable
SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable
SharedSubscriptionAvailable SharedSubscriptionAvailable SharedSubscriptionAvailable SharedSubscriptionAvailable
} }
func (p *ConnackPacketProperties) arrayOf() []Property { func (p *ConnackPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.SessionExpiryInterval, &p.SessionExpiryInterval,
&p.AssignedClientIdentifier, &p.AssignedClientIdentifier,
&p.ServerKeepAlive, &p.ServerKeepAlive,
&p.AuthenticationMethod, &p.AuthenticationMethod,
&p.AuthenticationData, &p.AuthenticationData,
&p.ResponseInformation, &p.ResponseInformation,
&p.ServerReference, &p.ServerReference,
&p.ReasonString, &p.ReasonString,
&p.ReceiveMaximum, &p.ReceiveMaximum,
&p.TopicAliasMaximum, &p.TopicAliasMaximum,
&p.MaximumQoS, &p.MaximumQoS,
&p.RetainAvailable, &p.RetainAvailable,
&p.UserProperty, &p.UserProperty,
&p.MaximumPacketSize, &p.MaximumPacketSize,
&p.WildcardSubscriptionAvailable, &p.WildcardSubscriptionAvailable,
&p.SubscriptionIdentifierAvailable, &p.SubscriptionIdentifierAvailable,
&p.SharedSubscriptionAvailable, &p.SharedSubscriptionAvailable,
} }
} }
type DisconnectPacketProperties struct { type DisconnectPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval
ServerReference ServerReference ServerReference ServerReference
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *DisconnectPacketProperties) arrayOf() []Property { func (p *DisconnectPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.SessionExpiryInterval, &p.SessionExpiryInterval,
&p.ServerReference, &p.ServerReference,
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type AuthPacketProperties struct { type AuthPacketProperties struct {
AuthenticationMethod AuthenticationMethod AuthenticationMethod AuthenticationMethod
AuthenticationData AuthenticationData AuthenticationData AuthenticationData
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *AuthPacketProperties) arrayOf() []Property { func (p *AuthPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.AuthenticationMethod, &p.AuthenticationMethod,
&p.AuthenticationData, &p.AuthenticationData,
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubackPacketProperties struct { type PubackPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubackPacketProperties) arrayOf() []Property { func (p *PubackPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubrecPacketProperties struct { type PubrecPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubrecPacketProperties) arrayOf() []Property { func (p *PubrecPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubrelPacketProperties struct { type PubrelPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubrelPacketProperties) arrayOf() []Property { func (p *PubrelPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubcompPacketProperties struct { type PubcompPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubcompPacketProperties) arrayOf() []Property { func (p *PubcompPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type SubackPacketProperties struct { type SubackPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *SubackPacketProperties) arrayOf() []Property { func (p *SubackPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type UnsubackPacketProperties struct { type UnsubackPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *UnsubackPacketProperties) arrayOf() []Property { func (p *UnsubackPacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type UnsubscribePacketProperties struct { type UnsubscribePacketProperties struct {
UserProperty UserProperty UserProperty UserProperty
} }
func (p *UnsubscribePacketProperties) arrayOf() []Property { func (p *UnsubscribePacketProperties) arrayOf() []Property {
return []Property { return []Property{
&p.UserProperty, &p.UserProperty,
} }
} }

View file

@ -60,9 +60,8 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) {
r := bufio.NewReader(control.reader) r := bufio.NewReader(control.reader)
// If there is less then a byte in the reader assume the reason code == 0 // If there is less then a byte in the reader assume the reason code == 0
reason,err := r.ReadByte() reason, err := r.ReadByte()
if err == io.EOF { if err == io.EOF {
reason = 0 reason = 0
} else if err != nil { } else if err != nil {
@ -73,7 +72,7 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) {
// If there are less than 2 bytes remaining in the reader assume that the packet has no properties // If there are less than 2 bytes remaining in the reader assume that the packet has no properties
_, err = r.Peek(2) _, err = r.Peek(2)
if err == nil { if err == nil {
err = properties.ParseProperties(r,packet.Properties.ArrayOf()) err = properties.ParseProperties(r, packet.Properties.ArrayOf())
} else if err != io.EOF { } else if err != io.EOF {
return packet, err return packet, err
} else if err == io.EOF { } else if err == io.EOF {
@ -91,10 +90,10 @@ func (p DisconnectPacket) Write(w io.Writer) error {
return err return err
} }
control := controlPacket { control := controlPacket{
packetType: PacketTypeDisconnect, packetType: PacketTypeDisconnect,
flags: 0, flags: 0,
reader: buf, reader: buf,
} }
return control.write(w) return control.write(w)
} }

View file

@ -6,8 +6,7 @@ import (
"io" "io"
) )
type PingreqPacket struct{}
type PingreqPacket struct {}
func parsePingreq(control controlPacket) (PingreqPacket, error) { func parsePingreq(control controlPacket) (PingreqPacket, error) {
packet := PingreqPacket{} packet := PingreqPacket{}
@ -18,7 +17,7 @@ func parsePingreq(control controlPacket) (PingreqPacket, error) {
if control.flags != 0 { if control.flags != 0 {
return packet, errors.New("Malformed connect packet") return packet, errors.New("Malformed connect packet")
} }
return packet, nil return packet, nil
} }
@ -26,13 +25,13 @@ func (r PingreqPacket) Visit(p PacketVisitor) {
p.VisitPing(r) p.VisitPing(r)
} }
type PingrespPacket struct {} type PingrespPacket struct{}
func (p PingrespPacket) Write(w io.Writer) error { func (p PingrespPacket) Write(w io.Writer) error {
control := controlPacket { control := controlPacket{
packetType: PacketTypePingresp, packetType: PacketTypePingresp,
flags: 0, flags: 0,
reader: bytes.NewReader([]byte{}), reader: bytes.NewReader([]byte{}),
} }
return control.write(w) return control.write(w)

View file

@ -1,11 +1,10 @@
package packets package packets
import ( import (
"io"
"badat.dev/maeqtt/v2/mqtt/properties" "badat.dev/maeqtt/v2/mqtt/properties"
"io"
) )
type PubackReasonCode byte type PubackReasonCode byte
const ( const (

View file

@ -2,6 +2,7 @@ package packets
import ( import (
"bufio" "bufio"
"bytes"
"errors" "errors"
"io" "io"
"strings" "strings"
@ -15,7 +16,8 @@ type Topic struct {
} }
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic") var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
func parseTopic(topic_name string) (Topic, error) {
func ParseTopic(topic_name string) (Topic, error) {
topic := Topic{} topic := Topic{}
fields := strings.Split(topic_name, "/") fields := strings.Split(topic_name, "/")
for i, field := range fields { for i, field := range fields {
@ -45,7 +47,7 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
return filter, err return filter, err
} }
filter.Topic, err = parseTopic(topic_str) filter.Topic, err = ParseTopic(topic_str)
if err != nil { if err != nil {
return filter, err return filter, err
} }
@ -61,27 +63,32 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
return filter, nil return filter, nil
} }
// Both sub and unsubscribe packets are identitcal so we can reuse the parsing logic type SubscribePacket struct {
type SubscriptionPacket struct {
PacketId uint16 PacketId uint16
TopicFilters []TopicFilter TopicFilters []TopicFilter
Properties properties.SubscribePacketProperties
} }
func parseSubscriptionPacket(control controlPacket, props []properties.Property) (SubscriptionPacket, error) { func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
var err error if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
}
packet := SubscribePacket{}
r := bufio.NewReader(control.reader) r := bufio.NewReader(control.reader)
packet := SubscriptionPacket{}
if control.flags != 2 { if control.flags != 2 {
return packet, errors.New("Malformed subscription packet") return packet, errors.New("Malformed subscription packet")
} }
var err error
packet.PacketId, err = types.DecodeUint16(r) packet.PacketId, err = types.DecodeUint16(r)
if err != nil { if err != nil {
return packet, err return packet, err
} }
err = properties.ParseProperties(r, props) err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil { if err != nil {
return packet, err return packet, err
} }
@ -100,33 +107,10 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
return packet, nil return packet, nil
} }
} }
println("A")
return packet, nil return packet, nil
} }
type SubscribePacket struct {
*SubscriptionPacket
props properties.SubscribePacketProperties
}
/// CURRENTLY BROKEN
// TODO FIXME AAAAA
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
}
pack := SubscribePacket{}
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf())
if err != nil {
return pack, err
}
pack.SubscriptionPacket = &subscriptionPack
return pack, nil
}
func (p SubscribePacket) Visit(v PacketVisitor) { func (p SubscribePacket) Visit(v PacketVisitor) {
v.VisitSubscribe(p) v.VisitSubscribe(p)
} }
@ -143,31 +127,46 @@ const (
SubackReasonTopicFilterInvalid = 143 SubackReasonTopicFilterInvalid = 143
SubackReasonPacketIDInUse = 145 SubackReasonPacketIDInUse = 145
SubackReasonQuotaExceeded = 151 SubackReasonQuotaExceeded = 151
SubackReasonSharedSubNotSupported = 151 SubackReasonSharedSubNotSupported = 158
SubackReasonSubIDUnsupported = 151 SubackReasonSubIDUnsupported = 161
SubackReasonWildcardSubUnsupported = 151 SubackReasonWildcardSubUnsupported = 162
) )
type SubAckPacket struct { type SubAckPacket struct {
PacketID uint16 PacketID uint16
Properties properties.SubackPacketProperties Properties properties.SubackPacketProperties
Reason SubackReasonCode Reason SubackReasonCode
} }
func (p SubAckPacket) Write(w io.Writer) error { func (p SubAckPacket) Write(w io.Writer) error {
resp := pubRespPacket{ buf := bytes.NewBuffer([]byte{})
PacketType: PacketTypeSuback, err := types.WriteUint16(buf, p.PacketID)
PacketID: p.PacketID, if err != nil {
Properties: p.Properties.ArrayOf(), return err
Reason: byte(p.Reason),
} }
return resp.Write(w)
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
err = buf.WriteByte(byte(p.Reason))
if err != nil {
return err
}
conPack := controlPacket{
packetType: PacketTypeSuback,
flags: 0,
reader: buf,
}
return conPack.write(w)
} }
type UnsubscribePacket struct { type UnsubscribePacket struct {
*SubscriptionPacket PacketID uint16
props properties.UnsubscribePacketProperties Topics []Topic
Properties properties.UnsubscribePacketProperties
} }
func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) { func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
@ -175,14 +174,41 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
panic("Wrong packet type for parseSubscribePacket") panic("Wrong packet type for parseSubscribePacket")
} }
pack := UnsubscribePacket{} packet := UnsubscribePacket{}
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf()) r := bufio.NewReader(control.reader)
if err != nil {
return pack, err if control.flags != 2 {
return packet, errors.New("Malformed subscription packet")
} }
pack.PacketId = subscriptionPack.PacketId
pack.TopicFilters = subscriptionPack.TopicFilters var err error
return pack, nil packet.PacketID, err = types.DecodeUint16(r)
if err != nil {
return packet, err
}
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil {
return packet, err
}
for err != io.EOF {
topic_str, err := types.DecodeUTF8String(r)
if err != nil && err != io.EOF {
return packet, err
} else if err == io.EOF {
return packet, nil
}
filter, err := ParseTopic(topic_str)
if err != nil {
return packet, err
}
packet.Topics = append(packet.Topics, filter)
}
return packet, nil
} }
func (p UnsubscribePacket) Visit(v PacketVisitor) { func (p UnsubscribePacket) Visit(v PacketVisitor) {
@ -192,12 +218,12 @@ func (p UnsubscribePacket) Visit(v PacketVisitor) {
type UnsubackReasonCode byte type UnsubackReasonCode byte
const ( const (
UnsubackReasonSuccess PubackReasonCode = 0 UnsubackReasonSuccess UnsubackReasonCode = 0
UnSubackReasonUnspecified = 128 UnSubackReasonUnspecified = 128
UnSubackReasonImplSpecificError = 131 UnSubackReasonImplSpecificError = 131
UnSubackReasonNotAuthorized = 135 UnSubackReasonNotAuthorized = 135
UnSubackReasonTopicFilterInvalid = 143 UnSubackReasonTopicFilterInvalid = 143
UnSubackReasonPacketIDInUse = 145 UnSubackReasonPacketIDInUse = 145
) )
type UnsubAckPacket struct { type UnsubAckPacket struct {
@ -206,13 +232,27 @@ type UnsubAckPacket struct {
Reason UnsubackReasonCode Reason UnsubackReasonCode
} }
func (p UnsubAckPacket) Write(w io.Writer) error { func (p UnsubAckPacket) Write(w io.Writer) error {
resp := pubRespPacket{ buf := bytes.NewBuffer([]byte{})
PacketType: PacketTypeUnsuback, err := types.WriteUint16(buf, p.PacketID)
PacketID: p.PacketID, if err != nil {
Properties: p.Properties.ArrayOf(), return err
Reason: byte(p.Reason),
} }
return resp.Write(w)
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
err = buf.WriteByte(byte(p.Reason))
if err != nil {
return err
}
conPack := controlPacket{
packetType: PacketTypeUnsuback,
flags: 0,
reader: buf,
}
return conPack.write(w)
} }

View file

@ -34,7 +34,7 @@ func (p pubRespPacket) Write(w io.Writer) error {
} }
conPack := controlPacket{ conPack := controlPacket{
packetType: PacketTypePuback, packetType: p.PacketType,
flags: 0, flags: 0,
reader: buf, reader: buf,
} }

View file

@ -64,7 +64,7 @@ func ReadPacket(r *bufio.Reader) (*ClientPacket, error) {
return nil, err return nil, err
} }
reader := io.LimitReader(r, int64(dataLength)) reader := io.LimitReader(r, int64(dataLength))
control := controlPacket{ control := controlPacket{
packetType: PacketType(highestFourBits), packetType: PacketType(highestFourBits),
flags: lowerFourBits, flags: lowerFourBits,

View file

@ -57,7 +57,6 @@ func DecodeBinaryData(r *bufio.Reader) ([]byte, error) {
return buffer, err return buffer, err
} }
func DecodeUTF8String(r *bufio.Reader) (string, error) { func DecodeUTF8String(r *bufio.Reader) (string, error) {
binary, err := DecodeBinaryData(r) binary, err := DecodeBinaryData(r)
return string(binary[:]), err return string(binary[:]), err

View file

@ -29,8 +29,8 @@ func WriteUint32(w io.Writer, v uint32) error {
return err return err
} }
const uint32Max uint32 = ^uint32(0) const uint32Max uint32 = ^uint32(0)
func WriteDataWithVarIntLen(w io.Writer, data []byte) error { func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
if len(data) > int(uint32Max) { if len(data) > int(uint32Max) {
return errors.New("Tried to write more data than max varint size") return errors.New("Tried to write more data than max varint size")
@ -46,6 +46,7 @@ func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
} }
const uint16Max uint16 = ^uint16(0) const uint16Max uint16 = ^uint16(0)
func WriteBinaryData(w io.Writer, data []byte) error { func WriteBinaryData(w io.Writer, data []byte) error {
if len(data) > int(uint16Max) { if len(data) > int(uint16Max) {
return errors.New("Tried to write more data than max uint16 size") return errors.New("Tried to write more data than max uint16 size")
@ -64,7 +65,6 @@ func WriteUTF8String(w io.Writer, str string) error {
return WriteBinaryData(w, []byte(str)) return WriteBinaryData(w, []byte(str))
} }
func WriteVariableByteInt(w io.Writer, v uint32) error { func WriteVariableByteInt(w io.Writer, v uint32) error {
for { for {
encodedByte := byte(v % 128) encodedByte := byte(v % 128)

View file

@ -1,4 +1,4 @@
package main package session
import ( import (
"bufio" "bufio"

View file

@ -1,13 +1,15 @@
package main package session
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"log" "log"
"math/rand" "math/rand"
"time" "time"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/subscription"
) )
func init() { func init() {
@ -25,8 +27,10 @@ type Session struct {
Connection *Connection Connection *Connection
SubscriptionChannel chan packets.PublishPacket SubscriptionChannel chan packets.PublishPacket
ExpiryInterval time.Duration ExpiryInterval time.Duration // TODO
expireTimer time.Timer // TODO expireTimer time.Timer // TODO
freePacketID uint16
} }
func NewSession(conn *Connection, p packets.ConnectPacket) Session { func NewSession(conn *Connection, p packets.ConnectPacket) Session {
@ -39,8 +43,7 @@ func NewSession(conn *Connection, p packets.ConnectPacket) Session {
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
if s.Connection != nil { if s.Connection != nil {
//TODO s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
panic("Disconnect if already have a connection, unimplemented")
} }
connAck := packets.ConnackPacket{} connAck := packets.ConnackPacket{}
@ -60,7 +63,6 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
connAck.Properties.RetainAvailable.Value = &false connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = conn s.Connection = conn
err := s.Connection.sendPacket(connAck) err := s.Connection.sendPacket(connAck)
if err != nil { if err != nil {
@ -75,7 +77,7 @@ func (s *Session) HandlerLoop() {
case packet := <-s.Connection.PacketChannel: case packet := <-s.Connection.PacketChannel:
packet.Visit(s) packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan: case _ = <-s.Connection.ClientDisconnectedChan:
s.OnDisconnect() s.onDisconnect()
case subMessage := <-s.SubscriptionChannel: case subMessage := <-s.SubscriptionChannel:
//TODO, log for now //TODO, log for now
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage) log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
@ -83,22 +85,29 @@ func (s *Session) HandlerLoop() {
} }
} }
func (s *Session) Disconnect() error { func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
panic("Disconnection unimplemented") s.Connection.sendPacket(packets.DisconnectPacket{
ReasonCode: code,
})
err := s.Connection.close() err := s.Connection.close()
if err != nil { if err != nil {
return err return err
} }
s.OnDisconnect() s.onDisconnect()
return nil return nil
} }
func (s *Session) OnDisconnect() { func (s *Session) onDisconnect() {
s.Connection = nil s.Connection = nil
s.resetExpireTimer() s.resetExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID) log.Printf("Client disconnected, id: %s", *s.ClientID)
} }
func (s *Session) expireSession() {
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
}
// newTime is nullable // newTime is nullable
func (s *Session) updateExpireTimer(newTime *uint32) { func (s *Session) updateExpireTimer(newTime *uint32) {
var expiry = uint32(0) var expiry = uint32(0)
@ -129,35 +138,67 @@ func genClientID() *string {
} }
func (s *Session) VisitConnect(_ packets.ConnectPacket) { func (s *Session) VisitConnect(_ packets.ConnectPacket) {
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION // Disconnect, we handle the connect packet in Connect,
s.Disconnect() // this means that we have an estabilished connection already
log.Println("WARN: Got a connect packet on an already estabilished connection")
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
} }
func (s *Session) VisitPublish(p packets.PublishPacket) { func (s *Session) VisitPublish(p packets.PublishPacket) {
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload)) subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName)
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
defer lock.Unlock() defer lock.Unlock()
if p.QOSLevel == 0 {
if p.PacketId != nil {
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
return
}
} else if p.QOSLevel == 1 {
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
if len(subs) == 0 {
reason = packets.PubackReasonCodeNoMatchingSubscribers
}
ack := packets.PubackPacket{
PacketID: *p.PacketId,
Reason: reason,
}
s.Connection.sendPacket(ack)
} else if p.QOSLevel == 2 {
panic("UNIMPLEMENTED QOS level 2")
}
for _, sub := range subs { for _, sub := range subs {
go func(sub Subscription) {sub <- p}(sub) if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
}
} }
} }
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
//TODO FINISH err := s.Connection.close()
// HANDLE CLIENT DISCONNECTING if err != nil && err != io.ErrClosedPipe {
s.OnDisconnect() log.Println("Error closing connection", err)
}
s.onDisconnect()
} }
func (s *Session) VisitSubscribe(p packets.SubscribePacket) { func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
//TODO FINISH
for _, filter := range p.TopicFilters { for _, filter := range p.TopicFilters {
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel) subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
} }
s.Connection.sendPacket(packets.SubAckPacket{
PacketID: p.PacketId,
Reason: packets.SubackReasonGrantedQoSTwo,
})
} }
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) { func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
panic("not implemented") // TODO: Implement for _, topic := range p.Topics {
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.UnsubAckPacket{
PacketID: p.PacketID,
Reason: packets.UnsubackReasonSuccess,
})
} }
func (s *Session) VisitPing(p packets.PingreqPacket) { func (s *Session) VisitPing(p packets.PingreqPacket) {
@ -179,3 +220,8 @@ func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
func (s *Session) getFreePacketId() uint16 {
s.freePacketID += 1
return s.freePacketID
}

View file

@ -1,86 +0,0 @@
package main
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
type Subscription chan packets.PublishPacket
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = NewSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
fields := strings.Split(topic,"/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
locker := s.nodeLock.RLocker()
s.nodeLock.RLock()
if len(topic) == 0 {
return s.subscriptions,locker
}
defer s.nodeLock.RUnlock()
child, exists := s.children[topic[0]]
if exists {
return child.findMatchingRec(topic[1:])
} else {
return []Subscription{},locker
}
}

View file

@ -0,0 +1,106 @@
package subscription
//TODO WILDCARD SUBSCRIPTIONS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode()
type SubscriptionChannel chan packets.PublishPacket
type Subscription struct {
SubscriptionChannel
packets.TopicFilter
}
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func newSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = newSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
for i, sub := range s.subscriptions {
if sub.SubscriptionChannel == subChan {
lst := len(s.subscriptions) - 1
s.subscriptions[i] = s.subscriptions[lst]
s.subscriptions = s.subscriptions[:lst]
}
}
}
func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) {
sub := Subscription{subChan, topicFilter}
node := s.findNode(topicFilter.Topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.removeSubscription(subChan)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
for _, node := range s.children {
node.nodeLock.Lock()
node.removeSubscription(subChan)
node.nodeLock.Unlock()
node.RemoveSubsForChannel(subChan)
}
}
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) {
fields := strings.Split(topicName, "/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}

View file

@ -0,0 +1,27 @@
package subscription
import (
"testing"
"badat.dev/maeqtt/v2/mqtt/packets"
)
func TestSubscribe(t *testing.T) {
tree := newSubscriptionTreeNode()
topic, _ := packets.ParseTopic("a/b/c")
channel := make(SubscriptionChannel)
topicFilter := packets.TopicFilter{
Topic: topic,
MaxQoS: 1,
}
tree.Subscribe(topicFilter, channel)
subs, lock := tree.GetSubscriptions("a/b/c")
defer lock.Unlock()
if len(subs) != 1 {
t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs))
}
if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel {
t.Error("Error with data stored in a subscription")
}
}