예제 #1
0
def Norm(b, k, f, kappa):
    """
        Computes secret integer values [c] and [v_prime] st.
        2^{k-1} <= c < 2^k and c = b*v_prime
    """
    temp = types.sint()
    comparison.LTZ(temp, b, k, kappa)
    sign = 1 - 2 * temp  # 1 - 2 * [b < 0]

    x = sign * b
    #x = |b|
    bits = x.bit_decompose(k)
    y = PreOR(bits)

    z = [0] * k
    for i in range(k - 1):
        z[i] = y[i] - y[i + 1]

    z[k - 1] = y[k - 1]
    # z[i] = 0 for all i except when bits[i + 1] = first one

    #now reverse bits of z[i]
    v = types.sint()
    for i in range(k):
        v += two_power(k - i - 1) * z[i]
    c = x * v
    v_prime = sign * v
    return c, v_prime
예제 #2
0
def Int2FL(a, gamma, l, kappa):
    lam = gamma - 1
    s = types.sint()
    comparison.LTZ(s, a, gamma, kappa)
    z = EQZ(a, gamma, kappa)
    a = (1 - 2 * s) * a

    a_bits = BitDec(a, lam, lam, kappa)
    a_bits.reverse()
    b = PreOR(a_bits, kappa)
    t = a * (1 + sum(2**i * (1 - b_i) for i, b_i in enumerate(b)))
    p = -(lam - sum(b))
    if lam > l:
        if types.sfloat.round_nearest:
            v, overflow = TruncRoundNearestAdjustOverflow(
                t, gamma - 1, l, kappa)
            p = p + overflow
        else:
            v = types.sint()
            comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False)
            #TODO: Shouldnt this be only gamma
    else:
        v = 2**(l - gamma + 1) * t
    p = (p + gamma - 1 - l) * (1 - z)
    return v, p, z, s
예제 #3
0
def norm_SQ(b, k):
    # calculation of z
    # x in order 0 - k
    z = MSB(b, k)
    # now reverse bits of z[i] to generate v
    v = types.sint(0)
    for i in range(k):
        v += (2**(k - i - 1)) * z[i]
    c = b * v

    # construct m
    m = types.sint(0)
    for i in range(k):
        m = m + (i + 1) * z[i]

    # construct w, changes from what is on the paper
    # and the documentation
    k_over_2 = k / 2 + 1  #int(math.ceil((k/2.0)))+1
    w_array = [0] * (k_over_2)
    w_array[0] = z[0]
    for i in range(1, k_over_2):
        w_array[i] = z[2 * i - 1] + z[2 * i]

    w = types.sint(0)
    for i in range(k_over_2):
        w += (2**i) * w_array[i]

    # return computed values
    return c, v, m, w
예제 #4
0
def sint_cint_division(a, b, k, f, kappa):
    """
        type(a) = sint, type(b) = cint
    """
    from types import cint, sint, Array
    from library import for_range

    theta = int(ceil(log(k / 3.5) / log(2)))
    two = cint(2) * AdvInteger.two_power(f)
    sign_b = cint(1) - 2 * cint(b < 0)
    sign_a = sint(1) - 2 * sint(a < 0)
    absolute_b = b * sign_b
    absolute_a = a * sign_a
    w0 = approximate_reciprocal(absolute_b, k, f, theta)

    A = Array(theta, sint)
    B = Array(theta, cint)
    W = Array(theta, cint)

    A[0] = absolute_a
    B[0] = absolute_b
    W[0] = w0

    @for_range(1, theta)
    def block(i):
        A[i] = AdvInteger.TruncPr(A[i - 1] * W[i - 1], 2 * k, f, kappa)
        temp = (B[i - 1] * W[i - 1]) >> f
        # no reading and writing to the same variable in a for loop.
        W[i] = two - temp
        B[i] = temp

    return (sign_a * sign_b) * A[theta - 1]
예제 #5
0
def norm_simplified_SQ(b, k):
    z = MSB(b, k)
    # construct m
    m = types.sint(0)
    m_odd = 0
    for i in range(k):
        m = m + (i + 1) * z[i]
        # determine the parity of the input
        if (i % 2 == 0):
            m_odd = m_odd + z[i]

    # construct w,
    k_over_2 = k / 2 + 1
    w_array = [0] * (k_over_2)
    w_array[0] = z[0]
    for i in range(1, k_over_2):
        w_array[i] = z[2 * i - 1] + z[2 * i]

    # w aggregation
    w = types.sint(0)
    for i in range(k_over_2):
        w += (2**i) * w_array[i]

    # return computed values
    return m_odd, m, w
