package subscription import ( "testing" "badat.dev/maeqtt/v2/mqtt/packets" ) func assertMatches(topic packets.Topic, topicName string, shouldMatch bool, t *testing.T) { tree := newSubscriptionTreeNode() channel := make(SubscriptionChannel) topicFilter := packets.TopicFilter{ Topic: topic, MaxQoS: 1, } tree.Subscribe(topicFilter, channel) subs := tree.GetSubscriptions(topicName) if (len(subs) != 1) && shouldMatch { t.Errorf("Topic %v did not match %v", topic, topicName) } if (len(subs) == 1) && !shouldMatch { t.Errorf("Topic %v matched %v (it wasn't supposed to)", topic, topicName ) } } func TestSubscribe(t *testing.T) { topic, _ := packets.ParseTopic("a/b/c") assertMatches(topic, "a/b/c", true, t) assertMatches(topic, "a/c/c", false, t) assertMatches(topic, "b/b/c", false, t) assertMatches(topic, "aaa/c/a", false, t) } func TestSingleLevelWildcard(t *testing.T) { topic, _ := packets.ParseTopic("a/+/c") assertMatches(topic, "a/b/c", true, t) assertMatches(topic, "a/c/c", true, t) assertMatches(topic, "a/b/d", false, t) assertMatches(topic, "aaa/c/a", false, t) topic, _ = packets.ParseTopic("+/+/+") assertMatches(topic, "a/b/c", true, t) assertMatches(topic, "a/c/c", true, t) assertMatches(topic, "a/b/d/e", false, t) assertMatches(topic, "c/a", true, t) } func TestMultiLevelWildcard(t *testing.T) { topic, _ := packets.ParseTopic("a/b/c/#") assertMatches(topic, "a/b/c", true, t) assertMatches(topic, "a/b/c/a", true, t) assertMatches(topic, "a/b/c/d", true, t) assertMatches(topic, "a/b/c/f", true, t) assertMatches(topic, "a/b/d/a", false, t) }