diff options
| -rw-r--r-- | .woodpecker/tests.yaml | 13 | ||||
| -rw-r--r-- | CHANGELOG | 13 | ||||
| -rw-r--r-- | Cargo.lock | 2 | ||||
| -rw-r--r-- | Cargo.toml | 14 | ||||
| -rw-r--r-- | benches/primitives.rs | 60 | ||||
| -rw-r--r-- | fuzz/Cargo.lock | 2 | ||||
| -rw-r--r-- | src/accessor.rs | 62 | ||||
| -rw-r--r-- | src/aesround.rs | 8 | ||||
| -rw-r--r-- | src/block.rs | 76 | ||||
| -rw-r--r-- | src/lib.rs | 513 |
10 files changed, 604 insertions, 159 deletions
diff --git a/.woodpecker/tests.yaml b/.woodpecker/tests.yaml new file mode 100644 index 0000000..45662d9 --- /dev/null +++ b/.woodpecker/tests.yaml @@ -0,0 +1,13 @@ +when: + - event: push + branch: master + +steps: + - name: test + image: rust + commands: + - cargo test + - name: test(simd) + image: rustlang/rust:nightly + commands: + - cargo test --features=simd diff --git a/CHANGELOG b/CHANGELOG new file mode 100644 index 0000000..f9e6626 --- /dev/null +++ b/CHANGELOG @@ -0,0 +1,13 @@ +v0.2.1: + +- Documentation fixes +- Small speed improvements + +v0.2.0: + +- Added the simd feature +- Speed improvements + +v0.1.0: + +- Initial release @@ -718,7 +718,7 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "zears" -version = "0.1.0" +version = "0.2.1" dependencies = [ "aes", "blake2", @@ -1,6 +1,6 @@ [package] name = "zears" -version = "0.1.0" +version = "0.2.1" edition = "2024" license = "MIT" description = "Rust implementation of the AEZ v5 cipher." @@ -10,12 +10,15 @@ keywords = ["aez", "aezv5", "aead", "encryption", "cipher"] categories = ["algorithms", "cryptography"] [features] +default = ["std"] simd = [] +std = ["constant_time_eq/std", "blake2/std"] +primitives = [] [dependencies] aes = { version = "0.8.4", features = ["hazmat"] } -blake2 = "0.10.6" -constant_time_eq = "0.4.2" +blake2 = { version = "0.10.6", default-features = false } +constant_time_eq = { version = "0.4.2", default-features = false } cpufeatures = "0.2.17" [dev-dependencies] @@ -25,3 +28,8 @@ criterion = "0.5.1" [[bench]] name = "zears" harness = false + +[[bench]] +name = "primitives" +required-features = ["primitives"] +harness = false diff --git a/benches/primitives.rs b/benches/primitives.rs new file mode 100644 index 0000000..09f926a --- /dev/null +++ b/benches/primitives.rs @@ -0,0 +1,60 @@ +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; + +use zears::{Aez, primitives}; + +fn bench(c: &mut Criterion) { + let mut group = c.benchmark_group("primitives"); + + const KB: usize = 1024; + let aez = Aez::new(&[0u8; 48]); + + for size in [0, 16, 32, 64, KB].into_iter() { + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_function(BenchmarkId::new("aez_hash", size), |b| { + let tweak = vec![0u8; size]; + b.iter(|| primitives::aez_hash(&aez, [&tweak])) + }); + + // Make sure we also hit the path for tweaks that are not exactly block sized + let size = size + 8; + + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_function(BenchmarkId::new("aez_hash", size), |b| { + let tweak = vec![0u8; size]; + b.iter(|| primitives::aez_hash(&aez, [&tweak])) + }); + } + + for size in [KB, 2 * KB, 4 * KB, 8 * KB, 16 * KB].into_iter() { + group.throughput(Throughput::Bytes(size as u64)); + + group.bench_function(BenchmarkId::new("aez_prf", size), |b| { + let mut buffer = vec![0u8; size]; + let tweak: [&[u8]; 0] = []; + b.iter(|| primitives::aez_prf(&aez, tweak, &mut buffer)) + }); + + group.bench_function(BenchmarkId::new("encipher", size), |b| { + let mut buffer = vec![0u8; size]; + let tweak: [&[u8]; 0] = []; + b.iter(|| primitives::encipher(&aez, tweak, &mut buffer)) + }); + + group.bench_function(BenchmarkId::new("decipher", size), |b| { + let mut buffer = vec![0u8; size]; + let tweak: [&[u8]; 0] = []; + b.iter(|| primitives::decipher(&aez, tweak, &mut buffer)) + }); + } + + group.finish(); + + c.bench_function("primitives/e", |b| { + b.iter(|| primitives::e(1, 1, &aez, [0; 16])) + }); +} + +criterion_group!(benches, bench); +criterion_main!(benches); diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index 1fce2f7..b532992 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -269,7 +269,7 @@ dependencies = [ [[package]] name = "zears" -version = "0.1.0" +version = "0.2.1" dependencies = [ "aes", "blake2", diff --git a/src/accessor.rs b/src/accessor.rs index 24905af..e7e6de6 100644 --- a/src/accessor.rs +++ b/src/accessor.rs @@ -8,6 +8,8 @@ pub struct BlockAccessor<'a> { num_block_pairs: usize, } +const BIG_CHUNK_SIZE: usize = 8 * 32; + impl<'a> BlockAccessor<'a> { pub fn new(message: &'a mut [u8]) -> Self { let num_block_pairs = (message.len() - 16 - 16) / 32; @@ -71,13 +73,67 @@ impl<'a> BlockAccessor<'a> { self.data[start + 16..start + 32].copy_from_slice(&m_y.bytes()); } - pub fn pairs_mut<'b>( - &'b mut self, - ) -> impl Iterator<Item = (&'b mut [u8; 16], &'b mut [u8; 16])> { + pub fn pairs_mut(&mut self) -> impl Iterator<Item = (&mut [u8; 16], &mut [u8; 16])> { let stop = self.suffix_start(); self.data[..stop] .chunks_exact_mut(32) .map(move |x| x.split_at_mut(16)) .map(move |(x, y)| (x.try_into().unwrap(), y.try_into().unwrap())) } + + pub fn suffix_8_mut(&mut self) -> impl Iterator<Item = (&mut [u8; 16], &mut [u8; 16])> { + let start = self.suffix_start() / BIG_CHUNK_SIZE * BIG_CHUNK_SIZE; + let stop = self.suffix_start(); + self.data[start..stop] + .chunks_exact_mut(32) + .map(move |x| x.split_at_mut(16)) + .map(move |(x, y)| (x.try_into().unwrap(), y.try_into().unwrap())) + } + + pub fn pairs_8_mut( + &mut self, + ) -> impl Iterator<Item = ([&mut [u8; 16]; 8], [&mut [u8; 16]; 8])> { + let stop = self.suffix_start() / BIG_CHUNK_SIZE * BIG_CHUNK_SIZE; + self.data[..stop] + .chunks_exact_mut(BIG_CHUNK_SIZE) + .map(move |x| { + let (b0, b1) = x.split_at_mut(BIG_CHUNK_SIZE / 2); + let (b00, b01) = b0.split_at_mut(BIG_CHUNK_SIZE / 4); + let (b10, b11) = b1.split_at_mut(BIG_CHUNK_SIZE / 4); + let (b000, b001) = b00.split_at_mut(BIG_CHUNK_SIZE / 8); + let (b010, b011) = b01.split_at_mut(BIG_CHUNK_SIZE / 8); + let (b100, b101) = b10.split_at_mut(BIG_CHUNK_SIZE / 8); + let (b110, b111) = b11.split_at_mut(BIG_CHUNK_SIZE / 8); + let (b0000, b0001) = b000.split_at_mut(16); + let (b0010, b0011) = b001.split_at_mut(16); + let (b0100, b0101) = b010.split_at_mut(16); + let (b0110, b0111) = b011.split_at_mut(16); + let (b1000, b1001) = b100.split_at_mut(16); + let (b1010, b1011) = b101.split_at_mut(16); + let (b1100, b1101) = b110.split_at_mut(16); + let (b1110, b1111) = b111.split_at_mut(16); + ( + [ + b0000.try_into().unwrap(), + b0010.try_into().unwrap(), + b0100.try_into().unwrap(), + b0110.try_into().unwrap(), + b1000.try_into().unwrap(), + b1010.try_into().unwrap(), + b1100.try_into().unwrap(), + b1110.try_into().unwrap(), + ], + [ + b0001.try_into().unwrap(), + b0011.try_into().unwrap(), + b0101.try_into().unwrap(), + b0111.try_into().unwrap(), + b1001.try_into().unwrap(), + b1011.try_into().unwrap(), + b1101.try_into().unwrap(), + b1111.try_into().unwrap(), + ], + ) + }) + } } diff --git a/src/aesround.rs b/src/aesround.rs index 6f63243..4ae3f6f 100644 --- a/src/aesround.rs +++ b/src/aesround.rs @@ -61,9 +61,11 @@ impl AesRound for AesSoft { // under the hood), but there is a big benefit here: // First, we can save time by only loading the keys once as a __m128i, which makes the whole thing // a bit faster. -// More importantly though, when using target-cpu=native, we get nicely vectorized AES instructions -// (VAESENC), which we don't get if we go through aes::hazmat::cipher_round. This is a *huge* -// speedup, which we don't want to miss. +// More importantly though, the compiler does not inline the call to cipher_round, even when using +// target-cpu=native. I guess this is because it crosses a crate boundary (and cross-crate inlining +// only happens with LTO). In fact, compiling with lto=true does inline the call, but we don't want +// to force that to all library users. Anyway, by re-implementing the AES instruction here, we get +// nice inlining without relying on LTO and therefore a huge speedup, as AES is called a lot. #[cfg(target_arch = "x86_64")] pub mod x86_64 { use super::*; diff --git a/src/block.rs b/src/block.rs index bde60fc..8d6517d 100644 --- a/src/block.rs +++ b/src/block.rs @@ -1,4 +1,9 @@ +#[cfg(feature = "std")] use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr}; +#[cfg(not(feature = "std"))] +use core::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr}; + + #[cfg(feature = "simd")] use std::simd::prelude::*; @@ -12,6 +17,18 @@ pub struct Block(u8x16); #[cfg(not(feature = "simd"))] pub struct Block([u8; 16]); +macro_rules! add_ladder { + ($ar:expr, $lit:literal) => { + $ar[$lit] = $ar[$lit].wrapping_add(1); + }; + ($ar:expr, $lit:literal $($rest:literal) +) => { + $ar[$lit] = $ar[$lit].wrapping_add(1); + if $ar[$lit] == 0 { + add_ladder!($ar, $($rest) +); + } + }; +} + impl Block { pub fn null() -> Block { Block([0; 16].into()) @@ -61,7 +78,7 @@ impl Block { Block(value.into().to_be_bytes().into()) } - pub fn to_int(&self) -> u128 { + pub fn to_int(self) -> u128 { u128::from_be_bytes(self.0.into()) } @@ -72,7 +89,19 @@ impl Block { /// This corresponds to X10* in the paper. pub fn pad(&self, length: usize) -> Block { assert!(length <= 127); - Block::from_int(self.to_int() | (1 << (127 - length))) + let mut result = *self; + result[length / 8] |= 1 << (7 - length % 8); + result + } + + /// Pad the block to full length. + /// + /// Unlike [`pad`], this function takes the length in bytes. + pub fn pad_bytes(&self, length: u8) -> Block { + assert!(length <= 15); + let mut result = *self; + result[length as usize] = 0x80; + result } /// Clip the block by setting all bits beyond the given length to 0. @@ -86,13 +115,17 @@ impl Block { /// Computes self * 2^exponent /// /// Ensures that there's no overflow in computing 2^exponent. - pub fn exp(&self, exponent: u32) -> Block { + pub fn exp(&self, exponent: usize) -> Block { match exponent { _ if exponent < 32 => *self * (1 << exponent), _ if exponent % 2 == 0 => self.exp(exponent / 2).exp(exponent / 2), _ => (*self * 2).exp(exponent - 1), } } + + pub fn count_up(&mut self) { + add_ladder!(self, 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0); + } } impl From<[u8; 16]> for Block { @@ -148,6 +181,37 @@ impl BitXor<Block> for Block { impl Shl<u32> for Block { type Output = Block; fn shl(self, rhs: u32) -> Block { + // We often use a shift by one, for example in the multiplication. We therefore optimize + // for this special case. + #[cfg(feature = "simd")] + { + if rhs == 1 { + return Block((self.0 << 1) | (self.0.shift_elements_left::<1>(0) >> 7)); + } + } + #[cfg(not(feature = "simd"))] + { + if rhs == 1 { + return Block([ + (self.0[0] << 1) | (self.0[1] >> 7), + (self.0[1] << 1) | (self.0[2] >> 7), + (self.0[2] << 1) | (self.0[3] >> 7), + (self.0[3] << 1) | (self.0[4] >> 7), + (self.0[4] << 1) | (self.0[5] >> 7), + (self.0[5] << 1) | (self.0[6] >> 7), + (self.0[6] << 1) | (self.0[7] >> 7), + (self.0[7] << 1) | (self.0[8] >> 7), + (self.0[8] << 1) | (self.0[9] >> 7), + (self.0[9] << 1) | (self.0[10] >> 7), + (self.0[10] << 1) | (self.0[11] >> 7), + (self.0[11] << 1) | (self.0[12] >> 7), + (self.0[12] << 1) | (self.0[13] >> 7), + (self.0[13] << 1) | (self.0[14] >> 7), + (self.0[14] << 1) | (self.0[15] >> 7), + (self.0[15] << 1), + ]); + } + } Block::from(self.to_int() << rhs) } } @@ -202,9 +266,9 @@ impl IndexMut<usize> for Block { } } -impl Mul<u32> for Block { +impl Mul<usize> for Block { type Output = Block; - fn mul(self, rhs: u32) -> Block { + fn mul(self, rhs: usize) -> Block { match rhs { 0 => Block::null(), 1 => self, @@ -216,7 +280,7 @@ impl Mul<u32> for Block { result } _ if rhs % 2 == 0 => self * 2 * (rhs / 2), - _ => self * (rhs - 1) ^ self, + _ => (self * (rhs - 1)) ^ self, } } } @@ -1,4 +1,5 @@ #![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(not(feature = "std"), no_std)] //! AEZ *\[sic!\]* v5 encryption implemented in Rust. //! //! # ☣️ Cryptographic hazmat ☣️ @@ -22,27 +23,27 @@ //! # AEZ encryption (for laypeople) //! //! The security property of encryption schemes says that an adversary without key must not learn -//! the content of a message, but the adversary might still be able to modify the message. For -//! example, in AES-CTR, flipping a bit in the ciphertext means that the same bit will be flipped -//! in the plaintext once the message is decrypted. +//! the content of a message. However, the adversary might still be able to modify the message. For +//! example, in AES-CTR (or other stream ciphers), flipping a bit in the ciphertext means that the +//! same bit will be flipped in the plaintext once the message is decrypted. This allows for +//! "planned" changes. //! //! Authenticated encryption solves this problem by including a mechanism to detect changes. This //! can be done for example by including a MAC, or using a mode like GCM (Galois counter mode). In -//! many cases, not only the integrity of the ciphertext can be verified, but additional data can -//! be provided during encryption and decryption which will also be included in the integrity -//! check. This results in an *authenticated encryption with associated data* scheme, AEAD for +//! many cases, not only the integrity of the ciphertext can be verified, but the user can provide +//! additional data during encryption and decryption which will also have its integrity be +//! verified. This is called an *authenticated encryption with associated data* scheme, AEAD for //! short. //! //! AEZ employs a nifty technique in order to realize an AEAD scheme: The core of AEZ is an -//! enciphering scheme, which in addition to "hiding" its input is also very "unpredictable", -//! similar to a hash function. That means that if a ciphertext is changed slightly (by flipping a -//! bit), the resulting plaintext will be unpredictably and completely different. +//! enciphering scheme, which in addition to "hiding" its input is also very "unpredictable" when +//! bits are flipped. Similar to a hash function, if the ciphertext is changed slightly (by +//! flipping a bit), the resulting plaintext will be unpredictably and completely different. //! //! With this property, authenticated encryption can be realized implicitly: The message is padded //! with a known string before enciphering it. If, after deciphering, this known string is not -//! present, the message has been tampered with. Since the enciphering scheme is parametrized by -//! the key, a nonce and arbitrary additional data, we can verify the integrity of associated data -//! as well. +//! present, the message has been tampered with. Since the enciphering is parametrized by the key, +//! a nonce and arbitrary additional data, we can verify the integrity of associated data as well. //! //! # Other implementations //! @@ -67,6 +68,8 @@ //! //! ``` //! # use zears::*; +//! # #[cfg(feature = "std")] +//! # fn main() { //! let aez = Aez::new(b"my secret key!"); //! let cipher = aez.encrypt(b"nonce", &[b"associated data"], 16, b"message"); //! let plaintext = aez.decrypt(b"nonce", &[b"associated data"], 16, &cipher); @@ -82,6 +85,9 @@ //! let cipher = aez.encrypt(b"nonce", &[b"foo"], 16, b"message"); //! let plaintext = aez.decrypt(b"nonce", &[b"bar"], 16, &cipher); //! assert!(plaintext.is_none()); +//! # } +//! # #[cfg(not(feature = "std"))] +//! # fn main() {} //! ``` //! //! # Feature flags & compilation hints @@ -89,7 +95,7 @@ //! * Enable feature `simd` (requires nightly due to the `portable_simd` Rust feature) to speed up //! encryption and decryption by using SIMD instructions (if available). //! * Use `target-cpu=native` (e.g. by setting `RUSTFLAGS=-Ctarget-cpu=native`) to make the -//! compiler emit vectorized AES instructions (if available). This can speed up +//! compiler emit hardware AES instructions (if available). This can speed up //! encryption/decryption at the cost of producing less portable code. //! //! On my machine, this produces the following results (for the `encrypt_inplace/2048` benchmark): @@ -102,6 +108,11 @@ //! | +simd, target-cpu=native | 3.3272 GiB/s | +592.01% | //! | `aez` crate | 4.8996 GiB/s | | +#[cfg(not(feature = "std"))] +use core::iter; +#[cfg(feature = "std")] +use std::iter; + use constant_time_eq::constant_time_eq; mod accessor; @@ -115,7 +126,6 @@ use accessor::BlockAccessor; use aesround::AesRound; use block::Block; type Key = [u8; 48]; -type Tweak<'a> = &'a [&'a [u8]]; static ZEROES: [u8; 1024] = [0; 1024]; @@ -125,7 +135,23 @@ enum Mode { Decipher, } +enum SliceOrAsRef<'a, A> { + Slice(&'a [u8]), + AsRef(A), +} + +impl<'a, A: AsRef<[u8]>> AsRef<[u8]> for SliceOrAsRef<'a, A> { + fn as_ref(&self) -> &[u8] { + match self { + SliceOrAsRef::Slice(x) => *x, + SliceOrAsRef::AsRef(x) => x.as_ref(), + } + } +} + /// AEZ encryption scheme. +/// +/// See the [module level documentation](index.html) for more information. pub struct Aez { key_i: Block, key_j: Block, @@ -145,6 +171,7 @@ impl Aez { let key = extract(key); let (key_i, key_j, key_l) = split_key(&key); let aes = aesround::AesImpl::new(key_i, key_j, key_l); + #[allow(clippy::erasing_op)] let key_l_multiples = [ key_l * 0, key_l * 1, @@ -172,7 +199,7 @@ impl Aez { /// Parameters: /// /// * `nonce` -- the nonce to use. Each nonce should only be used once, as re-using the nonce - /// (without chaning the key) will lead to the same ciphertext being produced, potentially + /// (without changing the key) will lead to the same ciphertext being produced, potentially /// making it re-identifiable. /// * `associated_data` -- additional data to be included in the integrity check. Note that /// this data will *not* be contained in the ciphertext, but it must be provided on @@ -181,9 +208,10 @@ impl Aez { /// 16` gives 128 bits of security. Passing a value of 0 is valid and leads to no integrity /// checking. /// * `data` -- actual data to encrypt. Can be empty, in which case the returned ciphertext - /// provides a "hash" that verifies the integrity of the associated data will be + /// provides a "hash" that verifies the integrity of the associated data. /// /// Returns the ciphertext, which will be of length `data.len() + tau`. + #[cfg(feature = "std")] pub fn encrypt( &self, nonce: &[u8], @@ -204,6 +232,7 @@ impl Aez { /// If `tau == 0`, the vector will not be expanded. /// /// The parameters are the same as for [`Aez::encrypt`]. + #[cfg(feature = "std")] pub fn encrypt_vec( &self, nonce: &[u8], @@ -212,7 +241,7 @@ impl Aez { data: &mut Vec<u8>, ) { data.resize(data.len() + tau as usize, 0); - encrypt(&self, nonce, associated_data, tau, data); + encrypt(self, nonce, associated_data, tau, data); } /// Encrypts the data inplace. @@ -231,7 +260,7 @@ impl Aez { assert!(buffer.len() >= tau as usize); let data_len = buffer.len() - tau as usize; append_auth(data_len, buffer); - encrypt(&self, nonce, associated_data, tau as u32, buffer); + encrypt(self, nonce, associated_data, tau, buffer); } /// Encrypts the data in the given buffer, writing the output to the given output buffer. @@ -251,7 +280,7 @@ impl Aez { let tau = output.len() - input.len(); output[..input.len()].copy_from_slice(input); append_auth(input.len(), output); - encrypt(&self, nonce, associated_data, tau as u32, output); + encrypt(self, nonce, associated_data, tau as u32, output); } /// Decrypts the given ciphertext. @@ -266,6 +295,7 @@ impl Aez { /// /// Returns the decrypted content. If the integrity check fails, returns `None` instead. The /// returned vector has length `data.len() - tau`. + #[cfg(feature = "std")] pub fn decrypt( &self, nonce: &[u8], @@ -274,7 +304,7 @@ impl Aez { data: &[u8], ) -> Option<Vec<u8>> { let mut buffer = Vec::from(data); - let len = match decrypt(&self, nonce, associated_data, tau, &mut buffer) { + let len = match decrypt(self, nonce, associated_data, tau, &mut buffer) { None => return None, Some(m) => m.len(), }; @@ -294,7 +324,7 @@ impl Aez { tau: u32, data: &'a mut [u8], ) -> Option<&'a [u8]> { - decrypt(&self, nonce, associated_data, tau, data) + decrypt(self, nonce, associated_data, tau, data) } } @@ -319,37 +349,33 @@ fn append_auth(data_len: usize, buffer: &mut [u8]) { } } -fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { +fn encrypt<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + aez: &Aez, + nonce: &[u8], + ad: T, + tau: u32, + buffer: &mut [u8], +) { // We treat tau as bytes, but according to the spec, tau is actually in bits. let tau_block = Block::from_int(tau as u128 * 8); let tau_bytes = tau_block.bytes(); - let mut tweaks_vec; - // We optimize for the common case of having no associated data, or having one item of - // associated data (which is all the reference implementation supports anyway). If there's more - // associated data, we cave in and allocate a vec. - let tweaks = match ad.len() { - 0 => &[&tau_bytes, nonce] as &[&[u8]], - 1 => &[&tau_bytes, nonce, ad[0]], - _ => { - tweaks_vec = vec![&tau_bytes, nonce]; - tweaks_vec.extend(ad); - &tweaks_vec - } - }; + let tweaks = iter::once(SliceOrAsRef::Slice(&tau_bytes)) + .chain(iter::once(SliceOrAsRef::Slice(nonce))) + .chain(ad.into_iter().map(SliceOrAsRef::AsRef)); assert!(buffer.len() >= tau as usize); if buffer.len() == tau as usize { // As aez_prf only xor's the input in, we have to clear the buffer first buffer.fill(0); - aez_prf(aez, &tweaks, buffer); + aez_prf(aez, tweaks, buffer); } else { - encipher(aez, &tweaks, buffer); + encipher(aez, tweaks, buffer); } } -fn decrypt<'a>( +fn decrypt<'a, 't, A: AsRef<[u8]> + 't, T: IntoIterator<Item = &'t A>>( aez: &Aez, nonce: &[u8], - ad: &[&[u8]], + ad: T, tau: u32, ciphertext: &'a mut [u8], ) -> Option<&'a [u8]> { @@ -359,37 +385,38 @@ fn decrypt<'a>( let tau_block = Block::from_int(tau * 8); let tau_bytes = tau_block.bytes(); - let mut tweaks = vec![&tau_bytes, nonce]; - tweaks.extend(ad); + let tweaks = iter::once(&tau_bytes as &[_]) + .chain(iter::once(nonce)) + .chain(ad.into_iter().map(|x| x.as_ref())); if ciphertext.len() == tau as usize { - aez_prf(aez, &tweaks, ciphertext); - if is_zeroes(&ciphertext) { + aez_prf(aez, tweaks, ciphertext); + if is_zeroes(ciphertext) { return Some(&[]); } else { return None; } } - decipher(aez, &tweaks, ciphertext); + decipher(aez, tweaks, ciphertext); let (m, auth) = ciphertext.split_at(ciphertext.len() - tau as usize); assert!(auth.len() == tau as usize); - if is_zeroes(&auth) { Some(m) } else { None } + if is_zeroes(auth) { Some(m) } else { None } } fn is_zeroes(data: &[u8]) -> bool { - let comparator = if data.len() <= ZEROES.len() { - &ZEROES[..data.len()] - } else { - // We should find a way to do this without allocating a separate buffer full of zeroes, but - // I don't want to hand-roll my constant-time-is-zeroes yet. - &vec![0; data.len()] - }; - constant_time_eq(data, comparator) + const PATTERN: u32 = 0xDEADBABE; + let mut accum = 0u32; + for chunk in data.chunks(ZEROES.len()) { + // this is accum = accum | !(chunk == 0...0) + // basically, accum will say if one chunk was not zeroes + accum |= (1 - (constant_time_eq(chunk, &ZEROES[..chunk.len()]) as u32)) * PATTERN; + } + accum == 0 } -fn encipher(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { +fn encipher<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, message: &mut [u8]) { if message.len() < 256 / 8 { cipher_aez_tiny(Mode::Encipher, aez, tweaks, message) } else { @@ -397,7 +424,7 @@ fn encipher(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { } } -fn decipher(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { +fn decipher<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, buffer: &mut [u8]) { if buffer.len() < 256 / 8 { cipher_aez_tiny(Mode::Decipher, aez, tweaks, buffer); } else { @@ -405,7 +432,12 @@ fn decipher(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { } } -fn cipher_aez_tiny(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { +fn cipher_aez_tiny<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + mode: Mode, + aez: &Aez, + tweaks: T, + message: &mut [u8], +) { let mu = message.len() * 8; assert!(mu < 256); let n = mu / 2; @@ -464,25 +496,61 @@ fn cipher_aez_tiny(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { } if mode == Mode::Encipher && mu < 128 { - let mut c = Block::from_slice(&message); + let mut c = Block::from_slice(message); c = c ^ (e(0, 3, aez, delta ^ (c | Block::one())) & Block::one()); message.copy_from_slice(&c.bytes()[..mu / 8]); } } -fn cipher_aez_core(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { - assert!(message.len() >= 32); - let delta = aez_hash(aez, tweaks); - let mut blocks = BlockAccessor::new(message); - let (m_u, m_v, m_x, m_y, d) = ( - blocks.m_u(), - blocks.m_v(), - blocks.m_x(), - blocks.m_y(), - blocks.m_uv_len(), - ); - let len_v = d.saturating_sub(128); +macro_rules! unroll_pairs { + ( + $accessor:expr; + setup_unrolled => $setup_unrolled:block; + setup_single => $setup_single:block; + roll ($left:ident, $right:ident) => $roll:block; + ) => { + for (left, right) in $accessor.pairs_8_mut() { + $setup_unrolled; + + let [l0, l1, l2, l3, l4, l5, l6, l7] = left; + let [r0, r1, r2, r3, r4, r5, r6, r7] = right; + + let $left = l0; + let $right = r0; + $roll; + let $left = l1; + let $right = r1; + $roll; + let $left = l2; + let $right = r2; + $roll; + let $left = l3; + let $right = r3; + $roll; + let $left = l4; + let $right = r4; + $roll; + let $left = l5; + let $right = r5; + $roll; + let $left = l6; + let $right = r6; + $roll; + let $left = l7; + let $right = r7; + $roll; + } + + for (left, right) in $accessor.suffix_8_mut() { + let $left = left; + let $right = right; + $setup_single; + $roll; + } + }; +} +fn pass_one(aez: &Aez, blocks: &mut BlockAccessor) -> Block { let mut x = Block::null(); let mut e1_eval = E::new(1, 0, aez); let e0_eval = E::new(0, 0, aez); @@ -500,14 +568,73 @@ fn cipher_aez_core(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { x = x ^ xi; } + x +} + +fn pass_two(aez: &Aez, blocks: &mut BlockAccessor, s: Block) -> Block { + let mut y = Block::null(); + let e2_eval = E::new(2, 0, aez); + let mut e1_eval = E::new(1, 0, aez); + let e0_eval = E::new(0, 0, aez); + let mut evals_for_s = e2_eval.evals_for(s); + + unroll_pairs! { blocks; + setup_unrolled => { + evals_for_s.refill(); + }; + setup_single => { + if evals_for_s.len == 0 { + evals_for_s.refill(); + } + }; + roll (raw_wi, raw_xi) => { + e1_eval.advance(); + let wi = Block::from(*raw_wi); + let xi = Block::from(*raw_xi); + let yi = wi ^ evals_for_s.blocks[8 - evals_for_s.len]; + let zi = xi ^ evals_for_s.blocks[8 - evals_for_s.len]; + let ci_ = yi ^ e0_eval.eval(zi); + let ci = zi ^ e1_eval.eval(ci_); + + ci.write_to(raw_wi); + ci_.write_to(raw_xi); + + y = y ^ yi; + evals_for_s.len -= 1; + }; + } + + y +} + +fn cipher_aez_core<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + mode: Mode, + aez: &Aez, + tweaks: T, + message: &mut [u8], +) { + assert!(message.len() >= 32); + let delta = aez_hash(aez, tweaks); + let mut blocks = BlockAccessor::new(message); + let (m_u, m_v, m_x, m_y, d) = ( + blocks.m_u(), + blocks.m_v(), + blocks.m_x(), + blocks.m_y(), + blocks.m_uv_len(), + ); + let len_v = d.saturating_sub(128); + + let mut x = pass_one(aez, &mut blocks); + match d { 0 => (), _ if d <= 127 => { - x = x ^ e(0, 4, aez, m_u.pad(d.into())); + x = x ^ e(0, 4, aez, m_u.pad(d)); } _ => { x = x ^ e(0, 4, aez, m_u); - x = x ^ e(0, 5, aez, m_v.pad(len_v.into())); + x = x ^ e(0, 5, aez, m_v.pad(len_v)); } } @@ -515,34 +642,16 @@ fn cipher_aez_core(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { match mode { Mode::Encipher => { s_x = m_x ^ delta ^ x ^ e(0, 1, aez, m_y); - s_y = m_y ^ e(-1, 1, aez, s_x); + s_y = m_y ^ e_neg(1, aez, s_x); } Mode::Decipher => { s_x = m_x ^ delta ^ x ^ e(0, 2, aez, m_y); - s_y = m_y ^ e(-1, 2, aez, s_x); + s_y = m_y ^ e_neg(2, aez, s_x); } } let s = s_x ^ s_y; - let mut y = Block::null(); - let mut e2_eval = E::new(2, 0, aez); - let mut e1_eval = E::new(1, 0, aez); - for (raw_wi, raw_xi) in blocks.pairs_mut() { - e2_eval.advance(); - e1_eval.advance(); - let wi = Block::from(*raw_wi); - let xi = Block::from(*raw_xi); - let s_ = e2_eval.eval(s); - let yi = wi ^ s_; - let zi = xi ^ s_; - let ci_ = yi ^ e0_eval.eval(zi); - let ci = zi ^ e1_eval.eval(ci_); - - ci.write_to(raw_wi); - ci_.write_to(raw_xi); - - y = y ^ yi; - } + let mut y = pass_two(aez, &mut blocks, s); let mut c_u = Block::default(); let mut c_v = Block::default(); @@ -550,25 +659,25 @@ fn cipher_aez_core(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { match d { 0 => (), _ if d <= 127 => { - c_u = (m_u ^ e(-1, 4, aez, s)).clip(d.into()); - y = y ^ e(0, 4, aez, c_u.pad(d.into())); + c_u = (m_u ^ e_neg(4, aez, s)).clip(d); + y = y ^ e(0, 4, aez, c_u.pad(d)); } _ => { - c_u = m_u ^ e(-1, 4, aez, s); - c_v = (m_v ^ e(-1, 5, aez, s)).clip(len_v.into()); + c_u = m_u ^ e_neg(4, aez, s); + c_v = (m_v ^ e_neg(5, aez, s)).clip(len_v); y = y ^ e(0, 4, aez, c_u); - y = y ^ e(0, 5, aez, c_v.pad(len_v.into())); + y = y ^ e(0, 5, aez, c_v.pad(len_v)); } } let (c_x, c_y); match mode { Mode::Encipher => { - c_y = s_x ^ e(-1, 2, aez, s_y); + c_y = s_x ^ e_neg(2, aez, s_y); c_x = s_y ^ delta ^ y ^ e(0, 2, aez, c_y); } Mode::Decipher => { - c_y = s_x ^ e(-1, 1, aez, s_y); + c_y = s_x ^ e_neg(1, aez, s_y); c_x = s_y ^ delta ^ y ^ e(0, 1, aez, c_y); } } @@ -579,21 +688,10 @@ fn cipher_aez_core(mode: Mode, aez: &Aez, tweaks: Tweak, message: &mut [u8]) { blocks.set_m_y(c_y); } -fn pad_to_blocks(value: &[u8]) -> Vec<Block> { - let mut blocks = Vec::new(); - for chunk in value.chunks(16) { - if chunk.len() == 16 { - blocks.push(Block::from_slice(chunk)); - } else { - blocks.push(Block::from_slice(chunk).pad(chunk.len() * 8)); - } - } - blocks -} - -fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { +fn aez_hash<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T) -> Block { let mut hash = Block::null(); - for (i, tweak) in tweaks.iter().enumerate() { + for (i, tweak) in tweaks.into_iter().enumerate() { + let tweak = tweak.as_ref(); // Adjust for zero-based vs one-based indexing let j = i + 2 + 1; let mut ej = E::new(j.try_into().unwrap(), 0, aez); @@ -608,33 +706,35 @@ fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { hash = hash ^ ej.eval(Block::from_slice(chunk)); } } else { - let blocks = pad_to_blocks(tweak); - for (l, chunk) in blocks.iter().enumerate() { + let blocks = tweak.chunks_exact(16); + let remainder = blocks.remainder(); + + for chunk in blocks { ej.advance(); - if l == blocks.len() - 1 { - hash = hash ^ e(j.try_into().unwrap(), 0, aez, *chunk); - } else { - hash = hash ^ ej.eval(*chunk); - } + hash = hash ^ ej.eval(Block::from_slice(chunk)); } + + ej.advance(); + let chunk = Block::from_slice(remainder).pad_bytes(remainder.len() as u8); + hash = hash ^ e(j.try_into().unwrap(), 0, aez, chunk); } } hash } /// XOR's the result of aez_prf into the given buffer -fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { - let mut index = 0u128; +fn aez_prf<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, buffer: &mut [u8]) { + let mut index = Block::null(); let delta = aez_hash(aez, tweaks); for chunk in buffer.chunks_exact_mut(16) { let chunk: &mut [u8; 16] = chunk.try_into().unwrap(); - let block = e(-1, 3, aez, delta ^ Block::from_int(index)); + let block = e_neg(3, aez, delta ^ index); (block ^ Block::from(*chunk)).write_to(chunk); - index += 1; + index.count_up(); } let suffix_start = buffer.len() - buffer.len() % 16; let chunk = &mut buffer[suffix_start..]; - let block = e(-1, 3, aez, delta ^ Block::from_int(index)); + let block = e_neg(3, aez, delta ^ index); for (a, b) in chunk.iter_mut().zip(block.bytes().iter()) { *a ^= *b; } @@ -646,16 +746,14 @@ fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { /// temporary values and makes it much faster to compute E_K^{j, i+1}, E_K^{j, i+2}, ... struct E<'a> { aez: &'a Aez, - i: u32, + i: usize, kj_t_j: Block, ki_p_i: Block, } impl<'a> E<'a> { /// Create a new "suspended" computation of E_K^{j,i}. - fn new(j: i32, i: u32, aez: &'a Aez) -> Self { - assert!(j >= 0); - let j: u32 = j.try_into().expect("j was negative"); + fn new(j: usize, i: usize, aez: &'a Aez) -> Self { let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 }; E { aez, @@ -667,10 +765,14 @@ impl<'a> E<'a> { /// Complete this computation to evaluate E_K^{j,i}(block). fn eval(&self, block: Block) -> Block { - let delta = self.kj_t_j ^ self.ki_p_i ^ self.aez.key_l_multiples[self.i as usize % 8]; + let delta = self.kj_t_j ^ self.ki_p_i ^ self.aez.key_l_multiples[self.i % 8]; self.aez.aes.aes4(block ^ delta) } + fn evals_for(self, block: Block) -> Eiter<'a> { + Eiter::new(self, block) + } + /// Advance this computation by going from i to i+1. /// /// Afterwards, this computation will represent E_K^{j, i+1} @@ -685,18 +787,91 @@ impl<'a> E<'a> { } } +struct Eiter<'a> { + e: E<'a>, + value: Block, + blocks: [Block; 8], + len: usize, +} + +impl<'a> Eiter<'a> { + fn new(e: E<'a>, value: Block) -> Self { + assert_eq!(e.i, 0); + Eiter { + e, + value, + blocks: [Default::default(); 8], + len: 0, + } + } + + fn refill(&mut self) { + self.e.ki_p_i = self.e.ki_p_i * 2; + let pre_xored = self.value ^ self.e.kj_t_j ^ self.e.ki_p_i; + self.blocks = [ + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[1]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[2]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[3]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[4]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[5]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[6]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[7]), + self.e + .aez + .aes + .aes4(pre_xored ^ self.e.aez.key_l_multiples[0]), + ]; + self.len = 8; + } +} + +impl<'a> Iterator for Eiter<'a> { + type Item = Block; + + fn next(&mut self) -> Option<Self::Item> { + if self.len == 0 { + self.refill(); + } + let result = Some(self.blocks[8 - self.len]); + self.len -= 1; + result + } +} + /// Shorthand to get E_K^{j,i}(block) -fn e(j: i32, i: u32, aez: &Aez, block: Block) -> Block { - if j == -1 { - let delta = if i < 8 { - aez.key_l_multiples[i as usize] - } else { - aez.key_l * i - }; - aez.aes.aes10(block ^ delta) +fn e(j: usize, i: usize, aez: &Aez, block: Block) -> Block { + E::new(j, i, aez).eval(block) +} + +/// Computes E_K^{-1,i}(block) +fn e_neg(i: usize, aez: &Aez, block: Block) -> Block { + let delta = if i < 8 { + aez.key_l_multiples[i] } else { - E::new(j, i, aez).eval(block) - } + aez.key_l * i + }; + aez.aes.aes10(block ^ delta) } fn split_key(key: &Key) -> (Block, Block, Block) { @@ -707,7 +882,49 @@ fn split_key(key: &Key) -> (Block, Block, Block) { ) } -#[cfg(test)] +/// Access to the primitive AEZ functions `aez_prf`, `aez_hash`, `encipher`, `decipher` and `e`. +/// +/// Note that this is a low-level API, not intended for general use. +/// +/// Note further that this API is exempt from semver guarantees and might break. +#[cfg(feature = "primitives")] +pub mod primitives { + use super::{Aez, Block}; + + pub fn encipher<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + aez: &Aez, + tweaks: T, + message: &mut [u8], + ) { + super::encipher(aez, tweaks, message) + } + + pub fn decipher<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + aez: &Aez, + tweaks: T, + message: &mut [u8], + ) { + super::decipher(aez, tweaks, message) + } + + pub fn aez_hash<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T) -> [u8; 16] { + super::aez_hash(aez, tweaks).bytes() + } + + pub fn aez_prf<A: AsRef<[u8]>, T: IntoIterator<Item = A>>( + aez: &Aez, + tweaks: T, + buffer: &mut [u8], + ) { + super::aez_prf(aez, tweaks, buffer) + } + + pub fn e(j: usize, i: usize, aez: &Aez, block: [u8; 16]) -> [u8; 16] { + super::e(j, i, aez, Block::from(block)).bytes() + } +} + +#[cfg(all(test, feature = "std"))] mod test { use super::*; @@ -731,7 +948,19 @@ mod test { let a = hex::decode(a).unwrap(); let a = Block::from_slice(&a); let b = hex::decode(b).unwrap(); - assert_eq!(&e(*j, *i, &aez, a).bytes(), b.as_slice(), "{name}"); + if *j >= 0 { + assert_eq!( + &e((*j).try_into().unwrap(), (*i).try_into().unwrap(), &aez, a).bytes(), + b.as_slice(), + "{name}" + ); + } else { + assert_eq!( + &e_neg((*i).try_into().unwrap(), &aez, a).bytes(), + b.as_slice(), + "{name}" + ); + } } } |
