diff --git a/sdk/event_relays.go b/sdk/event_relays.go new file mode 100644 index 0000000..4716eaa --- /dev/null +++ b/sdk/event_relays.go @@ -0,0 +1,115 @@ +package sdk + +import ( + "encoding/binary" + "encoding/hex" + "fmt" +) + +const eventRelayPrefix = byte('r') + +func makeEventRelayKey(eventID []byte) []byte { + // format: 'r' + first 8 bytes of event ID + key := make([]byte, 9) + key[0] = eventRelayPrefix + copy(key[1:], eventID[:8]) + return key +} + +func encodeRelayList(relays []string) []byte { + totalSize := 0 + for _, relay := range relays { + totalSize += 2 + len(relay) // 2 bytes for length prefix + } + + buf := make([]byte, totalSize) + offset := 0 + + for _, relay := range relays { + binary.LittleEndian.PutUint16(buf[offset:], uint16(len(relay))) + offset += 2 + copy(buf[offset:], relay) + offset += len(relay) + } + + return buf +} + +func decodeRelayList(data []byte) []string { + relays := make([]string, 0) + offset := 0 + + for offset < len(data) { + if offset+2 > len(data) { + return nil // malformed + } + + length := int(binary.LittleEndian.Uint16(data[offset:])) + offset += 2 + + if offset+length > len(data) { + return nil // malformed + } + + relay := string(data[offset : offset+length]) + relays = append(relays, relay) + offset += length + } + + return relays +} + +func (sys *System) trackEventRelayCommon(eventID string, relay string) { + // decode the event ID hex into bytes + idBytes, err := hex.DecodeString(eventID) + if err != nil || len(idBytes) < 8 { + return + } + + // get the key for this event + key := makeEventRelayKey(idBytes) + + // update the relay list atomically + sys.KVStore.Update(key, func(data []byte) ([]byte, error) { + var relays []string + if data != nil { + relays = decodeRelayList(data) + } else { + relays = make([]string, 0, 1) + } + + // check if relay is already in list + for _, r := range relays { + if r == relay { + return data, nil // no change needed + } + } + + // append new relay + relays = append(relays, relay) + return encodeRelayList(relays), nil + }) +} + +// GetEventRelays returns all known relay URLs that have been seen to carry the given event. +func (sys *System) GetEventRelays(eventID string) ([]string, error) { + // decode the event ID hex into bytes + idBytes, err := hex.DecodeString(eventID) + if err != nil || len(idBytes) < 8 { + return nil, fmt.Errorf("invalid event id") + } + + // get the key for this event + key := makeEventRelayKey(idBytes) + + // get stored relay list + data, err := sys.KVStore.Get(key) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + + return decodeRelayList(data), nil +} diff --git a/sdk/kvstore/badger/store.go b/sdk/kvstore/badger/store.go index db957ea..57b55c4 100644 --- a/sdk/kvstore/badger/store.go +++ b/sdk/kvstore/badger/store.go @@ -58,27 +58,31 @@ func (s *Store) Close() error { return s.db.Close() } -func (s *Store) Scan(prefix []byte, fn func(key []byte, value []byte) bool) error { - return s.db.View(func(txn *badger.Txn) error { - it := txn.NewIterator(badger.DefaultIteratorOptions) - defer it.Close() - - for it.Seek(prefix); it.ValidForPrefix(prefix); it.Next() { - item := it.Item() - err := item.Value(func(v []byte) error { - k := item.Key() - if !fn(k, v) { - return badger.ErrStopIteration - } +func (s *Store) Update(key []byte, f func([]byte) ([]byte, error)) error { + return s.db.Update(func(txn *badger.Txn) error { + var val []byte + item, err := txn.Get(key) + if err == nil { + err = item.Value(func(v []byte) error { + val = make([]byte, len(v)) + copy(val, v) return nil }) - if err == badger.ErrStopIteration { - break - } if err != nil { return err } + } else if err != badger.ErrKeyNotFound { + return err } - return nil + + newVal, err := f(val) + if err != nil { + return err + } + + if newVal == nil { + return txn.Delete(key) + } + return txn.Set(key, newVal) }) } diff --git a/sdk/kvstore/interface.go b/sdk/kvstore/interface.go index c963a4c..c16b50a 100644 --- a/sdk/kvstore/interface.go +++ b/sdk/kvstore/interface.go @@ -14,7 +14,9 @@ type KVStore interface { // Close releases any resources held by the store Close() error - // Scan iterates through all keys with the given prefix. - // For each key-value pair, fn is called. If fn returns false, iteration stops. - Scan(prefix []byte, fn func(key []byte, value []byte) bool) error + // Update atomically modifies a value for a given key. + // The function f receives the current value (nil if not found) + // and returns the new value to be set. + // If f returns nil, the key is deleted. + Update(key []byte, f func([]byte) ([]byte, error)) error } diff --git a/sdk/kvstore/lmdb/store.go b/sdk/kvstore/lmdb/store.go index f8374d6..1be0e21 100644 --- a/sdk/kvstore/lmdb/store.go +++ b/sdk/kvstore/lmdb/store.go @@ -91,22 +91,26 @@ func (s *Store) Close() error { return nil } -func (s *Store) Scan(prefix []byte, fn func(key []byte, value []byte) bool) error { - return s.env.View(func(txn *lmdb.Txn) error { - cursor, err := txn.OpenCursor(s.dbi) +func (s *Store) Update(key []byte, f func([]byte) ([]byte, error)) error { + return s.env.Update(func(txn *lmdb.Txn) error { + var val []byte + v, err := txn.Get(s.dbi, key) + if err == nil { + // make a copy since v is only valid during the transaction + val = make([]byte, len(v)) + copy(val, v) + } else if !lmdb.IsNotFound(err) { + return err + } + + newVal, err := f(val) if err != nil { return err } - defer cursor.Close() - for k, v, err := cursor.Get(prefix, nil, lmdb.SetRange); err == nil; k, v, err = cursor.Get(nil, nil, lmdb.Next) { - if !bytes.HasPrefix(k, prefix) { - break - } - if !fn(k, v) { - break - } + if newVal == nil { + return txn.Del(s.dbi, key, nil) } - return nil + return txn.Put(s.dbi, key, newVal, 0) }) } diff --git a/sdk/kvstore/memory/store.go b/sdk/kvstore/memory/store.go index 524083a..ff8a807 100644 --- a/sdk/kvstore/memory/store.go +++ b/sdk/kvstore/memory/store.go @@ -57,17 +57,29 @@ func (s *Store) Close() error { return nil } -func (s *Store) Scan(prefix []byte, fn func(key []byte, value []byte) bool) error { - s.RLock() - defer s.RUnlock() +func (s *Store) Update(key []byte, f func([]byte) ([]byte, error)) error { + s.Lock() + defer s.Unlock() - prefixStr := string(prefix) - for k, v := range s.data { - if strings.HasPrefix(k, prefixStr) { - if !fn([]byte(k), v) { - break - } - } + var val []byte + if v, ok := s.data[string(key)]; ok { + // Return a copy to prevent modification of stored data + val = make([]byte, len(v)) + copy(val, v) + } + + newVal, err := f(val) + if err != nil { + return err + } + + if newVal == nil { + delete(s.data, string(key)) + } else { + // Store a copy to prevent modification of stored data + cp := make([]byte, len(newVal)) + copy(cp, newVal) + s.data[string(key)] = cp } return nil } diff --git a/sdk/system.go b/sdk/system.go index 7976db5..541b191 100644 --- a/sdk/system.go +++ b/sdk/system.go @@ -139,33 +139,6 @@ func (sys *System) Close() { } } -// GetEventRelays returns all known relay URLs that have been seen to carry the given event. -func (sys *System) GetEventRelays(eventID string) ([]string, error) { - // decode the event ID hex into bytes - idBytes, err := hex.DecodeString(eventID) - if err != nil || len(idBytes) < 8 { - return nil, fmt.Errorf("invalid event id") - } - - // create prefix for scanning: 'r' + first 8 bytes of event ID - prefix := make([]byte, 9) - prefix[0] = eventRelayPrefix - copy(prefix[1:], idBytes[:8]) - - relays := make([]string, 0) - err = sys.KVStore.Scan(prefix, func(key []byte, value []byte) bool { - // extract relay URL from key (everything after prefix) - relay := string(key[9:]) - relays = append(relays, relay) - return true - }) - if err != nil { - return nil, err - } - - return relays, nil -} - func WithHintsDB(hdb hints.HintsDB) SystemModifier { return func(sys *System) { sys.Hints = hdb diff --git a/sdk/tracker.go b/sdk/tracker.go index ce32047..78362b0 100644 --- a/sdk/tracker.go +++ b/sdk/tracker.go @@ -1,7 +1,6 @@ package sdk import ( - "encoding/hex" "net/url" "github.com/nbd-wtf/go-nostr" @@ -110,37 +109,10 @@ func (sys *System) TrackEventHints(ie nostr.RelayEvent) { } } -const eventRelayPrefix = byte('r') - -func makeEventRelayKey(eventID []byte, relay string) []byte { - // Format: 'r' + first 8 bytes of event ID + relay URL - key := make([]byte, 1+8+len(relay)) - key[0] = eventRelayPrefix - copy(key[1:], eventID[:8]) - copy(key[9:], relay) - return key -} - func (sys *System) TrackEventRelays(ie nostr.RelayEvent) { - // decode the event ID hex into bytes - idBytes, err := hex.DecodeString(ie.ID) - if err != nil || len(idBytes) < 8 { - return - } - - // store with prefix + eventid + relay format - key := makeEventRelayKey(idBytes, ie.Relay.URL) - sys.KVStore.Set(key, nil) // value is not needed since relay is in key + sys.trackEventRelayCommon(ie.ID, ie.Relay.URL) } func (sys *System) TrackEventRelaysD(relay, id string) { - // decode the event ID hex into bytes - idBytes, err := hex.DecodeString(id) - if err != nil || len(idBytes) < 8 { - return - } - - // store with prefix + eventid + relay format - key := makeEventRelayKey(idBytes, relay) - sys.KVStore.Set(key, nil) // value is not needed since relay is in key + sys.trackEventRelayCommon(id, relay) }