From 3b44ab9381b6ef9cdf3223265106aa03f2c2aaad Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Sun, 25 Jun 2023 00:17:57 -0300 Subject: [PATCH] refactor these tests to ensure nested subscriptions are not blocking each other. --- subscription_test.go | 50 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/subscription_test.go b/subscription_test.go index 20319c8..fb518bd 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -2,11 +2,13 @@ package nostr import ( "context" + "fmt" + "sync/atomic" "testing" "time" ) -const RELAY = "wss://relay.damus.io" +const RELAY = "wss://relay.nostr.band" // test if we can fetch a couple of random events func TestSubscribe(t *testing.T) { @@ -51,19 +53,19 @@ func TestNestedSubscriptions(t *testing.T) { rl := mustRelayConnect(RELAY) defer rl.Close() - // fetch any note - sub, err := rl.Subscribe(context.Background(), Filters{{Kinds: []int{1}, Limit: 1}}) + n := atomic.Uint32{} + + // fetch 2 replies to a note + sub, err := rl.Subscribe(context.Background(), Filters{{Kinds: []int{1}, Tags: TagMap{"e": []string{"0e34a74f8547e3b95d52a2543719b109fd0312aba144e2ef95cba043f42fe8c5"}}, Limit: 3}}) if err != nil { t.Errorf("subscription 1 failed: %v", err) return } - timeout := time.After(5 * time.Second) - for { select { case event := <-sub.Events: - // now fetch author of this event + // now fetch author of this sub, err := rl.Subscribe(context.Background(), Filters{{Kinds: []int{0}, Authors: []string{event.PubKey}, Limit: 1}}) if err != nil { t.Errorf("subscription 2 failed: %v", err) @@ -73,30 +75,28 @@ func TestNestedSubscriptions(t *testing.T) { for { select { case <-sub.Events: - // now mentions of this person - sub, err := rl.Subscribe(context.Background(), Filters{{Kinds: []int{1}, Tags: TagMap{"p": []string{event.PubKey}}, Limit: 1}}) - if err != nil { - t.Errorf("subscription 3 failed: %v", err) + // do another subscription here in "sync" mode, just so we're sure things are not blocking + rl.QuerySync(context.Background(), Filter{Limit: 1}) + + n.Add(1) + if n.Load() == 3 { + // if we get here it means the test passed return } - - for { - select { - case <-sub.Events: - // if we get here safely we won - return - case <-timeout: - t.Errorf("timeout 3") - } - } - case <-timeout: - t.Errorf("timeout 2") + case <-sub.Context.Done(): + goto end + case <-sub.EndOfStoredEvents: + sub.Unsub() } } - case <-rl.Context().Done(): + end: + fmt.Println("") + case <-sub.EndOfStoredEvents: + sub.Unsub() + return + case <-sub.Context.Done(): t.Errorf("connection closed: %v", rl.Context().Err()) - case <-timeout: - t.Errorf("timeout 1") + return } } }