From 09ac734b2db14cbf8bb7a90d9cd277acb5e17077 Mon Sep 17 00:00:00 2001 From: bad Date: Tue, 19 Oct 2021 12:00:53 +0200 Subject: [PATCH] Refactor connection and add support for resuming sessions --- main.go | 6 +++- session/Session.go | 63 ++++++++++++++++++++------------------- session/connection.go | 33 +++++++------------- session/expiry.go | 33 ++++++++++++++------ session/packetVisitors.go | 21 +++---------- session/utils.go | 2 +- 6 files changed, 76 insertions(+), 82 deletions(-) diff --git a/main.go b/main.go index 125f09e..a524c51 100644 --- a/main.go +++ b/main.go @@ -54,8 +54,10 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSess var sess *session.Session if conReq.ConnectPakcet.ClientId != nil { - sess, exists := sessions[*conReq.ConnectPakcet.ClientId] + exists := false + sess, exists = sessions[*conReq.ConnectPakcet.ClientId] if exists { + log.Printf("Resuming session %v", *sess.ClientID) sess.ConnecionChannel <- conReq } } @@ -63,6 +65,8 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSess if sess == nil { newSess := session.NewSession(conReq, rmSessChan) sess = &newSess + sessions[*sess.ClientID] = sess + log.Printf("New session %v", *sess.ClientID) go func() { defer handlePanic(con) sess.HandlerLoop() diff --git a/session/Session.go b/session/Session.go index 40404a6..4c06eac 100644 --- a/session/Session.go +++ b/session/Session.go @@ -15,7 +15,6 @@ type Session struct { SubscriptionChannel chan packets.PublishPacket ConnecionChannel chan ConnectionRequest - freePacketID uint16 Expiry @@ -23,7 +22,9 @@ type Session struct { func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session { sess := Session{} - sess.SubscriptionChannel = make(chan packets.PublishPacket,) + sess.SubscriptionChannel = make(chan packets.PublishPacket) + sess.ConnecionChannel = make(chan ConnectionRequest) + sess.Expiry = NewExpiry(rmSessChan) sess.Connect(req) @@ -31,21 +32,21 @@ func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session } func (s *Session) Connect(req ConnectionRequest) { + s.stopExpireTimer() + if s.Connection != nil { s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) } connAck := packets.ConnackPacket{} - - s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value) - s.expireTimer.Stop() + s.SetExpireTimerDuration(req.ConnectPakcet.Properties.SessionExpiryInterval.Value) if req.ConnectPakcet.ClientId == nil { if s.ClientID == nil { s.ClientID = genClientID() } connAck.Properties.AssignedClientIdentifier.Value = s.ClientID - } else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId { + } else if s.ClientID != nil && *s.ClientID != *req.ConnectPakcet.ClientId { panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId)) } else { s.ClientID = req.ConnectPakcet.ClientId @@ -59,46 +60,46 @@ func (s *Session) Connect(req ConnectionRequest) { connAck.Properties.SharedSubscriptionAvailable.Value = &false s.Connection = req.Connection - err := s.Connection.sendPacket(connAck) - if err != nil { - panic("TODO, handle this") - } + _ = s.Connection.sendPacket(connAck) + go s.Connection.PacketReadLoop() } // Starts a loop the receives and responds to packets func (s *Session) HandlerLoop() { - go s.Connection.PacketReadLoop() - for s.Connection != nil { + for { + var packetChan chan packets.ClientPacket + if s.Connection != nil { + packetChan = s.Connection.PacketChannel + } select { - case packet := <-s.Connection.PacketChannel: - packet.Visit(s) - case <-s.Connection.ClientDisconnectedChan: - s.onDisconnect() + case packet, more := <-packetChan: + if more { + packet.Visit(s) + } else { + s.onDisconnect() + } case c := <-s.ConnecionChannel: s.Connect(c) + case <-s.expiryChannel(): + s.expireSession() + break case subMessage := <-s.SubscriptionChannel: - // TODO implement other qos levels - subMessage.QOSLevel = 0 - subMessage.Dup = false - err := s.Connection.sendPacket(subMessage) - if err != nil { - panic("TOOO handle this") + if s.Connection != nil { + // TODO implement other qos levels + subMessage.QOSLevel = 0 + subMessage.Dup = false + err := s.Connection.sendPacket(subMessage) + if err != nil { + panic("TOOO handle this") + } } } } - select { - case c := <-s.ConnecionChannel: - s.Connect(c) - // Tail recursion baybeeee - s.HandlerLoop() - case <- s.expireTimer.C: - s.expireSession() - } } func (s *Session) onDisconnect() { s.Connection = nil - s.resetExpireTimer() + s.startExpireTimer() log.Printf("Client disconnected, id: %s", *s.ClientID) } diff --git a/session/connection.go b/session/connection.go index ce5f15b..55645e7 100644 --- a/session/connection.go +++ b/session/connection.go @@ -5,7 +5,6 @@ import ( "errors" "io" "log" - "time" "badat.dev/maeqtt/v2/mqtt/packets" ) @@ -18,47 +17,37 @@ type Connection struct { WantsProblemInf bool Will packets.Will - KeepAliveInterval time.Duration - keepAliveTicker time.Ticker + // TODO + //KeepAliveInterval time.Duration + //keepAliveTimer time.Timer + // Gets closed whenever the client disconnects PacketChannel chan packets.ClientPacket - ClientDisconnectedChan chan bool rw io.ReadWriteCloser } -func (c *Connection) resetKeepAlive() { - if c.KeepAliveInterval != 0 { - panic("TODO") - // TODO IMPLEMENT THIS - //s.keepAliveTicker.Reset(s.KeepAliveInterval) - } -} - func (c *Connection) readPacket() (*packets.ClientPacket, error) { return packets.ReadPacket(bufio.NewReader(c.rw)) } func (c *Connection) sendPacket(p packets.ServerPacket) error { - c.resetKeepAlive() return p.Write(c.rw) } -func (c *Connection) close() error { - close(c.PacketChannel) - return c.rw.Close() +func (c *Connection) close() { + _ = c.rw.Close() } func (c *Connection) PacketReadLoop() { for { - pack, err := c.readPacket() + pack, err := c.readPacket() if err != nil { - c.ClientDisconnectedChan <- true - c.close() - } else { - c.PacketChannel <- *pack + break } + c.PacketChannel <- *pack } + close(c.PacketChannel) } var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect") @@ -103,8 +92,6 @@ func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) { conn.WantsRespInf = false } - conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second - conn.PacketChannel = make(chan packets.ClientPacket, 1) return connReq, err diff --git a/session/expiry.go b/session/expiry.go index a1e1872..51e61e0 100644 --- a/session/expiry.go +++ b/session/expiry.go @@ -1,23 +1,22 @@ package session import ( + "log" "time" "badat.dev/maeqtt/v2/subscription" ) type Expiry struct { - ExpiryInterval time.Duration - expireTimer time.Timer + ExpiryInterval time.Duration + expireTimer *time.Timer RemoveSessionChannel } func NewExpiry(channel RemoveSessionChannel) Expiry { - expiry := Expiry {} + expiry := Expiry{} expiry.RemoveSessionChannel = channel - expiry.expireTimer = *time.NewTimer(time.Hour*9999) - expiry.expireTimer.Stop() return expiry } @@ -25,21 +24,37 @@ func NewExpiry(channel RemoveSessionChannel) Expiry { type RemoveSessionChannel chan string func (s *Session) expireSession() { + log.Printf("Session: %v expired", *s.ClientID) subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel) s.RemoveSessionChannel <- *s.ClientID } // newTime is nullable -func (s *Session) SetExpireTimer(newTime *uint32) { +func (e *Expiry) SetExpireTimerDuration(newTime *uint32) { expiry := uint32(0) if newTime != nil { expiry = *newTime } else { expiry = uint32(0) } - s.ExpiryInterval = time.Duration(expiry) * time.Second + e.ExpiryInterval = time.Second * time.Duration(expiry) } -func (s *Session) resetExpireTimer() { - s.expireTimer.Reset(s.ExpiryInterval) +func (e *Expiry) startExpireTimer() { + e.stopExpireTimer() + e.expireTimer = time.NewTimer(e.ExpiryInterval) +} + +func (e *Expiry) stopExpireTimer() { + if e.expireTimer != nil { + e.expireTimer.Stop() + } +} + +func (e *Expiry) expiryChannel() <- chan time.Time { + if e.expireTimer != nil { + return e.expireTimer.C + } else { + return nil + } } diff --git a/session/packetVisitors.go b/session/packetVisitors.go index f27b951..2bf2a01 100644 --- a/session/packetVisitors.go +++ b/session/packetVisitors.go @@ -1,7 +1,6 @@ package session import ( - "io" "log" "badat.dev/maeqtt/v2/mqtt/packets" @@ -31,10 +30,7 @@ func (s *Session) VisitPublish(p packets.PublishPacket) { PacketID: *p.PacketId, Reason: reason, } - err := s.Connection.sendPacket(ack) - if err != nil { - panic("TODO") - } + _ = s.Connection.sendPacket(ack) } else if p.QOSLevel == 2 { panic("UNIMPLEMENTED QOS level 2") } @@ -48,10 +44,7 @@ func (s *Session) VisitPublish(p packets.PublishPacket) { } func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { - err := s.Connection.close() - if err != nil && err != io.ErrClosedPipe { - log.Println("Error closing connection", err) - } + s.Connection.close() s.onDisconnect() } @@ -59,26 +52,20 @@ func (s *Session) VisitSubscribe(p packets.SubscribePacket) { for _, filter := range p.TopicFilters { subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel) } - err := s.Connection.sendPacket(packets.SubAckPacket{ + _ = s.Connection.sendPacket(packets.SubAckPacket{ PacketID: p.PacketId, Reason: packets.SubackReasonGrantedQoSTwo, }) - if err != nil { - panic("TODO") - } } func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) { for _, topic := range p.Topics { subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel) } - err := s.Connection.sendPacket(packets.UnsubAckPacket{ + _ = s.Connection.sendPacket(packets.UnsubAckPacket{ PacketID: p.PacketID, Reason: packets.UnsubackReasonSuccess, }) - if err != nil { - panic("TODO") - } } func (s *Session) VisitPing(p packets.PingreqPacket) { diff --git a/session/utils.go b/session/utils.go index 0b3f7f0..52e5a6e 100644 --- a/session/utils.go +++ b/session/utils.go @@ -29,7 +29,7 @@ func (s *Session) Disconnect(code packets.DisconnectReasonCode) { _ = s.Connection.sendPacket(packets.DisconnectPacket{ ReasonCode: code, }) - _ = s.Connection.close() + s.Connection.close() s.onDisconnect() }