*connection to be an integral part of *Relay.
This commit is contained in:
@@ -7,8 +7,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,30 +15,13 @@ import (
|
|||||||
|
|
||||||
var ErrDisconnected = errors.New("<disconnected>")
|
var ErrDisconnected = errors.New("<disconnected>")
|
||||||
|
|
||||||
// Connection represents a websocket connection to a Nostr relay.
|
|
||||||
type connection struct {
|
|
||||||
conn *ws.Conn
|
|
||||||
cancel context.CancelCauseFunc
|
|
||||||
writeQueue chan writeRequest
|
|
||||||
closed *atomic.Bool
|
|
||||||
closedNotify chan struct{}
|
|
||||||
closeMutex sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
type writeRequest struct {
|
type writeRequest struct {
|
||||||
msg []byte
|
msg []byte
|
||||||
answer chan error
|
answer chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnection(
|
func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error {
|
||||||
ctx context.Context,
|
debugLogf("{%s} connecting!\n", r.URL)
|
||||||
cancel context.CancelCauseFunc,
|
|
||||||
url string,
|
|
||||||
handleMessage func(string),
|
|
||||||
requestHeader http.Header,
|
|
||||||
tlsConfig *tls.Config,
|
|
||||||
) (*connection, error) {
|
|
||||||
debugLogf("{%s} connecting!\n", url)
|
|
||||||
|
|
||||||
dialCtx := ctx
|
dialCtx := ctx
|
||||||
if _, ok := dialCtx.Deadline(); !ok {
|
if _, ok := dialCtx.Deadline(); !ok {
|
||||||
@@ -48,9 +29,9 @@ func newConnection(
|
|||||||
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
|
||||||
}
|
}
|
||||||
|
|
||||||
c, _, err := ws.Dial(dialCtx, url, getConnectionOptions(requestHeader, tlsConfig))
|
c, _, err := ws.Dial(dialCtx, r.URL, getConnectionOptions(r.requestHeader, tlsConfig))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
c.SetReadLimit(2 << 24) // 33MB
|
c.SetReadLimit(2 << 24) // 33MB
|
||||||
|
|
||||||
@@ -63,40 +44,37 @@ func newConnection(
|
|||||||
writeQueue := make(chan writeRequest)
|
writeQueue := make(chan writeRequest)
|
||||||
readQueue := make(chan string)
|
readQueue := make(chan string)
|
||||||
|
|
||||||
conn := &connection{
|
r.conn = c
|
||||||
conn: c,
|
r.writeQueue = writeQueue
|
||||||
cancel: cancel,
|
r.closed = &atomic.Bool{}
|
||||||
writeQueue: writeQueue,
|
r.closedNotify = make(chan struct{})
|
||||||
closed: &atomic.Bool{},
|
|
||||||
closedNotify: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
conn.doClose(ws.StatusNormalClosure, "")
|
r.closeConnection(ws.StatusNormalClosure, "")
|
||||||
debugLogf("{%s} closing!, context done: '%s'\n", url, context.Cause(ctx))
|
debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
|
||||||
return
|
return
|
||||||
case <-conn.closedNotify:
|
case <-r.closedNotify:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
|
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
|
||||||
err := c.Ping(ctx)
|
err := c.Ping(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
debugLogf("{%s} closing!, ping failed: '%s'\n", url, err)
|
debugLogf("{%s} closing!, ping failed: '%s'\n", r.URL, err)
|
||||||
conn.doClose(ws.StatusAbnormalClosure, "ping took too long")
|
r.closeConnection(ws.StatusAbnormalClosure, "ping took too long")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case wr := <-writeQueue:
|
case wr := <-writeQueue:
|
||||||
debugLogf("{%s} sending '%v'\n", url, string(wr.msg))
|
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
|
||||||
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
|
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
|
||||||
err := c.Write(ctx, ws.MessageText, wr.msg)
|
err := c.Write(ctx, ws.MessageText, wr.msg)
|
||||||
cancel()
|
cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
debugLogf("{%s} closing!, write failed: '%s'\n", url, err)
|
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
|
||||||
conn.doClose(ws.StatusAbnormalClosure, "write failed")
|
r.closeConnection(ws.StatusAbnormalClosure, "write failed")
|
||||||
if wr.answer != nil {
|
if wr.answer != nil {
|
||||||
wr.answer <- err
|
wr.answer <- err
|
||||||
}
|
}
|
||||||
@@ -106,8 +84,8 @@ func newConnection(
|
|||||||
close(wr.answer)
|
close(wr.answer)
|
||||||
}
|
}
|
||||||
case msg := <-readQueue:
|
case msg := <-readQueue:
|
||||||
debugLogf("{%s} received %v\n", url, msg)
|
debugLogf("{%s} received %v\n", r.URL, msg)
|
||||||
handleMessage(msg)
|
r.handleMessage(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -121,13 +99,13 @@ func newConnection(
|
|||||||
|
|
||||||
_, reader, err := c.Reader(ctx)
|
_, reader, err := c.Reader(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
debugLogf("{%s} closing!, reader failure: '%s'\n", url, err)
|
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
|
||||||
conn.doClose(ws.StatusAbnormalClosure, "failed to get reader")
|
r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := io.Copy(buf, reader); err != nil {
|
if _, err := io.Copy(buf, reader); err != nil {
|
||||||
debugLogf("{%s} closing!, read failure: '%s'\n", url, err)
|
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
|
||||||
conn.doClose(ws.StatusAbnormalClosure, "failed to read")
|
r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,17 +113,18 @@ func newConnection(
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return conn, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *connection) doClose(code ws.StatusCode, reason string) {
|
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
|
||||||
wasClosed := c.closed.Swap(true)
|
wasClosed := r.closed.Swap(true)
|
||||||
if !wasClosed {
|
if !wasClosed {
|
||||||
c.conn.Close(code, reason)
|
r.conn.Close(code, reason)
|
||||||
c.cancel(fmt.Errorf("doClose(): %s", reason))
|
r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
|
||||||
c.closeMutex.Lock()
|
r.closeMutex.Lock()
|
||||||
close(c.closedNotify)
|
close(r.closedNotify)
|
||||||
close(c.writeQueue)
|
close(r.writeQueue)
|
||||||
c.closeMutex.Unlock()
|
r.conn = nil
|
||||||
|
r.closeMutex.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
40
relay.go
40
relay.go
@@ -14,6 +14,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
ws "github.com/coder/websocket"
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +27,12 @@ type Relay struct {
|
|||||||
URL string
|
URL string
|
||||||
requestHeader http.Header // e.g. for origin header
|
requestHeader http.Header // e.g. for origin header
|
||||||
|
|
||||||
connection *connection
|
// websocket connection
|
||||||
|
conn *ws.Conn
|
||||||
|
writeQueue chan writeRequest
|
||||||
|
closed *atomic.Bool
|
||||||
|
closedNotify chan struct{}
|
||||||
|
|
||||||
Subscriptions *xsync.MapOf[int64, *Subscription]
|
Subscriptions *xsync.MapOf[int64, *Subscription]
|
||||||
|
|
||||||
ConnectionError error
|
ConnectionError error
|
||||||
@@ -98,7 +104,7 @@ func (r *Relay) String() string {
|
|||||||
func (r *Relay) Context() context.Context { return r.connectionContext }
|
func (r *Relay) Context() context.Context { return r.connectionContext }
|
||||||
|
|
||||||
// IsConnected returns true if the connection to this relay seems to be active.
|
// IsConnected returns true if the connection to this relay seems to be active.
|
||||||
func (r *Relay) IsConnected() bool { return !r.connection.closed.Load() }
|
func (r *Relay) IsConnected() bool { return !r.closed.Load() }
|
||||||
|
|
||||||
// Connect tries to establish a websocket connection to r.URL.
|
// Connect tries to establish a websocket connection to r.URL.
|
||||||
// If the context expires before the connection is complete, an error is returned.
|
// If the context expires before the connection is complete, an error is returned.
|
||||||
@@ -121,11 +127,9 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
|
|||||||
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := newConnection(ctx, r.connectionContextCancel, r.URL, r.handleMessage, r.requestHeader, tlsConfig)
|
if err := r.newConnection(ctx, tlsConfig); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
|
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
|
||||||
}
|
}
|
||||||
r.connection = conn
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -219,33 +223,33 @@ func (r *Relay) handleMessage(message string) {
|
|||||||
|
|
||||||
// Write queues an arbitrary message to be sent to the relay.
|
// Write queues an arbitrary message to be sent to the relay.
|
||||||
func (r *Relay) Write(msg []byte) {
|
func (r *Relay) Write(msg []byte) {
|
||||||
r.connection.closeMutex.Lock()
|
r.closeMutex.Lock()
|
||||||
defer r.connection.closeMutex.Unlock()
|
defer r.closeMutex.Unlock()
|
||||||
select {
|
select {
|
||||||
case <-r.connection.closedNotify:
|
case <-r.closedNotify:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-r.connectionContext.Done():
|
case <-r.connectionContext.Done():
|
||||||
case r.connection.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
case r.writeQueue <- writeRequest{msg: msg, answer: nil}:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
// WriteWithError is like Write, but returns an error if the write fails (and the connection gets closed).
|
||||||
func (r *Relay) WriteWithError(msg []byte) error {
|
func (r *Relay) WriteWithError(msg []byte) error {
|
||||||
ch := make(chan error)
|
ch := make(chan error)
|
||||||
r.connection.closeMutex.Lock()
|
r.closeMutex.Lock()
|
||||||
defer r.connection.closeMutex.Unlock()
|
defer r.closeMutex.Unlock()
|
||||||
select {
|
select {
|
||||||
case <-r.connection.closedNotify:
|
case <-r.closedNotify:
|
||||||
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
return fmt.Errorf("failed to write to %s: <closed>", r.URL)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-r.connectionContext.Done():
|
case <-r.connectionContext.Done():
|
||||||
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
return fmt.Errorf("failed to write to %s: %w", r.URL, context.Cause(r.connectionContext))
|
||||||
case r.connection.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
|
||||||
}
|
}
|
||||||
return <-ch
|
return <-ch
|
||||||
}
|
}
|
||||||
@@ -357,7 +361,7 @@ func (r *Relay) publish(ctx context.Context, id ID, env Envelope) error {
|
|||||||
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionOptions) (*Subscription, error) {
|
||||||
sub := r.PrepareSubscription(ctx, filter, opts)
|
sub := r.PrepareSubscription(ctx, filter, opts)
|
||||||
|
|
||||||
if r.connection == nil {
|
if r.conn == nil {
|
||||||
return nil, fmt.Errorf("not connected to %s", r.URL)
|
return nil, fmt.Errorf("not connected to %s", r.URL)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -367,7 +371,7 @@ func (r *Relay) Subscribe(ctx context.Context, filter Filter, opts SubscriptionO
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-r.connection.closedNotify:
|
case <-r.closedNotify:
|
||||||
sub.unsub(ErrDisconnected)
|
sub.unsub(ErrDisconnected)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
}
|
}
|
||||||
@@ -510,13 +514,13 @@ func (r *Relay) close(reason error) error {
|
|||||||
if r.connectionContextCancel == nil {
|
if r.connectionContextCancel == nil {
|
||||||
return fmt.Errorf("relay already closed")
|
return fmt.Errorf("relay already closed")
|
||||||
}
|
}
|
||||||
r.connectionContextCancel(reason)
|
|
||||||
r.connectionContextCancel = nil
|
|
||||||
|
|
||||||
if r.connection == nil {
|
if r.conn == nil {
|
||||||
return fmt.Errorf("relay not connected")
|
return fmt.Errorf("relay not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.connectionContextCancel(reason)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user