diff --git a/main.go b/main.go index 791a6b0..1045206 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "log" "net" "runtime/debug" @@ -18,41 +17,42 @@ func main() { log.Fatal(err) } + var sessions map[string]*session.Session = make(map[string]*session.Session) + for { conn, err := listener.Accept() if err != nil { log.Println("Failed accepting connection ", err) } - go handleConnection(conn) + handleConnection(conn, sessions) } } -func handleConnection(con net.Conn) { +func handleConnection(con net.Conn, sessions map[string]*session.Session) { defer handlePanic(con) - reader := bufio.NewReader(con) - - packet, err := packets.ReadPacket(reader) + conReq, err := session.NewConnection(con) if err != nil { - log.Println("Error reading packet ", err) - return + // TODO + panic(err) } - connect, isConn := (*packet).(packets.ConnectPacket) - if !isConn { - log.Println("Didn't recieve a connect packet") - err := packets.DisconnectPacket{ - ReasonCode: packets.DisconnectReasonCodeProtocolError, - }.Write(con) - if err != nil { - log.Println("Failed to disconnect after not recieving a connect packet", err) + + var sess *session.Session + if(conReq.ConnectPakcet.ClientId != nil) { + sess, exists := sessions[*conReq.ConnectPakcet.ClientId] + if exists { + sess.ConnecionChannel <- conReq } - return } - conn := session.NewConnection(connect, con) - sess := session.NewSession(&conn, connect) - - sess.HandlerLoop() + if sess == nil { + newSess := session.NewSession(conReq) + sess = &newSess + go func() { + defer handlePanic(con) + sess.HandlerLoop() + }() + } } func handlePanic(con net.Conn) { diff --git a/mqtt/packets/Publish.go b/mqtt/packets/Publish.go index d33f1e1..63f7a0e 100644 --- a/mqtt/packets/Publish.go +++ b/mqtt/packets/Publish.go @@ -2,6 +2,7 @@ package packets import ( "bufio" + "bytes" "errors" "io" @@ -60,3 +61,38 @@ func parsePublishPacket(control controlPacket) (PublishPacket, error) { return packet, nil } + +func (p PublishPacket) Write(w io.Writer) error { + buf := bytes.NewBuffer([]byte{}) + + err := types.WriteUTF8String(buf, p.TopicName) + if err != nil { + return err + } + + if p.PacketId != nil { + err := types.WriteUint16(buf, *p.PacketId) + if err != nil { + return err + } + } + + + err = properties.WriteProps(buf, p.Properties.ArrayOf()) + if err != nil { + return err + } + + buf.Write(p.Payload) + + flags := types.BoolToUint(p.Retain) + flags += uint(p.QOSLevel) << 1 + flags += types.BoolToUint(p.Dup) << 3 + conPack := controlPacket{ + packetType: PacketTypePublish, + flags: flags, + reader: buf, + } + + return conPack.write(w) +} diff --git a/session/connection.go b/session/connection.go index d9eff1e..afdc69a 100644 --- a/session/connection.go +++ b/session/connection.go @@ -2,8 +2,9 @@ package session import ( "bufio" - "fmt" + "errors" "io" + "log" "time" "badat.dev/maeqtt/v2/mqtt/packets" @@ -47,46 +48,63 @@ func (c *Connection) close() error { return c.rw.Close() } -func (c *Connection) packetReadLoop() { +func (c *Connection) PacketReadLoop() { for { pack, err := c.readPacket() - if err == io.EOF { + if err != nil { c.ClientDisconnectedChan <- true - } else if err != nil { - panic(fmt.Errorf("Unimplemented error handling, %e", err).Error()) + c.close() } else { c.PacketChannel <- *pack } } } -func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection { - conn := Connection{} - conn.rw = rw +var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect") - if p.Properties.ReceiveMaximum.Value != nil { - conn.RecvMax = *p.Properties.ReceiveMaximum.Value +func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) { + connReq := ConnectionRequest{} + + conn := Connection{} + connReq.Connection = &conn + + conn.rw = rw + packet, err := conn.readPacket() + conPack, isConn := (*packet).(packets.ConnectPacket) + if !isConn { + log.Println("Didn't recieve a connect packet") + err := packets.DisconnectPacket{ + ReasonCode: packets.DisconnectReasonCodeProtocolError, + }.Write(rw) + if err != nil { + log.Println("Failed to disconnect after not recieving a connect packet", err) + } + return connReq, FirstPackNotConnect + } + connReq.ConnectPakcet = conPack + + if conPack.Properties.ReceiveMaximum.Value != nil { + conn.RecvMax = *conPack.Properties.ReceiveMaximum.Value } else { conn.RecvMax = 65535 } - conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value + conn.MaxPacketSize = conPack.Properties.MaximumPacketSize.Value - if p.Properties.TopicAliasMaximum.Value != nil { - conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value + if conPack.Properties.TopicAliasMaximum.Value != nil { + conn.TopicAliasMax = *conPack.Properties.TopicAliasMaximum.Value } else { conn.TopicAliasMax = 0 } - if p.Properties.RequestProblemInformation.Value != nil { - conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0 + if conPack.Properties.RequestProblemInformation.Value != nil { + conn.WantsRespInf = *conPack.Properties.RequestProblemInformation.Value != 0 } else { conn.WantsRespInf = false } - conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second + conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second - conn.PacketChannel = make(chan packets.ClientPacket) + conn.PacketChannel = make(chan packets.ClientPacket, 1) - go conn.packetReadLoop() - return conn + return connReq, err } diff --git a/session/session.go b/session/session.go index 65dcfab..30dd068 100644 --- a/session/session.go +++ b/session/session.go @@ -26,6 +26,7 @@ type Session struct { // Nullable Connection *Connection SubscriptionChannel chan packets.PublishPacket + ConnecionChannel chan ConnectionRequest ExpiryInterval time.Duration // TODO expireTimer time.Timer // TODO @@ -33,23 +34,28 @@ type Session struct { freePacketID uint16 } -func NewSession(conn *Connection, p packets.ConnectPacket) Session { +type ConnectionRequest struct { + Connection *Connection + ConnectPakcet packets.ConnectPacket +} + +func NewSession(req ConnectionRequest) Session { sess := Session{} sess.SubscriptionChannel = make(chan packets.PublishPacket) - sess.Connect(conn, p) + sess.Connect(req) return sess } -func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { +func (s *Session) Connect(req ConnectionRequest) { if s.Connection != nil { s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) } connAck := packets.ConnackPacket{} - s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value) + s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value) - if p.ClientId != nil { + if req.ConnectPakcet.ClientId != nil { if s.ClientID == nil { s.ClientID = genClientID() } @@ -63,24 +69,29 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { connAck.Properties.RetainAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false - s.Connection = conn + s.Connection = req.Connection err := s.Connection.sendPacket(connAck) if err != nil { + // TODO panic(err) } } // Starts a loop the recieves and responds to packets func (s *Session) HandlerLoop() { + go s.Connection.PacketReadLoop() for s.Connection != nil { select { case packet := <-s.Connection.PacketChannel: packet.Visit(s) case _ = <-s.Connection.ClientDisconnectedChan: s.onDisconnect() + case c := <-s.ConnecionChannel: + s.Connect(c) case subMessage := <-s.SubscriptionChannel: - //TODO, log for now - log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage) + subMessage.QOSLevel = 0 + subMessage.Dup = false + s.Connection.sendPacket(subMessage) } } }