Refactoring, session expiry
This commit is contained in:
parent
9226c0abda
commit
baf746a254
8 changed files with 316 additions and 256 deletions
33
main.go
33
main.go
|
@ -14,27 +14,42 @@ func main() {
|
||||||
listener, err := net.Listen("tcp", listen_addr)
|
listener, err := net.Listen("tcp", listen_addr)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatalf("Coulde't start a listener on tcp %v. Error: %e", listen_addr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessions map[string]*session.Session = make(map[string]*session.Session)
|
var sessions map[string]*session.Session = make(map[string]*session.Session)
|
||||||
|
removeSessChan := make(session.RemoveSessionChannel)
|
||||||
|
connChan := make(chan net.Conn)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed accepting connection ", err)
|
||||||
|
} else {
|
||||||
|
connChan <- conn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
select {
|
||||||
if err != nil {
|
case con := <- connChan:
|
||||||
log.Println("Failed accepting connection ", err)
|
handleConnection(con, sessions, removeSessChan)
|
||||||
|
case sesId := <- removeSessChan:
|
||||||
|
delete(sessions, sesId)
|
||||||
}
|
}
|
||||||
handleConnection(conn, sessions)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
func handleConnection(con net.Conn, sessions map[string]*session.Session, rmSessChan session.RemoveSessionChannel) {
|
||||||
defer handlePanic(con)
|
defer handlePanic(con)
|
||||||
|
|
||||||
conReq, err := session.NewConnection(con)
|
conReq, err := session.NewConnection(con)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO
|
log.Println("Failed to create connection ", err)
|
||||||
panic(err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var sess *session.Session
|
var sess *session.Session
|
||||||
|
@ -46,7 +61,7 @@ func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
newSess := session.NewSession(conReq)
|
newSess := session.NewSession(conReq, rmSessChan)
|
||||||
sess = &newSess
|
sess = &newSess
|
||||||
go func() {
|
go func() {
|
||||||
defer handlePanic(con)
|
defer handlePanic(con)
|
||||||
|
|
98
session/Session.go
Normal file
98
session/Session.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ClientID *string
|
||||||
|
|
||||||
|
// Nullable
|
||||||
|
Connection *Connection
|
||||||
|
SubscriptionChannel chan packets.PublishPacket
|
||||||
|
ConnecionChannel chan ConnectionRequest
|
||||||
|
|
||||||
|
|
||||||
|
freePacketID uint16
|
||||||
|
|
||||||
|
Expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSession(req ConnectionRequest, rmSessChan RemoveSessionChannel) Session {
|
||||||
|
sess := Session{}
|
||||||
|
sess.SubscriptionChannel = make(chan packets.PublishPacket,)
|
||||||
|
sess.Expiry = NewExpiry(rmSessChan)
|
||||||
|
|
||||||
|
sess.Connect(req)
|
||||||
|
return sess
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Connect(req ConnectionRequest) {
|
||||||
|
if s.Connection != nil {
|
||||||
|
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||||
|
}
|
||||||
|
connAck := packets.ConnackPacket{}
|
||||||
|
|
||||||
|
|
||||||
|
s.SetExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
||||||
|
s.expireTimer.Stop()
|
||||||
|
|
||||||
|
if req.ConnectPakcet.ClientId == nil {
|
||||||
|
if s.ClientID == nil {
|
||||||
|
s.ClientID = genClientID()
|
||||||
|
}
|
||||||
|
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
||||||
|
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
|
||||||
|
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
|
||||||
|
} else {
|
||||||
|
s.ClientID = req.ConnectPakcet.ClientId
|
||||||
|
}
|
||||||
|
|
||||||
|
true := byte(1)
|
||||||
|
false := byte(0)
|
||||||
|
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
||||||
|
|
||||||
|
connAck.Properties.RetainAvailable.Value = &false
|
||||||
|
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||||
|
|
||||||
|
s.Connection = req.Connection
|
||||||
|
s.Connection.sendPacket(connAck)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Starts a loop the recieves and responds to packets
|
||||||
|
func (s *Session) HandlerLoop() {
|
||||||
|
go s.Connection.PacketReadLoop()
|
||||||
|
for s.Connection != nil {
|
||||||
|
select {
|
||||||
|
case packet := <-s.Connection.PacketChannel:
|
||||||
|
packet.Visit(s)
|
||||||
|
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||||
|
s.onDisconnect()
|
||||||
|
case c := <-s.ConnecionChannel:
|
||||||
|
s.Connect(c)
|
||||||
|
case subMessage := <-s.SubscriptionChannel:
|
||||||
|
// TODO implement other qos levels
|
||||||
|
subMessage.QOSLevel = 0
|
||||||
|
subMessage.Dup = false
|
||||||
|
s.Connection.sendPacket(subMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c := <-s.ConnecionChannel:
|
||||||
|
s.Connect(c)
|
||||||
|
// Tail recursion baybeeee
|
||||||
|
s.HandlerLoop()
|
||||||
|
case _ = <- s.expireTimer.C:
|
||||||
|
s.expireSession()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) onDisconnect() {
|
||||||
|
s.Connection = nil
|
||||||
|
s.resetExpireTimer()
|
||||||
|
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
||||||
|
}
|
8
session/connectionRequest.go
Normal file
8
session/connectionRequest.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import "badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
|
||||||
|
type ConnectionRequest struct {
|
||||||
|
Connection *Connection
|
||||||
|
ConnectPakcet packets.ConnectPacket
|
||||||
|
}
|
49
session/expiry.go
Normal file
49
session/expiry.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/subscription"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Expiry struct {
|
||||||
|
ExpiryInterval time.Duration
|
||||||
|
expireTimer time.Timer
|
||||||
|
RemoveSessionChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewExpiry(channel RemoveSessionChannel) Expiry {
|
||||||
|
expiry := Expiry {}
|
||||||
|
|
||||||
|
expiry.RemoveSessionChannel = channel
|
||||||
|
expiry.expireTimer = *time.NewTimer(time.Hour*9999)
|
||||||
|
expiry.expireTimer.Stop()
|
||||||
|
return expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channel for removing a session from the global state
|
||||||
|
type RemoveSessionChannel chan string
|
||||||
|
|
||||||
|
func (s *Session) expireSession() {
|
||||||
|
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
||||||
|
s.RemoveSessionChannel <- *s.ClientID
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTime is nullable
|
||||||
|
func (s *Session) SetExpireTimer(newTime *uint32) {
|
||||||
|
var expiry = uint32(0)
|
||||||
|
if newTime != nil {
|
||||||
|
expiry = *newTime
|
||||||
|
} else {
|
||||||
|
expiry = uint32(0)
|
||||||
|
}
|
||||||
|
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) resetExpireTimer() {
|
||||||
|
if s.ExpiryInterval == 0 {
|
||||||
|
s.expireTimer.Stop()
|
||||||
|
} else {
|
||||||
|
s.expireTimer.Reset(s.ExpiryInterval)
|
||||||
|
}
|
||||||
|
}
|
93
session/packetVisitors.go
Normal file
93
session/packetVisitors.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
"badat.dev/maeqtt/v2/subscription"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
||||||
|
// Disconnect, we handle the connect packet in Connect,
|
||||||
|
// this means that we have an estabilished connection already
|
||||||
|
log.Println("WARN: Got a connect packet on an already estabilished connection")
|
||||||
|
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
||||||
|
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
||||||
|
for _, subNode := range subNodes {
|
||||||
|
subNode.NodeLock.RLock()
|
||||||
|
defer subNode.NodeLock.RUnlock()
|
||||||
|
if p.QOSLevel == 0 {
|
||||||
|
if p.PacketId != nil {
|
||||||
|
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if p.QOSLevel == 1 {
|
||||||
|
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
||||||
|
ack := packets.PubackPacket{
|
||||||
|
PacketID: *p.PacketId,
|
||||||
|
Reason: reason,
|
||||||
|
}
|
||||||
|
s.Connection.sendPacket(ack)
|
||||||
|
} else if p.QOSLevel == 2 {
|
||||||
|
panic("UNIMPLEMENTED QOS level 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sub := range subNode.Subscriptions {
|
||||||
|
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
||||||
|
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
||||||
|
err := s.Connection.close()
|
||||||
|
if err != nil && err != io.ErrClosedPipe {
|
||||||
|
log.Println("Error closing connection", err)
|
||||||
|
}
|
||||||
|
s.onDisconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
||||||
|
for _, filter := range p.TopicFilters {
|
||||||
|
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
||||||
|
}
|
||||||
|
s.Connection.sendPacket(packets.SubAckPacket{
|
||||||
|
PacketID: p.PacketId,
|
||||||
|
Reason: packets.SubackReasonGrantedQoSTwo,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
||||||
|
for _, topic := range p.Topics {
|
||||||
|
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
||||||
|
}
|
||||||
|
s.Connection.sendPacket(packets.UnsubAckPacket{
|
||||||
|
PacketID: p.PacketID,
|
||||||
|
Reason: packets.UnsubackReasonSuccess,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
||||||
|
s.Connection.sendPacket(packets.PingrespPacket{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
||||||
|
panic("not implemented") // TODO: Implement
|
||||||
|
}
|
|
@ -1,245 +0,0 @@
|
||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"math/rand"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
|
||||||
"badat.dev/maeqtt/v2/subscription"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
func Auth(username string, password []byte) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
ClientID *string
|
|
||||||
|
|
||||||
// Nullable
|
|
||||||
Connection *Connection
|
|
||||||
SubscriptionChannel chan packets.PublishPacket
|
|
||||||
ConnecionChannel chan ConnectionRequest
|
|
||||||
|
|
||||||
ExpiryInterval time.Duration // TODO
|
|
||||||
expireTimer time.Timer // TODO
|
|
||||||
|
|
||||||
freePacketID uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConnectionRequest struct {
|
|
||||||
Connection *Connection
|
|
||||||
ConnectPakcet packets.ConnectPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSession(req ConnectionRequest) Session {
|
|
||||||
sess := Session{}
|
|
||||||
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
|
||||||
|
|
||||||
sess.Connect(req)
|
|
||||||
return sess
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Connect(req ConnectionRequest) {
|
|
||||||
if s.Connection != nil {
|
|
||||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
|
||||||
}
|
|
||||||
connAck := packets.ConnackPacket{}
|
|
||||||
|
|
||||||
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
|
||||||
|
|
||||||
if req.ConnectPakcet.ClientId == nil {
|
|
||||||
if s.ClientID == nil {
|
|
||||||
s.ClientID = genClientID()
|
|
||||||
}
|
|
||||||
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
|
||||||
} else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId {
|
|
||||||
panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId))
|
|
||||||
} else {
|
|
||||||
s.ClientID = req.ConnectPakcet.ClientId
|
|
||||||
}
|
|
||||||
|
|
||||||
true := byte(1)
|
|
||||||
false := byte(0)
|
|
||||||
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
|
||||||
|
|
||||||
connAck.Properties.RetainAvailable.Value = &false
|
|
||||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
|
||||||
|
|
||||||
s.Connection = req.Connection
|
|
||||||
err := s.Connection.sendPacket(connAck)
|
|
||||||
if err != nil {
|
|
||||||
// TODO
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Starts a loop the recieves and responds to packets
|
|
||||||
func (s *Session) HandlerLoop() {
|
|
||||||
go s.Connection.PacketReadLoop()
|
|
||||||
for s.Connection != nil {
|
|
||||||
select {
|
|
||||||
case packet := <-s.Connection.PacketChannel:
|
|
||||||
packet.Visit(s)
|
|
||||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
|
||||||
s.onDisconnect()
|
|
||||||
case c := <-s.ConnecionChannel:
|
|
||||||
s.Connect(c)
|
|
||||||
case subMessage := <-s.SubscriptionChannel:
|
|
||||||
subMessage.QOSLevel = 0
|
|
||||||
subMessage.Dup = false
|
|
||||||
s.Connection.sendPacket(subMessage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c := <-s.ConnecionChannel
|
|
||||||
s.Connect(c)
|
|
||||||
s.HandlerLoop()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
|
||||||
s.Connection.sendPacket(packets.DisconnectPacket{
|
|
||||||
ReasonCode: code,
|
|
||||||
})
|
|
||||||
|
|
||||||
err := s.Connection.close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.onDisconnect()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) onDisconnect() {
|
|
||||||
s.Connection = nil
|
|
||||||
s.resetExpireTimer()
|
|
||||||
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) expireSession() {
|
|
||||||
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTime is nullable
|
|
||||||
func (s *Session) updateExpireTimer(newTime *uint32) {
|
|
||||||
var expiry = uint32(0)
|
|
||||||
if newTime != nil {
|
|
||||||
expiry = *newTime
|
|
||||||
} else {
|
|
||||||
expiry = uint32(0)
|
|
||||||
}
|
|
||||||
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
|
||||||
|
|
||||||
if s.Connection == nil {
|
|
||||||
s.resetExpireTimer()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (s *Session) resetExpireTimer() {
|
|
||||||
//s.expireTimer.Reset(s.ExpiryInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
func genClientID() *string {
|
|
||||||
buf := make([]byte, 32)
|
|
||||||
_, err := rand.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
// I don't think this can actually happen but just in case panic
|
|
||||||
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
|
||||||
}
|
|
||||||
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
|
||||||
return &id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
|
||||||
// Disconnect, we handle the connect packet in Connect,
|
|
||||||
// this means that we have an estabilished connection already
|
|
||||||
log.Println("WARN: Got a connect packet on an already estabilished connection")
|
|
||||||
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
|
||||||
subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
|
||||||
for _, subNode := range subNodes {
|
|
||||||
subNode.NodeLock.RLock()
|
|
||||||
defer subNode.NodeLock.RUnlock()
|
|
||||||
if p.QOSLevel == 0 {
|
|
||||||
if p.PacketId != nil {
|
|
||||||
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if p.QOSLevel == 1 {
|
|
||||||
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
|
||||||
ack := packets.PubackPacket{
|
|
||||||
PacketID: *p.PacketId,
|
|
||||||
Reason: reason,
|
|
||||||
}
|
|
||||||
s.Connection.sendPacket(ack)
|
|
||||||
} else if p.QOSLevel == 2 {
|
|
||||||
panic("UNIMPLEMENTED QOS level 2")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sub := range subNode.Subscriptions {
|
|
||||||
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
|
||||||
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
|
||||||
err := s.Connection.close()
|
|
||||||
if err != nil && err != io.ErrClosedPipe {
|
|
||||||
log.Println("Error closing connection", err)
|
|
||||||
}
|
|
||||||
s.onDisconnect()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
|
||||||
for _, filter := range p.TopicFilters {
|
|
||||||
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
|
||||||
}
|
|
||||||
s.Connection.sendPacket(packets.SubAckPacket{
|
|
||||||
PacketID: p.PacketId,
|
|
||||||
Reason: packets.SubackReasonGrantedQoSTwo,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
|
||||||
for _, topic := range p.Topics {
|
|
||||||
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
|
||||||
}
|
|
||||||
s.Connection.sendPacket(packets.UnsubAckPacket{
|
|
||||||
PacketID: p.PacketID,
|
|
||||||
Reason: packets.UnsubackReasonSuccess,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
|
||||||
s.Connection.sendPacket(packets.PingrespPacket{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
|
||||||
panic("not implemented") // TODO: Implement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) getFreePacketId() uint16 {
|
|
||||||
s.freePacketID += 1
|
|
||||||
return s.freePacketID
|
|
||||||
}
|
|
44
session/utils.go
Normal file
44
session/utils.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func genClientID() *string {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
_, err := rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
// I don't think this can actually happen but just in case panic
|
||||||
|
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
||||||
|
}
|
||||||
|
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
||||||
|
return &id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
||||||
|
s.Connection.sendPacket(packets.DisconnectPacket{
|
||||||
|
ReasonCode: code,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := s.Connection.close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.onDisconnect()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (s *Session) getFreePacketId() uint16 {
|
||||||
|
s.freePacketID += 1
|
||||||
|
return s.freePacketID
|
||||||
|
}
|
|
@ -1,7 +1,5 @@
|
||||||
package subscription
|
package subscription
|
||||||
|
|
||||||
//TODO WILDCARD SUBSCRIPTIONS
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
Loading…
Reference in a new issue