reuse buffer when reading messages from websocket.

This commit is contained in:
fiatjaf
2023-07-30 17:12:30 -03:00
parent cfcd19568b
commit 35faff858a
2 changed files with 14 additions and 12 deletions

View File

@@ -121,23 +121,23 @@ func (c *Connection) WriteMessage(data []byte) error {
return nil return nil
} }
func (c *Connection) ReadMessage(ctx context.Context) ([]byte, error) { func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, errors.New("context canceled") return errors.New("context canceled")
default: default:
} }
h, err := c.reader.NextFrame() h, err := c.reader.NextFrame()
if err != nil { if err != nil {
c.conn.Close() c.conn.Close()
return nil, fmt.Errorf("failed to advance frame: %w", err) return fmt.Errorf("failed to advance frame: %w", err)
} }
if h.OpCode.IsControl() { if h.OpCode.IsControl() {
if err := c.controlHandler(h, c.reader); err != nil { if err := c.controlHandler(h, c.reader); err != nil {
return nil, fmt.Errorf("failed to handle control frame: %w", err) return fmt.Errorf("failed to handle control frame: %w", err)
} }
} else if h.OpCode == ws.OpBinary || } else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText { h.OpCode == ws.OpText {
@@ -145,23 +145,22 @@ func (c *Connection) ReadMessage(ctx context.Context) ([]byte, error) {
} }
if err := c.reader.Discard(); err != nil { if err := c.reader.Discard(); err != nil {
return nil, fmt.Errorf("failed to discard: %w", err) return fmt.Errorf("failed to discard: %w", err)
} }
} }
buf := new(bytes.Buffer)
if c.msgState.IsCompressed() && c.enableCompression { if c.msgState.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader) c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil { if _, err := io.Copy(buf, c.flateReader); err != nil {
return nil, fmt.Errorf("failed to read message: %w", err) return fmt.Errorf("failed to read message: %w", err)
} }
} else { } else {
if _, err := io.Copy(buf, c.reader); err != nil { if _, err := io.Copy(buf, c.reader); err != nil {
return nil, fmt.Errorf("failed to read message: %w", err) return fmt.Errorf("failed to read message: %w", err)
} }
} }
return buf.Bytes(), nil return nil
} }
func (c *Connection) Close() error { func (c *Connection) Close() error {

View File

@@ -1,6 +1,7 @@
package nostr package nostr
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"log" "log"
@@ -232,16 +233,18 @@ func (r *Relay) Connect(ctx context.Context) error {
// general message reader loop // general message reader loop
go func() { go func() {
buf := new(bytes.Buffer)
for { for {
message, err := conn.ReadMessage(r.connectionContext) buf.Reset()
if err != nil { if err := conn.ReadMessage(r.connectionContext, buf); err != nil {
r.ConnectionError = err r.ConnectionError = err
r.Close() r.Close()
break break
} }
message := buf.Bytes()
debugLogf("{%s} %v\n", r.URL, message) debugLogf("{%s} %v\n", r.URL, message)
envelope := ParseMessage(message) envelope := ParseMessage(message)
if envelope == nil { if envelope == nil {
continue continue