예제 #6
0
def Norm(b, k, f, kappa):
    """
        Computes secret integer values [c] and [v_prime] st.
        2^{k-1} <= c < 2^k and c = b*v_prime
    """
    temp = types.sint()
    comparison.LTZ(temp, b, k, kappa)
    sign = 1 - 2 * temp # 1 - 2 * [b < 0]

    x = sign * b
    #x = |b|
    bits = x.bit_decompose(k)
    y = PreOR(bits)

    z = [0] * k
    for i in range(k - 1):
        z[i] = y[i] - y[i + 1]

    z[k - 1] = y[k - 1]
    # z[i] = 0 for all i except when bits[i + 1] = first one

    #now reverse bits of z[i]
    v = types.sint()
    for i in range(k):
        v += two_power(k - i - 1) * z[i]
    c = x * v
    v_prime = sign * v
    return c, v_prime
예제 #7
0
def SDiv(a, b, l, kappa, round_nearest=False):
    theta = int(ceil(log(l / 3.5) / log(2)))
    alpha = two_power(2 * l)
    w = types.cint(int(2.9142 * two_power(l))) - 2 * b
    x = alpha - b * w
    y = a * w
    y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
    x2 = types.sint()
    comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
    x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
    for i in range(theta - 1):
        y = y * (x1 + two_power(l)) + (y * x2).round(
            2 * l, l, kappa, round_nearest, signed=False)
        y = y.round(2 * l + 1, l + 1, kappa, round_nearest, signed=False)
        x = x1 * x2 + (x2**2).round(
            2 * l + 1, l + 1, kappa, round_nearest, signed=False)
        x = x1 * x1 + x.round(
            2 * l + 1, l - 1, kappa, round_nearest, signed=False)
        x2 = types.sint()
        comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
        x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
    y = y * (x1 + two_power(l)) + (y * x2).round(
        2 * l, l, kappa, round_nearest, signed=False)
    y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
    return y
예제 #8
0
파일: comparison.py 프로젝트: PuzzleEAA/eaa
def TruncRoundNearest(a, k, m, kappa):
    """
    Returns a / 2^m, rounded to the nearest integer.

    k: bit length of m
    m: compile-time integer
    """
    from types import sint, cint
    from library import reveal, load_int_to_secret
    if m == 1:
        lsb = sint()
        Mod2(lsb, a, k, kappa, False)
        return (a + lsb) / 2
    r_dprime = sint()
    r_prime = sint()
    r = [sint() for i in range(m)]
    u = sint()
    PRandM(r_dprime, r_prime, r, k, m, kappa)
    c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
    c_prime = c % (cint(1) << (m - 1))
    if const_rounds:
        BitLTC1(u, c_prime, r[:-1], kappa)
    else:
        BitLTL(u, c_prime, r[:-1], kappa)
    bit = ((c - c_prime) / (cint(1) << (m - 1))) % 2
    xor = bit + u - 2 * bit * u
    prod = xor * r[-1]
    # u_prime = xor * u + (1 - xor) * r[-1]
    u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
    a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
    d = (a - a_prime) / (cint(1) << m)
    rounding = xor + r[-1] - 2 * prod
    return d + rounding
예제 #9
0
def maskField(a, k, kappa):
    r_dprime = types.sint()
    r_prime = types.sint()
    c = types.cint()
    r = [types.sint() for i in range(k)]
    comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
    asm_open(c, a + two_power(k) * r_dprime + r_prime)# + 2**(k-1))
    return c, r
예제 #10
0
def maskField(a, k, kappa):
    r_dprime = types.sint()
    r_prime = types.sint()
    c = types.cint()
    r = [types.sint() for i in range(k)]
    comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
    asm_open(c, a + two_power(k) * r_dprime + r_prime)  # + 2**(k-1))
    return c, r
예제 #11
0
def EQZ(a, k, kappa):
    r_dprime = types.sint()
    r_prime = types.sint()
    c = types.cint()
    d = [None] * k
    r = [types.sint() for i in range(k)]
    comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
    asm_open(c, a + two_power(k) * r_dprime + r_prime)  # + 2**(k-1))
    for i, b in enumerate(bits(c, k)):
        d[i] = b + r[i] - 2 * b * r[i]
    return 1 - KOR(d, kappa)
