From 71cdf50525f0cbb70673477510050669206df7f2 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Sat, 5 Apr 2025 19:03:10 +0200 Subject: use proper Block struct and operator overloading --- src/lib.rs | 565 ++++++++++++++++++------------------------------------------- 1 file changed, 166 insertions(+), 399 deletions(-) (limited to 'src/lib.rs') diff --git a/src/lib.rs b/src/lib.rs index 6f4d93f..f5dd3ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,13 @@ use std::iter; +mod block; #[cfg(test)] mod testvectors; -type Block = [u8; 16]; +use block::Block; type Key = [u8; 48]; type Tweak<'a> = &'a [&'a [u8]]; -static NULL: Block = [0; 16]; -static ONE: Block = [128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - pub struct Aez { key: Key, } @@ -40,72 +38,14 @@ impl Aez { } } -fn xor(lhs: &Block, rhs: &Block) -> Block { - let mut result = [0; 16]; - for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { - *r = a ^ b; - } - result -} - -fn and(lhs: &Block, rhs: &Block) -> Block { - let mut result = [0; 16]; - for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { - *r = a & b; - } - result -} - -fn or(lhs: &Block, rhs: &Block) -> Block { - let mut result = [0; 16]; - for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { - *r = a | b; - } - result -} - -fn lshift(block: &Block, times: u32) -> Block { - let mut block = block.clone(); - for _ in 0..times { - let mut result = [0; 16]; - for (b, r) in block.iter().zip(result.iter_mut()) { - *r = b << 1; - } - for (b, r) in block[1..].iter().zip(result.iter_mut()) { - *r = *r | ((b & 0x80) >> 7); - } - block = result; - } - block -} - -fn times(lhs: u32, block: &Block) -> Block { - match lhs { - 0 => NULL, - 1 => *block, - 2 => { - let mut result = lshift(block, 1); - if block[0] & 0x80 != 0 { - result[15] ^= 135; - } - result - } - _ if lhs % 2 == 0 => times(2, ×(lhs / 2, block)), - _ => xor(×(lhs - 1, block), block), - } -} - -fn aesenc(mut block: Block, key: &Block) -> Block { - aes::hazmat::cipher_round((&mut block).into(), key.into()); +fn aesenc(mut block: Block, key: &Block) -> block::Block { + aes::hazmat::cipher_round((&mut block.0).into(), &key.0.into()); block } fn aes4(keys: &[&Block; 5], block: &Block) -> Block { aesenc( - aesenc( - aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]), - keys[3], - ), + aesenc(aesenc(aesenc(*block ^ *keys[0], keys[1]), keys[2]), keys[3]), keys[4], ) } @@ -119,7 +59,7 @@ fn aes10(keys: &[&Block; 11], block: &Block) -> Block { aesenc( aesenc( aesenc( - aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]), + aesenc(aesenc(*block ^ *keys[0], keys[1]), keys[2]), keys[3], ), keys[4], @@ -150,35 +90,6 @@ fn extract(key: &[u8]) -> [u8; 48] { } } -fn clip_to_bits(block: &mut Block, mut bits: usize) { - for byte in block { - if bits == 0 { - *byte = 0; - } else if bits < 8 { - *byte &= 0xff << (8 - bits); - } - bits = bits.saturating_sub(8); - } -} - -fn full_block(data: &[u8]) -> Block { - let mut result = [0; 16]; - result[..data.len()].copy_from_slice(data); - result -} - -fn pad_block(block: &Block, mut bits: usize) -> Block { - let mut block = *block; - for byte in &mut block { - if bits < 8 { - *byte |= 0x80 >> bits; - break; - } - bits = bits.saturating_sub(8); - } - block -} - fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec { let auth_message = message .iter() @@ -186,8 +97,8 @@ fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> V .chain(iter::repeat_n(0, tau as usize)) .collect::>(); // We treat tau as bytes, but according to the spec, tau is actually in bits. - let tau_block = tau_to_block(tau * 8); - let mut tweaks = vec![&tau_block, nonce]; + let tau_block = Block::from_int(tau * 8); + let mut tweaks = vec![&tau_block.0, nonce]; tweaks.extend(ad); if message.is_empty() { aez_prf(key, &tweaks, tau) @@ -201,8 +112,8 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) - return None; } - let tau_block = tau_to_block(tau * 8); - let mut tweaks = vec![&tau_block, nonce]; + let tau_block = Block::from_int(tau * 8); + let mut tweaks = vec![&tau_block.0, nonce]; tweaks.extend(ad); if ciphertext.len() == tau as usize { @@ -237,7 +148,7 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let n = mu / 2; let delta = aez_hash(key, tweaks); let round_count = match mu { - 8 => 24, + 8 => 24u32, 16 => 16, _ if mu < 128 => 10, _ => 8, @@ -246,48 +157,34 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let (mut left, mut right); // We might end up having to split at a nibble, so manually adjust for that if n % 8 == 0 { - left = full_block(&message[..n / 8]); - right = full_block(&message[n / 8..]); + left = Block::from_slice(&message[..n / 8]); + right = Block::from_slice(&message[n / 8..]); } else { - left = full_block(&message[..n / 8 + 1]); - clip_to_bits(&mut left, n); - right = full_block(&message[n / 8..]); - right = lshift(&right, 4); + assert!(n % 8 == 4); + left = Block::from_slice(&message[..n / 8 + 1]).clip(n); + right = Block::from_slice(&message[n / 8..]) << 4; }; let i = if mu >= 128 { 6 } else { 7 }; for j in 0..round_count { - let mut right_ = xor( - &left, - &e( - 0, - i, - key, - &xor( - &xor(&delta, &pad_block(&right, n)), - &(j as u128).to_be_bytes(), - ), - ), - ); - clip_to_bits(&mut right_, n); + let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } let mut ciphertext = Vec::new(); if n % 8 == 0 { - ciphertext.extend_from_slice(&right[..n / 8]); - ciphertext.extend_from_slice(&left[..n / 8]); + ciphertext.extend_from_slice(&right.0[..n / 8]); + ciphertext.extend_from_slice(&left.0[..n / 8]); } else { - ciphertext.extend_from_slice(&right[..n / 8 + 1]); - for byte in &left[..n / 8 + 1] { + ciphertext.extend_from_slice(&right.0[..n / 8 + 1]); + for byte in &left.0[..n / 8 + 1] { *ciphertext.last_mut().unwrap() |= byte >> 4; ciphertext.push((byte & 0x0f) << 4); } ciphertext.pop(); } if mu < 128 { - let mut c = Block::default(); - c[..ciphertext.len()].copy_from_slice(&ciphertext); - c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE)); - ciphertext = Vec::from(&c[..mu / 8]); + let mut c = Block::from_slice(&ciphertext); + c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); + ciphertext = Vec::from(&c.0[..mu / 8]); } assert!(ciphertext.len() == message.len()); ciphertext @@ -304,77 +201,76 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { for (i, (mi, mi_)) in block_pairs.iter().enumerate() { let i = (i + 1) as i32; - let w = xor(mi, &e(1, i, key, mi_)); - let x = xor(mi_, &e(0, 0, key, &w)); + let w = *mi ^ e(1, i, key, *mi_); + let x = *mi_ ^ e(0, 0, key, w); ws.push(w); xs.push(x); } - let mut x = NULL; + let mut x = Block::NULL; for xi in &xs { - x = xor(&x, xi); + x = x ^ *xi; } match d { 0 => (), _ if d <= 127 => { - x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into()))); + x = x ^ e(0, 4, key, m_u.pad(d.into())); } _ => { - x = xor(&x, &e(0, 4, key, &m_u)); - x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into()))); + x = x ^ e(0, 4, key, m_u); + x = x ^ e(0, 5, key, m_v.pad(len_v.into())); } } - let s_x = xor(&m_x, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y)))); - let s_y = xor(&m_y, &e(-1, 1, key, &s_x)); - let s = xor(&s_x, &s_y); + let s_x = m_x ^ delta ^ x ^ e(0, 1, key, m_y); + let s_y = m_y ^ e(-1, 1, key, s_x); + let s = s_x ^ s_y; let mut cipher_pairs = Vec::new(); - let mut y = NULL; + let mut y = Block::NULL; for (i, (wi, xi)) in ws.iter().zip(xs.iter()).enumerate() { let i = (i + 1) as i32; - let s_ = e(2, i, key, &s); - let yi = xor(wi, &s_); - let zi = xor(xi, &s_); - let ci_ = xor(&yi, &e(0, 0, key, &zi)); - let ci = xor(&zi, &e(1, i, key, &ci_)); + let s_ = e(2, i, key, s); + let yi = *wi ^ s_; + let zi = *xi ^ s_; + let ci_ = yi ^ e(0, 0, key, zi); + let ci = zi ^ e(1, i, key, ci_); cipher_pairs.push((ci, ci_)); - y = xor(&y, &yi); + y = y ^ yi; } - let mut c_u = [0; 16]; - let mut c_v = [0; 16]; + let mut c_u = Block::default(); + let mut c_v = Block::default(); match d { 0 => (), _ if d <= 127 => { - c_u = xor(&m_u, &e(-1, 4, key, &s)); - clip_to_bits(&mut c_u, d.into()); - y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into()))); + c_u = (m_u ^ e(-1, 4, key, s)).clip(d.into()); + y = y ^ e(0, 4, key, c_u.pad(d.into())); } _ => { - c_u = xor(&m_u, &e(-1, 4, key, &s)); - c_v = xor(&m_v, &e(-1, 5, key, &s)); - clip_to_bits(&mut c_v, len_v.into()); - y = xor(&y, &e(0, 4, key, &c_u)); - y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into()))); + c_u = m_u ^ e(-1, 4, key, s); + c_v = (m_v ^ e(-1, 5, key, s)).clip(len_v.into()); + y = y ^ e(0, 4, key, c_u); + y = y ^ e(0, 5, key, c_v.pad(len_v.into())); } } - let c_y = xor(&s_x, &e(-1, 2, key, &s_y)); - let c_x = xor(&s_y, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y)))); + let c_y = s_x ^ e(-1, 2, key, s_y); + let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y); let mut ciphertext = Vec::new(); for (ci, ci_) in cipher_pairs { - ciphertext.extend_from_slice(&ci); - ciphertext.extend_from_slice(&ci_); + ciphertext.extend_from_slice(&ci.0); + ciphertext.extend_from_slice(&ci_.0); } - ciphertext.extend_from_slice(&c_u[..128.min(d) as usize / 8]); - ciphertext.extend_from_slice(&c_v[..len_v as usize / 8]); - ciphertext.extend_from_slice(&c_x); - ciphertext.extend_from_slice(&c_y); + ciphertext.extend_from_slice(&c_u.0[..128.min(d) as usize / 8]); + ciphertext.extend_from_slice(&c_v.0[..len_v as usize / 8]); + ciphertext.extend_from_slice(&c_x.0); + ciphertext.extend_from_slice(&c_y.0); + assert!(ciphertext.len() == message.len()); ciphertext } @@ -392,7 +288,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let n = mu / 2; let delta = aez_hash(key, tweaks); let round_count = match mu { - 8 => 24, + 8 => 24u32, 16 => 16, _ if mu < 128 => 10, _ => 8, @@ -400,48 +296,33 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let mut message = Vec::from(message); if mu < 128 { - let mut c = Block::default(); - c[..message.len()].copy_from_slice(&message); - c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE)); + let mut c = Block::from_slice(&message); + c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); message.clear(); - message.extend(&c[..mu / 8]); + message.extend(&c.0[..mu / 8]); } let (mut left, mut right); // We might end up having to split at a nibble, so manually adjust for that if n % 8 == 0 { - left = full_block(&message[..n / 8]); - right = full_block(&message[n / 8..]); + left = Block::from_slice(&message[..n / 8]); + right = Block::from_slice(&message[n / 8..]); } else { - left = full_block(&message[..n / 8 + 1]); - clip_to_bits(&mut left, n); - right = full_block(&message[n / 8..]); - right = lshift(&right, 4); + left = Block::from_slice(&message[..n / 8 + 1]).clip(n); + right = Block::from_slice(&message[n / 8..]) << 4; }; let i = if mu >= 128 { 6 } else { 7 }; for j in (0..round_count).rev() { - let mut right_ = xor( - &left, - &e( - 0, - i, - key, - &xor( - &xor(&delta, &pad_block(&right, n)), - &(j as u128).to_be_bytes(), - ), - ), - ); - clip_to_bits(&mut right_, n); + let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } let mut ciphertext = Vec::new(); if n % 8 == 0 { - ciphertext.extend_from_slice(&right[..n / 8]); - ciphertext.extend_from_slice(&left[..n / 8]); + ciphertext.extend_from_slice(&right.0[..n / 8]); + ciphertext.extend_from_slice(&left.0[..n / 8]); } else { - ciphertext.extend_from_slice(&right[..n / 8 + 1]); - for byte in &left[..n / 8 + 1] { + ciphertext.extend_from_slice(&right.0[..n / 8 + 1]); + for byte in &left.0[..n / 8 + 1] { *ciphertext.last_mut().unwrap() |= byte >> 4; ciphertext.push((byte & 0x0f) << 4); } @@ -462,77 +343,75 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec { for (i, (ci, ci_)) in block_pairs.iter().enumerate() { let i = (i + 1) as i32; - let w = xor(ci, &e(1, i, key, ci_)); - let y = xor(ci_, &e(0, 0, key, &w)); + let w = *ci ^ e(1, i, key, *ci_); + let y = *ci_ ^ e(0, 0, key, w); ws.push(w); ys.push(y); } - let mut y = NULL; + let mut y = Block::NULL; for yi in &ys { - y = xor(&y, yi); + y = y ^ *yi; } match d { 0 => (), _ if d <= 127 => { - y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into()))); + y = y ^ e(0, 4, key, c_u.pad(d.into())); } _ => { - y = xor(&y, &e(0, 4, key, &c_u)); - y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into()))); + y = y ^ e(0, 4, key, c_u); + y = y ^ e(0, 5, key, c_v.pad(len_v.into())); } } - let s_x = xor(&c_x, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y)))); - let s_y = xor(&c_y, &e(-1, 2, key, &s_x)); - let s = xor(&s_x, &s_y); + let s_x = c_x ^ delta ^ y ^ e(0, 2, key, c_y); + let s_y = c_y ^ e(-1, 2, key, s_x); + let s = s_x ^ s_y; let mut plain_pairs = Vec::new(); - let mut x = NULL; + let mut x = Block::NULL; for (i, (wi, yi)) in ws.iter().zip(ys.iter()).enumerate() { let i = (i + 1) as i32; - let s_ = e(2, i, key, &s); - let xi = xor(wi, &s_); - let zi = xor(yi, &s_); - let mi_ = xor(&xi, &e(0, 0, key, &zi)); - let mi = xor(&zi, &e(1, i, key, &mi_)); + let s_ = e(2, i, key, s); + let xi = *wi ^ s_; + let zi = *yi ^ s_; + let mi_ = xi ^ e(0, 0, key, zi); + let mi = zi ^ e(1, i, key, mi_); plain_pairs.push((mi, mi_)); - x = xor(&x, &xi); + x = x ^ xi; } - let mut m_u = [0; 16]; - let mut m_v = [0; 16]; + let mut m_u = Block::default(); + let mut m_v = Block::default(); match d { 0 => (), _ if d <= 127 => { - m_u = xor(&c_u, &e(-1, 4, key, &s)); - clip_to_bits(&mut m_u, d.into()); - x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into()))); + m_u = (c_u ^ e(-1, 4, key, s)).clip(d.into()); + x = x ^ e(0, 4, key, m_u.pad(d.into())); } _ => { - m_u = xor(&c_u, &e(-1, 4, key, &s)); - m_v = xor(&c_v, &e(-1, 5, key, &s)); - clip_to_bits(&mut m_v, len_v.into()); - x = xor(&x, &e(0, 4, key, &m_u)); - x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into()))); + m_u = c_u ^ e(-1, 4, key, s); + m_v = (c_v ^ e(-1, 5, key, s)).clip(len_v.into()); + x = x ^ e(0, 4, key, m_u); + x = x ^ e(0, 5, key, m_v.pad(len_v.into())); } } - let m_y = xor(&s_x, &e(-1, 1, key, &s_y)); - let m_x = xor(&s_y, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y)))); + let m_y = s_x ^ e(-1, 1, key, s_y); + let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y); let mut message = Vec::new(); for (mi, mi_) in plain_pairs { - message.extend_from_slice(&mi); - message.extend_from_slice(&mi_); + message.extend_from_slice(&mi.0); + message.extend_from_slice(&mi_.0); } - message.extend_from_slice(&m_u[..128.min(d) as usize / 8]); - message.extend_from_slice(&m_v[..len_v as usize / 8]); - message.extend_from_slice(&m_x); - message.extend_from_slice(&m_y); + message.extend_from_slice(&m_u.0[..128.min(d) as usize / 8]); + message.extend_from_slice(&m_v.0[..len_v as usize / 8]); + message.extend_from_slice(&m_x.0); + message.extend_from_slice(&m_y.0); message } @@ -540,9 +419,8 @@ fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block let num_blocks = (message.len() - 16 - 16) / 32; let mut blocks = Vec::new(); for _ in 0..num_blocks { - let (mut a, mut b) = ([0; 16], [0; 16]); - a.copy_from_slice(&message[..16]); - b.copy_from_slice(&message[16..32]); + let a = Block::from_slice(&message[..16]); + let b = Block::from_slice(&message[16..32]); blocks.push((a, b)); message = &message[32..]; } @@ -552,18 +430,17 @@ fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block message = &message[m_uv.len()..]; assert!(message.len() == 32); - let mut m_u = [0; 16]; - let mut m_v = [0; 16]; + let m_u; + let m_v; if d <= 127 { - m_u[..m_uv.len()].copy_from_slice(m_uv); + m_u = Block::from_slice(m_uv); + m_v = Block::default(); } else { - m_u.copy_from_slice(&m_uv[..16]); - m_v[..m_uv.len() - 16].copy_from_slice(&m_uv[16..]); + m_u = Block::from_slice(&m_uv[..16]); + m_v = Block::from_slice(&m_uv[16..]); } - let mut m_x = [0; 16]; - m_x.copy_from_slice(&message[..16]); - let mut m_y = [0; 16]; - m_y.copy_from_slice(&message[16..]); + let m_x = Block::from_slice(&message[..16]); + let m_y = Block::from_slice(&message[16..]); (blocks, m_u, m_v, m_x, m_y, d as u8) } @@ -571,29 +448,16 @@ fn pad_to_blocks(value: &[u8]) -> Vec { let mut blocks = Vec::new(); for chunk in value.chunks(16) { if chunk.len() == 16 { - blocks.push(chunk.try_into().expect("we made sure the length fits")); + blocks.push(Block::from_slice(chunk)); } else { - let mut block = Block::default(); - for (b, v) in block.iter_mut().zip( - chunk - .iter() - .chain(iter::once(&0x80)) - .chain(iter::repeat(&0)), - ) { - *b = *v; - } - blocks.push(block) + blocks.push(Block::from_slice(chunk).pad(chunk.len() * 8)); } } blocks } -fn tau_to_block(tau: u32) -> Block { - (tau as u128).to_be_bytes() -} - fn aez_hash(key: &Key, tweaks: Tweak) -> Block { - let mut hash = NULL; + let mut hash = Block::NULL; for (i, tweak) in tweaks.iter().enumerate() { // Adjust for zero-based vs one-based indexing let j = i + 2 + 1; @@ -601,33 +465,22 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this // as well, but then we need to fiddle with getting an empty chunk from an empty iterator. if tweak.is_empty() { - hash = xor( - &hash, - &e( - j.try_into().unwrap(), - 0, - key, - &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ), - ); + hash = hash ^ e(j.try_into().unwrap(), 0, key, Block::ONE); } else if tweak.len() % 16 == 0 { for (l, chunk) in tweak.chunks(16).enumerate() { - hash = xor( - &hash, - &e( + hash = hash + ^ e( j.try_into().unwrap(), (l + 1).try_into().unwrap(), key, - chunk.try_into().expect("we made sure the length fits"), - ), - ); + Block::from_slice(chunk), + ); } } else { let blocks = pad_to_blocks(tweak); for (l, chunk) in blocks.iter().enumerate() { - hash = xor( - &hash, - &e( + hash = hash + ^ e( j.try_into().unwrap(), if l == blocks.len() - 1 { 0 @@ -635,9 +488,8 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { (l + 1).try_into().unwrap() }, key, - chunk, - ), - ); + *chunk, + ); } } } @@ -649,146 +501,52 @@ fn aez_prf(key: &Key, tweaks: Tweak, tau: u32) -> Vec { let mut index = 0u128; let delta = aez_hash(key, tweaks); while result.len() < tau as usize { - let block = e(-1, 3, key, &xor(&delta, &index.to_be_bytes())); - result.extend_from_slice(&block[..16.min(tau as usize - result.len())]); + let block = e(-1, 3, key, delta ^ Block::from_int(index)); + result.extend_from_slice(&block.0[..16.min(tau as usize - result.len())]); index += 1; } result } -fn e(j: i32, i: i32, key: &Key, block: &Block) -> Block { +fn e(j: i32, i: i32, key: &Key, block: Block) -> Block { let (key_i, key_j, key_l) = split_key(key); if j == -1 { let k = [ - &NULL, key_i, key_j, key_l, key_i, key_j, key_l, key_i, key_j, key_l, key_i, + &Block::NULL, + &key_i, + &key_j, + &key_l, + &key_i, + &key_j, + &key_l, + &key_i, + &key_j, + &key_l, + &key_i, ]; - let delta = times(i.try_into().expect("i was negative"), key_l); - aes10(&k, &xor(block, &delta)) + let delta = key_l * i.try_into().expect("i was negative"); + aes10(&k, &(block ^ delta)) } else { - let k = [&NULL, key_j, key_i, key_l, &NULL]; + let k = [&Block::NULL, &key_j, &key_i, &key_l, &Block::NULL]; let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 }; - let delta = xor( - &xor( - ×(j.try_into().expect("j was negative"), key_j), - ×(1 << exponent, key_i), - ), - ×((i % 8).try_into().expect("i was negative"), key_l), - ); - aes4(&k, &xor(block, &delta)) + let j: u32 = j.try_into().expect("j was negative"); + let i: u32 = i.try_into().expect("i was negative"); + let delta = (key_j * j) ^ (key_i * (1 << exponent)) ^ (key_l * (i % 8)); + aes4(&k, &(block ^ delta)) } } -fn split_key(key: &Key) -> (&Block, &Block, &Block) { - let (i, jl) = key.split_at(16); - let (j, l) = jl.split_at(16); +fn split_key(key: &Key) -> (Block, Block, Block) { ( - i.try_into().unwrap(), - j.try_into().unwrap(), - l.try_into().unwrap(), + Block::from_slice(&key[..16]), + Block::from_slice(&key[16..32]), + Block::from_slice(&key[32..]), ) } #[cfg(test)] mod test { use super::*; - - #[test] - fn test_xor() { - assert_eq!(xor(&[1; 16], &[2; 16]), [3; 16]); - } - - #[test] - fn test_times() { - assert_eq!( - times(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - assert_eq!( - times(1, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] - ); - assert_eq!( - times(2, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2] - ); - assert_eq!( - times(2, &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133] - ); - assert_eq!( - times(2, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), - [2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133] - ); - assert_eq!( - times(3, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), - [131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132] - ); - assert_eq!( - times(4, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), - [4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10] - ); - } - - #[test] - fn test_lshift() { - assert_eq!( - lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1), - [0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - assert_eq!( - lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4), - [0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - assert_eq!( - lshift(&[0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4), - [0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - assert_eq!( - lshift(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 8), - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] - ); - } - - #[test] - fn test_pad_block() { - assert_eq!( - pad_block(&[0; 16], 0), - [0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - assert_eq!( - pad_block(&[0; 16], 1), - [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - assert_eq!( - pad_block(&[0; 16], 8), - [0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - } - - #[test] - fn test_clip_to_bits() { - let mut block; - - block = [0xFF; 16]; - clip_to_bits(&mut block, 0); - assert_eq!(block, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - - block = [0xFF; 16]; - clip_to_bits(&mut block, 4); - assert_eq!(block, [0xF0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - - block = [0xFF; 16]; - clip_to_bits(&mut block, 8); - assert_eq!(block, [0xFF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - - block = [0xFF; 16]; - clip_to_bits(&mut block, 9); - assert_eq!( - block, - [0xFF, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); - } - #[test] fn test_extract() { for (a, b) in testvectors::EXTRACT_VECTORS { @@ -805,9 +563,9 @@ mod test { let k = hex::decode(k).unwrap(); let k = k.as_slice().try_into().unwrap(); let a = hex::decode(a).unwrap(); - let a = a.as_slice().try_into().unwrap(); + let a = Block::from_slice(&a); let b = hex::decode(b).unwrap(); - assert_eq!(&e(*j, *i, k, a), b.as_slice(), "{name}"); + assert_eq!(&e(*j, *i, k, a).0, b.as_slice(), "{name}"); } } @@ -819,13 +577,13 @@ mod test { let k = k.as_slice().try_into().unwrap(); let v = hex::decode(v).unwrap(); - let mut tweaks = vec![Vec::from(tau_to_block(*tau))]; + let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)]; for t in *tw { tweaks.push(hex::decode(t).unwrap()); } let tweaks = tweaks.iter().map(Vec::as_slice).collect::>(); - assert_eq!(&aez_hash(&k, &tweaks), v.as_slice(), "{name}"); + assert_eq!(&aez_hash(&k, &tweaks).0, v.as_slice(), "{name}"); } } @@ -849,10 +607,10 @@ mod test { let c = hex::decode(c).unwrap(); if &encrypt(&k, &n, &ad, *tau, &m) == &c { - println!("+ {name}"); + //println!("+ {name}"); succ += 1; } else { - println!("- {name}"); + println!("- {}", c.len()); failed += 1; } } @@ -898,4 +656,13 @@ mod test { let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap(); assert_eq!(plain, b"hi"); } + + #[test] + fn test_encrypt_decrypt_long() { + let message = b"ene mene miste es rappelt in der kiste ene mene meck und du bist weg"; + let aez = Aez::new(b"foobar"); + let cipher = aez.encrypt(&[0], &[b"foobar"], 16, message); + let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap(); + assert_eq!(plain, message); + } } -- cgit v1.2.3