Finish basic subscriptions

This commit is contained in:
bad 2021-10-07 22:01:52 +02:00
parent 35879183be
commit 01df3272b5
4 changed files with 113 additions and 48 deletions

42
main.go
View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"bufio"
"log" "log"
"net" "net"
"runtime/debug" "runtime/debug"
@ -18,41 +17,42 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
var sessions map[string]*session.Session = make(map[string]*session.Session)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
log.Println("Failed accepting connection ", err) log.Println("Failed accepting connection ", err)
} }
go handleConnection(conn) handleConnection(conn, sessions)
} }
} }
func handleConnection(con net.Conn) { func handleConnection(con net.Conn, sessions map[string]*session.Session) {
defer handlePanic(con) defer handlePanic(con)
reader := bufio.NewReader(con) conReq, err := session.NewConnection(con)
packet, err := packets.ReadPacket(reader)
if err != nil { if err != nil {
log.Println("Error reading packet ", err) // TODO
return panic(err)
}
connect, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't recieve a connect packet")
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(con)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
}
return
} }
conn := session.NewConnection(connect, con) var sess *session.Session
sess := session.NewSession(&conn, connect) if(conReq.ConnectPakcet.ClientId != nil) {
sess, exists := sessions[*conReq.ConnectPakcet.ClientId]
if exists {
sess.ConnecionChannel <- conReq
}
}
if sess == nil {
newSess := session.NewSession(conReq)
sess = &newSess
go func() {
defer handlePanic(con)
sess.HandlerLoop() sess.HandlerLoop()
}()
}
} }
func handlePanic(con net.Conn) { func handlePanic(con net.Conn) {

View file

@ -2,6 +2,7 @@ package packets
import ( import (
"bufio" "bufio"
"bytes"
"errors" "errors"
"io" "io"
@ -60,3 +61,38 @@ func parsePublishPacket(control controlPacket) (PublishPacket, error) {
return packet, nil return packet, nil
} }
func (p PublishPacket) Write(w io.Writer) error {
buf := bytes.NewBuffer([]byte{})
err := types.WriteUTF8String(buf, p.TopicName)
if err != nil {
return err
}
if p.PacketId != nil {
err := types.WriteUint16(buf, *p.PacketId)
if err != nil {
return err
}
}
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
buf.Write(p.Payload)
flags := types.BoolToUint(p.Retain)
flags += uint(p.QOSLevel) << 1
flags += types.BoolToUint(p.Dup) << 3
conPack := controlPacket{
packetType: PacketTypePublish,
flags: flags,
reader: buf,
}
return conPack.write(w)
}

View file

@ -2,8 +2,9 @@ package session
import ( import (
"bufio" "bufio"
"fmt" "errors"
"io" "io"
"log"
"time" "time"
"badat.dev/maeqtt/v2/mqtt/packets" "badat.dev/maeqtt/v2/mqtt/packets"
@ -47,46 +48,63 @@ func (c *Connection) close() error {
return c.rw.Close() return c.rw.Close()
} }
func (c *Connection) packetReadLoop() { func (c *Connection) PacketReadLoop() {
for { for {
pack, err := c.readPacket() pack, err := c.readPacket()
if err == io.EOF { if err != nil {
c.ClientDisconnectedChan <- true c.ClientDisconnectedChan <- true
} else if err != nil { c.close()
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
} else { } else {
c.PacketChannel <- *pack c.PacketChannel <- *pack
} }
} }
} }
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection { var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
conn := Connection{}
conn.rw = rw
if p.Properties.ReceiveMaximum.Value != nil { func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
conn.RecvMax = *p.Properties.ReceiveMaximum.Value connReq := ConnectionRequest{}
conn := Connection{}
connReq.Connection = &conn
conn.rw = rw
packet, err := conn.readPacket()
conPack, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't recieve a connect packet")
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(rw)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
}
return connReq, FirstPackNotConnect
}
connReq.ConnectPakcet = conPack
if conPack.Properties.ReceiveMaximum.Value != nil {
conn.RecvMax = *conPack.Properties.ReceiveMaximum.Value
} else { } else {
conn.RecvMax = 65535 conn.RecvMax = 65535
} }
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value conn.MaxPacketSize = conPack.Properties.MaximumPacketSize.Value
if p.Properties.TopicAliasMaximum.Value != nil { if conPack.Properties.TopicAliasMaximum.Value != nil {
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value conn.TopicAliasMax = *conPack.Properties.TopicAliasMaximum.Value
} else { } else {
conn.TopicAliasMax = 0 conn.TopicAliasMax = 0
} }
if p.Properties.RequestProblemInformation.Value != nil { if conPack.Properties.RequestProblemInformation.Value != nil {
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0 conn.WantsRespInf = *conPack.Properties.RequestProblemInformation.Value != 0
} else { } else {
conn.WantsRespInf = false conn.WantsRespInf = false
} }
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second
conn.PacketChannel = make(chan packets.ClientPacket) conn.PacketChannel = make(chan packets.ClientPacket, 1)
go conn.packetReadLoop() return connReq, err
return conn
} }

View file

@ -26,6 +26,7 @@ type Session struct {
// Nullable // Nullable
Connection *Connection Connection *Connection
SubscriptionChannel chan packets.PublishPacket SubscriptionChannel chan packets.PublishPacket
ConnecionChannel chan ConnectionRequest
ExpiryInterval time.Duration // TODO ExpiryInterval time.Duration // TODO
expireTimer time.Timer // TODO expireTimer time.Timer // TODO
@ -33,23 +34,28 @@ type Session struct {
freePacketID uint16 freePacketID uint16
} }
func NewSession(conn *Connection, p packets.ConnectPacket) Session { type ConnectionRequest struct {
Connection *Connection
ConnectPakcet packets.ConnectPacket
}
func NewSession(req ConnectionRequest) Session {
sess := Session{} sess := Session{}
sess.SubscriptionChannel = make(chan packets.PublishPacket) sess.SubscriptionChannel = make(chan packets.PublishPacket)
sess.Connect(conn, p) sess.Connect(req)
return sess return sess
} }
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) { func (s *Session) Connect(req ConnectionRequest) {
if s.Connection != nil { if s.Connection != nil {
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver) s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
} }
connAck := packets.ConnackPacket{} connAck := packets.ConnackPacket{}
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value) s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
if p.ClientId != nil { if req.ConnectPakcet.ClientId != nil {
if s.ClientID == nil { if s.ClientID == nil {
s.ClientID = genClientID() s.ClientID = genClientID()
} }
@ -63,24 +69,29 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
connAck.Properties.RetainAvailable.Value = &false connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = conn s.Connection = req.Connection
err := s.Connection.sendPacket(connAck) err := s.Connection.sendPacket(connAck)
if err != nil { if err != nil {
// TODO
panic(err) panic(err)
} }
} }
// Starts a loop the recieves and responds to packets // Starts a loop the recieves and responds to packets
func (s *Session) HandlerLoop() { func (s *Session) HandlerLoop() {
go s.Connection.PacketReadLoop()
for s.Connection != nil { for s.Connection != nil {
select { select {
case packet := <-s.Connection.PacketChannel: case packet := <-s.Connection.PacketChannel:
packet.Visit(s) packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan: case _ = <-s.Connection.ClientDisconnectedChan:
s.onDisconnect() s.onDisconnect()
case c := <-s.ConnecionChannel:
s.Connect(c)
case subMessage := <-s.SubscriptionChannel: case subMessage := <-s.SubscriptionChannel:
//TODO, log for now subMessage.QOSLevel = 0
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage) subMessage.Dup = false
s.Connection.sendPacket(subMessage)
} }
} }
} }