예제 #12
0
def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
    """ Oblivious truncation by secret m """
    if util.is_constant(m) and not compute_modulo:
        # cheaper
        res = type(a)(size=a.size)
        comparison.Trunc(res, a, l, m, kappa, signed=signed)
        return res
    if l == 1:
        if compute_modulo:
            return a * m, 1 + m
        else:
            return a * (1 - m)
    if program.Program.prog.options.ring and not compute_modulo:
        return TruncInRing(a, l, Pow2(m, l, kappa))
    r = [types.sint() for i in range(l)]
    r_dprime = types.sint(0)
    r_prime = types.sint(0)
    rk = types.sint()
    c = types.cint()
    ci = [types.cint() for i in range(l)]
    d = types.sint()
    x, pow2m = B2U(m, l, kappa)
    #assert(pow2m.value == 2**m.value)
    #assert(sum(b.value for b in x) == m.value)
    for i in range(l):
        bit(r[i])
        t1 = two_power(i) * r[i]
        t2 = t1 * x[i]
        r_prime += t2
        r_dprime += t1 - t2
    #assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
    comparison.PRandInt(rk, kappa)
    r_dprime += two_power(l) * rk
    #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
    asm_open(c, a + r_dprime + r_prime)
    for i in range(1, l):
        ci[i] = c % two_power(i)
        #assert(ci[i].value == c.value % 2**i)
    c_dprime = sum(ci[i] * (x[i - 1] - x[i]) for i in range(1, l))
    #assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P))
    lts(d, c_dprime, r_prime, l, kappa)
    if compute_modulo:
        b = c_dprime - r_prime + pow2m * d
        return b, pow2m
    else:
        to_shift = a - c_dprime + r_prime
        if program.Program.prog.options.ring:
            shifted = TruncInRing(to_shift, l, pow2m)
        else:
            pow2inv = Inv(pow2m)
            #assert(pow2inv.value * pow2m.value % comparison.program.P == 1)
            shifted = to_shift * pow2inv
        b = shifted - d
    return b
예제 #13
0
def EQZ(a, k, kappa):
    r_dprime = types.sint()
    r_prime = types.sint()
    c = types.cint()
    d = [None]*k
    r = [types.sint() for i in range(k)]
    PRandM(r_dprime, r_prime, r, k, k, kappa)
    startopen(a + two_power(k) * r_dprime + r_prime)# + 2**(k-1))
    stopopen(c)
    for i,b in enumerate(bits(c, k)):
        d[i] = b + r[i] - 2*b*r[i]
    #return 1 - KOR(d, kappa)
    return 1 - KOR(d)
예제 #14
0
def TruncPrField(a, k, m, kappa=None):
    if kappa is None:
       kappa = 40 

    b = two_power(k-1) + a
    r_prime, r_dprime = types.sint(), types.sint()
    comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
                      k, m, kappa)
    two_to_m = two_power(m)
    r = two_to_m * r_dprime + r_prime
    c = (b + r).reveal()
    c_prime = c % two_to_m
    a_prime = c_prime - r_prime
    d = (a - a_prime) / two_to_m
    return d
예제 #15
0
def Norm(b, k, f, kappa, simplex_flag=False):
    """
        Computes secret integer values [c] and [v_prime] st.
        2^{k-1} <= c < 2^k and c = b*v_prime
    """
    # For simplex, we can get rid of computing abs(b)
    temp = None
    if simplex_flag == False:
        temp = types.sint(b < 0)
    elif simplex_flag == True:
        temp = types.cint(0)

    sign = 1 - 2 * temp  # 1 - 2 * [b < 0]
    absolute_val = sign * b

    #next 2 lines actually compute the SufOR for little indian encoding
    bits = absolute_val.bit_decompose(k)[::-1]
    suffixes = AdvInteger.PreOR(bits)[::-1]

    z = [0] * k
    for i in range(k - 1):
        z[i] = suffixes[i] - suffixes[i + 1]
    z[k - 1] = suffixes[k - 1]

    #doing complicated stuff to compute v = 2^{k-m}
    acc = types.cint(0)
    for i in range(k):
        acc += AdvInteger.two_power(k - i - 1) * z[i]

    part_reciprocal = absolute_val * acc
    signed_acc = sign * acc

    return part_reciprocal, signed_acc
예제 #16
0
def TruncPrRing(a, k, m, signed=True):
    if m == 0:
        return a
    n_ring = int(program.Program.prog.options.ring)
    assert n_ring >= k, '%d too large' % k
    if k == n_ring:
        for i in range(m):
            a += types.sint.get_random_bit() << i
        return comparison.TruncLeakyInRing(a, k, m, signed=signed)
    else:
        from types import sint
        if signed:
            a += (1 << (k - 1))
        if program.Program.prog.use_trunc_pr:
            res = sint()
            trunc_pr(res, a, k, m)
        else:
            # extra bit to mask overflow
            r_bits = [sint.get_random_bit() for i in range(k + 1)]
            n_shift = n_ring - len(r_bits)
            tmp = a + sint.bit_compose(r_bits)
            masked = (tmp << n_shift).reveal()
            shifted = (masked << 1 >> (n_shift + m + 1))
            overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
            res = shifted - sint.bit_compose(r_bits[m:k]) + \
                  (overflow << (k - m))
        if signed:
            res -= (1 << (k - m - 1))
        return res
