Refactor connection and add support for resuming sessions
This commit is contained in:
parent
036552bac1
commit
09ac734b2d
6 changed files with 76 additions and 82 deletions
6
main.go
6
main.go
|
@ -54,8 +54,10 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSess
|
||||||
|
|
||||||
var sess *session.Session
|
var sess *session.Session
|
||||||
if conReq.ConnectPakcet.ClientId != nil {
|
if conReq.ConnectPakcet.ClientId != nil {
|
||||||
sess, exists := sessions[*conReq.ConnectPakcet.ClientId]
|
exists := false
|
||||||
|
sess, exists = sessions[*conReq.ConnectPakcet.ClientId]
|
||||||
if exists {
|
if exists {
|
||||||
|
log.Printf("Resuming session %v", *sess.ClientID)
|
||||||
sess.ConnecionChannel <- conReq
|
sess.ConnecionChannel <- conReq
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,6 +65,8 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSess
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
newSess := session.NewSession(conReq, rmSessChan)
|
newSess := session.NewSession(conReq, rmSessChan)
|
||||||
sess = &newSess
|
sess = &newSess
|
||||||
|
sessions[*sess.ClientID] = sess
|
||||||
|
log.Printf("New session %v", *sess.ClientID)
|
||||||
go func() {
|
go func() {
|
||||||
defer handlePanic(con)
|
defer handlePanic(con)
|
||||||
sess.HandlerLoop()
|
sess.HandlerLoop()
|
||||||
|
|
|
@ -15,7 +15,6 @@ type Session struct {
|
||||||
SubscriptionChannel chan packets.PublishPacket
|
SubscriptionChannel chan packets.PublishPacket
|
||||||
ConnecionChannel chan ConnectionRequest
|
ConnecionChannel chan ConnectionRequest
|
||||||
|
|
||||||
|
|
||||||
freePacketID uint16
|
freePacketID uint16
|
||||||
|
|
||||||
Expiry
|
Expiry
|
||||||
|
@ -23,7 +22,9 @@ type Session struct {
|
||||||
|
|
||||||
func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session {
|
func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session {
|
||||||
sess := Session{}
|
sess := Session{}
|
||||||
sess.SubscriptionChannel = make(chan packets.PublishPacket,)
|
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||||
|
sess.ConnecionChannel = make(chan ConnectionRequest)
|
||||||
|
|
||||||
sess.Expiry = NewExpiry(rmSessChan)
|
sess.Expiry = NewExpiry(rmSessChan)
|
||||||
|
|
||||||
sess.Connect(req)
|
sess.Connect(req)
|
||||||
|
@ -31,21 +32,21 @@ func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Connect(req ConnectionRequest) {
|
func (s *Session) Connect(req ConnectionRequest) {
|
||||||
|
s.stopExpireTimer()
|
||||||
|
|
||||||
if s.Connection != nil {
|
if s.Connection != nil {
|
||||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||||
}
|
}
|
||||||
connAck := packets.ConnackPacket{}
|
connAck := packets.ConnackPacket{}
|
||||||
|
|
||||||
|
s.SetExpireTimerDuration(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
||||||
s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
|
||||||
s.expireTimer.Stop()
|
|
||||||
|
|
||||||
if req.ConnectPakcet.ClientId == nil {
|
if req.ConnectPakcet.ClientId == nil {
|
||||||
if s.ClientID == nil {
|
if s.ClientID == nil {
|
||||||
s.ClientID = genClientID()
|
s.ClientID = genClientID()
|
||||||
}
|
}
|
||||||
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||||
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
|
} else if s.ClientID != nil && *s.ClientID != *req.ConnectPakcet.ClientId {
|
||||||
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
|
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
|
||||||
} else {
|
} else {
|
||||||
s.ClientID = req.ConnectPakcet.ClientId
|
s.ClientID = req.ConnectPakcet.ClientId
|
||||||
|
@ -59,46 +60,46 @@ func (s *Session) Connect(req ConnectionRequest) {
|
||||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||||
|
|
||||||
s.Connection = req.Connection
|
s.Connection = req.Connection
|
||||||
err := s.Connection.sendPacket(connAck)
|
_ = s.Connection.sendPacket(connAck)
|
||||||
if err != nil {
|
go s.Connection.PacketReadLoop()
|
||||||
panic("TODO, handle this")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starts a loop the receives and responds to packets
|
// Starts a loop the receives and responds to packets
|
||||||
func (s *Session) HandlerLoop() {
|
func (s *Session) HandlerLoop() {
|
||||||
go s.Connection.PacketReadLoop()
|
for {
|
||||||
for s.Connection != nil {
|
var packetChan chan packets.ClientPacket
|
||||||
|
if s.Connection != nil {
|
||||||
|
packetChan = s.Connection.PacketChannel
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case packet := <-s.Connection.PacketChannel:
|
case packet, more := <-packetChan:
|
||||||
packet.Visit(s)
|
if more {
|
||||||
case <-s.Connection.ClientDisconnectedChan:
|
packet.Visit(s)
|
||||||
s.onDisconnect()
|
} else {
|
||||||
|
s.onDisconnect()
|
||||||
|
}
|
||||||
case c := <-s.ConnecionChannel:
|
case c := <-s.ConnecionChannel:
|
||||||
s.Connect(c)
|
s.Connect(c)
|
||||||
|
case <-s.expiryChannel():
|
||||||
|
s.expireSession()
|
||||||
|
break
|
||||||
case subMessage := <-s.SubscriptionChannel:
|
case subMessage := <-s.SubscriptionChannel:
|
||||||
// TODO implement other qos levels
|
if s.Connection != nil {
|
||||||
subMessage.QOSLevel = 0
|
// TODO implement other qos levels
|
||||||
subMessage.Dup = false
|
subMessage.QOSLevel = 0
|
||||||
err := s.Connection.sendPacket(subMessage)
|
subMessage.Dup = false
|
||||||
if err != nil {
|
err := s.Connection.sendPacket(subMessage)
|
||||||
panic("TOOO handle this")
|
if err != nil {
|
||||||
|
panic("TOOO handle this")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
|
||||||
case c := <-s.ConnecionChannel:
|
|
||||||
s.Connect(c)
|
|
||||||
// Tail recursion baybeeee
|
|
||||||
s.HandlerLoop()
|
|
||||||
case <- s.expireTimer.C:
|
|
||||||
s.expireSession()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) onDisconnect() {
|
func (s *Session) onDisconnect() {
|
||||||
s.Connection = nil
|
s.Connection = nil
|
||||||
s.resetExpireTimer()
|
s.startExpireTimer()
|
||||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
)
|
)
|
||||||
|
@ -18,47 +17,37 @@ type Connection struct {
|
||||||
WantsProblemInf bool
|
WantsProblemInf bool
|
||||||
Will packets.Will
|
Will packets.Will
|
||||||
|
|
||||||
KeepAliveInterval time.Duration
|
// TODO
|
||||||
keepAliveTicker time.Ticker
|
//KeepAliveInterval time.Duration
|
||||||
|
//keepAliveTimer time.Timer
|
||||||
|
|
||||||
|
// Gets closed whenever the client disconnects
|
||||||
PacketChannel chan packets.ClientPacket
|
PacketChannel chan packets.ClientPacket
|
||||||
ClientDisconnectedChan chan bool
|
|
||||||
|
|
||||||
rw io.ReadWriteCloser
|
rw io.ReadWriteCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) resetKeepAlive() {
|
|
||||||
if c.KeepAliveInterval != 0 {
|
|
||||||
panic("TODO")
|
|
||||||
// TODO IMPLEMENT THIS
|
|
||||||
//s.keepAliveTicker.Reset(s.KeepAliveInterval)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
|
func (c *Connection) readPacket() (*packets.ClientPacket, error) {
|
||||||
return packets.ReadPacket(bufio.NewReader(c.rw))
|
return packets.ReadPacket(bufio.NewReader(c.rw))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) sendPacket(p packets.ServerPacket) error {
|
func (c *Connection) sendPacket(p packets.ServerPacket) error {
|
||||||
c.resetKeepAlive()
|
|
||||||
return p.Write(c.rw)
|
return p.Write(c.rw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) close() error {
|
func (c *Connection) close() {
|
||||||
close(c.PacketChannel)
|
_ = c.rw.Close()
|
||||||
return c.rw.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) PacketReadLoop() {
|
func (c *Connection) PacketReadLoop() {
|
||||||
for {
|
for {
|
||||||
pack, err := c.readPacket()
|
pack, err := c.readPacket()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.ClientDisconnectedChan <- true
|
break
|
||||||
c.close()
|
|
||||||
} else {
|
|
||||||
c.PacketChannel <- *pack
|
|
||||||
}
|
}
|
||||||
|
c.PacketChannel <- *pack
|
||||||
}
|
}
|
||||||
|
close(c.PacketChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
|
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
|
||||||
|
@ -103,8 +92,6 @@ func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
|
||||||
conn.WantsRespInf = false
|
conn.WantsRespInf = false
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second
|
|
||||||
|
|
||||||
conn.PacketChannel = make(chan packets.ClientPacket, 1)
|
conn.PacketChannel = make(chan packets.ClientPacket, 1)
|
||||||
|
|
||||||
return connReq, err
|
return connReq, err
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/subscription"
|
"badat.dev/maeqtt/v2/subscription"
|
||||||
|
@ -8,16 +9,14 @@ import (
|
||||||
|
|
||||||
type Expiry struct {
|
type Expiry struct {
|
||||||
ExpiryInterval time.Duration
|
ExpiryInterval time.Duration
|
||||||
expireTimer time.Timer
|
expireTimer *time.Timer
|
||||||
RemoveSessionChannel
|
RemoveSessionChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewExpiry(channel RemoveSessionChannel) Expiry {
|
func NewExpiry(channel RemoveSessionChannel) Expiry {
|
||||||
expiry := Expiry {}
|
expiry := Expiry{}
|
||||||
|
|
||||||
expiry.RemoveSessionChannel = channel
|
expiry.RemoveSessionChannel = channel
|
||||||
expiry.expireTimer = *time.NewTimer(time.Hour*9999)
|
|
||||||
expiry.expireTimer.Stop()
|
|
||||||
return expiry
|
return expiry
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,21 +24,37 @@ func NewExpiry(channel RemoveSessionChannel) Expiry {
|
||||||
type RemoveSessionChannel chan string
|
type RemoveSessionChannel chan string
|
||||||
|
|
||||||
func (s *Session) expireSession() {
|
func (s *Session) expireSession() {
|
||||||
|
log.Printf("Session: %v expired", *s.ClientID)
|
||||||
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
||||||
s.RemoveSessionChannel <- *s.ClientID
|
s.RemoveSessionChannel <- *s.ClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTime is nullable
|
// newTime is nullable
|
||||||
func (s *Session) SetExpireTimer(newTime *uint32) {
|
func (e *Expiry) SetExpireTimerDuration(newTime *uint32) {
|
||||||
expiry := uint32(0)
|
expiry := uint32(0)
|
||||||
if newTime != nil {
|
if newTime != nil {
|
||||||
expiry = *newTime
|
expiry = *newTime
|
||||||
} else {
|
} else {
|
||||||
expiry = uint32(0)
|
expiry = uint32(0)
|
||||||
}
|
}
|
||||||
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
e.ExpiryInterval = time.Second * time.Duration(expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) resetExpireTimer() {
|
func (e *Expiry) startExpireTimer() {
|
||||||
s.expireTimer.Reset(s.ExpiryInterval)
|
e.stopExpireTimer()
|
||||||
|
e.expireTimer = time.NewTimer(e.ExpiryInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Expiry) stopExpireTimer() {
|
||||||
|
if e.expireTimer != nil {
|
||||||
|
e.expireTimer.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Expiry) expiryChannel() <- chan time.Time {
|
||||||
|
if e.expireTimer != nil {
|
||||||
|
return e.expireTimer.C
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
@ -31,10 +30,7 @@ func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||||
PacketID: *p.PacketId,
|
PacketID: *p.PacketId,
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
}
|
}
|
||||||
err := s.Connection.sendPacket(ack)
|
_ = s.Connection.sendPacket(ack)
|
||||||
if err != nil {
|
|
||||||
panic("TODO")
|
|
||||||
}
|
|
||||||
} else if p.QOSLevel == 2 {
|
} else if p.QOSLevel == 2 {
|
||||||
panic("UNIMPLEMENTED QOS level 2")
|
panic("UNIMPLEMENTED QOS level 2")
|
||||||
}
|
}
|
||||||
|
@ -48,10 +44,7 @@ func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||||
err := s.Connection.close()
|
s.Connection.close()
|
||||||
if err != nil && err != io.ErrClosedPipe {
|
|
||||||
log.Println("Error closing connection", err)
|
|
||||||
}
|
|
||||||
s.onDisconnect()
|
s.onDisconnect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,26 +52,20 @@ func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||||
for _, filter := range p.TopicFilters {
|
for _, filter := range p.TopicFilters {
|
||||||
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
||||||
}
|
}
|
||||||
err := s.Connection.sendPacket(packets.SubAckPacket{
|
_ = s.Connection.sendPacket(packets.SubAckPacket{
|
||||||
PacketID: p.PacketId,
|
PacketID: p.PacketId,
|
||||||
Reason: packets.SubackReasonGrantedQoSTwo,
|
Reason: packets.SubackReasonGrantedQoSTwo,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
panic("TODO")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
||||||
for _, topic := range p.Topics {
|
for _, topic := range p.Topics {
|
||||||
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
||||||
}
|
}
|
||||||
err := s.Connection.sendPacket(packets.UnsubAckPacket{
|
_ = s.Connection.sendPacket(packets.UnsubAckPacket{
|
||||||
PacketID: p.PacketID,
|
PacketID: p.PacketID,
|
||||||
Reason: packets.UnsubackReasonSuccess,
|
Reason: packets.UnsubackReasonSuccess,
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
panic("TODO")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||||
|
|
|
@ -29,7 +29,7 @@ func (s *Session) Disconnect(code packets.DisconnectReasonCode) {
|
||||||
_ = s.Connection.sendPacket(packets.DisconnectPacket{
|
_ = s.Connection.sendPacket(packets.DisconnectPacket{
|
||||||
ReasonCode: code,
|
ReasonCode: code,
|
||||||
})
|
})
|
||||||
_ = s.Connection.close()
|
s.Connection.close()
|
||||||
|
|
||||||
s.onDisconnect()
|
s.onDisconnect()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue