From 995e4d9e8f5c4db5cee959cb8a45640773d34ce5 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Wed, 9 Apr 2025 18:59:20 +0200 Subject: rewrite algorithm to work in-place --- src/lib.rs | 306 ++++++++++++++++++++++++++++++++----------------------------- 1 file changed, 160 insertions(+), 146 deletions(-) (limited to 'src/lib.rs') diff --git a/src/lib.rs b/src/lib.rs index 08f0570..58e488e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,10 +85,12 @@ use constant_time_eq::constant_time_eq; +mod accessor; mod block; #[cfg(test)] mod testvectors; +use accessor::BlockAccessor; use block::Block; type Key = [u8; 48]; type Tweak<'a> = &'a [&'a [u8]]; @@ -135,7 +137,10 @@ impl Aez { tau: u32, data: &[u8], ) -> Vec { - encrypt(&self.key, nonce, associated_data, tau, data) + let mut buffer = vec![0; data.len() + tau as usize]; + buffer[..data.len()].copy_from_slice(data); + encrypt(&self.key, nonce, associated_data, tau, &mut buffer); + buffer } /// Decrypts the given ciphertext. @@ -154,7 +159,13 @@ impl Aez { tau: u32, data: &[u8], ) -> Option> { - decrypt(&self.key, nonce, associated_data, tau, data) + let mut buffer = Vec::from(data); + let len = match decrypt(&self.key, nonce, associated_data, tau, &mut buffer) { + None => return None, + Some(m) => m.len(), + }; + buffer.truncate(len); + Some(buffer) } } @@ -210,28 +221,35 @@ fn extract(key: &[u8]) -> [u8; 48] { } } -fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec { - let mut auth_message = Vec::with_capacity(message.len() + tau as usize); - auth_message.extend_from_slice(&message); - while auth_message.len() < message.len() + tau as usize { - auth_message.extend_from_slice( - &ZEROES[..ZEROES - .len() - .min(tau as usize - (auth_message.len() - message.len()))], - ); +fn append_auth(data_len: usize, buffer: &mut [u8]) { + let mut total_len = data_len; + while total_len < buffer.len() { + let block_size = ZEROES.len().min(buffer.len() - total_len); + buffer[total_len..total_len + block_size].copy_from_slice(&ZEROES[..block_size]); + total_len += block_size; } +} + +fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { // We treat tau as bytes, but according to the spec, tau is actually in bits. let tau_block = Block::from_int(tau as u128 * 8); let mut tweaks = vec![&tau_block.0, nonce]; tweaks.extend(ad); - if message.is_empty() { - aez_prf(key, &tweaks, tau) + assert!(buffer.len() >= tau as usize); + if buffer.len() == tau as usize { + buffer.copy_from_slice(&aez_prf(key, &tweaks, tau)); } else { - encipher(key, &tweaks, &auth_message) + encipher(key, &tweaks, buffer); } } -fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -> Option> { +fn decrypt<'a>( + key: &Key, + nonce: &[u8], + ad: &[&[u8]], + tau: u32, + ciphertext: &'a mut [u8], +) -> Option<&'a [u8]> { if ciphertext.len() < tau as usize { return None; } @@ -242,14 +260,14 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) - if ciphertext.len() == tau as usize { if constant_time_eq(&ciphertext, &aez_prf(key, &tweaks, tau)) { - return Some(Vec::new()); + return Some(&[]); } else { return None; } } - let x = decipher(key, &tweaks, ciphertext); - let (m, auth) = x.split_at(ciphertext.len() - tau as usize); + decipher(key, &tweaks, ciphertext); + let (m, auth) = ciphertext.split_at(ciphertext.len() - tau as usize); assert!(auth.len() == tau as usize); let comparator = if tau as usize <= ZEROES.len() { &ZEROES[..tau as usize] @@ -257,13 +275,13 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) - &vec![0; tau as usize] }; if constant_time_eq(&auth, comparator) { - Some(Vec::from(m)) + Some(m) } else { None } } -fn encipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { +fn encipher(key: &Key, tweaks: Tweak, message: &mut [u8]) { if message.len() < 256 / 8 { encipher_aez_tiny(key, tweaks, message) } else { @@ -271,7 +289,7 @@ fn encipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { } } -fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { +fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { let mu = message.len() * 8; assert!(mu < 256); let n = mu / 2; @@ -298,48 +316,54 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } - let mut ciphertext = Vec::new(); if n % 8 == 0 { - ciphertext.extend_from_slice(&right.0[..n / 8]); - ciphertext.extend_from_slice(&left.0[..n / 8]); + message[..n / 8].copy_from_slice(&right.0[..n / 8]); + message[n / 8..].copy_from_slice(&left.0[..n / 8]); } else { - ciphertext.extend_from_slice(&right.0[..n / 8 + 1]); + let mut index = n / 8; + message[..index + 1].copy_from_slice(&right.0[..index + 1]); for byte in &left.0[..n / 8 + 1] { - *ciphertext.last_mut().unwrap() |= byte >> 4; - ciphertext.push((byte & 0x0f) << 4); + message[index] |= byte >> 4; + if index < message.len() - 1 { + message[index + 1] = (byte & 0x0f) << 4; + } + index += 1; } - ciphertext.pop(); } if mu < 128 { - let mut c = Block::from_slice(&ciphertext); + let mut c = Block::from_slice(&message); c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); - ciphertext = Vec::from(&c.0[..mu / 8]); + message.copy_from_slice(&c.0[..mu / 8]); } - assert!(ciphertext.len() == message.len()); - ciphertext } -fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { +fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { assert!(message.len() >= 32); let delta = aez_hash(key, tweaks); - let (block_pairs, m_u, m_v, m_x, m_y, d) = split_blocks(message); + let mut blocks = BlockAccessor::new(message); + let (m_u, m_v, m_x, m_y, d) = ( + blocks.m_u(), + blocks.m_v(), + blocks.m_x(), + blocks.m_y(), + blocks.m_uv_len(), + ); let len_v = d.saturating_sub(128); - let mut ws = Vec::new(); - let mut xs = Vec::new(); - + let mut x = Block::NULL; let mut e1_eval = E::new(1, 0, key); - for (mi, mi_) in block_pairs.iter() { + + for (raw_mi, raw_mi_) in blocks.pairs_mut() { e1_eval.advance(); - let w = *mi ^ e1_eval.eval(*mi_); - let x = *mi_ ^ e(0, 0, key, w); - ws.push(w); - xs.push(x); - } + let mi = Block::from(*raw_mi); + let mi_ = Block::from(*raw_mi_); + let wi = mi ^ e1_eval.eval(mi_); + let xi = mi_ ^ e(0, 0, key, wi); - let mut x = Block::NULL; - for xi in &xs { - x = x ^ *xi; + *raw_mi = wi.0; + *raw_mi_ = xi.0; + + x = x ^ xi; } match d { @@ -357,20 +381,22 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let s_y = m_y ^ e(-1, 1, key, s_x); let s = s_x ^ s_y; - let mut cipher_pairs = Vec::new(); let mut y = Block::NULL; let mut e2_eval = E::new(2, 0, key); let mut e1_eval = E::new(1, 0, key); - for (wi, xi) in ws.iter().zip(xs.iter()) { + for (raw_wi, raw_xi) in blocks.pairs_mut() { e2_eval.advance(); e1_eval.advance(); + let wi = Block::from(*raw_wi); + let xi = Block::from(*raw_xi); let s_ = e2_eval.eval(s); - let yi = *wi ^ s_; - let zi = *xi ^ s_; + let yi = wi ^ s_; + let zi = xi ^ s_; let ci_ = yi ^ e(0, 0, key, zi); let ci = zi ^ e1_eval.eval(ci_); - cipher_pairs.push((ci, ci_)); + *raw_wi = ci.0; + *raw_xi = ci_.0; y = y ^ yi; } @@ -394,29 +420,22 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { let c_y = s_x ^ e(-1, 2, key, s_y); let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y); - let mut ciphertext = Vec::new(); - for (ci, ci_) in cipher_pairs { - ciphertext.extend_from_slice(&ci.0); - ciphertext.extend_from_slice(&ci_.0); - } - ciphertext.extend_from_slice(&c_u.0[..128.min(d) as usize / 8]); - ciphertext.extend_from_slice(&c_v.0[..len_v as usize / 8]); - ciphertext.extend_from_slice(&c_x.0); - ciphertext.extend_from_slice(&c_y.0); - assert!(ciphertext.len() == message.len()); - ciphertext + blocks.set_m_u(c_u); + blocks.set_m_v(c_v); + blocks.set_m_x(c_x); + blocks.set_m_y(c_y); } -fn decipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { - if message.len() < 256 / 8 { - decipher_aez_tiny(key, tweaks, message) +fn decipher(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { + if buffer.len() < 256 / 8 { + decipher_aez_tiny(key, tweaks, buffer); } else { - decipher_aez_core(key, tweaks, message) + decipher_aez_core(key, tweaks, buffer); } } -fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { - let mu = message.len() * 8; +fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { + let mu = buffer.len() * 8; assert!(mu < 256); let n = mu / 2; let delta = aez_hash(key, tweaks); @@ -427,65 +446,69 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec { _ => 8, }; - let mut message = Vec::from(message); if mu < 128 { - let mut c = Block::from_slice(&message); + let mut c = Block::from_slice(buffer); c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); - message.clear(); - message.extend(&c.0[..mu / 8]); + buffer.copy_from_slice(&c.0[..mu / 8]); } let (mut left, mut right); // We might end up having to split at a nibble, so manually adjust for that if n % 8 == 0 { - left = Block::from_slice(&message[..n / 8]); - right = Block::from_slice(&message[n / 8..]); + left = Block::from_slice(&buffer[..n / 8]); + right = Block::from_slice(&buffer[n / 8..]); } else { - left = Block::from_slice(&message[..n / 8 + 1]).clip(n); - right = Block::from_slice(&message[n / 8..]) << 4; + left = Block::from_slice(&buffer[..n / 8 + 1]).clip(n); + right = Block::from_slice(&buffer[n / 8..]) << 4; }; let i = if mu >= 128 { 6 } else { 7 }; for j in (0..round_count).rev() { let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } - let mut ciphertext = Vec::new(); + if n % 8 == 0 { - ciphertext.extend_from_slice(&right.0[..n / 8]); - ciphertext.extend_from_slice(&left.0[..n / 8]); + buffer[..n / 8].copy_from_slice(&right.0[..n / 8]); + buffer[n / 8..].copy_from_slice(&left.0[..n / 8]); } else { - ciphertext.extend_from_slice(&right.0[..n / 8 + 1]); + let mut index = n / 8; + buffer[..index + 1].copy_from_slice(&right.0[..index + 1]); for byte in &left.0[..n / 8 + 1] { - *ciphertext.last_mut().unwrap() |= byte >> 4; - ciphertext.push((byte & 0x0f) << 4); + buffer[index] |= byte >> 4; + if index < buffer.len() - 1 { + buffer[index + 1] = (byte & 0x0f) << 4; + } + index += 1; } - ciphertext.pop(); } - assert!(ciphertext.len() == message.len()); - ciphertext } -fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec { - assert!(cipher.len() >= 32); +fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { + assert!(buffer.len() >= 32); let delta = aez_hash(key, tweaks); - let (block_pairs, c_u, c_v, c_x, c_y, d) = split_blocks(cipher); + let mut blocks = BlockAccessor::new(buffer); + let (c_u, c_v, c_x, c_y, d) = ( + blocks.m_u(), + blocks.m_v(), + blocks.m_x(), + blocks.m_y(), + blocks.m_uv_len(), + ); let len_v = d.saturating_sub(128); - let mut ws = Vec::new(); - let mut ys = Vec::new(); - + let mut y = Block::NULL; let mut e1_eval = E::new(1, 0, key); - for (ci, ci_) in block_pairs.iter() { + for (raw_ci, raw_ci_) in blocks.pairs_mut() { e1_eval.advance(); - let w = *ci ^ e1_eval.eval(*ci_); - let y = *ci_ ^ e(0, 0, key, w); - ws.push(w); - ys.push(y); - } + let ci = Block::from(*raw_ci); + let ci_ = Block::from(*raw_ci_); + let wi = ci ^ e1_eval.eval(ci_); + let yi = ci_ ^ e(0, 0, key, wi); - let mut y = Block::NULL; - for yi in &ys { - y = y ^ *yi; + *raw_ci = wi.0; + *raw_ci_ = yi.0; + + y = y ^ yi; } match d { @@ -503,20 +526,23 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec { let s_y = c_y ^ e(-1, 2, key, s_x); let s = s_x ^ s_y; - let mut plain_pairs = Vec::new(); let mut x = Block::NULL; let mut e2_eval = E::new(2, 0, key); let mut e1_eval = E::new(1, 0, key); - for (wi, yi) in ws.iter().zip(ys.iter()) { + for (raw_wi, raw_yi) in blocks.pairs_mut() { e2_eval.advance(); e1_eval.advance(); + let wi = Block::from(*raw_wi); + let yi = Block::from(*raw_yi); let s_ = e2_eval.eval(s); - let xi = *wi ^ s_; - let zi = *yi ^ s_; + let xi = wi ^ s_; + let zi = yi ^ s_; let mi_ = xi ^ e(0, 0, key, zi); let mi = zi ^ e1_eval.eval(mi_); - plain_pairs.push((mi, mi_)); + *raw_wi = mi.0; + *raw_yi = mi_.0; + x = x ^ xi; } @@ -540,45 +566,10 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec { let m_y = s_x ^ e(-1, 1, key, s_y); let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y); - let mut message = Vec::new(); - for (mi, mi_) in plain_pairs { - message.extend_from_slice(&mi.0); - message.extend_from_slice(&mi_.0); - } - message.extend_from_slice(&m_u.0[..128.min(d) as usize / 8]); - message.extend_from_slice(&m_v.0[..len_v as usize / 8]); - message.extend_from_slice(&m_x.0); - message.extend_from_slice(&m_y.0); - message -} - -fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block, Block, u8) { - let num_blocks = (message.len() - 16 - 16) / 32; - let mut blocks = Vec::new(); - for _ in 0..num_blocks { - let a = Block::from_slice(&message[..16]); - let b = Block::from_slice(&message[16..32]); - blocks.push((a, b)); - message = &message[32..]; - } - let m_uv = &message[..message.len() - 32]; - let d = m_uv.len() * 8; - assert!(d < 256); - message = &message[m_uv.len()..]; - assert!(message.len() == 32); - - let m_u; - let m_v; - if d <= 127 { - m_u = Block::from_slice(m_uv); - m_v = Block::default(); - } else { - m_u = Block::from_slice(&m_uv[..16]); - m_v = Block::from_slice(&m_uv[16..]); - } - let m_x = Block::from_slice(&message[..16]); - let m_y = Block::from_slice(&message[16..]); - (blocks, m_u, m_v, m_x, m_y, d as u8) + blocks.set_m_u(m_u); + blocks.set_m_v(m_v); + blocks.set_m_x(m_x); + blocks.set_m_y(m_y); } fn pad_to_blocks(value: &[u8]) -> Vec { @@ -803,6 +794,29 @@ mod test { } } + fn vec_encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec { + let mut v = vec![0; message.len() + tau as usize]; + v[..message.len()].copy_from_slice(message); + encrypt(key, nonce, ad, tau, &mut v); + v + } + + fn vec_decrypt( + key: &Key, + nonce: &[u8], + ad: &[&[u8]], + tau: u32, + ciphertext: &[u8], + ) -> Option> { + let mut v = Vec::from(ciphertext); + let len = match decrypt(key, nonce, ad, tau, &mut v) { + None => return None, + Some(m) => m.len(), + }; + v.truncate(len); + Some(v) + } + #[test] fn test_encrypt() { let mut failed = 0; @@ -822,7 +836,7 @@ mod test { let m = hex::decode(m).unwrap(); let c = hex::decode(c).unwrap(); - if &encrypt(&k, &n, &ad, *tau, &m) == &c { + if &vec_encrypt(&k, &n, &ad, *tau, &m) == &c { println!("+ {name}"); succ += 1; } else { @@ -853,7 +867,7 @@ mod test { let m = hex::decode(m).unwrap(); let c = hex::decode(c).unwrap(); - if decrypt(&k, &n, &ad, *tau, &c) == Some(m) { + if vec_decrypt(&k, &n, &ad, *tau, &c) == Some(m) { println!("+ {name}"); succ += 1; } else { -- cgit v1.2.3