diff --git a/connection.go b/connection.go index 98dd850..141c565 100644 --- a/connection.go +++ b/connection.go @@ -3,10 +3,11 @@ package nostr import ( "bytes" "context" - "crypto/tls" "errors" "fmt" "io" + "net/http" + "net/textproto" "sync/atomic" "time" @@ -20,7 +21,7 @@ type writeRequest struct { answer chan error } -func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error { +func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error { debugLogf("{%s} connecting!\n", r.URL) dialCtx := ctx @@ -29,7 +30,18 @@ func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) } - c, _, err := ws.Dial(dialCtx, r.URL, getConnectionOptions(r.requestHeader, tlsConfig)) + dialOpts := &ws.DialOptions{ + HTTPHeader: http.Header{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"}, + }, + CompressionMode: ws.CompressionContextTakeover, + HTTPClient: httpClient, + } + for k, v := range r.requestHeader { + dialOpts.HTTPHeader[k] = v + } + + c, _, err := ws.Dial(dialCtx, r.URL, dialOpts) if err != nil { return err } diff --git a/connection_options.go b/connection_options.go deleted file mode 100644 index c14b933..0000000 --- a/connection_options.go +++ /dev/null @@ -1,34 +0,0 @@ -//go:build !js - -package nostr - -import ( - "crypto/tls" - "net/http" - "net/textproto" - - ws "github.com/coder/websocket" -) - -var defaultConnectionOptions = &ws.DialOptions{ - CompressionMode: ws.CompressionContextTakeover, - HTTPHeader: http.Header{ - textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"}, - }, -} - -func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions { - if requestHeader == nil && tlsConfig == nil { - return defaultConnectionOptions - } - - return &ws.DialOptions{ - HTTPHeader: requestHeader, - CompressionMode: ws.CompressionContextTakeover, - HTTPClient: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, - }, - } -} diff --git a/connection_options_js.go b/connection_options_js.go deleted file mode 100644 index 3e80025..0000000 --- a/connection_options_js.go +++ /dev/null @@ -1,15 +0,0 @@ -package nostr - -import ( - "crypto/tls" - "net/http" - - ws "github.com/coder/websocket" -) - -var emptyOptions = ws.DialOptions{} - -func getConnectionOptions(_ http.Header, _ *tls.Config) *ws.DialOptions { - // on javascript we ignore everything because there is nothing else we can do - return &emptyOptions -} diff --git a/relay.go b/relay.go index d81c279..c7e9383 100644 --- a/relay.go +++ b/relay.go @@ -114,11 +114,20 @@ func (r *Relay) IsConnected() bool { return !r.closed.Load() } // The given context here is only used during the connection phase. The long-living // relay connection will be based on the context given to NewRelay(). func (r *Relay) Connect(ctx context.Context) error { - return r.ConnectWithTLS(ctx, nil) + return r.ConnectWithClient(ctx, nil) } // ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that. func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error { + return r.ConnectWithClient(ctx, &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + }) +} + +// ConnectWithClient is like Connect(), but takes a special *http.Client if you need that. +func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) error { if r.connectionContext == nil || r.Subscriptions == nil { return fmt.Errorf("relay must be initialized with a call to NewRelay()") } @@ -127,7 +136,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error return fmt.Errorf("invalid relay URL '%s'", r.URL) } - if err := r.newConnection(ctx, tlsConfig); err != nil { + if err := r.newConnection(ctx, client); err != nil { return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) }