예제 #17
0
def PreORC(a, kappa=None, m=None, raw=False):
    k = len(a)
    if k == 1:
        return [a[0]]
    m = m or k
    if isinstance(a[0], types.sgf2n):
        max_k = program.Program.prog.galois_length - 1
    else:
        max_k = int(log(program.Program.prog.P) / log(2)) - kappa
    if k <= max_k:
        p = [None] * m
        if m == k:
            p[0] = a[0]
        if isinstance(a[0], types.sgf2n):
            b = comparison.PreMulC([3 - a[i] for i in range(k)])
            for i in range(m):
                tmp = b[k-1-i]
                if not raw:
                    tmp = tmp.bit_decompose()[0]
                p[m-1-i] = 1 - tmp
        else:
            t = [types.sint() for i in range(m)]
            b = comparison.PreMulC([a[i] + 1 for i in range(k)])
            for i in range(m):
                comparison.Mod2(t[i], b[k-1-i], k, kappa, False)
                p[m-1-i] = 1 - t[i]
        return p
    else:
        # not constant-round anymore
        s = [PreORC(a[i:i+max_k], kappa, raw=raw) for i in range(0,k,max_k)]
        t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
        return sum(([or_op(x, y) for x in si] for si,y in zip(s[1:],t)), s[0])
예제 #18
0
def PreORC(a, kappa=None, m=None, raw=False):
    k = len(a)
    if k == 1:
        return [a[0]]
    m = m or k
    if isinstance(a[0], types.sgf2n):
        max_k = program.Program.prog.galois_length - 1
    else:
        max_k = int(log(program.Program.prog.P) / log(2)) - kappa
    if k <= max_k:
        p = [None] * m
        if m == k:
            p[0] = a[0]
        if isinstance(a[0], types.sgf2n):
            b = comparison.PreMulC([3 - a[i] for i in range(k)])
            for i in range(m):
                tmp = b[k - 1 - i]
                if not raw:
                    tmp = tmp.bit_decompose()[0]
                p[m - 1 - i] = 1 - tmp
        else:
            t = [types.sint() for i in range(m)]
            b = comparison.PreMulC([a[i] + 1 for i in range(k)])
            for i in range(m):
                comparison.Mod2(t[i], b[k - 1 - i], k, kappa, False)
                p[m - 1 - i] = 1 - t[i]
        return p
    else:
        # not constant-round anymore
        s = [
            PreORC(a[i:i + max_k], kappa, raw=raw) for i in range(0, k, max_k)
        ]
        t = PreORC([si[-1] for si in s[:-1]], kappa, raw=raw)
        return sum(([or_op(x, y) for x in si] for si, y in zip(s[1:], t)),
                   s[0])
예제 #19
0
def B2U_from_Pow2(pow2a, l, kappa):
    #assert(pow2a.value == 2**a.value)
    r = [types.sint() for i in range(l)]
    t = types.sint()
    c = types.cint()
    for i in range(l):
        bit(r[i])
    comparison.PRandInt(t, kappa)
    asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l)))
    comparison.program.curr_tape.require_bit_length(l + kappa)
    c = list(bits(c, l))
    x = [c[i] + r[i] - 2*c[i]*r[i] for i in range(l)]
    #print ' '.join(str(b.value) for b in x)
    y = PreOR(x, kappa)
    #print ' '.join(str(b.value) for b in y)
    return [1 - y[i] for i in range(l)]
