diff --git a/relay.go b/relay.go index 0ce5a3c..a2e4781 100644 --- a/relay.go +++ b/relay.go @@ -97,13 +97,21 @@ func (r *Relay) Connect(ctx context.Context) error { r.Challenges = make(chan string) r.Notices = make(chan string) + // close these channels when the connection is dropped + go func() { + <-r.ConnectionContext.Done() + close(r.Challenges) + close(r.Notices) + }() + conn := NewConnection(socket) r.Connection = conn // ping every 29 seconds - ticker := time.NewTicker(29 * time.Second) - defer ticker.Stop() go func() { + ticker := time.NewTicker(29 * time.Second) + defer ticker.Stop() + defer cancel() for { select { case <-ticker.C: @@ -118,6 +126,7 @@ func (r *Relay) Connect(ctx context.Context) error { // handling received messages go func() { + defer cancel() for { typ, message, err := conn.socket.ReadMessage() if err != nil { @@ -152,13 +161,17 @@ func (r *Relay) Connect(ctx context.Context) error { var content string json.Unmarshal(jsonMessage[1], &content) go func() { - r.Notices <- content + if r.ConnectionContext.Err() == nil { + r.Notices <- content + } }() case "AUTH": var challenge string json.Unmarshal(jsonMessage[1], &challenge) go func() { - r.Challenges <- challenge + if r.ConnectionContext.Err() == nil { + r.Challenges <- challenge + } }() case "EVENT": if len(jsonMessage) < 3 { @@ -235,8 +248,6 @@ func (r *Relay) Connect(ctx context.Context) error { } } } - - cancel() }() return nil