Refactor connection and add support for resuming sessions

This commit is contained in:
bad 2021-10-19 12:00:53 +02:00
parent 036552bac1
commit 09ac734b2d
6 changed files with 76 additions and 82 deletions

View File

@ -54,8 +54,10 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSess
var sess *session.Session var sess *session.Session
if conReq.ConnectPakcet.ClientId != nil { if conReq.ConnectPakcet.ClientId != nil {
sess, exists := sessions[*conReq.ConnectPakcet.ClientId] exists := false
sess, exists = sessions[*conReq.ConnectPakcet.ClientId]
if exists { if exists {
log.Printf("Resuming session %v", *sess.ClientID)
sess.ConnecionChannel <- conReq sess.ConnecionChannel <- conReq
} }
} }
@ -63,6 +65,8 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSess
if sess == nil { if sess == nil {
newSess := session.NewSession(conReq, rmSessChan) newSess := session.NewSession(conReq, rmSessChan)
sess = &newSess sess = &newSess
sessions[*sess.ClientID] = sess
log.Printf("New session %v", *sess.ClientID)
go func() { go func() {
defer handlePanic(con) defer handlePanic(con)
sess.HandlerLoop() sess.HandlerLoop()

View File

@ -15,7 +15,6 @@ type Session struct {
SubscriptionChannel chan packets.PublishPacket SubscriptionChannel chan packets.PublishPacket
ConnecionChannel chan ConnectionRequest ConnecionChannel chan ConnectionRequest
freePacketID uint16 freePacketID uint16
Expiry Expiry
@ -23,7 +22,9 @@ type Session struct {
func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session { func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session {
sess := Session{} sess := Session{}
sess.SubscriptionChannel = make(chan packets.PublishPacket,) sess.SubscriptionChannel = make(chan packets.PublishPacket)
sess.ConnecionChannel = make(chan ConnectionRequest)
sess.Expiry = NewExpiry(rmSessChan) sess.Expiry = NewExpiry(rmSessChan)
sess.Connect(req) sess.Connect(req)
@ -31,21 +32,21 @@ func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session
} }
func (s *Session) Connect(req ConnectionRequest) { func (s *Session) Connect(req ConnectionRequest) {
s.stopExpireTimer()
if s.Connection != nil { if s.Connection != nil {
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
} }
connAck := packets.ConnackPacket{} connAck := packets.ConnackPacket{}
s.SetExpireTimerDuration(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
s.expireTimer.Stop()
if req.ConnectPakcet.ClientId == nil { if req.ConnectPakcet.ClientId == nil {
if s.ClientID == nil { if s.ClientID == nil {
s.ClientID = genClientID() s.ClientID = genClientID()
} }
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId { } else if s.ClientID != nil && *s.ClientID != *req.ConnectPakcet.ClientId {
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId)) panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
} else { } else {
s.ClientID = req.ConnectPakcet.ClientId s.ClientID = req.ConnectPakcet.ClientId
@ -59,46 +60,46 @@ func (s *Session) Connect(req ConnectionRequest) {
connAck.Properties.SharedSubscriptionAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = req.Connection s.Connection = req.Connection
err := s.Connection.sendPacket(connAck) _ = s.Connection.sendPacket(connAck)
if err != nil { go s.Connection.PacketReadLoop()
panic("TODO, handle this")
}
} }
// Starts a loop the receives and responds to packets // Starts a loop the receives and responds to packets
func (s *Session) HandlerLoop() { func (s *Session) HandlerLoop() {
go s.Connection.PacketReadLoop() for {
for s.Connection != nil { var packetChan chan packets.ClientPacket
if s.Connection != nil {
packetChan = s.Connection.PacketChannel
}
select { select {
case packet := <-s.Connection.PacketChannel: case packet, more := <-packetChan:
packet.Visit(s) if more {
case <-s.Connection.ClientDisconnectedChan: packet.Visit(s)
s.onDisconnect() } else {
s.onDisconnect()
}
case c := <-s.ConnecionChannel: case c := <-s.ConnecionChannel:
s.Connect(c) s.Connect(c)
case <-s.expiryChannel():
s.expireSession()
break
case subMessage := <-s.SubscriptionChannel: case subMessage := <-s.SubscriptionChannel:
// TODO implement other qos levels if s.Connection != nil {
subMessage.QOSLevel = 0 // TODO implement other qos levels
subMessage.Dup = false subMessage.QOSLevel = 0
err := s.Connection.sendPacket(subMessage) subMessage.Dup = false
if err != nil { err := s.Connection.sendPacket(subMessage)
panic("TOOO handle this") if err != nil {
panic("TOOO handle this")
}
} }
} }
} }
select {
case c := <-s.ConnecionChannel:
s.Connect(c)
// Tail recursion baybeeee
s.HandlerLoop()
case <- s.expireTimer.C:
s.expireSession()
}
} }
func (s *Session) onDisconnect() { func (s *Session) onDisconnect() {
s.Connection = nil s.Connection = nil
s.resetExpireTimer() s.startExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID) log.Printf("Client disconnected, id: %s", *s.ClientID)
} }

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"time"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
) )
@ -18,47 +17,37 @@ type Connection struct {
WantsProblemInf bool WantsProblemInf bool
Will packets.Will Will packets.Will
KeepAliveInterval time.Duration // TODO
keepAliveTicker time.Ticker //KeepAliveInterval time.Duration
//keepAliveTimer time.Timer
// Gets closed whenever the client disconnects
PacketChannel chan packets.ClientPacket PacketChannel chan packets.ClientPacket
ClientDisconnectedChan chan bool
rw io.ReadWriteCloser rw io.ReadWriteCloser
} }
func (c *Connection) resetKeepAlive() {
if c.KeepAliveInterval != 0 {
panic("TODO")
// TODO IMPLEMENT THIS
//s.keepAliveTicker.Reset(s.KeepAliveInterval)
}
}
func (c *Connection) readPacket() (*packets.ClientPacket, error) { func (c *Connection) readPacket() (*packets.ClientPacket, error) {
return packets.ReadPacket(bufio.NewReader(c.rw)) return packets.ReadPacket(bufio.NewReader(c.rw))
} }
func (c *Connection) sendPacket(p packets.ServerPacket) error { func (c *Connection) sendPacket(p packets.ServerPacket) error {
c.resetKeepAlive()
return p.Write(c.rw) return p.Write(c.rw)
} }
func (c *Connection) close() error { func (c *Connection) close() {
close(c.PacketChannel) _ = c.rw.Close()
return c.rw.Close()
} }
func (c *Connection) PacketReadLoop() { func (c *Connection) PacketReadLoop() {
for { for {
pack, err := c.readPacket() pack, err := c.readPacket()
if err != nil { if err != nil {
c.ClientDisconnectedChan <- true break
c.close()
} else {
c.PacketChannel <- *pack
} }
c.PacketChannel <- *pack
} }
close(c.PacketChannel)
} }
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect") var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
@ -103,8 +92,6 @@ func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
conn.WantsRespInf = false conn.WantsRespInf = false
} }
conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second
conn.PacketChannel = make(chan packets.ClientPacket, 1) conn.PacketChannel = make(chan packets.ClientPacket, 1)
return connReq, err return connReq, err

View File

@ -1,23 +1,22 @@
package session package session
import ( import (
"log"
"time" "time"
"badat.dev/maeqtt/v2/subscription" "badat.dev/maeqtt/v2/subscription"
) )
type Expiry struct { type Expiry struct {
ExpiryInterval time.Duration ExpiryInterval time.Duration
expireTimer time.Timer expireTimer *time.Timer
RemoveSessionChannel RemoveSessionChannel
} }
func NewExpiry(channel RemoveSessionChannel) Expiry { func NewExpiry(channel RemoveSessionChannel) Expiry {
expiry := Expiry {} expiry := Expiry{}
expiry.RemoveSessionChannel = channel expiry.RemoveSessionChannel = channel
expiry.expireTimer = *time.NewTimer(time.Hour*9999)
expiry.expireTimer.Stop()
return expiry return expiry
} }
@ -25,21 +24,37 @@ func NewExpiry(channel RemoveSessionChannel) Expiry {
type RemoveSessionChannel chan string type RemoveSessionChannel chan string
func (s *Session) expireSession() { func (s *Session) expireSession() {
log.Printf("Session: %v expired", *s.ClientID)
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel) subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
s.RemoveSessionChannel <- *s.ClientID s.RemoveSessionChannel <- *s.ClientID
} }
// newTime is nullable // newTime is nullable
func (s *Session) SetExpireTimer(newTime *uint32) { func (e *Expiry) SetExpireTimerDuration(newTime *uint32) {
expiry := uint32(0) expiry := uint32(0)
if newTime != nil { if newTime != nil {
expiry = *newTime expiry = *newTime
} else { } else {
expiry = uint32(0) expiry = uint32(0)
} }
s.ExpiryInterval = time.Duration(expiry) * time.Second e.ExpiryInterval = time.Second * time.Duration(expiry)
} }
func (s *Session) resetExpireTimer() { func (e *Expiry) startExpireTimer() {
s.expireTimer.Reset(s.ExpiryInterval) e.stopExpireTimer()
e.expireTimer = time.NewTimer(e.ExpiryInterval)
}
func (e *Expiry) stopExpireTimer() {
if e.expireTimer != nil {
e.expireTimer.Stop()
}
}
func (e *Expiry) expiryChannel() <- chan time.Time {
if e.expireTimer != nil {
return e.expireTimer.C
} else {
return nil
}
} }

View File

@ -1,7 +1,6 @@
package session package session
import ( import (
"io"
"log" "log"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
@ -31,10 +30,7 @@ func (s *Session) VisitPublish(p packets.PublishPacket) {
PacketID: *p.PacketId, PacketID: *p.PacketId,
Reason: reason, Reason: reason,
} }
err := s.Connection.sendPacket(ack) _ = s.Connection.sendPacket(ack)
if err != nil {
panic("TODO")
}
} else if p.QOSLevel == 2 { } else if p.QOSLevel == 2 {
panic("UNIMPLEMENTED QOS level 2") panic("UNIMPLEMENTED QOS level 2")
} }
@ -48,10 +44,7 @@ func (s *Session) VisitPublish(p packets.PublishPacket) {
} }
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
err := s.Connection.close() s.Connection.close()
if err != nil && err != io.ErrClosedPipe {
log.Println("Error closing connection", err)
}
s.onDisconnect() s.onDisconnect()
} }
@ -59,26 +52,20 @@ func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
for _, filter := range p.TopicFilters { for _, filter := range p.TopicFilters {
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel) subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
} }
err := s.Connection.sendPacket(packets.SubAckPacket{ _ = s.Connection.sendPacket(packets.SubAckPacket{
PacketID: p.PacketId, PacketID: p.PacketId,
Reason: packets.SubackReasonGrantedQoSTwo, Reason: packets.SubackReasonGrantedQoSTwo,
}) })
if err != nil {
panic("TODO")
}
} }
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) { func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
for _, topic := range p.Topics { for _, topic := range p.Topics {
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel) subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
} }
err := s.Connection.sendPacket(packets.UnsubAckPacket{ _ = s.Connection.sendPacket(packets.UnsubAckPacket{
PacketID: p.PacketID, PacketID: p.PacketID,
Reason: packets.UnsubackReasonSuccess, Reason: packets.UnsubackReasonSuccess,
}) })
if err != nil {
panic("TODO")
}
} }
func (s *Session) VisitPing(p packets.PingreqPacket) { func (s *Session) VisitPing(p packets.PingreqPacket) {

View File

@ -29,7 +29,7 @@ func (s *Session) Disconnect(code packets.DisconnectReasonCode) {
_ = s.Connection.sendPacket(packets.DisconnectPacket{ _ = s.Connection.sendPacket(packets.DisconnectPacket{
ReasonCode: code, ReasonCode: code,
}) })
_ = s.Connection.close() s.Connection.close()
s.onDisconnect() s.onDisconnect()
} }