diff --git a/relay.go b/relay.go index 3c88bd6..b8932fc 100644 --- a/relay.go +++ b/relay.go @@ -61,99 +61,104 @@ func (r *Relay) Connect() error { } conn := NewConnection(socket) + r.Connection = conn - for { - typ, message, err := conn.socket.ReadMessage() - if err != nil { - return fmt.Errorf("read error: %w", err) - } - if typ == websocket.PingMessage { - conn.WriteMessage(websocket.PongMessage, nil) - continue - } - - if typ != websocket.TextMessage || len(message) == 0 || message[0] != '[' { - continue - } - - var jsonMessage []json.RawMessage - err = json.Unmarshal(message, &jsonMessage) - if err != nil { - continue - } - - if len(jsonMessage) < 2 { - continue - } - - var label string - json.Unmarshal(jsonMessage[0], &label) - - switch label { - case "NOTICE": - var content string - json.Unmarshal(jsonMessage[1], &content) - r.Notices <- content - case "EVENT": - if len(jsonMessage) < 3 { + go func() { + for { + typ, message, err := conn.socket.ReadMessage() + if err != nil { + continue + } + if typ == websocket.PingMessage { + conn.WriteMessage(websocket.PongMessage, nil) continue } - var channel string - json.Unmarshal(jsonMessage[1], &channel) - if subscription, ok := r.subscriptions.Load(channel); ok { - var event Event - json.Unmarshal(jsonMessage[2], &event) - - // check signature of all received events, ignore invalid - ok, err := event.CheckSignature() - if !ok { - errmsg := "" - if err != nil { - errmsg = err.Error() - } - log.Printf("bad signature: %s", errmsg) - continue - } - - // check if the event matches the desired filter, ignore otherwise - if !subscription.filters.Match(&event) { - continue - } - - if !subscription.stopped { - subscription.Events <- event - } + if typ != websocket.TextMessage || len(message) == 0 || message[0] != '[' { + continue } - case "EOSE": + + var jsonMessage []json.RawMessage + err = json.Unmarshal(message, &jsonMessage) + if err != nil { + continue + } + if len(jsonMessage) < 2 { continue } - var channel string - json.Unmarshal(jsonMessage[1], &channel) - if subscription, ok := r.subscriptions.Load(channel); ok { - subscription.EndOfStoredEvents <- struct{}{} - } - case "OK": - if len(jsonMessage) < 3 { - continue - } - var ( - eventId string - ok bool - ) - json.Unmarshal(jsonMessage[1], &eventId) - json.Unmarshal(jsonMessage[2], &ok) - if statusChan, ok := r.statusChans.Load(eventId); ok { - if ok { - statusChan <- PublishStatusSucceeded - } else { - statusChan <- PublishStatusFailed + var label string + json.Unmarshal(jsonMessage[0], &label) + + switch label { + case "NOTICE": + var content string + json.Unmarshal(jsonMessage[1], &content) + r.Notices <- content + case "EVENT": + if len(jsonMessage) < 3 { + continue + } + + var channel string + json.Unmarshal(jsonMessage[1], &channel) + if subscription, ok := r.subscriptions.Load(channel); ok { + var event Event + json.Unmarshal(jsonMessage[2], &event) + + // check signature of all received events, ignore invalid + ok, err := event.CheckSignature() + if !ok { + errmsg := "" + if err != nil { + errmsg = err.Error() + } + log.Printf("bad signature: %s", errmsg) + continue + } + + // check if the event matches the desired filter, ignore otherwise + if !subscription.filters.Match(&event) { + continue + } + + if !subscription.stopped { + subscription.Events <- event + } + } + case "EOSE": + if len(jsonMessage) < 2 { + continue + } + var channel string + json.Unmarshal(jsonMessage[1], &channel) + if subscription, ok := r.subscriptions.Load(channel); ok { + subscription.EndOfStoredEvents <- struct{}{} + } + case "OK": + if len(jsonMessage) < 3 { + continue + } + var ( + eventId string + ok bool + ) + json.Unmarshal(jsonMessage[1], &eventId) + json.Unmarshal(jsonMessage[2], &ok) + + if statusChan, ok := r.statusChans.Load(eventId); ok { + if ok { + statusChan <- PublishStatusSucceeded + } else { + statusChan <- PublishStatusFailed + } } } } - } + }() + + return nil } func (r Relay) Publish(event Event) chan Status { @@ -194,6 +199,10 @@ func (r Relay) Publish(event Event) chan Status { } func (r *Relay) Subscribe(filters Filters) *Subscription { + if r.Connection == nil { + panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()")) + } + random := make([]byte, 7) rand.Read(random) id := hex.EncodeToString(random) @@ -201,10 +210,13 @@ func (r *Relay) Subscribe(filters Filters) *Subscription { } func (r *Relay) subscribe(id string, filters Filters) *Subscription { - sub := Subscription{} - sub.id = id + sub := Subscription{ + conn: r.Connection, + id: id, + Events: make(chan Event), + EndOfStoredEvents: make(chan struct{}), + } - sub.Events = make(chan Event) r.subscriptions.Store(sub.id, &sub) sub.Sub(filters)