From 5bd298ed568aca12a54f014a7b13f943379a5eb9 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Fri, 11 Apr 2025 12:48:18 +0200 Subject: use simd instructions (requires nightly compiler) --- src/accessor.rs | 9 ++++---- src/aesround.rs | 49 +++++++++++++++++---------------------- src/block.rs | 65 ++++++++++++++++++++++++++++------------------------ src/lib.rs | 71 ++++++++++++++++++++++++++++++--------------------------- 4 files changed, 99 insertions(+), 95 deletions(-) diff --git a/src/accessor.rs b/src/accessor.rs index 89f5251..24905af 100644 --- a/src/accessor.rs +++ b/src/accessor.rs @@ -36,7 +36,8 @@ impl<'a> BlockAccessor<'a> { pub fn set_m_u(&mut self, m_u: Block) { let start = self.suffix_start(); - self.data[start..start + self.m_u_len / 8].copy_from_slice(&m_u.0[..self.m_u_len / 8]); + self.data[start..start + self.m_u_len / 8] + .copy_from_slice(&m_u.bytes()[..self.m_u_len / 8]); } pub fn m_v(&self) -> Block { @@ -47,7 +48,7 @@ impl<'a> BlockAccessor<'a> { pub fn set_m_v(&mut self, m_v: Block) { let start = self.suffix_start(); self.data[start + self.m_u_len / 8..start + self.m_uv_len / 8] - .copy_from_slice(&m_v.0[..self.m_v_len / 8]); + .copy_from_slice(&m_v.bytes()[..self.m_v_len / 8]); } pub fn m_x(&self) -> Block { @@ -57,7 +58,7 @@ impl<'a> BlockAccessor<'a> { pub fn set_m_x(&mut self, m_x: Block) { let start = self.suffix_start() + self.m_uv_len / 8; - self.data[start..start + 16].copy_from_slice(&m_x.0); + self.data[start..start + 16].copy_from_slice(&m_x.bytes()); } pub fn m_y(&self) -> Block { @@ -67,7 +68,7 @@ impl<'a> BlockAccessor<'a> { pub fn set_m_y(&mut self, m_y: Block) { let start = self.suffix_start() + self.m_uv_len / 8; - self.data[start + 16..start + 32].copy_from_slice(&m_y.0); + self.data[start + 16..start + 32].copy_from_slice(&m_y.bytes()); } pub fn pairs_mut<'b>( diff --git a/src/aesround.rs b/src/aesround.rs index 0a06192..d04ac9b 100644 --- a/src/aesround.rs +++ b/src/aesround.rs @@ -26,23 +26,23 @@ pub struct AesSoft { impl AesRound for AesSoft { fn new(key_i: Block, key_j: Block, key_l: Block) -> Self { Self { - key_i: key_i.0.into(), - key_j: key_j.0.into(), - key_l: key_l.0.into(), + key_i: key_i.bytes().into(), + key_j: key_j.bytes().into(), + key_l: key_l.bytes().into(), } } fn aes4(&self, value: Block) -> Block { - let mut block: aes::Block = value.0.into(); + let mut block: aes::Block = value.bytes().into(); ::aes::hazmat::cipher_round(&mut block, &self.key_j); ::aes::hazmat::cipher_round(&mut block, &self.key_i); ::aes::hazmat::cipher_round(&mut block, &self.key_l); - ::aes::hazmat::cipher_round(&mut block, &Block::NULL.0.into()); - Block(block.into()) + ::aes::hazmat::cipher_round(&mut block, &Block::null().bytes().into()); + >::from(block.into()) } fn aes10(&self, value: Block) -> Block { - let mut block: aes::Block = value.0.into(); + let mut block: aes::Block = value.bytes().into(); ::aes::hazmat::cipher_round(&mut block, &self.key_i); ::aes::hazmat::cipher_round(&mut block, &self.key_j); ::aes::hazmat::cipher_round(&mut block, &self.key_l); @@ -53,7 +53,7 @@ impl AesRound for AesSoft { ::aes::hazmat::cipher_round(&mut block, &self.key_j); ::aes::hazmat::cipher_round(&mut block, &self.key_l); ::aes::hazmat::cipher_round(&mut block, &self.key_i); - Block(block.into()) + >::from(block.into()) } } @@ -75,16 +75,13 @@ pub mod x86_64 { impl AesRound for AesNi { fn new(key_i: Block, key_j: Block, key_l: Block) -> Self { - // SAFETY: loadu can load from unaligned memory - unsafe { - Self { - support: cpuid_aes::init(), - fallback: AesSoft::new(key_i, key_j, key_l), - key_i: _mm_loadu_si128(key_i.0.as_ptr() as *const _), - key_j: _mm_loadu_si128(key_j.0.as_ptr() as *const _), - key_l: _mm_loadu_si128(key_l.0.as_ptr() as *const _), - null: _mm_loadu_si128(Block::NULL.0.as_ptr() as *const _), - } + Self { + support: cpuid_aes::init(), + fallback: AesSoft::new(key_i, key_j, key_l), + key_i: key_i.simd().into(), + key_j: key_j.simd().into(), + key_l: key_l.simd().into(), + null: Block::null().simd().into(), } } @@ -93,16 +90,14 @@ pub mod x86_64 { return self.fallback.aes4(value); } - // SAFETY: loadu can load from unaligned memory + // SAFETY: Nothing should go wrong when calling AESENC unsafe { - let mut block = _mm_loadu_si128(value.0.as_ptr() as *const _); + let mut block = value.simd().into(); block = _mm_aesenc_si128(block, self.key_j); block = _mm_aesenc_si128(block, self.key_i); block = _mm_aesenc_si128(block, self.key_l); block = _mm_aesenc_si128(block, self.null); - let mut result = Block::default(); - _mm_storeu_si128(result.0.as_mut_ptr() as *mut _, block); - result + Block::from_simd(block.into()) } } @@ -111,9 +106,9 @@ pub mod x86_64 { return self.fallback.aes10(value); } - // SAFETY: loadu can load from unaligned memory + // SAFETY: Nothing should go wrong when calling AESENC unsafe { - let mut block = _mm_loadu_si128(value.0.as_ptr() as *const _); + let mut block = value.simd().into(); block = _mm_aesenc_si128(block, self.key_i); block = _mm_aesenc_si128(block, self.key_j); block = _mm_aesenc_si128(block, self.key_l); @@ -124,9 +119,7 @@ pub mod x86_64 { block = _mm_aesenc_si128(block, self.key_j); block = _mm_aesenc_si128(block, self.key_l); block = _mm_aesenc_si128(block, self.key_i); - let mut result = Block::default(); - _mm_storeu_si128(result.0.as_mut_ptr() as *mut _, block); - result + Block::from_simd(block.into()) } } } diff --git a/src/block.rs b/src/block.rs index c294aab..b485b17 100644 --- a/src/block.rs +++ b/src/block.rs @@ -1,12 +1,34 @@ use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr}; +use std::simd::prelude::*; /// A block, the unit of work that AEZ divides the message into. #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Block(pub [u8; 16]); +pub struct Block(u8x16); impl Block { - pub const NULL: Block = Block([0; 16]); - pub const ONE: Block = Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + pub fn null() -> Block { + Block([0; 16].into()) + } + + pub fn one() -> Block { + Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].into()) + } + + pub fn bytes(&self) -> [u8; 16] { + self.0.into() + } + + pub fn write_to(&self, output: &mut [u8; 16]) { + self.0.copy_to_slice(output); + } + + pub(crate) fn simd(&self) -> u8x16 { + self.0 + } + + pub(crate) fn from_simd(value: u8x16) -> Self { + Block(value) + } /// Create a block from a slice. /// @@ -16,18 +38,18 @@ impl Block { let len = value.len().min(16); let mut array = [0; 16]; array[..len].copy_from_slice(&value[..len]); - Block(array) + Block(array.into()) } /// Constructs a block representing the given integer. /// /// This corresponds to [x]_128 in the paper. pub fn from_int>(value: I) -> Self { - Block(value.into().to_be_bytes()) + Block(value.into().to_be_bytes().into()) } pub fn to_int(&self) -> u128 { - u128::from_be_bytes(self.0) + u128::from_be_bytes(self.0.into()) } /// Pad the block to full length. @@ -62,43 +84,26 @@ impl Block { impl From<[u8; 16]> for Block { fn from(value: [u8; 16]) -> Block { - Block(value) + Block(value.into()) } } impl From<&[u8; 16]> for Block { fn from(value: &[u8; 16]) -> Block { - Block(*value) + Block((*value).into()) } } impl From for Block { fn from(value: u128) -> Block { - Block(value.to_be_bytes()) + Block(value.to_be_bytes().into()) } } impl BitXor for Block { type Output = Block; fn bitxor(self, rhs: Block) -> Block { - Block([ - self.0[0] ^ rhs.0[0], - self.0[1] ^ rhs.0[1], - self.0[2] ^ rhs.0[2], - self.0[3] ^ rhs.0[3], - self.0[4] ^ rhs.0[4], - self.0[5] ^ rhs.0[5], - self.0[6] ^ rhs.0[6], - self.0[7] ^ rhs.0[7], - self.0[8] ^ rhs.0[8], - self.0[9] ^ rhs.0[9], - self.0[10] ^ rhs.0[10], - self.0[11] ^ rhs.0[11], - self.0[12] ^ rhs.0[12], - self.0[13] ^ rhs.0[13], - self.0[14] ^ rhs.0[14], - self.0[15] ^ rhs.0[15], - ]) + Block(self.0 ^ rhs.0) } } @@ -119,14 +124,14 @@ impl Shr for Block { impl BitAnd for Block { type Output = Block; fn bitand(self, rhs: Block) -> Block { - Block::from(self.to_int() & rhs.to_int()) + Block(self.0 & rhs.0) } } impl BitOr for Block { type Output = Block; fn bitor(self, rhs: Block) -> Block { - Block::from(self.to_int() | rhs.to_int()) + Block(self.0 | rhs.0) } } @@ -147,7 +152,7 @@ impl Mul for Block { type Output = Block; fn mul(self, rhs: u32) -> Block { match rhs { - 0 => Block::NULL, + 0 => Block::null(), 1 => self, 2 => { let mut result = self << 1; diff --git a/src/lib.rs b/src/lib.rs index 7f2e5c3..6e411a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(portable_simd)] //! AEZ *\[sic!\]* v5 encryption implemented in Rust. //! //! # ☣️ Cryptographic hazmat ☣️ @@ -297,7 +298,8 @@ fn append_auth(data_len: usize, buffer: &mut [u8]) { fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) { // We treat tau as bytes, but according to the spec, tau is actually in bits. let tau_block = Block::from_int(tau as u128 * 8); - let mut tweaks = vec![&tau_block.0, nonce]; + let tau_bytes = tau_block.bytes(); + let mut tweaks = vec![&tau_bytes, nonce]; tweaks.extend(ad); assert!(buffer.len() >= tau as usize); if buffer.len() == tau as usize { @@ -321,7 +323,8 @@ fn decrypt<'a>( } let tau_block = Block::from_int(tau * 8); - let mut tweaks = vec![&tau_block.0, nonce]; + let tau_bytes = tau_block.bytes(); + let mut tweaks = vec![&tau_bytes, nonce]; tweaks.extend(ad); if ciphertext.len() == tau as usize { @@ -387,12 +390,12 @@ fn encipher_aez_tiny(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { (left, right) = (right, right_); } if n % 8 == 0 { - message[..n / 8].copy_from_slice(&right.0[..n / 8]); - message[n / 8..].copy_from_slice(&left.0[..n / 8]); + message[..n / 8].copy_from_slice(&right.bytes()[..n / 8]); + message[n / 8..].copy_from_slice(&left.bytes()[..n / 8]); } else { let mut index = n / 8; - message[..index + 1].copy_from_slice(&right.0[..index + 1]); - for byte in &left.0[..n / 8 + 1] { + message[..index + 1].copy_from_slice(&right.bytes()[..index + 1]); + for byte in &left.bytes()[..n / 8 + 1] { message[index] |= byte >> 4; if index < message.len() - 1 { message[index + 1] = (byte & 0x0f) << 4; @@ -402,8 +405,8 @@ fn encipher_aez_tiny(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { } if mu < 128 { let mut c = Block::from_slice(&message); - c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE); - message.copy_from_slice(&c.0[..mu / 8]); + c = c ^ (e(0, 3, aez, delta ^ (c | Block::one())) & Block::one()); + message.copy_from_slice(&c.bytes()[..mu / 8]); } } @@ -420,7 +423,7 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { ); let len_v = d.saturating_sub(128); - let mut x = Block::NULL; + let mut x = Block::null(); let mut e1_eval = E::new(1, 0, aez); let e0_eval = E::new(0, 0, aez); @@ -431,8 +434,8 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { let wi = mi ^ e1_eval.eval(mi_); let xi = mi_ ^ e0_eval.eval(wi); - *raw_mi = wi.0; - *raw_mi_ = xi.0; + wi.write_to(raw_mi); + xi.write_to(raw_mi_); x = x ^ xi; } @@ -452,7 +455,7 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { let s_y = m_y ^ e(-1, 1, aez, s_x); let s = s_x ^ s_y; - let mut y = Block::NULL; + let mut y = Block::null(); let mut e2_eval = E::new(2, 0, aez); let mut e1_eval = E::new(1, 0, aez); let e0_eval = E::new(0, 0, aez); @@ -467,8 +470,8 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) { let ci_ = yi ^ e0_eval.eval(zi); let ci = zi ^ e1_eval.eval(ci_); - *raw_wi = ci.0; - *raw_xi = ci_.0; + ci.write_to(raw_wi); + ci_.write_to(raw_xi); y = y ^ yi; } @@ -520,8 +523,8 @@ fn decipher_aez_tiny(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { if mu < 128 { let mut c = Block::from_slice(buffer); - c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE); - buffer.copy_from_slice(&c.0[..mu / 8]); + c = c ^ (e(0, 3, aez, delta ^ (c | Block::one())) & Block::one()); + buffer.copy_from_slice(&c.bytes()[..mu / 8]); } let (mut left, mut right); @@ -540,12 +543,12 @@ fn decipher_aez_tiny(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { } if n % 8 == 0 { - buffer[..n / 8].copy_from_slice(&right.0[..n / 8]); - buffer[n / 8..].copy_from_slice(&left.0[..n / 8]); + buffer[..n / 8].copy_from_slice(&right.bytes()[..n / 8]); + buffer[n / 8..].copy_from_slice(&left.bytes()[..n / 8]); } else { let mut index = n / 8; - buffer[..index + 1].copy_from_slice(&right.0[..index + 1]); - for byte in &left.0[..n / 8 + 1] { + buffer[..index + 1].copy_from_slice(&right.bytes()[..index + 1]); + for byte in &left.bytes()[..n / 8 + 1] { buffer[index] |= byte >> 4; if index < buffer.len() - 1 { buffer[index + 1] = (byte & 0x0f) << 4; @@ -568,7 +571,7 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { ); let len_v = d.saturating_sub(128); - let mut y = Block::NULL; + let mut y = Block::null(); let mut e1_eval = E::new(1, 0, aez); let e0_eval = E::new(0, 0, aez); for (raw_ci, raw_ci_) in blocks.pairs_mut() { @@ -578,8 +581,8 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let wi = ci ^ e1_eval.eval(ci_); let yi = ci_ ^ e0_eval.eval(wi); - *raw_ci = wi.0; - *raw_ci_ = yi.0; + *raw_ci = wi.bytes(); + *raw_ci_ = yi.bytes(); y = y ^ yi; } @@ -599,7 +602,7 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let s_y = c_y ^ e(-1, 2, aez, s_x); let s = s_x ^ s_y; - let mut x = Block::NULL; + let mut x = Block::null(); let mut e2_eval = E::new(2, 0, aez); let mut e1_eval = E::new(1, 0, aez); let e0_eval = E::new(0, 0, aez); @@ -614,8 +617,8 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let mi_ = xi ^ e0_eval.eval(zi); let mi = zi ^ e1_eval.eval(mi_); - *raw_wi = mi.0; - *raw_yi = mi_.0; + *raw_wi = mi.bytes(); + *raw_yi = mi_.bytes(); x = x ^ xi; } @@ -659,7 +662,7 @@ fn pad_to_blocks(value: &[u8]) -> Vec { } fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { - let mut hash = Block::NULL; + let mut hash = Block::null(); for (i, tweak) in tweaks.iter().enumerate() { // Adjust for zero-based vs one-based indexing let j = i + 2 + 1; @@ -667,7 +670,7 @@ fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block { // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this // as well, but then we need to fiddle with getting an empty chunk from an empty iterator. if tweak.is_empty() { - hash = hash ^ e(j.try_into().unwrap(), 0, aez, Block::ONE); + hash = hash ^ e(j.try_into().unwrap(), 0, aez, Block::one()); } else if tweak.len() % 16 == 0 { for (l, chunk) in tweak.chunks(16).enumerate() { hash = hash @@ -704,7 +707,7 @@ fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) { let delta = aez_hash(aez, tweaks); for chunk in buffer.chunks_mut(16) { let block = e(-1, 3, aez, delta ^ Block::from_int(index)); - for (a, b) in chunk.iter_mut().zip(block.0.iter()) { + for (a, b) in chunk.iter_mut().zip(block.bytes().iter()) { *a ^= b; } index += 1; @@ -749,7 +752,9 @@ impl<'a> E<'a> { // We need to advance ki_p_i if exponent = old_exponent + 1 // This happens exactly when the old exponent was just a multiple of 8, because the // next exponent is then not a multiple anymore and will be rounded *up*. - if self.i % 8 == 0 { self.ki_p_i = self.ki_p_i * 2 }; + if self.i % 8 == 0 { + self.ki_p_i = self.ki_p_i * 2 + }; self.i += 1; } } @@ -796,7 +801,7 @@ mod test { let a = hex::decode(a).unwrap(); let a = Block::from_slice(&a); let b = hex::decode(b).unwrap(); - assert_eq!(&e(*j, *i, &aez, a).0, b.as_slice(), "{name}"); + assert_eq!(&e(*j, *i, &aez, a).bytes(), b.as_slice(), "{name}"); } } @@ -808,13 +813,13 @@ mod test { let aez = Aez::new(k.as_slice()); let v = hex::decode(v).unwrap(); - let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)]; + let mut tweaks = vec![Vec::from(Block::from_int(*tau).bytes())]; for t in *tw { tweaks.push(hex::decode(t).unwrap()); } let tweaks = tweaks.iter().map(Vec::as_slice).collect::>(); - assert_eq!(&aez_hash(&aez, &tweaks).0, v.as_slice(), "{name}"); + assert_eq!(&aez_hash(&aez, &tweaks).bytes(), v.as_slice(), "{name}"); } } -- cgit v1.2.3