package session import ( "encoding/base64" "fmt" "io" "log" "math/rand" "time" "badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/subscription" ) func init() { rand.Seed(time.Now().UnixNano()) } func Auth(username string, password []byte) bool { return true } type Session struct { ClientID *string // Nullable Connection *Connection SubscriptionChannel chan packets.PublishPacket ConnecionChannel chan ConnectionRequest ExpiryInterval time.Duration // TODO expireTimer time.Timer // TODO freePacketID uint16 } type ConnectionRequest struct { Connection *Connection ConnectPakcet packets.ConnectPacket } func NewSession(req ConnectionRequest) Session { sess := Session{} sess.SubscriptionChannel = make(chan packets.PublishPacket) sess.Connect(req) return sess } func (s *Session) Connect(req ConnectionRequest) { if s.Connection != nil { s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) } connAck := packets.ConnackPacket{} s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value) if req.ConnectPakcet.ClientId == nil { if s.ClientID == nil { s.ClientID = genClientID() } connAck.Properties.AssignedClientIdentifier.Value = s.ClientID } else if s.ClientID != nil && s.ClientID != req.ConnectPakcet.ClientId { panic(fmt.Errorf("Session %s connect called with a connect packet with an ID: %s", *s.ClientID, *req.ConnectPakcet.ClientId)) } else { s.ClientID = req.ConnectPakcet.ClientId } true := byte(1) false := byte(0) connAck.Properties.WildcardSubscriptionAvailable.Value = &true connAck.Properties.RetainAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false s.Connection = req.Connection err := s.Connection.sendPacket(connAck) if err != nil { // TODO panic(err) } } // Starts a loop the recieves and responds to packets func (s *Session) HandlerLoop() { go s.Connection.PacketReadLoop() for s.Connection != nil { select { case packet := <-s.Connection.PacketChannel: packet.Visit(s) case _ = <-s.Connection.ClientDisconnectedChan: s.onDisconnect() case c := <-s.ConnecionChannel: s.Connect(c) case subMessage := <-s.SubscriptionChannel: subMessage.QOSLevel = 0 subMessage.Dup = false s.Connection.sendPacket(subMessage) } } c := <-s.ConnecionChannel s.Connect(c) s.HandlerLoop() } func (s *Session) Disconnect(code packets.DisconnectReasonCode) error { s.Connection.sendPacket(packets.DisconnectPacket{ ReasonCode: code, }) err := s.Connection.close() if err != nil { return err } s.onDisconnect() return nil } func (s *Session) onDisconnect() { s.Connection = nil s.resetExpireTimer() log.Printf("Client disconnected, id: %s", *s.ClientID) } func (s *Session) expireSession() { subscription.Subscriptions.RemoveSubsForChannel(s.SubscriptionChannel) } // newTime is nullable func (s *Session) updateExpireTimer(newTime *uint32) { var expiry = uint32(0) if newTime != nil { expiry = *newTime } else { expiry = uint32(0) } s.ExpiryInterval = time.Duration(expiry) * time.Second if s.Connection == nil { s.resetExpireTimer() } } func (s *Session) resetExpireTimer() { //s.expireTimer.Reset(s.ExpiryInterval) } func genClientID() *string { buf := make([]byte, 32) _, err := rand.Read(buf) if err != nil { // I don't think this can actually happen but just in case panic panic(fmt.Errorf("Failed to generate a client id, %e", err)) } id := "Client_rand_" + base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(buf) return &id } func (s *Session) VisitConnect(_ packets.ConnectPacket) { // Disconnect, we handle the connect packet in Connect, // this means that we have an estabilished connection already log.Println("WARN: Got a connect packet on an already estabilished connection") s.Disconnect(packets.DisconnectReasonCodeProtocolError) } func (s *Session) VisitPublish(p packets.PublishPacket) { subNodes := subscription.Subscriptions.GetSubscriptions(p.TopicName) for _, subNode := range subNodes { subNode.NodeLock.RLock() defer subNode.NodeLock.RUnlock() if p.QOSLevel == 0 { if p.PacketId != nil { log.Printf("Client: %v, Got publish with qos 0 and a packet id, ignoring\n", s.ClientID) return } } else if p.QOSLevel == 1 { var reason packets.PubackReasonCode = packets.PubackReasonCodeSuccess ack := packets.PubackPacket{ PacketID: *p.PacketId, Reason: reason, } s.Connection.sendPacket(ack) } else if p.QOSLevel == 2 { panic("UNIMPLEMENTED QOS level 2") } for _, sub := range subNode.Subscriptions { if !(sub.NoLocal && sub.SubscriptionChannel == s.SubscriptionChannel) { go func(sub subscription.Subscription) { sub.SubscriptionChannel <- p }(sub) } } } } func (s *Session) VisitDisconnect(p packets.DisconnectPacket) { err := s.Connection.close() if err != nil && err != io.ErrClosedPipe { log.Println("Error closing connection", err) } s.onDisconnect() } func (s *Session) VisitSubscribe(p packets.SubscribePacket) { for _, filter := range p.TopicFilters { subscription.Subscriptions.Subscribe(filter, s.SubscriptionChannel) } s.Connection.sendPacket(packets.SubAckPacket{ PacketID: p.PacketId, Reason: packets.SubackReasonGrantedQoSTwo, }) } func (s *Session) VisitUnsubscribe(p packets.UnsubscribePacket) { for _, topic := range p.Topics { subscription.Subscriptions.Unsubscribe(topic, s.SubscriptionChannel) } s.Connection.sendPacket(packets.UnsubAckPacket{ PacketID: p.PacketID, Reason: packets.UnsubackReasonSuccess, }) } func (s *Session) VisitPing(p packets.PingreqPacket) { s.Connection.sendPacket(packets.PingrespPacket{}) } func (s *Session) VisitPubackPacket(_ packets.PubackPacket) { panic("not implemented") // TODO: Implement } func (s *Session) VisitPubrecPacket(_ packets.PubrecPacket) { panic("not implemented") // TODO: Implement } func (s *Session) VisitPubrelPacket(_ packets.PubrelPacket) { panic("not implemented") // TODO: Implement } func (s *Session) VisitPubcompPacket(_ packets.PubcompPacket) { panic("not implemented") // TODO: Implement } func (s *Session) getFreePacketId() uint16 { s.freePacketID += 1 return s.freePacketID }