diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 901 |
1 files changed, 901 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6f4d93f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,901 @@ +use std::iter; + +#[cfg(test)] +mod testvectors; + +type Block = [u8; 16]; +type Key = [u8; 48]; +type Tweak<'a> = &'a [&'a [u8]]; + +static NULL: Block = [0; 16]; +static ONE: Block = [128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + +pub struct Aez { + key: Key, +} + +impl Aez { + pub fn new(key: &[u8]) -> Self { + Aez { key: extract(key) } + } + + pub fn encrypt( + &self, + nonce: &[u8], + associated_data: &[&[u8]], + tau: u32, + data: &[u8], + ) -> Vec<u8> { + encrypt(&self.key, nonce, associated_data, tau, data) + } + + pub fn decrypt( + &self, + nonce: &[u8], + associated_data: &[&[u8]], + tau: u32, + data: &[u8], + ) -> Option<Vec<u8>> { + decrypt(&self.key, nonce, associated_data, tau, data) + } +} + +fn xor(lhs: &Block, rhs: &Block) -> Block { + let mut result = [0; 16]; + for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { + *r = a ^ b; + } + result +} + +fn and(lhs: &Block, rhs: &Block) -> Block { + let mut result = [0; 16]; + for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { + *r = a & b; + } + result +} + +fn or(lhs: &Block, rhs: &Block) -> Block { + let mut result = [0; 16]; + for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { + *r = a | b; + } + result +} + +fn lshift(block: &Block, times: u32) -> Block { + let mut block = block.clone(); + for _ in 0..times { + let mut result = [0; 16]; + for (b, r) in block.iter().zip(result.iter_mut()) { + *r = b << 1; + } + for (b, r) in block[1..].iter().zip(result.iter_mut()) { + *r = *r | ((b & 0x80) >> 7); + } + block = result; + } + block +} + +fn times(lhs: u32, block: &Block) -> Block { + match lhs { + 0 => NULL, + 1 => *block, + 2 => { + let mut result = lshift(block, 1); + if block[0] & 0x80 != 0 { + result[15] ^= 135; + } + result + } + _ if lhs % 2 == 0 => times(2, ×(lhs / 2, block)), + _ => xor(×(lhs - 1, block), block), + } +} + +fn aesenc(mut block: Block, key: &Block) -> Block { + aes::hazmat::cipher_round((&mut block).into(), key.into()); + block +} + +fn aes4(keys: &[&Block; 5], block: &Block) -> Block { + aesenc( + aesenc( + aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]), + keys[3], + ), + keys[4], + ) +} + +fn aes10(keys: &[&Block; 11], block: &Block) -> Block { + aesenc( + aesenc( + aesenc( + aesenc( + aesenc( + aesenc( + aesenc( + aesenc( + aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]), + keys[3], + ), + keys[4], + ), + keys[5], + ), + keys[6], + ), + keys[7], + ), + keys[8], + ), + keys[9], + ), + keys[10], + ) +} + +fn extract(key: &[u8]) -> [u8; 48] { + if key.len() == 48 { + key.try_into().unwrap() + } else { + use blake2::Digest; + type Blake2b384 = blake2::Blake2b<blake2::digest::consts::U48>; + let mut hasher = Blake2b384::new(); + hasher.update(key); + hasher.finalize().into() + } +} + +fn clip_to_bits(block: &mut Block, mut bits: usize) { + for byte in block { + if bits == 0 { + *byte = 0; + } else if bits < 8 { + *byte &= 0xff << (8 - bits); + } + bits = bits.saturating_sub(8); + } +} + +fn full_block(data: &[u8]) -> Block { + let mut result = [0; 16]; + result[..data.len()].copy_from_slice(data); + result +} + +fn pad_block(block: &Block, mut bits: usize) -> Block { + let mut block = *block; + for byte in &mut block { + if bits < 8 { + *byte |= 0x80 >> bits; + break; + } + bits = bits.saturating_sub(8); + } + block +} + +fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> { + let auth_message = message + .iter() + .copied() + .chain(iter::repeat_n(0, tau as usize)) + .collect::<Vec<_>>(); + // We treat tau as bytes, but according to the spec, tau is actually in bits. + let tau_block = tau_to_block(tau * 8); + let mut tweaks = vec![&tau_block, nonce]; + tweaks.extend(ad); + if message.is_empty() { + aez_prf(key, &tweaks, tau) + } else { + encipher(key, &tweaks, &auth_message) + } +} + +fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -> Option<Vec<u8>> { + if ciphertext.len() < tau as usize { + return None; + } + + let tau_block = tau_to_block(tau * 8); + let mut tweaks = vec![&tau_block, nonce]; + tweaks.extend(ad); + + if ciphertext.len() == tau as usize { + if ciphertext == aez_prf(key, &tweaks, tau) { + return Some(Vec::new()); + } else { + return None; + } + } + + let x = decipher(key, &tweaks, ciphertext); + let (m, auth) = x.split_at(ciphertext.len() - tau as usize); + assert!(auth.len() == tau as usize); + if auth.iter().all(|x| *x == 0) { + Some(Vec::from(m)) + } else { + None + } +} + +fn encipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> { + if message.len() < 256 / 8 { + encipher_aez_tiny(key, tweaks, message) + } else { + encipher_aez_core(key, tweaks, message) + } +} + +fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> { + let mu = message.len() * 8; + assert!(mu < 256); + let n = mu / 2; + let delta = aez_hash(key, tweaks); + let round_count = match mu { + 8 => 24, + 16 => 16, + _ if mu < 128 => 10, + _ => 8, + }; + + let (mut left, mut right); + // We might end up having to split at a nibble, so manually adjust for that + if n % 8 == 0 { + left = full_block(&message[..n / 8]); + right = full_block(&message[n / 8..]); + } else { + left = full_block(&message[..n / 8 + 1]); + clip_to_bits(&mut left, n); + right = full_block(&message[n / 8..]); + right = lshift(&right, 4); + }; + let i = if mu >= 128 { 6 } else { 7 }; + for j in 0..round_count { + let mut right_ = xor( + &left, + &e( + 0, + i, + key, + &xor( + &xor(&delta, &pad_block(&right, n)), + &(j as u128).to_be_bytes(), + ), + ), + ); + clip_to_bits(&mut right_, n); + (left, right) = (right, right_); + } + let mut ciphertext = Vec::new(); + if n % 8 == 0 { + ciphertext.extend_from_slice(&right[..n / 8]); + ciphertext.extend_from_slice(&left[..n / 8]); + } else { + ciphertext.extend_from_slice(&right[..n / 8 + 1]); + for byte in &left[..n / 8 + 1] { + *ciphertext.last_mut().unwrap() |= byte >> 4; + ciphertext.push((byte & 0x0f) << 4); + } + ciphertext.pop(); + } + if mu < 128 { + let mut c = Block::default(); + c[..ciphertext.len()].copy_from_slice(&ciphertext); + c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE)); + ciphertext = Vec::from(&c[..mu / 8]); + } + assert!(ciphertext.len() == message.len()); + ciphertext +} + +fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> { + assert!(message.len() >= 32); + let delta = aez_hash(key, tweaks); + let (block_pairs, m_u, m_v, m_x, m_y, d) = split_blocks(message); + let len_v = d.saturating_sub(128); + + let mut ws = Vec::new(); + let mut xs = Vec::new(); + + for (i, (mi, mi_)) in block_pairs.iter().enumerate() { + let i = (i + 1) as i32; + let w = xor(mi, &e(1, i, key, mi_)); + let x = xor(mi_, &e(0, 0, key, &w)); + ws.push(w); + xs.push(x); + } + + let mut x = NULL; + for xi in &xs { + x = xor(&x, xi); + } + + match d { + 0 => (), + _ if d <= 127 => { + x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into()))); + } + _ => { + x = xor(&x, &e(0, 4, key, &m_u)); + x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into()))); + } + } + + let s_x = xor(&m_x, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y)))); + let s_y = xor(&m_y, &e(-1, 1, key, &s_x)); + let s = xor(&s_x, &s_y); + + let mut cipher_pairs = Vec::new(); + let mut y = NULL; + for (i, (wi, xi)) in ws.iter().zip(xs.iter()).enumerate() { + let i = (i + 1) as i32; + let s_ = e(2, i, key, &s); + let yi = xor(wi, &s_); + let zi = xor(xi, &s_); + let ci_ = xor(&yi, &e(0, 0, key, &zi)); + let ci = xor(&zi, &e(1, i, key, &ci_)); + + cipher_pairs.push((ci, ci_)); + y = xor(&y, &yi); + } + + let mut c_u = [0; 16]; + let mut c_v = [0; 16]; + + match d { + 0 => (), + _ if d <= 127 => { + c_u = xor(&m_u, &e(-1, 4, key, &s)); + clip_to_bits(&mut c_u, d.into()); + y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into()))); + } + _ => { + c_u = xor(&m_u, &e(-1, 4, key, &s)); + c_v = xor(&m_v, &e(-1, 5, key, &s)); + clip_to_bits(&mut c_v, len_v.into()); + y = xor(&y, &e(0, 4, key, &c_u)); + y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into()))); + } + } + + let c_y = xor(&s_x, &e(-1, 2, key, &s_y)); + let c_x = xor(&s_y, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y)))); + + let mut ciphertext = Vec::new(); + for (ci, ci_) in cipher_pairs { + ciphertext.extend_from_slice(&ci); + ciphertext.extend_from_slice(&ci_); + } + ciphertext.extend_from_slice(&c_u[..128.min(d) as usize / 8]); + ciphertext.extend_from_slice(&c_v[..len_v as usize / 8]); + ciphertext.extend_from_slice(&c_x); + ciphertext.extend_from_slice(&c_y); + ciphertext +} + +fn decipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> { + if message.len() < 256 / 8 { + decipher_aez_tiny(key, tweaks, message) + } else { + decipher_aez_core(key, tweaks, message) + } +} + +fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> { + let mu = message.len() * 8; + assert!(mu < 256); + let n = mu / 2; + let delta = aez_hash(key, tweaks); + let round_count = match mu { + 8 => 24, + 16 => 16, + _ if mu < 128 => 10, + _ => 8, + }; + + let mut message = Vec::from(message); + if mu < 128 { + let mut c = Block::default(); + c[..message.len()].copy_from_slice(&message); + c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE)); + message.clear(); + message.extend(&c[..mu / 8]); + } + + let (mut left, mut right); + // We might end up having to split at a nibble, so manually adjust for that + if n % 8 == 0 { + left = full_block(&message[..n / 8]); + right = full_block(&message[n / 8..]); + } else { + left = full_block(&message[..n / 8 + 1]); + clip_to_bits(&mut left, n); + right = full_block(&message[n / 8..]); + right = lshift(&right, 4); + }; + let i = if mu >= 128 { 6 } else { 7 }; + for j in (0..round_count).rev() { + let mut right_ = xor( + &left, + &e( + 0, + i, + key, + &xor( + &xor(&delta, &pad_block(&right, n)), + &(j as u128).to_be_bytes(), + ), + ), + ); + clip_to_bits(&mut right_, n); + (left, right) = (right, right_); + } + let mut ciphertext = Vec::new(); + if n % 8 == 0 { + ciphertext.extend_from_slice(&right[..n / 8]); + ciphertext.extend_from_slice(&left[..n / 8]); + } else { + ciphertext.extend_from_slice(&right[..n / 8 + 1]); + for byte in &left[..n / 8 + 1] { + *ciphertext.last_mut().unwrap() |= byte >> 4; + ciphertext.push((byte & 0x0f) << 4); + } + ciphertext.pop(); + } + assert!(ciphertext.len() == message.len()); + ciphertext +} + +fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> { + assert!(cipher.len() >= 32); + let delta = aez_hash(key, tweaks); + let (block_pairs, c_u, c_v, c_x, c_y, d) = split_blocks(cipher); + let len_v = d.saturating_sub(128); + + let mut ws = Vec::new(); + let mut ys = Vec::new(); + + for (i, (ci, ci_)) in block_pairs.iter().enumerate() { + let i = (i + 1) as i32; + let w = xor(ci, &e(1, i, key, ci_)); + let y = xor(ci_, &e(0, 0, key, &w)); + ws.push(w); + ys.push(y); + } + + let mut y = NULL; + for yi in &ys { + y = xor(&y, yi); + } + + match d { + 0 => (), + _ if d <= 127 => { + y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into()))); + } + _ => { + y = xor(&y, &e(0, 4, key, &c_u)); + y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into()))); + } + } + + let s_x = xor(&c_x, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y)))); + let s_y = xor(&c_y, &e(-1, 2, key, &s_x)); + let s = xor(&s_x, &s_y); + + let mut plain_pairs = Vec::new(); + let mut x = NULL; + for (i, (wi, yi)) in ws.iter().zip(ys.iter()).enumerate() { + let i = (i + 1) as i32; + let s_ = e(2, i, key, &s); + let xi = xor(wi, &s_); + let zi = xor(yi, &s_); + let mi_ = xor(&xi, &e(0, 0, key, &zi)); + let mi = xor(&zi, &e(1, i, key, &mi_)); + + plain_pairs.push((mi, mi_)); + x = xor(&x, &xi); + } + + let mut m_u = [0; 16]; + let mut m_v = [0; 16]; + + match d { + 0 => (), + _ if d <= 127 => { + m_u = xor(&c_u, &e(-1, 4, key, &s)); + clip_to_bits(&mut m_u, d.into()); + x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into()))); + } + _ => { + m_u = xor(&c_u, &e(-1, 4, key, &s)); + m_v = xor(&c_v, &e(-1, 5, key, &s)); + clip_to_bits(&mut m_v, len_v.into()); + x = xor(&x, &e(0, 4, key, &m_u)); + x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into()))); + } + } + + let m_y = xor(&s_x, &e(-1, 1, key, &s_y)); + let m_x = xor(&s_y, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y)))); + + let mut message = Vec::new(); + for (mi, mi_) in plain_pairs { + message.extend_from_slice(&mi); + message.extend_from_slice(&mi_); + } + message.extend_from_slice(&m_u[..128.min(d) as usize / 8]); + message.extend_from_slice(&m_v[..len_v as usize / 8]); + message.extend_from_slice(&m_x); + message.extend_from_slice(&m_y); + message +} + +fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block, Block, u8) { + let num_blocks = (message.len() - 16 - 16) / 32; + let mut blocks = Vec::new(); + for _ in 0..num_blocks { + let (mut a, mut b) = ([0; 16], [0; 16]); + a.copy_from_slice(&message[..16]); + b.copy_from_slice(&message[16..32]); + blocks.push((a, b)); + message = &message[32..]; + } + let m_uv = &message[..message.len() - 32]; + let d = m_uv.len() * 8; + assert!(d < 256); + message = &message[m_uv.len()..]; + assert!(message.len() == 32); + + let mut m_u = [0; 16]; + let mut m_v = [0; 16]; + if d <= 127 { + m_u[..m_uv.len()].copy_from_slice(m_uv); + } else { + m_u.copy_from_slice(&m_uv[..16]); + m_v[..m_uv.len() - 16].copy_from_slice(&m_uv[16..]); + } + let mut m_x = [0; 16]; + m_x.copy_from_slice(&message[..16]); + let mut m_y = [0; 16]; + m_y.copy_from_slice(&message[16..]); + (blocks, m_u, m_v, m_x, m_y, d as u8) +} + +fn pad_to_blocks(value: &[u8]) -> Vec<Block> { + let mut blocks = Vec::new(); + for chunk in value.chunks(16) { + if chunk.len() == 16 { + blocks.push(chunk.try_into().expect("we made sure the length fits")); + } else { + let mut block = Block::default(); + for (b, v) in block.iter_mut().zip( + chunk + .iter() + .chain(iter::once(&0x80)) + .chain(iter::repeat(&0)), + ) { + *b = *v; + } + blocks.push(block) + } + } + blocks +} + +fn tau_to_block(tau: u32) -> Block { + (tau as u128).to_be_bytes() +} + +fn aez_hash(key: &Key, tweaks: Tweak) -> Block { + let mut hash = NULL; + for (i, tweak) in tweaks.iter().enumerate() { + // Adjust for zero-based vs one-based indexing + let j = i + 2 + 1; + // This is somewhat implicit in the AEZ spec, but basically for an empty string we still + // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this + // as well, but then we need to fiddle with getting an empty chunk from an empty iterator. + if tweak.is_empty() { + hash = xor( + &hash, + &e( + j.try_into().unwrap(), + 0, + key, + &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ), + ); + } else if tweak.len() % 16 == 0 { + for (l, chunk) in tweak.chunks(16).enumerate() { + hash = xor( + &hash, + &e( + j.try_into().unwrap(), + (l + 1).try_into().unwrap(), + key, + chunk.try_into().expect("we made sure the length fits"), + ), + ); + } + } else { + let blocks = pad_to_blocks(tweak); + for (l, chunk) in blocks.iter().enumerate() { + hash = xor( + &hash, + &e( + j.try_into().unwrap(), + if l == blocks.len() - 1 { + 0 + } else { + (l + 1).try_into().unwrap() + }, + key, + chunk, + ), + ); + } + } + } + hash +} + +fn aez_prf(key: &Key, tweaks: Tweak, tau: u32) -> Vec<u8> { + let mut result = Vec::new(); + let mut index = 0u128; + let delta = aez_hash(key, tweaks); + while result.len() < tau as usize { + let block = e(-1, 3, key, &xor(&delta, &index.to_be_bytes())); + result.extend_from_slice(&block[..16.min(tau as usize - result.len())]); + index += 1; + } + result +} + +fn e(j: i32, i: i32, key: &Key, block: &Block) -> Block { + let (key_i, key_j, key_l) = split_key(key); + if j == -1 { + let k = [ + &NULL, key_i, key_j, key_l, key_i, key_j, key_l, key_i, key_j, key_l, key_i, + ]; + let delta = times(i.try_into().expect("i was negative"), key_l); + aes10(&k, &xor(block, &delta)) + } else { + let k = [&NULL, key_j, key_i, key_l, &NULL]; + let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 }; + let delta = xor( + &xor( + ×(j.try_into().expect("j was negative"), key_j), + ×(1 << exponent, key_i), + ), + ×((i % 8).try_into().expect("i was negative"), key_l), + ); + aes4(&k, &xor(block, &delta)) + } +} + +fn split_key(key: &Key) -> (&Block, &Block, &Block) { + let (i, jl) = key.split_at(16); + let (j, l) = jl.split_at(16); + ( + i.try_into().unwrap(), + j.try_into().unwrap(), + l.try_into().unwrap(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_xor() { + assert_eq!(xor(&[1; 16], &[2; 16]), [3; 16]); + } + + #[test] + fn test_times() { + assert_eq!( + times(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + times(1, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ); + assert_eq!( + times(2, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2] + ); + assert_eq!( + times(2, &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133] + ); + assert_eq!( + times(2, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), + [2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133] + ); + assert_eq!( + times(3, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), + [131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132] + ); + assert_eq!( + times(4, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), + [4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10] + ); + } + + #[test] + fn test_lshift() { + assert_eq!( + lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1), + [0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4), + [0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + lshift(&[0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4), + [0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + lshift(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 8), + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] + ); + } + + #[test] + fn test_pad_block() { + assert_eq!( + pad_block(&[0; 16], 0), + [0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + pad_block(&[0; 16], 1), + [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + pad_block(&[0; 16], 8), + [0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + } + + #[test] + fn test_clip_to_bits() { + let mut block; + + block = [0xFF; 16]; + clip_to_bits(&mut block, 0); + assert_eq!(block, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + block = [0xFF; 16]; + clip_to_bits(&mut block, 4); + assert_eq!(block, [0xF0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + block = [0xFF; 16]; + clip_to_bits(&mut block, 8); + assert_eq!(block, [0xFF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + block = [0xFF; 16]; + clip_to_bits(&mut block, 9); + assert_eq!( + block, + [0xFF, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + } + + #[test] + fn test_extract() { + for (a, b) in testvectors::EXTRACT_VECTORS { + let a = hex::decode(a).unwrap(); + let b = hex::decode(b).unwrap(); + assert_eq!(extract(&a), b.as_slice()); + } + } + + #[test] + fn test_e() { + for (k, j, i, a, b) in testvectors::E_VECTORS { + let name = format!("e({j}, {i}, {k}, {a})"); + let k = hex::decode(k).unwrap(); + let k = k.as_slice().try_into().unwrap(); + let a = hex::decode(a).unwrap(); + let a = a.as_slice().try_into().unwrap(); + let b = hex::decode(b).unwrap(); + assert_eq!(&e(*j, *i, k, a), b.as_slice(), "{name}"); + } + } + + #[test] + fn test_aez_hash() { + for (k, tau, tw, v) in testvectors::HASH_VECTORS { + let name = format!("aez_hash({k}, {tau}, {tw:?})"); + let k = hex::decode(k).unwrap(); + let k = k.as_slice().try_into().unwrap(); + let v = hex::decode(v).unwrap(); + + let mut tweaks = vec![Vec::from(tau_to_block(*tau))]; + for t in *tw { + tweaks.push(hex::decode(t).unwrap()); + } + let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>(); + + assert_eq!(&aez_hash(&k, &tweaks), v.as_slice(), "{name}"); + } + } + + #[test] + fn test_encrypt() { + let mut failed = 0; + let mut succ = 0; + for (k, n, ads, tau, m, c) in testvectors::ENCRYPT_VECTORS { + let name = format!("encrypt({k}, {n}, {ads:?}, {tau}, {m})"); + let k = hex::decode(k).unwrap(); + let k = k.as_slice().try_into().unwrap(); + let n = hex::decode(n).unwrap(); + + let mut ad = Vec::new(); + for i in *ads { + ad.push(hex::decode(i).unwrap()); + } + let ad = ad.iter().map(Vec::as_slice).collect::<Vec<_>>(); + + let m = hex::decode(m).unwrap(); + let c = hex::decode(c).unwrap(); + + if &encrypt(&k, &n, &ad, *tau, &m) == &c { + println!("+ {name}"); + succ += 1; + } else { + println!("- {name}"); + failed += 1; + } + } + println!("{succ} succeeded, {failed} failed"); + assert_eq!(failed, 0); + } + + #[test] + fn test_decrypt() { + let mut failed = 0; + let mut succ = 0; + for (k, n, ads, tau, m, c) in testvectors::ENCRYPT_VECTORS { + let name = format!("decrypt({k}, {n}, {ads:?}, {tau}, {c})"); + let k = hex::decode(k).unwrap(); + let k = k.as_slice().try_into().unwrap(); + let n = hex::decode(n).unwrap(); + + let mut ad = Vec::new(); + for i in *ads { + ad.push(hex::decode(i).unwrap()); + } + let ad = ad.iter().map(Vec::as_slice).collect::<Vec<_>>(); + + let m = hex::decode(m).unwrap(); + let c = hex::decode(c).unwrap(); + + if decrypt(&k, &n, &ad, *tau, &c) == Some(m) { + println!("+ {name}"); + succ += 1; + } else { + println!("- {name}"); + failed += 1; + } + } + println!("{succ} succeeded, {failed} failed"); + assert_eq!(failed, 0); + } + + #[test] + fn test_encrypt_decrypt() { + let aez = Aez::new(b"foobar"); + let cipher = aez.encrypt(&[0], &[b"foobar"], 16, b"hi"); + let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap(); + assert_eq!(plain, b"hi"); + } +} |