class AES: def __init__(self, master_key, b=1., noise=0.): self.key_schedule(master_key) self.b = b self.noise = noise def key_schedule(self, master_key): self.round_keys = self.key2matrix(master_key) for i in range(4, 4 * 11): self.round_keys.append([]) if i % 4 == 0: byte = self.round_keys[i - 4][0] \ ^ Sbox[self.round_keys[i - 1][1]] \ ^ Rcon[i // 4] self.round_keys[i].append(byte) for j in range(1, 4): byte = self.round_keys[i - 4][j] \ ^ Sbox[self.round_keys[i - 1][(j + 1) % 4]] self.round_keys[i].append(byte) else: for j in range(4): byte = self.round_keys[i - 4][j] \ ^ self.round_keys[i - 1][j] self.round_keys[i].append(byte) def key2matrix(self, key): matrix = [] for i in range(16): byte = (key >> (8 * (15 - i))) & 0xFF if i % 4 == 0: matrix.append([byte]) else: matrix[i // 4].append(byte) return matrix def encrypt(self, plaintext): # input plaintext -> np.array self.trace = Trace(self.b, self.noise) self.plain_state = plaintext.reshape(4, 4) self.__add_round_key(self.plain_state, self.round_keys[:4]) for i in range(1, 10): self.__round_encrypt(self.plain_state, self.round_keys[4 * i:4 * (i + 1)]) self.__sub_bytes(self.plain_state) self.__shift_rows(self.plain_state) self.__add_round_key(self.plain_state, self.round_keys[40:]) return self.plain_state.flatten(), self.trace def __add_round_key(self, s, k): self.trace.AddRoundKey(s) for i in range(4): for j in range(4): s[i][j] ^= k[i][j] def __round_encrypt(self, state_matrix, key_matrix): self.__sub_bytes(state_matrix) self.__shift_rows(state_matrix) self.__mix_columns(state_matrix) self.__add_round_key(state_matrix, key_matrix) def __sub_bytes(self, s): self.trace.SubBytes(s) for i in range(4): for j in range(4): s[i][j] = Sbox[s[i][j]] def __shift_rows(self, s): self.trace.ShiftRows(s) s[0][1], s[1][1], s[2][1], s[3][1] = s[1][1], s[2][1], s[3][1], s[0][1] s[0][2], s[1][2], s[2][2], s[3][2] = s[2][2], s[3][2], s[0][2], s[1][2] s[0][3], s[1][3], s[2][3], s[3][3] = s[3][3], s[0][3], s[1][3], s[2][3] def __mix_single_column(self, a): xtime = lambda a: (( (a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1) t = a[0] ^ a[1] ^ a[2] ^ a[3] u = a[0] a[0] ^= t ^ xtime(a[0] ^ a[1]) a[1] ^= t ^ xtime(a[1] ^ a[2]) a[2] ^= t ^ xtime(a[2] ^ a[3]) a[3] ^= t ^ xtime(a[3] ^ u) def __mix_columns(self, s): self.trace.MixColumns(s) for i in range(4): self.__mix_single_column(s[i])