diff options
author | Daniel Schadt <kingdread@gmx.de> | 2025-04-10 23:08:10 +0200 |
---|---|---|
committer | Daniel Schadt <kingdread@gmx.de> | 2025-04-10 23:08:10 +0200 |
commit | eecd8fb23edb86223f8e3c6ac18df7c1dc596151 (patch) | |
tree | 170fad5455e097195da2fa52bb9cec09c443dc77 | |
parent | 0009a24bfb76fe425844c99769148d66c23a7225 (diff) | |
download | zears-eecd8fb23edb86223f8e3c6ac18df7c1dc596151.tar.gz zears-eecd8fb23edb86223f8e3c6ac18df7c1dc596151.tar.bz2 zears-eecd8fb23edb86223f8e3c6ac18df7c1dc596151.zip |
only have a single AesImpl instance
When I first wrote the aesenc/aes4/aes10 functions, I didn't know yet
how they were going to be used, so I sticked to the spec as much as
possible. As it turns out, they are always used with the same keys, so
it's enough to "initialize" the AES once, and then re-use for multiple E
computations.
It's also beginning a lot to look like all of those functions should
actually be methods, which is something we can fix in the future (and
unite decipher/encipher).
Anyway, the speedup here is around 38% for the 1KiB benchmark, and 4%
for the 16KiB benchmark.
-rw-r--r-- | src/lib.rs | 227 |
1 files changed, 118 insertions, 109 deletions
@@ -102,7 +102,11 @@ static ZEROES: [u8; 1024] = [0; 1024]; /// AEZ encryption scheme. pub struct Aez { - key: Key, + key_i: Block, + key_j: Block, + key_l: Block, + key_l_multiples: [Block; 8], + aes: aesround::AesImpl, } impl Aez { @@ -113,7 +117,26 @@ impl Aez { /// If you provide a key of the correct length (48 bytes), no expansion is done and the key is /// taken as-is. pub fn new(key: &[u8]) -> Self { - Aez { key: extract(key) } + let key = extract(key); + let (key_i, key_j, key_l) = split_key(&key); + let aes = aesround::AesImpl::new(key_i, key_j, key_l); + let key_l_multiples = [ + key_l * 0, + key_l * 1, + key_l * 2, + key_l * 3, + key_l * 4, + key_l * 5, + key_l * 6, + key_l * 7, + ]; + Aez { + key_i, + key_j, + key_l, + key_l_multiples, + aes, + } } /// Encrypt the given data. @@ -164,7 +187,7 @@ impl Aez { data: &mut Vec<u8>, ) { data.resize(data.len() + tau as usize, 0); - encrypt(&self.key, nonce, associated_data, tau, data); + encrypt(&self, nonce, associated_data, tau, data); } /// Encrypts the data inplace. @@ -183,7 +206,7 @@ impl Aez { assert!(buffer.len() >= tau as usize); let data_len = buffer.len() - tau as usize; append_auth(data_len, buffer); - encrypt(&self.key, nonce, associated_data, tau as u32, buffer); + encrypt(&self, nonce, associated_data, tau as u32, buffer); } /// Encrypts the data in the given buffer, writing the output to the given output buffer. @@ -203,7 +226,7 @@ impl Aez { let tau = output.len() - input.len(); output[..input.len()].copy_from_slice(input); append_auth(input.len(), output); - encrypt(&self.key, nonce, associated_data, tau as u32, output); + encrypt(&self, nonce, associated_data, tau as u32, output); } /// Decrypts the given ciphertext. @@ -226,7 +249,7 @@ impl Aez { data: &[u8], ) -> Option<Vec<u8>> { let mut buffer = Vec::from(data); - let len = match decrypt(&self.key, nonce, associated_data, tau, &mut buffer) { + let len = match decrypt(&self, nonce, associated_data, tau, &mut buffer) { None => return None, Some(m) => m.len(), }; @@ -246,7 +269,7 @@ impl Aez { tau: u32, data: &'a mut [u8], ) -> Option<&'a [u8]> { - decrypt(&self.key, nonce, associated_data, tau, data) + decrypt(&self, nonce, associated_data, tau, data) } } @@ -271,7 +294,7 @@ fn append_auth(data_len: usize, buffer: &mut [u8]) { } } -fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { +fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { // We treat tau as bytes, but according to the spec, tau is actually in bits. let tau_block = Block::from_int(tau as u128 * 8); let mut tweaks = vec![&tau_block.0, nonce]; @@ -280,14 +303,14 @@ fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { if buffer.len() == tau as usize { // As aez_prf only xor's the input in, we have to clear the buffer first buffer.fill(0); - aez_prf(key, &tweaks, buffer); + aez_prf(aez, &tweaks, buffer); } else { - encipher(key, &tweaks, buffer); + encipher(aez, &tweaks, buffer); } } fn decrypt<'a>( - key: &Key, + aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, @@ -302,7 +325,7 @@ fn decrypt<'a>( tweaks.extend(ad); if ciphertext.len() == tau as usize { - aez_prf(key, &tweaks, ciphertext); + aez_prf(aez, &tweaks, ciphertext); if is_zeroes(&ciphertext) { return Some(&[]); } else { @@ -310,7 +333,7 @@ fn decrypt<'a>( } } - decipher(key, &tweaks, ciphertext); + decipher(aez, &tweaks, ciphertext); let (m, auth) = ciphertext.split_at(ciphertext.len() - tau as usize); assert!(auth.len() == tau as usize); @@ -328,19 +351,19 @@ fn is_zeroes(data: &[u8]) -> bool { constant_time_eq(data, comparator) } -fn encipher(key: &Key, tweaks: Tweak, message: &mut [u8]) { +fn encipher(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { if message.len() < 256 / 8 { - encipher_aez_tiny(key, tweaks, message) + encipher_aez_tiny(aez, tweaks, message) } else { - encipher_aez_core(key, tweaks, message) + encipher_aez_core(aez, tweaks, message) } } -fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { +fn encipher_aez_tiny(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { let mu = message.len() * 8; assert!(mu < 256); let n = mu / 2; - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let round_count = match mu { 8 => 24u32, 16 => 16, @@ -360,7 +383,7 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { }; let i = if mu >= 128 { 6 } else { 7 }; for j in 0..round_count { - let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); + let right_ = (left ^ e(0, i, aez, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } if n % 8 == 0 { @@ -379,14 +402,14 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { } if mu < 128 { let mut c = Block::from_slice(&message); - c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); + c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE); message.copy_from_slice(&c.0[..mu / 8]); } } -fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { +fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { assert!(message.len() >= 32); - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let mut blocks = BlockAccessor::new(message); let (m_u, m_v, m_x, m_y, d) = ( blocks.m_u(), @@ -398,8 +421,8 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { let len_v = d.saturating_sub(128); let mut x = Block::NULL; - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_mi, raw_mi_) in blocks.pairs_mut() { e1_eval.advance(); @@ -417,22 +440,22 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - x = x ^ e(0, 4, key, m_u.pad(d.into())); + x = x ^ e(0, 4, aez, m_u.pad(d.into())); } _ => { - x = x ^ e(0, 4, key, m_u); - x = x ^ e(0, 5, key, m_v.pad(len_v.into())); + x = x ^ e(0, 4, aez, m_u); + x = x ^ e(0, 5, aez, m_v.pad(len_v.into())); } } - let s_x = m_x ^ delta ^ x ^ e(0, 1, key, m_y); - let s_y = m_y ^ e(-1, 1, key, s_x); + let s_x = m_x ^ delta ^ x ^ e(0, 1, aez, m_y); + let s_y = m_y ^ e(-1, 1, aez, s_x); let s = s_x ^ s_y; let mut y = Block::NULL; - let mut e2_eval = E::new(2, 0, key); - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e2_eval = E::new(2, 0, aez); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_wi, raw_xi) in blocks.pairs_mut() { e2_eval.advance(); e1_eval.advance(); @@ -455,19 +478,19 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - c_u = (m_u ^ e(-1, 4, key, s)).clip(d.into()); - y = y ^ e(0, 4, key, c_u.pad(d.into())); + c_u = (m_u ^ e(-1, 4, aez, s)).clip(d.into()); + y = y ^ e(0, 4, aez, c_u.pad(d.into())); } _ => { - c_u = m_u ^ e(-1, 4, key, s); - c_v = (m_v ^ e(-1, 5, key, s)).clip(len_v.into()); - y = y ^ e(0, 4, key, c_u); - y = y ^ e(0, 5, key, c_v.pad(len_v.into())); + c_u = m_u ^ e(-1, 4, aez, s); + c_v = (m_v ^ e(-1, 5, aez, s)).clip(len_v.into()); + y = y ^ e(0, 4, aez, c_u); + y = y ^ e(0, 5, aez, c_v.pad(len_v.into())); } } - let c_y = s_x ^ e(-1, 2, key, s_y); - let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y); + let c_y = s_x ^ e(-1, 2, aez, s_y); + let c_x = s_y ^ delta ^ y ^ e(0, 2, aez, c_y); blocks.set_m_u(c_u); blocks.set_m_v(c_v); @@ -475,19 +498,19 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { blocks.set_m_y(c_y); } -fn decipher(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { if buffer.len() < 256 / 8 { - decipher_aez_tiny(key, tweaks, buffer); + decipher_aez_tiny(aez, tweaks, buffer); } else { - decipher_aez_core(key, tweaks, buffer); + decipher_aez_core(aez, tweaks, buffer); } } -fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher_aez_tiny(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let mu = buffer.len() * 8; assert!(mu < 256); let n = mu / 2; - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let round_count = match mu { 8 => 24u32, 16 => 16, @@ -497,7 +520,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { if mu < 128 { let mut c = Block::from_slice(buffer); - c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); + c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE); buffer.copy_from_slice(&c.0[..mu / 8]); } @@ -512,7 +535,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { }; let i = if mu >= 128 { 6 } else { 7 }; for j in (0..round_count).rev() { - let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); + let right_ = (left ^ e(0, i, aez, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } @@ -532,9 +555,9 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { } } -fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { assert!(buffer.len() >= 32); - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let mut blocks = BlockAccessor::new(buffer); let (c_u, c_v, c_x, c_y, d) = ( blocks.m_u(), @@ -546,8 +569,8 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { let len_v = d.saturating_sub(128); let mut y = Block::NULL; - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_ci, raw_ci_) in blocks.pairs_mut() { e1_eval.advance(); let ci = Block::from(*raw_ci); @@ -564,22 +587,22 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - y = y ^ e(0, 4, key, c_u.pad(d.into())); + y = y ^ e(0, 4, aez, c_u.pad(d.into())); } _ => { - y = y ^ e(0, 4, key, c_u); - y = y ^ e(0, 5, key, c_v.pad(len_v.into())); + y = y ^ e(0, 4, aez, c_u); + y = y ^ e(0, 5, aez, c_v.pad(len_v.into())); } } - let s_x = c_x ^ delta ^ y ^ e(0, 2, key, c_y); - let s_y = c_y ^ e(-1, 2, key, s_x); + let s_x = c_x ^ delta ^ y ^ e(0, 2, aez, c_y); + let s_y = c_y ^ e(-1, 2, aez, s_x); let s = s_x ^ s_y; let mut x = Block::NULL; - let mut e2_eval = E::new(2, 0, key); - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e2_eval = E::new(2, 0, aez); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_wi, raw_yi) in blocks.pairs_mut() { e2_eval.advance(); e1_eval.advance(); @@ -603,19 +626,19 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - m_u = (c_u ^ e(-1, 4, key, s)).clip(d.into()); - x = x ^ e(0, 4, key, m_u.pad(d.into())); + m_u = (c_u ^ e(-1, 4, aez, s)).clip(d.into()); + x = x ^ e(0, 4, aez, m_u.pad(d.into())); } _ => { - m_u = c_u ^ e(-1, 4, key, s); - m_v = (c_v ^ e(-1, 5, key, s)).clip(len_v.into()); - x = x ^ e(0, 4, key, m_u); - x = x ^ e(0, 5, key, m_v.pad(len_v.into())); + m_u = c_u ^ e(-1, 4, aez, s); + m_v = (c_v ^ e(-1, 5, aez, s)).clip(len_v.into()); + x = x ^ e(0, 4, aez, m_u); + x = x ^ e(0, 5, aez, m_v.pad(len_v.into())); } } - let m_y = s_x ^ e(-1, 1, key, s_y); - let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y); + let m_y = s_x ^ e(-1, 1, aez, s_y); + let m_x = s_y ^ delta ^ x ^ e(0, 1, aez, m_y); blocks.set_m_u(m_u); blocks.set_m_v(m_v); @@ -635,7 +658,7 @@ fn pad_to_blocks(value: &[u8]) -> Vec<Block> { blocks } -fn aez_hash(key: &Key, tweaks: Tweak) -> Block { +fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { let mut hash = Block::NULL; for (i, tweak) in tweaks.iter().enumerate() { // Adjust for zero-based vs one-based indexing @@ -644,14 +667,14 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this // as well, but then we need to fiddle with getting an empty chunk from an empty iterator. if tweak.is_empty() { - hash = hash ^ e(j.try_into().unwrap(), 0, key, Block::ONE); + hash = hash ^ e(j.try_into().unwrap(), 0, aez, Block::ONE); } else if tweak.len() % 16 == 0 { for (l, chunk) in tweak.chunks(16).enumerate() { hash = hash ^ e( j.try_into().unwrap(), (l + 1).try_into().unwrap(), - key, + aez, Block::from_slice(chunk), ); } @@ -666,7 +689,7 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { } else { (l + 1).try_into().unwrap() }, - key, + aez, *chunk, ); } @@ -676,11 +699,11 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { } /// XOR's the result of aez_prf into the given buffer -fn aez_prf(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let mut index = 0u128; - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); for chunk in buffer.chunks_mut(16) { - let block = e(-1, 3, key, delta ^ Block::from_int(index)); + let block = e(-1, 3, aez, delta ^ Block::from_int(index)); for (a, b) in chunk.iter_mut().zip(block.0.iter()) { *a ^= b; } @@ -692,11 +715,9 @@ fn aez_prf(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { /// /// As we usually need multiple values with a fixed j and ascending i, this struct saves the /// temporary values and makes it much faster to compute E_K^{j, i+1}, E_K^{j, i+2}, ... -struct E { - key_l: Block, - key_ls: [Block; 8], +struct E<'a> { + aez: &'a Aez, state: Estate, - aes: aesround::AesImpl, } #[derive(Clone, Debug)] @@ -711,11 +732,9 @@ enum Estate { }, } -impl E { +impl<'a> E<'a> { /// Create a new "suspended" computation of E_K^{j,i}. - fn new(j: i32, i: u32, key: &Key) -> Self { - let (key_i, key_j, key_l) = split_key(key); - let aes = aesround::AesImpl::new(key_i, key_j, key_l); + fn new(j: i32, i: u32, aez: &'a Aez) -> Self { let state = if j == -1 { Estate::Neg { i } } else { @@ -723,25 +742,13 @@ impl E { let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 }; Estate::Pos { i, - kj_t_j: key_j * j, - ki_p_i: key_i.exp(exponent), + kj_t_j: aez.key_j * j, + ki_p_i: aez.key_i.exp(exponent), } }; - let key_ls = [ - key_l * 0, - key_l * 1, - key_l * 2, - key_l * 3, - key_l * 4, - key_l * 5, - key_l * 6, - key_l * 7, - ]; E { - key_l, - key_ls, + aez, state, - aes, } } @@ -749,12 +756,12 @@ impl E { fn eval(&self, block: Block) -> Block { match self.state { Estate::Neg { i } => { - let delta = self.key_l * i; - self.aes.aes10(block ^ delta) + let delta = self.aez.key_l * i; + self.aez.aes.aes10(block ^ delta) } Estate::Pos { i, kj_t_j, ki_p_i } => { - let delta = kj_t_j ^ ki_p_i ^ self.key_ls[i as usize % 8]; - self.aes.aes4(block ^ delta) + let delta = kj_t_j ^ ki_p_i ^ self.aez.key_l_multiples[i as usize % 8]; + self.aez.aes.aes4(block ^ delta) } } } @@ -781,8 +788,8 @@ impl E { } /// Shorthand to get E_K^{j,i}(block) -fn e(j: i32, i: u32, key: &Key, block: Block) -> Block { - E::new(j, i, key).eval(block) +fn e(j: i32, i: u32, aez: &Aez, block: Block) -> Block { + E::new(j, i, aez).eval(block) } fn split_key(key: &Key) -> (Block, Block, Block) { @@ -813,11 +820,11 @@ mod test { for (k, j, i, a, b) in testvectors::E_VECTORS { let name = format!("e({j}, {i}, {k}, {a})"); let k = hex::decode(k).unwrap(); - let k = k.as_slice().try_into().unwrap(); + let aez = Aez::new(k.as_slice()); let a = hex::decode(a).unwrap(); let a = Block::from_slice(&a); let b = hex::decode(b).unwrap(); - assert_eq!(&e(*j, *i, k, a).0, b.as_slice(), "{name}"); + assert_eq!(&e(*j, *i, &aez, a).0, b.as_slice(), "{name}"); } } @@ -826,7 +833,7 @@ mod test { for (k, tau, tw, v) in testvectors::HASH_VECTORS { let name = format!("aez_hash({k}, {tau}, {tw:?})"); let k = hex::decode(k).unwrap(); - let k = k.as_slice().try_into().unwrap(); + let aez = Aez::new(k.as_slice()); let v = hex::decode(v).unwrap(); let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)]; @@ -835,14 +842,15 @@ mod test { } let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>(); - assert_eq!(&aez_hash(&k, &tweaks).0, v.as_slice(), "{name}"); + assert_eq!(&aez_hash(&aez, &tweaks).0, v.as_slice(), "{name}"); } } fn vec_encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> { + let aez = Aez::new(key); let mut v = vec![0; message.len() + tau as usize]; v[..message.len()].copy_from_slice(message); - encrypt(key, nonce, ad, tau, &mut v); + encrypt(&aez, nonce, ad, tau, &mut v); v } @@ -853,8 +861,9 @@ mod test { tau: u32, ciphertext: &[u8], ) -> Option<Vec<u8>> { + let aez = Aez::new(key); let mut v = Vec::from(ciphertext); - let len = match decrypt(key, nonce, ad, tau, &mut v) { + let len = match decrypt(&aez, nonce, ad, tau, &mut v) { None => return None, Some(m) => m.len(), }; |