Finish basic subscriptions
This commit is contained in:
parent
35879183be
commit
01df3272b5
4 changed files with 113 additions and 48 deletions
42
main.go
42
main.go
|
@ -1,7 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
@ -18,41 +17,42 @@ func main() {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sessions map[string]*session.Session = make(map[string]*session.Session)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed accepting connection ", err)
|
log.Println("Failed accepting connection ", err)
|
||||||
}
|
}
|
||||||
go handleConnection(conn)
|
handleConnection(conn, sessions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleConnection(con net.Conn) {
|
func handleConnection(con net.Conn, sessions map[string]*session.Session) {
|
||||||
defer handlePanic(con)
|
defer handlePanic(con)
|
||||||
|
|
||||||
reader := bufio.NewReader(con)
|
conReq, err := session.NewConnection(con)
|
||||||
|
|
||||||
packet, err := packets.ReadPacket(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Error reading packet ", err)
|
// TODO
|
||||||
return
|
panic(err)
|
||||||
}
|
|
||||||
connect, isConn := (*packet).(packets.ConnectPacket)
|
|
||||||
if !isConn {
|
|
||||||
log.Println("Didn't recieve a connect packet")
|
|
||||||
err := packets.DisconnectPacket{
|
|
||||||
ReasonCode: packets.DisconnectReasonCodeProtocolError,
|
|
||||||
}.Write(con)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to disconnect after not recieving a connect packet", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := session.NewConnection(connect, con)
|
var sess *session.Session
|
||||||
sess := session.NewSession(&conn, connect)
|
if(conReq.ConnectPakcet.ClientId != nil) {
|
||||||
|
sess, exists := sessions[*conReq.ConnectPakcet.ClientId]
|
||||||
|
if exists {
|
||||||
|
sess.ConnecionChannel <- conReq
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess == nil {
|
||||||
|
newSess := session.NewSession(conReq)
|
||||||
|
sess = &newSess
|
||||||
|
go func() {
|
||||||
|
defer handlePanic(con)
|
||||||
sess.HandlerLoop()
|
sess.HandlerLoop()
|
||||||
|
}()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handlePanic(con net.Conn) {
|
func handlePanic(con net.Conn) {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package packets
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
@ -60,3 +61,38 @@ func parsePublishPacket(control controlPacket) (PublishPacket, error) {
|
||||||
|
|
||||||
return packet, nil
|
return packet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p PublishPacket) Write(w io.Writer) error {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
|
||||||
|
err := types.WriteUTF8String(buf, p.TopicName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.PacketId != nil {
|
||||||
|
err := types.WriteUint16(buf, *p.PacketId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
err = properties.WriteProps(buf, p.Properties.ArrayOf())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Write(p.Payload)
|
||||||
|
|
||||||
|
flags := types.BoolToUint(p.Retain)
|
||||||
|
flags += uint(p.QOSLevel) << 1
|
||||||
|
flags += types.BoolToUint(p.Dup) << 3
|
||||||
|
conPack := controlPacket{
|
||||||
|
packetType: PacketTypePublish,
|
||||||
|
flags: flags,
|
||||||
|
reader: buf,
|
||||||
|
}
|
||||||
|
|
||||||
|
return conPack.write(w)
|
||||||
|
}
|
||||||
|
|
|
@ -2,8 +2,9 @@ package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"badat.dev/maeqtt/v2/mqtt/packets"
|
"badat.dev/maeqtt/v2/mqtt/packets"
|
||||||
|
@ -47,46 +48,63 @@ func (c *Connection) close() error {
|
||||||
return c.rw.Close()
|
return c.rw.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) packetReadLoop() {
|
func (c *Connection) PacketReadLoop() {
|
||||||
for {
|
for {
|
||||||
pack, err := c.readPacket()
|
pack, err := c.readPacket()
|
||||||
if err == io.EOF {
|
if err != nil {
|
||||||
c.ClientDisconnectedChan <- true
|
c.ClientDisconnectedChan <- true
|
||||||
} else if err != nil {
|
c.close()
|
||||||
panic(fmt.Errorf("Unimplemented error handling, %e", err).Error())
|
|
||||||
} else {
|
} else {
|
||||||
c.PacketChannel <- *pack
|
c.PacketChannel <- *pack
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnection(p packets.ConnectPacket, rw io.ReadWriteCloser) Connection {
|
var FirstPackNotConnect error = errors.New("Failed to connect, first packet is not connect")
|
||||||
conn := Connection{}
|
|
||||||
conn.rw = rw
|
|
||||||
|
|
||||||
if p.Properties.ReceiveMaximum.Value != nil {
|
func NewConnection(rw io.ReadWriteCloser) (ConnectionRequest, error) {
|
||||||
conn.RecvMax = *p.Properties.ReceiveMaximum.Value
|
connReq := ConnectionRequest{}
|
||||||
|
|
||||||
|
conn := Connection{}
|
||||||
|
connReq.Connection = &conn
|
||||||
|
|
||||||
|
conn.rw = rw
|
||||||
|
packet, err := conn.readPacket()
|
||||||
|
conPack, isConn := (*packet).(packets.ConnectPacket)
|
||||||
|
if !isConn {
|
||||||
|
log.Println("Didn't recieve a connect packet")
|
||||||
|
err := packets.DisconnectPacket{
|
||||||
|
ReasonCode: packets.DisconnectReasonCodeProtocolError,
|
||||||
|
}.Write(rw)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to disconnect after not recieving a connect packet", err)
|
||||||
|
}
|
||||||
|
return connReq, FirstPackNotConnect
|
||||||
|
}
|
||||||
|
connReq.ConnectPakcet = conPack
|
||||||
|
|
||||||
|
if conPack.Properties.ReceiveMaximum.Value != nil {
|
||||||
|
conn.RecvMax = *conPack.Properties.ReceiveMaximum.Value
|
||||||
} else {
|
} else {
|
||||||
conn.RecvMax = 65535
|
conn.RecvMax = 65535
|
||||||
}
|
}
|
||||||
conn.MaxPacketSize = p.Properties.MaximumPacketSize.Value
|
conn.MaxPacketSize = conPack.Properties.MaximumPacketSize.Value
|
||||||
|
|
||||||
if p.Properties.TopicAliasMaximum.Value != nil {
|
if conPack.Properties.TopicAliasMaximum.Value != nil {
|
||||||
conn.TopicAliasMax = *p.Properties.TopicAliasMaximum.Value
|
conn.TopicAliasMax = *conPack.Properties.TopicAliasMaximum.Value
|
||||||
} else {
|
} else {
|
||||||
conn.TopicAliasMax = 0
|
conn.TopicAliasMax = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Properties.RequestProblemInformation.Value != nil {
|
if conPack.Properties.RequestProblemInformation.Value != nil {
|
||||||
conn.WantsRespInf = *p.Properties.RequestProblemInformation.Value != 0
|
conn.WantsRespInf = *conPack.Properties.RequestProblemInformation.Value != 0
|
||||||
} else {
|
} else {
|
||||||
conn.WantsRespInf = false
|
conn.WantsRespInf = false
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.KeepAliveInterval = time.Duration(p.KeepAliveInterval) * time.Second
|
conn.KeepAliveInterval = time.Duration(conPack.KeepAliveInterval) * time.Second
|
||||||
|
|
||||||
conn.PacketChannel = make(chan packets.ClientPacket)
|
conn.PacketChannel = make(chan packets.ClientPacket, 1)
|
||||||
|
|
||||||
go conn.packetReadLoop()
|
return connReq, err
|
||||||
return conn
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ type Session struct {
|
||||||
// Nullable
|
// Nullable
|
||||||
Connection *Connection
|
Connection *Connection
|
||||||
SubscriptionChannel chan packets.PublishPacket
|
SubscriptionChannel chan packets.PublishPacket
|
||||||
|
ConnecionChannel chan ConnectionRequest
|
||||||
|
|
||||||
ExpiryInterval time.Duration // TODO
|
ExpiryInterval time.Duration // TODO
|
||||||
expireTimer time.Timer // TODO
|
expireTimer time.Timer // TODO
|
||||||
|
@ -33,23 +34,28 @@ type Session struct {
|
||||||
freePacketID uint16
|
freePacketID uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(conn *Connection, p packets.ConnectPacket) Session {
|
type ConnectionRequest struct {
|
||||||
|
Connection *Connection
|
||||||
|
ConnectPakcet packets.ConnectPacket
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSession(req ConnectionRequest) Session {
|
||||||
sess := Session{}
|
sess := Session{}
|
||||||
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
sess.SubscriptionChannel = make(chan packets.PublishPacket)
|
||||||
|
|
||||||
sess.Connect(conn, p)
|
sess.Connect(req)
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
func (s *Session) Connect(req ConnectionRequest) {
|
||||||
if s.Connection != nil {
|
if s.Connection != nil {
|
||||||
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
s.Disconnect(packets.DisconnectReasonCodeSessionTakenOver)
|
||||||
}
|
}
|
||||||
connAck := packets.ConnackPacket{}
|
connAck := packets.ConnackPacket{}
|
||||||
|
|
||||||
s.updateExpireTimer(p.Properties.SessionExpiryInterval.Value)
|
s.updateExpireTimer(req.ConnectPakcet.Properties.SessionExpiryInterval.Value)
|
||||||
|
|
||||||
if p.ClientId != nil {
|
if req.ConnectPakcet.ClientId != nil {
|
||||||
if s.ClientID == nil {
|
if s.ClientID == nil {
|
||||||
s.ClientID = genClientID()
|
s.ClientID = genClientID()
|
||||||
}
|
}
|
||||||
|
@ -63,24 +69,29 @@ func (s *Session) Connect(conn *Connection, p packets.ConnectPacket) {
|
||||||
connAck.Properties.RetainAvailable.Value = &false
|
connAck.Properties.RetainAvailable.Value = &false
|
||||||
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
connAck.Properties.SharedSubscriptionAvailable.Value = &false
|
||||||
|
|
||||||
s.Connection = conn
|
s.Connection = req.Connection
|
||||||
err := s.Connection.sendPacket(connAck)
|
err := s.Connection.sendPacket(connAck)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// TODO
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starts a loop the recieves and responds to packets
|
// Starts a loop the recieves and responds to packets
|
||||||
func (s *Session) HandlerLoop() {
|
func (s *Session) HandlerLoop() {
|
||||||
|
go s.Connection.PacketReadLoop()
|
||||||
for s.Connection != nil {
|
for s.Connection != nil {
|
||||||
select {
|
select {
|
||||||
case packet := <-s.Connection.PacketChannel:
|
case packet := <-s.Connection.PacketChannel:
|
||||||
packet.Visit(s)
|
packet.Visit(s)
|
||||||
case _ = <-s.Connection.ClientDisconnectedChan:
|
case _ = <-s.Connection.ClientDisconnectedChan:
|
||||||
s.onDisconnect()
|
s.onDisconnect()
|
||||||
|
case c := <-s.ConnecionChannel:
|
||||||
|
s.Connect(c)
|
||||||
case subMessage := <-s.SubscriptionChannel:
|
case subMessage := <-s.SubscriptionChannel:
|
||||||
//TODO, log for now
|
subMessage.QOSLevel = 0
|
||||||
log.Printf("Recieved subscription message, handling UNIMPLEMENTED, message: %v", subMessage)
|
subMessage.Dup = false
|
||||||
|
s.Connection.sendPacket(subMessage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue