def exp2_fx(a, zero_output=False): """ Power of two for fixed-point numbers. :param a: exponent for :math:`2^a` (sfix) :param zero_output: whether to output zero for very small values. If not, the result will be undefined. :return: :math:`2^a` if it is within the range. Undefined otherwise """ if types.program.options.ring: sint = types.sint intbitint = types.intbitint # how many bits to use from integer part n_int_bits = int(math.ceil(math.log(a.k - a.f, 2))) n_bits = a.f + n_int_bits n_shift = int(types.program.options.ring) - a.k if types.program.use_split(): assert not zero_output from Compiler.GC.types import sbitvec if types.program.use_split() == 3: x = a.v.split_to_two_summands(a.k) bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False) # converting MSB first reduces the number of rounds s = sint.conv(bits[-1]) lower_overflow = sint.conv(x[0][a.f]) + \ sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f]) lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f) elif types.program.use_split() == 4: x = list(zip(*a.v.split_to_n_summands(a.k, 4))) bi = types._bitint red = bi.wallace_reduction sums1, carries1 = red(*x[:3], get_carry=False) sums2, carries2 = red(x[3], sums1, carries1, False) bits = bi.carry_lookahead_adder(sums2, carries2, fewer_inv=False) overflows = bi.full_adder(carries1[a.f], carries2[a.f], bits[a.f] ^ sums2[a.f] ^ carries2[a.f]) overflows = reversed(list((sint.conv(x) for x in reversed(overflows)))) lower_overflow = sint.bit_compose(sint.conv(x) for x in overflows) s = sint.conv(bits[-1]) lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f) else: bits = sbitvec(a.v, a.k) s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] else: if types.program.use_edabit(): l = sint.get_edabit(a.f, True) u = sint.get_edabit(a.k - a.f, True) r_bits = l[1] + u[1] r = l[0] + (u[0] << a.f) lower_r = l[0] else: r_bits = [sint.get_random_bit() for i in range(a.k)] r = sint.bit_compose(r_bits) lower_r = sint.bit_compose(r_bits[:a.f]) shifted = ((a.v - r) << n_shift).reveal() masked_bits = (shifted >> n_shift).bit_decompose(a.k) lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) lower_masked = sint.bit_compose(masked_bits[:a.f]) lower = lower_r + lower_masked - \ (sint.conv(lower_overflow) << (a.f)) higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits], masked_bits[a.f:n_bits], carry_in=lower_overflow, get_carry=True) carry = comparison.CarryOutLE(masked_bits[n_bits:-1], r_bits[n_bits:-1], higher_bits[-1]) if zero_output: # should be for free highest_bits = r_bits[0].ripple_carry_adder( masked_bits[n_bits:-1], [0] * (a.k - n_bits), carry_in=higher_bits[-1]) bits_to_check = [x.bit_xor(y) for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])] t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y), bits_to_check)) # sign s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1]) del higher_bits[-1] c = types.sfix._new(lower, k=a.k, f=a.f) assert(len(higher_bits) == n_bits - a.f) pow2_bits = [sint.conv(x) for x in higher_bits] d = floatingpoint.Pow2_from_bits(pow2_bits) e = p_eval(p_1045, c) g = d * e small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits, 2 ** n_int_bits, signed=False, nearest=types.sfix.round_nearest), k=a.k, f=a.f) if zero_output: small_result = t.if_else(small_result, 0) return s.if_else(small_result, g) else: assert not zero_output # obtain absolute value of a s = a < 0 a = s.if_else(-a, a) # isolates fractional part of number b = trunc(a) c = a - b # squares integer part of a d = b.pow2(a.k - a.f) # evaluates fractional part of a in p_1045 e = p_eval(p_1045, c) g = d * e return s.if_else(1 / g, g)
def sha3_256(x): """ This function implements SHA3-256 for inputs of up to 1080 bits:: from circuit import sha3_256 a = sbitvec.from_vec([]) b = sbitvec(sint(0xcc), 8) for x in a, b: sha3_256(x).elements()[0].reveal().print_reg() This should output the first two test vectors of SHA3-256 in byte-reversed order:: 0x5375f6fb6aa989b0c287a923afe81e79ff875921cacc956666d71ebff8c6ffa7 0x17c7e0d65c285af8406d4f21c071851a312b739a8ecdf25c1270d31c39357067 Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only implemented for computation modulo a power of two. """ global Keccak_f if Keccak_f is None: # only one instance Keccak_f = Circuit('Keccak_f') # whole bytes assert len(x.v) % 8 == 0 # only one block r = 1088 assert len(x.v) < 1088 if x.v: n = x.v[0].n else: n = 1 d = sbitvec([sbits.get_type(8)(0x06)] * n) sbn = sbits.get_type(n) padding = [sbn(0)] * (r - 8 - len(x.v)) P_flat = x.v + d.v + padding assert len(P_flat) == r P_flat[-1] = ~P_flat[-1] w = 64 P1 = [P_flat[i * w:(i + 1) * w] for i in range(r // w)] S = [[[sbn(0) for i in range(w)] for i in range(5)] for i in range(5)] for x in range(5): for y in range(5): if x + 5 * y < r // w: for i in range(w): S[x][y][i] ^= P1[x + 5 * y][i] def flatten(S): res = [None] * 1600 for y in range(5): for x in range(5): for i in range(w): j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 res[1600 - 1 - j] = S[x][y][i] return res def unflatten(S_flat): res = [[[None] * w for j in range(5)] for i in range(5)] for y in range(5): for x in range(5): for i in range(w): j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8 res[x][y][i] = S_flat[1600 - 1 -j] return res S = unflatten(Keccak_f(flatten(S))) Z = [] while len(Z) <= 256: for y in range(5): for x in range(5): if x + 5 * y < r // w: Z += S[y][x] if len(Z) <= 256: S = unflatten(Keccak_f(flatten(S))) return sbitvec.from_vec(Z[:256])