Improve folder strutcute

This commit is contained in:
bad 2021-10-01 22:18:48 +02:00
parent e9612e2430
commit 35879183be
15 changed files with 606 additions and 461 deletions

26
main.go
View file

@ -7,6 +7,7 @@ import (
"runtime/debug" "runtime/debug"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/session"
) )
func main() { func main() {
@ -27,7 +28,7 @@ func main() {
} }
func handleConnection(con net.Conn) { func handleConnection(con net.Conn) {
defer handlePanic() defer handlePanic(con)
reader := bufio.NewReader(con) reader := bufio.NewReader(con)
@ -38,20 +39,33 @@ func handleConnection(con net.Conn) {
} }
connect, isConn := (*packet).(packets.ConnectPacket) connect, isConn := (*packet).(packets.ConnectPacket)
if !isConn { if !isConn {
log.Println("Didn't recieve a connet packet") log.Println("Didn't recieve a connect packet")
panic("TODO: Send a disconnect packet") err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(con)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
}
return
} }
conn := NewConnection(connect, con) conn := session.NewConnection(connect, con)
sess := session.NewSession(&conn, connect)
sess := NewSession(&conn, connect)
sess.HandlerLoop() sess.HandlerLoop()
} }
func handlePanic() { func handlePanic(con net.Conn) {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Println("Recovering from panic:", r) log.Println("Recovering from panic:", r)
log.Println("Stack Trace:") log.Println("Stack Trace:")
debug.PrintStack() debug.PrintStack()
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeImplErorr,
}.Write(con)
if err != nil {
log.Println("Failed to send a disconnect packet after recovering from panic", err)
}
} }
} }

View file

