*connection to be an integral part of *Relay.

This commit is contained in:
fiatjaf
2025-11-12 06:54:41 -03:00
parent 1c43f0d666
commit c2ab9d082c
2 changed files with 54 additions and 71 deletions

View File

@@ -7,8 +7,6 @@ import (
"errors"
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"time"
@@ -17,30 +15,13 @@ import (
var ErrDisconnected = errors.New("<disconnected>")
// Connection represents a websocket connection to a Nostr relay.
type connection struct {
conn *ws.Conn
cancel context.CancelCauseFunc
writeQueue chan writeRequest
closed *atomic.Bool
closedNotify chan struct{}
closeMutex sync.Mutex
}
type writeRequest struct {
msg []byte
answer chan error
}
func newConnection(
ctx context.Context,
cancel context.CancelCauseFunc,
url string,
handleMessage func(string),
requestHeader http.Header,
tlsConfig *tls.Config,
) (*connection, error) {
debugLogf("{%s} connecting!\n", url)
func (r *Relay) newConnection(ctx context.Context, tlsConfig *tls.Config) error {
debugLogf("{%s} connecting!\n", r.URL)
dialCtx := ctx
if _, ok := dialCtx.Deadline(); !ok {
@@ -48,9 +29,9 @@ func newConnection(
dialCtx, _ = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
}
c, _, err := ws.Dial(dialCtx, url, getConnectionOptions(requestHeader, tlsConfig))
c, _, err := ws.Dial(dialCtx, r.URL, getConnectionOptions(r.requestHeader, tlsConfig))
if err != nil {
return nil, err
return err
}
c.SetReadLimit(2 << 24) // 33MB
@@ -63,40 +44,37 @@ func newConnection(
writeQueue := make(chan writeRequest)
readQueue := make(chan string)
conn := &connection{
conn: c,
cancel: cancel,
writeQueue: writeQueue,
closed: &atomic.Bool{},
closedNotify: make(chan struct{}),
}
r.conn = c
r.writeQueue = writeQueue
r.closed = &atomic.Bool{}
r.closedNotify = make(chan struct{})
go func() {
for {
select {
case <-ctx.Done():
conn.doClose(ws.StatusNormalClosure, "")
debugLogf("{%s} closing!, context done: '%s'\n", url, context.Cause(ctx))
r.closeConnection(ws.StatusNormalClosure, "")
debugLogf("{%s} closing!, context done: '%s'\n", r.URL, context.Cause(ctx))
return
case <-conn.closedNotify:
case <-r.closedNotify:
return
case <-ticker.C:
ctx, cancel := context.WithTimeoutCause(ctx, time.Millisecond*800, errors.New("ping took too long"))
err := c.Ping(ctx)
cancel()
if err != nil {
debugLogf("{%s} closing!, ping failed: '%s'\n", url, err)
conn.doClose(ws.StatusAbnormalClosure, "ping took too long")
debugLogf("{%s} closing!, ping failed: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "ping took too long")
return
}
case wr := <-writeQueue:
debugLogf("{%s} sending '%v'\n", url, string(wr.msg))
debugLogf("{%s} sending '%v'\n", r.URL, string(wr.msg))
ctx, cancel := context.WithTimeoutCause(ctx, time.Second*10, errors.New("write took too long"))
err := c.Write(ctx, ws.MessageText, wr.msg)
cancel()
if err != nil {
debugLogf("{%s} closing!, write failed: '%s'\n", url, err)
conn.doClose(ws.StatusAbnormalClosure, "write failed")
debugLogf("{%s} closing!, write failed: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "write failed")
if wr.answer != nil {
wr.answer <- err
}
@@ -106,8 +84,8 @@ func newConnection(
close(wr.answer)
}
case msg := <-readQueue:
debugLogf("{%s} received %v\n", url, msg)
handleMessage(msg)
debugLogf("{%s} received %v\n", r.URL, msg)
r.handleMessage(msg)
}
}
}()
@@ -121,13 +99,13 @@ func newConnection(
_, reader, err := c.Reader(ctx)
if err != nil {
debugLogf("{%s} closing!, reader failure: '%s'\n", url, err)
conn.doClose(ws.StatusAbnormalClosure, "failed to get reader")
debugLogf("{%s} closing!, reader failure: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "failed to get reader")
return
}
if _, err := io.Copy(buf, reader); err != nil {
debugLogf("{%s} closing!, read failure: '%s'\n", url, err)
conn.doClose(ws.StatusAbnormalClosure, "failed to read")
debugLogf("{%s} closing!, read failure: '%s'\n", r.URL, err)
r.closeConnection(ws.StatusAbnormalClosure, "failed to read")
return
}
@@ -135,17 +113,18 @@ func newConnection(
}
}()
return conn, nil
return nil
}
func (c *connection) doClose(code ws.StatusCode, reason string) {
wasClosed := c.closed.Swap(true)
func (r *Relay) closeConnection(code ws.StatusCode, reason string) {
wasClosed := r.closed.Swap(true)
if !wasClosed {
c.conn.Close(code, reason)
c.cancel(fmt.Errorf("doClose(): %s", reason))
c.closeMutex.Lock()
close(c.closedNotify)
close(c.writeQueue)
c.closeMutex.Unlock()
r.conn.Close(code, reason)
r.connectionContextCancel(fmt.Errorf("doClose(): %s", reason))
r.closeMutex.Lock()
close(r.closedNotify)
close(r.writeQueue)
r.conn = nil
r.closeMutex.Unlock()
}
}