diff --git a/connection.go b/connection.go index a6b2291..646789b 100644 --- a/connection.go +++ b/connection.go @@ -34,7 +34,13 @@ func NewConnection( requestHeader http.Header, tlsConfig *tls.Config, ) (*Connection, error) { - c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig)) + dialCtx := ctx + if _, ok := dialCtx.Deadline(); !ok { + // if no timeout is set, force it to 7 seconds + dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) + } + + c, _, err := ws.Dial(dialCtx, url, getConnectionOptions(requestHeader, tlsConfig)) if err != nil { return nil, err } diff --git a/pool.go b/pool.go index 138f24f..764faa8 100644 --- a/pool.go +++ b/pool.go @@ -145,17 +145,10 @@ func (pool *Pool) EnsureRelay(url string) (*Relay, error) { return relay, nil } + relay = NewRelay(pool.Context, url, pool.relayOptions) // try to connect // we use this ctx here so when the pool dies everything dies - ctx, cancel := context.WithTimeoutCause( - pool.Context, - time.Second*7, - errors.New("connecting to the relay took too long"), - ) - defer cancel() - - relay = NewRelay(pool.Context, url, pool.relayOptions) - if err := relay.Connect(ctx); err != nil { + if err := relay.Connect(pool.Context); err != nil { if pool.penaltyBox != nil { // putting relay in penalty box pool.penaltyBoxMu.Lock() @@ -469,7 +462,7 @@ func (pool *Pool) subMany( subscribe: sub, err = relay.Subscribe(ctx, filter, opts) if err != nil { - debugLogf("%s reconnecting because subscription died\n", nm) + debugLogf("%s reconnecting because subscription died: %s\n", nm, err) goto reconnect } diff --git a/relay.go b/relay.go index 34b232a..9e12e42 100644 --- a/relay.go +++ b/relay.go @@ -117,11 +117,6 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error return fmt.Errorf("invalid relay URL '%s'", r.URL) } - if _, ok := ctx.Deadline(); !ok { - // if no timeout is set, force it to 7 seconds - ctx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) - } - conn, err := NewConnection(ctx, r.URL, r.handleMessage, r.requestHeader, tlsConfig) if err != nil { return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) @@ -230,10 +225,10 @@ func (r *Relay) WriteWithError(msg []byte) error { ch := make(chan error) select { case r.Connection.writeQueue <- writeRequest{msg: msg, answer: ch}: - case <-r.Connection.closedNotify: - return fmt.Errorf("failed to write to %s: ", r.URL) case <-r.connectionContext.Done(): return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext)) + case <-r.Connection.closedNotify: + return fmt.Errorf("failed to write to %s: ", r.URL) } return <-ch }