diff --git a/app/src/main/java/com/beemdevelopment/aegis/crypto/CryptoUtils.java b/app/src/main/java/com/beemdevelopment/aegis/crypto/CryptoUtils.java index d5720b54..0497d8ed 100644 --- a/app/src/main/java/com/beemdevelopment/aegis/crypto/CryptoUtils.java +++ b/app/src/main/java/com/beemdevelopment/aegis/crypto/CryptoUtils.java @@ -35,10 +35,14 @@ public class CryptoUtils { public static final int CRYPTO_SCRYPT_r = 8; public static final int CRYPTO_SCRYPT_p = 1; + public static SecretKey deriveKey(byte[] input, SCryptParameters params) { + byte[] keyBytes = SCrypt.generate(input, params.getSalt(), params.getN(), params.getR(), params.getP(), CRYPTO_AEAD_KEY_SIZE); + return new SecretKeySpec(keyBytes, 0, keyBytes.length, "AES"); + } + public static SecretKey deriveKey(char[] password, SCryptParameters params) { byte[] bytes = toBytes(password); - byte[] keyBytes = SCrypt.generate(bytes, params.getSalt(), params.getN(), params.getR(), params.getP(), CRYPTO_AEAD_KEY_SIZE); - return new SecretKeySpec(keyBytes, 0, keyBytes.length, "AES"); + return deriveKey(bytes, params); } public static Cipher createEncryptCipher(SecretKey key) @@ -123,6 +127,8 @@ public class CryptoUtils { private static byte[] toBytes(char[] chars) { CharBuffer charBuf = CharBuffer.wrap(chars); ByteBuffer byteBuf = StandardCharsets.UTF_8.encode(charBuf); - return byteBuf.array(); + byte[] bytes = new byte[byteBuf.limit()]; + byteBuf.get(bytes); + return bytes; } } diff --git a/app/src/test/java/com/beemdevelopment/aegis/SCryptTest.java b/app/src/test/java/com/beemdevelopment/aegis/SCryptTest.java new file mode 100644 index 00000000..eacad2e2 --- /dev/null +++ b/app/src/test/java/com/beemdevelopment/aegis/SCryptTest.java @@ -0,0 +1,40 @@ +package com.beemdevelopment.aegis; + +import com.beemdevelopment.aegis.crypto.CryptoUtils; +import com.beemdevelopment.aegis.crypto.SCryptParameters; +import com.beemdevelopment.aegis.encoding.Hex; +import com.beemdevelopment.aegis.encoding.HexException; + +import org.junit.Test; + +import javax.crypto.SecretKey; + +import static org.junit.Assert.*; + +public class SCryptTest { + @Test + public void testTrailingNullCollision() throws HexException { + byte[] salt = new byte[0]; + SCryptParameters params = new SCryptParameters( + CryptoUtils.CRYPTO_SCRYPT_N, + CryptoUtils.CRYPTO_SCRYPT_p, + CryptoUtils.CRYPTO_SCRYPT_r, + salt + ); + + byte[] expectedKey = Hex.decode("41cd8110d0c66ede16f97ce84fd8e2bd2269c9318532a01437789dfbadd1392e"); + byte[][] inputs = new byte[][]{ + new byte[]{'t', 'e', 's', 't'}, + new byte[]{'t', 'e', 's', 't', '\0'}, + new byte[]{'t', 'e', 's', 't', '\0', '\0'}, + new byte[]{'t', 'e', 's', 't', '\0', '\0', '\0'}, + new byte[]{'t', 'e', 's', 't', '\0', '\0', '\0', '\0'}, + new byte[]{'t', 'e', 's', 't', '\0', '\0', '\0', '\0', '\0'}, + }; + + for (byte[] input : inputs) { + SecretKey key = CryptoUtils.deriveKey(input, params); + assertArrayEquals(expectedKey, key.getEncoded()); + } + } +}