aboutsummaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs63
1 files changed, 37 insertions, 26 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 41b8d72..af95fe3 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -642,11 +642,11 @@ fn cipher_aez_core<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(
match mode {
Mode::Encipher => {
s_x = m_x ^ delta ^ x ^ e(0, 1, aez, m_y);
- s_y = m_y ^ e(-1, 1, aez, s_x);
+ s_y = m_y ^ e_neg(1, aez, s_x);
}
Mode::Decipher => {
s_x = m_x ^ delta ^ x ^ e(0, 2, aez, m_y);
- s_y = m_y ^ e(-1, 2, aez, s_x);
+ s_y = m_y ^ e_neg(2, aez, s_x);
}
}
let s = s_x ^ s_y;
@@ -659,12 +659,12 @@ fn cipher_aez_core<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(
match d {
0 => (),
_ if d <= 127 => {
- c_u = (m_u ^ e(-1, 4, aez, s)).clip(d);
+ c_u = (m_u ^ e_neg(4, aez, s)).clip(d);
y = y ^ e(0, 4, aez, c_u.pad(d));
}
_ => {
- c_u = m_u ^ e(-1, 4, aez, s);
- c_v = (m_v ^ e(-1, 5, aez, s)).clip(len_v);
+ c_u = m_u ^ e_neg(4, aez, s);
+ c_v = (m_v ^ e_neg(5, aez, s)).clip(len_v);
y = y ^ e(0, 4, aez, c_u);
y = y ^ e(0, 5, aez, c_v.pad(len_v));
}
@@ -673,11 +673,11 @@ fn cipher_aez_core<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(
let (c_x, c_y);
match mode {
Mode::Encipher => {
- c_y = s_x ^ e(-1, 2, aez, s_y);
+ c_y = s_x ^ e_neg(2, aez, s_y);
c_x = s_y ^ delta ^ y ^ e(0, 2, aez, c_y);
}
Mode::Decipher => {
- c_y = s_x ^ e(-1, 1, aez, s_y);
+ c_y = s_x ^ e_neg(1, aez, s_y);
c_x = s_y ^ delta ^ y ^ e(0, 1, aez, c_y);
}
}
@@ -728,13 +728,13 @@ fn aez_prf<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, buff
let delta = aez_hash(aez, tweaks);
for chunk in buffer.chunks_exact_mut(16) {
let chunk: &mut [u8; 16] = chunk.try_into().unwrap();
- let block = e(-1, 3, aez, delta ^ index);
+ let block = e_neg(3, aez, delta ^ index);
(block ^ Block::from(*chunk)).write_to(chunk);
index.count_up();
}
let suffix_start = buffer.len() - buffer.len() % 16;
let chunk = &mut buffer[suffix_start..];
- let block = e(-1, 3, aez, delta ^ index);
+ let block = e_neg(3, aez, delta ^ index);
for (a, b) in chunk.iter_mut().zip(block.bytes().iter()) {
*a ^= *b;
}
@@ -746,16 +746,14 @@ fn aez_prf<A: AsRef<[u8]>, T: IntoIterator<Item = A>>(aez: &Aez, tweaks: T, buff
/// temporary values and makes it much faster to compute E_K^{j, i+1}, E_K^{j, i+2}, ...
struct E<'a> {
aez: &'a Aez,
- i: u32,
+ i: usize,
kj_t_j: Block,
ki_p_i: Block,
}
impl<'a> E<'a> {
/// Create a new "suspended" computation of E_K^{j,i}.
- fn new(j: i32, i: u32, aez: &'a Aez) -> Self {
- assert!(j >= 0);
- let j: u32 = j.try_into().expect("j was negative");
+ fn new(j: usize, i: usize, aez: &'a Aez) -> Self {
let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 };
E {
aez,
@@ -767,7 +765,7 @@ impl<'a> E<'a> {
/// Complete this computation to evaluate E_K^{j,i}(block).
fn eval(&self, block: Block) -> Block {
- let delta = self.kj_t_j ^ self.ki_p_i ^ self.aez.key_l_multiples[self.i as usize % 8];
+ let delta = self.kj_t_j ^ self.ki_p_i ^ self.aez.key_l_multiples[self.i % 8];
self.aez.aes.aes4(block ^ delta)
}
@@ -862,17 +860,18 @@ impl<'a> Iterator for Eiter<'a> {
}
/// Shorthand to get E_K^{j,i}(block)
-fn e(j: i32, i: u32, aez: &Aez, block: Block) -> Block {
- if j == -1 {
- let delta = if i < 8 {
- aez.key_l_multiples[i as usize]
- } else {
- aez.key_l * i
- };
- aez.aes.aes10(block ^ delta)
+fn e(j: usize, i: usize, aez: &Aez, block: Block) -> Block {
+ E::new(j, i, aez).eval(block)
+}
+
+/// Computes E_K^{-1,i}(block)
+fn e_neg(i: usize, aez: &Aez, block: Block) -> Block {
+ let delta = if i < 8 {
+ aez.key_l_multiples[i]
} else {
- E::new(j, i, aez).eval(block)
- }
+ aez.key_l * i
+ };
+ aez.aes.aes10(block ^ delta)
}
fn split_key(key: &Key) -> (Block, Block, Block) {
@@ -920,7 +919,7 @@ pub mod primitives {
super::aez_prf(aez, tweaks, buffer)
}
- pub fn e(j: i32, i: u32, aez: &Aez, block: [u8; 16]) -> [u8; 16] {
+ pub fn e(j: usize, i: usize, aez: &Aez, block: [u8; 16]) -> [u8; 16] {
super::e(j, i, aez, Block::from(block)).bytes()
}
}
@@ -949,7 +948,19 @@ mod test {
let a = hex::decode(a).unwrap();
let a = Block::from_slice(&a);
let b = hex::decode(b).unwrap();
- assert_eq!(&e(*j, *i, &aez, a).bytes(), b.as_slice(), "{name}");
+ if *j >= 0 {
+ assert_eq!(
+ &e((*j).try_into().unwrap(), (*i).try_into().unwrap(), &aez, a).bytes(),
+ b.as_slice(),
+ "{name}"
+ );
+ } else {
+ assert_eq!(
+ &e_neg((*i).try_into().unwrap(), &aez, a).bytes(),
+ b.as_slice(),
+ "{name}"
+ );
+ }
}
}