nip77: convert to dealing with bytes instead of hex strings.

it was cool but this should be faster and less confusing.
This commit is contained in:
fiatjaf
2025-05-12 05:54:39 -03:00
parent bbffe45824
commit 94d29f1230
8 changed files with 66 additions and 157 deletions

View File

@@ -1,12 +1,13 @@
package negentropy package negentropy
import ( import (
"bytes"
"fmt" "fmt"
"fiatjaf.com/nostr" "fiatjaf.com/nostr"
) )
func (n *Negentropy) readTimestamp(reader *StringHexReader) (nostr.Timestamp, error) { func (n *Negentropy) readTimestamp(reader *bytes.Reader) (nostr.Timestamp, error) {
delta, err := readVarInt(reader) delta, err := readVarInt(reader)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -31,7 +32,7 @@ func (n *Negentropy) readTimestamp(reader *StringHexReader) (nostr.Timestamp, er
return timestamp, nil return timestamp, nil
} }
func (n *Negentropy) readBound(reader *StringHexReader) (Bound, error) { func (n *Negentropy) readBound(reader *bytes.Reader) (Bound, error) {
timestamp, err := n.readTimestamp(reader) timestamp, err := n.readTimestamp(reader)
if err != nil { if err != nil {
return Bound{}, fmt.Errorf("failed to decode bound timestamp: %w", err) return Bound{}, fmt.Errorf("failed to decode bound timestamp: %w", err)
@@ -43,14 +44,14 @@ func (n *Negentropy) readBound(reader *StringHexReader) (Bound, error) {
} }
pfb := make([]byte, length) pfb := make([]byte, length)
if err := reader.ReadHexBytes(pfb); err != nil { if _, err := reader.Read(pfb); err != nil {
return Bound{}, fmt.Errorf("failed to read bound id: %w", err) return Bound{}, fmt.Errorf("failed to read bound id: %w", err)
} }
return Bound{timestamp, pfb}, nil return Bound{timestamp, pfb}, nil
} }
func (n *Negentropy) writeTimestamp(w *StringHexWriter, timestamp nostr.Timestamp) { func (n *Negentropy) writeTimestamp(w *bytes.Buffer, timestamp nostr.Timestamp) {
if timestamp == maxTimestamp { if timestamp == maxTimestamp {
// zeroes are infinite // zeroes are infinite
n.lastTimestampOut = maxTimestamp // cache this (see below) n.lastTimestampOut = maxTimestamp // cache this (see below)
@@ -69,10 +70,10 @@ func (n *Negentropy) writeTimestamp(w *StringHexWriter, timestamp nostr.Timestam
return return
} }
func (n *Negentropy) writeBound(w *StringHexWriter, bound Bound) { func (n *Negentropy) writeBound(w *bytes.Buffer, bound Bound) {
n.writeTimestamp(w, bound.Timestamp) n.writeTimestamp(w, bound.Timestamp)
writeVarInt(w, len(bound.IDPrefix)) writeVarInt(w, len(bound.IDPrefix))
w.WriteBytes(bound.IDPrefix) w.Write(bound.IDPrefix)
} }
func getMinimalBound(prev, curr Item) Bound { func getMinimalBound(prev, curr Item) Bound {
@@ -92,11 +93,11 @@ func getMinimalBound(prev, curr Item) Bound {
return Bound{curr.Timestamp, curr.ID[:(sharedPrefixBytes + 1)]} return Bound{curr.Timestamp, curr.ID[:(sharedPrefixBytes + 1)]}
} }
func readVarInt(reader *StringHexReader) (int, error) { func readVarInt(reader *bytes.Reader) (int, error) {
var res int = 0 var res int = 0
for { for {
b, err := reader.ReadHexByte() b, err := reader.ReadByte()
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -110,13 +111,13 @@ func readVarInt(reader *StringHexReader) (int, error) {
return res, nil return res, nil
} }
func writeVarInt(w *StringHexWriter, n int) { func writeVarInt(w *bytes.Buffer, n int) {
if n == 0 { if n == 0 {
w.WriteByte(0) w.WriteByte(0)
return return
} }
w.WriteBytes(EncodeVarInt(n)) w.Write(EncodeVarInt(n))
} }
func EncodeVarInt(n int) []byte { func EncodeVarInt(n int) []byte {

View File

@@ -1,93 +0,0 @@
package negentropy
import (
"encoding/hex"
"io"
)
func NewStringHexReader(source string) *StringHexReader {
return &StringHexReader{source, 0, make([]byte, 1)}
}
type StringHexReader struct {
source string
idx int
tmp []byte
}
func (r *StringHexReader) Len() int {
return len(r.source) - r.idx
}
func (r *StringHexReader) ReadHexBytes(buf []byte) error {
n := len(buf) * 2
r.idx += n
if len(r.source) < r.idx {
return io.EOF
}
_, err := hex.Decode(buf, []byte(r.source[r.idx-n:r.idx]))
return err
}
func (r *StringHexReader) ReadHexByte() (byte, error) {
err := r.ReadHexBytes(r.tmp)
return r.tmp[0], err
}
func (r *StringHexReader) ReadString(size int) (string, error) {
r.idx += size
if len(r.source) < r.idx {
return "", io.EOF
}
return r.source[r.idx-size : r.idx], nil
}
func NewStringHexWriter(buf []byte) *StringHexWriter {
return &StringHexWriter{buf, make([]byte, 2)}
}
type StringHexWriter struct {
hexbuf []byte
tmp []byte
}
func (r *StringHexWriter) Len() int {
return len(r.hexbuf)
}
func (r *StringHexWriter) Hex() string {
return string(r.hexbuf)
}
func (r *StringHexWriter) Reset() {
r.hexbuf = r.hexbuf[:0]
}
func (r *StringHexWriter) WriteHex(hexString string) {
r.hexbuf = append(r.hexbuf, hexString...)
return
}
func (r *StringHexWriter) WriteByte(b byte) error {
hex.Encode(r.tmp, []byte{b})
r.hexbuf = append(r.hexbuf, r.tmp...)
return nil
}
func (r *StringHexWriter) WriteBytes(in []byte) {
r.hexbuf = hex.AppendEncode(r.hexbuf, in)
// curr := len(r.hexbuf)
// next := curr + len(in)*2
// for cap(r.hexbuf) < next {
// r.hexbuf = append(r.hexbuf, in...)
// }
// r.hexbuf = r.hexbuf[0:next]
// dst := r.hexbuf[curr:next]
// hex.Encode(dst, in)
return
}

View File

@@ -1,10 +1,11 @@
package negentropy package negentropy
import ( import (
"bytes"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io"
"math" "math"
"strings"
"unsafe" "unsafe"
"fiatjaf.com/nostr" "fiatjaf.com/nostr"
@@ -60,55 +61,60 @@ func (n *Negentropy) Start() string {
n.initialized = true n.initialized = true
n.isClient = true n.isClient = true
output := NewStringHexWriter(make([]byte, 0, 1+n.storage.Size()*64)) output := bytes.NewBuffer(make([]byte, 0, 1+n.storage.Size()*64))
output.WriteByte(protocolVersion) output.WriteByte(protocolVersion)
n.SplitRange(0, n.storage.Size(), InfiniteBound, output) n.SplitRange(0, n.storage.Size(), InfiniteBound, output)
return output.Hex() return hex.EncodeToString(output.Bytes())
} }
func (n *Negentropy) Reconcile(msg string) (output string, err error) { func (n *Negentropy) Reconcile(msg string) (string, error) {
n.initialized = true n.initialized = true
reader := NewStringHexReader(msg) msgb, err := hex.DecodeString(msg)
output, err = n.reconcileAux(reader)
if err != nil { if err != nil {
return "", err return "", err
} }
if len(output) == 2 && n.isClient { reader := bytes.NewReader(msgb)
output, err := n.reconcileAux(reader)
if err != nil {
return "", err
}
if len(output) == 1 && n.isClient {
close(n.Haves) close(n.Haves)
close(n.HaveNots) close(n.HaveNots)
return "", nil return "", nil
} }
return output, nil return hex.EncodeToString(output), nil
} }
func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) { func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) {
n.lastTimestampIn, n.lastTimestampOut = 0, 0 // reset for each message n.lastTimestampIn, n.lastTimestampOut = 0, 0 // reset for each message
fullOutput := NewStringHexWriter(make([]byte, 0, 5000)) fullOutput := bytes.NewBuffer(make([]byte, 0, 5000))
fullOutput.WriteByte(protocolVersion) fullOutput.WriteByte(protocolVersion)
pv, err := reader.ReadHexByte() pv, err := reader.ReadByte()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read pv: %w", err) return nil, fmt.Errorf("failed to read pv: %w", err)
} }
if pv != protocolVersion { if pv != protocolVersion {
if n.isClient { if n.isClient {
return "", fmt.Errorf("unsupported negentropy protocol version %v", pv) return nil, fmt.Errorf("unsupported negentropy protocol version %v", pv)
} }
// if we're a server we just return our protocol version // if we're a server we just return our protocol version
return fullOutput.Hex(), nil return fullOutput.Bytes(), nil
} }
var prevBound Bound var prevBound Bound
prevIndex := 0 prevIndex := 0
skipping := false // this means we are currently coalescing ranges into skip skipping := false // this means we are currently coalescing ranges into skip
partialOutput := NewStringHexWriter(make([]byte, 0, 100)) partialOutput := bytes.NewBuffer(make([]byte, 0, 100))
for reader.Len() > 0 { for reader.Len() > 0 {
partialOutput.Reset() partialOutput.Reset()
@@ -123,11 +129,11 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
currBound, err := n.readBound(reader) currBound, err := n.readBound(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode bound: %w", err) return nil, fmt.Errorf("failed to decode bound: %w", err)
} }
modeVal, err := readVarInt(reader) modeVal, err := readVarInt(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode mode: %w", err) return nil, fmt.Errorf("failed to decode mode: %w", err)
} }
mode := Mode(modeVal) mode := Mode(modeVal)
@@ -139,9 +145,10 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
skipping = true skipping = true
case FingerprintMode: case FingerprintMode:
theirFingerprint, err := reader.ReadString(FingerprintSize * 2) theirFingerprint := [FingerprintSize]byte{}
_, err := reader.Read(theirFingerprint[:])
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read fingerprint: %w", err) return nil, fmt.Errorf("failed to read fingerprint: %w", err)
} }
ourFingerprint := n.storage.Fingerprint(lower, upper) ourFingerprint := n.storage.Fingerprint(lower, upper)
@@ -155,15 +162,15 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
case IdListMode: case IdListMode:
numIds, err := readVarInt(reader) numIds, err := readVarInt(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode number of ids: %w", err) return nil, fmt.Errorf("failed to decode number of ids: %w", err)
} }
// what they have // what they have
theirItems := make(map[nostr.ID]struct{}, numIds) theirItems := make(map[nostr.ID]struct{}, numIds)
for i := 0; i < numIds; i++ { for i := 0; i < numIds; i++ {
var id [32]byte var id [32]byte
if err := reader.ReadHexBytes(id[:]); err != nil { if _, err := reader.Read(id[:]); err != nil {
return "", fmt.Errorf("failed to read id (#%d/%d) in list: %w", i, numIds, err) return nil, fmt.Errorf("failed to read id (#%d/%d) in list: %w", i, numIds, err)
} else { } else {
theirItems[id] = struct{}{} theirItems[id] = struct{}{}
} }
@@ -197,33 +204,32 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
// server got list of ids, reply with their own ids for the same range // server got list of ids, reply with their own ids for the same range
finishSkip() finishSkip()
responseIds := strings.Builder{} responseIds := make([]byte, 0, 32*100)
responseIds.Grow(64 * 100)
responses := 0 responses := 0
endBound := currBound endBound := currBound
for index, item := range n.storage.Range(lower, upper) { for index, item := range n.storage.Range(lower, upper) {
if n.frameSizeLimit-200 < fullOutput.Len()/2+responseIds.Len()/2 { if n.frameSizeLimit-200 < fullOutput.Len()/2+len(responseIds)/2 {
endBound = Bound{item.Timestamp, item.ID[:]} endBound = Bound{item.Timestamp, item.ID[:]}
upper = index upper = index
break break
} }
responseIds.WriteString(hex.EncodeToString(item.ID[:])) responseIds = append(responseIds, item.ID[:]...)
responses++ responses++
} }
n.writeBound(partialOutput, endBound) n.writeBound(partialOutput, endBound)
partialOutput.WriteByte(byte(IdListMode)) partialOutput.WriteByte(byte(IdListMode))
writeVarInt(partialOutput, responses) writeVarInt(partialOutput, responses)
partialOutput.WriteHex(responseIds.String()) partialOutput.Write(responseIds)
fullOutput.WriteHex(partialOutput.Hex()) io.Copy(fullOutput, partialOutput)
partialOutput.Reset() partialOutput.Reset()
} }
default: default:
return "", fmt.Errorf("unexpected mode %d", mode) return nil, fmt.Errorf("unexpected mode %d", mode)
} }
if n.frameSizeLimit-200 < fullOutput.Len()/2+partialOutput.Len()/2 { if n.frameSizeLimit-200 < fullOutput.Len()/2+partialOutput.Len()/2 {
@@ -231,22 +237,22 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
remainingFingerprint := n.storage.Fingerprint(upper, n.storage.Size()) remainingFingerprint := n.storage.Fingerprint(upper, n.storage.Size())
n.writeBound(fullOutput, InfiniteBound) n.writeBound(fullOutput, InfiniteBound)
fullOutput.WriteByte(byte(FingerprintMode)) fullOutput.WriteByte(byte(FingerprintMode))
fullOutput.WriteHex(remainingFingerprint) fullOutput.Write(remainingFingerprint[:])
break // stop processing further break // stop processing further
} else { } else {
// append the constructed output for this iteration // append the constructed output for this iteration
fullOutput.WriteHex(partialOutput.Hex()) io.Copy(fullOutput, partialOutput)
} }
prevIndex = upper prevIndex = upper
prevBound = currBound prevBound = currBound
} }
return fullOutput.Hex(), nil return fullOutput.Bytes(), nil
} }
func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *StringHexWriter) { func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *bytes.Buffer) {
numElems := upper - lower numElems := upper - lower
if numElems < buckets*2 { if numElems < buckets*2 {
@@ -256,7 +262,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *Stri
writeVarInt(output, numElems) writeVarInt(output, numElems)
for _, item := range n.storage.Range(lower, upper) { for _, item := range n.storage.Range(lower, upper) {
output.WriteBytes(item.ID[:]) output.Write(item.ID[:])
} }
} else { } else {
itemsPerBucket := numElems / buckets itemsPerBucket := numElems / buckets
@@ -291,7 +297,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *Stri
n.writeBound(output, nextBound) n.writeBound(output, nextBound)
output.WriteByte(byte(FingerprintMode)) output.WriteByte(byte(FingerprintMode))
output.WriteHex(ourFingerprint) output.Write(ourFingerprint[:])
} }
} }
} }

