diff options
author | Daniel Schadt <kingdread@gmx.de> | 2025-04-05 19:03:10 +0200 |
---|---|---|
committer | Daniel Schadt <kingdread@gmx.de> | 2025-04-05 19:03:10 +0200 |
commit | 71cdf50525f0cbb70673477510050669206df7f2 (patch) | |
tree | 41e58ce93318dfaaf8f2c4f4dd91b879ead378af /src/block.rs | |
parent | 5cd9e4a71f0561d599ce5c7d498828ef5b8db2bb (diff) | |
download | zears-71cdf50525f0cbb70673477510050669206df7f2.tar.gz zears-71cdf50525f0cbb70673477510050669206df7f2.tar.bz2 zears-71cdf50525f0cbb70673477510050669206df7f2.zip |
use proper Block struct and operator overloading
Diffstat (limited to 'src/block.rs')
-rw-r--r-- | src/block.rs | 221 |
1 files changed, 221 insertions, 0 deletions
diff --git a/src/block.rs b/src/block.rs new file mode 100644 index 0000000..e63062e --- /dev/null +++ b/src/block.rs @@ -0,0 +1,221 @@ +use std::ops::{BitAnd, BitOr, BitXor, Index, IndexMut, Mul, Shl, Shr}; + +/// A block, the unit of work that AEZ divides the message into. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Block(pub [u8; 16]); + +impl Block { + pub const NULL: Block = Block([0; 16]); + pub const ONE: Block = Block([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + + /// Create a block from a slice. + /// + /// If the slice is too long, it will be truncated. If the slice is too short, the remaining + /// items are set to 0. + pub fn from_slice(value: &[u8]) -> Self { + let len = value.len().min(16); + let mut array = [0; 16]; + array[..len].copy_from_slice(&value[..len]); + Block(array) + } + + /// Constructs a block representing the given integer. + /// + /// This corresponds to [x]_128 in the paper. + pub fn from_int<I: Into<u128>>(value: I) -> Self { + Block(value.into().to_be_bytes()) + } + + pub fn to_int(&self) -> u128 { + u128::from_be_bytes(self.0) + } + + /// Pad the block to full length. + /// + /// The given length is the current length. + /// + /// This corresponds to X10* in the paper. + pub fn pad(&self, length: usize) -> Block { + assert!(length <= 127); + Block::from_int(self.to_int() | (1 << (127 - length))) + } + + /// Clip the block by setting all bits beyond the given length to 0. + pub fn clip(&self, mut length: usize) -> Block { + let mut block = self.0; + for byte in &mut block { + if length == 0 { + *byte = 0; + } else if length < 8 { + *byte &= 0xff << (8 - length); + } + length = length.saturating_sub(8); + } + Block(block) + } +} + +impl From<[u8; 16]> for Block { + fn from(value: [u8; 16]) -> Block { + Block(value) + } +} + +impl From<&[u8; 16]> for Block { + fn from(value: &[u8; 16]) -> Block { + Block(*value) + } +} + +impl From<u128> for Block { + fn from(value: u128) -> Block { + Block(value.to_be_bytes()) + } +} + +impl BitXor<Block> for Block { + type Output = Block; + fn bitxor(self, rhs: Block) -> Block { + Block::from(self.to_int() ^ rhs.to_int()) + } +} + +impl Shl<u32> for Block { + type Output = Block; + fn shl(self, rhs: u32) -> Block { + Block::from(self.to_int() << rhs) + } +} + +impl Shr<u32> for Block { + type Output = Block; + fn shr(self, rhs: u32) -> Block { + Block::from(self.to_int() >> rhs) + } +} + +impl BitAnd<Block> for Block { + type Output = Block; + fn bitand(self, rhs: Block) -> Block { + Block::from(self.to_int() & rhs.to_int()) + } +} + +impl BitOr<Block> for Block { + type Output = Block; + fn bitor(self, rhs: Block) -> Block { + Block::from(self.to_int() | rhs.to_int()) + } +} + +impl Index<usize> for Block { + type Output = u8; + fn index(&self, index: usize) -> &u8 { + &self.0[index] + } +} + +impl IndexMut<usize> for Block { + fn index_mut(&mut self, index: usize) -> &mut u8 { + &mut self.0[index] + } +} + +impl Mul<u32> for Block { + type Output = Block; + fn mul(self, rhs: u32) -> Block { + match rhs { + 0 => Block::NULL, + 1 => self, + 2 => { + let mut result = self << 1; + if self[0] & 0x80 != 0 { + result[15] ^= 135; + } + result + } + _ if rhs % 2 == 0 => self * 2 * (rhs / 2), + _ => self * (rhs - 1) ^ self, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_xor() { + assert_eq!( + Block::from([1; 16]) ^ Block::from([2; 16]), + Block::from([3; 16]) + ); + } + + #[test] + fn test_pad() { + assert_eq!( + Block::from([0; 16]).pad(0), + Block::from([0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + assert_eq!( + Block::from([0; 16]).pad(1), + Block::from([0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + assert_eq!( + Block::from([0; 16]).pad(8), + Block::from([0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + } + + #[test] + fn test_shl() { + assert_eq!( + Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 1, + Block::from([0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + assert_eq!( + Block::from([0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4, + Block::from([0x10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + assert_eq!( + Block::from([0x0A, 0xB0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) << 4, + Block::from([0xAB, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + assert_eq!( + Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) << 8, + Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]), + ); + } + + #[test] + fn test_times() { + assert_eq!( + Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 0, + Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + ); + assert_eq!( + Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) * 1, + Block::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), + ); + assert_eq!( + Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2, + Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]), + ); + assert_eq!( + Block::from([128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) * 2, + Block::from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 133]), + ); + assert_eq!( + Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 2, + Block::from([2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 133]), + ); + assert_eq!( + Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 3, + Block::from([131, 0, 0, 0, 1, 128, 0, 0, 0, 3, 0, 0, 0, 0, 0, 132]), + ); + assert_eq!( + Block::from([129, 0, 0, 0, 0, 128, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]) * 4, + Block::from([4, 0, 0, 0, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 1, 10]), + ); + } +} |