Refactoring, session expiry
This commit is contained in:
parent
9226c0abda
commit
baf746a254
8 changed files with 316 additions and 256 deletions
27
main.go
27
main.go
|
@ -14,27 +14,42 @@ func main() {
|
|||
listener, err := net.Listen("tcp", listen_addr)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
log.Fatalf("Coulde't start a listener on tcp %v. Error: %e", listen_addr, err)
|
||||
}
|
||||
|
||||
var sessions map[string]*session.Session = make(map[string]*session.Session)
|
||||
removeSessChan := make(session.RemoveSessionChannel)
|
||||
connChan := make(chan net.Conn)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println("Failed accepting connection ", err)
|
||||
} else {
|
||||
connChan <- conn
|
||||
}
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case con := <- connChan:
|
||||
handleConnection(con, sessions, removeSessChan)
|
||||
case sesId := <- removeSessChan:
|
||||
delete(sessions, sesId)
|
||||
}
|
||||
handleConnection(conn, sessions)
|
||||
}
|
||||
}
|
||||
|
||||
func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
||||
func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSessChan session.RemoveSessionChannel) {
|
||||
defer handlePanic(con)
|
||||
|
||||
conReq, err := session.NewConnection(con)
|
||||
if err != nil {
|
||||
// TODO
|
||||
panic(err)
|
||||
log.Println("Failed to create connection ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var sess *session.Session
|
||||
|
@ -46,7 +61,7 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
|||
}
|
||||
|
||||
if sess == nil {
|
||||
newSess := session.NewSession(conReq)
|
||||
newSess := session.NewSession(conReq, rmSessChan)
|
||||
sess = &newSess
|
||||
go func() {
|
||||
defer handlePanic(con)
|
||||
|
|
98
session/Session.go
Normal file
98
session/Session.go
Normal file
|
@ -0,0 +1,98 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
ClientID *string
|
||||
|
||||
// Nullable
|
||||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
ConnecionChannel chan ConnectionRequest
|
||||
|
||||
|
||||
freePacketID uint16
|
||||
|
||||
Expiry
|
||||
}
|
||||
|
||||
func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session {
|
||||
sess := Session{}
|
||||
sess.SubscriptionChannel = make(chan packets.PublishPacket,)
|
||||
sess.Expiry = NewExpiry(rmSessChan)
|
||||
|
||||
sess.Connect(req)
|
||||
return sess
|
||||
}
|
||||
|
||||
func (s *Session) Connect(req ConnectionRequest) {
|
||||
if s.Connection != nil {
|
||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||
}
|
||||
connAck := packets.ConnackPacket{}
|
||||
|
||||
|
||||
s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
||||
s.expireTimer.Stop()
|
||||
|
||||
if req.ConnectPakcet.ClientId == nil {
|
||||
if s.ClientID == nil {
|
||||
s.ClientID = genClientID()
|
||||
}
|
||||
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
|
||||
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
|
||||
} else {
|
||||
s.ClientID = req.ConnectPakcet.ClientId
|
||||
}
|
||||
|
||||
true := byte(1)
|
||||
false := byte(0)
|
||||
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
||||
|
||||
connAck.Properties.RetainAvailable.Value = &false
|
||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||
|
||||
s.Connection = req.Connection
|
||||
s.Connection.sendPacket(connAck)
|
||||
}
|
||||
|
||||
// Starts a loop the recieves and responds to packets
|
||||
func (s *Session) HandlerLoop() {
|
||||
go s.Connection.PacketReadLoop()
|
||||
for s.Connection != nil {
|
||||
select {
|
||||
case packet := <-s.Connection.PacketChannel:
|
||||
packet.Visit(s)
|
||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||
s.onDisconnect()
|
||||
case c := <-s.ConnecionChannel:
|
||||
s.Connect(c)
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
// TODO implement other qos levels
|
||||
subMessage.QOSLevel = 0
|
||||
subMessage.Dup = false
|
||||
s.Connection.sendPacket(subMessage)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case c := <-s.ConnecionChannel:
|
||||
s.Connect(c)
|
||||
// Tail recursion baybeeee
|
||||
s.HandlerLoop()
|
||||
case _ = <- s.expireTimer.C:
|
||||
s.expireSession()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) onDisconnect() {
|
||||
s.Connection = nil
|
||||
s.resetExpireTimer()
|
||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||
}
|
8
session/connectionRequest.go
Normal file
8
session/connectionRequest.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package session
|
||||
|
||||
import "badat.dev/maeqtt/v2/mqtt/packets"
|
||||
|
||||
type ConnectionRequest struct {
|
||||
Connection *Connection
|
||||
ConnectPakcet packets.ConnectPacket
|
||||
}
|
49
session/expiry.go
Normal file
49
session/expiry.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/subscription"
|
||||
)
|
||||
|
||||
type Expiry struct {
|
||||
ExpiryInterval time.Duration
|
||||
expireTimer time.Timer
|
||||
RemoveSessionChannel
|
||||
}
|
||||
|
||||
func NewExpiry(channel RemoveSessionChannel) Expiry {
|
||||
expiry := Expiry {}
|
||||
|
||||
expiry.RemoveSessionChannel = channel
|
||||
expiry.expireTimer = *time.NewTimer(time.Hour*9999)
|
||||
expiry.expireTimer.Stop()
|
||||
return expiry
|
||||
}
|
||||
|
||||
// Channel for removing a session from the global state
|
||||
type RemoveSessionChannel chan string
|
||||
|
||||
func (s *Session) expireSession() {
|
||||
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
||||
s.RemoveSessionChannel <- *s.ClientID
|
||||
}
|
||||
|
||||
// newTime is nullable
|
||||
func (s *Session) SetExpireTimer(newTime *uint32) {
|
||||
var expiry = uint32(0)
|
||||
if newTime != nil {
|
||||
expiry = *newTime
|
||||
} else {
|
||||
expiry = uint32(0)
|
||||
}
|
||||
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
||||
}
|
||||
|
||||
func (s *Session) resetExpireTimer() {
|
||||
if s.ExpiryInterval == 0 {
|
||||
s.expireTimer.Stop()
|
||||
} else {
|
||||
s.expireTimer.Reset(s.ExpiryInterval)
|
||||
}
|
||||
}
|
93
session/packetVisitors.go
Normal file
93
session/packetVisitors.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
"badat.dev/maeqtt/v2/subscription"
|
||||
)
|
||||
|
||||
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||
// Disconnect, we handle the connect packet in Connect,
|
||||
// this means that we have an estabilished connection already
|
||||
log.Println("WARN: Got a connect packet on an already estabilished connection")
|
||||
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
|
||||
}
|
||||
|
||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
||||
for _, subNode := range subNodes {
|
||||
subNode.NodeLock.RLock()
|
||||
defer subNode.NodeLock.RUnlock()
|
||||
if p.QOSLevel == 0 {
|
||||
if p.PacketId != nil {
|
||||
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
||||
return
|
||||
}
|
||||
} else if p.QOSLevel == 1 {
|
||||
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
||||
ack := packets.PubackPacket{
|
||||
PacketID: *p.PacketId,
|
||||
Reason: reason,
|
||||
}
|
||||
s.Connection.sendPacket(ack)
|
||||
} else if p.QOSLevel == 2 {
|
||||
panic("UNIMPLEMENTED QOS level 2")
|
||||
}
|
||||
|
||||
for _, sub := range subNode.Subscriptions {
|
||||
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
||||
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||
err := s.Connection.close()
|
||||
if err != nil && err != io.ErrClosedPipe {
|
||||
log.Println("Error closing connection", err)
|
||||
}
|
||||
s.onDisconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||
for _, filter := range p.TopicFilters {
|
||||
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
||||
}
|
||||
s.Connection.sendPacket(packets.SubAckPacket{
|
||||
PacketID: p.PacketId,
|
||||
Reason: packets.SubackReasonGrantedQoSTwo,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
||||
for _, topic := range p.Topics {
|
||||
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
||||
}
|
||||
s.Connection.sendPacket(packets.UnsubAckPacket{
|
||||
PacketID: p.PacketID,
|
||||
Reason: packets.UnsubackReasonSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||
s.Connection.sendPacket(packets.PingrespPacket{})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
|
@ -1,245 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
"badat.dev/maeqtt/v2/subscription"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func Auth(username string, password []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ClientID *string
|
||||
|
||||
// Nullable
|
||||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
ConnecionChannel chan ConnectionRequest
|
||||
|
||||
ExpiryInterval time.Duration // TODO
|
||||
expireTimer time.Timer // TODO
|
||||
|
||||
freePacketID uint16
|
||||
}
|
||||
|
||||
type ConnectionRequest struct {
|
||||
Connection *Connection
|
||||
ConnectPakcet packets.ConnectPacket
|
||||
}
|
||||
|
||||
func NewSession(req ConnectionRequest) Session {
|
||||
sess := Session{}
|
||||
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||
|
||||
sess.Connect(req)
|
||||
return sess
|
||||
}
|
||||
|
||||
func (s *Session) Connect(req ConnectionRequest) {
|
||||
if s.Connection != nil {
|
||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||
}
|
||||
connAck := packets.ConnackPacket{}
|
||||
|
||||
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
||||
|
||||
if req.ConnectPakcet.ClientId == nil {
|
||||
if s.ClientID == nil {
|
||||
s.ClientID = genClientID()
|
||||
}
|
||||
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
|
||||
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
|
||||
} else {
|
||||
s.ClientID = req.ConnectPakcet.ClientId
|
||||
}
|
||||
|
||||
true := byte(1)
|
||||
false := byte(0)
|
||||
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
||||
|
||||
connAck.Properties.RetainAvailable.Value = &false
|
||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||
|
||||
s.Connection = req.Connection
|
||||
err := s.Connection.sendPacket(connAck)
|
||||
if err != nil {
|
||||
// TODO
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Starts a loop the recieves and responds to packets
|
||||
func (s *Session) HandlerLoop() {
|
||||
go s.Connection.PacketReadLoop()
|
||||
for s.Connection != nil {
|
||||
select {
|
||||
case packet := <-s.Connection.PacketChannel:
|
||||
packet.Visit(s)
|
||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||
s.onDisconnect()
|
||||
case c := <-s.ConnecionChannel:
|
||||
s.Connect(c)
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
subMessage.QOSLevel = 0
|
||||
subMessage.Dup = false
|
||||
s.Connection.sendPacket(subMessage)
|
||||
}
|
||||
}
|
||||
c := <-s.ConnecionChannel
|
||||
s.Connect(c)
|
||||
s.HandlerLoop()
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
||||
s.Connection.sendPacket(packets.DisconnectPacket{
|
||||
ReasonCode: code,
|
||||
})
|
||||
|
||||
err := s.Connection.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.onDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) onDisconnect() {
|
||||
s.Connection = nil
|
||||
s.resetExpireTimer()
|
||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||
}
|
||||
|
||||
func (s *Session) expireSession() {
|
||||
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
||||
}
|
||||
|
||||
// newTime is nullable
|
||||
func (s *Session) updateExpireTimer(newTime *uint32) {
|
||||
var expiry = uint32(0)
|
||||
if newTime != nil {
|
||||
expiry = *newTime
|
||||
} else {
|
||||
expiry = uint32(0)
|
||||
}
|
||||
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
||||
|
||||
if s.Connection == nil {
|
||||
s.resetExpireTimer()
|
||||
}
|
||||
}
|
||||
func (s *Session) resetExpireTimer() {
|
||||
//s.expireTimer.Reset(s.ExpiryInterval)
|
||||
}
|
||||
|
||||
func genClientID() *string {
|
||||
buf := make([]byte, 32)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
// I don't think this can actually happen but just in case panic
|
||||
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
||||
}
|
||||
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
||||
return &id
|
||||
}
|
||||
|
||||
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||
// Disconnect, we handle the connect packet in Connect,
|
||||
// this means that we have an estabilished connection already
|
||||
log.Println("WARN: Got a connect packet on an already estabilished connection")
|
||||
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
|
||||
}
|
||||
|
||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
||||
for _, subNode := range subNodes {
|
||||
subNode.NodeLock.RLock()
|
||||
defer subNode.NodeLock.RUnlock()
|
||||
if p.QOSLevel == 0 {
|
||||
if p.PacketId != nil {
|
||||
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
||||
return
|
||||
}
|
||||
} else if p.QOSLevel == 1 {
|
||||
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
||||
ack := packets.PubackPacket{
|
||||
PacketID: *p.PacketId,
|
||||
Reason: reason,
|
||||
}
|
||||
s.Connection.sendPacket(ack)
|
||||
} else if p.QOSLevel == 2 {
|
||||
panic("UNIMPLEMENTED QOS level 2")
|
||||
}
|
||||
|
||||
for _, sub := range subNode.Subscriptions {
|
||||
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
||||
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||
err := s.Connection.close()
|
||||
if err != nil && err != io.ErrClosedPipe {
|
||||
log.Println("Error closing connection", err)
|
||||
}
|
||||
s.onDisconnect()
|
||||
}
|
||||
|
||||
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||
for _, filter := range p.TopicFilters {
|
||||
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
||||
}
|
||||
s.Connection.sendPacket(packets.SubAckPacket{
|
||||
PacketID: p.PacketId,
|
||||
Reason: packets.SubackReasonGrantedQoSTwo,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
||||
for _, topic := range p.Topics {
|
||||
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
||||
}
|
||||
s.Connection.sendPacket(packets.UnsubAckPacket{
|
||||
PacketID: p.PacketID,
|
||||
Reason: packets.UnsubackReasonSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||
s.Connection.sendPacket(packets.PingrespPacket{})
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||
panic("not implemented") // TODO: Implement
|
||||
}
|
||||
|
||||
func (s *Session) getFreePacketId() uint16 {
|
||||
s.freePacketID += 1
|
||||
return s.freePacketID
|
||||
}
|
44
session/utils.go
Normal file
44
session/utils.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func genClientID() *string {
|
||||
buf := make([]byte, 32)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
// I don't think this can actually happen but just in case panic
|
||||
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
||||
}
|
||||
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
||||
return &id
|
||||
}
|
||||
|
||||
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
||||
s.Connection.sendPacket(packets.DisconnectPacket{
|
||||
ReasonCode: code,
|
||||
})
|
||||
|
||||
err := s.Connection.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.onDisconnect()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func (s *Session) getFreePacketId() uint16 {
|
||||
s.freePacketID += 1
|
||||
return s.freePacketID
|
||||
}
|
|
@ -1,7 +1,5 @@
|
|||
package subscription
|
||||
|
||||
//TODO WILDCARD SUBSCRIPTIONS
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
Loading…
Reference in a new issue