View File

@@ -1,13 +0,0 @@
package negentropy
import (
"iter"
)
type Storage interface {
Size() int
Range(begin, end int) iter.Seq2[int, Item]
FindLowerBound(begin, end int, value Bound) int
GetBound(idx int) Bound
Fingerprint(begin, end int) string
}

View File

@@ -3,7 +3,6 @@ package storage
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex"
"fiatjaf.com/nostr/nip77/negentropy" "fiatjaf.com/nostr/nip77/negentropy"
) )
@@ -41,9 +40,9 @@ func (acc *Accumulator) AddBytes(other []byte) {
} }
} }
func (acc *Accumulator) GetFingerprint(n int) string { func (acc *Accumulator) GetFingerprint(n int) [negentropy.FingerprintSize]byte {
input := acc.Buf[:32] input := acc.Buf[:32]
input = append(input, negentropy.EncodeVarInt(n)...) input = append(input, negentropy.EncodeVarInt(n)...)
hash := sha256.Sum256(input) hash := sha256.Sum256(input)
return hex.EncodeToString(hash[:negentropy.FingerprintSize]) return [negentropy.FingerprintSize]byte(hash[:negentropy.FingerprintSize])
} }

View File

@@ -23,6 +23,6 @@ func (Empty) GetBound(idx int) negentropy.Bound {
return negentropy.InfiniteBound return negentropy.InfiniteBound
} }
func (Empty) Fingerprint(begin, end int) string { func (Empty) Fingerprint(begin, end int) [negentropy.FingerprintSize]byte {
return acc.GetFingerprint(end - begin) return acc.GetFingerprint(end - begin)
} }

View File

@@ -59,7 +59,7 @@ func (v *Vector) FindLowerBound(begin, end int, bound negentropy.Bound) int {
return begin + idx return begin + idx
} }
func (v *Vector) Fingerprint(begin, end int) string { func (v *Vector) Fingerprint(begin, end int) [negentropy.FingerprintSize]byte {
v.acc.Reset() v.acc.Reset()
for _, item := range v.Range(begin, end) { for _, item := range v.Range(begin, end) {

View File

@@ -2,6 +2,7 @@ package negentropy
import ( import (
"fmt" "fmt"
"iter"
"fiatjaf.com/nostr" "fiatjaf.com/nostr"
) )
@@ -47,3 +48,11 @@ func (b Bound) String() string {
} }
return fmt.Sprintf("Bound<%d:%x>", b.Timestamp, b.IDPrefix) return fmt.Sprintf("Bound<%d:%x>", b.Timestamp, b.IDPrefix)
} }
type Storage interface {
Size() int
Range(begin, end int) iter.Seq2[int, Item]
FindLowerBound(begin, end int, value Bound) int
GetBound(idx int) Bound
Fingerprint(begin, end int) [FingerprintSize]byte
}