예제 #20
0
def Trunc(a, l, m, kappa, compute_modulo=False):
    """ Oblivious truncation by secret m """
    if l == 1:
        if compute_modulo:
            return a * m, 1 + m
        else:
            return a * (1 - m)
    r = [types.sint() for i in range(l)]
    r_dprime = types.sint(0)
    r_prime = types.sint(0)
    rk = types.sint()
    c = types.cint()
    ci = [types.cint() for i in range(l)]
    d = types.sint()
    x, pow2m = B2U(m, l, kappa)
    #assert(pow2m.value == 2**m.value)
    #assert(sum(b.value for b in x) == m.value)
    for i in range(l):
        bit(r[i])
        t1 = two_power(i) * r[i]
        t2 = t1*x[i]
        r_prime += t2
        r_dprime += t1 - t2
    #assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
    comparison.PRandInt(rk, kappa)
    r_dprime += two_power(l) * rk
    #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
    asm_open(c, a + r_dprime + r_prime)
    for i in range(1,l):
        ci[i] = c % two_power(i)
        #assert(ci[i].value == c.value % 2**i)
    c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
    #assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P))
    lts(d, c_dprime, r_prime, l, kappa)
    if compute_modulo:
        b = c_dprime - r_prime + pow2m * d
        return b, pow2m
    else:
        to_shift = a - c_dprime + r_prime
        if program.Program.prog.options.ring:
            shifted = TruncInRing(to_shift, l, pow2m)
        else:
            pow2inv = Inv(pow2m)
            #assert(pow2inv.value * pow2m.value % comparison.program.P == 1)
            shifted = to_shift * pow2inv
        b = shifted - d
    return b
def KMulC(a):
    """
    Return just the product of all items in a
    """
    from types import sint, cint
    p = sint()
    PreMulC_without_inverses(p, a)
    return p
예제 #22
0
def TruncPr(a, k, m, kappa=None):
    """ Probabilistic truncation [a/2^m + u]
        where Pr[u = 1] = (a % 2^m) / 2^m
    """
    if kappa is None:
       kappa = 40
 
    b = two_power(k-1) + a
    r_prime, r_dprime = types.sint(), types.sint()
    comparison.PRandM(r_dprime, r_prime, [types.sint() for i in range(m)],
                      k, m, kappa)
    two_to_m = two_power(m)
    r = two_to_m * r_dprime + r_prime
    c = (b + r).reveal()
    c_prime = c % two_to_m
    a_prime = c_prime - r_prime
    d = (a - a_prime) / two_to_m
    return d
예제 #23
0
def Pow2(a, l, kappa):
    m = int(ceil(log(l, 2)))
    t = BitDec(a, m, m, kappa)
    x = [types.sint() for i in range(m)]
    pow2k = [types.cint() for i in range(m)]
    for i in range(m):
        pow2k[i] = two_power(2**i)
        t[i] = t[i] * pow2k[i] + 1 - t[i]
    return KMul(t)
예제 #24
0
def Pow2(a, l, kappa):
    m = int(ceil(log(l, 2)))
    t = BitDec(a, m, m, kappa)
    x = [types.sint() for i in range(m)]
    pow2k = [types.cint() for i in range(m)]
    for i in range(m):
        pow2k[i] = two_power(2**i)
        t[i] = t[i]*pow2k[i] + 1 - t[i]
    return KMul(t)
예제 #25
0
def LTZ(s, a, k, kappa):
    """
    s = (a ?< 0)

    k: bit length of a
    """
    from types import sint
    t = sint()
    Trunc(t, a, k, k - 1, kappa, True)
    subsfi(s, t, 0)
예제 #26
0
def BitDec(a, k, m, kappa, bits_to_compute=None):
    r_dprime = types.sint()
    r_prime = types.sint()
    c = types.cint()
    r = [types.sint() for i in range(m)]
    comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
    #assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P)
    pow2 = two_power(k + kappa)
    asm_open(c, pow2 + two_power(k) + a - two_power(m) * r_dprime - r_prime)
    #rval = 2**m*r_dprime.value + r_prime.value
    #assert(rval % 2**m == r_prime.value)
    #assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P ))
    try:
        pass  #assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P)
    except AssertionError:
        print 'BitDec assertion failed'
        print 'a =', a.value
        print 'a mod 2^%d =' % k, (a.value % 2**k)
    return BitAdd(list(bits(c, m)), r, bits_to_compute)[:-1]
예제 #27
0
def BitDecField(a, k, m, kappa, bits_to_compute=None):
    r_dprime = types.sint()
    r_prime = types.sint()
    c = types.cint()
    r = [types.sint() for i in range(m)]
    comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
    #assert(r_prime.value == sum(r[i].value*2**i for i in range(m)) % comparison.program.P)
    pow2 = two_power(k + kappa)
    asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
    #rval = 2**m*r_dprime.value + r_prime.value
    #assert(rval % 2**m == r_prime.value)
    #assert(rval == (2**m*r_dprime.value + sum(r[i].value*2**i for i in range(m)) % comparison.program.P ))
    try:
        pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P)
    except AssertionError:
        print 'BitDec assertion failed'
        print 'a =', a.value
        print 'a mod 2^%d =' % k, (a.value % 2**k)
    return types.intbitint.bit_adder(list(bits(c,m)), r)
