Improve folder strutcute
This commit is contained in:
parent
e9612e2430
commit
35879183be
15 changed files with 606 additions and 461 deletions
26
main.go
26
main.go
|
@ -7,6 +7,7 @@ import (
|
|||
"runtime/debug"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
"badat.dev/maeqtt/v2/session"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -27,7 +28,7 @@ func main() {
|
|||
}
|
||||
|
||||
func handleConnection(con net.Conn) {
|
||||
defer handlePanic()
|
||||
defer handlePanic(con)
|
||||
|
||||
reader := bufio.NewReader(con)
|
||||
|
||||
|
@ -38,20 +39,33 @@ func handleConnection(con net.Conn) {
|
|||
}
|
||||
connect, isConn := (*packet).(packets.ConnectPacket)
|
||||
if !isConn {
|
||||
log.Println("Didn't recieve a connet packet")
|
||||
panic("TODO: Send a disconnect packet")
|
||||
log.Println("Didn't recieve a connect packet")
|
||||
err := packets.DisconnectPacket{
|
||||
ReasonCode: packets.DisconnectReasonCodeProtocolError,
|
||||
}.Write(con)
|
||||
if err != nil {
|
||||
log.Println("Failed to disconnect after not recieving a connect packet", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn := NewConnection(connect, con)
|
||||
conn := session.NewConnection(connect, con)
|
||||
sess := session.NewSession(&conn, connect)
|
||||
|
||||
sess := NewSession(&conn, connect)
|
||||
sess.HandlerLoop()
|
||||
}
|
||||
|
||||
func handlePanic() {
|
||||
func handlePanic(con net.Conn) {
|
||||
if r := recover(); r != nil {
|
||||
log.Println("Recovering from panic:", r)
|
||||
log.Println("Stack Trace:")
|
||||
debug.PrintStack()
|
||||
|
||||
err := packets.DisconnectPacket{
|
||||
ReasonCode: packets.DisconnectReasonCodeImplErorr,
|
||||
}.Write(con)
|
||||
if err != nil {
|
||||
log.Println("Failed to send a disconnect packet after recovering from panic", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package mqtt
|
||||
// This code has been generated with the genProps.py script. Do not modify
|
||||
|
||||
// This code has been generated with the genProps.py script. Do not modify
|
||||
|
||||
import "bufio"
|
||||
|
||||
|
@ -21,7 +21,6 @@ func (p *PayloadFormatIndicator) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type MessageExpiryInterval struct {
|
||||
value *uint32
|
||||
}
|
||||
|
@ -39,7 +38,6 @@ func (p *MessageExpiryInterval) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ContentType struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -57,7 +55,6 @@ func (p *ContentType) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ResponseTopic struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -75,7 +72,6 @@ func (p *ResponseTopic) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type CorrelationData struct {
|
||||
value *[]byte
|
||||
}
|
||||
|
@ -93,7 +89,6 @@ func (p *CorrelationData) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type SubscriptionIdentifier struct {
|
||||
value *int
|
||||
}
|
||||
|
@ -111,7 +106,6 @@ func (p *SubscriptionIdentifier) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type SessionExpiryInterval struct {
|
||||
value *uint32
|
||||
}
|
||||
|
@ -129,7 +123,6 @@ func (p *SessionExpiryInterval) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type AssignedClientIdentifier struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -147,7 +140,6 @@ func (p *AssignedClientIdentifier) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ServerKeepAlive struct {
|
||||
value *uint16
|
||||
}
|
||||
|
@ -165,7 +157,6 @@ func (p *ServerKeepAlive) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type AuthenticationMethod struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -183,7 +174,6 @@ func (p *AuthenticationMethod) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type AuthenticationData struct {
|
||||
value *[]byte
|
||||
}
|
||||
|
@ -201,7 +191,6 @@ func (p *AuthenticationData) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type RequestProblemInformation struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -219,7 +208,6 @@ func (p *RequestProblemInformation) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type WillDelayInterval struct {
|
||||
value *uint32
|
||||
}
|
||||
|
@ -237,7 +225,6 @@ func (p *WillDelayInterval) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type RequestResponseInformation struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -255,7 +242,6 @@ func (p *RequestResponseInformation) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ResponseInformation struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -273,7 +259,6 @@ func (p *ResponseInformation) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ServerReference struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -291,7 +276,6 @@ func (p *ServerReference) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ReasonString struct {
|
||||
value *string
|
||||
}
|
||||
|
@ -309,7 +293,6 @@ func (p *ReasonString) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type ReceiveMaximum struct {
|
||||
value *uint16
|
||||
}
|
||||
|
@ -327,7 +310,6 @@ func (p *ReceiveMaximum) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type TopicAliasMaximum struct {
|
||||
value *uint16
|
||||
}
|
||||
|
@ -345,7 +327,6 @@ func (p *TopicAliasMaximum) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type TopicAlias struct {
|
||||
value *uint16
|
||||
}
|
||||
|
@ -363,7 +344,6 @@ func (p *TopicAlias) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type MaximumQoS struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -381,7 +361,6 @@ func (p *MaximumQoS) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type RetainAvailable struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -399,7 +378,6 @@ func (p *RetainAvailable) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type MaximumPacketSize struct {
|
||||
value *uint32
|
||||
}
|
||||
|
@ -417,7 +395,6 @@ func (p *MaximumPacketSize) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type WildcardSubscriptionAvailable struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -435,7 +412,6 @@ func (p *WildcardSubscriptionAvailable) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type SubscriptionIdentifierAvailable struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -453,7 +429,6 @@ func (p *SubscriptionIdentifierAvailable) parse(r *bufio.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
|
||||
type SharedSubscriptionAvailable struct {
|
||||
value *byte
|
||||
}
|
||||
|
@ -481,6 +456,7 @@ SubscriptionIdentifier SubscriptionIdentifier
|
|||
TopicAlias TopicAlias
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *PublishPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.PayloadFormatIndicator,
|
||||
|
@ -493,6 +469,7 @@ return []Property {
|
|||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type WillProperties struct {
|
||||
PayloadFormatIndicator PayloadFormatIndicator
|
||||
MessageExpiryInterval MessageExpiryInterval
|
||||
|
@ -502,6 +479,7 @@ CorrelationData CorrelationData
|
|||
WillDelayInterval WillDelayInterval
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *WillProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.PayloadFormatIndicator,
|
||||
|
@ -513,16 +491,19 @@ return []Property {
|
|||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type SubscribePacketProperties struct {
|
||||
SubscriptionIdentifier SubscriptionIdentifier
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *SubscribePacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.SubscriptionIdentifier,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectPacketProperties struct {
|
||||
SessionExpiryInterval SessionExpiryInterval
|
||||
AuthenticationMethod AuthenticationMethod
|
||||
|
@ -534,6 +515,7 @@ TopicAliasMaximum TopicAliasMaximum
|
|||
UserProperty UserProperty
|
||||
MaximumPacketSize MaximumPacketSize
|
||||
}
|
||||
|
||||
func (p *ConnectPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.SessionExpiryInterval,
|
||||
|
@ -547,6 +529,7 @@ return []Property {
|
|||
&p.MaximumPacketSize,
|
||||
}
|
||||
}
|
||||
|
||||
type ConnackPacketProperties struct {
|
||||
SessionExpiryInterval SessionExpiryInterval
|
||||
AssignedClientIdentifier AssignedClientIdentifier
|
||||
|
@ -566,6 +549,7 @@ WildcardSubscriptionAvailable WildcardSubscriptionAvailable
|
|||
SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable
|
||||
SharedSubscriptionAvailable SharedSubscriptionAvailable
|
||||
}
|
||||
|
||||
func (p *ConnackPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.SessionExpiryInterval,
|
||||
|
@ -587,12 +571,14 @@ return []Property {
|
|||
&p.SharedSubscriptionAvailable,
|
||||
}
|
||||
}
|
||||
|
||||
type DisconnectPacketProperties struct {
|
||||
SessionExpiryInterval SessionExpiryInterval
|
||||
ServerReference ServerReference
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *DisconnectPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.SessionExpiryInterval,
|
||||
|
@ -601,12 +587,14 @@ return []Property {
|
|||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type AuthPacketProperties struct {
|
||||
AuthenticationMethod AuthenticationMethod
|
||||
AuthenticationData AuthenticationData
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *AuthPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.AuthenticationMethod,
|
||||
|
@ -615,69 +603,83 @@ return []Property {
|
|||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type PubackPacketProperties struct {
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *PubackPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.ReasonString,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type PubrecPacketProperties struct {
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *PubrecPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.ReasonString,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type PubrelPacketProperties struct {
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *PubrelPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.ReasonString,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type PubcompPacketProperties struct {
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *PubcompPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.ReasonString,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type SubackPacketProperties struct {
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *SubackPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.ReasonString,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type UnsubackPacketProperties struct {
|
||||
ReasonString ReasonString
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *UnsubackPacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.ReasonString,
|
||||
&p.UserProperty,
|
||||
}
|
||||
}
|
||||
|
||||
type UnsubscribePacketProperties struct {
|
||||
UserProperty UserProperty
|
||||
}
|
||||
|
||||
func (p *UnsubscribePacketProperties) arrayOf() []Property {
|
||||
return []Property{
|
||||
&p.UserProperty,
|
||||
|
|
|
@ -60,7 +60,6 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) {
|
|||
|
||||
r := bufio.NewReader(control.reader)
|
||||
|
||||
|
||||
// If there is less then a byte in the reader assume the reason code == 0
|
||||
reason, err := r.ReadByte()
|
||||
if err == io.EOF {
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
|
||||
type PingreqPacket struct{}
|
||||
|
||||
func parsePingreq(control controlPacket) (PingreqPacket, error) {
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
package packets
|
||||
|
||||
import (
|
||||
"io"
|
||||
"badat.dev/maeqtt/v2/mqtt/properties"
|
||||
"io"
|
||||
)
|
||||
|
||||
|
||||
type PubackReasonCode byte
|
||||
|
||||
const (
|
||||
|
|
|
@ -2,6 +2,7 @@ package packets
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
|
@ -15,7 +16,8 @@ type Topic struct {
|
|||
}
|
||||
|
||||
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
|
||||
func parseTopic(topic_name string) (Topic, error) {
|
||||
|
||||
func ParseTopic(topic_name string) (Topic, error) {
|
||||
topic := Topic{}
|
||||
fields := strings.Split(topic_name, "/")
|
||||
for i, field := range fields {
|
||||
|
@ -45,7 +47,7 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
|
|||
return filter, err
|
||||
}
|
||||
|
||||
filter.Topic, err = parseTopic(topic_str)
|
||||
filter.Topic, err = ParseTopic(topic_str)
|
||||
if err != nil {
|
||||
return filter, err
|
||||
}
|
||||
|
@ -61,27 +63,32 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
|
|||
return filter, nil
|
||||
}
|
||||
|
||||
// Both sub and unsubscribe packets are identitcal so we can reuse the parsing logic
|
||||
type SubscriptionPacket struct {
|
||||
type SubscribePacket struct {
|
||||
PacketId uint16
|
||||
TopicFilters []TopicFilter
|
||||
Properties properties.SubscribePacketProperties
|
||||
}
|
||||
|
||||
func parseSubscriptionPacket(control controlPacket, props []properties.Property) (SubscriptionPacket, error) {
|
||||
var err error
|
||||
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
||||
if control.packetType != PacketTypeSubscribe {
|
||||
panic("Wrong packet type for parseSubscribePacket")
|
||||
}
|
||||
|
||||
packet := SubscribePacket{}
|
||||
|
||||
r := bufio.NewReader(control.reader)
|
||||
packet := SubscriptionPacket{}
|
||||
|
||||
if control.flags != 2 {
|
||||
return packet, errors.New("Malformed subscription packet")
|
||||
}
|
||||
|
||||
var err error
|
||||
packet.PacketId, err = types.DecodeUint16(r)
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
err = properties.ParseProperties(r, props)
|
||||
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
@ -100,33 +107,10 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
|
|||
return packet, nil
|
||||
}
|
||||
}
|
||||
println("A")
|
||||
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
type SubscribePacket struct {
|
||||
*SubscriptionPacket
|
||||
props properties.SubscribePacketProperties
|
||||
}
|
||||
|
||||
/// CURRENTLY BROKEN
|
||||
|
||||
// TODO FIXME AAAAA
|
||||
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
|
||||
if control.packetType != PacketTypeSubscribe {
|
||||
panic("Wrong packet type for parseSubscribePacket")
|
||||
}
|
||||
|
||||
pack := SubscribePacket{}
|
||||
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf())
|
||||
if err != nil {
|
||||
return pack, err
|
||||
}
|
||||
pack.SubscriptionPacket = &subscriptionPack
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (p SubscribePacket) Visit(v PacketVisitor) {
|
||||
v.VisitSubscribe(p)
|
||||
}
|
||||
|
@ -143,9 +127,9 @@ const (
|
|||
SubackReasonTopicFilterInvalid = 143
|
||||
SubackReasonPacketIDInUse = 145
|
||||
SubackReasonQuotaExceeded = 151
|
||||
SubackReasonSharedSubNotSupported = 151
|
||||
SubackReasonSubIDUnsupported = 151
|
||||
SubackReasonWildcardSubUnsupported = 151
|
||||
SubackReasonSharedSubNotSupported = 158
|
||||
SubackReasonSubIDUnsupported = 161
|
||||
SubackReasonWildcardSubUnsupported = 162
|
||||
)
|
||||
|
||||
type SubAckPacket struct {
|
||||
|
@ -154,20 +138,35 @@ type SubAckPacket struct {
|
|||
Reason SubackReasonCode
|
||||
}
|
||||
|
||||
|
||||
func (p SubAckPacket) Write(w io.Writer) error {
|
||||
resp := pubRespPacket{
|
||||
PacketType: PacketTypeSuback,
|
||||
PacketID: p.PacketID,
|
||||
Properties: p.Properties.ArrayOf(),
|
||||
Reason: byte(p.Reason),
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
err := types.WriteUint16(buf, p.PacketID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Write(w)
|
||||
|
||||
err = properties.WriteProps(buf, p.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(byte(p.Reason))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conPack := controlPacket{
|
||||
packetType: PacketTypeSuback,
|
||||
flags: 0,
|
||||
reader: buf,
|
||||
}
|
||||
return conPack.write(w)
|
||||
}
|
||||
|
||||
type UnsubscribePacket struct {
|
||||
*SubscriptionPacket
|
||||
props properties.UnsubscribePacketProperties
|
||||
PacketID uint16
|
||||
Topics []Topic
|
||||
Properties properties.UnsubscribePacketProperties
|
||||
}
|
||||
|
||||
func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
|
||||
|
@ -175,14 +174,41 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
|
|||
panic("Wrong packet type for parseSubscribePacket")
|
||||
}
|
||||
|
||||
pack := UnsubscribePacket{}
|
||||
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf())
|
||||
if err != nil {
|
||||
return pack, err
|
||||
packet := UnsubscribePacket{}
|
||||
r := bufio.NewReader(control.reader)
|
||||
|
||||
if control.flags != 2 {
|
||||
return packet, errors.New("Malformed subscription packet")
|
||||
}
|
||||
pack.PacketId = subscriptionPack.PacketId
|
||||
pack.TopicFilters = subscriptionPack.TopicFilters
|
||||
return pack, nil
|
||||
|
||||
var err error
|
||||
packet.PacketID, err = types.DecodeUint16(r)
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
for err != io.EOF {
|
||||
topic_str, err := types.DecodeUTF8String(r)
|
||||
if err != nil && err != io.EOF {
|
||||
return packet, err
|
||||
} else if err == io.EOF {
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
filter, err := ParseTopic(topic_str)
|
||||
if err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
packet.Topics = append(packet.Topics, filter)
|
||||
}
|
||||
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (p UnsubscribePacket) Visit(v PacketVisitor) {
|
||||
|
@ -192,7 +218,7 @@ func (p UnsubscribePacket) Visit(v PacketVisitor) {
|
|||
type UnsubackReasonCode byte
|
||||
|
||||
const (
|
||||
UnsubackReasonSuccess PubackReasonCode = 0
|
||||
UnsubackReasonSuccess UnsubackReasonCode = 0
|
||||
UnSubackReasonUnspecified = 128
|
||||
UnSubackReasonImplSpecificError = 131
|
||||
UnSubackReasonNotAuthorized = 135
|
||||
|
@ -206,13 +232,27 @@ type UnsubAckPacket struct {
|
|||
Reason UnsubackReasonCode
|
||||
}
|
||||
|
||||
|
||||
func (p UnsubAckPacket) Write(w io.Writer) error {
|
||||
resp := pubRespPacket{
|
||||
PacketType: PacketTypeUnsuback,
|
||||
PacketID: p.PacketID,
|
||||
Properties: p.Properties.ArrayOf(),
|
||||
Reason: byte(p.Reason),
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
err := types.WriteUint16(buf, p.PacketID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Write(w)
|
||||
|
||||
err = properties.WriteProps(buf, p.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(byte(p.Reason))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conPack := controlPacket{
|
||||
packetType: PacketTypeUnsuback,
|
||||
flags: 0,
|
||||
reader: buf,
|
||||
}
|
||||
return conPack.write(w)
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ func (p pubRespPacket) Write(w io.Writer) error {
|
|||
}
|
||||
|
||||
conPack := controlPacket{
|
||||
packetType: PacketTypePuback,
|
||||
packetType: p.PacketType,
|
||||
flags: 0,
|
||||
reader: buf,
|
||||
}
|
||||
|
|
|
@ -57,7 +57,6 @@ func DecodeBinaryData(r *bufio.Reader) ([]byte, error) {
|
|||
return buffer, err
|
||||
}
|
||||
|
||||
|
||||
func DecodeUTF8String(r *bufio.Reader) (string, error) {
|
||||
binary, err := DecodeBinaryData(r)
|
||||
return string(binary[:]), err
|
||||
|
|
|
@ -29,8 +29,8 @@ func WriteUint32(w io.Writer, v uint32) error {
|
|||
return err
|
||||
}
|
||||
|
||||
|
||||
const uint32Max uint32 = ^uint32(0)
|
||||
|
||||
func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
|
||||
if len(data) > int(uint32Max) {
|
||||
return errors.New("Tried to write more data than max varint size")
|
||||
|
@ -46,6 +46,7 @@ func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
|
|||
}
|
||||
|
||||
const uint16Max uint16 = ^uint16(0)
|
||||
|
||||
func WriteBinaryData(w io.Writer, data []byte) error {
|
||||
if len(data) > int(uint16Max) {
|
||||
return errors.New("Tried to write more data than max uint16 size")
|
||||
|
@ -64,7 +65,6 @@ func WriteUTF8String(w io.Writer, str string) error {
|
|||
return WriteBinaryData(w, []byte(str))
|
||||
}
|
||||
|
||||
|
||||
func WriteVariableByteInt(w io.Writer, v uint32) error {
|
||||
for {
|
||||
encodedByte := byte(v % 128)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package session
|
||||
|
||||
import (
|
||||
"bufio"
|
|
@ -1,13 +1,15 @@
|
|||
package main
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
"badat.dev/maeqtt/v2/subscription"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -25,8 +27,10 @@ type Session struct {
|
|||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
|
||||
ExpiryInterval time.Duration
|
||||
ExpiryInterval time.Duration // TODO
|
||||
expireTimer time.Timer // TODO
|
||||
|
||||
freePacketID uint16
|
||||
}
|
||||
|
||||
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
||||
|
@ -39,8 +43,7 @@ func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
|||
|
||||
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
||||
if s.Connection != nil {
|
||||
//TODO
|
||||
panic("Disconnect if already have a connection, unimplemented")
|
||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||
}
|
||||
connAck := packets.ConnackPacket{}
|
||||
|
||||
|
@ -60,7 +63,6 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
|||
connAck.Properties.RetainAvailable.Value = &false
|
||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||
|
||||
|
||||
s.Connection = conn
|
||||
err := s.Connection.sendPacket(connAck)
|
||||
if err != nil {
|
||||
|
@ -75,7 +77,7 @@ func (s *Session) HandlerLoop() {
|
|||
case packet := <-s.Connection.PacketChannel:
|
||||
packet.Visit(s)
|
||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||
s.OnDisconnect()
|
||||
s.onDisconnect()
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
//TODO, log for now
|
||||
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
|
||||
|
@ -83,22 +85,29 @@ func (s *Session) HandlerLoop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect() error {
|
||||
panic("Disconnection unimplemented")
|
||||
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
||||
s.Connection.sendPacket(packets.DisconnectPacket{
|
||||
ReasonCode: code,
|
||||
})
|
||||
|
||||
err := s.Connection.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.OnDisconnect()
|
||||
s.onDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) OnDisconnect() {
|
||||
func (s *Session) onDisconnect() {
|
||||
s.Connection = nil
|
||||
s.resetExpireTimer()
|
||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||
}
|
||||
|
||||
func (s *Session) expireSession() {
|
||||
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
||||
}
|
||||
|
||||
// newTime is nullable
|
||||
func (s *Session) updateExpireTimer(newTime *uint32) {
|
||||
var expiry = uint32(0)
|
||||
|
@ -129,35 +138,67 @@ func genClientID() *string {
|
|||
}
|
||||
|
||||
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION
|
||||
s.Disconnect()
|
||||
// Disconnect, we handle the connect packet in Connect,
|
||||
// this means that we have an estabilished connection already
|
||||
log.Println("WARN: Got a connect packet on an already estabilished connection")
|
||||
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
|
||||
}
|
||||
|
||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload))
|
||||
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
|
||||
subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
||||
defer lock.Unlock()
|
||||
if p.QOSLevel == 0 {
|
||||
if p.PacketId != nil {
|
||||
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
||||
return
|
||||
}
|
||||
} else if p.QOSLevel == 1 {
|
||||
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
||||
if len(subs) == 0 {
|
||||
reason = packets.PubackReasonCodeNoMatchingSubscribers
|
||||
}
|
||||
ack := packets.PubackPacket{
|
||||
PacketID: *p.PacketId,
|
||||
Reason: reason,
|
||||
}
|
||||
s.Connection.sendPacket(ack)
|
||||
} else if p.QOSLevel == 2 {
|
||||
panic("UNIMPLEMENTED QOS level 2")
|
||||
}
|
||||
|
||||
for _, sub := range subs {
|
||||
go func(sub Subscription) {sub <- p}(sub)
|
||||
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
||||
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||
//TODO FINISH
|
||||
// HANDLE CLIENT DISCONNECTING
|
||||
s.OnDisconnect()
|
||||
err := s.Connection.close()
|
||||
if err != nil && err != io.ErrClosedPipe {
|
||||
log.Println("Error closing connection", err)
|
||||
}
|
||||
s.onDisconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||
//TODO FINISH
|
||||
for _, filter := range p.TopicFilters {
|
||||
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel)
|
||||
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
||||
}
|
||||
s.Connection.sendPacket(packets.SubAckPacket{
|
||||
PacketID: p.PacketId,
|
||||
Reason: packets.SubackReasonGrantedQoSTwo,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
||||
for _, topic := range p.Topics {
|
||||
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
||||
}
|
||||
s.Connection.sendPacket(packets.UnsubAckPacket{
|
||||
PacketID: p.PacketID,
|
||||
Reason: packets.UnsubackReasonSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||
|
@ -179,3 +220,8 @@ func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
|||
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) getFreePacketId() uint16 {
|
||||
s.freePacketID += 1
|
||||
return s.freePacketID
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
package main
|
||||
|
||||
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
|
||||
|
||||
type Subscription chan packets.PublishPacket
|
||||
|
||||
type SubscriptionTreeNode struct {
|
||||
subscriptions []Subscription
|
||||
children map[string]*SubscriptionTreeNode
|
||||
nodeLock sync.RWMutex
|
||||
}
|
||||
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
|
||||
s := SubscriptionTreeNode{}
|
||||
s.children = make(map[string]*SubscriptionTreeNode)
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
||||
if len(fields) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
field := fields[0]
|
||||
|
||||
s.nodeLock.RLock()
|
||||
_, exists := s.children[field]
|
||||
// Insert a value into the map if one doesn't exist yet
|
||||
if !exists {
|
||||
// Can't upgrade a read lock so we need to unlock and
|
||||
// check again, this time with a write lock
|
||||
s.nodeLock.RUnlock()
|
||||
s.nodeLock.Lock()
|
||||
|
||||
_, exists = s.children[field]
|
||||
if !exists {
|
||||
s.children[field] = NewSubscriptionTreeNode()
|
||||
}
|
||||
s.nodeLock.Unlock()
|
||||
s.nodeLock.RLock()
|
||||
}
|
||||
|
||||
child, _ := s.children[field]
|
||||
s.nodeLock.RUnlock()
|
||||
return child.findNode(fields[1:])
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
|
||||
node := s.findNode(topic.Fields)
|
||||
node.nodeLock.Lock()
|
||||
node.subscriptions = append(node.subscriptions, sub)
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
|
||||
fields := strings.Split(topic,"/")
|
||||
|
||||
child := s.findNode(fields)
|
||||
locker := child.nodeLock.RLocker()
|
||||
locker.Lock()
|
||||
return child.subscriptions, locker
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
|
||||
locker := s.nodeLock.RLocker()
|
||||
s.nodeLock.RLock()
|
||||
if len(topic) == 0 {
|
||||
return s.subscriptions,locker
|
||||
}
|
||||
defer s.nodeLock.RUnlock()
|
||||
|
||||
child, exists := s.children[topic[0]]
|
||||
if exists {
|
||||
return child.findMatchingRec(topic[1:])
|
||||
} else {
|
||||
return []Subscription{},locker
|
||||
}
|
||||
}
|
106
subscription/subscription.go
Normal file
106
subscription/subscription.go
Normal file
|
@ -0,0 +1,106 @@
|
|||
package subscription
|
||||
|
||||
//TODO WILDCARD SUBSCRIPTIONS
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode()
|
||||
|
||||
type SubscriptionChannel chan packets.PublishPacket
|
||||
|
||||
type Subscription struct {
|
||||
SubscriptionChannel
|
||||
packets.TopicFilter
|
||||
}
|
||||
|
||||
type SubscriptionTreeNode struct {
|
||||
subscriptions []Subscription
|
||||
children map[string]*SubscriptionTreeNode
|
||||
nodeLock sync.RWMutex
|
||||
}
|
||||
|
||||
func newSubscriptionTreeNode() *SubscriptionTreeNode {
|
||||
s := SubscriptionTreeNode{}
|
||||
s.children = make(map[string]*SubscriptionTreeNode)
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
|
||||
if len(fields) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
field := fields[0]
|
||||
|
||||
s.nodeLock.RLock()
|
||||
_, exists := s.children[field]
|
||||
// Insert a value into the map if one doesn't exist yet
|
||||
if !exists {
|
||||
// Can't upgrade a read lock so we need to unlock and
|
||||
// check again, this time with a write lock
|
||||
s.nodeLock.RUnlock()
|
||||
s.nodeLock.Lock()
|
||||
|
||||
_, exists = s.children[field]
|
||||
if !exists {
|
||||
s.children[field] = newSubscriptionTreeNode()
|
||||
}
|
||||
s.nodeLock.Unlock()
|
||||
s.nodeLock.RLock()
|
||||
}
|
||||
|
||||
child, _ := s.children[field]
|
||||
s.nodeLock.RUnlock()
|
||||
return child.findNode(fields[1:])
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
|
||||
for i, sub := range s.subscriptions {
|
||||
if sub.SubscriptionChannel == subChan {
|
||||
lst := len(s.subscriptions) - 1
|
||||
s.subscriptions[i] = s.subscriptions[lst]
|
||||
s.subscriptions = s.subscriptions[:lst]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) {
|
||||
sub := Subscription{subChan, topicFilter}
|
||||
|
||||
node := s.findNode(topicFilter.Topic.Fields)
|
||||
node.nodeLock.Lock()
|
||||
node.subscriptions = append(node.subscriptions, sub)
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
|
||||
node := s.findNode(topic.Fields)
|
||||
|
||||
node.nodeLock.Lock()
|
||||
node.removeSubscription(subChan)
|
||||
node.nodeLock.Unlock()
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
|
||||
for _, node := range s.children {
|
||||
node.nodeLock.Lock()
|
||||
node.removeSubscription(subChan)
|
||||
node.nodeLock.Unlock()
|
||||
|
||||
node.RemoveSubsForChannel(subChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) {
|
||||
fields := strings.Split(topicName, "/")
|
||||
|
||||
child := s.findNode(fields)
|
||||
locker := child.nodeLock.RLocker()
|
||||
locker.Lock()
|
||||
return child.subscriptions, locker
|
||||
}
|
27
subscription/subscription_test.go
Normal file
27
subscription/subscription_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package subscription
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
tree := newSubscriptionTreeNode()
|
||||
topic, _ := packets.ParseTopic("a/b/c")
|
||||
channel := make(SubscriptionChannel)
|
||||
topicFilter := packets.TopicFilter{
|
||||
Topic: topic,
|
||||
MaxQoS: 1,
|
||||
}
|
||||
tree.Subscribe(topicFilter, channel)
|
||||
subs, lock := tree.GetSubscriptions("a/b/c")
|
||||
defer lock.Unlock()
|
||||
|
||||
if len(subs) != 1 {
|
||||
t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs))
|
||||
}
|
||||
if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel {
|
||||
t.Error("Error with data stored in a subscription")
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue