241 lines
5.8 KiB
Go
241 lines
5.8 KiB
Go
package session
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"time"
|
|
|
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
|
"badat.dev/maeqtt/v2/subscription"
|
|
)
|
|
|
|
func init() {
|
|
rand.Seed(time.Now().UnixNano())
|
|
}
|
|
|
|
func Auth(username string, password []byte) bool {
|
|
return true
|
|
}
|
|
|
|
type Session struct {
|
|
ClientID *string
|
|
|
|
// Nullable
|
|
Connection *Connection
|
|
SubscriptionChannel chan packets.PublishPacket
|
|
ConnecionChannel chan ConnectionRequest
|
|
|
|
ExpiryInterval time.Duration // TODO
|
|
expireTimer time.Timer // TODO
|
|
|
|
freePacketID uint16
|
|
}
|
|
|
|
type ConnectionRequest struct {
|
|
Connection *Connection
|
|
ConnectPakcet packets.ConnectPacket
|
|
}
|
|
|
|
func NewSession(req ConnectionRequest) Session {
|
|
sess := Session{}
|
|
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
|
|
|
sess.Connect(req)
|
|
return sess
|
|
}
|
|
|
|
func (s *Session) Connect(req ConnectionRequest) {
|
|
if s.Connection != nil {
|
|
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
|
}
|
|
connAck := packets.ConnackPacket{}
|
|
|
|
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
|
|
|
if req.ConnectPakcet.ClientId != nil {
|
|
if s.ClientID == nil {
|
|
s.ClientID = genClientID()
|
|
}
|
|
connAck.Properties.AssignedClientIdentifier.Value = s.ClientID
|
|
}
|
|
|
|
true := byte(1)
|
|
false := byte(0)
|
|
connAck.Properties.WildcardSubscriptionAvailable.Value = &true
|
|
|
|
connAck.Properties.RetainAvailable.Value = &false
|
|
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
|
|
|
s.Connection = req.Connection
|
|
err := s.Connection.sendPacket(connAck)
|
|
if err != nil {
|
|
// TODO
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
// Starts a loop the recieves and responds to packets
|
|
func (s *Session) HandlerLoop() {
|
|
go s.Connection.PacketReadLoop()
|
|
for s.Connection != nil {
|
|
select {
|
|
case packet := <-s.Connection.PacketChannel:
|
|
packet.Visit(s)
|
|
case _ = <-s.Connection.ClientDisconnectedChan:
|
|
s.onDisconnect()
|
|
case c := <-s.ConnecionChannel:
|
|
s.Connect(c)
|
|
case subMessage := <-s.SubscriptionChannel:
|
|
subMessage.QOSLevel = 0
|
|
subMessage.Dup = false
|
|
s.Connection.sendPacket(subMessage)
|
|
}
|
|
}
|
|
c := <-s.ConnecionChannel
|
|
s.Connect(c)
|
|
s.HandlerLoop()
|
|
}
|
|
|
|
func (s *Session) Disconnect(code packets.DisconnectReasonCode) error {
|
|
s.Connection.sendPacket(packets.DisconnectPacket{
|
|
ReasonCode: code,
|
|
})
|
|
|
|
err := s.Connection.close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.onDisconnect()
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) onDisconnect() {
|
|
s.Connection = nil
|
|
s.resetExpireTimer()
|
|
log.Printf("Client disconnected, id: %s", *s.ClientID)
|
|
}
|
|
|
|
func (s *Session) expireSession() {
|
|
subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel)
|
|
}
|
|
|
|
// newTime is nullable
|
|
func (s *Session) updateExpireTimer(newTime *uint32) {
|
|
var expiry = uint32(0)
|
|
if newTime != nil {
|
|
expiry = *newTime
|
|
} else {
|
|
expiry = uint32(0)
|
|
}
|
|
s.ExpiryInterval = time.Duration(expiry) * time.Second
|
|
|
|
if s.Connection == nil {
|
|
s.resetExpireTimer()
|
|
}
|
|
}
|
|
func (s *Session) resetExpireTimer() {
|
|
//s.expireTimer.Reset(s.ExpiryInterval)
|
|
}
|
|
|
|
func genClientID() *string {
|
|
buf := make([]byte, 32)
|
|
_, err := rand.Read(buf)
|
|
if err != nil {
|
|
// I don't think this can actually happen but just in case panic
|
|
panic(fmt.Errorf("Failed to generate a client id, %e", err))
|
|
}
|
|
id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf)
|
|
return &id
|
|
}
|
|
|
|
func (s *Session) VisitConnect(_ packets.ConnectPacket) {
|
|
// Disconnect, we handle the connect packet in Connect,
|
|
// this means that we have an estabilished connection already
|
|
log.Println("WARN: Got a connect packet on an already estabilished connection")
|
|
s.Disconnect(packets.DisconnectReasonCodeProtocolError)
|
|
}
|
|
|
|
func (s *Session) VisitPublish(p packets.PublishPacket) {
|
|
subs, lock := subscription.Subscriptions.GetSubscriptions(p.TopicName)
|
|
defer lock.Unlock()
|
|
if p.QOSLevel == 0 {
|
|
if p.PacketId != nil {
|
|
log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID)
|
|
return
|
|
}
|
|
} else if p.QOSLevel == 1 {
|
|
var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess
|
|
if len(subs) == 0 {
|
|
reason = packets.PubackReasonCodeNoMatchingSubscribers
|
|
}
|
|
ack := packets.PubackPacket{
|
|
PacketID: *p.PacketId,
|
|
Reason: reason,
|
|
}
|
|
s.Connection.sendPacket(ack)
|
|
} else if p.QOSLevel == 2 {
|
|
panic("UNIMPLEMENTED QOS level 2")
|
|
}
|
|
|
|
for _, sub := range subs {
|
|
if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) {
|
|
go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) VisitDisconnect(p packets.DisconnectPacket) {
|
|
err := s.Connection.close()
|
|
if err != nil && err != io.ErrClosedPipe {
|
|
log.Println("Error closing connection", err)
|
|
}
|
|
s.onDisconnect()
|
|
}
|
|
|
|
func (s *Session) VisitSubscribe(p packets.SubscribePacket) {
|
|
for _, filter := range p.TopicFilters {
|
|
subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel)
|
|
}
|
|
s.Connection.sendPacket(packets.SubAckPacket{
|
|
PacketID: p.PacketId,
|
|
Reason: packets.SubackReasonGrantedQoSTwo,
|
|
})
|
|
}
|
|
|
|
func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) {
|
|
for _, topic := range p.Topics {
|
|
subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel)
|
|
}
|
|
s.Connection.sendPacket(packets.UnsubAckPacket{
|
|
PacketID: p.PacketID,
|
|
Reason: packets.UnsubackReasonSuccess,
|
|
})
|
|
}
|
|
|
|
func (s *Session) VisitPing(p packets.PingreqPacket) {
|
|
s.Connection.sendPacket(packets.PingrespPacket{})
|
|
}
|
|
|
|
func (s *Session) VisitPubackPacket(_ packets.PubackPacket) {
|
|
panic("not implemented") // TODO: Implement
|
|
}
|
|
|
|
func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) {
|
|
panic("not implemented") // TODO: Implement
|
|
}
|
|
|
|
func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) {
|
|
panic("not implemented") // TODO: Implement
|
|
}
|
|
|
|
func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) {
|
|
panic("not implemented") // TODO: Implement
|
|
}
|
|
|
|
func (s *Session) getFreePacketId() uint16 {
|
|
s.freePacketID += 1
|
|
return s.freePacketID
|
|
}
|