예제 #28
0
def LTZ(s, a, k, kappa):
    """
    s = (a ?< 0)

    k: bit length of a
    """
    from types import sint
    t = sint()
    Trunc(t, a, k, k - 1, kappa, True)
    subsfi(s, t, 0)
예제 #29
0
def Trunc(a, l, m, kappa, compute_modulo=False):
    """ Oblivious truncation by secret m """
    if l == 1:
        if compute_modulo:
            return a * m, 1 + m
        else:
            return a * (1 - m)
    r = [types.sint() for i in range(l)]
    r_dprime = types.sint(0)
    r_prime = types.sint(0)
    rk = types.sint()
    c = types.cint()
    ci = [types.cint() for i in range(l)]
    d = types.sint()
    x, pow2m = B2U(m, l, kappa)
    #assert(pow2m.value == 2**m.value)
    #assert(sum(b.value for b in x) == m.value)
    for i in range(l):
        bit(r[i])
        t1 = two_power(i) * r[i]
        t2 = t1 * x[i]
        r_prime += t2
        r_dprime += t1 - t2
    #assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
    comparison.PRandInt(rk, kappa)
    r_dprime += two_power(l) * rk
    #assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
    asm_open(c, a + r_dprime + r_prime)
    for i in range(1, l):
        ci[i] = c % two_power(i)
        #assert(ci[i].value == c.value % 2**i)
    c_dprime = sum(ci[i] * (x[i - 1] - x[i]) for i in range(1, l))
    #assert(c_dprime.value == (sum(ci[i].value*(x[i-1].value - x[i].value) for i in range(1,l)) % comparison.program.P))
    lts(d, c_dprime, r_prime, l, kappa)
    if compute_modulo:
        b = c_dprime - r_prime + pow2m * d
        return b, pow2m
    else:
        pow2inv = Inv(pow2m)
        #assert(pow2inv.value * pow2m.value % comparison.program.P == 1)
        b = (a - c_dprime + r_prime) * pow2inv - d
    return b
예제 #30
0
def KMulC(a):
    """
    Return just the product of all items in a
    """
    from types import sint, cint
    p = sint()
    if use_inv:
        PreMulC_with_inverses(p, a)
    else:
        PreMulC_without_inverses(p, a)
    return p
예제 #31
0
def TruncRoundNearest(a, k, m, kappa, signed=False):
    """
    Returns a / 2^m, rounded to the nearest integer.

    k: bit length of a
    m: compile-time integer
    """
    from types import sint
    res = sint()
    Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
    return res
예제 #32
0
def Inv(a):
    """ Invert a non-zero value """
    t = [types.sint() for i in range(2)]
    c = [types.cint() for i in range(2)]
    one = types.cint()
    ldi(one, 1)
    square(t[0],t[1]);
    s = t[0]*a
    asm_open(c[0], s)
    divc(c[1], one, c[0])
    return c[1]*t[0]
예제 #33
0
def SDiv(a, b, l, kappa):
    theta = int(ceil(log(l / 3.5) / log(2)))
    alpha = two_power(2*l)
    beta = 1 / types.cint(two_power(l))
    w = types.cint(int(2.9142 * two_power(l))) - 2 * b
    x = alpha - b * w
    y = a * w
    y = TruncPr(y, 2 * l, l, kappa)
    x2 = types.sint()
    comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
    x1 = (x - x2) * beta
    for i in range(theta-1):
        y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
        y = TruncPr(y, 2 * l + 1, l + 1, kappa)
        x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa)
        x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa)
        x2 = types.sint()
        comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
        x1 = (x - x2) * beta
    y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
    y = TruncPr(y, 2 * l + 1, l - 1, kappa)
    return y
예제 #34
0
def Inv(a):
    """ Invert a non-zero value """
    t = [types.sint() for i in range(3)]
    c = [types.cint() for i in range(2)]
    one = types.cint()
    ldi(one, 1)
    inverse(t[0], t[1])
    s = t[0]*a
    asm_open(c[0], s)
    # avoid division by zero for benchmarking
    divc(c[1], one, c[0])
    #divc(c[1], c[0], one)
    return c[1]*t[0]