@ -1,6 +1,6 @@
package mqtt package mqtt
// This code has been generated with the genProps.py script. Do not modify
// This code has been generated with the genProps.py script. Do not modify
import "bufio" import "bufio"
@ -21,7 +21,6 @@ func (p *PayloadFormatIndicator) parse(r *bufio.Reader) error {
return nil return nil
} }
type MessageExpiryInterval struct { type MessageExpiryInterval struct {
value *uint32 value *uint32
} }
@ -39,7 +38,6 @@ func (p *MessageExpiryInterval) parse(r *bufio.Reader) error {
return nil return nil
} }
type ContentType struct { type ContentType struct {
value *string value *string
} }
@ -57,7 +55,6 @@ func (p *ContentType) parse(r *bufio.Reader) error {
return nil return nil
} }
type ResponseTopic struct { type ResponseTopic struct {
value *string value *string
} }
@ -75,7 +72,6 @@ func (p *ResponseTopic) parse(r *bufio.Reader) error {
return nil return nil
} }
type CorrelationData struct { type CorrelationData struct {
value *[]byte value *[]byte
} }
@ -93,7 +89,6 @@ func (p *CorrelationData) parse(r *bufio.Reader) error {
return nil return nil
} }
type SubscriptionIdentifier struct { type SubscriptionIdentifier struct {
value *int value *int
} }
@ -111,7 +106,6 @@ func (p *SubscriptionIdentifier) parse(r *bufio.Reader) error {
return nil return nil
} }
type SessionExpiryInterval struct { type SessionExpiryInterval struct {
value *uint32 value *uint32
} }
@ -129,7 +123,6 @@ func (p *SessionExpiryInterval) parse(r *bufio.Reader) error {
return nil return nil
} }
type AssignedClientIdentifier struct { type AssignedClientIdentifier struct {
value *string value *string
} }
@ -147,7 +140,6 @@ func (p *AssignedClientIdentifier) parse(r *bufio.Reader) error {
return nil return nil
} }
type ServerKeepAlive struct { type ServerKeepAlive struct {
value *uint16 value *uint16
} }
@ -165,7 +157,6 @@ func (p *ServerKeepAlive) parse(r *bufio.Reader) error {
return nil return nil
} }
type AuthenticationMethod struct { type AuthenticationMethod struct {
value *string value *string
} }
@ -183,7 +174,6 @@ func (p *AuthenticationMethod) parse(r *bufio.Reader) error {
return nil return nil
} }
type AuthenticationData struct { type AuthenticationData struct {
value *[]byte value *[]byte
} }
@ -201,7 +191,6 @@ func (p *AuthenticationData) parse(r *bufio.Reader) error {
return nil return nil
} }
type RequestProblemInformation struct { type RequestProblemInformation struct {
value *byte value *byte
} }
@ -219,7 +208,6 @@ func (p *RequestProblemInformation) parse(r *bufio.Reader) error {
return nil return nil
} }
type WillDelayInterval struct { type WillDelayInterval struct {
value *uint32 value *uint32
} }
@ -237,7 +225,6 @@ func (p *WillDelayInterval) parse(r *bufio.Reader) error {
return nil return nil
} }
type RequestResponseInformation struct { type RequestResponseInformation struct {
value *byte value *byte
} }
@ -255,7 +242,6 @@ func (p *RequestResponseInformation) parse(r *bufio.Reader) error {
return nil return nil
} }
type ResponseInformation struct { type ResponseInformation struct {
value *string value *string
} }
@ -273,7 +259,6 @@ func (p *ResponseInformation) parse(r *bufio.Reader) error {
return nil return nil
} }
type ServerReference struct { type ServerReference struct {
value *string value *string
} }
@ -291,7 +276,6 @@ func (p *ServerReference) parse(r *bufio.Reader) error {
return nil return nil
} }
type ReasonString struct { type ReasonString struct {
value *string value *string
} }
@ -309,7 +293,6 @@ func (p *ReasonString) parse(r *bufio.Reader) error {
return nil return nil
} }
type ReceiveMaximum struct { type ReceiveMaximum struct {
value *uint16 value *uint16
} }
@ -327,7 +310,6 @@ func (p *ReceiveMaximum) parse(r *bufio.Reader) error {
return nil return nil
} }
type TopicAliasMaximum struct { type TopicAliasMaximum struct {
value *uint16 value *uint16
} }
@ -345,7 +327,6 @@ func (p *TopicAliasMaximum) parse(r *bufio.Reader) error {
return nil return nil
} }
type TopicAlias struct { type TopicAlias struct {
value *uint16 value *uint16
} }
@ -363,7 +344,6 @@ func (p *TopicAlias) parse(r *bufio.Reader) error {
return nil return nil
} }
type MaximumQoS struct { type MaximumQoS struct {
value *byte value *byte
} }
@ -381,7 +361,6 @@ func (p *MaximumQoS) parse(r *bufio.Reader) error {
return nil return nil
} }
type RetainAvailable struct { type RetainAvailable struct {
value *byte value *byte
} }
@ -399,7 +378,6 @@ func (p *RetainAvailable) parse(r *bufio.Reader) error {
return nil return nil
} }
type MaximumPacketSize struct { type MaximumPacketSize struct {
value *uint32 value *uint32
} }
@ -417,7 +395,6 @@ func (p *MaximumPacketSize) parse(r *bufio.Reader) error {
return nil return nil
} }
type WildcardSubscriptionAvailable struct { type WildcardSubscriptionAvailable struct {
value *byte value *byte
} }
@ -435,7 +412,6 @@ func (p *WildcardSubscriptionAvailable) parse(r *bufio.Reader) error {
return nil return nil
} }
type SubscriptionIdentifierAvailable struct { type SubscriptionIdentifierAvailable struct {
value *byte value *byte
} }
@ -453,7 +429,6 @@ func (p *SubscriptionIdentifierAvailable) parse(r *bufio.Reader) error {
return nil return nil
} }
type SharedSubscriptionAvailable struct { type SharedSubscriptionAvailable struct {
value *byte value *byte
} }
@ -481,6 +456,7 @@ SubscriptionIdentifier SubscriptionIdentifier
TopicAlias TopicAlias TopicAlias TopicAlias
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PublishPacketProperties) arrayOf() []Property { func (p *PublishPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.PayloadFormatIndicator, &p.PayloadFormatIndicator,
@ -493,6 +469,7 @@ return []Property {
&p.UserProperty, &p.UserProperty,
} }
} }
type WillProperties struct { type WillProperties struct {
PayloadFormatIndicator PayloadFormatIndicator PayloadFormatIndicator PayloadFormatIndicator
MessageExpiryInterval MessageExpiryInterval MessageExpiryInterval MessageExpiryInterval
@ -502,6 +479,7 @@ CorrelationData CorrelationData
WillDelayInterval WillDelayInterval WillDelayInterval WillDelayInterval
UserProperty UserProperty UserProperty UserProperty
} }
func (p *WillProperties) arrayOf() []Property { func (p *WillProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.PayloadFormatIndicator, &p.PayloadFormatIndicator,
@ -513,16 +491,19 @@ return []Property {
&p.UserProperty, &p.UserProperty,
} }
} }
type SubscribePacketProperties struct { type SubscribePacketProperties struct {
SubscriptionIdentifier SubscriptionIdentifier SubscriptionIdentifier SubscriptionIdentifier
UserProperty UserProperty UserProperty UserProperty
} }
func (p *SubscribePacketProperties) arrayOf() []Property { func (p *SubscribePacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.SubscriptionIdentifier, &p.SubscriptionIdentifier,
&p.UserProperty, &p.UserProperty,
} }
} }
type ConnectPacketProperties struct { type ConnectPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval
AuthenticationMethod AuthenticationMethod AuthenticationMethod AuthenticationMethod
@ -534,6 +515,7 @@ TopicAliasMaximum TopicAliasMaximum
UserProperty UserProperty UserProperty UserProperty
MaximumPacketSize MaximumPacketSize MaximumPacketSize MaximumPacketSize
} }
func (p *ConnectPacketProperties) arrayOf() []Property { func (p *ConnectPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.SessionExpiryInterval, &p.SessionExpiryInterval,
@ -547,6 +529,7 @@ return []Property {
&p.MaximumPacketSize, &p.MaximumPacketSize,
} }
} }
type ConnackPacketProperties struct { type ConnackPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval
AssignedClientIdentifier AssignedClientIdentifier AssignedClientIdentifier AssignedClientIdentifier
@ -566,6 +549,7 @@ WildcardSubscriptionAvailable WildcardSubscriptionAvailable
SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable
SharedSubscriptionAvailable SharedSubscriptionAvailable SharedSubscriptionAvailable SharedSubscriptionAvailable
} }
func (p *ConnackPacketProperties) arrayOf() []Property { func (p *ConnackPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.SessionExpiryInterval, &p.SessionExpiryInterval,
@ -587,12 +571,14 @@ return []Property {
&p.SharedSubscriptionAvailable, &p.SharedSubscriptionAvailable,
} }
} }
type DisconnectPacketProperties struct { type DisconnectPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval SessionExpiryInterval
ServerReference ServerReference ServerReference ServerReference
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *DisconnectPacketProperties) arrayOf() []Property { func (p *DisconnectPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.SessionExpiryInterval, &p.SessionExpiryInterval,
@ -601,12 +587,14 @@ return []Property {
&p.UserProperty, &p.UserProperty,
} }
} }
type AuthPacketProperties struct { type AuthPacketProperties struct {
AuthenticationMethod AuthenticationMethod AuthenticationMethod AuthenticationMethod
AuthenticationData AuthenticationData AuthenticationData AuthenticationData
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *AuthPacketProperties) arrayOf() []Property { func (p *AuthPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.AuthenticationMethod, &p.AuthenticationMethod,
@ -615,69 +603,83 @@ return []Property {
&p.UserProperty, &p.UserProperty,
} }
} }
type PubackPacketProperties struct { type PubackPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubackPacketProperties) arrayOf() []Property { func (p *PubackPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubrecPacketProperties struct { type PubrecPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubrecPacketProperties) arrayOf() []Property { func (p *PubrecPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubrelPacketProperties struct { type PubrelPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubrelPacketProperties) arrayOf() []Property { func (p *PubrelPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type PubcompPacketProperties struct { type PubcompPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *PubcompPacketProperties) arrayOf() []Property { func (p *PubcompPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type SubackPacketProperties struct { type SubackPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *SubackPacketProperties) arrayOf() []Property { func (p *SubackPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type UnsubackPacketProperties struct { type UnsubackPacketProperties struct {
ReasonString ReasonString ReasonString ReasonString
UserProperty UserProperty UserProperty UserProperty
} }
func (p *UnsubackPacketProperties) arrayOf() []Property { func (p *UnsubackPacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.ReasonString, &p.ReasonString,
&p.UserProperty, &p.UserProperty,
} }
} }
type UnsubscribePacketProperties struct { type UnsubscribePacketProperties struct {
UserProperty UserProperty UserProperty UserProperty
} }
func (p *UnsubscribePacketProperties) arrayOf() []Property { func (p *UnsubscribePacketProperties) arrayOf() []Property {
return []Property{ return []Property{
&p.UserProperty, &p.UserProperty,

View file

@ -60,7 +60,6 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) {
r := bufio.NewReader(control.reader) r := bufio.NewReader(control.reader)
// If there is less then a byte in the reader assume the reason code == 0 // If there is less then a byte in the reader assume the reason code == 0
reason, err := r.ReadByte() reason, err := r.ReadByte()
if err == io.EOF { if err == io.EOF {

View file

@ -6,7 +6,6 @@ import (
"io" "io"
) )
type PingreqPacket struct{} type PingreqPacket struct{}
func parsePingreq(control controlPacket) (PingreqPacket, error) { func parsePingreq(control controlPacket) (PingreqPacket, error) {

View file

@ -1,11 +1,10 @@
package packets package packets
import ( import (
"io"
"badat.dev/maeqtt/v2/mqtt/properties" "badat.dev/maeqtt/v2/mqtt/properties"
"io"
) )
type PubackReasonCode byte type PubackReasonCode byte
const ( const (

View file

@ -2,6 +2,7 @@ package packets
import ( import (
"bufio" "bufio"
"bytes"
"errors" "errors"
"io" "io"
"strings" "strings"
@ -15,7 +16,8 @@ type Topic struct {
} }
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic") var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
func parseTopic(topic_name string) (Topic, error) {
func ParseTopic(topic_name string) (Topic, error) {
topic := Topic{} topic := Topic{}
fields := strings.Split(topic_name, "/") fields := strings.Split(topic_name, "/")
for i, field := range fields { for i, field := range fields {
@ -45,7 +47,7 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
return filter, err return filter, err
} }
filter.Topic, err = parseTopic(topic_str) filter.Topic, err = ParseTopic(topic_str)
if err != nil { if err != nil {
return filter, err return filter, err
} }
@ -61,27 +63,32 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
return filter, nil return filter, nil
} }
// Both sub and unsubscribe packets are identitcal so we can reuse the parsing logic type SubscribePacket struct {
type SubscriptionPacket struct {
PacketId uint16 PacketId uint16
TopicFilters []TopicFilter TopicFilters []TopicFilter
Properties properties.SubscribePacketProperties
} }
func parseSubscriptionPacket(control controlPacket, props []properties.Property) (SubscriptionPacket, error) { func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
var err error if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
}
packet := SubscribePacket{}
r := bufio.NewReader(control.reader) r := bufio.NewReader(control.reader)
packet := SubscriptionPacket{}
if control.flags != 2 { if control.flags != 2 {
return packet, errors.New("Malformed subscription packet") return packet, errors.New("Malformed subscription packet")
} }
var err error
packet.PacketId, err = types.DecodeUint16(r) packet.PacketId, err = types.DecodeUint16(r)
if err != nil { if err != nil {
return packet, err return packet, err
} }
err = properties.ParseProperties(r, props) err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil { if err != nil {
return packet, err return packet, err
} }
@ -100,33 +107,10 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
return packet, nil return packet, nil
} }
} }
println("A")
return packet, nil return packet, nil
} }
type SubscribePacket struct {
*SubscriptionPacket
props properties.SubscribePacketProperties
}
/// CURRENTLY BROKEN
// TODO FIXME AAAAA
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
}
pack := SubscribePacket{}
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf())
if err != nil {
return pack, err
}
pack.SubscriptionPacket = &subscriptionPack
return pack, nil
}
func (p SubscribePacket) Visit(v PacketVisitor) { func (p SubscribePacket) Visit(v PacketVisitor) {
v.VisitSubscribe(p) v.VisitSubscribe(p)
} }
@ -143,9 +127,9 @@ const (
SubackReasonTopicFilterInvalid = 143 SubackReasonTopicFilterInvalid = 143
SubackReasonPacketIDInUse = 145 SubackReasonPacketIDInUse = 145
SubackReasonQuotaExceeded = 151 SubackReasonQuotaExceeded = 151
SubackReasonSharedSubNotSupported = 151 SubackReasonSharedSubNotSupported = 158
SubackReasonSubIDUnsupported = 151 SubackReasonSubIDUnsupported = 161
SubackReasonWildcardSubUnsupported = 151 SubackReasonWildcardSubUnsupported = 162
) )
type SubAckPacket struct { type SubAckPacket struct {
@ -154,20 +138,35 @@ type SubAckPacket struct {
Reason SubackReasonCode Reason SubackReasonCode
} }
func (p SubAckPacket) Write(w io.Writer) error { func (p SubAckPacket) Write(w io.Writer) error {
resp := pubRespPacket{ buf := bytes.NewBuffer([]byte{})
PacketType: PacketTypeSuback, err := types.WriteUint16(buf, p.PacketID)
PacketID: p.PacketID, if err != nil {
Properties: p.Properties.ArrayOf(), return err
Reason: byte(p.Reason),
} }
return resp.Write(w)
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
err = buf.WriteByte(byte(p.Reason))
if err != nil {
return err
}
conPack := controlPacket{
packetType: PacketTypeSuback,
flags: 0,
reader: buf,
}
return conPack.write(w)
} }
type UnsubscribePacket struct { type UnsubscribePacket struct {
*SubscriptionPacket PacketID uint16
props properties.UnsubscribePacketProperties Topics []Topic
Properties properties.UnsubscribePacketProperties
} }
func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) { func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
@ -175,14 +174,41 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
panic("Wrong packet type for parseSubscribePacket") panic("Wrong packet type for parseSubscribePacket")
} }
pack := UnsubscribePacket{} packet := UnsubscribePacket{}
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf()) r := bufio.NewReader(control.reader)
if err != nil {
return pack, err if control.flags != 2 {
return packet, errors.New("Malformed subscription packet")
} }
pack.PacketId = subscriptionPack.PacketId
pack.TopicFilters = subscriptionPack.TopicFilters var err error
return pack, nil packet.PacketID, err = types.DecodeUint16(r)
if err != nil {
return packet, err
}
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil {
return packet, err
}
for err != io.EOF {
topic_str, err := types.DecodeUTF8String(r)
if err != nil && err != io.EOF {
return packet, err
} else if err == io.EOF {
return packet, nil
}
filter, err := ParseTopic(topic_str)
if err != nil {
return packet, err
}
packet.Topics = append(packet.Topics, filter)
}
return packet, nil
} }
func (p UnsubscribePacket) Visit(v PacketVisitor) { func (p UnsubscribePacket) Visit(v PacketVisitor) {
@ -192,7 +218,7 @@ func (p UnsubscribePacket) Visit(v PacketVisitor) {
type UnsubackReasonCode byte type UnsubackReasonCode byte
const ( const (
UnsubackReasonSuccess PubackReasonCode = 0 UnsubackReasonSuccess UnsubackReasonCode = 0
UnSubackReasonUnspecified = 128 UnSubackReasonUnspecified = 128
UnSubackReasonImplSpecificError = 131 UnSubackReasonImplSpecificError = 131
UnSubackReasonNotAuthorized = 135 UnSubackReasonNotAuthorized = 135
@ -206,13 +232,27 @@ type UnsubAckPacket struct {
Reason UnsubackReasonCode Reason UnsubackReasonCode
} }
func (p UnsubAckPacket) Write(w io.Writer) error { func (p UnsubAckPacket) Write(w io.Writer) error {
resp := pubRespPacket{ buf := bytes.NewBuffer([]byte{})
PacketType: PacketTypeUnsuback, err := types.WriteUint16(buf, p.PacketID)
PacketID: p.PacketID, if err != nil {
Properties: p.Properties.ArrayOf(), return err
Reason: byte(p.Reason),
} }
return resp.Write(w)
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
err = buf.WriteByte(byte(p.Reason))
if err != nil {
return err
}
conPack := controlPacket{
packetType: PacketTypeUnsuback,
flags: 0,
reader: buf,
}
return conPack.write(w)
} }

View file

@ -34,7 +34,7 @@ func (p pubRespPacket) Write(w io.Writer) error {
} }
conPack := controlPacket{ conPack := controlPacket{
packetType: PacketTypePuback, packetType: p.PacketType,
flags: 0, flags: 0,
reader: buf, reader: buf,
} }

View file

@ -57,7 +57,6 @@ func DecodeBinaryData(r *bufio.Reader) ([]byte, error) {
return buffer, err return buffer, err
} }
func DecodeUTF8String(r *bufio.Reader) (string, error) { func DecodeUTF8String(r *bufio.Reader) (string, error) {
binary, err := DecodeBinaryData(r) binary, err := DecodeBinaryData(r)
return string(binary[:]), err return string(binary[:]), err

View file

@ -29,8 +29,8 @@ func WriteUint32(w io.Writer, v uint32) error {
return err return err
} }
const uint32Max uint32 = ^uint32(0) const uint32Max uint32 = ^uint32(0)
func WriteDataWithVarIntLen(w io.Writer, data []byte) error { func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
if len(data) > int(uint32Max) { if len(data) > int(uint32Max) {
return errors.New("Tried to write more data than max varint size") return errors.New("Tried to write more data than max varint size")
@ -46,6 +46,7 @@ func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
} }
const uint16Max uint16 = ^uint16(0) const uint16Max uint16 = ^uint16(0)
func WriteBinaryData(w io.Writer, data []byte) error { func WriteBinaryData(w io.Writer, data []byte) error {
if len(data) > int(uint16Max) { if len(data) > int(uint16Max) {
return errors.New("Tried to write more data than max uint16 size") return errors.New("Tried to write more data than max uint16 size")
@ -64,7 +65,6 @@ func WriteUTF8String(w io.Writer, str string) error {
return WriteBinaryData(w, []byte(str)) return WriteBinaryData(w, []byte(str))
} }
func WriteVariableByteInt(w io.Writer, v uint32) error { func WriteVariableByteInt(w io.Writer, v uint32) error {
for { for {
encodedByte := byte(v % 128) encodedByte := byte(v % 128)

View file

@ -1,4 +1,4 @@
package main package session
import ( import (
"bufio" "bufio"

View file

@ -1,13 +1,15 @@
package main package session
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"log" "log"
"math/rand" "math/rand"
"time" "time"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/subscription"
) )
func init() { func init() {
@ -25,8 +27,10 @@ type Session struct {
Connection *Connection Connection *Connection
SubscriptionChannel chan packets.PublishPacket SubscriptionChannel chan packets.PublishPacket
ExpiryInterval time.Duration ExpiryInterval time.Duration // TODO
expireTimer time.Timer // TODO expireTimer time.Timer // TODO
freePacketID uint16
} }
func NewSession(conn *Connection, p packets.ConnectPacket) Session { func NewSession(conn *Connection, p packets.ConnectPacket) Session {
@ -39,8 +43,7 @@ func NewSession(conn *Connection, p packets.ConnectPacket) Session {
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
if s.Connection != nil { if s.Connection != nil {
//TODO s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
panic("Disconnect if already have a connection, unimplemented")
} }
connAck := packets.ConnackPacket{} connAck := packets.ConnackPacket{}
@ -60,7 +63,6 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
connAck.Properties.RetainAvailable.Value = &false connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = conn s.Connection = conn
err := s.Connection.sendPacket(connAck) err := s.Connection.sendPacket(connAck)
if err != nil { if err != nil {
@ -75,7 +77,7 @@ func (s *Session) HandlerLoop() {
case packet := <-s.Connection.PacketChannel: case packet := <-s.Connection.PacketChannel:
packet.Visit(s) packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan: case _ = <-s.Connection.ClientDisconnectedChan:
s.OnDisconnect() s.onDisconnect()
case subMessage := <-s.SubscriptionChannel: case subMessage := <-s.SubscriptionChannel:
//TODO, log for now //TODO, log for now
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage) log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
@ -83,22 +85,29 @@ func (s *Session) HandlerLoop() {
} }
} }
func (s *Session) Disconnect() error { func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
panic("Disconnection unimplemented") s.Connection.sendPacket(packets.DisconnectPacket{
ReasonCode: code,
})
err := s.Connection.close() err := s.Connection.close()
if err != nil { if err != nil {
return err return err
} }
s.OnDisconnect() s.onDisconnect()
return nil return nil
} }
func (s *Session) OnDisconnect() { func (s *Session) onDisconnect() {
s.Connection = nil s.Connection = nil
s.resetExpireTimer() s.resetExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID) log.Printf("Client disconnected, id: %s", *s.ClientID)
} }
func (s *Session) expireSession() {
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
}
// newTime is nullable // newTime is nullable
func (s *Session) updateExpireTimer(newTime *uint32) { func (s *Session) updateExpireTimer(newTime *uint32) {
var expiry = uint32(0) var expiry = uint32(0)
@ -129,35 +138,67 @@ func genClientID() *string {
} }
func (s *Session) VisitConnect(_ packets.ConnectPacket) { func (s *Session) VisitConnect(_ packets.ConnectPacket) {
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION // Disconnect, we handle the connect packet in Connect,
s.Disconnect() // this means that we have an estabilished connection already
log.Println("WARN: Got a connect packet on an already estabilished connection")
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
} }
func (s *Session) VisitPublish(p packets.PublishPacket) { func (s *Session) VisitPublish(p packets.PublishPacket) {
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload)) subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName)
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
defer lock.Unlock() defer lock.Unlock()
if p.QOSLevel == 0 {
if p.PacketId != nil {
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
return
}
} else if p.QOSLevel == 1 {
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
if len(subs) == 0 {
reason = packets.PubackReasonCodeNoMatchingSubscribers
}
ack := packets.PubackPacket{
PacketID: *p.PacketId,
Reason: reason,
}
s.Connection.sendPacket(ack)
} else if p.QOSLevel == 2 {
panic("UNIMPLEMENTED QOS level 2")
}
for _, sub := range subs { for _, sub := range subs {
go func(sub Subscription) {sub <- p}(sub) if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
}
} }
} }
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
//TODO FINISH err := s.Connection.close()
// HANDLE CLIENT DISCONNECTING if err != nil && err != io.ErrClosedPipe {
s.OnDisconnect() log.Println("Error closing connection", err)
}
s.onDisconnect()
} }
func (s *Session) VisitSubscribe(p packets.SubscribePacket) { func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
//TODO FINISH
for _, filter := range p.TopicFilters { for _, filter := range p.TopicFilters {
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel) subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
} }
s.Connection.sendPacket(packets.SubAckPacket{
PacketID: p.PacketId,
Reason: packets.SubackReasonGrantedQoSTwo,
})
} }
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) { func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
panic("not implemented") // TODO: Implement for _, topic := range p.Topics {
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.UnsubAckPacket{
PacketID: p.PacketID,
Reason: packets.UnsubackReasonSuccess,
})
} }
func (s *Session) VisitPing(p packets.PingreqPacket) { func (s *Session) VisitPing(p packets.PingreqPacket) {
@ -179,3 +220,8 @@ func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
func (s *Session) getFreePacketId() uint16 {
s.freePacketID += 1
return s.freePacketID
}

View file

@ -1,86 +0,0 @@
package main
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
type Subscription chan packets.PublishPacket
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = NewSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
fields := strings.Split(topic,"/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
locker := s.nodeLock.RLocker()
s.nodeLock.RLock()
if len(topic) == 0 {
return s.subscriptions,locker
}
defer s.nodeLock.RUnlock()
child, exists := s.children[topic[0]]
if exists {
return child.findMatchingRec(topic[1:])
} else {
return []Subscription{},locker
}
}

View file

@ -0,0 +1,106 @@
package subscription
//TODO WILDCARD SUBSCRIPTIONS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode()
type SubscriptionChannel chan packets.PublishPacket
type Subscription struct {
SubscriptionChannel
packets.TopicFilter
}
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func newSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = newSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
for i, sub := range s.subscriptions {
if sub.SubscriptionChannel == subChan {
lst := len(s.subscriptions) - 1
s.subscriptions[i] = s.subscriptions[lst]
s.subscriptions = s.subscriptions[:lst]
}
}
}
func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) {
sub := Subscription{subChan, topicFilter}
node := s.findNode(topicFilter.Topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.removeSubscription(subChan)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
for _, node := range s.children {
node.nodeLock.Lock()
node.removeSubscription(subChan)
node.nodeLock.Unlock()
node.RemoveSubsForChannel(subChan)
}
}
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) {
fields := strings.Split(topicName, "/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}

View file

@ -0,0 +1,27 @@
package subscription
import (
"testing"
"badat.dev/maeqtt/v2/mqtt/packets"
)
func TestSubscribe(t *testing.T) {
tree := newSubscriptionTreeNode()
topic, _ := packets.ParseTopic("a/b/c")
channel := make(SubscriptionChannel)
topicFilter := packets.TopicFilter{
Topic: topic,
MaxQoS: 1,
}
tree.Subscribe(topicFilter, channel)
subs, lock := tree.GetSubscriptions("a/b/c")
defer lock.Unlock()
if len(subs) != 1 {
t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs))
}
if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel {
t.Error("Error with data stored in a subscription")
}
}