address closeMutex deadlock by canceling the relay connection context on doClose().

This commit is contained in:
fiatjaf
2025-08-23 09:54:36 -03:00
parent c2635c1f20
commit 69c0981b51
2 changed files with 8 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"sync" "sync"
@@ -19,6 +20,7 @@ var ErrDisconnected = errors.New("<disconnected>")
// Connection represents a websocket connection to a Nostr relay. // Connection represents a websocket connection to a Nostr relay.
type connection struct { type connection struct {
conn *ws.Conn conn *ws.Conn
cancel context.CancelCauseFunc
writeQueue chan writeRequest writeQueue chan writeRequest
closed *atomic.Bool closed *atomic.Bool
closedNotify chan struct{} closedNotify chan struct{}
@@ -32,6 +34,7 @@ type writeRequest struct {
func newConnection( func newConnection(
ctx context.Context, ctx context.Context,
cancel context.CancelCauseFunc,
url string, url string,
handleMessage func(string), handleMessage func(string),
requestHeader http.Header, requestHeader http.Header,
@@ -62,6 +65,7 @@ func newConnection(
conn := &connection{ conn := &connection{
conn: c, conn: c,
cancel: cancel,
writeQueue: writeQueue, writeQueue: writeQueue,
closed: &atomic.Bool{}, closed: &atomic.Bool{},
closedNotify: make(chan struct{}), closedNotify: make(chan struct{}),
@@ -81,8 +85,8 @@ func newConnection(
err := c.Ping(ctx) err := c.Ping(ctx)
cancel() cancel()
if err != nil { if err != nil {
conn.doClose(ws.StatusAbnormalClosure, "ping took too long")
debugLogf("{%s} closing!, ping failed: '%s'\n", url, err) debugLogf("{%s} closing!, ping failed: '%s'\n", url, err)
conn.doClose(ws.StatusAbnormalClosure, "ping took too long")
return return
} }
case wr := <-writeQueue: case wr := <-writeQueue:
@@ -91,11 +95,11 @@ func newConnection(
err := c.Write(ctx, ws.MessageText, wr.msg) err := c.Write(ctx, ws.MessageText, wr.msg)
cancel() cancel()
if err != nil { if err != nil {
debugLogf("{%s} closing!, write failed: '%s'\n", url, err)
conn.doClose(ws.StatusAbnormalClosure, "write failed") conn.doClose(ws.StatusAbnormalClosure, "write failed")
if wr.answer != nil { if wr.answer != nil {
wr.answer <- err wr.answer <- err
} }
debugLogf("{%s} closing!, write failed: '%s'\n", url, err)
return return
} }
if wr.answer != nil { if wr.answer != nil {
@@ -138,6 +142,7 @@ func (c *connection) doClose(code ws.StatusCode, reason string) {
wasClosed := c.closed.Swap(true) wasClosed := c.closed.Swap(true)
if !wasClosed { if !wasClosed {
c.conn.Close(code, reason) c.conn.Close(code, reason)
c.cancel(fmt.Errorf("doClose(): %s", reason))
c.closeMutex.Lock() c.closeMutex.Lock()
close(c.closedNotify) close(c.closedNotify)
close(c.writeQueue) close(c.writeQueue)

View File

@@ -118,7 +118,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
return fmt.Errorf("invalid relay URL '%s'", r.URL) return fmt.Errorf("invalid relay URL '%s'", r.URL)
} }
conn, err := newConnection(ctx, r.URL, r.handleMessage, r.requestHeader, tlsConfig) conn, err := newConnection(ctx, r.connectionContextCancel, r.URL, r.handleMessage, r.requestHeader, tlsConfig)
if err != nil { if err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
} }