예제 #35
0
def SDiv(a, b, l, kappa):
    theta = int(ceil(log(l / 3.5) / log(2)))
    alpha = two_power(2 * l)
    beta = 1 / types.cint(two_power(l))
    w = types.cint(int(2.9142 * two_power(l))) - 2 * b
    x = alpha - b * w
    y = a * w
    y = TruncPr(y, 2 * l, l, kappa)
    x2 = types.sint()
    comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
    x1 = (x - x2) * beta
    for i in range(theta - 1):
        y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
        y = TruncPr(y, 2 * l + 1, l + 1, kappa)
        x = x1 * x2 + TruncPr(x2**2, 2 * l + 1, l + 1, kappa)
        x = x1 * x1 + TruncPr(x, 2 * l + 1, l - 1, kappa)
        x2 = types.sint()
        comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
        x1 = (x - x2) * beta
    y = y * (x1 + two_power(l)) + TruncPr(y * x2, 2 * l, l, kappa)
    y = TruncPr(y, 2 * l + 1, l - 1, kappa)
    return y
예제 #36
0
def Inv(a):
    """ Invert a non-zero value """
    t = [types.sint() for i in range(2)]
    c = [types.cint() for i in range(2)]
    one = types.cint()
    ldi(one, 1)
    square(t[0], t[1])
    s = t[0] * a
    asm_open(c[0], s)
    # avoid division by zero for benchmarking
    divc(c[1], one, c[0])
    #divc(c[1], c[0], one)
    return c[1] * t[0]
예제 #37
0
def Int2FL(a, gamma, l, kappa):
    lam = gamma - 1
    s = types.sint()
    comparison.LTZ(s, a, gamma, kappa)
    z = EQZ(a, gamma, kappa)
    a = (1 - 2 * s) * a
    a_bits = BitDec(a, lam, lam, kappa)
    a_bits.reverse()
    b = PreOR(a_bits, kappa)
    t = a * (1 + sum(2**i * (1 - b_i) for i,b_i in enumerate(b)))
    p = - (lam - sum(b))
    if gamma - 1 > l:
        if types.sfloat.round_nearest:
            v, overflow = TruncRoundNearestAdjustOverflow(t, gamma - 1, l, kappa)
            p = p + overflow
        else:
            v = types.sint()
            comparison.Trunc(v, t, gamma - 1, gamma - l - 1, kappa, False)
    else:
        v = 2**(l-gamma+1) * t
    p = (p + gamma - 1 - l) * (1 -z)
    return v, p, z, s
예제 #38
0
def Mod2mRing(a_prime, a, k, m, signed):
    assert(int(program.options.ring) >= k)
    from Compiler.types import sint, intbitint, cint
    shift = int(program.options.ring) - m
    r = [sint.get_random_bit() for i in range(m)]
    r_prime = sint.bit_compose(r)
    tmp = a + r_prime
    c_prime = (tmp << shift).reveal() >> shift
    u = sint()
    BitLTL(u, c_prime, r, 0)
    res = (u << m) + c_prime - r_prime
    if a_prime is not None:
        movs(a_prime, res)
    return res
예제 #39
0
def Mod2mRing(a_prime, a, k, m, signed):
    assert(int(program.options.ring) >= k)
    from Compiler.types import sint, intbitint, cint
    shift = int(program.options.ring) - m
    r = [sint.get_random_bit() for i in range(m)]
    r_prime = sint.bit_compose(r)
    tmp = a + r_prime
    c_prime = (tmp << shift).reveal() >> shift
    u = sint()
    BitLTL(u, c_prime, r, 0)
    res = (u << m) + c_prime - r_prime
    if a_prime is not None:
        movs(a_prime, res)
    return res
예제 #40
0
def TruncRoundNearest(a, k, m, kappa):
    """
    Returns a / 2^m, rounded to the nearest integer.

    k: bit length of m
    m: compile-time integer
    """
    from types import sint, cint
    from library import reveal, load_int_to_secret
    if m == 1:
        if program.options.ring:
            lsb = Mod2mRing(None, a, k, 1, False)
            return TruncRing(None, a + lsb, k + 1, 1, False)
        else:
            lsb = sint()
            Mod2(lsb, a, k, kappa, False)
            return (a + lsb) / 2
    r_dprime = sint()
    r_prime = sint()
    r = [sint() for i in range(m)]
    u = sint()
    PRandM(r_dprime, r_prime, r, k, m, kappa)
    c = reveal((cint(1) << (k - 1)) + a + (cint(1) << m) * r_dprime + r_prime)
    c_prime = c % (cint(1) << (m - 1))
    if const_rounds:
        BitLTC1(u, c_prime, r[:-1], kappa)
    else:
        BitLTL(u, c_prime, r[:-1], kappa)
    bit = ((c - c_prime) >> (m - 1)) % 2
    xor = bit + u - 2 * bit * u
    prod = xor * r[-1]
    # u_prime = xor * u + (1 - xor) * r[-1]
    u_prime = bit * u + u - 2 * bit * u + r[-1] - prod
    a_prime = (c % (cint(1) << m)) - r_prime + (cint(1) << m) * u_prime
    d = (a - a_prime) >> m
    rounding = xor + r[-1] - 2 * prod
    return d + rounding
