Refactoring, session expiry

This commit is contained in:
bad 2021-10-16 23:38:23 +02:00
parent 9226c0abda
commit baf746a254
8 changed files with 316 additions and 256 deletions

27
main.go
View file

@ -14,27 +14,42 @@ func main() {
listener, err := net.Listen("tcp", listen_addr)
if err != nil {
log.Fatal(err)
log.Fatalf("Coulde't start a listener on tcp %v. Error: %e", listen_addr, err)
}
var sessions map[string]*session.Session = make(map[string]*session.Session)
removeSessChan := make(session.RemoveSessionChannel)
connChan := make(chan net.Conn)
go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Println("Failed accepting connection ", err)
} else {
connChan <- conn
}
}
}()
for {
select {
case con := <- connChan:
handleConnection(con, sessions, removeSessChan)
case sesId := <- removeSessChan:
delete(sessions, sesId)
}
handleConnection(conn, sessions)
}
}
func handleConnection(con net.Conn, sessions map[string]*session.Session) {
func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSessChan session.RemoveSessionChannel) {
defer handlePanic(con)
conReq, err := session.NewConnection(con)
if err != nil {
// TODO
panic(err)
log.Println("Failed to create connection ", err)
return
}
var sess *session.Session
@ -46,7 +61,7 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session) {
}
if sess == nil {
newSess := session.NewSession(conReq)
newSess := session.NewSession(conReq, rmSessChan)
sess = &newSess
go func() {
defer handlePanic(con)

98
session/Session.go Normal file
View file

@ -0,0 +1,98 @@
package session
import (
"fmt"
"log"
"badat.dev/maeqtt/v2/mqtt/packets"
)
type Session struct {
ClientID *string
// Nullable
Connection *Connection
SubscriptionChannel chan packets.PublishPacket
ConnecionChannel chan ConnectionRequest
freePacketID uint16
Expiry
}
func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session {
sess := Session{}
sess.SubscriptionChannel = make(chan packets.PublishPacket,)
sess.Expiry = NewExpiry(rmSessChan)
sess.Connect(req)
return sess
}
func (s *Session) Connect(req ConnectionRequest) {
if s.Connection != nil {
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
}
connAck := packets.ConnackPacket{}
s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
s.expireTimer.Stop()
if req.ConnectPakcet.ClientId == nil {
if s.ClientID == nil {
s.ClientID = genClientID()
}
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
} else {
s.ClientID = req.ConnectPakcet.ClientId
}
true := byte(1)
false := byte(0)
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = req.Connection
s.Connection.sendPacket(connAck)
}
// Starts a loop the recieves and responds to packets
func (s *Session) HandlerLoop() {
go s.Connection.PacketReadLoop()
for s.Connection != nil {
select {
case packet := <-s.Connection.PacketChannel:
packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan:
s.onDisconnect()
case c := <-s.ConnecionChannel:
s.Connect(c)
case subMessage := <-s.SubscriptionChannel:
// TODO implement other qos levels
subMessage.QOSLevel = 0
subMessage.Dup = false
s.Connection.sendPacket(subMessage)
}
}
select {
case c := <-s.ConnecionChannel:
s.Connect(c)
// Tail recursion baybeeee
s.HandlerLoop()
case _ = <- s.expireTimer.C:
s.expireSession()
}
}
func (s *Session) onDisconnect() {
s.Connection = nil
s.resetExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID)
}

View file

@ -0,0 +1,8 @@
package session
import "badat.dev/maeqtt/v2/mqtt/packets"
type ConnectionRequest struct {
Connection *Connection
ConnectPakcet packets.ConnectPacket
}

49
session/expiry.go Normal file
View file

@ -0,0 +1,49 @@
package session
import (
"time"
"badat.dev/maeqtt/v2/subscription"
)
type Expiry struct {
ExpiryInterval time.Duration
expireTimer time.Timer
RemoveSessionChannel
}
func NewExpiry(channel RemoveSessionChannel) Expiry {
expiry := Expiry {}
expiry.RemoveSessionChannel = channel
expiry.expireTimer = *time.NewTimer(time.Hour*9999)
expiry.expireTimer.Stop()
return expiry
}
// Channel for removing a session from the global state
type RemoveSessionChannel chan string
func (s *Session) expireSession() {
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
s.RemoveSessionChannel <- *s.ClientID
}
// newTime is nullable
func (s *Session) SetExpireTimer(newTime *uint32) {
var expiry = uint32(0)
if newTime != nil {
expiry = *newTime
} else {
expiry = uint32(0)
}
s.ExpiryInterval = time.Duration(expiry) * time.Second
}
func (s *Session) resetExpireTimer() {
if s.ExpiryInterval == 0 {
s.expireTimer.Stop()
} else {
s.expireTimer.Reset(s.ExpiryInterval)
}
}

93
session/packetVisitors.go Normal file
View file

@ -0,0 +1,93 @@
package session
import (
"io"
"log"
"badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/subscription"
)
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
// Disconnect, we handle the connect packet in Connect,
// this means that we have an estabilished connection already
log.Println("WARN: Got a connect packet on an already estabilished connection")
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
}
func (s *Session) VisitPublish(p packets.PublishPacket) {
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
for _, subNode := range subNodes {
subNode.NodeLock.RLock()
defer subNode.NodeLock.RUnlock()
if p.QOSLevel == 0 {
if p.PacketId != nil {
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
return
}
} else if p.QOSLevel == 1 {
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
ack := packets.PubackPacket{
PacketID: *p.PacketId,
Reason: reason,
}
s.Connection.sendPacket(ack)
} else if p.QOSLevel == 2 {
panic("UNIMPLEMENTED QOS level 2")
}
for _, sub := range subNode.Subscriptions {
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
}
}
}
}
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
err := s.Connection.close()
if err != nil && err != io.ErrClosedPipe {
log.Println("Error closing connection", err)
}
s.onDisconnect()
}
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
for _, filter := range p.TopicFilters {
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.SubAckPacket{
PacketID: p.PacketId,
Reason: packets.SubackReasonGrantedQoSTwo,
})
}
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
for _, topic := range p.Topics {
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.UnsubAckPacket{
PacketID: p.PacketID,
Reason: packets.UnsubackReasonSuccess,
})
}
func (s *Session) VisitPing(p packets.PingreqPacket) {
s.Connection.sendPacket(packets.PingrespPacket{})
}
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
panic("not implemented") // TODO: Implement
}

