From 286040c4ce8e54a6e54ab7aeb9220ce873d25177 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Sat, 14 Sep 2024 16:28:19 -0300 Subject: [PATCH] negentropy: do the algorithm entirely in hex. plus: - nicer iterators - some optimizations here and there. - something else I forgot. --- nip77/negentropy/encoding.go | 64 +++++---- nip77/negentropy/hex.go | 96 +++++++++++++ nip77/negentropy/negentropy.go | 223 ++++++++++++++---------------- nip77/negentropy/types.go | 45 +++--- nip77/negentropy/vector.go | 26 ++-- nip77/negentropy/whatever_test.go | 33 ++--- 6 files changed, 293 insertions(+), 194 deletions(-) create mode 100644 nip77/negentropy/hex.go diff --git a/nip77/negentropy/encoding.go b/nip77/negentropy/encoding.go index 08da9f9..317b1e1 100644 --- a/nip77/negentropy/encoding.go +++ b/nip77/negentropy/encoding.go @@ -1,13 +1,12 @@ package negentropy import ( - "bytes" - "encoding/hex" + "fmt" "github.com/nbd-wtf/go-nostr" ) -func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, error) { +func (n *Negentropy) DecodeTimestampIn(reader *StringHexReader) (nostr.Timestamp, error) { t, err := decodeVarInt(reader) if err != nil { return 0, err @@ -28,47 +27,42 @@ func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, e return timestamp, nil } -func (n *Negentropy) DecodeBound(reader *bytes.Reader) (Bound, error) { +func (n *Negentropy) DecodeBound(reader *StringHexReader) (Bound, error) { timestamp, err := n.DecodeTimestampIn(reader) if err != nil { - return Bound{}, err + return Bound{}, fmt.Errorf("failed to decode bound timestamp: %w", err) } length, err := decodeVarInt(reader) if err != nil { - return Bound{}, err + return Bound{}, fmt.Errorf("failed to decode bound length: %w", err) } - id := make([]byte, length) - if _, err = reader.Read(id); err != nil { - return Bound{}, err + id, err := reader.ReadString(length * 2) + if err != nil { + return Bound{}, fmt.Errorf("failed to read bound id: %w", err) } - return Bound{Item{timestamp, hex.EncodeToString(id)}}, nil + return Bound{Item{timestamp, id}}, nil } -func (n *Negentropy) encodeTimestampOut(timestamp nostr.Timestamp) []byte { +func (n *Negentropy) encodeTimestampOut(w *StringHexWriter, timestamp nostr.Timestamp) { if timestamp == maxTimestamp { n.lastTimestampOut = maxTimestamp - return encodeVarInt(0) + encodeVarIntToHex(w, 0) + return } temp := timestamp timestamp -= n.lastTimestampOut n.lastTimestampOut = temp - return encodeVarInt(int(timestamp + 1)) + encodeVarIntToHex(w, int(timestamp+1)) + return } -func (n *Negentropy) encodeBound(bound Bound) []byte { - var output []byte - - t := n.encodeTimestampOut(bound.Timestamp) - idlen := encodeVarInt(len(bound.ID) / 2) - output = append(output, t...) - output = append(output, idlen...) - id, _ := hex.DecodeString(bound.Item.ID) - - output = append(output, id...) - return output +func (n *Negentropy) encodeBound(w *StringHexWriter, bound Bound) { + n.encodeTimestampOut(w, bound.Timestamp) + encodeVarIntToHex(w, len(bound.ID)/2) + w.WriteHex(bound.Item.ID) } func getMinimalBound(prev, curr Item) Bound { @@ -89,11 +83,11 @@ func getMinimalBound(prev, curr Item) Bound { return Bound{Item{curr.Timestamp, curr.ID[:(sharedPrefixBytes+1)*2]}} } -func decodeVarInt(reader *bytes.Reader) (int, error) { +func decodeVarInt(reader *StringHexReader) (int, error) { var res int = 0 for { - b, err := reader.ReadByte() + b, err := reader.ReadHexByte() if err != nil { return 0, err } @@ -124,3 +118,21 @@ func encodeVarInt(n int) []byte { return o } + +func encodeVarIntToHex(w *StringHexWriter, n int) { + if n == 0 { + w.WriteByte(0) + } + + var o []byte + for n != 0 { + o = append([]byte{byte(n & 0x7F)}, o...) + n >>= 7 + } + + for i := 0; i < len(o)-1; i++ { + o[i] |= 0x80 + } + + w.WriteBytes(o) +} diff --git a/nip77/negentropy/hex.go b/nip77/negentropy/hex.go new file mode 100644 index 0000000..36f482e --- /dev/null +++ b/nip77/negentropy/hex.go @@ -0,0 +1,96 @@ +package negentropy + +import ( + "encoding/hex" + "io" +) + +func NewStringHexReader(source string) *StringHexReader { + return &StringHexReader{source, 0, make([]byte, 1)} +} + +type StringHexReader struct { + source string + idx int + + tmp []byte +} + +func (r *StringHexReader) Len() int { + return len(r.source) - r.idx +} + +func (r *StringHexReader) ReadHexBytes(buf []byte) error { + n := len(buf) * 2 + r.idx += n + if len(r.source) < r.idx { + return io.EOF + } + _, err := hex.Decode(buf, []byte(r.source[r.idx-n:r.idx])) + return err +} + +func (r *StringHexReader) ReadHexByte() (byte, error) { + err := r.ReadHexBytes(r.tmp) + return r.tmp[0], err +} + +func (r *StringHexReader) ReadString(size int) (string, error) { + if size == 0 { + return "", nil + } + r.idx += size + if len(r.source) < r.idx { + return "", io.EOF + } + return r.source[r.idx-size : r.idx], nil +} + +func NewStringHexWriter(buf []byte) *StringHexWriter { + return &StringHexWriter{buf, make([]byte, 2)} +} + +type StringHexWriter struct { + hexbuf []byte + + tmp []byte +} + +func (r *StringHexWriter) Len() int { + return len(r.hexbuf) +} + +func (r *StringHexWriter) Hex() string { + return string(r.hexbuf) +} + +func (r *StringHexWriter) Reset() { + r.hexbuf = r.hexbuf[:0] +} + +func (r *StringHexWriter) WriteHex(hexString string) { + r.hexbuf = append(r.hexbuf, hexString...) + return +} + +func (r *StringHexWriter) WriteByte(b byte) error { + hex.Encode(r.tmp, []byte{b}) + r.hexbuf = append(r.hexbuf, r.tmp...) + return nil +} + +func (r *StringHexWriter) WriteBytes(in []byte) { + r.hexbuf = hex.AppendEncode(r.hexbuf, in) + + // curr := len(r.hexbuf) + // next := curr + len(in)*2 + // for cap(r.hexbuf) < next { + // r.hexbuf = append(r.hexbuf, in...) + // } + // r.hexbuf = r.hexbuf[0:next] + // dst := r.hexbuf[curr:next] + + // hex.Encode(dst, in) + + return +} diff --git a/nip77/negentropy/negentropy.go b/nip77/negentropy/negentropy.go index e382d3f..3060354 100644 --- a/nip77/negentropy/negentropy.go +++ b/nip77/negentropy/negentropy.go @@ -1,11 +1,10 @@ package negentropy import ( - "bytes" - "encoding/hex" "fmt" "math" - "os" + "slices" + "strings" "unsafe" "github.com/nbd-wtf/go-nostr" @@ -22,7 +21,7 @@ type Negentropy struct { storage Storage sealed bool frameSizeLimit int - isInitiator bool + isClient bool lastTimestampIn nostr.Timestamp lastTimestampOut nostr.Timestamp @@ -37,6 +36,17 @@ func NewNegentropy(storage Storage, frameSizeLimit int) *Negentropy { } } +func (n *Negentropy) String() string { + label := "unsealed" + if n.sealed { + label = "server" + if n.isClient { + label = "client" + } + } + return fmt.Sprintf("", label, n.storage.Size()) +} + func (n *Negentropy) Insert(evt *nostr.Event) { err := n.storage.Insert(evt.CreatedAt, evt.ID) if err != nil { @@ -51,83 +61,76 @@ func (n *Negentropy) seal() { n.sealed = true } -func (n *Negentropy) Initiate() []byte { +func (n *Negentropy) Initiate() string { n.seal() - n.isInitiator = true + n.isClient = true n.Haves = make(chan string, n.storage.Size()/2) n.HaveNots = make(chan string, n.storage.Size()/2) - output := bytes.NewBuffer(make([]byte, 0, 1+n.storage.Size()*32)) + output := NewStringHexWriter(make([]byte, 0, 1+n.storage.Size()*64)) output.WriteByte(protocolVersion) n.SplitRange(0, n.storage.Size(), infiniteBound, output) - return output.Bytes() + return output.Hex() } -func (n *Negentropy) Reconcile(msg []byte) (output []byte, err error) { +func (n *Negentropy) Reconcile(msg string) (output string, err error) { n.seal() - reader := bytes.NewReader(msg) + reader := NewStringHexReader(msg) output, err = n.reconcileAux(reader) if err != nil { - return nil, err + return "", err } - if len(output) == 1 && n.isInitiator { + if len(output) == 2 && n.isClient { close(n.Haves) close(n.HaveNots) - return nil, nil + return "", nil } return output, nil } -func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) { +func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) { n.lastTimestampIn, n.lastTimestampOut = 0, 0 // reset for each message - fullOutput := bytes.NewBuffer(make([]byte, 0, 5000)) + fullOutput := NewStringHexWriter(make([]byte, 0, 5000)) fullOutput.WriteByte(protocolVersion) - pv, err := reader.ReadByte() + pv, err := reader.ReadHexByte() if err != nil { - return nil, err - } - - if pv < 0x60 || pv > 0x6f { - return nil, fmt.Errorf("invalid protocol version byte") + return "", fmt.Errorf("failed to read pv: %w", err) } if pv != protocolVersion { - if n.isInitiator { - return nil, fmt.Errorf("unsupported negentropy protocol version requested") - } - return fullOutput.Bytes(), nil + return "", fmt.Errorf("unsupported negentropy protocol version %v", pv) } var prevBound Bound prevIndex := 0 - skip := false + skipping := false // this means we are currently coalescing ranges into skip - partialOutput := bytes.NewBuffer(make([]byte, 0, 100)) + partialOutput := NewStringHexWriter(make([]byte, 0, 100)) for reader.Len() > 0 { partialOutput.Reset() - doSkip := func() { - if skip { - skip = false - encodedBound := n.encodeBound(prevBound) - partialOutput.Write(encodedBound) - partialOutput.WriteByte(SkipMode) + finishSkip := func() { + // end skip range, if necessary, so we can start a new bound that isn't a skip + if skipping { + skipping = false + n.encodeBound(partialOutput, prevBound) + partialOutput.WriteByte(byte(SkipMode)) } } currBound, err := n.DecodeBound(reader) if err != nil { - return nil, err + return "", fmt.Errorf("failed to decode bound: %w", err) } modeVal, err := decodeVarInt(reader) if err != nil { - return nil, err + return "", fmt.Errorf("failed to decode mode: %w", err) } mode := Mode(modeVal) @@ -136,134 +139,129 @@ func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) { switch mode { case SkipMode: - skip = true + skipping = true case FingerprintMode: var theirFingerprint [FingerprintSize]byte - _, err := reader.Read(theirFingerprint[:]) - if err != nil { - return nil, err - } - ourFingerprint, err := n.storage.Fingerprint(lower, upper) - if err != nil { - return nil, err + if err := reader.ReadHexBytes(theirFingerprint[:]); err != nil { + return "", fmt.Errorf("failed to read fingerprint: %w", err) } + ourFingerprint := n.storage.Fingerprint(lower, upper) if theirFingerprint == ourFingerprint { - skip = true + skipping = true } else { - doSkip() + finishSkip() n.SplitRange(lower, upper, currBound, partialOutput) } case IdListMode: numIds, err := decodeVarInt(reader) if err != nil { - return nil, err + return "", fmt.Errorf("failed to decode number of ids: %w", err) } - theirElems := make(map[string]struct{}) - var idb [32]byte - + // what they have + theirItems := make([]string, 0, numIds) for i := 0; i < numIds; i++ { - _, err := reader.Read(idb[:]) - if err != nil { - return nil, err + if id, err := reader.ReadString(64); err != nil { + return "", fmt.Errorf("failed to read id (#%d/%d) in list: %w", i, numIds, err) + } else { + theirItems = append(theirItems, id) } - id := hex.EncodeToString(idb[:]) - theirElems[id] = struct{}{} } - n.storage.Iterate(lower, upper, func(item Item, _ int) bool { + // what we have + for _, item := range n.storage.Range(lower, upper) { id := item.ID - if _, exists := theirElems[id]; !exists { - if n.isInitiator { + + if idx, theyHave := slices.BinarySearch(theirItems, id); theyHave { + // if we have and they have, ignore + theirItems[idx] = "" + } else { + // if we have and they don't, notify client + if n.isClient { n.Haves <- id } - } else { - delete(theirElems, id) } - return true - }) + } - if n.isInitiator { - skip = true - for id := range theirElems { - n.HaveNots <- id + if n.isClient { + // notify client of what they have and we don't + for _, id := range theirItems { + if id != "" { + n.HaveNots <- id + } } + + // client got list of ids, it's done, skip + skipping = true } else { - doSkip() + // server got list of ids, reply with their own ids for the same range + finishSkip() + + responseIds := strings.Builder{} + responseIds.Grow(64 * 100) + responses := 0 - responseIds := make([]byte, 0, 32*n.storage.Size()) endBound := currBound - n.storage.Iterate(lower, upper, func(item Item, index int) bool { - if n.frameSizeLimit-200 < fullOutput.Len()+len(responseIds) { + for index, item := range n.storage.Range(lower, upper) { + if n.frameSizeLimit-200 < fullOutput.Len()+1+8+responseIds.Len() { endBound = Bound{item} upper = index - return false + break } + responseIds.WriteString(item.ID) + responses++ + } - id, _ := hex.DecodeString(item.ID) - responseIds = append(responseIds, id...) - return true - }) + n.encodeBound(partialOutput, endBound) + partialOutput.WriteByte(byte(IdListMode)) + encodeVarIntToHex(partialOutput, responses) + partialOutput.WriteHex(responseIds.String()) - encodedBound := n.encodeBound(endBound) - - partialOutput.Write(encodedBound) - partialOutput.WriteByte(IdListMode) - partialOutput.Write(encodeVarInt(len(responseIds) / 32)) - partialOutput.Write(responseIds) - - partialOutput.WriteTo(fullOutput) + fullOutput.WriteHex(partialOutput.Hex()) partialOutput.Reset() } default: - return nil, fmt.Errorf("unexpected mode %d", mode) + return "", fmt.Errorf("unexpected mode %d", mode) } if n.frameSizeLimit-200 < fullOutput.Len()+partialOutput.Len() { // frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range - remainingFingerprint, err := n.storage.Fingerprint(upper, n.storage.Size()) - if err != nil { - panic(err) - } - - fullOutput.Write(n.encodeBound(infiniteBound)) - fullOutput.WriteByte(FingerprintMode) - fullOutput.Write(remainingFingerprint[:]) + remainingFingerprint := n.storage.Fingerprint(upper, n.storage.Size()) + n.encodeBound(fullOutput, infiniteBound) + fullOutput.WriteByte(byte(FingerprintMode)) + fullOutput.WriteBytes(remainingFingerprint[:]) break // stop processing further } else { // append the constructed output for this iteration - partialOutput.WriteTo(fullOutput) + fullOutput.WriteHex(partialOutput.Hex()) } prevIndex = upper prevBound = currBound } - return fullOutput.Bytes(), nil + return fullOutput.Hex(), nil } -func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *bytes.Buffer) { +func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *StringHexWriter) { numElems := upper - lower const buckets = 16 if numElems < buckets*2 { // we just send the full ids here - boundEncoded := n.encodeBound(upperBound) - output.Write(boundEncoded) - output.WriteByte(IdListMode) - output.Write(encodeVarInt(numElems)) + n.encodeBound(output, upperBound) + output.WriteByte(byte(IdListMode)) + encodeVarIntToHex(output, numElems) - n.storage.Iterate(lower, upper, func(item Item, _ int) bool { - id, _ := hex.DecodeString(item.ID) - output.Write(id) - return true - }) + for _, item := range n.storage.Range(lower, upper) { + output.WriteHex(item.ID) + } } else { itemsPerBucket := numElems / buckets bucketsWithExtra := numElems % buckets @@ -274,12 +272,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *byte if i < bucketsWithExtra { bucketSize++ } - ourFingerprint, err := n.storage.Fingerprint(curr, curr+bucketSize) - if err != nil { - fmt.Fprintln(os.Stderr, err) - panic(err) - } - + ourFingerprint := n.storage.Fingerprint(curr, curr+bucketSize) curr += bucketSize var nextBound Bound @@ -288,23 +281,21 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *byte } else { var prevItem, currItem Item - n.storage.Iterate(curr-1, curr+1, func(item Item, index int) bool { + for index, item := range n.storage.Range(curr-1, curr+1) { if index == curr-1 { prevItem = item } else { currItem = item } - return true - }) + } minBound := getMinimalBound(prevItem, currItem) nextBound = minBound } - boundEncoded := n.encodeBound(nextBound) - output.Write(boundEncoded) - output.WriteByte(FingerprintMode) - output.Write(ourFingerprint[:]) + n.encodeBound(output, nextBound) + output.WriteByte(byte(FingerprintMode)) + output.WriteBytes(ourFingerprint[:]) } } } diff --git a/nip77/negentropy/types.go b/nip77/negentropy/types.go index a552677..c6e581e 100644 --- a/nip77/negentropy/types.go +++ b/nip77/negentropy/types.go @@ -1,10 +1,11 @@ package negentropy import ( + "cmp" "crypto/sha256" "encoding/binary" - "encoding/hex" "fmt" + "iter" "strings" "github.com/nbd-wtf/go-nostr" @@ -12,22 +13,35 @@ import ( const FingerprintSize = 16 -type Mode int +type Mode uint8 const ( - SkipMode = 0 - FingerprintMode = 1 - IdListMode = 2 + SkipMode Mode = 0 + FingerprintMode Mode = 1 + IdListMode Mode = 2 ) +func (v Mode) String() string { + switch v { + case SkipMode: + return "SKIP" + case FingerprintMode: + return "FINGERPRINT" + case IdListMode: + return "IDLIST" + default: + return "" + } +} + type Storage interface { Insert(nostr.Timestamp, string) error Seal() Size() int - Iterate(begin, end int, cb func(item Item, i int) bool) error + Range(begin, end int) iter.Seq2[int, Item] FindLowerBound(begin, end int, value Bound) int GetBound(idx int) Bound - Fingerprint(begin, end int) ([FingerprintSize]byte, error) + Fingerprint(begin, end int) [FingerprintSize]byte } type Item struct { @@ -36,10 +50,10 @@ type Item struct { } func itemCompare(a, b Item) int { - if a.Timestamp != b.Timestamp { - return int(a.Timestamp - b.Timestamp) + if a.Timestamp == b.Timestamp { + return strings.Compare(a.ID, b.ID) } - return strings.Compare(a.ID, b.ID) + return cmp.Compare(a.Timestamp, b.Timestamp) } func (i Item) String() string { return fmt.Sprintf("Item<%d:%s>", i.Timestamp, i.ID) } @@ -61,11 +75,6 @@ func (acc *Accumulator) SetToZero() { acc.Buf = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} } -func (acc *Accumulator) Add(id string) { - b, _ := hex.DecodeString(id) - acc.AddBytes(b) -} - func (acc *Accumulator) AddAccumulator(other Accumulator) { acc.AddBytes(other.Buf) } @@ -95,12 +104,8 @@ func (acc *Accumulator) AddBytes(other []byte) { } } -func (acc *Accumulator) SV() []byte { - return acc.Buf[:] -} - func (acc *Accumulator) GetFingerprint(n int) [FingerprintSize]byte { - input := acc.SV() + input := acc.Buf[:] input = append(input, encodeVarInt(n)...) hash := sha256.Sum256(input) diff --git a/nip77/negentropy/vector.go b/nip77/negentropy/vector.go index 46aaff5..854927d 100644 --- a/nip77/negentropy/vector.go +++ b/nip77/negentropy/vector.go @@ -1,7 +1,9 @@ package negentropy import ( + "encoding/hex" "fmt" + "iter" "slices" "github.com/nbd-wtf/go-nostr" @@ -45,13 +47,14 @@ func (v *Vector) GetBound(idx int) Bound { return infiniteBound } -func (v *Vector) Iterate(begin, end int, cb func(Item, int) bool) error { - for i := begin; i < end; i++ { - if !cb(v.items[i], i) { - break +func (v *Vector) Range(begin, end int) iter.Seq2[int, Item] { + return func(yield func(int, Item) bool) { + for i := begin; i < end; i++ { + if !yield(i, v.items[i]) { + break + } } } - return nil } func (v *Vector) FindLowerBound(begin, end int, bound Bound) int { @@ -59,16 +62,15 @@ func (v *Vector) FindLowerBound(begin, end int, bound Bound) int { return begin + idx } -func (v *Vector) Fingerprint(begin, end int) ([FingerprintSize]byte, error) { +func (v *Vector) Fingerprint(begin, end int) [FingerprintSize]byte { var out Accumulator out.SetToZero() - if err := v.Iterate(begin, end, func(item Item, _ int) bool { - out.Add(item.ID) - return true - }); err != nil { - return [FingerprintSize]byte{}, err + tmp := make([]byte, 32) + for _, item := range v.Range(begin, end) { + hex.Decode(tmp, []byte(item.ID)) + out.AddBytes(tmp) } - return out.GetFingerprint(end - begin), nil + return out.GetFingerprint(end - begin) } diff --git a/nip77/negentropy/whatever_test.go b/nip77/negentropy/whatever_test.go index 0a6efb7..e41a6bf 100644 --- a/nip77/negentropy/whatever_test.go +++ b/nip77/negentropy/whatever_test.go @@ -1,10 +1,9 @@ package negentropy import ( - "encoding/hex" "fmt" + "log" "slices" - "strings" "sync" "testing" @@ -60,7 +59,7 @@ func runTestWith(t *testing.T, expectedN1NeedRanges [][]int, expectedN1HaveRanges [][]int, ) { var err error - var q []byte + var q string var n1 *Negentropy var n2 *Negentropy @@ -109,18 +108,21 @@ func runTestWith(t *testing.T, wg := sync.WaitGroup{} wg.Add(3) + var fatal error + go func() { - wg.Done() - for n := n1; q != nil; n = invert[n] { + defer wg.Done() + for n := n1; q != ""; n = invert[n] { i++ + fmt.Println("processing reconcile", n) q, err = n.Reconcile(q) if err != nil { - t.Fatal(err) + fatal = err return } - if q == nil { + if q == "" { return } } @@ -141,6 +143,7 @@ func runTestWith(t *testing.T, } haves = append(haves, item) } + slices.Sort(haves) require.ElementsMatch(t, expectedHave, haves, "wrong have") }() @@ -159,22 +162,12 @@ func runTestWith(t *testing.T, } havenots = append(havenots, item) } + slices.Sort(havenots) require.ElementsMatch(t, expectedNeed, havenots, "wrong need") }() wg.Wait() -} - -func hexedBytes(o []byte) string { - s := strings.Builder{} - s.Grow(2 + 1 + len(o)*5) - s.WriteString("[ ") - for _, b := range o { - x := hex.EncodeToString([]byte{b}) - s.WriteString("0x") - s.WriteString(x) - s.WriteString(" ") + if fatal != nil { + log.Fatal(fatal) } - s.WriteString("]") - return s.String() }