Improve folder strutcute

This commit is contained in:
bad 2021-10-01 22:18:48 +02:00
parent e9612e2430
commit 35879183be
15 changed files with 606 additions and 461 deletions

26
main.go
View file

@ -7,6 +7,7 @@ import (
"runtime/debug"
"badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/session"
)
func main() {
@ -27,7 +28,7 @@ func main() {
}
func handleConnection(con net.Conn) {
defer handlePanic()
defer handlePanic(con)
reader := bufio.NewReader(con)
@ -38,20 +39,33 @@ func handleConnection(con net.Conn) {
}
connect, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't recieve a connet packet")
panic("TODO: Send a disconnect packet")
log.Println("Didn't recieve a connect packet")
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(con)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
}
return
}
conn := NewConnection(connect, con)
conn := session.NewConnection(connect, con)
sess := session.NewSession(&conn, connect)
sess := NewSession(&conn, connect)
sess.HandlerLoop()
}
func handlePanic() {
func handlePanic(con net.Conn) {
if r := recover(); r != nil {
log.Println("Recovering from panic:", r)
log.Println("Stack Trace:")
debug.PrintStack()
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeImplErorr,
}.Write(con)
if err != nil {
log.Println("Failed to send a disconnect packet after recovering from panic", err)
}
}
}

View file

