aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaniel Schadt <kingdread@gmx.de>2025-04-11 12:48:18 +0200
committerDaniel Schadt <kingdread@gmx.de>2025-04-11 12:48:18 +0200
commit5bd298ed568aca12a54f014a7b13f943379a5eb9 (patch)
tree911afd45baafe196517455f33ab89bc8a9f09355 /src
parent34ed0189281fcca1921d4e3d762e6d9183d5230f (diff)
downloadzears-5bd298ed568aca12a54f014a7b13f943379a5eb9.tar.gz
zears-5bd298ed568aca12a54f014a7b13f943379a5eb9.tar.bz2
zears-5bd298ed568aca12a54f014a7b13f943379a5eb9.zip
use simd instructions
(requires nightly compiler)
Diffstat (limited to 'src')
-rw-r--r--src/accessor.rs9
-rw-r--r--src/aesround.rs49
-rw-r--r--src/block.rs65
-rw-r--r--src/lib.rs71
4 files changed, 99 insertions, 95 deletions
diff --git a/src/accessor.rs b/src/accessor.rs
index 89f5251..24905af 100644
--- a/src/accessor.rs
+++ b/src/accessor.rs
@@ -36,7 +36,8 @@ impl<'a> BlockAccessor<'a> {
pub fn set_m_u(&mut self, m_u: Block) {
let start = self.suffix_start();
- self.data[start..start + self.m_u_len / 8].copy_from_slice(&m_u.0[..self.m_u_len / 8]);
+ self.data[start..start + self.m_u_len / 8]
+ .copy_from_slice(&m_u.bytes()[..self.m_u_len / 8]);
}
pub fn m_v(&self) -> Block {
@@ -47,7 +48,7 @@ impl<'a> BlockAccessor<'a> {
pub fn set_m_v(&mut self, m_v: Block) {
let start = self.suffix_start();
self.data[start + self.m_u_len / 8..start + self.m_uv_len / 8]
- .copy_from_slice(&m_v.0[..self.m_v_len / 8]);
+ .copy_from_slice(&m_v.bytes()[..self.m_v_len / 8]);
}
pub fn m_x(&self) -> Block {
@@ -57,7 +58,7 @@ impl<'a> BlockAccessor<'a> {
pub fn set_m_x(&mut self, m_x: Block) {
let start = self.suffix_start() + self.m_uv_len / 8;
- self.data[start..start + 16].copy_from_slice(&m_x.0);
+ self.data[start..start + 16].copy_from_slice(&m_x.bytes());
}
pub fn m_y(&self) -> Block {
@@ -67,7 +68,7 @@ impl<'a> BlockAccessor<'a> {
pub fn set_m_y(&mut self, m_y: Block) {
let start = self.suffix_start() + self.m_uv_len / 8;
- self.data[start + 16..start + 32].copy_from_slice(&m_y.0);
+ self.data[start + 16..start + 32].copy_from_slice(&m_y.bytes());
}
pub fn pairs_mut<'b>(
diff --git a/src/aesround.rs b/src/aesround.rs
index 0a06192..d04ac9b 100644
--- a/src/aesround.rs
+++ b/src/aesround.rs
@@ -26,23 +26,23 @@ pub struct AesSoft {
impl AesRound for AesSoft {
fn new(key_i: Block, key_j: Block, key_l: Block) -> Self {
Self {
- key_i: key_i.0.into(),
- key_j: key_j.0.into(),
- key_l: key_l.0.into(),
+ key_i: key_i.bytes().into(),
+ key_j: key_j.bytes().into(),
+ key_l: key_l.bytes().into(),
}
}
fn aes4(&self, value: Block) -> Block {
- let mut block: aes::Block = value.0.into();
+ let mut block: aes::Block = value.bytes().into();
::aes::hazmat::cipher_round(&mut block, &self.key_j);
::aes::hazmat::cipher_round(&mut block, &self.key_i);
::aes::hazmat::cipher_round(&mut block, &self.key_l);
- ::aes::hazmat::cipher_round(&mut block, &Block::NULL.0.into());
- Block(block.into())
+ ::aes::hazmat::cipher_round(&mut block, &Block::null().bytes().into());
+ <Block as From<[u8; 16]>>::from(block.into())
}
fn aes10(&self, value: Block) -> Block {
- let mut block: aes::Block = value.0.into();
+ let mut block: aes::Block = value.bytes().into();
::aes::hazmat::cipher_round(&mut block, &self.key_i);
::aes::hazmat::cipher_round(&mut block, &self.key_j);
::aes::hazmat::cipher_round(&mut block, &self.key_l);
@@ -53,7 +53,7 @@ impl AesRound for AesSoft {
::aes::hazmat::cipher_round(&mut block, &self.key_j);
::aes::hazmat::cipher_round(&mut block, &self.key_l);
::aes::hazmat::cipher_round(&mut block, &self.key_i);
- Block(block.into())
+ <Block as From<[u8; 16]>>::from(block.into())
}
}
@@ -75,16 +75,13 @@ pub mod x86_64 {
impl AesRound for AesNi {
fn new(key_i: Block, key_j: Block, key_l: Block) -> Self {
- // SAFETY: loadu can load from unaligned memory
- unsafe {
- Self {
- support: cpuid_aes::init(),
- fallback: AesSoft::new(key_i, key_j, key_l),
- key_i: _mm_loadu_si128(key_i.0.as_ptr() as *const _),
- key_j: _mm_loadu_si128(key_j.0.as_ptr() as *const _),
- key_l: _mm_loadu_si128(key_l.0.as_ptr() as *const _),
- null: _mm_loadu_si128(Block::NULL.0.as_ptr() as *const _),
- }
+ Self {
+ support: cpuid_aes::init(),
+ fallback: AesSoft::new(key_i, key_j, key_l),
+ key_i: key_i.simd().into(),
+ key_j: key_j.simd().into(),
+ key_l: key_l.simd().into(),
+ null: Block::null().simd().into(),
}
}
@@ -93,16 +90,14 @@ pub mod x86_64 {
return self.fallback.aes4(value);
}
- // SAFETY: loadu can load from unaligned memory
+ // SAFETY: Nothing should go wrong when calling AESENC
unsafe {
- let mut block = _mm_loadu_si128(value.0.as_ptr() as *const _);
+ let mut block = value.simd().into();
block = _mm_aesenc_si128(block, self.key_j);
block = _mm_aesenc_si128(block, self.key_i);
block = _mm_aesenc_si128(block, self.key_l);
block = _mm_aesenc_si128(block, self.null);
- let mut result = Block::default();
- _mm_storeu_si128(result.0.as_mut_ptr() as *mut _, block);
- result
+ Block::from_simd(block.into())
}
}
@@ -111,9 +106,9 @@ pub mod x86_64 {
return self.fallback.aes10(value);
}
- // SAFETY: loadu can load from unaligned memory
+ // SAFETY: Nothing should go wrong when calling AESENC
unsafe {
- let mut block = _mm_loadu_si128(value.0.as_ptr() as *const _);
+ let mut block = value.simd().into();
block = _mm_aesenc_si128(block, self.key_i);
block = _mm_aesenc_si128(block, self.key_j);
block = _mm_aesenc_si128(block, self.key_l);
@@ -124,9 +119,7 @@ pub mod x86_64 {
block = _mm_aesenc_si128(block, self.key_j);
block = _mm_aesenc_si128(block, self.key_l);
block = _mm_aesenc_si128(block, self.key_i);
- let mut result = Block::default();
- _mm_storeu_si128(result.0.as_mut_ptr() as *mut _, block);
- result
+ Block::from_simd(block.into())
}
}
}
diff --git a/src/block.rs b/src/block.rs
index c294aab..b485b17 100644
--- a/src/block.rs
+++ b/src/block.rs
@@ -1,12 +1,34 @@
use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr};
+use std::simd::prelude::*;
/// A block, the unit of work that AEZ divides the message into.
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
-pub struct Block(pub [u8; 16]);
+pub struct Block(u8x16);
impl Block {
- pub const NULL: Block = Block([0; 16]);
- pub const ONE: Block = Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
+ pub fn null() -> Block {
+ Block([0; 16].into())
+ }
+
+ pub fn one() -> Block {
+ Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].into())
+ }
+
+ pub fn bytes(&self) -> [u8; 16] {
+ self.0.into()
+ }
+
+ pub fn write_to(&self, output: &mut [u8; 16]) {
+ self.0.copy_to_slice(output);
+ }
+
+ pub(crate) fn simd(&self) -> u8x16 {
+ self.0
+ }
+
+ pub(crate) fn from_simd(value: u8x16) -> Self {
+ Block(value)
+ }
/// Create a block from a slice.
///
@@ -16,18 +38,18 @@ impl Block {
let len = value.len().min(16);
let mut array = [0; 16];
array[..len].copy_from_slice(&value[..len]);
- Block(array)
+ Block(array.into())
}
/// Constructs a block representing the given integer.
///
/// This corresponds to [x]_128 in the paper.
pub fn from_int<I: Into<u128>>(value: I) -> Self {
- Block(value.into().to_be_bytes())
+ Block(value.into().to_be_bytes().into())
}
pub fn to_int(&self) -> u128 {
- u128::from_be_bytes(self.0)
+ u128::from_be_bytes(self.0.into())
}
/// Pad the block to full length.
@@ -62,43 +84,26 @@ impl Block {
impl From<[u8; 16]> for Block {
fn from(value: [u8; 16]) -> Block {
- Block(value)
+ Block(value.into())
}
}
impl From<&[u8; 16]> for Block {
fn from(value: &[u8; 16]) -> Block {
- Block(*value)
+ Block((*value).into())
}
}
impl From<u128> for Block {
fn from(value: u128) -> Block {
- Block(value.to_be_bytes())
+ Block(value.to_be_bytes().into())
}
}
impl BitXor<Block> for Block {
type Output = Block;
fn bitxor(self, rhs: Block) -> Block {
- Block([
- self.0[0] ^ rhs.0[0],
- self.0[1] ^ rhs.0[1],
- self.0[2] ^ rhs.0[2],
- self.0[3] ^ rhs.0[3],
- self.0[4] ^ rhs.0[4],
- self.0[5] ^ rhs.0[5],
- self.0[6] ^ rhs.0[6],
- self.0[7] ^ rhs.0[7],
- self.0[8] ^ rhs.0[8],
- self.0[9] ^ rhs.0[9],
- self.0[10] ^ rhs.0[10],
- self.0[11] ^ rhs.0[11],
- self.0[12] ^ rhs.0[12],
- self.0[13] ^ rhs.0[13],
- self.0[14] ^ rhs.0[14],
- self.0[15] ^ rhs.0[15],
- ])
+ Block(self.0 ^ rhs.0)
}
}
@@ -119,14 +124,14 @@ impl Shr<u32> for Block {
impl BitAnd<Block> for Block {
type Output = Block;
fn bitand(self, rhs: Block) -> Block {
- Block::from(self.to_int() & rhs.to_int())
+ Block(self.0 & rhs.0)
}
}
impl BitOr<Block> for Block {
type Output = Block;
fn bitor(self, rhs: Block) -> Block {
- Block::from(self.to_int() | rhs.to_int())
+ Block(self.0 | rhs.0)
}
}
@@ -147,7 +152,7 @@ impl Mul<u32> for Block {
type Output = Block;
fn mul(self, rhs: u32) -> Block {
match rhs {
- 0 => Block::NULL,
+ 0 => Block::null(),
1 => self,
2 => {
let mut result = self << 1;
diff --git a/src/lib.rs b/src/lib.rs
index 7f2e5c3..6e411a0 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,3 +1,4 @@
+#![feature(portable_simd)]
//! AEZ *\[sic!\]* v5 encryption implemented in Rust.
//!
//! # ☣️ Cryptographic hazmat ☣️
@@ -297,7 +298,8 @@ fn append_auth(data_len: usize, buffer: &mut [u8]) {
fn encrypt(aez: &Aez, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) {
// We treat tau as bytes, but according to the spec, tau is actually in bits.
let tau_block = Block::from_int(tau as u128 * 8);
- let mut tweaks = vec![&tau_block.0, nonce];
+ let tau_bytes = tau_block.bytes();
+ let mut tweaks = vec![&tau_bytes, nonce];
tweaks.extend(ad);
assert!(buffer.len() >= tau as usize);
if buffer.len() == tau as usize {
@@ -321,7 +323,8 @@ fn decrypt<'a>(
}
let tau_block = Block::from_int(tau * 8);
- let mut tweaks = vec![&tau_block.0, nonce];
+ let tau_bytes = tau_block.bytes();
+ let mut tweaks = vec![&tau_bytes, nonce];
tweaks.extend(ad);
if ciphertext.len() == tau as usize {
@@ -387,12 +390,12 @@ fn encipher_aez_tiny(aez: &Aez, tweaks: Tweak, message: &mut [u8]) {
(left, right) = (right, right_);
}
if n % 8 == 0 {
- message[..n / 8].copy_from_slice(&right.0[..n / 8]);
- message[n / 8..].copy_from_slice(&left.0[..n / 8]);
+ message[..n / 8].copy_from_slice(&right.bytes()[..n / 8]);
+ message[n / 8..].copy_from_slice(&left.bytes()[..n / 8]);
} else {
let mut index = n / 8;
- message[..index + 1].copy_from_slice(&right.0[..index + 1]);
- for byte in &left.0[..n / 8 + 1] {
+ message[..index + 1].copy_from_slice(&right.bytes()[..index + 1]);
+ for byte in &left.bytes()[..n / 8 + 1] {
message[index] |= byte >> 4;
if index < message.len() - 1 {
message[index + 1] = (byte & 0x0f) << 4;
@@ -402,8 +405,8 @@ fn encipher_aez_tiny(aez: &Aez, tweaks: Tweak, message: &mut [u8]) {
}
if mu < 128 {
let mut c = Block::from_slice(&message);
- c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE);
- message.copy_from_slice(&c.0[..mu / 8]);
+ c = c ^ (e(0, 3, aez, delta ^ (c | Block::one())) & Block::one());
+ message.copy_from_slice(&c.bytes()[..mu / 8]);
}
}
@@ -420,7 +423,7 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) {
);
let len_v = d.saturating_sub(128);
- let mut x = Block::NULL;
+ let mut x = Block::null();
let mut e1_eval = E::new(1, 0, aez);
let e0_eval = E::new(0, 0, aez);
@@ -431,8 +434,8 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) {
let wi = mi ^ e1_eval.eval(mi_);
let xi = mi_ ^ e0_eval.eval(wi);
- *raw_mi = wi.0;
- *raw_mi_ = xi.0;
+ wi.write_to(raw_mi);
+ xi.write_to(raw_mi_);
x = x ^ xi;
}
@@ -452,7 +455,7 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) {
let s_y = m_y ^ e(-1, 1, aez, s_x);
let s = s_x ^ s_y;
- let mut y = Block::NULL;
+ let mut y = Block::null();
let mut e2_eval = E::new(2, 0, aez);
let mut e1_eval = E::new(1, 0, aez);
let e0_eval = E::new(0, 0, aez);
@@ -467,8 +470,8 @@ fn encipher_aez_core(aez: &Aez, tweaks: Tweak, message: &mut [u8]) {
let ci_ = yi ^ e0_eval.eval(zi);
let ci = zi ^ e1_eval.eval(ci_);
- *raw_wi = ci.0;
- *raw_xi = ci_.0;
+ ci.write_to(raw_wi);
+ ci_.write_to(raw_xi);
y = y ^ yi;
}
@@ -520,8 +523,8 @@ fn decipher_aez_tiny(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
if mu < 128 {
let mut c = Block::from_slice(buffer);
- c = c ^ (e(0, 3, aez, delta ^ (c | Block::ONE)) & Block::ONE);
- buffer.copy_from_slice(&c.0[..mu / 8]);
+ c = c ^ (e(0, 3, aez, delta ^ (c | Block::one())) & Block::one());
+ buffer.copy_from_slice(&c.bytes()[..mu / 8]);
}
let (mut left, mut right);
@@ -540,12 +543,12 @@ fn decipher_aez_tiny(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
}
if n % 8 == 0 {
- buffer[..n / 8].copy_from_slice(&right.0[..n / 8]);
- buffer[n / 8..].copy_from_slice(&left.0[..n / 8]);
+ buffer[..n / 8].copy_from_slice(&right.bytes()[..n / 8]);
+ buffer[n / 8..].copy_from_slice(&left.bytes()[..n / 8]);
} else {
let mut index = n / 8;
- buffer[..index + 1].copy_from_slice(&right.0[..index + 1]);
- for byte in &left.0[..n / 8 + 1] {
+ buffer[..index + 1].copy_from_slice(&right.bytes()[..index + 1]);
+ for byte in &left.bytes()[..n / 8 + 1] {
buffer[index] |= byte >> 4;
if index < buffer.len() - 1 {
buffer[index + 1] = (byte & 0x0f) << 4;
@@ -568,7 +571,7 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
);
let len_v = d.saturating_sub(128);
- let mut y = Block::NULL;
+ let mut y = Block::null();
let mut e1_eval = E::new(1, 0, aez);
let e0_eval = E::new(0, 0, aez);
for (raw_ci, raw_ci_) in blocks.pairs_mut() {
@@ -578,8 +581,8 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
let wi = ci ^ e1_eval.eval(ci_);
let yi = ci_ ^ e0_eval.eval(wi);
- *raw_ci = wi.0;
- *raw_ci_ = yi.0;
+ *raw_ci = wi.bytes();
+ *raw_ci_ = yi.bytes();
y = y ^ yi;
}
@@ -599,7 +602,7 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
let s_y = c_y ^ e(-1, 2, aez, s_x);
let s = s_x ^ s_y;
- let mut x = Block::NULL;
+ let mut x = Block::null();
let mut e2_eval = E::new(2, 0, aez);
let mut e1_eval = E::new(1, 0, aez);
let e0_eval = E::new(0, 0, aez);
@@ -614,8 +617,8 @@ fn decipher_aez_core(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
let mi_ = xi ^ e0_eval.eval(zi);
let mi = zi ^ e1_eval.eval(mi_);
- *raw_wi = mi.0;
- *raw_yi = mi_.0;
+ *raw_wi = mi.bytes();
+ *raw_yi = mi_.bytes();
x = x ^ xi;
}
@@ -659,7 +662,7 @@ fn pad_to_blocks(value: &[u8]) -> Vec<Block> {
}
fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block {
- let mut hash = Block::NULL;
+ let mut hash = Block::null();
for (i, tweak) in tweaks.iter().enumerate() {
// Adjust for zero-based vs one-based indexing
let j = i + 2 + 1;
@@ -667,7 +670,7 @@ fn aez_hash(aez: &Aez, tweaks: Tweak) -> Block {
// set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this
// as well, but then we need to fiddle with getting an empty chunk from an empty iterator.
if tweak.is_empty() {
- hash = hash ^ e(j.try_into().unwrap(), 0, aez, Block::ONE);
+ hash = hash ^ e(j.try_into().unwrap(), 0, aez, Block::one());
} else if tweak.len() % 16 == 0 {
for (l, chunk) in tweak.chunks(16).enumerate() {
hash = hash
@@ -704,7 +707,7 @@ fn aez_prf(aez: &Aez, tweaks: Tweak, buffer: &mut [u8]) {
let delta = aez_hash(aez, tweaks);
for chunk in buffer.chunks_mut(16) {
let block = e(-1, 3, aez, delta ^ Block::from_int(index));
- for (a, b) in chunk.iter_mut().zip(block.0.iter()) {
+ for (a, b) in chunk.iter_mut().zip(block.bytes().iter()) {
*a ^= b;
}
index += 1;
@@ -749,7 +752,9 @@ impl<'a> E<'a> {
// We need to advance ki_p_i if exponent = old_exponent + 1
// This happens exactly when the old exponent was just a multiple of 8, because the
// next exponent is then not a multiple anymore and will be rounded *up*.
- if self.i % 8 == 0 { self.ki_p_i = self.ki_p_i * 2 };
+ if self.i % 8 == 0 {
+ self.ki_p_i = self.ki_p_i * 2
+ };
self.i += 1;
}
}
@@ -796,7 +801,7 @@ mod test {
let a = hex::decode(a).unwrap();
let a = Block::from_slice(&a);
let b = hex::decode(b).unwrap();
- assert_eq!(&e(*j, *i, &aez, a).0, b.as_slice(), "{name}");
+ assert_eq!(&e(*j, *i, &aez, a).bytes(), b.as_slice(), "{name}");
}
}
@@ -808,13 +813,13 @@ mod test {
let aez = Aez::new(k.as_slice());
let v = hex::decode(v).unwrap();
- let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)];
+ let mut tweaks = vec![Vec::from(Block::from_int(*tau).bytes())];
for t in *tw {
tweaks.push(hex::decode(t).unwrap());
}
let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>();
- assert_eq!(&aez_hash(&aez, &tweaks).0, v.as_slice(), "{name}");
+ assert_eq!(&aez_hash(&aez, &tweaks).bytes(), v.as_slice(), "{name}");
}
}