View file

@ -1,245 +0,0 @@
package session
import (
"encoding/base64"
"fmt"
"io"
"log"
"math/rand"
"time"
"badat.dev/maeqtt/v2/mqtt/packets"
"badat.dev/maeqtt/v2/subscription"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func Auth(username string, password []byte) bool {
return true
}
type Session struct {
ClientID *string
// Nullable
Connection *Connection
SubscriptionChannel chan packets.PublishPacket
ConnecionChannel chan ConnectionRequest
ExpiryInterval time.Duration // TODO
expireTimer time.Timer // TODO
freePacketID uint16
}
type ConnectionRequest struct {
Connection *Connection
ConnectPakcet packets.ConnectPacket
}
func NewSession(req ConnectionRequest) Session {
sess := Session{}
sess.SubscriptionChannel = make(chan packets.PublishPacket)
sess.Connect(req)
return sess
}
func (s *Session) Connect(req ConnectionRequest) {
if s.Connection != nil {
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
}
connAck := packets.ConnackPacket{}
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
if req.ConnectPakcet.ClientId == nil {
if s.ClientID == nil {
s.ClientID = genClientID()
}
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
} else {
s.ClientID = req.ConnectPakcet.ClientId
}
true := byte(1)
false := byte(0)
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = req.Connection
err := s.Connection.sendPacket(connAck)
if err != nil {
// TODO
panic(err)
}
}
// Starts a loop the recieves and responds to packets
func (s *Session) HandlerLoop() {
go s.Connection.PacketReadLoop()
for s.Connection != nil {
select {
case packet := <-s.Connection.PacketChannel:
packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan:
s.onDisconnect()
case c := <-s.ConnecionChannel:
s.Connect(c)
case subMessage := <-s.SubscriptionChannel:
subMessage.QOSLevel = 0
subMessage.Dup = false
s.Connection.sendPacket(subMessage)
}
}
c := <-s.ConnecionChannel
s.Connect(c)
s.HandlerLoop()
}
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
s.Connection.sendPacket(packets.DisconnectPacket{
ReasonCode: code,
})
err := s.Connection.close()
if err != nil {
return err
}
s.onDisconnect()
return nil
}
func (s *Session) onDisconnect() {
s.Connection = nil
s.resetExpireTimer()
log.Printf("Client disconnected, id: %s", *s.ClientID)
}
func (s *Session) expireSession() {
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
}
// newTime is nullable
func (s *Session) updateExpireTimer(newTime *uint32) {
var expiry = uint32(0)
if newTime != nil {
expiry = *newTime
} else {
expiry = uint32(0)
}
s.ExpiryInterval = time.Duration(expiry) * time.Second
if s.Connection == nil {
s.resetExpireTimer()
}
}
func (s *Session) resetExpireTimer() {
//s.expireTimer.Reset(s.ExpiryInterval)
}
func genClientID() *string {
buf := make([]byte, 32)
_, err := rand.Read(buf)
if err != nil {
// I don't think this can actually happen but just in case panic
panic(fmt.Errorf("Failed to generate a client id, %e", err))
}
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
return &id
}
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
// Disconnect, we handle the connect packet in Connect,
// this means that we have an estabilished connection already
log.Println("WARN: Got a connect packet on an already estabilished connection")
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
}
func (s *Session) VisitPublish(p packets.PublishPacket) {
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
for _, subNode := range subNodes {
subNode.NodeLock.RLock()
defer subNode.NodeLock.RUnlock()
if p.QOSLevel == 0 {
if p.PacketId != nil {
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
return
}
} else if p.QOSLevel == 1 {
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
ack := packets.PubackPacket{
PacketID: *p.PacketId,
Reason: reason,
}
s.Connection.sendPacket(ack)
} else if p.QOSLevel == 2 {
panic("UNIMPLEMENTED QOS level 2")
}
for _, sub := range subNode.Subscriptions {
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
}
}
}
}
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
err := s.Connection.close()
if err != nil && err != io.ErrClosedPipe {
log.Println("Error closing connection", err)
}
s.onDisconnect()
}
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
for _, filter := range p.TopicFilters {
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.SubAckPacket{
PacketID: p.PacketId,
Reason: packets.SubackReasonGrantedQoSTwo,
})
}
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
for _, topic := range p.Topics {
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
}
s.Connection.sendPacket(packets.UnsubAckPacket{
PacketID: p.PacketID,
Reason: packets.UnsubackReasonSuccess,
})
}
func (s *Session) VisitPing(p packets.PingreqPacket) {
s.Connection.sendPacket(packets.PingrespPacket{})
}
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
panic("not implemented") // TODO: Implement
}
func (s *Session) getFreePacketId() uint16 {
s.freePacketID += 1
return s.freePacketID
}

44
session/utils.go Normal file
View file

@ -0,0 +1,44 @@
package session
import (
"encoding/base64"
"fmt"
"math/rand"
"time"
"badat.dev/maeqtt/v2/mqtt/packets"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func genClientID() *string {
buf := make([]byte, 32)
_, err := rand.Read(buf)
if err != nil {
// I don't think this can actually happen but just in case panic
panic(fmt.Errorf("Failed to generate a client id, %e", err))
}
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
return &id
}
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
s.Connection.sendPacket(packets.DisconnectPacket{
ReasonCode: code,
})
err := s.Connection.close()
if err != nil {
return err
}
s.onDisconnect()
return nil
}
func (s *Session) getFreePacketId() uint16 {
s.freePacketID += 1
return s.freePacketID
}

View file

@ -1,7 +1,5 @@
package subscription
//TODO WILDCARD SUBSCRIPTIONS
import (
"strings"
"sync"