Finish basic subscriptions

This commit is contained in:
bad 2021-10-07 22:01:52 +02:00
parent 35879183be
commit 01df3272b5
4 changed files with 113 additions and 48 deletions

42
main.go
View file

@ -1,7 +1,6 @@
package main
import (
"bufio"
"log"
"net"
"runtime/debug"
@ -18,41 +17,42 @@ func main() {
log.Fatal(err)
}
var sessions map[string]*session.Session = make(map[string]*session.Session)
for {
conn, err := listener.Accept()
if err != nil {
log.Println("Failed accepting connection ", err)
}
go handleConnection(conn)
handleConnection(conn, sessions)
}
}
func handleConnection(con net.Conn) {
func handleConnection(con net.Conn, sessions map[string]*session.Session) {
defer handlePanic(con)
reader := bufio.NewReader(con)
packet, err := packets.ReadPacket(reader)
conReq, err := session.NewConnection(con)
if err != nil {
log.Println("Error reading packet ", err)
return
// TODO
panic(err)
}
connect, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't recieve a connect packet")
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(con)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
var sess *session.Session
if(conReq.ConnectPakcet.ClientId != nil) {
sess, exists := sessions[*conReq.ConnectPakcet.ClientId]
if exists {
sess.ConnecionChannel <- conReq
}
return
}
conn := session.NewConnection(connect, con)
sess := session.NewSession(&conn, connect)
sess.HandlerLoop()
if sess == nil {
newSess := session.NewSession(conReq)
sess = &newSess
go func() {
defer handlePanic(con)
sess.HandlerLoop()
}()
}
}
func handlePanic(con net.Conn) {

View file

@ -2,6 +2,7 @@ package packets
import (
"bufio"
"bytes"
"errors"
"io"
@ -60,3 +61,38 @@ func parsePublishPacket(control controlPacket) (PublishPacket, error) {
return packet, nil
}
func (p PublishPacket) Write(w io.Writer) error {
buf := bytes.NewBuffer([]byte{})
err := types.WriteUTF8String(buf, p.TopicName)
if err != nil {
return err
}
if p.PacketId != nil {
err := types.WriteUint16(buf, *p.PacketId)
if err != nil {
return err
}
}
err = properties.WriteProps(buf, p.Properties.ArrayOf())
if err != nil {
return err
}
buf.Write(p.Payload)
flags := types.BoolToUint(p.Retain)
flags += uint(p.QOSLevel) << 1
flags += types.BoolToUint(p.Dup) << 3
conPack := controlPacket{
packetType: PacketTypePublish,
flags: flags,
reader: buf,
}
return conPack.write(w)
}

View file

@ -2,8 +2,9 @@ package session
import (
"bufio"
"fmt"
"errors"
"io"
"log"
"time"
"badat.dev/maeqtt/v2/mqtt/packets"
@ -47,46 +48,63 @@ func (c *Connection) close() error {
return c.rw.Close()
}
func (c *Connection) packetReadLoop() {
func (c *Connection) PacketReadLoop() {
for {
pack, err := c.readPacket()
if err == io.EOF {
if err != nil {
c.ClientDisconnectedChan <- true
} else if err != nil {
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
c.close()
} else {
c.PacketChannel <- *pack
}
}
}
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
conn := Connection{}
conn.rw = rw
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
if p.Properties.ReceiveMaximum.Value != nil {
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
connReq := ConnectionRequest{}
conn := Connection{}
connReq.Connection = &conn
conn.rw = rw
packet, err := conn.readPacket()
conPack, isConn := (*packet).(packets.ConnectPacket)
if !isConn {
log.Println("Didn't recieve a connect packet")
err := packets.DisconnectPacket{
ReasonCode: packets.DisconnectReasonCodeProtocolError,
}.Write(rw)
if err != nil {
log.Println("Failed to disconnect after not recieving a connect packet", err)
}
return connReq, FirstPackNotConnect
}
connReq.ConnectPakcet = conPack
if conPack.Properties.ReceiveMaximum.Value != nil {
conn.RecvMax = *conPack.Properties.ReceiveMaximum.Value
} else {
conn.RecvMax = 65535
}
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
conn.MaxPacketSize = conPack.Properties.MaximumPacketSize.Value
if p.Properties.TopicAliasMaximum.Value != nil {
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
if conPack.Properties.TopicAliasMaximum.Value != nil {
conn.TopicAliasMax = *conPack.Properties.TopicAliasMaximum.Value
} else {
conn.TopicAliasMax = 0
}
if p.Properties.RequestProblemInformation.Value != nil {
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
if conPack.Properties.RequestProblemInformation.Value != nil {
conn.WantsRespInf = *conPack.Properties.RequestProblemInformation.Value != 0
} else {
conn.WantsRespInf = false
}
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second
conn.PacketChannel = make(chan packets.ClientPacket)
conn.PacketChannel = make(chan packets.ClientPacket, 1)
go conn.packetReadLoop()
return conn
return connReq, err
}

View file

@ -26,6 +26,7 @@ type Session struct {
// Nullable
Connection *Connection
SubscriptionChannel chan packets.PublishPacket
ConnecionChannel chan ConnectionRequest
ExpiryInterval time.Duration // TODO
expireTimer time.Timer // TODO
@ -33,23 +34,28 @@ type Session struct {
freePacketID uint16
}
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
type ConnectionRequest struct {
Connection *Connection
ConnectPakcet packets.ConnectPacket
}
func NewSession(req ConnectionRequest) Session {
sess := Session{}
sess.SubscriptionChannel = make(chan packets.PublishPacket)
sess.Connect(conn, p)
sess.Connect(req)
return sess
}
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
func (s *Session) Connect(req ConnectionRequest) {
if s.Connection != nil {
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
}
connAck := packets.ConnackPacket{}
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
if p.ClientId != nil {
if req.ConnectPakcet.ClientId != nil {
if s.ClientID == nil {
s.ClientID = genClientID()
}
@ -63,24 +69,29 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
connAck.Properties.RetainAvailable.Value = &false
connAck.Properties.SharedSubscriptionAvailable.Value = &false
s.Connection = conn
s.Connection = req.Connection
err := s.Connection.sendPacket(connAck)
if err != nil {
// TODO
panic(err)
}
}
// Starts a loop the recieves and responds to packets
func (s *Session) HandlerLoop() {
go s.Connection.PacketReadLoop()
for s.Connection != nil {
select {
case packet := <-s.Connection.PacketChannel:
packet.Visit(s)
case _ = <-s.Connection.ClientDisconnectedChan:
s.onDisconnect()
case c := <-s.ConnecionChannel:
s.Connect(c)
case subMessage := <-s.SubscriptionChannel:
//TODO, log for now
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
subMessage.QOSLevel = 0
subMessage.Dup = false
s.Connection.sendPacket(subMessage)
}
}
}