コード例 #1
0
def exp2_fx(a, zero_output=False):
    """
    Power of two for fixed-point numbers.

    :param a: exponent for :math:`2^a` (sfix)
    :param zero_output: whether to output zero for very small values. If not, the result will be undefined.

    :return: :math:`2^a` if it is within the range. Undefined otherwise
    """
    if types.program.options.ring:
        sint = types.sint
        intbitint = types.intbitint
        # how many bits to use from integer part
        n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
        n_bits = a.f + n_int_bits
        n_shift = int(types.program.options.ring) - a.k
        if types.program.use_split():
            assert not zero_output
            from Compiler.GC.types import sbitvec
            if types.program.use_split() == 3:
                x = a.v.split_to_two_summands(a.k)
                bits = types._bitint.carry_lookahead_adder(x[0], x[1],
                                                           fewer_inv=False)
                # converting MSB first reduces the number of rounds
                s = sint.conv(bits[-1])
                lower_overflow = sint.conv(x[0][a.f]) + \
                                 sint.conv(x[0][a.f] ^ x[1][a.f] ^ bits[a.f])
                lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
            elif types.program.use_split() == 4:
                x = list(zip(*a.v.split_to_n_summands(a.k, 4)))
                bi = types._bitint
                red = bi.wallace_reduction
                sums1, carries1 = red(*x[:3], get_carry=False)
                sums2, carries2 = red(x[3], sums1, carries1, False)
                bits = bi.carry_lookahead_adder(sums2, carries2,
                                                fewer_inv=False)
                overflows = bi.full_adder(carries1[a.f], carries2[a.f],
                                        bits[a.f] ^ sums2[a.f] ^ carries2[a.f])
                overflows = reversed(list((sint.conv(x)
                                           for x in reversed(overflows))))
                lower_overflow = sint.bit_compose(sint.conv(x)
                                                  for x in overflows)
                s = sint.conv(bits[-1])
                lower = a.v.raw_mod2m(a.f) - (lower_overflow << a.f)
            else:
                bits = sbitvec(a.v, a.k)
                s = sint.conv(bits[-1])
                lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
            higher_bits = bits[a.f:n_bits]
        else:
            if types.program.use_edabit():
                l = sint.get_edabit(a.f, True)
                u = sint.get_edabit(a.k - a.f, True)
                r_bits = l[1] + u[1]
                r = l[0] + (u[0] << a.f)
                lower_r = l[0]
            else:
                r_bits = [sint.get_random_bit() for i in range(a.k)]
                r = sint.bit_compose(r_bits)
                lower_r = sint.bit_compose(r_bits[:a.f])
            shifted = ((a.v - r) << n_shift).reveal()
            masked_bits = (shifted >> n_shift).bit_decompose(a.k)
            lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
                                r_bits[a.f-1::-1])
            lower_masked = sint.bit_compose(masked_bits[:a.f])
            lower = lower_r + lower_masked - \
                    (sint.conv(lower_overflow) << (a.f))
            higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
                                              masked_bits[a.f:n_bits],
                                              carry_in=lower_overflow,
                                              get_carry=True)
            carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
                                          r_bits[n_bits:-1],
                                          higher_bits[-1])
            if zero_output:
                # should be for free
                highest_bits = r_bits[0].ripple_carry_adder(
                    masked_bits[n_bits:-1], [0] * (a.k - n_bits),
                    carry_in=higher_bits[-1])
                bits_to_check = [x.bit_xor(y)
                                 for x, y in zip(highest_bits[:-1],
                                                 r_bits[n_bits:-1])]
                t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
                                                 bits_to_check))
            # sign
            s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
            del higher_bits[-1]
        c = types.sfix._new(lower, k=a.k, f=a.f)
        assert(len(higher_bits) == n_bits - a.f)
        pow2_bits = [sint.conv(x) for x in higher_bits]
        d = floatingpoint.Pow2_from_bits(pow2_bits)
        e = p_eval(p_1045, c)
        g = d * e
        small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
                                            2 ** n_int_bits, signed=False,
                                            nearest=types.sfix.round_nearest),
                                       k=a.k, f=a.f)
        if zero_output:
            small_result = t.if_else(small_result, 0)
        return s.if_else(small_result, g)
    else:
        assert not zero_output
        # obtain absolute value of a
        s = a < 0
        a = s.if_else(-a, a)
        # isolates fractional part of number
        b = trunc(a)
        c = a - b
        # squares integer part of a
        d = b.pow2(a.k - a.f)
        # evaluates fractional part of a in p_1045
        e = p_eval(p_1045, c)
        g = d * e
        return s.if_else(1 / g, g)
コード例 #2
0
def exp2_fx(a):
    """
    Power of two for fixed-point numbers.

    :param a: exponent for :math:`2^a` (sfix)

    :return: :math:`2^a` if it is within the range. Undefined otherwise
    """
    if types.program.options.ring:
        sint = types.sint
        intbitint = types.intbitint
        # how many bits to use from integer part
        n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
        n_bits = a.f + n_int_bits
        n_shift = int(types.program.options.ring) - a.k
        if types.program.use_edabit():
            l = sint.get_edabit(a.f, True)
            u = sint.get_edabit(a.k - a.f, True)
            r_bits = l[1] + u[1]
            r = l[0] + (u[0] << a.f)
            lower_r = l[0]
        else:
            r_bits = [sint.get_random_bit() for i in range(a.k)]
            r = sint.bit_compose(r_bits)
            lower_r = sint.bit_compose(r_bits[:a.f])
        shifted = ((a.v - r) << n_shift).reveal()
        masked_bits = (shifted >> n_shift).bit_decompose(a.k)
        lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
                            r_bits[a.f-1::-1])
        lower_masked = sint.bit_compose(masked_bits[:a.f])
        lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f))
        c = types.sfix._new(lower, k=a.k, f=a.f)
        higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
                                          masked_bits[a.f:n_bits],
                                          carry_in=lower_overflow,
                                          get_carry=True)
        assert(len(higher_bits) == n_bits - a.f + 1)
        pow2_bits = [sint.conv(x) for x in higher_bits]
        d = floatingpoint.Pow2_from_bits(pow2_bits[:-1])
        e = p_eval(p_1045, c)
        g = d * e
        small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
                                            2 ** n_int_bits, signed=False,
                                            nearest=types.sfix.round_nearest),
                                       k=a.k, f=a.f)
        carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
                                      r_bits[n_bits:-1],
                                      higher_bits[-1])
        # should be for free
        highest_bits = r_bits[0].ripple_carry_adder(
            masked_bits[n_bits:-1], [0] * (a.k - n_bits),
            carry_in=higher_bits[-1])
        bits_to_check = [x.bit_xor(y)
                         for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
        t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
                                        bits_to_check))
        # sign
        s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
        return s.if_else(t.if_else(small_result, 0), g)
    else:
        # obtain absolute value of a
        s = a < 0
        a = (s * (-2) + 1) * a
        # isolates fractional part of number
        b = trunc(a)
        c = a - b
        # squares integer part of a
        d = b.pow2(a.k - a.f)
        # evaluates fractional part of a in p_1045
        e = p_eval(p_1045, c)
        g = d * e
        return (1 - s) * g + s / g