From d77f4adf04f2878853d0919f908d1b110f3c94f2 Mon Sep 17 00:00:00 2001 From: Daniel Schadt Date: Thu, 10 Apr 2025 19:34:07 +0200 Subject: implement aes4 and aes10 with native instructions Even though aes::hazmat::cipher_round uses aes-ni instructions under the hood, simply loading the data (and the keys!) takes a significant amount of time. Sadly, there's no way that aes exposes that lets you re-use the "loaded" keys. By implementing aes4/aes10 directly with _mm_aesenc, we can keep the keys properly aligned. We still keep the software backend as fallback, using the software implementation of the aes crate. This gives a ~70% speedup. --- src/aesround.rs | 133 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 src/aesround.rs (limited to 'src/aesround.rs') diff --git a/src/aesround.rs b/src/aesround.rs new file mode 100644 index 0000000..0a06192 --- /dev/null +++ b/src/aesround.rs @@ -0,0 +1,133 @@ +use super::block::Block; + +#[cfg(target_arch = "x86_64")] +pub type AesImpl = x86_64::AesNi; + +#[cfg(not(target_arch = "x86_64"))] +pub type AesImpl = AesSoft; + +pub trait AesRound { + fn new(key_i: Block, key_j: Block, key_l: Block) -> Self; + fn aes4(&self, value: Block) -> Block; + fn aes10(&self, value: Block) -> Block; +} + +/// Implementation of aes4 and aes10 in software. +/// +/// Always available. +/// +/// Uses the `aes` crate under the hood. +pub struct AesSoft { + key_i: aes::Block, + key_j: aes::Block, + key_l: aes::Block, +} + +impl AesRound for AesSoft { + fn new(key_i: Block, key_j: Block, key_l: Block) -> Self { + Self { + key_i: key_i.0.into(), + key_j: key_j.0.into(), + key_l: key_l.0.into(), + } + } + + fn aes4(&self, value: Block) -> Block { + let mut block: aes::Block = value.0.into(); + ::aes::hazmat::cipher_round(&mut block, &self.key_j); + ::aes::hazmat::cipher_round(&mut block, &self.key_i); + ::aes::hazmat::cipher_round(&mut block, &self.key_l); + ::aes::hazmat::cipher_round(&mut block, &Block::NULL.0.into()); + Block(block.into()) + } + + fn aes10(&self, value: Block) -> Block { + let mut block: aes::Block = value.0.into(); + ::aes::hazmat::cipher_round(&mut block, &self.key_i); + ::aes::hazmat::cipher_round(&mut block, &self.key_j); + ::aes::hazmat::cipher_round(&mut block, &self.key_l); + ::aes::hazmat::cipher_round(&mut block, &self.key_i); + ::aes::hazmat::cipher_round(&mut block, &self.key_j); + ::aes::hazmat::cipher_round(&mut block, &self.key_l); + ::aes::hazmat::cipher_round(&mut block, &self.key_i); + ::aes::hazmat::cipher_round(&mut block, &self.key_j); + ::aes::hazmat::cipher_round(&mut block, &self.key_l); + ::aes::hazmat::cipher_round(&mut block, &self.key_i); + Block(block.into()) + } +} + +#[cfg(target_arch = "x86_64")] +pub mod x86_64 { + use super::*; + use core::arch::x86_64::*; + + cpufeatures::new!(cpuid_aes, "aes"); + + pub struct AesNi { + support: cpuid_aes::InitToken, + fallback: AesSoft, + key_i: __m128i, + key_j: __m128i, + key_l: __m128i, + null: __m128i, + } + + impl AesRound for AesNi { + fn new(key_i: Block, key_j: Block, key_l: Block) -> Self { + // SAFETY: loadu can load from unaligned memory + unsafe { + Self { + support: cpuid_aes::init(), + fallback: AesSoft::new(key_i, key_j, key_l), + key_i: _mm_loadu_si128(key_i.0.as_ptr() as *const _), + key_j: _mm_loadu_si128(key_j.0.as_ptr() as *const _), + key_l: _mm_loadu_si128(key_l.0.as_ptr() as *const _), + null: _mm_loadu_si128(Block::NULL.0.as_ptr() as *const _), + } + } + } + + fn aes4(&self, value: Block) -> Block { + if !self.support.get() { + return self.fallback.aes4(value); + } + + // SAFETY: loadu can load from unaligned memory + unsafe { + let mut block = _mm_loadu_si128(value.0.as_ptr() as *const _); + block = _mm_aesenc_si128(block, self.key_j); + block = _mm_aesenc_si128(block, self.key_i); + block = _mm_aesenc_si128(block, self.key_l); + block = _mm_aesenc_si128(block, self.null); + let mut result = Block::default(); + _mm_storeu_si128(result.0.as_mut_ptr() as *mut _, block); + result + } + } + + fn aes10(&self, value: Block) -> Block { + if !self.support.get() { + return self.fallback.aes10(value); + } + + // SAFETY: loadu can load from unaligned memory + unsafe { + let mut block = _mm_loadu_si128(value.0.as_ptr() as *const _); + block = _mm_aesenc_si128(block, self.key_i); + block = _mm_aesenc_si128(block, self.key_j); + block = _mm_aesenc_si128(block, self.key_l); + block = _mm_aesenc_si128(block, self.key_i); + block = _mm_aesenc_si128(block, self.key_j); + block = _mm_aesenc_si128(block, self.key_l); + block = _mm_aesenc_si128(block, self.key_i); + block = _mm_aesenc_si128(block, self.key_j); + block = _mm_aesenc_si128(block, self.key_l); + block = _mm_aesenc_si128(block, self.key_i); + let mut result = Block::default(); + _mm_storeu_si128(result.0.as_mut_ptr() as *mut _, block); + result + } + } + } +} -- cgit v1.2.3