Finish basic subscriptions
This commit is contained in:
parent
35879183be
commit
01df3272b5
4 changed files with 113 additions and 48 deletions
42
main.go
42
main.go
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"log"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
|
@ -18,41 +17,42 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var sessions map[string]*session.Session = make(map[string]*session.Session)
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println("Failed accepting connection ", err)
|
||||
}
|
||||
go handleConnection(conn)
|
||||
handleConnection(conn, sessions)
|
||||
}
|
||||
}
|
||||
|
||||
func handleConnection(con net.Conn) {
|
||||
func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
||||
defer handlePanic(con)
|
||||
|
||||
reader := bufio.NewReader(con)
|
||||
|
||||
packet, err := packets.ReadPacket(reader)
|
||||
conReq, err := session.NewConnection(con)
|
||||
if err != nil {
|
||||
log.Println("Error reading packet ", err)
|
||||
return
|
||||
// TODO
|
||||
panic(err)
|
||||
}
|
||||
connect, isConn := (*packet).(packets.ConnectPacket)
|
||||
if !isConn {
|
||||
log.Println("Didn't recieve a connect packet")
|
||||
err := packets.DisconnectPacket{
|
||||
ReasonCode: packets.DisconnectReasonCodeProtocolError,
|
||||
}.Write(con)
|
||||
if err != nil {
|
||||
log.Println("Failed to disconnect after not recieving a connect packet", err)
|
||||
|
||||
var sess *session.Session
|
||||
if(conReq.ConnectPakcet.ClientId != nil) {
|
||||
sess, exists := sessions[*conReq.ConnectPakcet.ClientId]
|
||||
if exists {
|
||||
sess.ConnecionChannel <- conReq
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conn := session.NewConnection(connect, con)
|
||||
sess := session.NewSession(&conn, connect)
|
||||
|
||||
sess.HandlerLoop()
|
||||
if sess == nil {
|
||||
newSess := session.NewSession(conReq)
|
||||
sess = &newSess
|
||||
go func() {
|
||||
defer handlePanic(con)
|
||||
sess.HandlerLoop()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func handlePanic(con net.Conn) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package packets
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
|
@ -60,3 +61,38 @@ func parsePublishPacket(control controlPacket) (PublishPacket, error) {
|
|||
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (p PublishPacket) Write(w io.Writer) error {
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
|
||||
err := types.WriteUTF8String(buf, p.TopicName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.PacketId != nil {
|
||||
err := types.WriteUint16(buf, *p.PacketId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
err = properties.WriteProps(buf, p.Properties.ArrayOf())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.Write(p.Payload)
|
||||
|
||||
flags := types.BoolToUint(p.Retain)
|
||||
flags += uint(p.QOSLevel) << 1
|
||||
flags += types.BoolToUint(p.Dup) << 3
|
||||
conPack := controlPacket{
|
||||
packetType: PacketTypePublish,
|
||||
flags: flags,
|
||||
reader: buf,
|
||||
}
|
||||
|
||||
return conPack.write(w)
|
||||
}
|
||||
|
|
|
@ -2,8 +2,9 @@ package session
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||
|
@ -47,46 +48,63 @@ func (c *Connection) close() error {
|
|||
return c.rw.Close()
|
||||
}
|
||||
|
||||
func (c *Connection) packetReadLoop() {
|
||||
func (c *Connection) PacketReadLoop() {
|
||||
for {
|
||||
pack, err := c.readPacket()
|
||||
if err == io.EOF {
|
||||
if err != nil {
|
||||
c.ClientDisconnectedChan <- true
|
||||
} else if err != nil {
|
||||
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
|
||||
c.close()
|
||||
} else {
|
||||
c.PacketChannel <- *pack
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
|
||||
conn := Connection{}
|
||||
conn.rw = rw
|
||||
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
|
||||
|
||||
if p.Properties.ReceiveMaximum.Value != nil {
|
||||
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
|
||||
func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
|
||||
connReq := ConnectionRequest{}
|
||||
|
||||
conn := Connection{}
|
||||
connReq.Connection = &conn
|
||||
|
||||
conn.rw = rw
|
||||
packet, err := conn.readPacket()
|
||||
conPack, isConn := (*packet).(packets.ConnectPacket)
|
||||
if !isConn {
|
||||
log.Println("Didn't recieve a connect packet")
|
||||
err := packets.DisconnectPacket{
|
||||
ReasonCode: packets.DisconnectReasonCodeProtocolError,
|
||||
}.Write(rw)
|
||||
if err != nil {
|
||||
log.Println("Failed to disconnect after not recieving a connect packet", err)
|
||||
}
|
||||
return connReq, FirstPackNotConnect
|
||||
}
|
||||
connReq.ConnectPakcet = conPack
|
||||
|
||||
if conPack.Properties.ReceiveMaximum.Value != nil {
|
||||
conn.RecvMax = *conPack.Properties.ReceiveMaximum.Value
|
||||
} else {
|
||||
conn.RecvMax = 65535
|
||||
}
|
||||
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
|
||||
conn.MaxPacketSize = conPack.Properties.MaximumPacketSize.Value
|
||||
|
||||
if p.Properties.TopicAliasMaximum.Value != nil {
|
||||
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
|
||||
if conPack.Properties.TopicAliasMaximum.Value != nil {
|
||||
conn.TopicAliasMax = *conPack.Properties.TopicAliasMaximum.Value
|
||||
} else {
|
||||
conn.TopicAliasMax = 0
|
||||
}
|
||||
|
||||
if p.Properties.RequestProblemInformation.Value != nil {
|
||||
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
|
||||
if conPack.Properties.RequestProblemInformation.Value != nil {
|
||||
conn.WantsRespInf = *conPack.Properties.RequestProblemInformation.Value != 0
|
||||
} else {
|
||||
conn.WantsRespInf = false
|
||||
}
|
||||
|
||||
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
|
||||
conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second
|
||||
|
||||
conn.PacketChannel = make(chan packets.ClientPacket)
|
||||
conn.PacketChannel = make(chan packets.ClientPacket, 1)
|
||||
|
||||
go conn.packetReadLoop()
|
||||
return conn
|
||||
return connReq, err
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ type Session struct {
|
|||
// Nullable
|
||||
Connection *Connection
|
||||
SubscriptionChannel chan packets.PublishPacket
|
||||
ConnecionChannel chan ConnectionRequest
|
||||
|
||||
ExpiryInterval time.Duration // TODO
|
||||
expireTimer time.Timer // TODO
|
||||
|
@ -33,23 +34,28 @@ type Session struct {
|
|||
freePacketID uint16
|
||||
}
|
||||
|
||||
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
||||
type ConnectionRequest struct {
|
||||
Connection *Connection
|
||||
ConnectPakcet packets.ConnectPacket
|
||||
}
|
||||
|
||||
func NewSession(req ConnectionRequest) Session {
|
||||
sess := Session{}
|
||||
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||
|
||||
sess.Connect(conn, p)
|
||||
sess.Connect(req)
|
||||
return sess
|
||||
}
|
||||
|
||||
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
||||
func (s *Session) Connect(req ConnectionRequest) {
|
||||
if s.Connection != nil {
|
||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||
}
|
||||
connAck := packets.ConnackPacket{}
|
||||
|
||||
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
|
||||
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
||||
|
||||
if p.ClientId != nil {
|
||||
if req.ConnectPakcet.ClientId != nil {
|
||||
if s.ClientID == nil {
|
||||
s.ClientID = genClientID()
|
||||
}
|
||||
|
@ -63,24 +69,29 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
|||
connAck.Properties.RetainAvailable.Value = &false
|
||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||
|
||||
s.Connection = conn
|
||||
s.Connection = req.Connection
|
||||
err := s.Connection.sendPacket(connAck)
|
||||
if err != nil {
|
||||
// TODO
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Starts a loop the recieves and responds to packets
|
||||
func (s *Session) HandlerLoop() {
|
||||
go s.Connection.PacketReadLoop()
|
||||
for s.Connection != nil {
|
||||
select {
|
||||
case packet := <-s.Connection.PacketChannel:
|
||||
packet.Visit(s)
|
||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||
s.onDisconnect()
|
||||
case c := <-s.ConnecionChannel:
|
||||
s.Connect(c)
|
||||
case subMessage := <-s.SubscriptionChannel:
|
||||
//TODO, log for now
|
||||
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
|
||||
subMessage.QOSLevel = 0
|
||||
subMessage.Dup = false
|
||||
s.Connection.sendPacket(subMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue