allow using a custom http client.

fixes nostr:nevent1qvzqqqqx25pzqm8ksn7p6aak225sed38vlzngtuwl50tf0e8ahzuzkhpmuahzgzdqyd8wumn8ghj7cmpvd5x2v3wwpexjmtpdshxuet59amrzqg4waehxw309aex2mrp0yhxgctdw4eju6t09uqzq8r9r4par63whq6px0af5uxtkkx0psydtamq6rdcva248l27l2szensns3
This commit is contained in:
fiatjaf
2025-12-18 12:01:30 -03:00
parent 4d7f7ce25d
commit 97424e363a
4 changed files with 26 additions and 54 deletions

View File

@@ -3,10 +3,11 @@ package nostr
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http"
"net/textproto"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -20,7 +21,7 @@ type writeRequest struct {
answer chan error answer chan error
} }
func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error { func (r *Relay) newConnection(ctx context.Context, httpClient *http.Client) error {
debugLogf("{%s} connecting!\n", r.URL) debugLogf("{%s} connecting!\n", r.URL)
dialCtx := ctx dialCtx := ctx
@@ -29,7 +30,18 @@ func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long")) dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
} }
c, _, err := ws.Dial(dialCtx, r.URL, getConnectionOptions(r.requestHeader, tlsConfig)) dialOpts := &ws.DialOptions{
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
},
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: httpClient,
}
for k, v := range r.requestHeader {
dialOpts.HTTPHeader[k] = v
}
c, _, err := ws.Dial(dialCtx, r.URL, dialOpts)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,34 +0,0 @@
//go:build !js
package nostr
import (
"crypto/tls"
"net/http"
"net/textproto"
ws "github.com/coder/websocket"
)
var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"fiatjaf.com/nostr"},
},
}
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}
return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}

View File

@@ -1,15 +0,0 @@
package nostr
import (
"crypto/tls"
"net/http"
ws "github.com/coder/websocket"
)
var emptyOptions = ws.DialOptions{}
func getConnectionOptions(_ http.Header, _ *tls.Config) *ws.DialOptions {
// on javascript we ignore everything because there is nothing else we can do
return &emptyOptions
}

View File

@@ -114,11 +114,20 @@ func (r *Relay) IsConnected() bool { return !r.closed.Load() }
// The given context here is only used during the connection phase. The long-living // The given context here is only used during the connection phase. The long-living
// relay connection will be based on the context given to NewRelay(). // relay connection will be based on the context given to NewRelay().
func (r *Relay) Connect(ctx context.Context) error { func (r *Relay) Connect(ctx context.Context) error {
return r.ConnectWithTLS(ctx, nil) return r.ConnectWithClient(ctx, nil)
} }
// ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that. // ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that.
func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error { func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error {
return r.ConnectWithClient(ctx, &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
})
}
// ConnectWithClient is like Connect(), but takes a special *http.Client if you need that.
func (r *Relay) ConnectWithClient(ctx context.Context, client *http.Client) error {
if r.connectionContext == nil || r.Subscriptions == nil { if r.connectionContext == nil || r.Subscriptions == nil {
return fmt.Errorf("relay must be initialized with a call to NewRelay()") return fmt.Errorf("relay must be initialized with a call to NewRelay()")
} }
@@ -127,7 +136,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
return fmt.Errorf("invalid relay URL '%s'", r.URL) return fmt.Errorf("invalid relay URL '%s'", r.URL)
} }
if err := r.newConnection(ctx, tlsConfig); err != nil { if err := r.newConnection(ctx, client); err != nil {
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
} }