예제 #41
0
def TruncRoundNearest(a, k, m, kappa, signed=False):
    """
    Returns a / 2^m, rounded to the nearest integer.

    k: bit length of a
    m: compile-time integer
    """
    if k == int(program.options.ring):
        # cannot work with bit length k+1
        tmp = TruncRing(None, a, k, m - 1, signed)
        return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
    from types import sint
    res = sint()
    Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
    return res
예제 #42
0
def TruncZeroesInRing(a, k, m, signed):
    """
    Returns a >> m.
    Requires 2^m | a and a < 2^k.
    """
    from types import sint, intbitint, cint, cgf2n
    n_bits = k - m
    n_shift = int(program.options.ring) - n_bits
    r_bits = [sint.get_random_bit() for i in range(n_bits)]
    r = sint.bit_compose(r_bits)
    shifted = ((a << (n_shift - m)) - (r << n_shift)).reveal()
    masked = shifted >> n_shift
    res_bits = intbitint.bit_adder(r_bits, masked.bit_decompose(n_bits))
    res = sint.bit_compose(res_bits)
    if signed:
        res = sint.conv(res_bits[-1].if_else(res - (sint(1) << n_bits), res))
    return res
예제 #43
0
def TruncZeroesInRing(a, k, m, signed):
    """
    Returns a >> m.
    Requires 2^m | a and a < 2^k.
    """
    from types import sint, intbitint, cint, cgf2n
    n_bits = k - m
    n_shift = int(program.options.ring) - n_bits
    r_bits = [sint.get_random_bit() for i in range(n_bits)]
    r = sint.bit_compose(r_bits)
    shifted = ((a << (n_shift - m)) - (r << n_shift)).reveal()
    masked = shifted >> n_shift
    res_bits = intbitint.bit_adder(r_bits, masked.bit_decompose(n_bits))
    res = sint.bit_compose(res_bits)
    if signed:
        res = sint.conv(res_bits[-1].if_else(res - (sint(1) << n_bits),
                                             res))
    return res
예제 #44
0
def TruncLeakyInRing(a, k, m, signed):
    """
    Returns a >> m.
    Requires a < 2^k and leaks a % 2^m (needs to be constant or random).
    """
    assert k > m
    assert int(program.options.ring) >= k
    from types import sint, intbitint, cint, cgf2n
    n_bits = k - m
    n_shift = int(program.options.ring) - n_bits
    r_bits = [sint.get_random_bit() for i in range(n_bits)]
    r = sint.bit_compose(r_bits)
    if signed:
        a += (1 << (k - 1))
    shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
    masked = shifted >> n_shift
    u = sint()
    BitLTL(u, masked, r_bits, 0)
    res = (u << n_bits) + masked - r
    if signed:
        res -= (1 << (n_bits - 1))
    return res
예제 #45
0
def FLRound(x, mode):
    """ Rounding with floating point output.
    *mode*: 0 -> floor, 1 -> ceil, -1 > trunc """
    v1, p1, z1, s1, l, k = x.v, x.p, x.z, x.s, x.vlen, x.plen
    a = types.sint()
    comparison.LTZ(a, p1, k, x.kappa)
    b = p1.less_than(-l + 1, k, x.kappa)
    v2, inv_2pow_p1 = Trunc(v1, l, -a * (1 - b) * x.p, x.kappa, True)
    c = EQZ(v2, l, x.kappa)
    if mode == -1:
        away_from_zero = 0
        mode = x.s
    else:
        away_from_zero = mode + s1 - 2 * mode * s1
    v = v1 - v2 + (1 - c) * inv_2pow_p1 * away_from_zero
    d = v.equal(two_power(l), l + 1, x.kappa)
    v = d * two_power(l-1) + (1 - d) * v
    v = a * ((1 - b) * v + b * away_from_zero * two_power(l-1)) + (1 - a) * v1
    s = (1 - b * mode) * s1
    z = or_op(EQZ(v, l, x.kappa), z1)
    v = v * (1 - z)
    p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
    return v, p, z, s
예제 #46
0
def TruncPrRing(a, k, m):
    if m == 0:
        return a
    res = types.sint()
    comparison.TruncRing(res, a, k, m, True)
    return res
예제 #47
0
 def expand(self):
     from types import sint
     a = sint()
     subs(a, self.args[1], self.args[2])
     comparison.LTZ(self.args[0], a, self.args[3], self.args[4])