diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lib.rs | 227 |
1 files changed, 118 insertions, 109 deletions
@@ -102,7 +102,11 @@ static ZEROES: [u8; 1024] = [0; 1024]; /// AEZ encryption scheme. pub struct Aez { - key: Key, + key_i: Block, + key_j: Block, + key_l: Block, + key_l_multiples: [Block; 8], + aes: aesround::AesImpl, } impl Aez { @@ -113,7 +117,26 @@ impl Aez { /// If you provide a key of the correct length (48 bytes), no expansion is done and the key is /// taken as-is. pub fn new(key: &[u8]) -> Self { - Aez { key: extract(key) } + let key = extract(key); + let (key_i, key_j, key_l) = split_key(&key); + let aes = aesround::AesImpl::new(key_i, key_j, key_l); + let key_l_multiples = [ + key_l * 0, + key_l * 1, + key_l * 2, + key_l * 3, + key_l * 4, + key_l * 5, + key_l * 6, + key_l * 7, + ]; + Aez { + key_i, + key_j, + key_l, + key_l_multiples, + aes, + } } /// Encrypt the given data. @@ -164,7 +187,7 @@ impl Aez { data: &mut Vec<u8>, ) { data.resize(data.len() + tau as usize, 0); - encrypt(&self.key, nonce, associated_data, tau, data); + encrypt(&self, nonce, associated_data, tau, data); } /// Encrypts the data inplace. @@ -183,7 +206,7 @@ impl Aez { assert!(buffer.len() >= tau as usize); let data_len = buffer.len() - tau as usize; append_auth(data_len, buffer); - encrypt(&self.key, nonce, associated_data, tau as u32, buffer); + encrypt(&self, nonce, associated_data, tau as u32, buffer); } /// Encrypts the data in the given buffer, writing the output to the given output buffer. @@ -203,7 +226,7 @@ impl Aez { let tau = output.len() - input.len(); output[..input.len()].copy_from_slice(input); append_auth(input.len(), output); - encrypt(&self.key, nonce, associated_data, tau as u32, output); + encrypt(&self, nonce, associated_data, tau as u32, output); } /// Decrypts the given ciphertext. @@ -226,7 +249,7 @@ impl Aez { data: &[u8], ) -> Option<Vec<u8>> { let mut buffer = Vec::from(data); - let len = match decrypt(&self.key, nonce, associated_data, tau, &mut buffer) { + let len = match decrypt(&self, nonce, associated_data, tau, &mut buffer) { None => return None, Some(m) => m.len(), }; @@ -246,7 +269,7 @@ impl Aez { tau: u32, data: &'a mut [u8], ) -> Option<&'a [u8]> { - decrypt(&self.key, nonce, associated_data, tau, data) + decrypt(&self, nonce, associated_data, tau, data) } } @@ -271,7 +294,7 @@ fn append_auth(data_len: usize, buffer: &mut [u8]) { } } -fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { +fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { // We treat tau as bytes, but according to the spec, tau is actually in bits. let tau_block = Block::from_int(tau as u128 * 8); let mut tweaks = vec![&tau_block.0, nonce]; @@ -280,14 +303,14 @@ fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { if buffer.len() == tau as usize { // As aez_prf only xor's the input in, we have to clear the buffer first buffer.fill(0); - aez_prf(key, &tweaks, buffer); + aez_prf(aez, &tweaks, buffer); } else { - encipher(key, &tweaks, buffer); + encipher(aez, &tweaks, buffer); } } fn decrypt<'a>( - key: &Key, + aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, @@ -302,7 +325,7 @@ fn decrypt<'a>( tweaks.extend(ad); if ciphertext.len() == tau as usize { - aez_prf(key, &tweaks, ciphertext); + aez_prf(aez, &tweaks, ciphertext); if is_zeroes(&ciphertext) { return Some(&[]); } else { @@ -310,7 +333,7 @@ fn decrypt<'a>( } } - decipher(key, &tweaks, ciphertext); + decipher(aez, &tweaks, ciphertext); let (m, auth) = ciphertext.split_at(ciphertext.len() - tau as usize); assert!(auth.len() == tau as usize); @@ -328,19 +351,19 @@ fn is_zeroes(data: &[u8]) -> bool { constant_time_eq(data, comparator) } -fn encipher(key: &Key, tweaks: Tweak, message: &mut [u8]) { +fn encipher(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { if message.len() < 256 / 8 { - encipher_aez_tiny(key, tweaks, message) + encipher_aez_tiny(aez, tweaks, message) } else { - encipher_aez_core(key, tweaks, message) + encipher_aez_core(aez, tweaks, message) } } -fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { +fn encipher_aez_tiny(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { let mu = message.len() * 8; assert!(mu < 256); let n = mu / 2; - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let round_count = match mu { 8 => 24u32, 16 => 16, @@ -360,7 +383,7 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { }; let i = if mu >= 128 { 6 } else { 7 }; for j in 0..round_count { - let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); + let right_ = (left ^ e(0, i, aez, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } if n % 8 == 0 { @@ -379,14 +402,14 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) { } if mu < 128 { let mut c = Block::from_slice(&message); - c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); + c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE); message.copy_from_slice(&c.0[..mu / 8]); } } -fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { +fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { assert!(message.len() >= 32); - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let mut blocks = BlockAccessor::new(message); let (m_u, m_v, m_x, m_y, d) = ( blocks.m_u(), @@ -398,8 +421,8 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { let len_v = d.saturating_sub(128); let mut x = Block::NULL; - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_mi, raw_mi_) in blocks.pairs_mut() { e1_eval.advance(); @@ -417,22 +440,22 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - x = x ^ e(0, 4, key, m_u.pad(d.into())); + x = x ^ e(0, 4, aez, m_u.pad(d.into())); } _ => { - x = x ^ e(0, 4, key, m_u); - x = x ^ e(0, 5, key, m_v.pad(len_v.into())); + x = x ^ e(0, 4, aez, m_u); + x = x ^ e(0, 5, aez, m_v.pad(len_v.into())); } } - let s_x = m_x ^ delta ^ x ^ e(0, 1, key, m_y); - let s_y = m_y ^ e(-1, 1, key, s_x); + let s_x = m_x ^ delta ^ x ^ e(0, 1, aez, m_y); + let s_y = m_y ^ e(-1, 1, aez, s_x); let s = s_x ^ s_y; let mut y = Block::NULL; - let mut e2_eval = E::new(2, 0, key); - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e2_eval = E::new(2, 0, aez); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_wi, raw_xi) in blocks.pairs_mut() { e2_eval.advance(); e1_eval.advance(); @@ -455,19 +478,19 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - c_u = (m_u ^ e(-1, 4, key, s)).clip(d.into()); - y = y ^ e(0, 4, key, c_u.pad(d.into())); + c_u = (m_u ^ e(-1, 4, aez, s)).clip(d.into()); + y = y ^ e(0, 4, aez, c_u.pad(d.into())); } _ => { - c_u = m_u ^ e(-1, 4, key, s); - c_v = (m_v ^ e(-1, 5, key, s)).clip(len_v.into()); - y = y ^ e(0, 4, key, c_u); - y = y ^ e(0, 5, key, c_v.pad(len_v.into())); + c_u = m_u ^ e(-1, 4, aez, s); + c_v = (m_v ^ e(-1, 5, aez, s)).clip(len_v.into()); + y = y ^ e(0, 4, aez, c_u); + y = y ^ e(0, 5, aez, c_v.pad(len_v.into())); } } - let c_y = s_x ^ e(-1, 2, key, s_y); - let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y); + let c_y = s_x ^ e(-1, 2, aez, s_y); + let c_x = s_y ^ delta ^ y ^ e(0, 2, aez, c_y); blocks.set_m_u(c_u); blocks.set_m_v(c_v); @@ -475,19 +498,19 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) { blocks.set_m_y(c_y); } -fn decipher(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { if buffer.len() < 256 / 8 { - decipher_aez_tiny(key, tweaks, buffer); + decipher_aez_tiny(aez, tweaks, buffer); } else { - decipher_aez_core(key, tweaks, buffer); + decipher_aez_core(aez, tweaks, buffer); } } -fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher_aez_tiny(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let mu = buffer.len() * 8; assert!(mu < 256); let n = mu / 2; - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let round_count = match mu { 8 => 24u32, 16 => 16, @@ -497,7 +520,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { if mu < 128 { let mut c = Block::from_slice(buffer); - c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); + c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE); buffer.copy_from_slice(&c.0[..mu / 8]); } @@ -512,7 +535,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { }; let i = if mu >= 128 { 6 } else { 7 }; for j in (0..round_count).rev() { - let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); + let right_ = (left ^ e(0, i, aez, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n); (left, right) = (right, right_); } @@ -532,9 +555,9 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { } } -fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { assert!(buffer.len() >= 32); - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); let mut blocks = BlockAccessor::new(buffer); let (c_u, c_v, c_x, c_y, d) = ( blocks.m_u(), @@ -546,8 +569,8 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { let len_v = d.saturating_sub(128); let mut y = Block::NULL; - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_ci, raw_ci_) in blocks.pairs_mut() { e1_eval.advance(); let ci = Block::from(*raw_ci); @@ -564,22 +587,22 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - y = y ^ e(0, 4, key, c_u.pad(d.into())); + y = y ^ e(0, 4, aez, c_u.pad(d.into())); } _ => { - y = y ^ e(0, 4, key, c_u); - y = y ^ e(0, 5, key, c_v.pad(len_v.into())); + y = y ^ e(0, 4, aez, c_u); + y = y ^ e(0, 5, aez, c_v.pad(len_v.into())); } } - let s_x = c_x ^ delta ^ y ^ e(0, 2, key, c_y); - let s_y = c_y ^ e(-1, 2, key, s_x); + let s_x = c_x ^ delta ^ y ^ e(0, 2, aez, c_y); + let s_y = c_y ^ e(-1, 2, aez, s_x); let s = s_x ^ s_y; let mut x = Block::NULL; - let mut e2_eval = E::new(2, 0, key); - let mut e1_eval = E::new(1, 0, key); - let e0_eval = E::new(0, 0, key); + let mut e2_eval = E::new(2, 0, aez); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); for (raw_wi, raw_yi) in blocks.pairs_mut() { e2_eval.advance(); e1_eval.advance(); @@ -603,19 +626,19 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - m_u = (c_u ^ e(-1, 4, key, s)).clip(d.into()); - x = x ^ e(0, 4, key, m_u.pad(d.into())); + m_u = (c_u ^ e(-1, 4, aez, s)).clip(d.into()); + x = x ^ e(0, 4, aez, m_u.pad(d.into())); } _ => { - m_u = c_u ^ e(-1, 4, key, s); - m_v = (c_v ^ e(-1, 5, key, s)).clip(len_v.into()); - x = x ^ e(0, 4, key, m_u); - x = x ^ e(0, 5, key, m_v.pad(len_v.into())); + m_u = c_u ^ e(-1, 4, aez, s); + m_v = (c_v ^ e(-1, 5, aez, s)).clip(len_v.into()); + x = x ^ e(0, 4, aez, m_u); + x = x ^ e(0, 5, aez, m_v.pad(len_v.into())); } } - let m_y = s_x ^ e(-1, 1, key, s_y); - let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y); + let m_y = s_x ^ e(-1, 1, aez, s_y); + let m_x = s_y ^ delta ^ x ^ e(0, 1, aez, m_y); blocks.set_m_u(m_u); blocks.set_m_v(m_v); @@ -635,7 +658,7 @@ fn pad_to_blocks(value: &[u8]) -> Vec<Block> { blocks } -fn aez_hash(key: &Key, tweaks: Tweak) -> Block { +fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { let mut hash = Block::NULL; for (i, tweak) in tweaks.iter().enumerate() { // Adjust for zero-based vs one-based indexing @@ -644,14 +667,14 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this // as well, but then we need to fiddle with getting an empty chunk from an empty iterator. if tweak.is_empty() { - hash = hash ^ e(j.try_into().unwrap(), 0, key, Block::ONE); + hash = hash ^ e(j.try_into().unwrap(), 0, aez, Block::ONE); } else if tweak.len() % 16 == 0 { for (l, chunk) in tweak.chunks(16).enumerate() { hash = hash ^ e( j.try_into().unwrap(), (l + 1).try_into().unwrap(), - key, + aez, Block::from_slice(chunk), ); } @@ -666,7 +689,7 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { } else { (l + 1).try_into().unwrap() }, - key, + aez, *chunk, ); } @@ -676,11 +699,11 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block { } /// XOR's the result of aez_prf into the given buffer -fn aez_prf(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { +fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let mut index = 0u128; - let delta = aez_hash(key, tweaks); + let delta = aez_hash(aez, tweaks); for chunk in buffer.chunks_mut(16) { - let block = e(-1, 3, key, delta ^ Block::from_int(index)); + let block = e(-1, 3, aez, delta ^ Block::from_int(index)); for (a, b) in chunk.iter_mut().zip(block.0.iter()) { *a ^= b; } @@ -692,11 +715,9 @@ fn aez_prf(key: &Key, tweaks: Tweak, buffer: &mut [u8]) { /// /// As we usually need multiple values with a fixed j and ascending i, this struct saves the /// temporary values and makes it much faster to compute E_K^{j, i+1}, E_K^{j, i+2}, ... -struct E { - key_l: Block, - key_ls: [Block; 8], +struct E<'a> { + aez: &'a Aez, state: Estate, - aes: aesround::AesImpl, } #[derive(Clone, Debug)] @@ -711,11 +732,9 @@ enum Estate { }, } -impl E { +impl<'a> E<'a> { /// Create a new "suspended" computation of E_K^{j,i}. - fn new(j: i32, i: u32, key: &Key) -> Self { - let (key_i, key_j, key_l) = split_key(key); - let aes = aesround::AesImpl::new(key_i, key_j, key_l); + fn new(j: i32, i: u32, aez: &'a Aez) -> Self { let state = if j == -1 { Estate::Neg { i } } else { @@ -723,25 +742,13 @@ impl E { let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 }; Estate::Pos { i, - kj_t_j: key_j * j, - ki_p_i: key_i.exp(exponent), + kj_t_j: aez.key_j * j, + ki_p_i: aez.key_i.exp(exponent), } }; - let key_ls = [ - key_l * 0, - key_l * 1, - key_l * 2, - key_l * 3, - key_l * 4, - key_l * 5, - key_l * 6, - key_l * 7, - ]; E { - key_l, - key_ls, + aez, state, - aes, } } @@ -749,12 +756,12 @@ impl E { fn eval(&self, block: Block) -> Block { match self.state { Estate::Neg { i } => { - let delta = self.key_l * i; - self.aes.aes10(block ^ delta) + let delta = self.aez.key_l * i; + self.aez.aes.aes10(block ^ delta) } Estate::Pos { i, kj_t_j, ki_p_i } => { - let delta = kj_t_j ^ ki_p_i ^ self.key_ls[i as usize % 8]; - self.aes.aes4(block ^ delta) + let delta = kj_t_j ^ ki_p_i ^ self.aez.key_l_multiples[i as usize % 8]; + self.aez.aes.aes4(block ^ delta) } } } @@ -781,8 +788,8 @@ impl E { } /// Shorthand to get E_K^{j,i}(block) -fn e(j: i32, i: u32, key: &Key, block: Block) -> Block { - E::new(j, i, key).eval(block) +fn e(j: i32, i: u32, aez: &Aez, block: Block) -> Block { + E::new(j, i, aez).eval(block) } fn split_key(key: &Key) -> (Block, Block, Block) { @@ -813,11 +820,11 @@ mod test { for (k, j, i, a, b) in testvectors::E_VECTORS { let name = format!("e({j}, {i}, {k}, {a})"); let k = hex::decode(k).unwrap(); - let k = k.as_slice().try_into().unwrap(); + let aez = Aez::new(k.as_slice()); let a = hex::decode(a).unwrap(); let a = Block::from_slice(&a); let b = hex::decode(b).unwrap(); - assert_eq!(&e(*j, *i, k, a).0, b.as_slice(), "{name}"); + assert_eq!(&e(*j, *i, &aez, a).0, b.as_slice(), "{name}"); } } @@ -826,7 +833,7 @@ mod test { for (k, tau, tw, v) in testvectors::HASH_VECTORS { let name = format!("aez_hash({k}, {tau}, {tw:?})"); let k = hex::decode(k).unwrap(); - let k = k.as_slice().try_into().unwrap(); + let aez = Aez::new(k.as_slice()); let v = hex::decode(v).unwrap(); let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)]; @@ -835,14 +842,15 @@ mod test { } let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>(); - assert_eq!(&aez_hash(&k, &tweaks).0, v.as_slice(), "{name}"); + assert_eq!(&aez_hash(&aez, &tweaks).0, v.as_slice(), "{name}"); } } fn vec_encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> { + let aez = Aez::new(key); let mut v = vec![0; message.len() + tau as usize]; v[..message.len()].copy_from_slice(message); - encrypt(key, nonce, ad, tau, &mut v); + encrypt(&aez, nonce, ad, tau, &mut v); v } @@ -853,8 +861,9 @@ mod test { tau: u32, ciphertext: &[u8], ) -> Option<Vec<u8>> { + let aez = Aez::new(key); let mut v = Vec::from(ciphertext); - let len = match decrypt(key, nonce, ad, tau, &mut v) { + let len = match decrypt(&aez, nonce, ad, tau, &mut v) { None => return None, Some(m) => m.len(), }; |