aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Schadt <kingdread@gmx.de>2025-04-05 19:03:10 +0200
committerDaniel Schadt <kingdread@gmx.de>2025-04-05 19:03:10 +0200
commit71cdf50525f0cbb70673477510050669206df7f2 (patch)
tree41e58ce93318dfaaf8f2c4f4dd91b879ead378af
parent5cd9e4a71f0561d599ce5c7d498828ef5b8db2bb (diff)
downloadzears-71cdf50525f0cbb70673477510050669206df7f2.tar.gz
zears-71cdf50525f0cbb70673477510050669206df7f2.tar.bz2
zears-71cdf50525f0cbb70673477510050669206df7f2.zip
use proper Block struct and operator overloading
-rw-r--r--src/block.rs221
-rw-r--r--src/lib.rs565
2 files changed, 387 insertions, 399 deletions
diff --git a/src/block.rs b/src/block.rs
new file mode 100644
index 0000000..e63062e
--- /dev/null
+++ b/src/block.rs
@@ -0,0 +1,221 @@
+use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr};
+
+/// A block, the unit of work that AEZ divides the message into.
+#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Block(pub [u8; 16]);
+
+impl Block {
+ pub const NULL: Block = Block([0; 16]);
+ pub const ONE: Block = Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
+
+ /// Create a block from a slice.
+ ///
+ /// If the slice is too long, it will be truncated. If the slice is too short, the remaining
+ /// items are set to 0.
+ pub fn from_slice(value: &[u8]) -> Self {
+ let len = value.len().min(16);
+ let mut array = [0; 16];
+ array[..len].copy_from_slice(&value[..len]);
+ Block(array)
+ }
+
+ /// Constructs a block representing the given integer.
+ ///
+ /// This corresponds to [x]_128 in the paper.
+ pub fn from_int<I: Into<u128>>(value: I) -> Self {
+ Block(value.into().to_be_bytes())
+ }
+
+ pub fn to_int(&self) -> u128 {
+ u128::from_be_bytes(self.0)
+ }
+
+ /// Pad the block to full length.
+ ///
+ /// The given length is the current length.
+ ///
+ /// This corresponds to X10* in the paper.
+ pub fn pad(&self, length: usize) -> Block {
+ assert!(length <= 127);
+ Block::from_int(self.to_int() | (1 << (127 - length)))
+ }
+
+ /// Clip the block by setting all bits beyond the given length to 0.
+ pub fn clip(&self, mut length: usize) -> Block {
+ let mut block = self.0;
+ for byte in &mut block {
+ if length == 0 {
+ *byte = 0;
+ } else if length < 8 {
+ *byte &= 0xff << (8 - length);
+ }
+ length = length.saturating_sub(8);
+ }
+ Block(block)
+ }
+}
+
+impl From<[u8; 16]> for Block {
+ fn from(value: [u8; 16]) -> Block {
+ Block(value)
+ }
+}
+
+impl From<&[u8; 16]> for Block {
+ fn from(value: &[u8; 16]) -> Block {
+ Block(*value)
+ }
+}
+
+impl From<u128> for Block {
+ fn from(value: u128) -> Block {
+ Block(value.to_be_bytes())
+ }
+}
+
+impl BitXor<Block> for Block {
+ type Output = Block;
+ fn bitxor(self, rhs: Block) -> Block {
+ Block::from(self.to_int() ^ rhs.to_int())
+ }
+}
+
+impl Shl<u32> for Block {
+ type Output = Block;
+ fn shl(self, rhs: u32) -> Block {
+ Block::from(self.to_int() << rhs)
+ }
+}
+
+impl Shr<u32> for Block {
+ type Output = Block;
+ fn shr(self, rhs: u32) -> Block {
+ Block::from(self.to_int() >> rhs)
+ }
+}
+
+impl BitAnd<Block> for Block {
+ type Output = Block;
+ fn bitand(self, rhs: Block) -> Block {
+ Block::from(self.to_int() & rhs.to_int())
+ }
+}
+
+impl BitOr<Block> for Block {
+ type Output = Block;
+ fn bitor(self, rhs: Block) -> Block {
+ Block::from(self.to_int() | rhs.to_int())
+ }
+}
+
+impl Index<usize> for Block {
+ type Output = u8;
+ fn index(&self, index: usize) -> &u8 {
+ &self.0[index]
+ }
+}
+
+impl IndexMut<usize> for Block {
+ fn index_mut(&mut self, index: usize) -> &mut u8 {
+ &mut self.0[index]
+ }
+}
+
+impl Mul<u32> for Block {
+ type Output = Block;
+ fn mul(self, rhs: u32) -> Block {
+ match rhs {
+ 0 => Block::NULL,
+ 1 => self,
+ 2 => {
+ let mut result = self << 1;
+ if self[0] & 0x80 != 0 {
+ result[15] ^= 135;
+ }
+ result
+ }
+ _ if rhs % 2 == 0 => self * 2 * (rhs / 2),
+ _ => self * (rhs - 1) ^ self,
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ #[test]
+ fn test_xor() {
+ assert_eq!(
+ Block::from([1; 16]) ^ Block::from([2; 16]),
+ Block::from([3; 16])
+ );
+ }
+
+ #[test]
+ fn test_pad() {
+ assert_eq!(
+ Block::from([0; 16]).pad(0),
+ Block::from([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ assert_eq!(
+ Block::from([0; 16]).pad(1),
+ Block::from([0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ assert_eq!(
+ Block::from([0; 16]).pad(8),
+ Block::from([0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ }
+
+ #[test]
+ fn test_shl() {
+ assert_eq!(
+ Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 1,
+ Block::from([0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ assert_eq!(
+ Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4,
+ Block::from([0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ assert_eq!(
+ Block::from([0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4,
+ Block::from([0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ assert_eq!(
+ Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) << 8,
+ Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]),
+ );
+ }
+
+ #[test]
+ fn test_times() {
+ assert_eq!(
+ Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 0,
+ Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
+ );
+ assert_eq!(
+ Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 1,
+ Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
+ );
+ assert_eq!(
+ Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2,
+ Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]),
+ );
+ assert_eq!(
+ Block::from([128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2,
+ Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133]),
+ );
+ assert_eq!(
+ Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 2,
+ Block::from([2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133]),
+ );
+ assert_eq!(
+ Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 3,
+ Block::from([131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132]),
+ );
+ assert_eq!(
+ Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 4,
+ Block::from([4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10]),
+ );
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 6f4d93f..f5dd3ab 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,15 +1,13 @@
use std::iter;
+mod block;
#[cfg(test)]
mod testvectors;
-type Block = [u8; 16];
+use block::Block;
type Key = [u8; 48];
type Tweak<'a> = &'a [&'a [u8]];
-static NULL: Block = [0; 16];
-static ONE: Block = [128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
-
pub struct Aez {
key: Key,
}
@@ -40,72 +38,14 @@ impl Aez {
}
}
-fn xor(lhs: &Block, rhs: &Block) -> Block {
- let mut result = [0; 16];
- for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) {
- *r = a ^ b;
- }
- result
-}
-
-fn and(lhs: &Block, rhs: &Block) -> Block {
- let mut result = [0; 16];
- for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) {
- *r = a & b;
- }
- result
-}
-
-fn or(lhs: &Block, rhs: &Block) -> Block {
- let mut result = [0; 16];
- for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) {
- *r = a | b;
- }
- result
-}
-
-fn lshift(block: &Block, times: u32) -> Block {
- let mut block = block.clone();
- for _ in 0..times {
- let mut result = [0; 16];
- for (b, r) in block.iter().zip(result.iter_mut()) {
- *r = b << 1;
- }
- for (b, r) in block[1..].iter().zip(result.iter_mut()) {
- *r = *r | ((b & 0x80) >> 7);
- }
- block = result;
- }
- block
-}
-
-fn times(lhs: u32, block: &Block) -> Block {
- match lhs {
- 0 => NULL,
- 1 => *block,
- 2 => {
- let mut result = lshift(block, 1);
- if block[0] & 0x80 != 0 {
- result[15] ^= 135;
- }
- result
- }
- _ if lhs % 2 == 0 => times(2, &times(lhs / 2, block)),
- _ => xor(&times(lhs - 1, block), block),
- }
-}
-
-fn aesenc(mut block: Block, key: &Block) -> Block {
- aes::hazmat::cipher_round((&mut block).into(), key.into());
+fn aesenc(mut block: Block, key: &Block) -> block::Block {
+ aes::hazmat::cipher_round((&mut block.0).into(), &key.0.into());
block
}
fn aes4(keys: &[&Block; 5], block: &Block) -> Block {
aesenc(
- aesenc(
- aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]),
- keys[3],
- ),
+ aesenc(aesenc(aesenc(*block ^ *keys[0], keys[1]), keys[2]), keys[3]),
keys[4],
)
}
@@ -119,7 +59,7 @@ fn aes10(keys: &[&Block; 11], block: &Block) -> Block {
aesenc(
aesenc(
aesenc(
- aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]),
+ aesenc(aesenc(*block ^ *keys[0], keys[1]), keys[2]),
keys[3],
),
keys[4],
@@ -150,35 +90,6 @@ fn extract(key: &[u8]) -> [u8; 48] {
}
}
-fn clip_to_bits(block: &mut Block, mut bits: usize) {
- for byte in block {
- if bits == 0 {
- *byte = 0;
- } else if bits < 8 {
- *byte &= 0xff << (8 - bits);
- }
- bits = bits.saturating_sub(8);
- }
-}
-
-fn full_block(data: &[u8]) -> Block {
- let mut result = [0; 16];
- result[..data.len()].copy_from_slice(data);
- result
-}
-
-fn pad_block(block: &Block, mut bits: usize) -> Block {
- let mut block = *block;
- for byte in &mut block {
- if bits < 8 {
- *byte |= 0x80 >> bits;
- break;
- }
- bits = bits.saturating_sub(8);
- }
- block
-}
-
fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> {
let auth_message = message
.iter()
@@ -186,8 +97,8 @@ fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> V
.chain(iter::repeat_n(0, tau as usize))
.collect::<Vec<_>>();
// We treat tau as bytes, but according to the spec, tau is actually in bits.
- let tau_block = tau_to_block(tau * 8);
- let mut tweaks = vec![&tau_block, nonce];
+ let tau_block = Block::from_int(tau * 8);
+ let mut tweaks = vec![&tau_block.0, nonce];
tweaks.extend(ad);
if message.is_empty() {
aez_prf(key, &tweaks, tau)
@@ -201,8 +112,8 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -
return None;
}
- let tau_block = tau_to_block(tau * 8);
- let mut tweaks = vec![&tau_block, nonce];
+ let tau_block = Block::from_int(tau * 8);
+ let mut tweaks = vec![&tau_block.0, nonce];
tweaks.extend(ad);
if ciphertext.len() == tau as usize {
@@ -237,7 +148,7 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let n = mu / 2;
let delta = aez_hash(key, tweaks);
let round_count = match mu {
- 8 => 24,
+ 8 => 24u32,
16 => 16,
_ if mu < 128 => 10,
_ => 8,
@@ -246,48 +157,34 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let (mut left, mut right);
// We might end up having to split at a nibble, so manually adjust for that
if n % 8 == 0 {
- left = full_block(&message[..n / 8]);
- right = full_block(&message[n / 8..]);
+ left = Block::from_slice(&message[..n / 8]);
+ right = Block::from_slice(&message[n / 8..]);
} else {
- left = full_block(&message[..n / 8 + 1]);
- clip_to_bits(&mut left, n);
- right = full_block(&message[n / 8..]);
- right = lshift(&right, 4);
+ assert!(n % 8 == 4);
+ left = Block::from_slice(&message[..n / 8 + 1]).clip(n);
+ right = Block::from_slice(&message[n / 8..]) << 4;
};
let i = if mu >= 128 { 6 } else { 7 };
for j in 0..round_count {
- let mut right_ = xor(
- &left,
- &e(
- 0,
- i,
- key,
- &xor(
- &xor(&delta, &pad_block(&right, n)),
- &(j as u128).to_be_bytes(),
- ),
- ),
- );
- clip_to_bits(&mut right_, n);
+ let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n);
(left, right) = (right, right_);
}
let mut ciphertext = Vec::new();
if n % 8 == 0 {
- ciphertext.extend_from_slice(&right[..n / 8]);
- ciphertext.extend_from_slice(&left[..n / 8]);
+ ciphertext.extend_from_slice(&right.0[..n / 8]);
+ ciphertext.extend_from_slice(&left.0[..n / 8]);
} else {
- ciphertext.extend_from_slice(&right[..n / 8 + 1]);
- for byte in &left[..n / 8 + 1] {
+ ciphertext.extend_from_slice(&right.0[..n / 8 + 1]);
+ for byte in &left.0[..n / 8 + 1] {
*ciphertext.last_mut().unwrap() |= byte >> 4;
ciphertext.push((byte & 0x0f) << 4);
}
ciphertext.pop();
}
if mu < 128 {
- let mut c = Block::default();
- c[..ciphertext.len()].copy_from_slice(&ciphertext);
- c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE));
- ciphertext = Vec::from(&c[..mu / 8]);
+ let mut c = Block::from_slice(&ciphertext);
+ c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE);
+ ciphertext = Vec::from(&c.0[..mu / 8]);
}
assert!(ciphertext.len() == message.len());
ciphertext
@@ -304,77 +201,76 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
for (i, (mi, mi_)) in block_pairs.iter().enumerate() {
let i = (i + 1) as i32;
- let w = xor(mi, &e(1, i, key, mi_));
- let x = xor(mi_, &e(0, 0, key, &w));
+ let w = *mi ^ e(1, i, key, *mi_);
+ let x = *mi_ ^ e(0, 0, key, w);
ws.push(w);
xs.push(x);
}
- let mut x = NULL;
+ let mut x = Block::NULL;
for xi in &xs {
- x = xor(&x, xi);
+ x = x ^ *xi;
}
match d {
0 => (),
_ if d <= 127 => {
- x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into())));
+ x = x ^ e(0, 4, key, m_u.pad(d.into()));
}
_ => {
- x = xor(&x, &e(0, 4, key, &m_u));
- x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into())));
+ x = x ^ e(0, 4, key, m_u);
+ x = x ^ e(0, 5, key, m_v.pad(len_v.into()));
}
}
- let s_x = xor(&m_x, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y))));
- let s_y = xor(&m_y, &e(-1, 1, key, &s_x));
- let s = xor(&s_x, &s_y);
+ let s_x = m_x ^ delta ^ x ^ e(0, 1, key, m_y);
+ let s_y = m_y ^ e(-1, 1, key, s_x);
+ let s = s_x ^ s_y;
let mut cipher_pairs = Vec::new();
- let mut y = NULL;
+ let mut y = Block::NULL;
for (i, (wi, xi)) in ws.iter().zip(xs.iter()).enumerate() {
let i = (i + 1) as i32;
- let s_ = e(2, i, key, &s);
- let yi = xor(wi, &s_);
- let zi = xor(xi, &s_);
- let ci_ = xor(&yi, &e(0, 0, key, &zi));
- let ci = xor(&zi, &e(1, i, key, &ci_));
+ let s_ = e(2, i, key, s);
+ let yi = *wi ^ s_;
+ let zi = *xi ^ s_;
+ let ci_ = yi ^ e(0, 0, key, zi);
+ let ci = zi ^ e(1, i, key, ci_);
cipher_pairs.push((ci, ci_));
- y = xor(&y, &yi);
+ y = y ^ yi;
}
- let mut c_u = [0; 16];
- let mut c_v = [0; 16];
+ let mut c_u = Block::default();
+ let mut c_v = Block::default();
match d {
0 => (),
_ if d <= 127 => {
- c_u = xor(&m_u, &e(-1, 4, key, &s));
- clip_to_bits(&mut c_u, d.into());
- y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into())));
+ c_u = (m_u ^ e(-1, 4, key, s)).clip(d.into());
+ y = y ^ e(0, 4, key, c_u.pad(d.into()));
}
_ => {
- c_u = xor(&m_u, &e(-1, 4, key, &s));
- c_v = xor(&m_v, &e(-1, 5, key, &s));
- clip_to_bits(&mut c_v, len_v.into());
- y = xor(&y, &e(0, 4, key, &c_u));
- y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into())));
+ c_u = m_u ^ e(-1, 4, key, s);
+ c_v = (m_v ^ e(-1, 5, key, s)).clip(len_v.into());
+ y = y ^ e(0, 4, key, c_u);
+ y = y ^ e(0, 5, key, c_v.pad(len_v.into()));
}
}
- let c_y = xor(&s_x, &e(-1, 2, key, &s_y));
- let c_x = xor(&s_y, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y))));
+ let c_y = s_x ^ e(-1, 2, key, s_y);
+ let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y);
let mut ciphertext = Vec::new();
for (ci, ci_) in cipher_pairs {
- ciphertext.extend_from_slice(&ci);
- ciphertext.extend_from_slice(&ci_);
+ ciphertext.extend_from_slice(&ci.0);
+ ciphertext.extend_from_slice(&ci_.0);
}
- ciphertext.extend_from_slice(&c_u[..128.min(d) as usize / 8]);
- ciphertext.extend_from_slice(&c_v[..len_v as usize / 8]);
- ciphertext.extend_from_slice(&c_x);
- ciphertext.extend_from_slice(&c_y);
+ ciphertext.extend_from_slice(&c_u.0[..128.min(d) as usize / 8]);
+ ciphertext.extend_from_slice(&c_v.0[..len_v as usize / 8]);
+ ciphertext.extend_from_slice(&c_x.0);
+ ciphertext.extend_from_slice(&c_y.0);
+ assert!(ciphertext.len() == message.len());
ciphertext
}
@@ -392,7 +288,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let n = mu / 2;
let delta = aez_hash(key, tweaks);
let round_count = match mu {
- 8 => 24,
+ 8 => 24u32,
16 => 16,
_ if mu < 128 => 10,
_ => 8,
@@ -400,48 +296,33 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let mut message = Vec::from(message);
if mu < 128 {
- let mut c = Block::default();
- c[..message.len()].copy_from_slice(&message);
- c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE));
+ let mut c = Block::from_slice(&message);
+ c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE);
message.clear();
- message.extend(&c[..mu / 8]);
+ message.extend(&c.0[..mu / 8]);
}
let (mut left, mut right);
// We might end up having to split at a nibble, so manually adjust for that
if n % 8 == 0 {
- left = full_block(&message[..n / 8]);
- right = full_block(&message[n / 8..]);
+ left = Block::from_slice(&message[..n / 8]);
+ right = Block::from_slice(&message[n / 8..]);
} else {
- left = full_block(&message[..n / 8 + 1]);
- clip_to_bits(&mut left, n);
- right = full_block(&message[n / 8..]);
- right = lshift(&right, 4);
+ left = Block::from_slice(&message[..n / 8 + 1]).clip(n);
+ right = Block::from_slice(&message[n / 8..]) << 4;
};
let i = if mu >= 128 { 6 } else { 7 };
for j in (0..round_count).rev() {
- let mut right_ = xor(
- &left,
- &e(
- 0,
- i,
- key,
- &xor(
- &xor(&delta, &pad_block(&right, n)),
- &(j as u128).to_be_bytes(),
- ),
- ),
- );
- clip_to_bits(&mut right_, n);
+ let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n);
(left, right) = (right, right_);
}
let mut ciphertext = Vec::new();
if n % 8 == 0 {
- ciphertext.extend_from_slice(&right[..n / 8]);
- ciphertext.extend_from_slice(&left[..n / 8]);
+ ciphertext.extend_from_slice(&right.0[..n / 8]);
+ ciphertext.extend_from_slice(&left.0[..n / 8]);
} else {
- ciphertext.extend_from_slice(&right[..n / 8 + 1]);
- for byte in &left[..n / 8 + 1] {
+ ciphertext.extend_from_slice(&right.0[..n / 8 + 1]);
+ for byte in &left.0[..n / 8 + 1] {
*ciphertext.last_mut().unwrap() |= byte >> 4;
ciphertext.push((byte & 0x0f) << 4);
}
@@ -462,77 +343,75 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> {
for (i, (ci, ci_)) in block_pairs.iter().enumerate() {
let i = (i + 1) as i32;
- let w = xor(ci, &e(1, i, key, ci_));
- let y = xor(ci_, &e(0, 0, key, &w));
+ let w = *ci ^ e(1, i, key, *ci_);
+ let y = *ci_ ^ e(0, 0, key, w);
ws.push(w);
ys.push(y);
}
- let mut y = NULL;
+ let mut y = Block::NULL;
for yi in &ys {
- y = xor(&y, yi);
+ y = y ^ *yi;
}
match d {
0 => (),
_ if d <= 127 => {
- y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into())));
+ y = y ^ e(0, 4, key, c_u.pad(d.into()));
}
_ => {
- y = xor(&y, &e(0, 4, key, &c_u));
- y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into())));
+ y = y ^ e(0, 4, key, c_u);
+ y = y ^ e(0, 5, key, c_v.pad(len_v.into()));
}
}
- let s_x = xor(&c_x, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y))));
- let s_y = xor(&c_y, &e(-1, 2, key, &s_x));
- let s = xor(&s_x, &s_y);
+ let s_x = c_x ^ delta ^ y ^ e(0, 2, key, c_y);
+ let s_y = c_y ^ e(-1, 2, key, s_x);
+ let s = s_x ^ s_y;
let mut plain_pairs = Vec::new();
- let mut x = NULL;
+ let mut x = Block::NULL;
for (i, (wi, yi)) in ws.iter().zip(ys.iter()).enumerate() {
let i = (i + 1) as i32;
- let s_ = e(2, i, key, &s);
- let xi = xor(wi, &s_);
- let zi = xor(yi, &s_);
- let mi_ = xor(&xi, &e(0, 0, key, &zi));
- let mi = xor(&zi, &e(1, i, key, &mi_));
+ let s_ = e(2, i, key, s);
+ let xi = *wi ^ s_;
+ let zi = *yi ^ s_;
+ let mi_ = xi ^ e(0, 0, key, zi);
+ let mi = zi ^ e(1, i, key, mi_);
plain_pairs.push((mi, mi_));
- x = xor(&x, &xi);
+ x = x ^ xi;
}
- let mut m_u = [0; 16];
- let mut m_v = [0; 16];
+ let mut m_u = Block::default();
+ let mut m_v = Block::default();
match d {
0 => (),
_ if d <= 127 => {
- m_u = xor(&c_u, &e(-1, 4, key, &s));
- clip_to_bits(&mut m_u, d.into());
- x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into())));
+ m_u = (c_u ^ e(-1, 4, key, s)).clip(d.into());
+ x = x ^ e(0, 4, key, m_u.pad(d.into()));
}
_ => {
- m_u = xor(&c_u, &e(-1, 4, key, &s));
- m_v = xor(&c_v, &e(-1, 5, key, &s));
- clip_to_bits(&mut m_v, len_v.into());
- x = xor(&x, &e(0, 4, key, &m_u));
- x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into())));
+ m_u = c_u ^ e(-1, 4, key, s);
+ m_v = (c_v ^ e(-1, 5, key, s)).clip(len_v.into());
+ x = x ^ e(0, 4, key, m_u);
+ x = x ^ e(0, 5, key, m_v.pad(len_v.into()));
}
}
- let m_y = xor(&s_x, &e(-1, 1, key, &s_y));
- let m_x = xor(&s_y, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y))));
+ let m_y = s_x ^ e(-1, 1, key, s_y);
+ let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y);
let mut message = Vec::new();
for (mi, mi_) in plain_pairs {
- message.extend_from_slice(&mi);
- message.extend_from_slice(&mi_);
+ message.extend_from_slice(&mi.0);
+ message.extend_from_slice(&mi_.0);
}
- message.extend_from_slice(&m_u[..128.min(d) as usize / 8]);
- message.extend_from_slice(&m_v[..len_v as usize / 8]);
- message.extend_from_slice(&m_x);
- message.extend_from_slice(&m_y);
+ message.extend_from_slice(&m_u.0[..128.min(d) as usize / 8]);
+ message.extend_from_slice(&m_v.0[..len_v as usize / 8]);
+ message.extend_from_slice(&m_x.0);
+ message.extend_from_slice(&m_y.0);
message
}
@@ -540,9 +419,8 @@ fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block
let num_blocks = (message.len() - 16 - 16) / 32;
let mut blocks = Vec::new();
for _ in 0..num_blocks {
- let (mut a, mut b) = ([0; 16], [0; 16]);
- a.copy_from_slice(&message[..16]);
- b.copy_from_slice(&message[16..32]);
+ let a = Block::from_slice(&message[..16]);
+ let b = Block::from_slice(&message[16..32]);
blocks.push((a, b));
message = &message[32..];
}
@@ -552,18 +430,17 @@ fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block
message = &message[m_uv.len()..];
assert!(message.len() == 32);
- let mut m_u = [0; 16];
- let mut m_v = [0; 16];
+ let m_u;
+ let m_v;
if d <= 127 {
- m_u[..m_uv.len()].copy_from_slice(m_uv);
+ m_u = Block::from_slice(m_uv);
+ m_v = Block::default();
} else {
- m_u.copy_from_slice(&m_uv[..16]);
- m_v[..m_uv.len() - 16].copy_from_slice(&m_uv[16..]);
+ m_u = Block::from_slice(&m_uv[..16]);
+ m_v = Block::from_slice(&m_uv[16..]);
}
- let mut m_x = [0; 16];
- m_x.copy_from_slice(&message[..16]);
- let mut m_y = [0; 16];
- m_y.copy_from_slice(&message[16..]);
+ let m_x = Block::from_slice(&message[..16]);
+ let m_y = Block::from_slice(&message[16..]);
(blocks, m_u, m_v, m_x, m_y, d as u8)
}
@@ -571,29 +448,16 @@ fn pad_to_blocks(value: &[u8]) -> Vec<Block> {
let mut blocks = Vec::new();
for chunk in value.chunks(16) {
if chunk.len() == 16 {
- blocks.push(chunk.try_into().expect("we made sure the length fits"));
+ blocks.push(Block::from_slice(chunk));
} else {
- let mut block = Block::default();
- for (b, v) in block.iter_mut().zip(
- chunk
- .iter()
- .chain(iter::once(&0x80))
- .chain(iter::repeat(&0)),
- ) {
- *b = *v;
- }
- blocks.push(block)
+ blocks.push(Block::from_slice(chunk).pad(chunk.len() * 8));
}
}
blocks
}
-fn tau_to_block(tau: u32) -> Block {
- (tau as u128).to_be_bytes()
-}
-
fn aez_hash(key: &Key, tweaks: Tweak) -> Block {
- let mut hash = NULL;
+ let mut hash = Block::NULL;
for (i, tweak) in tweaks.iter().enumerate() {
// Adjust for zero-based vs one-based indexing
let j = i + 2 + 1;
@@ -601,33 +465,22 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block {
// set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this
// as well, but then we need to fiddle with getting an empty chunk from an empty iterator.
if tweak.is_empty() {
- hash = xor(
- &hash,
- &e(
- j.try_into().unwrap(),
- 0,
- key,
- &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ),
- );
+ hash = hash ^ e(j.try_into().unwrap(), 0, key, Block::ONE);
} else if tweak.len() % 16 == 0 {
for (l, chunk) in tweak.chunks(16).enumerate() {
- hash = xor(
- &hash,
- &e(
+ hash = hash
+ ^ e(
j.try_into().unwrap(),
(l + 1).try_into().unwrap(),
key,
- chunk.try_into().expect("we made sure the length fits"),
- ),
- );
+ Block::from_slice(chunk),
+ );
}
} else {
let blocks = pad_to_blocks(tweak);
for (l, chunk) in blocks.iter().enumerate() {
- hash = xor(
- &hash,
- &e(
+ hash = hash
+ ^ e(
j.try_into().unwrap(),
if l == blocks.len() - 1 {
0
@@ -635,9 +488,8 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block {
(l + 1).try_into().unwrap()
},
key,
- chunk,
- ),
- );
+ *chunk,
+ );
}
}
}
@@ -649,146 +501,52 @@ fn aez_prf(key: &Key, tweaks: Tweak, tau: u32) -> Vec<u8> {
let mut index = 0u128;
let delta = aez_hash(key, tweaks);
while result.len() < tau as usize {
- let block = e(-1, 3, key, &xor(&delta, &index.to_be_bytes()));
- result.extend_from_slice(&block[..16.min(tau as usize - result.len())]);
+ let block = e(-1, 3, key, delta ^ Block::from_int(index));
+ result.extend_from_slice(&block.0[..16.min(tau as usize - result.len())]);
index += 1;
}
result
}
-fn e(j: i32, i: i32, key: &Key, block: &Block) -> Block {
+fn e(j: i32, i: i32, key: &Key, block: Block) -> Block {
let (key_i, key_j, key_l) = split_key(key);
if j == -1 {
let k = [
- &NULL, key_i, key_j, key_l, key_i, key_j, key_l, key_i, key_j, key_l, key_i,
+ &Block::NULL,
+ &key_i,
+ &key_j,
+ &key_l,
+ &key_i,
+ &key_j,
+ &key_l,
+ &key_i,
+ &key_j,
+ &key_l,
+ &key_i,
];
- let delta = times(i.try_into().expect("i was negative"), key_l);
- aes10(&k, &xor(block, &delta))
+ let delta = key_l * i.try_into().expect("i was negative");
+ aes10(&k, &(block ^ delta))
} else {
- let k = [&NULL, key_j, key_i, key_l, &NULL];
+ let k = [&Block::NULL, &key_j, &key_i, &key_l, &Block::NULL];
let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 };
- let delta = xor(
- &xor(
- &times(j.try_into().expect("j was negative"), key_j),
- &times(1 << exponent, key_i),
- ),
- &times((i % 8).try_into().expect("i was negative"), key_l),
- );
- aes4(&k, &xor(block, &delta))
+ let j: u32 = j.try_into().expect("j was negative");
+ let i: u32 = i.try_into().expect("i was negative");
+ let delta = (key_j * j) ^ (key_i * (1 << exponent)) ^ (key_l * (i % 8));
+ aes4(&k, &(block ^ delta))
}
}
-fn split_key(key: &Key) -> (&Block, &Block, &Block) {
- let (i, jl) = key.split_at(16);
- let (j, l) = jl.split_at(16);
+fn split_key(key: &Key) -> (Block, Block, Block) {
(
- i.try_into().unwrap(),
- j.try_into().unwrap(),
- l.try_into().unwrap(),
+ Block::from_slice(&key[..16]),
+ Block::from_slice(&key[16..32]),
+ Block::from_slice(&key[32..]),
)
}
#[cfg(test)]
mod test {
use super::*;
-
- #[test]
- fn test_xor() {
- assert_eq!(xor(&[1; 16], &[2; 16]), [3; 16]);
- }
-
- #[test]
- fn test_times() {
- assert_eq!(
- times(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- assert_eq!(
- times(1, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]),
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
- );
- assert_eq!(
- times(2, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]
- );
- assert_eq!(
- times(2, &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133]
- );
- assert_eq!(
- times(2, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
- [2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133]
- );
- assert_eq!(
- times(3, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
- [131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132]
- );
- assert_eq!(
- times(4, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
- [4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10]
- );
- }
-
- #[test]
- fn test_lshift() {
- assert_eq!(
- lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1),
- [0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- assert_eq!(
- lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4),
- [0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- assert_eq!(
- lshift(&[0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4),
- [0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- assert_eq!(
- lshift(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 8),
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
- );
- }
-
- #[test]
- fn test_pad_block() {
- assert_eq!(
- pad_block(&[0; 16], 0),
- [0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- assert_eq!(
- pad_block(&[0; 16], 1),
- [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- assert_eq!(
- pad_block(&[0; 16], 8),
- [0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- }
-
- #[test]
- fn test_clip_to_bits() {
- let mut block;
-
- block = [0xFF; 16];
- clip_to_bits(&mut block, 0);
- assert_eq!(block, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
-
- block = [0xFF; 16];
- clip_to_bits(&mut block, 4);
- assert_eq!(block, [0xF0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
-
- block = [0xFF; 16];
- clip_to_bits(&mut block, 8);
- assert_eq!(block, [0xFF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
-
- block = [0xFF; 16];
- clip_to_bits(&mut block, 9);
- assert_eq!(
- block,
- [0xFF, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- );
- }
-
#[test]
fn test_extract() {
for (a, b) in testvectors::EXTRACT_VECTORS {
@@ -805,9 +563,9 @@ mod test {
let k = hex::decode(k).unwrap();
let k = k.as_slice().try_into().unwrap();
let a = hex::decode(a).unwrap();
- let a = a.as_slice().try_into().unwrap();
+ let a = Block::from_slice(&a);
let b = hex::decode(b).unwrap();
- assert_eq!(&e(*j, *i, k, a), b.as_slice(), "{name}");
+ assert_eq!(&e(*j, *i, k, a).0, b.as_slice(), "{name}");
}
}
@@ -819,13 +577,13 @@ mod test {
let k = k.as_slice().try_into().unwrap();
let v = hex::decode(v).unwrap();
- let mut tweaks = vec![Vec::from(tau_to_block(*tau))];
+ let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)];
for t in *tw {
tweaks.push(hex::decode(t).unwrap());
}
let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>();
- assert_eq!(&aez_hash(&k, &tweaks), v.as_slice(), "{name}");
+ assert_eq!(&aez_hash(&k, &tweaks).0, v.as_slice(), "{name}");
}
}
@@ -849,10 +607,10 @@ mod test {
let c = hex::decode(c).unwrap();
if &encrypt(&k, &n, &ad, *tau, &m) == &c {
- println!("+ {name}");
+ //println!("+ {name}");
succ += 1;
} else {
- println!("- {name}");
+ println!("- {}", c.len());
failed += 1;
}
}
@@ -898,4 +656,13 @@ mod test {
let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap();
assert_eq!(plain, b"hi");
}
+
+ #[test]
+ fn test_encrypt_decrypt_long() {
+ let message = b"ene mene miste es rappelt in der kiste ene mene meck und du bist weg";
+ let aez = Aez::new(b"foobar");
+ let cipher = aez.encrypt(&[0], &[b"foobar"], 16, message);
+ let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap();
+ assert_eq!(plain, message);
+ }
}