コード例 #1
0
def exp2_fx(a, zero_output=False):
    """
    Power of two for fixed-point numbers.

    :param a: exponent for :math:`2^a` (sfix)
    :param zero_output: whether to output zero for very small values. If not, the result will be undefined.

    :return: :math:`2^a` if it is within the range. Undefined otherwise
    """
    if types.program.options.ring:
        sint = types.sint
        intbitint = types.intbitint
        # how many bits to use from integer part
        n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
        n_bits = a.f + n_int_bits
        n_shift = int(types.program.options.ring) - a.k
        if types.program.use_split():
            assert not zero_output
            from Compiler.GC.types import sbitvec
            if types.program.use_split() == 3:
                x = a.v.split_to_two_summands(a.k)
                bits = types._bitint.carry_lookahead_adder(x[0], x[1],
                                                           fewer_inv=False)
                # converting MSB first reduces the number of rounds
                s = sint.conv(bits[-1])
                lower_overflow = sint.conv(x[0][a.f]) + \
                                 sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f])
                lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
            elif types.program.use_split() == 4:
                x = list(zip(*a.v.split_to_n_summands(a.k, 4)))
                bi = types._bitint
                red = bi.wallace_reduction
                sums1, carries1 = red(*x[:3], get_carry=False)
                sums2, carries2 = red(x[3], sums1, carries1, False)
                bits = bi.carry_lookahead_adder(sums2, carries2,
                                                fewer_inv=False)
                overflows = bi.full_adder(carries1[a.f], carries2[a.f],
                                        bits[a.f] ^ sums2[a.f] ^ carries2[a.f])
                overflows = reversed(list((sint.conv(x)
                                           for x in reversed(overflows))))
                lower_overflow = sint.bit_compose(sint.conv(x)
                                                  for x in overflows)
                s = sint.conv(bits[-1])
                lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
            else:
                bits = sbitvec(a.v, a.k)
                s = sint.conv(bits[-1])
                lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
            higher_bits = bits[a.f:n_bits]
        else:
            if types.program.use_edabit():
                l = sint.get_edabit(a.f, True)
                u = sint.get_edabit(a.k - a.f, True)
                r_bits = l[1] + u[1]
                r = l[0] + (u[0] << a.f)
                lower_r = l[0]
            else:
                r_bits = [sint.get_random_bit() for i in range(a.k)]
                r = sint.bit_compose(r_bits)
                lower_r = sint.bit_compose(r_bits[:a.f])
            shifted = ((a.v - r) << n_shift).reveal()
            masked_bits = (shifted >> n_shift).bit_decompose(a.k)
            lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
                                r_bits[a.f-1::-1])
            lower_masked = sint.bit_compose(masked_bits[:a.f])
            lower = lower_r + lower_masked - \
                    (sint.conv(lower_overflow) << (a.f))
            higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
                                              masked_bits[a.f:n_bits],
                                              carry_in=lower_overflow,
                                              get_carry=True)
            carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
                                          r_bits[n_bits:-1],
                                          higher_bits[-1])
            if zero_output:
                # should be for free
                highest_bits = r_bits[0].ripple_carry_adder(
                    masked_bits[n_bits:-1], [0] * (a.k - n_bits),
                    carry_in=higher_bits[-1])
                bits_to_check = [x.bit_xor(y)
                                 for x, y in zip(highest_bits[:-1],
                                                 r_bits[n_bits:-1])]
                t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
                                                 bits_to_check))
            # sign
            s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
            del higher_bits[-1]
        c = types.sfix._new(lower, k=a.k, f=a.f)
        assert(len(higher_bits) == n_bits - a.f)
        pow2_bits = [sint.conv(x) for x in higher_bits]
        d = floatingpoint.Pow2_from_bits(pow2_bits)
        e = p_eval(p_1045, c)
        g = d * e
        small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
                                            2 ** n_int_bits, signed=False,
                                            nearest=types.sfix.round_nearest),
                                       k=a.k, f=a.f)
        if zero_output:
            small_result = t.if_else(small_result, 0)
        return s.if_else(small_result, g)
    else:
        assert not zero_output
        # obtain absolute value of a
        s = a < 0
        a = s.if_else(-a, a)
        # isolates fractional part of number
        b = trunc(a)
        c = a - b
        # squares integer part of a
        d = b.pow2(a.k - a.f)
        # evaluates fractional part of a in p_1045
        e = p_eval(p_1045, c)
        g = d * e
        return s.if_else(1 / g, g)
コード例 #2
0
def sha3_256(x):
    """
    This function implements SHA3-256 for inputs of up to 1080 bits::

        from circuit import sha3_256
        a = sbitvec.from_vec([])
        b = sbitvec(sint(0xcc), 8)
        for x in a, b:
            sha3_256(x).elements()[0].reveal().print_reg()

    This should output the first two test vectors of SHA3-256 in
    byte-reversed order::

        0x5375f6fb6aa989b0c287a923afe81e79ff875921cacc956666d71ebff8c6ffa7
        0x17c7e0d65c285af8406d4f21c071851a312b739a8ecdf25c1270d31c39357067

    Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
    implemented for computation modulo a power of two.
    """

    global Keccak_f
    if Keccak_f is None:
        # only one instance
        Keccak_f = Circuit('Keccak_f')

    # whole bytes
    assert len(x.v) % 8 == 0
    # only one block
    r = 1088
    assert len(x.v) < 1088
    if x.v:
        n = x.v[0].n
    else:
        n = 1
    d = sbitvec([sbits.get_type(8)(0x06)] * n)
    sbn = sbits.get_type(n)
    padding = [sbn(0)] * (r - 8 - len(x.v))
    P_flat = x.v + d.v + padding
    assert len(P_flat) == r
    P_flat[-1] = ~P_flat[-1]
    w = 64
    P1 = [P_flat[i * w:(i + 1) * w] for i in range(r // w)]

    S = [[[sbn(0) for i in range(w)] for i in range(5)] for i in range(5)]
    for x in range(5):
        for y in range(5):
            if x + 5 * y < r // w:
                for i in range(w):
                    S[x][y][i] ^= P1[x + 5 * y][i]

    def flatten(S):
        res = [None] * 1600
        for y in range(5):
            for x in range(5):
                for i in range(w):
                    j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
                    res[1600 - 1 - j] = S[x][y][i]
        return res

    def unflatten(S_flat):
        res = [[[None] * w for j in range(5)] for i in range(5)]
        for y in range(5):
            for x in range(5):
                for i in range(w):
                    j = (5 * y + x) * w + i // 8 * 8 + 7 - i % 8
                    res[x][y][i] = S_flat[1600 - 1 -j]
        return res

    S = unflatten(Keccak_f(flatten(S)))

    Z = []
    while len(Z) <= 256:
        for y in range(5):
            for x in range(5):
                if x + 5 * y < r // w:
                    Z += S[y][x]
        if len(Z) <= 256:
            S = unflatten(Keccak_f(flatten(S)))
    return sbitvec.from_vec(Z[:256])