aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/accessor.rs82
-rw-r--r--src/lib.rs306
2 files changed, 242 insertions, 146 deletions
diff --git a/src/accessor.rs b/src/accessor.rs
new file mode 100644
index 0000000..89f5251
--- /dev/null
+++ b/src/accessor.rs
@@ -0,0 +1,82 @@
+use super::block::Block;
+
+pub struct BlockAccessor<'a> {
+ data: &'a mut [u8],
+ m_uv_len: usize,
+ m_u_len: usize,
+ m_v_len: usize,
+ num_block_pairs: usize,
+}
+
+impl<'a> BlockAccessor<'a> {
+ pub fn new(message: &'a mut [u8]) -> Self {
+ let num_block_pairs = (message.len() - 16 - 16) / 32;
+ let m_uv_len = (message.len() % 32) * 8;
+ Self {
+ data: message,
+ m_uv_len,
+ m_u_len: 128.min(m_uv_len),
+ m_v_len: m_uv_len.saturating_sub(128),
+ num_block_pairs,
+ }
+ }
+
+ pub fn m_uv_len(&self) -> usize {
+ self.m_uv_len
+ }
+
+ fn suffix_start(&self) -> usize {
+ self.num_block_pairs * 32
+ }
+
+ pub fn m_u(&self) -> Block {
+ let start = self.suffix_start();
+ Block::from_slice(&self.data[start..start + self.m_u_len / 8])
+ }
+
+ pub fn set_m_u(&mut self, m_u: Block) {
+ let start = self.suffix_start();
+ self.data[start..start + self.m_u_len / 8].copy_from_slice(&m_u.0[..self.m_u_len / 8]);
+ }
+
+ pub fn m_v(&self) -> Block {
+ let start = self.suffix_start();
+ Block::from_slice(&self.data[start + self.m_u_len / 8..start + self.m_uv_len / 8])
+ }
+
+ pub fn set_m_v(&mut self, m_v: Block) {
+ let start = self.suffix_start();
+ self.data[start + self.m_u_len / 8..start + self.m_uv_len / 8]
+ .copy_from_slice(&m_v.0[..self.m_v_len / 8]);
+ }
+
+ pub fn m_x(&self) -> Block {
+ let start = self.suffix_start() + self.m_uv_len / 8;
+ Block::from_slice(&self.data[start..start + 16])
+ }
+
+ pub fn set_m_x(&mut self, m_x: Block) {
+ let start = self.suffix_start() + self.m_uv_len / 8;
+ self.data[start..start + 16].copy_from_slice(&m_x.0);
+ }
+
+ pub fn m_y(&self) -> Block {
+ let start = self.suffix_start() + self.m_uv_len / 8;
+ Block::from_slice(&self.data[start + 16..start + 32])
+ }
+
+ pub fn set_m_y(&mut self, m_y: Block) {
+ let start = self.suffix_start() + self.m_uv_len / 8;
+ self.data[start + 16..start + 32].copy_from_slice(&m_y.0);
+ }
+
+ pub fn pairs_mut<'b>(
+ &'b mut self,
+ ) -> impl Iterator<Item = (&'b mut [u8; 16], &'b mut [u8; 16])> {
+ let stop = self.suffix_start();
+ self.data[..stop]
+ .chunks_exact_mut(32)
+ .map(move |x| x.split_at_mut(16))
+ .map(move |(x, y)| (x.try_into().unwrap(), y.try_into().unwrap()))
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 08f0570..58e488e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -85,10 +85,12 @@
use constant_time_eq::constant_time_eq;
+mod accessor;
mod block;
#[cfg(test)]
mod testvectors;
+use accessor::BlockAccessor;
use block::Block;
type Key = [u8; 48];
type Tweak<'a> = &'a [&'a [u8]];
@@ -135,7 +137,10 @@ impl Aez {
tau: u32,
data: &[u8],
) -> Vec<u8> {
- encrypt(&self.key, nonce, associated_data, tau, data)
+ let mut buffer = vec![0; data.len() + tau as usize];
+ buffer[..data.len()].copy_from_slice(data);
+ encrypt(&self.key, nonce, associated_data, tau, &mut buffer);
+ buffer
}
/// Decrypts the given ciphertext.
@@ -154,7 +159,13 @@ impl Aez {
tau: u32,
data: &[u8],
) -> Option<Vec<u8>> {
- decrypt(&self.key, nonce, associated_data, tau, data)
+ let mut buffer = Vec::from(data);
+ let len = match decrypt(&self.key, nonce, associated_data, tau, &mut buffer) {
+ None => return None,
+ Some(m) => m.len(),
+ };
+ buffer.truncate(len);
+ Some(buffer)
}
}
@@ -210,28 +221,35 @@ fn extract(key: &[u8]) -> [u8; 48] {
}
}
-fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> {
- let mut auth_message = Vec::with_capacity(message.len() + tau as usize);
- auth_message.extend_from_slice(&message);
- while auth_message.len() < message.len() + tau as usize {
- auth_message.extend_from_slice(
- &ZEROES[..ZEROES
- .len()
- .min(tau as usize - (auth_message.len() - message.len()))],
- );
+fn append_auth(data_len: usize, buffer: &mut [u8]) {
+ let mut total_len = data_len;
+ while total_len < buffer.len() {
+ let block_size = ZEROES.len().min(buffer.len() - total_len);
+ buffer[total_len..total_len + block_size].copy_from_slice(&ZEROES[..block_size]);
+ total_len += block_size;
}
+}
+
+fn encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, buffer: &mut [u8]) {
// We treat tau as bytes, but according to the spec, tau is actually in bits.
let tau_block = Block::from_int(tau as u128 * 8);
let mut tweaks = vec![&tau_block.0, nonce];
tweaks.extend(ad);
- if message.is_empty() {
- aez_prf(key, &tweaks, tau)
+ assert!(buffer.len() >= tau as usize);
+ if buffer.len() == tau as usize {
+ buffer.copy_from_slice(&aez_prf(key, &tweaks, tau));
} else {
- encipher(key, &tweaks, &auth_message)
+ encipher(key, &tweaks, buffer);
}
}
-fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -> Option<Vec<u8>> {
+fn decrypt<'a>(
+ key: &Key,
+ nonce: &[u8],
+ ad: &[&[u8]],
+ tau: u32,
+ ciphertext: &'a mut [u8],
+) -> Option<&'a [u8]> {
if ciphertext.len() < tau as usize {
return None;
}
@@ -242,14 +260,14 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -
if ciphertext.len() == tau as usize {
if constant_time_eq(&ciphertext, &aez_prf(key, &tweaks, tau)) {
- return Some(Vec::new());
+ return Some(&[]);
} else {
return None;
}
}
- let x = decipher(key, &tweaks, ciphertext);
- let (m, auth) = x.split_at(ciphertext.len() - tau as usize);
+ decipher(key, &tweaks, ciphertext);
+ let (m, auth) = ciphertext.split_at(ciphertext.len() - tau as usize);
assert!(auth.len() == tau as usize);
let comparator = if tau as usize <= ZEROES.len() {
&ZEROES[..tau as usize]
@@ -257,13 +275,13 @@ fn decrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, ciphertext: &[u8]) -
&vec![0; tau as usize]
};
if constant_time_eq(&auth, comparator) {
- Some(Vec::from(m))
+ Some(m)
} else {
None
}
}
-fn encipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+fn encipher(key: &Key, tweaks: Tweak, message: &mut [u8]) {
if message.len() < 256 / 8 {
encipher_aez_tiny(key, tweaks, message)
} else {
@@ -271,7 +289,7 @@ fn encipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
}
}
-fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &mut [u8]) {
let mu = message.len() * 8;
assert!(mu < 256);
let n = mu / 2;
@@ -298,48 +316,54 @@ fn encipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n);
(left, right) = (right, right_);
}
- let mut ciphertext = Vec::new();
if n % 8 == 0 {
- ciphertext.extend_from_slice(&right.0[..n / 8]);
- ciphertext.extend_from_slice(&left.0[..n / 8]);
+ message[..n / 8].copy_from_slice(&right.0[..n / 8]);
+ message[n / 8..].copy_from_slice(&left.0[..n / 8]);
} else {
- ciphertext.extend_from_slice(&right.0[..n / 8 + 1]);
+ let mut index = n / 8;
+ message[..index + 1].copy_from_slice(&right.0[..index + 1]);
for byte in &left.0[..n / 8 + 1] {
- *ciphertext.last_mut().unwrap() |= byte >> 4;
- ciphertext.push((byte & 0x0f) << 4);
+ message[index] |= byte >> 4;
+ if index < message.len() - 1 {
+ message[index + 1] = (byte & 0x0f) << 4;
+ }
+ index += 1;
}
- ciphertext.pop();
}
if mu < 128 {
- let mut c = Block::from_slice(&ciphertext);
+ let mut c = Block::from_slice(&message);
c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE);
- ciphertext = Vec::from(&c.0[..mu / 8]);
+ message.copy_from_slice(&c.0[..mu / 8]);
}
- assert!(ciphertext.len() == message.len());
- ciphertext
}
-fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
+fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &mut [u8]) {
assert!(message.len() >= 32);
let delta = aez_hash(key, tweaks);
- let (block_pairs, m_u, m_v, m_x, m_y, d) = split_blocks(message);
+ let mut blocks = BlockAccessor::new(message);
+ let (m_u, m_v, m_x, m_y, d) = (
+ blocks.m_u(),
+ blocks.m_v(),
+ blocks.m_x(),
+ blocks.m_y(),
+ blocks.m_uv_len(),
+ );
let len_v = d.saturating_sub(128);
- let mut ws = Vec::new();
- let mut xs = Vec::new();
-
+ let mut x = Block::NULL;
let mut e1_eval = E::new(1, 0, key);
- for (mi, mi_) in block_pairs.iter() {
+
+ for (raw_mi, raw_mi_) in blocks.pairs_mut() {
e1_eval.advance();
- let w = *mi ^ e1_eval.eval(*mi_);
- let x = *mi_ ^ e(0, 0, key, w);
- ws.push(w);
- xs.push(x);
- }
+ let mi = Block::from(*raw_mi);
+ let mi_ = Block::from(*raw_mi_);
+ let wi = mi ^ e1_eval.eval(mi_);
+ let xi = mi_ ^ e(0, 0, key, wi);
- let mut x = Block::NULL;
- for xi in &xs {
- x = x ^ *xi;
+ *raw_mi = wi.0;
+ *raw_mi_ = xi.0;
+
+ x = x ^ xi;
}
match d {
@@ -357,20 +381,22 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let s_y = m_y ^ e(-1, 1, key, s_x);
let s = s_x ^ s_y;
- let mut cipher_pairs = Vec::new();
let mut y = Block::NULL;
let mut e2_eval = E::new(2, 0, key);
let mut e1_eval = E::new(1, 0, key);
- for (wi, xi) in ws.iter().zip(xs.iter()) {
+ for (raw_wi, raw_xi) in blocks.pairs_mut() {
e2_eval.advance();
e1_eval.advance();
+ let wi = Block::from(*raw_wi);
+ let xi = Block::from(*raw_xi);
let s_ = e2_eval.eval(s);
- let yi = *wi ^ s_;
- let zi = *xi ^ s_;
+ let yi = wi ^ s_;
+ let zi = xi ^ s_;
let ci_ = yi ^ e(0, 0, key, zi);
let ci = zi ^ e1_eval.eval(ci_);
- cipher_pairs.push((ci, ci_));
+ *raw_wi = ci.0;
+ *raw_xi = ci_.0;
y = y ^ yi;
}
@@ -394,29 +420,22 @@ fn encipher_aez_core(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
let c_y = s_x ^ e(-1, 2, key, s_y);
let c_x = s_y ^ delta ^ y ^ e(0, 2, key, c_y);
- let mut ciphertext = Vec::new();
- for (ci, ci_) in cipher_pairs {
- ciphertext.extend_from_slice(&ci.0);
- ciphertext.extend_from_slice(&ci_.0);
- }
- ciphertext.extend_from_slice(&c_u.0[..128.min(d) as usize / 8]);
- ciphertext.extend_from_slice(&c_v.0[..len_v as usize / 8]);
- ciphertext.extend_from_slice(&c_x.0);
- ciphertext.extend_from_slice(&c_y.0);
- assert!(ciphertext.len() == message.len());
- ciphertext
+ blocks.set_m_u(c_u);
+ blocks.set_m_v(c_v);
+ blocks.set_m_x(c_x);
+ blocks.set_m_y(c_y);
}
-fn decipher(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
- if message.len() < 256 / 8 {
- decipher_aez_tiny(key, tweaks, message)
+fn decipher(key: &Key, tweaks: Tweak, buffer: &mut [u8]) {
+ if buffer.len() < 256 / 8 {
+ decipher_aez_tiny(key, tweaks, buffer);
} else {
- decipher_aez_core(key, tweaks, message)
+ decipher_aez_core(key, tweaks, buffer);
}
}
-fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
- let mu = message.len() * 8;
+fn decipher_aez_tiny(key: &Key, tweaks: Tweak, buffer: &mut [u8]) {
+ let mu = buffer.len() * 8;
assert!(mu < 256);
let n = mu / 2;
let delta = aez_hash(key, tweaks);
@@ -427,65 +446,69 @@ fn decipher_aez_tiny(key: &Key, tweaks: Tweak, message: &[u8]) -> Vec<u8> {
_ => 8,
};
- let mut message = Vec::from(message);
if mu < 128 {
- let mut c = Block::from_slice(&message);
+ let mut c = Block::from_slice(buffer);
c = c ^ (e(0, 3, key, delta ^ (c | Block::ONE)) & Block::ONE);
- message.clear();
- message.extend(&c.0[..mu / 8]);
+ buffer.copy_from_slice(&c.0[..mu / 8]);
}
let (mut left, mut right);
// We might end up having to split at a nibble, so manually adjust for that
if n % 8 == 0 {
- left = Block::from_slice(&message[..n / 8]);
- right = Block::from_slice(&message[n / 8..]);
+ left = Block::from_slice(&buffer[..n / 8]);
+ right = Block::from_slice(&buffer[n / 8..]);
} else {
- left = Block::from_slice(&message[..n / 8 + 1]).clip(n);
- right = Block::from_slice(&message[n / 8..]) << 4;
+ left = Block::from_slice(&buffer[..n / 8 + 1]).clip(n);
+ right = Block::from_slice(&buffer[n / 8..]) << 4;
};
let i = if mu >= 128 { 6 } else { 7 };
for j in (0..round_count).rev() {
let right_ = (left ^ e(0, i, key, delta ^ right.pad(n) ^ Block::from_int(j))).clip(n);
(left, right) = (right, right_);
}
- let mut ciphertext = Vec::new();
+
if n % 8 == 0 {
- ciphertext.extend_from_slice(&right.0[..n / 8]);
- ciphertext.extend_from_slice(&left.0[..n / 8]);
+ buffer[..n / 8].copy_from_slice(&right.0[..n / 8]);
+ buffer[n / 8..].copy_from_slice(&left.0[..n / 8]);
} else {
- ciphertext.extend_from_slice(&right.0[..n / 8 + 1]);
+ let mut index = n / 8;
+ buffer[..index + 1].copy_from_slice(&right.0[..index + 1]);
for byte in &left.0[..n / 8 + 1] {
- *ciphertext.last_mut().unwrap() |= byte >> 4;
- ciphertext.push((byte & 0x0f) << 4);
+ buffer[index] |= byte >> 4;
+ if index < buffer.len() - 1 {
+ buffer[index + 1] = (byte & 0x0f) << 4;
+ }
+ index += 1;
}
- ciphertext.pop();
}
- assert!(ciphertext.len() == message.len());
- ciphertext
}
-fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> {
- assert!(cipher.len() >= 32);
+fn decipher_aez_core(key: &Key, tweaks: Tweak, buffer: &mut [u8]) {
+ assert!(buffer.len() >= 32);
let delta = aez_hash(key, tweaks);
- let (block_pairs, c_u, c_v, c_x, c_y, d) = split_blocks(cipher);
+ let mut blocks = BlockAccessor::new(buffer);
+ let (c_u, c_v, c_x, c_y, d) = (
+ blocks.m_u(),
+ blocks.m_v(),
+ blocks.m_x(),
+ blocks.m_y(),
+ blocks.m_uv_len(),
+ );
let len_v = d.saturating_sub(128);
- let mut ws = Vec::new();
- let mut ys = Vec::new();
-
+ let mut y = Block::NULL;
let mut e1_eval = E::new(1, 0, key);
- for (ci, ci_) in block_pairs.iter() {
+ for (raw_ci, raw_ci_) in blocks.pairs_mut() {
e1_eval.advance();
- let w = *ci ^ e1_eval.eval(*ci_);
- let y = *ci_ ^ e(0, 0, key, w);
- ws.push(w);
- ys.push(y);
- }
+ let ci = Block::from(*raw_ci);
+ let ci_ = Block::from(*raw_ci_);
+ let wi = ci ^ e1_eval.eval(ci_);
+ let yi = ci_ ^ e(0, 0, key, wi);
- let mut y = Block::NULL;
- for yi in &ys {
- y = y ^ *yi;
+ *raw_ci = wi.0;
+ *raw_ci_ = yi.0;
+
+ y = y ^ yi;
}
match d {
@@ -503,20 +526,23 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> {
let s_y = c_y ^ e(-1, 2, key, s_x);
let s = s_x ^ s_y;
- let mut plain_pairs = Vec::new();
let mut x = Block::NULL;
let mut e2_eval = E::new(2, 0, key);
let mut e1_eval = E::new(1, 0, key);
- for (wi, yi) in ws.iter().zip(ys.iter()) {
+ for (raw_wi, raw_yi) in blocks.pairs_mut() {
e2_eval.advance();
e1_eval.advance();
+ let wi = Block::from(*raw_wi);
+ let yi = Block::from(*raw_yi);
let s_ = e2_eval.eval(s);
- let xi = *wi ^ s_;
- let zi = *yi ^ s_;
+ let xi = wi ^ s_;
+ let zi = yi ^ s_;
let mi_ = xi ^ e(0, 0, key, zi);
let mi = zi ^ e1_eval.eval(mi_);
- plain_pairs.push((mi, mi_));
+ *raw_wi = mi.0;
+ *raw_yi = mi_.0;
+
x = x ^ xi;
}
@@ -540,45 +566,10 @@ fn decipher_aez_core(key: &Key, tweaks: Tweak, cipher: &[u8]) -> Vec<u8> {
let m_y = s_x ^ e(-1, 1, key, s_y);
let m_x = s_y ^ delta ^ x ^ e(0, 1, key, m_y);
- let mut message = Vec::new();
- for (mi, mi_) in plain_pairs {
- message.extend_from_slice(&mi.0);
- message.extend_from_slice(&mi_.0);
- }
- message.extend_from_slice(&m_u.0[..128.min(d) as usize / 8]);
- message.extend_from_slice(&m_v.0[..len_v as usize / 8]);
- message.extend_from_slice(&m_x.0);
- message.extend_from_slice(&m_y.0);
- message
-}
-
-fn split_blocks(mut message: &[u8]) -> (Vec<(Block, Block)>, Block, Block, Block, Block, u8) {
- let num_blocks = (message.len() - 16 - 16) / 32;
- let mut blocks = Vec::new();
- for _ in 0..num_blocks {
- let a = Block::from_slice(&message[..16]);
- let b = Block::from_slice(&message[16..32]);
- blocks.push((a, b));
- message = &message[32..];
- }
- let m_uv = &message[..message.len() - 32];
- let d = m_uv.len() * 8;
- assert!(d < 256);
- message = &message[m_uv.len()..];
- assert!(message.len() == 32);
-
- let m_u;
- let m_v;
- if d <= 127 {
- m_u = Block::from_slice(m_uv);
- m_v = Block::default();
- } else {
- m_u = Block::from_slice(&m_uv[..16]);
- m_v = Block::from_slice(&m_uv[16..]);
- }
- let m_x = Block::from_slice(&message[..16]);
- let m_y = Block::from_slice(&message[16..]);
- (blocks, m_u, m_v, m_x, m_y, d as u8)
+ blocks.set_m_u(m_u);
+ blocks.set_m_v(m_v);
+ blocks.set_m_x(m_x);
+ blocks.set_m_y(m_y);
}
fn pad_to_blocks(value: &[u8]) -> Vec<Block> {
@@ -803,6 +794,29 @@ mod test {
}
}
+ fn vec_encrypt(key: &Key, nonce: &[u8], ad: &[&[u8]], tau: u32, message: &[u8]) -> Vec<u8> {
+ let mut v = vec![0; message.len() + tau as usize];
+ v[..message.len()].copy_from_slice(message);
+ encrypt(key, nonce, ad, tau, &mut v);
+ v
+ }
+
+ fn vec_decrypt(
+ key: &Key,
+ nonce: &[u8],
+ ad: &[&[u8]],
+ tau: u32,
+ ciphertext: &[u8],
+ ) -> Option<Vec<u8>> {
+ let mut v = Vec::from(ciphertext);
+ let len = match decrypt(key, nonce, ad, tau, &mut v) {
+ None => return None,
+ Some(m) => m.len(),
+ };
+ v.truncate(len);
+ Some(v)
+ }
+
#[test]
fn test_encrypt() {
let mut failed = 0;
@@ -822,7 +836,7 @@ mod test {
let m = hex::decode(m).unwrap();
let c = hex::decode(c).unwrap();
- if &encrypt(&k, &n, &ad, *tau, &m) == &c {
+ if &vec_encrypt(&k, &n, &ad, *tau, &m) == &c {
println!("+ {name}");
succ += 1;
} else {
@@ -853,7 +867,7 @@ mod test {
let m = hex::decode(m).unwrap();
let c = hex::decode(c).unwrap();
- if decrypt(&k, &n, &ad, *tau, &c) == Some(m) {
+ if vec_decrypt(&k, &n, &ad, *tau, &c) == Some(m) {
println!("+ {name}");
succ += 1;
} else {