From baf746a2540560da35ea261f3b0149e457ebe0d8 Mon Sep 17 00:00:00 2001 From: bad Date: Sat, 16 Oct 2021 23:38:23 +0200 Subject: [PATCH] Refactoring, session expiry --- main.go | 33 +++-- session/Session.go | 98 ++++++++++++++ session/connectionRequest.go | 8 ++ session/expiry.go | 49 +++++++ session/packetVisitors.go | 93 +++++++++++++ session/session.go | 245 ----------------------------------- session/utils.go | 44 +++++++ subscription/subscription.go | 2 - 8 files changed, 316 insertions(+), 256 deletions(-) create mode 100644 session/Session.go create mode 100644 session/connectionRequest.go create mode 100644 session/expiry.go create mode 100644 session/packetVisitors.go delete mode 100644 session/session.go create mode 100644 session/utils.go diff --git a/main.go b/main.go index ab0eadc..1a89c79 100644 --- a/main.go +++ b/main.go @@ -14,27 +14,42 @@ func main() { listener, err := net.Listen("tcp", listen_addr) if err != nil { - log.Fatal(err) + log.Fatalf("Coulde't start a listener on tcp %v. Error: %e", listen_addr, err) } var sessions map[string]*session.Session = make(map[string]*session.Session) + removeSessChan := make(session.RemoveSessionChannel) + connChan := make(chan net.Conn) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.Println("Failed accepting connection ", err) + } else { + connChan <- conn + } + } + + }() for { - conn, err := listener.Accept() - if err != nil { - log.Println("Failed accepting connection ", err) + select { + case con := <- connChan: + handleConnection(con, sessions, removeSessChan) + case sesId := <- removeSessChan: + delete(sessions, sesId) } - handleConnection(conn, sessions) } } -func handleConnection(con net.Conn, sessions map[string]*session.Session) { +func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSessChan session.RemoveSessionChannel) { defer handlePanic(con) conReq, err := session.NewConnection(con) if err != nil { - // TODO - panic(err) + log.Println("Failed to create connection ", err) + return } var sess *session.Session @@ -46,7 +61,7 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session) { } if sess == nil { - newSess := session.NewSession(conReq) + newSess := session.NewSession(conReq, rmSessChan) sess = &newSess go func() { defer handlePanic(con) diff --git a/session/Session.go b/session/Session.go new file mode 100644 index 0000000..9f84d76 --- /dev/null +++ b/session/Session.go @@ -0,0 +1,98 @@ +package session + +import ( + "fmt" + "log" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +type Session struct { + ClientID *string + + // Nullable + Connection *Connection + SubscriptionChannel chan packets.PublishPacket + ConnecionChannel chan ConnectionRequest + + + freePacketID uint16 + + Expiry +} + +func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session { + sess := Session{} + sess.SubscriptionChannel = make(chan packets.PublishPacket,) + sess.Expiry = NewExpiry(rmSessChan) + + sess.Connect(req) + return sess +} + +func (s *Session) Connect(req ConnectionRequest) { + if s.Connection != nil { + s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) + } + connAck := packets.ConnackPacket{} + + + s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value) + s.expireTimer.Stop() + + if req.ConnectPakcet.ClientId == nil { + if s.ClientID == nil { + s.ClientID = genClientID() + } + connAck.Properties.AssignedClientIdentifier.Value = s.ClientID + } else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId { + panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId)) + } else { + s.ClientID = req.ConnectPakcet.ClientId + } + + true := byte(1) + false := byte(0) + connAck.Properties.WildcardSubscriptionAvailable.Value = &true + + connAck.Properties.RetainAvailable.Value = &false + connAck.Properties.SharedSubscriptionAvailable.Value = &false + + s.Connection = req.Connection + s.Connection.sendPacket(connAck) +} + +// Starts a loop the recieves and responds to packets +func (s *Session) HandlerLoop() { + go s.Connection.PacketReadLoop() + for s.Connection != nil { + select { + case packet := <-s.Connection.PacketChannel: + packet.Visit(s) + case _ = <-s.Connection.ClientDisconnectedChan: + s.onDisconnect() + case c := <-s.ConnecionChannel: + s.Connect(c) + case subMessage := <-s.SubscriptionChannel: + // TODO implement other qos levels + subMessage.QOSLevel = 0 + subMessage.Dup = false + s.Connection.sendPacket(subMessage) + } + } + + select { + case c := <-s.ConnecionChannel: + s.Connect(c) + // Tail recursion baybeeee + s.HandlerLoop() + case _ = <- s.expireTimer.C: + s.expireSession() + } +} + +func (s *Session) onDisconnect() { + s.Connection = nil + s.resetExpireTimer() + log.Printf("Client disconnected, id: %s", *s.ClientID) +} diff --git a/session/connectionRequest.go b/session/connectionRequest.go new file mode 100644 index 0000000..c9e8d68 --- /dev/null +++ b/session/connectionRequest.go @@ -0,0 +1,8 @@ +package session + +import "badat.dev/maeqtt/v2/mqtt/packets" + +type ConnectionRequest struct { + Connection *Connection + ConnectPakcet packets.ConnectPacket +} diff --git a/session/expiry.go b/session/expiry.go new file mode 100644 index 0000000..23ded86 --- /dev/null +++ b/session/expiry.go @@ -0,0 +1,49 @@ +package session + +import ( + "time" + + "badat.dev/maeqtt/v2/subscription" +) + +type Expiry struct { + ExpiryInterval time.Duration + expireTimer time.Timer + RemoveSessionChannel +} + +func NewExpiry(channel RemoveSessionChannel) Expiry { + expiry := Expiry {} + + expiry.RemoveSessionChannel = channel + expiry.expireTimer = *time.NewTimer(time.Hour*9999) + expiry.expireTimer.Stop() + return expiry +} + +// Channel for removing a session from the global state +type RemoveSessionChannel chan string + +func (s *Session) expireSession() { + subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel) + s.RemoveSessionChannel <- *s.ClientID +} + +// newTime is nullable +func (s *Session) SetExpireTimer(newTime *uint32) { + var expiry = uint32(0) + if newTime != nil { + expiry = *newTime + } else { + expiry = uint32(0) + } + s.ExpiryInterval = time.Duration(expiry) * time.Second +} + +func (s *Session) resetExpireTimer() { + if s.ExpiryInterval == 0 { + s.expireTimer.Stop() + } else { + s.expireTimer.Reset(s.ExpiryInterval) + } +} diff --git a/session/packetVisitors.go b/session/packetVisitors.go new file mode 100644 index 0000000..620a89b --- /dev/null +++ b/session/packetVisitors.go @@ -0,0 +1,93 @@ +package session + +import ( + "io" + "log" + + "badat.dev/maeqtt/v2/mqtt/packets" + "badat.dev/maeqtt/v2/subscription" +) + +func (s *Session) VisitConnect(_ packets.ConnectPacket) { + // Disconnect, we handle the connect packet in Connect, + // this means that we have an estabilished connection already + log.Println("WARN: Got a connect packet on an already estabilished connection") + s.Disconnect(packets.DisconnectReasonCodeProtocolError) +} + +func (s *Session) VisitPublish(p packets.PublishPacket) { + subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName) + for _, subNode := range subNodes { + subNode.NodeLock.RLock() + defer subNode.NodeLock.RUnlock() + if p.QOSLevel == 0 { + if p.PacketId != nil { + log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) + return + } + } else if p.QOSLevel == 1 { + var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess + ack := packets.PubackPacket{ + PacketID: *p.PacketId, + Reason: reason, + } + s.Connection.sendPacket(ack) + } else if p.QOSLevel == 2 { + panic("UNIMPLEMENTED QOS level 2") + } + + for _, sub := range subNode.Subscriptions { + if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { + go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) + } + } + } +} + +func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { + err := s.Connection.close() + if err != nil && err != io.ErrClosedPipe { + log.Println("Error closing connection", err) + } + s.onDisconnect() +} + +func (s *Session) VisitSubscribe(p packets.SubscribePacket) { + for _, filter := range p.TopicFilters { + subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel) + } + s.Connection.sendPacket(packets.SubAckPacket{ + PacketID: p.PacketId, + Reason: packets.SubackReasonGrantedQoSTwo, + }) +} + +func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) { + for _, topic := range p.Topics { + subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel) + } + s.Connection.sendPacket(packets.UnsubAckPacket{ + PacketID: p.PacketID, + Reason: packets.UnsubackReasonSuccess, + }) +} + +func (s *Session) VisitPing(p packets.PingreqPacket) { + s.Connection.sendPacket(packets.PingrespPacket{}) +} + +func (s *Session) VisitPubackPacket(_ packets.PubackPacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) { + panic("not implemented") // TODO: Implement +} + +func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { + panic("not implemented") // TODO: Implement +} diff --git a/session/session.go b/session/session.go deleted file mode 100644 index fe826a4..0000000 --- a/session/session.go +++ /dev/null @@ -1,245 +0,0 @@ -package session - -import ( - "encoding/base64" - "fmt" - "io" - "log" - "math/rand" - "time" - - "badat.dev/maeqtt/v2/mqtt/packets" - "badat.dev/maeqtt/v2/subscription" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func Auth(username string, password []byte) bool { - return true -} - -type Session struct { - ClientID *string - - // Nullable - Connection *Connection - SubscriptionChannel chan packets.PublishPacket - ConnecionChannel chan ConnectionRequest - - ExpiryInterval time.Duration // TODO - expireTimer time.Timer // TODO - - freePacketID uint16 -} - -type ConnectionRequest struct { - Connection *Connection - ConnectPakcet packets.ConnectPacket -} - -func NewSession(req ConnectionRequest) Session { - sess := Session{} - sess.SubscriptionChannel = make(chan packets.PublishPacket) - - sess.Connect(req) - return sess -} - -func (s *Session) Connect(req ConnectionRequest) { - if s.Connection != nil { - s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) - } - connAck := packets.ConnackPacket{} - - s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value) - - if req.ConnectPakcet.ClientId == nil { - if s.ClientID == nil { - s.ClientID = genClientID() - } - connAck.Properties.AssignedClientIdentifier.Value = s.ClientID - } else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId { - panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId)) - } else { - s.ClientID = req.ConnectPakcet.ClientId - } - - true := byte(1) - false := byte(0) - connAck.Properties.WildcardSubscriptionAvailable.Value = &true - - connAck.Properties.RetainAvailable.Value = &false - connAck.Properties.SharedSubscriptionAvailable.Value = &false - - s.Connection = req.Connection - err := s.Connection.sendPacket(connAck) - if err != nil { - // TODO - panic(err) - } -} - -// Starts a loop the recieves and responds to packets -func (s *Session) HandlerLoop() { - go s.Connection.PacketReadLoop() - for s.Connection != nil { - select { - case packet := <-s.Connection.PacketChannel: - packet.Visit(s) - case _ = <-s.Connection.ClientDisconnectedChan: - s.onDisconnect() - case c := <-s.ConnecionChannel: - s.Connect(c) - case subMessage := <-s.SubscriptionChannel: - subMessage.QOSLevel = 0 - subMessage.Dup = false - s.Connection.sendPacket(subMessage) - } - } - c := <-s.ConnecionChannel - s.Connect(c) - s.HandlerLoop() -} - -func (s *Session) Disconnect(code packets.DisconnectReasonCode) error { - s.Connection.sendPacket(packets.DisconnectPacket{ - ReasonCode: code, - }) - - err := s.Connection.close() - if err != nil { - return err - } - s.onDisconnect() - return nil -} - -func (s *Session) onDisconnect() { - s.Connection = nil - s.resetExpireTimer() - log.Printf("Client disconnected, id: %s", *s.ClientID) -} - -func (s *Session) expireSession() { - subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel) -} - -// newTime is nullable -func (s *Session) updateExpireTimer(newTime *uint32) { - var expiry = uint32(0) - if newTime != nil { - expiry = *newTime - } else { - expiry = uint32(0) - } - s.ExpiryInterval = time.Duration(expiry) * time.Second - - if s.Connection == nil { - s.resetExpireTimer() - } -} -func (s *Session) resetExpireTimer() { - //s.expireTimer.Reset(s.ExpiryInterval) -} - -func genClientID() *string { - buf := make([]byte, 32) - _, err := rand.Read(buf) - if err != nil { - // I don't think this can actually happen but just in case panic - panic(fmt.Errorf("Failed to generate a client id, %e", err)) - } - id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf) - return &id -} - -func (s *Session) VisitConnect(_ packets.ConnectPacket) { - // Disconnect, we handle the connect packet in Connect, - // this means that we have an estabilished connection already - log.Println("WARN: Got a connect packet on an already estabilished connection") - s.Disconnect(packets.DisconnectReasonCodeProtocolError) -} - -func (s *Session) VisitPublish(p packets.PublishPacket) { - subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName) - for _, subNode := range subNodes { - subNode.NodeLock.RLock() - defer subNode.NodeLock.RUnlock() - if p.QOSLevel == 0 { - if p.PacketId != nil { - log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) - return - } - } else if p.QOSLevel == 1 { - var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess - ack := packets.PubackPacket{ - PacketID: *p.PacketId, - Reason: reason, - } - s.Connection.sendPacket(ack) - } else if p.QOSLevel == 2 { - panic("UNIMPLEMENTED QOS level 2") - } - - for _, sub := range subNode.Subscriptions { - if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { - go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) - } - } - } -} - -func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { - err := s.Connection.close() - if err != nil && err != io.ErrClosedPipe { - log.Println("Error closing connection", err) - } - s.onDisconnect() -} - -func (s *Session) VisitSubscribe(p packets.SubscribePacket) { - for _, filter := range p.TopicFilters { - subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel) - } - s.Connection.sendPacket(packets.SubAckPacket{ - PacketID: p.PacketId, - Reason: packets.SubackReasonGrantedQoSTwo, - }) -} - -func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) { - for _, topic := range p.Topics { - subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel) - } - s.Connection.sendPacket(packets.UnsubAckPacket{ - PacketID: p.PacketID, - Reason: packets.UnsubackReasonSuccess, - }) -} - -func (s *Session) VisitPing(p packets.PingreqPacket) { - s.Connection.sendPacket(packets.PingrespPacket{}) -} - -func (s *Session) VisitPubackPacket(_ packets.PubackPacket) { - panic("not implemented") // TODO: Implement -} - -func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) { - panic("not implemented") // TODO: Implement -} - -func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) { - panic("not implemented") // TODO: Implement -} - -func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { - panic("not implemented") // TODO: Implement -} - -func (s *Session) getFreePacketId() uint16 { - s.freePacketID += 1 - return s.freePacketID -} diff --git a/session/utils.go b/session/utils.go new file mode 100644 index 0000000..ea707d6 --- /dev/null +++ b/session/utils.go @@ -0,0 +1,44 @@ +package session + +import ( + "encoding/base64" + "fmt" + "math/rand" + "time" + + "badat.dev/maeqtt/v2/mqtt/packets" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func genClientID() *string { + buf := make([]byte, 32) + _, err := rand.Read(buf) + if err != nil { + // I don't think this can actually happen but just in case panic + panic(fmt.Errorf("Failed to generate a client id, %e", err)) + } + id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf) + return &id +} + +func (s *Session) Disconnect(code packets.DisconnectReasonCode) error { + s.Connection.sendPacket(packets.DisconnectPacket{ + ReasonCode: code, + }) + + err := s.Connection.close() + if err != nil { + return err + } + s.onDisconnect() + return nil +} + + +func (s *Session) getFreePacketId() uint16 { + s.freePacketID += 1 + return s.freePacketID +} diff --git a/subscription/subscription.go b/subscription/subscription.go index aaeb205..8b059ba 100644 --- a/subscription/subscription.go +++ b/subscription/subscription.go @@ -1,7 +1,5 @@ package subscription -//TODO WILDCARD SUBSCRIPTIONS - import ( "strings" "sync"