@ -1,6 +1,6 @@
package mqtt
// This code has been generated with the genProps.py script. Do not modify
// This code has been generated with the genProps.py script. Do not modify
import "bufio"
@ -21,7 +21,6 @@ func (p *PayloadFormatIndicator) parse(r *bufio.Reader) error {
return nil
}
type MessageExpiryInterval struct {
value *uint32
}
@ -39,7 +38,6 @@ func (p *MessageExpiryInterval) parse(r *bufio.Reader) error {
return nil
}
type ContentType struct {
value *string
}
@ -57,7 +55,6 @@ func (p *ContentType) parse(r *bufio.Reader) error {
return nil
}
type ResponseTopic struct {
value *string
}
@ -75,7 +72,6 @@ func (p *ResponseTopic) parse(r *bufio.Reader) error {
return nil
}
type CorrelationData struct {
value *[]byte
}
@ -93,7 +89,6 @@ func (p *CorrelationData) parse(r *bufio.Reader) error {
return nil
}
type SubscriptionIdentifier struct {
value *int
}
@ -111,7 +106,6 @@ func (p *SubscriptionIdentifier) parse(r *bufio.Reader) error {
return nil
}
type SessionExpiryInterval struct {
value *uint32
}
@ -129,7 +123,6 @@ func (p *SessionExpiryInterval) parse(r *bufio.Reader) error {
return nil
}
type AssignedClientIdentifier struct {
value *string
}
@ -147,7 +140,6 @@ func (p *AssignedClientIdentifier) parse(r *bufio.Reader) error {
return nil
}
type ServerKeepAlive struct {
value *uint16
}
@ -165,7 +157,6 @@ func (p *ServerKeepAlive) parse(r *bufio.Reader) error {
return nil
}
type AuthenticationMethod struct {
value *string
}
@ -183,7 +174,6 @@ func (p *AuthenticationMethod) parse(r *bufio.Reader) error {
return nil
}
type AuthenticationData struct {
value *[]byte
}
@ -201,7 +191,6 @@ func (p *AuthenticationData) parse(r *bufio.Reader) error {
return nil
}
type RequestProblemInformation struct {
value *byte
}
@ -219,7 +208,6 @@ func (p *RequestProblemInformation) parse(r *bufio.Reader) error {
return nil
}
type WillDelayInterval struct {
value *uint32
}
@ -237,7 +225,6 @@ func (p *WillDelayInterval) parse(r *bufio.Reader) error {
return nil
}
type RequestResponseInformation struct {
value *byte
}
@ -255,7 +242,6 @@ func (p *RequestResponseInformation) parse(r *bufio.Reader) error {
return nil
}
type ResponseInformation struct {
value *string
}
@ -273,7 +259,6 @@ func (p *ResponseInformation) parse(r *bufio.Reader) error {
return nil
}
type ServerReference struct {
value *string
}
@ -291,7 +276,6 @@ func (p *ServerReference) parse(r *bufio.Reader) error {
return nil
}
type ReasonString struct {
value *string
}
@ -309,7 +293,6 @@ func (p *ReasonString) parse(r *bufio.Reader) error {
return nil
}
type ReceiveMaximum struct {
value *uint16
}
@ -327,7 +310,6 @@ func (p *ReceiveMaximum) parse(r *bufio.Reader) error {
return nil
}
type TopicAliasMaximum struct {
value *uint16
}
@ -345,7 +327,6 @@ func (p *TopicAliasMaximum) parse(r *bufio.Reader) error {
return nil
}
type TopicAlias struct {
value *uint16
}
@ -363,7 +344,6 @@ func (p *TopicAlias) parse(r *bufio.Reader) error {
return nil
}
type MaximumQoS struct {
value *byte
}
@ -381,7 +361,6 @@ func (p *MaximumQoS) parse(r *bufio.Reader) error {
return nil
}
type RetainAvailable struct {
value *byte
}
@ -399,7 +378,6 @@ func (p *RetainAvailable) parse(r *bufio.Reader) error {
return nil
}
type MaximumPacketSize struct {
value *uint32
}
@ -417,7 +395,6 @@ func (p *MaximumPacketSize) parse(r *bufio.Reader) error {
return nil
}
type WildcardSubscriptionAvailable struct {
value *byte
}
@ -435,7 +412,6 @@ func (p *WildcardSubscriptionAvailable) parse(r *bufio.Reader) error {
return nil
}
type SubscriptionIdentifierAvailable struct {
value *byte
}
@ -453,7 +429,6 @@ func (p *SubscriptionIdentifierAvailable) parse(r *bufio.Reader) error {
return nil
}
type SharedSubscriptionAvailable struct {
value *byte
}
@ -481,6 +456,7 @@ SubscriptionIdentifier SubscriptionIdentifier
TopicAlias TopicAlias
UserProperty UserProperty
}
func (p *PublishPacketProperties) arrayOf() []Property {
return []Property{
&p.PayloadFormatIndicator,
@ -493,6 +469,7 @@ return []Property {
&p.UserProperty,
}
}
type WillProperties struct {
PayloadFormatIndicator PayloadFormatIndicator
MessageExpiryInterval MessageExpiryInterval
@ -502,6 +479,7 @@ CorrelationData CorrelationData
WillDelayInterval WillDelayInterval
UserProperty UserProperty
}
func (p *WillProperties) arrayOf() []Property {
return []Property{
&p.PayloadFormatIndicator,
@ -513,16 +491,19 @@ return []Property {
&p.UserProperty,
}
}
type SubscribePacketProperties struct {
SubscriptionIdentifier SubscriptionIdentifier
UserProperty UserProperty
}
func (p *SubscribePacketProperties) arrayOf() []Property {
return []Property{
&p.SubscriptionIdentifier,
&p.UserProperty,
}
}
type ConnectPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval
AuthenticationMethod AuthenticationMethod
@ -534,6 +515,7 @@ TopicAliasMaximum TopicAliasMaximum
UserProperty UserProperty
MaximumPacketSize MaximumPacketSize
}
func (p *ConnectPacketProperties) arrayOf() []Property {
return []Property{
&p.SessionExpiryInterval,
@ -547,6 +529,7 @@ return []Property {
&p.MaximumPacketSize,
}
}
type ConnackPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval
AssignedClientIdentifier AssignedClientIdentifier
@ -566,6 +549,7 @@ WildcardSubscriptionAvailable WildcardSubscriptionAvailable
SubscriptionIdentifierAvailable SubscriptionIdentifierAvailable
SharedSubscriptionAvailable SharedSubscriptionAvailable
}
func (p *ConnackPacketProperties) arrayOf() []Property {
return []Property{
&p.SessionExpiryInterval,
@ -587,12 +571,14 @@ return []Property {
&p.SharedSubscriptionAvailable,
}
}
type DisconnectPacketProperties struct {
SessionExpiryInterval SessionExpiryInterval
ServerReference ServerReference
ReasonString ReasonString
UserProperty UserProperty
}
func (p *DisconnectPacketProperties) arrayOf() []Property {
return []Property{
&p.SessionExpiryInterval,
@ -601,12 +587,14 @@ return []Property {
&p.UserProperty,
}
}
type AuthPacketProperties struct {
AuthenticationMethod AuthenticationMethod
AuthenticationData AuthenticationData
ReasonString ReasonString
UserProperty UserProperty
}
func (p *AuthPacketProperties) arrayOf() []Property {
return []Property{
&p.AuthenticationMethod,
@ -615,69 +603,83 @@ return []Property {
&p.UserProperty,
}
}
type PubackPacketProperties struct {
ReasonString ReasonString
UserProperty UserProperty
}
func (p *PubackPacketProperties) arrayOf() []Property {
return []Property{
&p.ReasonString,
&p.UserProperty,
}
}
type PubrecPacketProperties struct {
ReasonString ReasonString
UserProperty UserProperty
}
func (p *PubrecPacketProperties) arrayOf() []Property {
return []Property{
&p.ReasonString,
&p.UserProperty,
}
}
type PubrelPacketProperties struct {
ReasonString ReasonString
UserProperty UserProperty
}
func (p *PubrelPacketProperties) arrayOf() []Property {
return []Property{
&p.ReasonString,
&p.UserProperty,
}
}
type PubcompPacketProperties struct {
ReasonString ReasonString
UserProperty UserProperty
}
func (p *PubcompPacketProperties) arrayOf() []Property {
return []Property{
&p.ReasonString,
&p.UserProperty,
}
}
type SubackPacketProperties struct {
ReasonString ReasonString
UserProperty UserProperty
}
func (p *SubackPacketProperties) arrayOf() []Property {
return []Property{
&p.ReasonString,
&p.UserProperty,
}
}
type UnsubackPacketProperties struct {
ReasonString ReasonString
UserProperty UserProperty
}
func (p *UnsubackPacketProperties) arrayOf() []Property {
return []Property{
&p.ReasonString,
&p.UserProperty,
}
}
type UnsubscribePacketProperties struct {
UserProperty UserProperty
}
func (p *UnsubscribePacketProperties) arrayOf() []Property {
return []Property{
&p.UserProperty,

View file

@ -60,7 +60,6 @@ func parseDisconnectPacket(control controlPacket) (DisconnectPacket, error) {
r := bufio.NewReader(control.reader)
// If there is less then a byte in the reader assume the reason code == 0
reason, err := r.ReadByte()
if err == io.EOF {

View file

@ -6,7 +6,6 @@ import (
"io"
)
type PingreqPacket struct{}
func parsePingreq(control controlPacket) (PingreqPacket, error) {

View file

@ -1,11 +1,10 @@
package packets
import (
"io"
"badat.dev/maeqtt/v2/mqtt/properties"
"io"
)
type PubackReasonCode byte
const (

View file

@ -2,6 +2,7 @@ package packets
import (
"bufio"
"bytes"
"errors"
"io"
"strings"
@ -15,7 +16,8 @@ type Topic struct {
}
var multiLevelWildcardNotLast = errors.New("Multi level wildcard isn't the field in a topic")
func parseTopic(topic_name string) (Topic, error) {
func ParseTopic(topic_name string) (Topic, error) {
topic := Topic{}
fields := strings.Split(topic_name, "/")
for i, field := range fields {
@ -45,7 +47,7 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
return filter, err
}
filter.Topic, err = parseTopic(topic_str)
filter.Topic, err = ParseTopic(topic_str)
if err != nil {
return filter, err
}
@ -61,27 +63,32 @@ func parseTopicFilter(r *bufio.Reader) (TopicFilter, error) {
return filter, nil
}
// Both sub and unsubscribe packets are identitcal so we can reuse the parsing logic
type SubscriptionPacket struct {
type SubscribePacket struct {
PacketId uint16
TopicFilters []TopicFilter
Properties properties.SubscribePacketProperties
}
func parseSubscriptionPacket(control controlPacket, props []properties.Property) (SubscriptionPacket, error) {
var err error
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
}
packet := SubscribePacket{}
r := bufio.NewReader(control.reader)
packet := SubscriptionPacket{}
if control.flags != 2 {
return packet, errors.New("Malformed subscription packet")
}
var err error
packet.PacketId, err = types.DecodeUint16(r)
if err != nil {
return packet, err
}
err = properties.ParseProperties(r, props)
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil {
return packet, err
}
@ -100,33 +107,10 @@ func parseSubscriptionPacket(control controlPacket, props []properties.Property)
return packet, nil
}
}
println("A")
return packet, nil
}
type SubscribePacket struct {
*SubscriptionPacket
props properties.SubscribePacketProperties
}
/// CURRENTLY BROKEN
// TODO FIXME AAAAA
func parseSubscribePacket(control controlPacket) (SubscribePacket, error) {
if control.packetType != PacketTypeSubscribe {
panic("Wrong packet type for parseSubscribePacket")
}
pack := SubscribePacket{}
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf())
if err != nil {
return pack, err
}
pack.SubscriptionPacket = &subscriptionPack
return pack, nil
}
func (p SubscribePacket) Visit(v PacketVisitor) {
v.VisitSubscribe(p)
}
@ -143,9 +127,9 @@ const (
SubackReasonTopicFilterInvalid = 143
SubackReasonPacketIDInUse = 145
SubackReasonQuotaExceeded = 151
SubackReasonSharedSubNotSupported = 151
SubackReasonSubIDUnsupported = 151
SubackReasonWildcardSubUnsupported = 151
SubackReasonSharedSubNotSupported = 158
SubackReasonSubIDUnsupported = 161
SubackReasonWildcardSubUnsupported = 162
)
type SubAckPacket struct {
@ -154,20 +138,35 @@ type SubAckPacket struct {
Reason SubackReasonCode
}
func (p SubAckPacket) Write(w io.Writer) error {
resp := pubRespPacket{
PacketType: PacketTypeSuback,
PacketID: p.PacketID,
Properties: p.Properties.ArrayOf(),
Reason: byte(p.Reason),
buf := bytes.NewBuffer([]byte{})
err := types.WriteUint16(buf, p.PacketID)
if err != nil {
return err
}
return resp.Write(w)
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
err = buf.WriteByte(byte(p.Reason))
if err != nil {
return err
}
conPack := controlPacket{
packetType: PacketTypeSuback,
flags: 0,
reader: buf,
}
return conPack.write(w)
}
type UnsubscribePacket struct {
*SubscriptionPacket
props properties.UnsubscribePacketProperties
PacketID uint16
Topics []Topic
Properties properties.UnsubscribePacketProperties
}
func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
@ -175,14 +174,41 @@ func parseUnsubscribePacket(control controlPacket) (UnsubscribePacket, error) {
panic("Wrong packet type for parseSubscribePacket")
}
pack := UnsubscribePacket{}
subscriptionPack, err := parseSubscriptionPacket(control, pack.props.ArrayOf())
if err != nil {
return pack, err
packet := UnsubscribePacket{}
r := bufio.NewReader(control.reader)
if control.flags != 2 {
return packet, errors.New("Malformed subscription packet")
}
pack.PacketId = subscriptionPack.PacketId
pack.TopicFilters = subscriptionPack.TopicFilters
return pack, nil
var err error
packet.PacketID, err = types.DecodeUint16(r)
if err != nil {
return packet, err
}
err = properties.ParseProperties(r, packet.Properties.ArrayOf())
if err != nil {
return packet, err
}
for err != io.EOF {
topic_str, err := types.DecodeUTF8String(r)
if err != nil && err != io.EOF {
return packet, err
} else if err == io.EOF {
return packet, nil
}
filter, err := ParseTopic(topic_str)
if err != nil {
return packet, err
}
packet.Topics = append(packet.Topics, filter)
}
return packet, nil
}
func (p UnsubscribePacket) Visit(v PacketVisitor) {
@ -192,7 +218,7 @@ func (p UnsubscribePacket) Visit(v PacketVisitor) {
type UnsubackReasonCode byte
const (
UnsubackReasonSuccess PubackReasonCode = 0
UnsubackReasonSuccess UnsubackReasonCode = 0
UnSubackReasonUnspecified = 128
UnSubackReasonImplSpecificError = 131
UnSubackReasonNotAuthorized = 135
@ -206,13 +232,27 @@ type UnsubAckPacket struct {
Reason UnsubackReasonCode
}
func (p UnsubAckPacket) Write(w io.Writer) error {
resp := pubRespPacket{
PacketType: PacketTypeUnsuback,
PacketID: p.PacketID,
Properties: p.Properties.ArrayOf(),
Reason: byte(p.Reason),
buf := bytes.NewBuffer([]byte{})
err := types.WriteUint16(buf, p.PacketID)
if err != nil {
return err
}
return resp.Write(w)
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
err = buf.WriteByte(byte(p.Reason))
if err != nil {
return err
}
conPack := controlPacket{
packetType: PacketTypeUnsuback,
flags: 0,
reader: buf,
}
return conPack.write(w)
}

View file

@ -34,7 +34,7 @@ func (p pubRespPacket) Write(w io.Writer) error {
}
conPack := controlPacket{
packetType: PacketTypePuback,
packetType: p.PacketType,
flags: 0,
reader: buf,
}

View file

@ -57,7 +57,6 @@ func DecodeBinaryData(r *bufio.Reader) ([]byte, error) {
return buffer, err
}
func DecodeUTF8String(r *bufio.Reader) (string, error) {
binary, err := DecodeBinaryData(r)
return string(binary[:]), err

View file

@ -29,8 +29,8 @@ func WriteUint32(w io.Writer, v uint32) error {
return err
}
const uint32Max uint32 = ^uint32(0)
func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
if len(data) > int(uint32Max) {
return errors.New("Tried to write more data than max varint size")
@ -46,6 +46,7 @@ func WriteDataWithVarIntLen(w io.Writer, data []byte) error {
}
const uint16Max uint16 = ^uint16(0)
func WriteBinaryData(w io.Writer, data []byte) error {
if len(data) > int(uint16Max) {
return errors.New("Tried to write more data than max uint16 size")
@ -64,7 +65,6 @@ func WriteUTF8String(w io.Writer, str string) error {
return WriteBinaryData(w, []byte(str))
}
func WriteVariableByteInt(w io.Writer, v uint32) error {
for {
encodedByte := byte(v % 128)

View file

@ -1,4 +1,4 @@
package main
package session
import (
"bufio"

View file

@ -1,13 +1,15 @@
package main
package session
import (
"encoding/base64"
"fmt"
"io"
"log"
"math/rand"
"time"
"badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/subscription"
)
func init() {
@ -25,8 +27,10 @@ type Session struct {
Connection *Connection
SubscriptionChannel chan packets.PublishPacket
ExpiryInterval time.Duration
ExpiryInterval time.Duration // TODO
expireTimer time.Timer // TODO
freePacketID uint16
}
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
@ -39,8 +43,7 @@ func NewSession(conn *Connection, p packets.ConnectPacket) Session {
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
if s.Connection != nil {
//TODO
panic("Disconnect if already have a connection, unimplemented")
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
}
connAck := packets.ConnackPacket{}
@ -60,7 +63,6 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = conn
err := s.Connection.sendPacket(connAck)
if err != nil {
@ -75,7 +77,7 @@ func (s *Session) HandlerLoop() {
case packet := <-s.Connection.PacketChannel:
packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan:
s.OnDisconnect()
s.onDisconnect()
case subMessage := <-s.SubscriptionChannel:
//TODO, log for now
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
@ -83,22 +85,29 @@ func (s *Session) HandlerLoop() {
}
}
func (s *Session) Disconnect() error {
panic("Disconnection unimplemented")
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
s.Connection.sendPacket(packets.DisconnectPacket{
ReasonCode: code,
})
err := s.Connection.close()
if err != nil {
return err
}
s.OnDisconnect()
s.onDisconnect()
return nil
}
func (s *Session) OnDisconnect() {
func (s *Session) onDisconnect() {
s.Connection = nil
s.resetExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID)
}
func (s *Session) expireSession() {
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
}
// newTime is nullable
func (s *Session) updateExpireTimer(newTime *uint32) {
var expiry = uint32(0)
@ -129,35 +138,67 @@ func genClientID() *string {
}
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
// ERROR CANNOT RECIEVE CONNECT ON AN ALREADY OPEN CONNECTION
s.Disconnect()
// Disconnect, we handle the connect packet in Connect,
// this means that we have an estabilished connection already
log.Println("WARN: Got a connect packet on an already estabilished connection")
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
}
func (s *Session) VisitPublish(p packets.PublishPacket) {
println("UNIMPLEMENTED, Publishing packet, message:", string(p.Payload))
subs, lock := Subscriptions.GetSubscriptions(p.TopicName)
subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName)
defer lock.Unlock()
if p.QOSLevel == 0 {
if p.PacketId != nil {
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
return
}
} else if p.QOSLevel == 1 {
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
if len(subs) == 0 {
reason = packets.PubackReasonCodeNoMatchingSubscribers
}
ack := packets.PubackPacket{
PacketID: *p.PacketId,
Reason: reason,
}
s.Connection.sendPacket(ack)
} else if p.QOSLevel == 2 {
panic("UNIMPLEMENTED QOS level 2")
}
for _, sub := range subs {
go func(sub Subscription) {sub <- p}(sub)
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
}
}
}
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
//TODO FINISH
// HANDLE CLIENT DISCONNECTING
s.OnDisconnect()
err := s.Connection.close()
if err != nil && err != io.ErrClosedPipe {
log.Println("Error closing connection", err)
}
s.onDisconnect()
}
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
//TODO FINISH
for _, filter := range p.TopicFilters {
Subscriptions.Subscribe(filter.Topic, s.SubscriptionChannel)
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.SubAckPacket{
PacketID: p.PacketId,
Reason: packets.SubackReasonGrantedQoSTwo,
})
}
func (s *Session) VisitUnsubscribe(_ packets.UnsubscribePacket) {
panic("not implemented") // TODO: Implement
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
for _, topic := range p.Topics {
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.UnsubAckPacket{
PacketID: p.PacketID,
Reason: packets.UnsubackReasonSuccess,
})
}
func (s *Session) VisitPing(p packets.PingreqPacket) {
@ -179,3 +220,8 @@ func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) getFreePacketId() uint16 {
s.freePacketID += 1
return s.freePacketID
}

View file

@ -1,86 +0,0 @@
package main
//TODO FULLY IMPLEMENT SUBSCRIPTIONS INSTEAD OF JUST THE TOPIC FILTERS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *NewSubscriptionTreeNode()
type Subscription chan packets.PublishPacket
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func NewSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = NewSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) Subscribe(topic packets.Topic, sub Subscription) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) GetSubscriptions(topic string) ([]Subscription, sync.Locker) {
fields := strings.Split(topic,"/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}
func (s *SubscriptionTreeNode) findMatchingRec(topic []string) ([]Subscription, sync.Locker) {
locker := s.nodeLock.RLocker()
s.nodeLock.RLock()
if len(topic) == 0 {
return s.subscriptions,locker
}
defer s.nodeLock.RUnlock()
child, exists := s.children[topic[0]]
if exists {
return child.findMatchingRec(topic[1:])
} else {
return []Subscription{},locker
}
}

View file

@ -0,0 +1,106 @@
package subscription
//TODO WILDCARD SUBSCRIPTIONS
import (
"strings"
"sync"
"badat.dev/maeqtt/v2/mqtt/packets"
)
var Subscriptions SubscriptionTreeNode = *newSubscriptionTreeNode()
type SubscriptionChannel chan packets.PublishPacket
type Subscription struct {
SubscriptionChannel
packets.TopicFilter
}
type SubscriptionTreeNode struct {
subscriptions []Subscription
children map[string]*SubscriptionTreeNode
nodeLock sync.RWMutex
}
func newSubscriptionTreeNode() *SubscriptionTreeNode {
s := SubscriptionTreeNode{}
s.children = make(map[string]*SubscriptionTreeNode)
return &s
}
func (s *SubscriptionTreeNode) findNode(fields []string) *SubscriptionTreeNode {
if len(fields) == 0 {
return s
}
field := fields[0]
s.nodeLock.RLock()
_, exists := s.children[field]
// Insert a value into the map if one doesn't exist yet
if !exists {
// Can't upgrade a read lock so we need to unlock and
// check again, this time with a write lock
s.nodeLock.RUnlock()
s.nodeLock.Lock()
_, exists = s.children[field]
if !exists {
s.children[field] = newSubscriptionTreeNode()
}
s.nodeLock.Unlock()
s.nodeLock.RLock()
}
child, _ := s.children[field]
s.nodeLock.RUnlock()
return child.findNode(fields[1:])
}
func (s *SubscriptionTreeNode) removeSubscription(subChan SubscriptionChannel) {
for i, sub := range s.subscriptions {
if sub.SubscriptionChannel == subChan {
lst := len(s.subscriptions) - 1
s.subscriptions[i] = s.subscriptions[lst]
s.subscriptions = s.subscriptions[:lst]
}
}
}
func (s *SubscriptionTreeNode) Subscribe(topicFilter packets.TopicFilter, subChan SubscriptionChannel) {
sub := Subscription{subChan, topicFilter}
node := s.findNode(topicFilter.Topic.Fields)
node.nodeLock.Lock()
node.subscriptions = append(node.subscriptions, sub)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) Unsubscribe(topic packets.Topic, subChan SubscriptionChannel) {
node := s.findNode(topic.Fields)
node.nodeLock.Lock()
node.removeSubscription(subChan)
node.nodeLock.Unlock()
}
func (s *SubscriptionTreeNode) RemoveSubsForChannel(subChan SubscriptionChannel) {
for _, node := range s.children {
node.nodeLock.Lock()
node.removeSubscription(subChan)
node.nodeLock.Unlock()
node.RemoveSubsForChannel(subChan)
}
}
func (s *SubscriptionTreeNode) GetSubscriptions(topicName string) ([]Subscription, sync.Locker) {
fields := strings.Split(topicName, "/")
child := s.findNode(fields)
locker := child.nodeLock.RLocker()
locker.Lock()
return child.subscriptions, locker
}

View file

@ -0,0 +1,27 @@
package subscription
import (
"testing"
"badat.dev/maeqtt/v2/mqtt/packets"
)
func TestSubscribe(t *testing.T) {
tree := newSubscriptionTreeNode()
topic, _ := packets.ParseTopic("a/b/c")
channel := make(SubscriptionChannel)
topicFilter := packets.TopicFilter{
Topic: topic,
MaxQoS: 1,
}
tree.Subscribe(topicFilter, channel)
subs, lock := tree.GetSubscriptions("a/b/c")
defer lock.Unlock()
if len(subs) != 1 {
t.Errorf("Error storing subscriptions, expected to len(subs) to be 1, got: %v \n", len(subs))
}
if subs[0].MaxQoS != topicFilter.MaxQoS || subs[0].SubscriptionChannel != channel {
t.Error("Error with data stored in a subscription")
}
}