diff --git a/nip04/nip04.go b/nip04/nip04.go index c2d990e..96b53a0 100644 --- a/nip04/nip04.go +++ b/nip04/nip04.go @@ -51,10 +51,12 @@ func Encrypt(message string, key []byte) (string, error) { } mode := cipher.NewCBCEncrypter(block, iv) - // PKCS5 padding - padding := block.BlockSize() - len([]byte(message))%block.BlockSize() - padtext := bytes.Repeat([]byte{byte(padding)}, padding) - paddedMsgBytes := append([]byte(message), padtext...) + plaintext := []byte(message) + + // add padding + padding := block.BlockSize() - len(plaintext)%block.BlockSize() // will be a number between 1 and 16 (inc), never 0 + padtext := bytes.Repeat([]byte{byte(padding)}, padding) // encode the padding in all the padding bytes + paddedMsgBytes := append(plaintext, padtext...) ciphertext := make([]byte, len(paddedMsgBytes)) mode.CryptBlocks(ciphertext, paddedMsgBytes) @@ -87,5 +89,9 @@ func Decrypt(content string, key []byte) (string, error) { plaintext := make([]byte, len(ciphertext)) mode.CryptBlocks(plaintext, ciphertext) - return string(plaintext[:]), nil + // remove padding + padding := int(plaintext[len(plaintext)-1]) // the padding amount is encoded in the padding bytes themselves + message := string(plaintext[0 : len(plaintext)-padding]) + + return message, nil } diff --git a/nip04/nip04_test.go b/nip04/nip04_test.go new file mode 100644 index 0000000..9f5f973 --- /dev/null +++ b/nip04/nip04_test.go @@ -0,0 +1,47 @@ +package nip04 + +import ( + "strings" + "testing" +) + +func TestEncryptionAndDecryption(t *testing.T) { + sharedSecret := make([]byte, 32) + message := "hello hellow" + + ciphertext, err := Encrypt(message, sharedSecret) + if err != nil { + t.Errorf("failed to encrypt: %s", err.Error()) + } + + plaintext, err := Decrypt(ciphertext, sharedSecret) + if err != nil { + t.Errorf("failed to decrypt: %s", err.Error()) + } + + if message != plaintext { + t.Errorf("original '%s' and decrypted '%s' messages differ", message, plaintext) + } +} + +func TestEncryptionAndDecryptionWithMultipleLengths(t *testing.T) { + sharedSecret := make([]byte, 32) + + for i := 0; i < 150; i++ { + message := strings.Repeat("a", i) + + ciphertext, err := Encrypt(message, sharedSecret) + if err != nil { + t.Errorf("failed to encrypt: %s", err.Error()) + } + + plaintext, err := Decrypt(ciphertext, sharedSecret) + if err != nil { + t.Errorf("failed to decrypt: %s", err.Error()) + } + + if message != plaintext { + t.Errorf("original '%s' and decrypted '%s' messages differ", message, plaintext) + } + } +}