diff options
-rw-r--r-- | src/lib.rs | 69 |
1 files changed, 35 insertions, 34 deletions
@@ -102,6 +102,8 @@ //! | +simd, target-cpu=native | 3.3272 GiB/s | +592.01% | //! | `aez` crate | 4.8996 GiB/s | | +use std::iter; + use constant_time_eq::constant_time_eq; mod accessor; @@ -115,7 +117,6 @@ use accessor::BlockAccessor; use aesround::AesRound; use block::Block; type Key = [u8; 48]; -type Tweak<'a> = &'a [&'a [u8]]; static ZEROES: [u8; 1024] = [0; 1024]; @@ -322,23 +323,19 @@ fn append_auth(data_len: usize, buffer: &mut [u8]) { } } -fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { +fn encrypt<'t, A: AsRef<[u8]> + 't, T: IntoIterator<Item = &'t A>>( + aez: &Aez, + nonce: &[u8], + ad: T, + tau: u32, + buffer: &mut [u8], +) { // We treat tau as bytes, but according to the spec, tau is actually in bits. let tau_block = Block::from_int(tau as u128 * 8); let tau_bytes = tau_block.bytes(); - let mut tweaks_vec; - // We optimize for the common case of having no associated data, or having one item of - // associated data (which is all the reference implementation supports anyway). If there's more - // associated data, we cave in and allocate a vec. - let tweaks = match ad.len() { - 0 => &[&tau_bytes, nonce] as &[&[u8]], - 1 => &[&tau_bytes, nonce, ad[0]], - _ => { - tweaks_vec = vec![&tau_bytes, nonce]; - tweaks_vec.extend(ad); - &tweaks_vec - } - }; + let tweaks = iter::once(&tau_bytes as &[_]) + .chain(iter::once(nonce)) + .chain(ad.into_iter().map(|r| r.as_ref())); assert!(buffer.len() >= tau as usize); if buffer.len() == tau as usize { // As aez_prf only xor's the input in, we have to clear the buffer first @@ -349,10 +346,10 @@ fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { } } -fn decrypt<'a>( +fn decrypt<'a, 't, A: AsRef<[u8]> + 't, T: IntoIterator<Item = &'t A>>( aez: &Aez, nonce: &[u8], - ad: &[&[u8]], + ad: T, tau: u32, ciphertext: &'a mut [u8], ) -> Option<&'a [u8]> { @@ -362,16 +359,9 @@ fn decrypt<'a>( let tau_block = Block::from_int(tau * 8); let tau_bytes = tau_block.bytes(); - let mut tweaks_vec; - let tweaks = match ad.len() { - 0 => &[&tau_bytes, nonce] as &[&[u8]], - 1 => &[&tau_bytes, nonce, ad[0]], - _ => { - tweaks_vec = vec![&tau_bytes, nonce]; - tweaks_vec.extend(ad); - &tweaks_vec - } - }; + let tweaks = iter::once(&tau_bytes as &[_]) + .chain(iter::once(nonce)) + .chain(ad.into_iter().map(|x| x.as_ref())); if ciphertext.len() == tau as usize { aez_prf(aez, tweaks, ciphertext); @@ -400,7 +390,7 @@ fn is_zeroes(data: &[u8]) -> bool { constant_time_eq(data, comparator) } -fn encipher(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { +fn encipher<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, message: &mut [u8]) { if message.len() < 256 / 8 { cipher_aez_tiny(Mode::Encipher, aez, tweaks, message) } else { @@ -408,7 +398,7 @@ fn encipher(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { } } -fn decipher(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, buffer: &mut [u8]) { if buffer.len() < 256 / 8 { cipher_aez_tiny(Mode::Decipher, aez, tweaks, buffer); } else { @@ -416,7 +406,12 @@ fn decipher(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { } } -fn cipher_aez_tiny(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { +fn cipher_aez_tiny<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + mode: Mode, + aez: &Aez, + tweaks: T, + message: &mut [u8], +) { let mu = message.len() * 8; assert!(mu < 256); let n = mu / 2; @@ -586,7 +581,12 @@ fn pass_two(aez: &Aez, blocks: &mut BlockAccessor, s: Block) -> Block { y } -fn cipher_aez_core(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { +fn cipher_aez_core<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + mode: Mode, + aez: &Aez, + tweaks: T, + message: &mut [u8], +) { assert!(message.len() >= 32); let delta = aez_hash(aez, tweaks); let mut blocks = BlockAccessor::new(message); @@ -672,9 +672,10 @@ fn pad_to_blocks(value: &[u8]) -> impl Iterator<Item = Block> { }) } -fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { +fn aez_hash<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T) -> Block { let mut hash = Block::null(); - for (i, tweak) in tweaks.iter().enumerate() { + for (i, tweak) in tweaks.into_iter().enumerate() { + let tweak = tweak.as_ref(); // Adjust for zero-based vs one-based indexing let j = i + 2 + 1; let mut ej = E::new(j.try_into().unwrap(), 0, aez); @@ -704,7 +705,7 @@ fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { } /// XOR's the result of aez_prf into the given buffer -fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { +fn aez_prf<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, buffer: &mut [u8]) { let mut index = Block::null(); let delta = aez_hash(aez, tweaks); for chunk in buffer.chunks_exact_mut(16) { |