diff options
| author | Daniel Schadt <kingdread@gmx.de> | 2025-04-05 19:03:10 +0200 | 
|---|---|---|
| committer | Daniel Schadt <kingdread@gmx.de> | 2025-04-05 19:03:10 +0200 | 
| commit | 71cdf50525f0cbb70673477510050669206df7f2 (patch) | |
| tree | 41e58ce93318dfaaf8f2c4f4dd91b879ead378af | |
| parent | 5cd9e4a71f0561d599ce5c7d498828ef5b8db2bb (diff) | |
| download | zears-71cdf50525f0cbb70673477510050669206df7f2.tar.gz zears-71cdf50525f0cbb70673477510050669206df7f2.tar.bz2 zears-71cdf50525f0cbb70673477510050669206df7f2.zip | |
use proper Block struct and operator overloading
| -rw-r--r-- | src/block.rs | 221 | ||||
| -rw-r--r-- | src/lib.rs | 565 | 
2 files changed, 387 insertions, 399 deletions
| diff --git a/src/block.rs b/src/block.rs new file mode 100644 index 0000000..e63062e --- /dev/null +++ b/src/block.rs @@ -0,0 +1,221 @@ +use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr}; + +/// A block, the unit of work that AEZ divides the message into. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Block(pub [u8; 16]); + +impl Block { +    pub const NULL: Block = Block([0; 16]); +    pub const ONE: Block = Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + +    /// Create a block from a slice. +    /// +    /// If the slice is too long, it will be truncated. If the slice is too short, the remaining +    /// items are set to 0. +    pub fn from_slice(value: &[u8]) -> Self { +        let len = value.len().min(16); +        let mut array = [0; 16]; +        array[..len].copy_from_slice(&value[..len]); +        Block(array) +    } + +    /// Constructs a block representing the given integer. +    /// +    /// This corresponds to [x]_128 in the paper. +    pub fn from_int<I: Into<u128>>(value: I) -> Self { +        Block(value.into().to_be_bytes()) +    } + +    pub fn to_int(&self) -> u128 { +        u128::from_be_bytes(self.0) +    } + +    /// Pad the block to full length. +    /// +    /// The given length is the current length. +    /// +    /// This corresponds to X10* in the paper. +    pub fn pad(&self, length: usize) -> Block { +        assert!(length <= 127); +        Block::from_int(self.to_int() | (1 << (127 - length))) +    } + +    /// Clip the block by setting all bits beyond the given length to 0. +    pub fn clip(&self, mut length: usize) -> Block { +        let mut block = self.0; +        for byte in &mut block { +            if length == 0 { +                *byte = 0; +            } else if length < 8 { +                *byte &= 0xff << (8 - length); +            } +            length = length.saturating_sub(8); +        } +        Block(block) +    } +} + +impl From<[u8; 16]> for Block { +    fn from(value: [u8; 16]) -> Block { +        Block(value) +    } +} + +impl From<&[u8; 16]> for Block { +    fn from(value: &[u8; 16]) -> Block { +        Block(*value) +    } +} + +impl From<u128> for Block { +    fn from(value: u128) -> Block { +        Block(value.to_be_bytes()) +    } +} + +impl BitXor<Block> for Block { +    type Output = Block; +    fn bitxor(self, rhs: Block) -> Block { +        Block::from(self.to_int() ^ rhs.to_int()) +    } +} + +impl Shl<u32> for Block { +    type Output = Block; +    fn shl(self, rhs: u32) -> Block { +        Block::from(self.to_int() << rhs) +    } +} + +impl Shr<u32> for Block { +    type Output = Block; +    fn shr(self, rhs: u32) -> Block { +        Block::from(self.to_int() >> rhs) +    } +} + +impl BitAnd<Block> for Block { +    type Output = Block; +    fn bitand(self, rhs: Block) -> Block { +        Block::from(self.to_int() & rhs.to_int()) +    } +} + +impl BitOr<Block> for Block { +    type Output = Block; +    fn bitor(self, rhs: Block) -> Block { +        Block::from(self.to_int() | rhs.to_int()) +    } +} + +impl Index<usize> for Block { +    type Output = u8; +    fn index(&self, index: usize) -> &u8 { +        &self.0[index] +    } +} + +impl IndexMut<usize> for Block { +    fn index_mut(&mut self, index: usize) -> &mut u8 { +        &mut self.0[index] +    } +} + +impl Mul<u32> for Block { +    type Output = Block; +    fn mul(self, rhs: u32) -> Block { +        match rhs { +            0 => Block::NULL, +            1 => self, +            2 => { +                let mut result = self << 1; +                if self[0] & 0x80 != 0 { +                    result[15] ^= 135; +                } +                result +            } +            _ if rhs % 2 == 0 => self * 2 * (rhs / 2), +            _ => self * (rhs - 1) ^ self, +        } +    } +} + +#[cfg(test)] +mod test { +    use super::*; +    #[test] +    fn test_xor() { +        assert_eq!( +            Block::from([1; 16]) ^ Block::from([2; 16]), +            Block::from([3; 16]) +        ); +    } + +    #[test] +    fn test_pad() { +        assert_eq!( +            Block::from([0; 16]).pad(0), +            Block::from([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +        assert_eq!( +            Block::from([0; 16]).pad(1), +            Block::from([0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +        assert_eq!( +            Block::from([0; 16]).pad(8), +            Block::from([0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +    } + +    #[test] +    fn test_shl() { +        assert_eq!( +            Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 1, +            Block::from([0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +        assert_eq!( +            Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4, +            Block::from([0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +        assert_eq!( +            Block::from([0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4, +            Block::from([0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +        assert_eq!( +            Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) << 8, +            Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]), +        ); +    } + +    #[test] +    fn test_times() { +        assert_eq!( +            Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 0, +            Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), +        ); +        assert_eq!( +            Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 1, +            Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), +        ); +        assert_eq!( +            Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2, +            Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]), +        ); +        assert_eq!( +            Block::from([128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2, +            Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133]), +        ); +        assert_eq!( +            Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 2, +            Block::from([2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133]), +        ); +        assert_eq!( +            Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 3, +            Block::from([131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132]), +        ); +        assert_eq!( +            Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 4, +            Block::from([4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10]), +        ); +    } +} @@ -1,15 +1,13 @@  use std::iter; +mod block;  #[cfg(test)]  mod testvectors; -type Block = [u8; 16]; +use block::Block;  type Key = [u8; 48];  type Tweak<'a> = &'a [&'a [u8]]; -static NULL: Block = [0; 16]; -static ONE: Block = [128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; -  pub struct Aez {      key: Key,  } @@ -40,72 +38,14 @@ impl Aez {      }  } -fn xor(lhs: &Block, rhs: &Block) -> Block { -    let mut result = [0; 16]; -    for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { -        *r = a ^ b; -    } -    result -} - -fn and(lhs: &Block, rhs: &Block) -> Block { -    let mut result = [0; 16]; -    for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { -        *r = a & b; -    } -    result -} - -fn or(lhs: &Block, rhs: &Block) -> Block { -    let mut result = [0; 16]; -    for ((a, b), r) in lhs.iter().zip(rhs.iter()).zip(result.iter_mut()) { -        *r = a | b; -    } -    result -} - -fn lshift(block: &Block, times: u32) -> Block { -    let mut block = block.clone(); -    for _ in 0..times { -        let mut result = [0; 16]; -        for (b, r) in block.iter().zip(result.iter_mut()) { -            *r = b << 1; -        } -        for (b, r) in block[1..].iter().zip(result.iter_mut()) { -            *r = *r | ((b & 0x80) >> 7); -        } -        block = result; -    } -    block -} - -fn times(lhs: u32, block: &Block) -> Block { -    match lhs { -        0 => NULL, -        1 => *block, -        2 => { -            let mut result = lshift(block, 1); -            if block[0] & 0x80 != 0 { -                result[15] ^= 135; -            } -            result -        } -        _ if lhs % 2 == 0 => times(2, ×(lhs / 2, block)), -        _ => xor(×(lhs - 1, block), block), -    } -} - -fn aesenc(mut block: Block, key: &Block) -> Block { -    aes::hazmat::cipher_round((&mut block).into(), key.into()); +fn aesenc(mut block: Block, key: &Block) -> block::Block { +    aes::hazmat::cipher_round((&mut block.0).into(), &key.0.into());      block  }  fn aes4(keys: &[&Block; 5], block: &Block) -> Block {      aesenc( -        aesenc( -            aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]), -            keys[3], -        ), +        aesenc(aesenc(aesenc(*block ^ *keys[0], keys[1]), keys[2]), keys[3]),          keys[4],      )  } @@ -119,7 +59,7 @@ fn aes10(keys: &[&Block; 11], block: &Block) -> Block {                          aesenc(                              aesenc(                                  aesenc( -                                    aesenc(aesenc(xor(block, keys[0]), keys[1]), keys[2]), +                                    aesenc(aesenc(*block ^ *keys[0], keys[1]), keys[2]),                                      keys[3],                                  ),                                  keys[4], @@ -150,35 +90,6 @@ fn extract(key: &[u8]) -> [u8; 48] {      }  } -fn clip_to_bits(block: &mut Block, mut bits: usize) { -    for byte in block { -        if bits == 0 { -            *byte = 0; -        } else if bits < 8 { -            *byte &= 0xff << (8 - bits); -        } -        bits = bits.saturating_sub(8); -    } -} - -fn full_block(data: &[u8]) -> Block { -    let mut result = [0; 16]; -    result[..data.len()].copy_from_slice(data); -    result -} - -fn pad_block(block: &Block, mut bits: usize) -> Block { -    let mut block = *block; -    for byte in &mut block { -        if bits < 8 { -            *byte |= 0x80 >> bits; -            break; -        } -        bits = bits.saturating_sub(8); -    } -    block -} -  fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> {      let auth_message = message          .iter() @@ -186,8 +97,8 @@ fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> V          .chain(iter::repeat_n(0, tau as usize))          .collect::<Vec<_>>();      // We treat tau as bytes, but according to the spec, tau is actually in bits. -    let tau_block = tau_to_block(tau * 8); -    let mut tweaks = vec![&tau_block, nonce]; +    let tau_block = Block::from_int(tau * 8); +    let mut tweaks = vec![&tau_block.0, nonce];      tweaks.extend(ad);      if message.is_empty() {          aez_prf(key, &tweaks, tau) @@ -201,8 +112,8 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -          return None;      } -    let tau_block = tau_to_block(tau * 8); -    let mut tweaks = vec![&tau_block, nonce]; +    let tau_block = Block::from_int(tau * 8); +    let mut tweaks = vec![&tau_block.0, nonce];      tweaks.extend(ad);      if ciphertext.len() == tau as usize { @@ -237,7 +148,7 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {      let n = mu / 2;      let delta = aez_hash(key, tweaks);      let round_count = match mu { -        8 => 24, +        8 => 24u32,          16 => 16,          _ if mu < 128 => 10,          _ => 8, @@ -246,48 +157,34 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {      let (mut left, mut right);      // We might end up having to split at a nibble, so manually adjust for that      if n % 8 == 0 { -        left = full_block(&message[..n / 8]); -        right = full_block(&message[n / 8..]); +        left = Block::from_slice(&message[..n / 8]); +        right = Block::from_slice(&message[n / 8..]);      } else { -        left = full_block(&message[..n / 8 + 1]); -        clip_to_bits(&mut left, n); -        right = full_block(&message[n / 8..]); -        right = lshift(&right, 4); +        assert!(n % 8 == 4); +        left = Block::from_slice(&message[..n / 8 + 1]).clip(n); +        right = Block::from_slice(&message[n / 8..]) << 4;      };      let i = if mu >= 128 { 6 } else { 7 };      for j in 0..round_count { -        let mut right_ = xor( -            &left, -            &e( -                0, -                i, -                key, -                &xor( -                    &xor(&delta, &pad_block(&right, n)), -                    &(j as u128).to_be_bytes(), -                ), -            ), -        ); -        clip_to_bits(&mut right_, n); +        let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n);          (left, right) = (right, right_);      }      let mut ciphertext = Vec::new();      if n % 8 == 0 { -        ciphertext.extend_from_slice(&right[..n / 8]); -        ciphertext.extend_from_slice(&left[..n / 8]); +        ciphertext.extend_from_slice(&right.0[..n / 8]); +        ciphertext.extend_from_slice(&left.0[..n / 8]);      } else { -        ciphertext.extend_from_slice(&right[..n / 8 + 1]); -        for byte in &left[..n / 8 + 1] { +        ciphertext.extend_from_slice(&right.0[..n / 8 + 1]); +        for byte in &left.0[..n / 8 + 1] {              *ciphertext.last_mut().unwrap() |= byte >> 4;              ciphertext.push((byte & 0x0f) << 4);          }          ciphertext.pop();      }      if mu < 128 { -        let mut c = Block::default(); -        c[..ciphertext.len()].copy_from_slice(&ciphertext); -        c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE)); -        ciphertext = Vec::from(&c[..mu / 8]); +        let mut c = Block::from_slice(&ciphertext); +        c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE); +        ciphertext = Vec::from(&c.0[..mu / 8]);      }      assert!(ciphertext.len() == message.len());      ciphertext @@ -304,77 +201,76 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {      for (i, (mi, mi_)) in block_pairs.iter().enumerate() {          let i = (i + 1) as i32; -        let w = xor(mi, &e(1, i, key, mi_)); -        let x = xor(mi_, &e(0, 0, key, &w)); +        let w = *mi ^ e(1, i, key, *mi_); +        let x = *mi_ ^ e(0, 0, key, w);          ws.push(w);          xs.push(x);      } -    let mut x = NULL; +    let mut x = Block::NULL;      for xi in &xs { -        x = xor(&x, xi); +        x = x ^ *xi;      }      match d {          0 => (),          _ if d <= 127 => { -            x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into()))); +            x = x ^ e(0, 4, key, m_u.pad(d.into()));          }          _ => { -            x = xor(&x, &e(0, 4, key, &m_u)); -            x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into()))); +            x = x ^ e(0, 4, key, m_u); +            x = x ^ e(0, 5, key, m_v.pad(len_v.into()));          }      } -    let s_x = xor(&m_x, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y)))); -    let s_y = xor(&m_y, &e(-1, 1, key, &s_x)); -    let s = xor(&s_x, &s_y); +    let s_x = m_x ^ delta ^ x ^ e(0, 1, key, m_y); +    let s_y = m_y ^ e(-1, 1, key, s_x); +    let s = s_x ^ s_y;      let mut cipher_pairs = Vec::new(); -    let mut y = NULL; +    let mut y = Block::NULL;      for (i, (wi, xi)) in ws.iter().zip(xs.iter()).enumerate() {          let i = (i + 1) as i32; -        let s_ = e(2, i, key, &s); -        let yi = xor(wi, &s_); -        let zi = xor(xi, &s_); -        let ci_ = xor(&yi, &e(0, 0, key, &zi)); -        let ci = xor(&zi, &e(1, i, key, &ci_)); +        let s_ = e(2, i, key, s); +        let yi = *wi ^ s_; +        let zi = *xi ^ s_; +        let ci_ = yi ^ e(0, 0, key, zi); +        let ci = zi ^ e(1, i, key, ci_);          cipher_pairs.push((ci, ci_)); -        y = xor(&y, &yi); +        y = y ^ yi;      } -    let mut c_u = [0; 16]; -    let mut c_v = [0; 16]; +    let mut c_u = Block::default(); +    let mut c_v = Block::default();      match d {          0 => (),          _ if d <= 127 => { -            c_u = xor(&m_u, &e(-1, 4, key, &s)); -            clip_to_bits(&mut c_u, d.into()); -            y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into()))); +            c_u = (m_u ^ e(-1, 4, key, s)).clip(d.into()); +            y = y ^ e(0, 4, key, c_u.pad(d.into()));          }          _ => { -            c_u = xor(&m_u, &e(-1, 4, key, &s)); -            c_v = xor(&m_v, &e(-1, 5, key, &s)); -            clip_to_bits(&mut c_v, len_v.into()); -            y = xor(&y, &e(0, 4, key, &c_u)); -            y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into()))); +            c_u = m_u ^ e(-1, 4, key, s); +            c_v = (m_v ^ e(-1, 5, key, s)).clip(len_v.into()); +            y = y ^ e(0, 4, key, c_u); +            y = y ^ e(0, 5, key, c_v.pad(len_v.into()));          }      } -    let c_y = xor(&s_x, &e(-1, 2, key, &s_y)); -    let c_x = xor(&s_y, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y)))); +    let c_y = s_x ^ e(-1, 2, key, s_y); +    let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y);      let mut ciphertext = Vec::new();      for (ci, ci_) in cipher_pairs { -        ciphertext.extend_from_slice(&ci); -        ciphertext.extend_from_slice(&ci_); +        ciphertext.extend_from_slice(&ci.0); +        ciphertext.extend_from_slice(&ci_.0);      } -    ciphertext.extend_from_slice(&c_u[..128.min(d) as usize / 8]); -    ciphertext.extend_from_slice(&c_v[..len_v as usize / 8]); -    ciphertext.extend_from_slice(&c_x); -    ciphertext.extend_from_slice(&c_y); +    ciphertext.extend_from_slice(&c_u.0[..128.min(d) as usize / 8]); +    ciphertext.extend_from_slice(&c_v.0[..len_v as usize / 8]); +    ciphertext.extend_from_slice(&c_x.0); +    ciphertext.extend_from_slice(&c_y.0); +    assert!(ciphertext.len() == message.len());      ciphertext  } @@ -392,7 +288,7 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {      let n = mu / 2;      let delta = aez_hash(key, tweaks);      let round_count = match mu { -        8 => 24, +        8 => 24u32,          16 => 16,          _ if mu < 128 => 10,          _ => 8, @@ -400,48 +296,33 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {      let mut message = Vec::from(message);      if mu < 128 { -        let mut c = Block::default(); -        c[..message.len()].copy_from_slice(&message); -        c = xor(&c, &and(&e(0, 3, key, &xor(&delta, &or(&c, &ONE))), &ONE)); +        let mut c = Block::from_slice(&message); +        c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE);          message.clear(); -        message.extend(&c[..mu / 8]); +        message.extend(&c.0[..mu / 8]);      }      let (mut left, mut right);      // We might end up having to split at a nibble, so manually adjust for that      if n % 8 == 0 { -        left = full_block(&message[..n / 8]); -        right = full_block(&message[n / 8..]); +        left = Block::from_slice(&message[..n / 8]); +        right = Block::from_slice(&message[n / 8..]);      } else { -        left = full_block(&message[..n / 8 + 1]); -        clip_to_bits(&mut left, n); -        right = full_block(&message[n / 8..]); -        right = lshift(&right, 4); +        left = Block::from_slice(&message[..n / 8 + 1]).clip(n); +        right = Block::from_slice(&message[n / 8..]) << 4;      };      let i = if mu >= 128 { 6 } else { 7 };      for j in (0..round_count).rev() { -        let mut right_ = xor( -            &left, -            &e( -                0, -                i, -                key, -                &xor( -                    &xor(&delta, &pad_block(&right, n)), -                    &(j as u128).to_be_bytes(), -                ), -            ), -        ); -        clip_to_bits(&mut right_, n); +        let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n);          (left, right) = (right, right_);      }      let mut ciphertext = Vec::new();      if n % 8 == 0 { -        ciphertext.extend_from_slice(&right[..n / 8]); -        ciphertext.extend_from_slice(&left[..n / 8]); +        ciphertext.extend_from_slice(&right.0[..n / 8]); +        ciphertext.extend_from_slice(&left.0[..n / 8]);      } else { -        ciphertext.extend_from_slice(&right[..n / 8 + 1]); -        for byte in &left[..n / 8 + 1] { +        ciphertext.extend_from_slice(&right.0[..n / 8 + 1]); +        for byte in &left.0[..n / 8 + 1] {              *ciphertext.last_mut().unwrap() |= byte >> 4;              ciphertext.push((byte & 0x0f) << 4);          } @@ -462,77 +343,75 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> {      for (i, (ci, ci_)) in block_pairs.iter().enumerate() {          let i = (i + 1) as i32; -        let w = xor(ci, &e(1, i, key, ci_)); -        let y = xor(ci_, &e(0, 0, key, &w)); +        let w = *ci ^ e(1, i, key, *ci_); +        let y = *ci_ ^ e(0, 0, key, w);          ws.push(w);          ys.push(y);      } -    let mut y = NULL; +    let mut y = Block::NULL;      for yi in &ys { -        y = xor(&y, yi); +        y = y ^ *yi;      }      match d {          0 => (),          _ if d <= 127 => { -            y = xor(&y, &e(0, 4, key, &pad_block(&c_u, d.into()))); +            y = y ^ e(0, 4, key, c_u.pad(d.into()));          }          _ => { -            y = xor(&y, &e(0, 4, key, &c_u)); -            y = xor(&y, &e(0, 5, key, &pad_block(&c_v, len_v.into()))); +            y = y ^ e(0, 4, key, c_u); +            y = y ^ e(0, 5, key, c_v.pad(len_v.into()));          }      } -    let s_x = xor(&c_x, &xor(&delta, &xor(&y, &e(0, 2, key, &c_y)))); -    let s_y = xor(&c_y, &e(-1, 2, key, &s_x)); -    let s = xor(&s_x, &s_y); +    let s_x = c_x ^ delta ^ y ^ e(0, 2, key, c_y); +    let s_y = c_y ^ e(-1, 2, key, s_x); +    let s = s_x ^ s_y;      let mut plain_pairs = Vec::new(); -    let mut x = NULL; +    let mut x = Block::NULL;      for (i, (wi, yi)) in ws.iter().zip(ys.iter()).enumerate() {          let i = (i + 1) as i32; -        let s_ = e(2, i, key, &s); -        let xi = xor(wi, &s_); -        let zi = xor(yi, &s_); -        let mi_ = xor(&xi, &e(0, 0, key, &zi)); -        let mi = xor(&zi, &e(1, i, key, &mi_)); +        let s_ = e(2, i, key, s); +        let xi = *wi ^ s_; +        let zi = *yi ^ s_; +        let mi_ = xi ^ e(0, 0, key, zi); +        let mi = zi ^ e(1, i, key, mi_);          plain_pairs.push((mi, mi_)); -        x = xor(&x, &xi); +        x = x ^ xi;      } -    let mut m_u = [0; 16]; -    let mut m_v = [0; 16]; +    let mut m_u = Block::default(); +    let mut m_v = Block::default();      match d {          0 => (),          _ if d <= 127 => { -            m_u = xor(&c_u, &e(-1, 4, key, &s)); -            clip_to_bits(&mut m_u, d.into()); -            x = xor(&x, &e(0, 4, key, &pad_block(&m_u, d.into()))); +            m_u = (c_u ^ e(-1, 4, key, s)).clip(d.into()); +            x = x ^ e(0, 4, key, m_u.pad(d.into()));          }          _ => { -            m_u = xor(&c_u, &e(-1, 4, key, &s)); -            m_v = xor(&c_v, &e(-1, 5, key, &s)); -            clip_to_bits(&mut m_v, len_v.into()); -            x = xor(&x, &e(0, 4, key, &m_u)); -            x = xor(&x, &e(0, 5, key, &pad_block(&m_v, len_v.into()))); +            m_u = c_u ^ e(-1, 4, key, s); +            m_v = (c_v ^ e(-1, 5, key, s)).clip(len_v.into()); +            x = x ^ e(0, 4, key, m_u); +            x = x ^ e(0, 5, key, m_v.pad(len_v.into()));          }      } -    let m_y = xor(&s_x, &e(-1, 1, key, &s_y)); -    let m_x = xor(&s_y, &xor(&delta, &xor(&x, &e(0, 1, key, &m_y)))); +    let m_y = s_x ^ e(-1, 1, key, s_y); +    let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y);      let mut message = Vec::new();      for (mi, mi_) in plain_pairs { -        message.extend_from_slice(&mi); -        message.extend_from_slice(&mi_); +        message.extend_from_slice(&mi.0); +        message.extend_from_slice(&mi_.0);      } -    message.extend_from_slice(&m_u[..128.min(d) as usize / 8]); -    message.extend_from_slice(&m_v[..len_v as usize / 8]); -    message.extend_from_slice(&m_x); -    message.extend_from_slice(&m_y); +    message.extend_from_slice(&m_u.0[..128.min(d) as usize / 8]); +    message.extend_from_slice(&m_v.0[..len_v as usize / 8]); +    message.extend_from_slice(&m_x.0); +    message.extend_from_slice(&m_y.0);      message  } @@ -540,9 +419,8 @@ fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block      let num_blocks = (message.len() - 16 - 16) / 32;      let mut blocks = Vec::new();      for _ in 0..num_blocks { -        let (mut a, mut b) = ([0; 16], [0; 16]); -        a.copy_from_slice(&message[..16]); -        b.copy_from_slice(&message[16..32]); +        let a = Block::from_slice(&message[..16]); +        let b = Block::from_slice(&message[16..32]);          blocks.push((a, b));          message = &message[32..];      } @@ -552,18 +430,17 @@ fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block      message = &message[m_uv.len()..];      assert!(message.len() == 32); -    let mut m_u = [0; 16]; -    let mut m_v = [0; 16]; +    let m_u; +    let m_v;      if d <= 127 { -        m_u[..m_uv.len()].copy_from_slice(m_uv); +        m_u = Block::from_slice(m_uv); +        m_v = Block::default();      } else { -        m_u.copy_from_slice(&m_uv[..16]); -        m_v[..m_uv.len() - 16].copy_from_slice(&m_uv[16..]); +        m_u = Block::from_slice(&m_uv[..16]); +        m_v = Block::from_slice(&m_uv[16..]);      } -    let mut m_x = [0; 16]; -    m_x.copy_from_slice(&message[..16]); -    let mut m_y = [0; 16]; -    m_y.copy_from_slice(&message[16..]); +    let m_x = Block::from_slice(&message[..16]); +    let m_y = Block::from_slice(&message[16..]);      (blocks, m_u, m_v, m_x, m_y, d as u8)  } @@ -571,29 +448,16 @@ fn pad_to_blocks(value: &[u8]) -> Vec<Block> {      let mut blocks = Vec::new();      for chunk in value.chunks(16) {          if chunk.len() == 16 { -            blocks.push(chunk.try_into().expect("we made sure the length fits")); +            blocks.push(Block::from_slice(chunk));          } else { -            let mut block = Block::default(); -            for (b, v) in block.iter_mut().zip( -                chunk -                    .iter() -                    .chain(iter::once(&0x80)) -                    .chain(iter::repeat(&0)), -            ) { -                *b = *v; -            } -            blocks.push(block) +            blocks.push(Block::from_slice(chunk).pad(chunk.len() * 8));          }      }      blocks  } -fn tau_to_block(tau: u32) -> Block { -    (tau as u128).to_be_bytes() -} -  fn aez_hash(key: &Key, tweaks: Tweak) -> Block { -    let mut hash = NULL; +    let mut hash = Block::NULL;      for (i, tweak) in tweaks.iter().enumerate() {          // Adjust for zero-based vs one-based indexing          let j = i + 2 + 1; @@ -601,33 +465,22 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block {          // set l = 1 and then xor E_K^{j, 0}(10*). We could modify the last if branch to cover this          // as well, but then we need to fiddle with getting an empty chunk from an empty iterator.          if tweak.is_empty() { -            hash = xor( -                &hash, -                &e( -                    j.try_into().unwrap(), -                    0, -                    key, -                    &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], -                ), -            ); +            hash = hash ^ e(j.try_into().unwrap(), 0, key, Block::ONE);          } else if tweak.len() % 16 == 0 {              for (l, chunk) in tweak.chunks(16).enumerate() { -                hash = xor( -                    &hash, -                    &e( +                hash = hash +                    ^ e(                          j.try_into().unwrap(),                          (l + 1).try_into().unwrap(),                          key, -                        chunk.try_into().expect("we made sure the length fits"), -                    ), -                ); +                        Block::from_slice(chunk), +                    );              }          } else {              let blocks = pad_to_blocks(tweak);              for (l, chunk) in blocks.iter().enumerate() { -                hash = xor( -                    &hash, -                    &e( +                hash = hash +                    ^ e(                          j.try_into().unwrap(),                          if l == blocks.len() - 1 {                              0 @@ -635,9 +488,8 @@ fn aez_hash(key: &Key, tweaks: Tweak) -> Block {                              (l + 1).try_into().unwrap()                          },                          key, -                        chunk, -                    ), -                ); +                        *chunk, +                    );              }          }      } @@ -649,146 +501,52 @@ fn aez_prf(key: &Key, tweaks: Tweak, tau: u32) -> Vec<u8> {      let mut index = 0u128;      let delta = aez_hash(key, tweaks);      while result.len() < tau as usize { -        let block = e(-1, 3, key, &xor(&delta, &index.to_be_bytes())); -        result.extend_from_slice(&block[..16.min(tau as usize - result.len())]); +        let block = e(-1, 3, key, delta ^ Block::from_int(index)); +        result.extend_from_slice(&block.0[..16.min(tau as usize - result.len())]);          index += 1;      }      result  } -fn e(j: i32, i: i32, key: &Key, block: &Block) -> Block { +fn e(j: i32, i: i32, key: &Key, block: Block) -> Block {      let (key_i, key_j, key_l) = split_key(key);      if j == -1 {          let k = [ -            &NULL, key_i, key_j, key_l, key_i, key_j, key_l, key_i, key_j, key_l, key_i, +            &Block::NULL, +            &key_i, +            &key_j, +            &key_l, +            &key_i, +            &key_j, +            &key_l, +            &key_i, +            &key_j, +            &key_l, +            &key_i,          ]; -        let delta = times(i.try_into().expect("i was negative"), key_l); -        aes10(&k, &xor(block, &delta)) +        let delta = key_l * i.try_into().expect("i was negative"); +        aes10(&k, &(block ^ delta))      } else { -        let k = [&NULL, key_j, key_i, key_l, &NULL]; +        let k = [&Block::NULL, &key_j, &key_i, &key_l, &Block::NULL];          let exponent = if i % 8 == 0 { i / 8 } else { i / 8 + 1 }; -        let delta = xor( -            &xor( -                ×(j.try_into().expect("j was negative"), key_j), -                ×(1 << exponent, key_i), -            ), -            ×((i % 8).try_into().expect("i was negative"), key_l), -        ); -        aes4(&k, &xor(block, &delta)) +        let j: u32 = j.try_into().expect("j was negative"); +        let i: u32 = i.try_into().expect("i was negative"); +        let delta = (key_j * j) ^ (key_i * (1 << exponent)) ^ (key_l * (i % 8)); +        aes4(&k, &(block ^ delta))      }  } -fn split_key(key: &Key) -> (&Block, &Block, &Block) { -    let (i, jl) = key.split_at(16); -    let (j, l) = jl.split_at(16); +fn split_key(key: &Key) -> (Block, Block, Block) {      ( -        i.try_into().unwrap(), -        j.try_into().unwrap(), -        l.try_into().unwrap(), +        Block::from_slice(&key[..16]), +        Block::from_slice(&key[16..32]), +        Block::from_slice(&key[32..]),      )  }  #[cfg(test)]  mod test {      use super::*; - -    #[test] -    fn test_xor() { -        assert_eq!(xor(&[1; 16], &[2; 16]), [3; 16]); -    } - -    #[test] -    fn test_times() { -        assert_eq!( -            times(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), -            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -        assert_eq!( -            times(1, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), -            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] -        ); -        assert_eq!( -            times(2, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), -            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2] -        ); -        assert_eq!( -            times(2, &[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), -            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133] -        ); -        assert_eq!( -            times(2, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), -            [2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133] -        ); -        assert_eq!( -            times(3, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), -            [131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132] -        ); -        assert_eq!( -            times(4, &[129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), -            [4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10] -        ); -    } - -    #[test] -    fn test_lshift() { -        assert_eq!( -            lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1), -            [0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -        assert_eq!( -            lshift(&[0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4), -            [0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -        assert_eq!( -            lshift(&[0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 4), -            [0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -        assert_eq!( -            lshift(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 8), -            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] -        ); -    } - -    #[test] -    fn test_pad_block() { -        assert_eq!( -            pad_block(&[0; 16], 0), -            [0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -        assert_eq!( -            pad_block(&[0; 16], 1), -            [0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -        assert_eq!( -            pad_block(&[0; 16], 8), -            [0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -    } - -    #[test] -    fn test_clip_to_bits() { -        let mut block; - -        block = [0xFF; 16]; -        clip_to_bits(&mut block, 0); -        assert_eq!(block, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - -        block = [0xFF; 16]; -        clip_to_bits(&mut block, 4); -        assert_eq!(block, [0xF0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - -        block = [0xFF; 16]; -        clip_to_bits(&mut block, 8); -        assert_eq!(block, [0xFF, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - -        block = [0xFF; 16]; -        clip_to_bits(&mut block, 9); -        assert_eq!( -            block, -            [0xFF, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -        ); -    } -      #[test]      fn test_extract() {          for (a, b) in testvectors::EXTRACT_VECTORS { @@ -805,9 +563,9 @@ mod test {              let k = hex::decode(k).unwrap();              let k = k.as_slice().try_into().unwrap();              let a = hex::decode(a).unwrap(); -            let a = a.as_slice().try_into().unwrap(); +            let a = Block::from_slice(&a);              let b = hex::decode(b).unwrap(); -            assert_eq!(&e(*j, *i, k, a), b.as_slice(), "{name}"); +            assert_eq!(&e(*j, *i, k, a).0, b.as_slice(), "{name}");          }      } @@ -819,13 +577,13 @@ mod test {              let k = k.as_slice().try_into().unwrap();              let v = hex::decode(v).unwrap(); -            let mut tweaks = vec![Vec::from(tau_to_block(*tau))]; +            let mut tweaks = vec![Vec::from(Block::from_int(*tau).0)];              for t in *tw {                  tweaks.push(hex::decode(t).unwrap());              }              let tweaks = tweaks.iter().map(Vec::as_slice).collect::<Vec<_>>(); -            assert_eq!(&aez_hash(&k, &tweaks), v.as_slice(), "{name}"); +            assert_eq!(&aez_hash(&k, &tweaks).0, v.as_slice(), "{name}");          }      } @@ -849,10 +607,10 @@ mod test {              let c = hex::decode(c).unwrap();              if &encrypt(&k, &n, &ad, *tau, &m) == &c { -                println!("+ {name}"); +                //println!("+ {name}");                  succ += 1;              } else { -                println!("- {name}"); +                println!("- {}", c.len());                  failed += 1;              }          } @@ -898,4 +656,13 @@ mod test {          let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap();          assert_eq!(plain, b"hi");      } + +    #[test] +    fn test_encrypt_decrypt_long() { +        let message = b"ene mene miste es rappelt in der kiste ene mene meck und du bist weg"; +        let aez = Aez::new(b"foobar"); +        let cipher = aez.encrypt(&[0], &[b"foobar"], 16, message); +        let plain = aez.decrypt(&[0], &[b"foobar"], 16, &cipher).unwrap(); +        assert_eq!(plain, message); +    }  } | 
