aboutsummaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs901
1 files changed, 901 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..6f4d93f
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,901 @@
+use std::iter;
+
+#[cfg(test)]
+mod testvectors;
+
+type Block = [u8; 16];
+type Key = [u8; 48];
+type Tweak<'a> = &'a [&'a [u8]];
+
+static NULL: Block = [0; 16];
+static ONE: Block = [128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
+
+pub struct Aez {
+ key: Key,
+}
+
+impl Aez {
+ pub fn new(key: &[u8]) -> Self {
+ Aez { key: extract(key) }
+ }
+
+ pub fn encrypt(
+ &self,
+ nonce: &[u8],
+ associated_data: &[&[u8]],
+ tau: u32,
+ data: &[u8],
+ ) -> Vec<u8> {
+ encrypt(&self.key, nonce, associated_data, tau, data)
+ }
+
+ pub fn decrypt(
+ &self,
+ nonce: &[u8],
+ associated_data: &[&[u8]],
+ tau: u32,
+ data: &[u8],
+ ) -> Option<Vec<u8>> {
+ decrypt(&self.key, nonce, associated_data, tau, data)
+ }
+}
+
+fn xor(lhs: &Block, rhs: &Block) -> Block {
+ let mut result = [0; 16];
+ for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) {
+ *r = a ^ b;
+ }
+ result
+}
+
+fn and(lhs: &Block, rhs: &Block) -> Block {
+ let mut result = [0; 16];
+ for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) {
+ *r = a & b;
+ }
+ result
+}
+
+fn or(lhs: &Block, rhs: &Block) -> Block {
+ let mut result = [0; 16];
+ for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) {
+ *r = a | b;
+ }
+ result
+}
+
+fn lshift(block: &Block, times: u32) -> Block {
+ let mut block = block.clone();
+ for _ in 0..times {
+ let mut result = [0; 16];
+ for (b, r) in block.iter().zip(result.iter_mut()) {
+ *r = b << 1;
+ }
+ for (b, r) in block[1..].iter().zip(result.iter_mut()) {
+ *r = *r | ((b & 0x80) >> 7);
+ }
+ block = result;
+ }
+ block
+}
+
+fn times(lhs: u32, block: &Block) -> Block {
+ match lhs {
+ 0 => NULL,
+ 1 => *block,
+ 2 => {
+ let mut result = lshift(block, 1);
+ if block[0] & 0x80 != 0 {
+ result[15] ^= 135;
+ }
+ result
+ }
+ _ if lhs % 2 == 0 => times(2, &times(lhs / 2, block)),
+ _ => xor(&times(lhs - 1, block), block),
+ }
+}
+
+fn aesenc(mut block: Block, key: &Block) -> Block {
+ aes::hazmat::cipher_round((&mut block).into(), key.into());
+ block
+}
+
+fn aes4(keys: &[&Block; 5], block: &Block) -> Block {
+ aesenc(
+ aesenc(
+ aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]),
+ keys[3],
+ ),
+ keys[4],
+ )
+}
+
+fn aes10(keys: &[&Block; 11], block: &Block) -> Block {
+ aesenc(
+ aesenc(
+ aesenc(
+ aesenc(
+ aesenc(
+ aesenc(
+ aesenc(
+ aesenc(
+ aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]),
+ keys[3],
+ ),
+ keys[4],
+ ),
+ keys[5],
+ ),
+ keys[6],
+ ),
+ keys[7],
+ ),
+ keys[8],
+ ),
+ keys[9],
+ ),
+ keys[10],
+ )
+}
+
+fn extract(key: &[u8]) -> [u8; 48] {
+ if key.len() == 48 {
+ key.try_into().unwrap()
+ } else {
+ use blake2::Digest;
+ type Blake2b384 = blake2::Blake2b<blake2::digest::consts::U48>;
+ let mut hasher = Blake2b384::new();
+ hasher.update(key);
+ hasher.finalize().into()
+ }
+}
+
+fn clip_to_bits(block: &mut Block, mut bits: usize) {
+ for byte in block {
+ if bits == 0 {
+ *byte = 0;
+ } else if bits < 8 {
+ *byte &= 0xff << (8 - bits);
+ }
+ bits = bits.saturating_sub(8);
+ }
+}
+
+fn full_block(data: &[u8]) -> Block {
+ let mut result = [0; 16];
+ result[..data.len()].copy_from_slice(data);
+ result
+}
+
+fn pad_block(block: &Block, mut bits: usize) -> Block {
+ let mut block = *block;
+ for byte in &mut block {
+ if bits < 8 {
+ *byte |= 0x80 >> bits;
+ break;
+ }
+ bits = bits.saturating_sub(8);
+ }
+ block
+}
+
+fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> {
+ let auth_message = message
+ .iter()
+ .copied()
+ .chain(iter::repeat_n(0, tau as usize))
+ .collect::<Vec<_>>();
+ // We treat tau as bytes, but according to the spec, tau is actually in bits.
+ let tau_block = tau_to_block(tau * 8);
+ let mut tweaks = vec![&tau_block, nonce];
+ tweaks.extend(ad);
+ if message.is_empty() {
+ aez_prf(key, &tweaks, tau)
+ } else {
+ encipher(key, &tweaks, &auth_message)
+ }
+}
+
+fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -> Option<Vec<u8>> {
+ if ciphertext.len() < tau as usize {
+ return None;
+ }
+
+ let tau_block = tau_to_block(tau * 8);
+ let mut tweaks = vec![&tau_block, nonce];
+ tweaks.extend(ad);
+
+ if ciphertext.len() == tau as usize {
+ if ciphertext == aez_prf(key, &tweaks, tau) {
+ return Some(Vec::new());
+ } else {
+ return None;
+ }
+ }
+
+ let x = decipher(key, &tweaks, ciphertext);
+ let (m, auth) = x.split_at(ciphertext.len() - tau as usize);
+ assert!(auth.len() == tau as usize);
+ if auth.iter().all(|x| *x == 0) {
+ Some(Vec::from(m))
+ } else {
+ None
+ }
+}
+
+fn encipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+ if message.len() < 256 / 8 {
+ encipher_aez_tiny(key, tweaks, message)
+ } else {
+ encipher_aez_core(key, tweaks, message)
+ }
+}
+
+fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+ let mu = message.len() * 8;
+ assert!(mu < 256);
+ let n = mu / 2;
+ let delta = aez_hash(key, tweaks);
+ let round_count = match mu {
+ 8 => 24,
+ 16 => 16,
+ _ if mu < 128 => 10,
+ _ => 8,
+ };
+
+ let (mut left, mut right);
+ // We might end up having to split at a nibble, so manually adjust for that
+ if n % 8 == 0 {
+ left = full_block(&message[..n / 8]);
+ right = full_block(&message[n / 8..]);
+ } else {
+ left = full_block(&message[..n / 8 + 1]);
+ clip_to_bits(&mut left, n);
+ right = full_block(&message[n / 8..]);
+ right = lshift(&right, 4);
+ };
+ let i = if mu >= 128 { 6 } else { 7 };
+ for j in 0..round_count {
+ let mut right_ = xor(
+ &left,
+ &e(
+ 0,
+ i,
+ key,
+ &xor(
+ &xor(&delta, &pad_block(&right, n)),
+ &(j as u128).to_be_bytes(),
+ ),
+ ),
+ );
+ clip_to_bits(&mut right_, n);
+ (left, right) = (right, right_);
+ }
+ let mut ciphertext = Vec::new();
+ if n % 8 == 0 {
+ ciphertext.extend_from_slice(&right[..n / 8]);
+ ciphertext.extend_from_slice(&left[..n / 8]);
+ } else {
+ ciphertext.extend_from_slice(&right[..n / 8 + 1]);
+ for byte in &left[..n / 8 + 1] {
+ *ciphertext.last_mut().unwrap() |= byte >> 4;
+ ciphertext.push((byte & 0x0f) << 4);
+ }
+ ciphertext.pop();
+ }
+ if mu < 128 {
+ let mut c = Block::default();
+ c[..ciphertext.len()].copy_from_slice(&ciphertext);
+ c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE));
+ ciphertext = Vec::from(&c[..mu / 8]);
+ }
+ assert!(ciphertext.len() == message.len());
+ ciphertext
+}
+
+fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+ assert!(message.len() >= 32);
+ let delta = aez_hash(key, tweaks);
+ let (block_pairs, m_u, m_v, m_x, m_y, d) = split_blocks(message);
+ let len_v = d.saturating_sub(128);
+
+ let mut ws = Vec::new();
+ let mut xs = Vec::new();
+
+ for (i, (mi, mi_)) in block_pairs.iter().enumerate() {
+ let i = (i + 1) as i32;
+ let w = xor(mi, &e(1, i, key, mi_));
+ let x = xor(mi_, &e(0, 0, key, &w));
+ ws.push(w);
+ xs.push(x);
+ }
+
+ let mut x = NULL;
+ for xi in &xs {
+ x = xor(&x, xi);
+ }
+
+ match d {
+ 0 => (),
+ _ if d <= 127 => {
+ x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into())));
+ }
+ _ => {
+ x = xor(&x, &e(0, 4, key, &m_u));
+ x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into())));
+ }
+ }
+
+ let s_x = xor(&m_x, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y))));
+ let s_y = xor(&m_y, &e(-1, 1, key, &s_x));
+ let s = xor(&s_x, &s_y);
+
+ let mut cipher_pairs = Vec::new();
+ let mut y = NULL;
+ for (i, (wi, xi)) in ws.iter().zip(xs.iter()).enumerate() {
+ let i = (i + 1) as i32;
+ let s_ = e(2, i, key, &s);
+ let yi = xor(wi, &s_);
+ let zi = xor(xi, &s_);
+ let ci_ = xor(&yi, &e(0, 0, key, &zi));
+ let ci = xor(&zi, &e(1, i, key, &ci_));
+
+ cipher_pairs.push((ci, ci_));
+ y = xor(&y, &yi);
+ }
+
+ let mut c_u = [0; 16];
+ let mut c_v = [0; 16];
+
+ match d {
+ 0 => (),
+ _ if d <= 127 => {
+ c_u = xor(&m_u, &e(-1, 4, key, &s));
+ clip_to_bits(&mut c_u, d.into());
+ y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into())));
+ }
+ _ => {
+ c_u = xor(&m_u, &e(-1, 4, key, &s));
+ c_v = xor(&m_v, &e(-1, 5, key, &s));
+ clip_to_bits(&mut c_v, len_v.into());
+ y = xor(&y, &e(0, 4, key, &c_u));
+ y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into())));
+ }
+ }
+
+ let c_y = xor(&s_x, &e(-1, 2, key, &s_y));
+ let c_x = xor(&s_y, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y))));
+
+ let mut ciphertext = Vec::new();
+ for (ci, ci_) in cipher_pairs {
+ ciphertext.extend_from_slice(&ci);
+ ciphertext.extend_from_slice(&ci_);
+ }
+ ciphertext.extend_from_slice(&c_u[..128.min(d) as usize / 8]);
+ ciphertext.extend_from_slice(&c_v[..len_v as usize / 8]);
+ ciphertext.extend_from_slice(&c_x);
+ ciphertext.extend_from_slice(&c_y);
+ ciphertext
+}
+
+fn decipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+ if message.len() < 256 / 8 {
+ decipher_aez_tiny(key, tweaks, message)
+ } else {
+ decipher_aez_core(key, tweaks, message)
+ }
+}
+
+fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+ let mu = message.len() * 8;
+ assert!(mu < 256);
+ let n = mu / 2;
+ let delta = aez_hash(key, tweaks);
+ let round_count = match mu {
+ 8 => 24,
+ 16 => 16,
+ _ if mu < 128 => 10,
+ _ => 8,
+ };
+
+ let mut message = Vec::from(message);
+ if mu < 128 {
+ let mut c = Block::default();
+ c[..message.len()].copy_from_slice(&message);
+ c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE));
+ message.clear();
+ message.extend(&c[..mu / 8]);
+ }
+
+ let (mut left, mut right);
+ // We might end up having to split at a nibble, so manually adjust for that
+ if n % 8 == 0 {
+ left = full_block(&message[..n / 8]);
+ right = full_block(&message[n / 8..]);
+ } else {
+ left = full_block(&message[..n / 8 + 1]);
+ clip_to_bits(&mut left, n);
+ right = full_block(&message[n / 8..]);
+ right = lshift(&right, 4);
+ };
+ let i = if mu >= 128 { 6 } else { 7 };
+ for j in (0..round_count).rev() {
+ let mut right_ = xor(
+ &left,
+ &e(
+ 0,
+ i,
+ key,
+ &xor(
+ &xor(&delta, &pad_block(&right, n)),
+ &(j as u128).to_be_bytes(),
+ ),
+ ),
+ );
+ clip_to_bits(&mut right_, n);
+ (left, right) = (right, right_);
+ }
+ let mut ciphertext = Vec::new();
+ if n % 8 == 0 {
+ ciphertext.extend_from_slice(&right[..n / 8]);
+ ciphertext.extend_from_slice(&left[..n / 8]);
+ } else {
+ ciphertext.extend_from_slice(&right[..n / 8 + 1]);
+ for byte in &left[..n / 8 + 1] {
+ *ciphertext.last_mut().unwrap() |= byte >> 4;
+ ciphertext.push((byte & 0x0f) << 4);
+ }
+ ciphertext.pop();
+ }
+ assert!(ciphertext.len() == message.len());
+ ciphertext
+}
+
+fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> {
+ assert!(cipher.len() >= 32);
+ let delta = aez_hash(key, tweaks);
+ let (block_pairs, c_u, c_v, c_x, c_y, d) = split_blocks(cipher);
+ let len_v = d.saturating_sub(128);
+
+ let mut ws = Vec::new();
+ let mut ys = Vec::new();
+
+ for (i, (ci, ci_)) in block_pairs.iter().enumerate() {
+ let i = (i + 1) as i32;
+ let w = xor(ci, &e(1, i, key, ci_));
+ let y = xor(ci_, &e(0, 0, key, &w));
+ ws.push(w);
+ ys.push(y);
+ }
+
+ let mut y = NULL;
+ for yi in &ys {
+ y = xor(&y, yi);
+ }
+
+ match d {
+ 0 => (),
+ _ if d <= 127 => {
+ y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into())));
+ }
+ _ => {
+ y = xor(&y, &e(0, 4, key, &c_u));
+ y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into())));
+ }
+ }
+
+ let s_x = xor(&c_x, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y))));
+ let s_y = xor(&c_y, &e(-1, 2, key, &s_x));
+ let s = xor(&s_x, &s_y);
+
+ let mut plain_pairs = Vec::new();
+ let mut x = NULL;
+ for (i, (wi, yi)) in ws.iter().zip(ys.iter()).enumerate() {
+ let i = (i + 1) as i32;
+ let s_ = e(2, i, key, &s);
+ let xi = xor(wi, &s_);
+ let zi = xor(yi, &s_);
+ let mi_ = xor(&xi, &e(0, 0, key, &zi));
+ let mi = xor(&zi, &e(1, i, key, &mi_));
+
+ plain_pairs.push((mi, mi_));
+ x = xor(&x, &xi);
+ }
+
+ let mut m_u = [0; 16];
+ let mut m_v = [0; 16];
+
+ match d {
+ 0 => (),
+ _ if d <= 127 => {
+ m_u = xor(&c_u, &e(-1, 4, key, &s));
+ clip_to_bits(&mut m_u, d.into());
+ x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into())));
+ }
+ _ => {
+ m_u = xor(&c_u, &e(-1, 4, key, &s));
+ m_v = xor(&c_v, &e(-1, 5, key, &s));
+ clip_to_bits(&mut m_v, len_v.into());
+ x = xor(&x, &e(0, 4, key, &m_u));
+ x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into())));
+ }
+ }
+
+ let m_y = xor(&s_x, &e(-1, 1, key, &s_y));
+ let m_x = xor(&s_y, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y))));
+
+ let mut message = Vec::new();
+ for (mi, mi_) in plain_pairs {
+ message.extend_from_slice(&mi);
+ message.extend_from_slice(&mi_);
+ }
+ message.extend_from_slice(&m_u[..128.min(d) as usize / 8]);
+ message.extend_from_slice(&m_v[..len_v as usize / 8]);
+ message.extend_from_slice(&m_x);
+ message.extend_from_slice(&m_y);
+ message
+}
+
+fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block, Block, u8) {
+ let num_blocks = (message.len() - 16 - 16) / 32;
+ let mut blocks = Vec::new();
+ for _ in 0..num_blocks {
+ let (mut a, mut b) = ([0; 16], [0; 16]);
+ a.copy_from_slice(&message[..16]);
+ b.copy_from_slice(&message[16..32]);
+ blocks.push((a, b));
+ message = &message[32..];
+ }
+ let m_uv = &message[..message.len() - 32];
+ let d = m_uv.len() * 8;
+ assert!(d < 256);
+ message = &message[m_uv.len()..];
+ assert!(message.len() == 32);
+
+ let mut m_u = [0; 16];
+ let mut m_v = [0; 16];
+ if d <= 127 {
+ m_u[..m_uv.len()].copy_from_slice(m_uv);
+ } else {
+ m_u.copy_from_slice(&m_uv[..16]);
+ m_v[..m_uv.len() - 16].copy_from_slice(&m_uv[16..]);
+ }
+ let mut m_x = [0; 16];
+ m_x.copy_from_slice(&message[..16]);
+ let mut m_y = [0; 16];
+ m_y.copy_from_slice(&message[16..]);
+ (blocks, m_u, m_v, m_x, m_y, d as u8)
+}
+
+fn pad_to_blocks(value: &[u8]) -> Vec<Block> {
+ let mut blocks = Vec::new();
+ for chunk in value.chunks(16) {
+ if chunk.len() == 16 {
+ blocks.push(chunk.try_into().expect("we made sure the length fits"));
+ } else {
+ let mut block = Block::default();
+ for (b, v) in block.iter_mut().zip(
+ chunk
+ .iter()
+ .chain(iter::once(&0x80))
+ .chain(iter::repeat(&0)),
+ ) {
+ *b = *v;
+ }
+ blocks.push(block)
+ }
+ }
+ blocks
+}
+
+fn tau_to_block(tau: u32) -> Block {
+ (tau as u128).to_be_bytes()
+}
+
+fn aez_hash(key: &Key, tweaks: Tweak) -> Block {
+ let mut hash = NULL;
+ for (i, tweak) in tweaks.iter().enumerate() {
+ // Adjust for zero-based vs one-based indexing
+ let j = i + 2 + 1;
+ // This is somewhat implicit in the AEZ spec, but basically for an empty string we still
+ // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this
+ // as well, but then we need to fiddle with getting an empty chunk from an empty iterator.
+ if tweak.is_empty() {
+ hash = xor(
+ &hash,
+ &e(
+ j.try_into().unwrap(),
+ 0,
+ key,
+ &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ),
+ );
+ } else if tweak.len() % 16 == 0 {
+ for (l, chunk) in tweak.chunks(16).enumerate() {
+ hash = xor(
+ &hash,
+ &e(
+ j.try_into().unwrap(),
+ (l + 1).try_into().unwrap(),
+ key,
+ chunk.try_into().expect("we made sure the length fits"),
+ ),
+ );
+ }
+ } else {
+ let blocks = pad_to_blocks(tweak);
+ for (l, chunk) in blocks.iter().enumerate() {
+ hash = xor(
+ &hash,
+ &e(
+ j.try_into().unwrap(),
+ if l == blocks.len() - 1 {
+ 0
+ } else {
+ (l + 1).try_into().unwrap()
+ },
+ key,
+ chunk,
+ ),
+ );
+ }
+ }
+ }
+ hash
+}
+
+fn aez_prf(key: &Key, tweaks: Tweak, tau: u32) -> Vec<u8> {
+ let mut result = Vec::new();
+ let mut index = 0u128;
+ let delta = aez_hash(key, tweaks);
+ while result.len() < tau as usize {
+ let block = e(-1, 3, key, &xor(&delta, &index.to_be_bytes()));
+ result.extend_from_slice(&block[..16.min(tau as usize - result.len())]);
+ index += 1;
+ }
+ result
+}
+
+fn e(j: i32, i: i32, key: &Key, block: &Block) -> Block {
+ let (key_i, key_j, key_l) = split_key(key);
+ if j == -1 {
+ let k = [
+ &NULL, key_i, key_j, key_l, key_i, key_j, key_l, key_i, key_j, key_l, key_i,
+ ];
+ let delta = times(i.try_into().expect("i was negative"), key_l);
+ aes10(&k, &xor(block, &delta))
+ } else {
+ let k = [&NULL, key_j, key_i, key_l, &NULL];
+ let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 };
+ let delta = xor(
+ &xor(
+ &times(j.try_into().expect("j was negative"), key_j),
+ &times(1 << exponent, key_i),
+ ),
+ &times((i % 8).try_into().expect("i was negative"), key_l),
+ );
+ aes4(&k, &xor(block, &delta))
+ }
+}
+
+fn split_key(key: &Key) -> (&Block, &Block, &Block) {
+ let (i, jl) = key.split_at(16);
+ let (j, l) = jl.split_at(16);
+ (
+ i.try_into().unwrap(),
+ j.try_into().unwrap(),
+ l.try_into().unwrap(),
+ )
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_xor() {
+ assert_eq!(xor(&[1; 16], &[2; 16]), [3; 16]);
+ }
+
+ #[test]
+ fn test_times() {
+ assert_eq!(
+ times(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ assert_eq!(
+ times(1, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
+ );
+ assert_eq!(
+ times(2, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]
+ );
+ assert_eq!(
+ times(2, &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133]
+ );
+ assert_eq!(
+ times(2, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
+ [2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133]
+ );
+ assert_eq!(
+ times(3, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
+ [131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132]
+ );
+ assert_eq!(
+ times(4, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
+ [4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10]
+ );
+ }
+
+ #[test]
+ fn test_lshift() {
+ assert_eq!(
+ lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1),
+ [0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ assert_eq!(
+ lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4),
+ [0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ assert_eq!(
+ lshift(&[0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4),
+ [0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ assert_eq!(
+ lshift(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 8),
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
+ );
+ }
+
+ #[test]
+ fn test_pad_block() {
+ assert_eq!(
+ pad_block(&[0; 16], 0),
+ [0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ assert_eq!(
+ pad_block(&[0; 16], 1),
+ [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ assert_eq!(
+ pad_block(&[0; 16], 8),
+ [0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ }
+
+ #[test]
+ fn test_clip_to_bits() {
+ let mut block;
+
+ block = [0xFF; 16];
+ clip_to_bits(&mut block, 0);
+ assert_eq!(block, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
+
+ block = [0xFF; 16];
+ clip_to_bits(&mut block, 4);
+ assert_eq!(block, [0xF0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
+
+ block = [0xFF; 16];
+ clip_to_bits(&mut block, 8);
+ assert_eq!(block, [0xFF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
+
+ block = [0xFF; 16];
+ clip_to_bits(&mut block, 9);
+ assert_eq!(
+ block,
+ [0xFF, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ );
+ }
+
+ #[test]
+ fn test_extract() {
+ for (a, b) in testvectors::EXTRACT_VECTORS {
+ let a = hex::decode(a).unwrap();
+ let b = hex::decode(b).unwrap();
+ assert_eq!(extract(&a), b.as_slice());
+ }
+ }
+
+ #[test]
+ fn test_e() {
+ for (k, j, i, a, b) in testvectors::E_VECTORS {
+ let name = format!("e({j}, {i}, {k}, {a})");
+ let k = hex::decode(k).unwrap();
+ let k = k.as_slice().try_into().unwrap();
+ let a = hex::decode(a).unwrap();
+ let a = a.as_slice().try_into().unwrap();
+ let b = hex::decode(b).unwrap();
+ assert_eq!(&e(*j, *i, k, a), b.as_slice(), "{name}");
+ }
+ }
+
+ #[test]
+ fn test_aez_hash() {
+ for (k, tau, tw, v) in testvectors::HASH_VECTORS {
+ let name = format!("aez_hash({k}, {tau}, {tw:?})");
+ let k = hex::decode(k).unwrap();
+ let k = k.as_slice().try_into().unwrap();
+ let v = hex::decode(v).unwrap();
+
+ let mut tweaks = vec![Vec::from(tau_to_block(*tau))];
+ for t in *tw {
+ tweaks.push(hex::decode(t).unwrap());
+ }
+ let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>();
+
+ assert_eq!(&aez_hash(&k, &tweaks), v.as_slice(), "{name}");
+ }
+ }
+
+ #[test]
+ fn test_encrypt() {
+ let mut failed = 0;
+ let mut succ = 0;
+ for (k, n, ads, tau, m, c) in testvectors::ENCRYPT_VECTORS {
+ let name = format!("encrypt({k}, {n}, {ads:?}, {tau}, {m})");
+ let k = hex::decode(k).unwrap();
+ let k = k.as_slice().try_into().unwrap();
+ let n = hex::decode(n).unwrap();
+
+ let mut ad = Vec::new();
+ for i in *ads {
+ ad.push(hex::decode(i).unwrap());
+ }
+ let ad = ad.iter().map(Vec::as_slice).collect::<Vec<_>>();
+
+ let m = hex::decode(m).unwrap();
+ let c = hex::decode(c).unwrap();
+
+ if &encrypt(&k, &n, &ad, *tau, &m) == &c {
+ println!("+ {name}");
+ succ += 1;
+ } else {
+ println!("- {name}");
+ failed += 1;
+ }
+ }
+ println!("{succ} succeeded, {failed} failed");
+ assert_eq!(failed, 0);
+ }
+
+ #[test]
+ fn test_decrypt() {
+ let mut failed = 0;
+ let mut succ = 0;
+ for (k, n, ads, tau, m, c) in testvectors::ENCRYPT_VECTORS {
+ let name = format!("decrypt({k}, {n}, {ads:?}, {tau}, {c})");
+ let k = hex::decode(k).unwrap();
+ let k = k.as_slice().try_into().unwrap();
+ let n = hex::decode(n).unwrap();
+
+ let mut ad = Vec::new();
+ for i in *ads {
+ ad.push(hex::decode(i).unwrap());
+ }
+ let ad = ad.iter().map(Vec::as_slice).collect::<Vec<_>>();
+
+ let m = hex::decode(m).unwrap();
+ let c = hex::decode(c).unwrap();
+
+ if decrypt(&k, &n, &ad, *tau, &c) == Some(m) {
+ println!("+ {name}");
+ succ += 1;
+ } else {
+ println!("- {name}");
+ failed += 1;
+ }
+ }
+ println!("{succ} succeeded, {failed} failed");
+ assert_eq!(failed, 0);
+ }
+
+ #[test]
+ fn test_encrypt_decrypt() {
+ let aez = Aez::new(b"foobar");
+ let cipher = aez.encrypt(&[0], &[b"foobar"], 16, b"hi");
+ let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap();
+ assert_eq!(plain, b"hi");
+ }
+}