Exemple #1
0
 def to_regint_by_bit(self):
     if self.n != None:
         res = regint(size=self.n)
     else:
         res = regint()
     inst.convcbitvec(self.n, res, self)
     return res
Exemple #2
0
def approximate_reciprocal(divisor, k, f, theta):
    """
        returns aproximation of 1/divisor
        where type(divisor) = cint
    """
    def twos_complement(x):
        bits = x.bit_decompose(k)[::-1]
        bit_array = Array(k, cint)
        bit_array.assign(bits)

        twos_result = MemValue(cint(0))

        @for_range(k)
        def block(i):
            val = twos_result.read()
            val <<= 1
            val += 1 - bit_array[i]
            twos_result.write(val)

        return twos_result.read() + 1

    bit_array = Array(k, cint)
    bits = divisor.bit_decompose(k)[::-1]
    bit_array.assign(bits)

    cnt_leading_zeros = MemValue(regint(0))

    flag = MemValue(regint(0))
    cnt_leading_zeros = MemValue(regint(0))
    normalized_divisor = MemValue(divisor)

    @for_range(k)
    def block(i):
        flag.write(flag.read() | bit_array[i] == 1)

        @if_(flag.read() == 0)
        def block():
            cnt_leading_zeros.write(cnt_leading_zeros.read() + 1)
            normalized_divisor.write(normalized_divisor << 1)

    q = MemValue(two_power(k))
    e = MemValue(twos_complement(normalized_divisor.read()))

    @for_range(theta)
    def block(i):
        qread = q.read()
        eread = e.read()
        qread += (qread * eread) >> k
        eread = (eread * eread) >> k

        q.write(qread)
        e.write(eread)

    res = q >> (2 * k - 2 * f - cnt_leading_zeros)

    return res
Exemple #3
0
def approximate_reciprocal(divisor, k, f, theta):
    """
        returns aproximation of 1/divisor
        where type(divisor) = cint
    """
    def twos_complement(x):
        bits = x.bit_decompose(k)[::-1]
        bit_array = Array(k, cint)
        bit_array.assign(bits)

        twos_result = MemValue(cint(0))
        @for_range(k)
        def block(i):
            val = twos_result.read()
            val <<= 1
            val += 1 - bit_array[i]
            twos_result.write(val)

        return twos_result.read() + 1

    bit_array = Array(k, cint)
    bits = divisor.bit_decompose(k)[::-1]
    bit_array.assign(bits)

    cnt_leading_zeros = MemValue(regint(0))

    flag = MemValue(regint(0))
    cnt_leading_zeros = MemValue(regint(0))
    normalized_divisor = MemValue(divisor)

    @for_range(k)
    def block(i):
        flag.write(flag.read() | bit_array[i] == 1)
        @if_(flag.read() == 0)
        def block():
            cnt_leading_zeros.write(cnt_leading_zeros.read() + 1)
            normalized_divisor.write(normalized_divisor << 1)

    q = MemValue(two_power(k))
    e = MemValue(twos_complement(normalized_divisor.read()))

    qr = q.read()
    er = e.read()

    for i in range(theta):
        qr = qr + shift_two(qr * er, k)
        er = shift_two(er * er, k)

    q = qr
    res = shift_two(q, (2*k - 2*f - cnt_leading_zeros))

    return res
Exemple #4
0
 def to_regint(self, dest=None):
     if dest is None:
         dest = regint()
     if self.n > 64:
         raise CompilerError('too many bits')
     inst.convcbit(dest, self)
     return dest
Exemple #5
0
 def load_int(self, value):
     n_limbs = math.ceil(self.n / self.unit)
     tmp = regint(size=n_limbs)
     for i in range(n_limbs):
         tmp[i].load_int(value % 2 ** self.unit)
         value >>= self.unit
     self.load_other(tmp)
Exemple #6
0
 def get_bits_loop():
    # How can we do this with a vectorized load of the bits? XXXX
    tbits = [sint(0)]*bit_length
    for i in range(bit_length):
      tbits[i] = sint.get_random_bit()
      tbits[i].store_in_mem(i)
    c = regint(BitLTFull(tbits, pbits, bit_length).reveal())
    return (c!=1)
Exemple #7
0
 def load_int(self, value):
     if self.n <= 64:
         tmp = regint(value)
     elif value == self.long_one():
         tmp = cint(1, size=self.n)
     else:
         raise CompilerError('loading long integers to cbits not supported')
     self.load_other(tmp)
Exemple #8
0
def convert_to_float(s):
    y = float(s)
    y_in_bytes = struct.pack(">d", y)
    writable_buf = array.array('c', ' ' * 8)
    # Need to reverse bytes and bits in the bytes
    for i in range(8):
        struct.pack_into("c", writable_buf, i, y_in_bytes[7 - i])
    y_as_regint = struct.unpack("<Q", writable_buf)
    res = regint(y_as_regint[0])
    return res
Exemple #9
0
 def loop(j):
     if_then(and_(lambda: i0 < i_right,
                  or_(lambda: i1 >= i_end,
                      lambda: regint(reveal(A[i0] <= A[i1])))))
     B[j] = A[i0]
     i0.iadd(1)
     else_then()
     B[j] = A[i1]
     i1.iadd(1)
     end_if()
Exemple #10
0
 def __invert__(self):
     # res = type(self)(n=self.n)
     # inst.nots(res, self)
     # return res
     if self.n == None or self.n > self.unit:
         one = self.get_type(self.n)()
         self.conv_regint_by_bit(self.n, one, regint(1, size=self.n))
     else:
         one = self.new(value=self.long_one(), n=self.n)
     return self + one
Exemple #11
0
 def loop(j):
     if_then(and_(lambda: i0 < i_right,
                  or_(lambda: i1 >= i_end,
                      lambda: regint(reveal(A[i0] <= A[i1])))))
     B[j] = A[i0]
     i0.iadd(1)
     else_then()
     B[j] = A[i1]
     i1.iadd(1)
     end_if()
Exemple #12
0
 def load_int(self, value):
     if abs(value) < 2**31:
         if (abs(value) > (1 << self.n)):
             raise Exception('public value %d longer than %d bits' \
                             % (value, self.n))
         inst.ldbits(self, self.n, value)
     else:
         value %= 2**self.n
         if value >> 64 != 0:
             raise NotImplementedError('public value too large')
         self.load_other(regint(value))
Exemple #13
0
    def match(self, n_loops=None):
        if n_loops is None or n_loops > self.N * self.M:
            loop = do_while
            init_rounds = self.N
        else:
            loop = for_range(n_loops)
            init_rounds = n_loops / self.M
        self.wives = \
            self.oram_type(self.N, entry_size=log2(self.N), \
                               init_rounds=0, value_type=self.basic_type)
        self.husbands = \
            self.oram_type(self.N, entry_size=log2(self.N), \
                               init_rounds=0, value_type=self.basic_type)
        propose = \
            self.oram_type(self.N, entry_size=log2(self.N), \
                               init_rounds=0, value_type=self.basic_type)
        self.unengaged = OStack(self.N, oram_type=self.oram_type, \
                                    int_type=self.int_type)

        @for_range(init_rounds)
        def f(i):
            self.unengaged.append(i)

        rounds = types.MemValue(types.regint(0))

        @loop
        def f(i=None):
            rounds.iadd(1)
            time()
            man = self.unengaged.pop()
            #self.husbands.ram[0].x[0].reveal().print_reg('j')
            pref = self.int_type(propose[man])
            if self.M < self.N and n_loops is None:

                @if_((pref == self.M).reveal())
                def f():
                    print_ln('run out of acceptable women')
                    crash()

            #self.husbands.ram[0].x[0].reveal().print_reg('k')
            propose[man] = pref + 1
            #self.husbands.ram[0].x[0].reveal().print_reg('l')
            self.propose(man, self.m_prefs[man][pref], True)
            print_ln('man: %s, pref: %s, left: %s', \
                         *(x.reveal() for x in (man, pref, self.unengaged.size)))
            # self.wives[man].reveal().print_reg('wife')
            return types.regint((self.unengaged.size > 0).reveal())

        print_ln('%s rounds', rounds)

        @for_range(init_rounds)
        def f(i):
            types.cint(i).print_reg('wife')
            self.husbands[i].reveal().print_reg('husb')
Exemple #14
0
def do_loop(condition, loop_fn):
    # store initial condition to stack
    pushint(condition if isinstance(condition,regint) else regint(condition))
    def wrapped_loop():
        # save condition to stack
        new_cond = regint.pop()
        # run the loop
        condition = loop_fn(new_cond)
        pushint(condition)
        return condition
    do_while(wrapped_loop)
    regint.pop()
Exemple #15
0
 def load_int(self, value):
     if (abs(value) > (1 << self.n)):
         raise Exception('public value %d longer than %d bits' \
                         % (value, self.n))
     if self.n <= 32:
         inst.ldbits(self, self.n, value)
     else:
         size = math.ceil(self.n / self.unit)
         tmp = regint(size=size)
         for i in range(size):
             tmp[i].load_int((value >> (i * 64)) % 2**64)
         self.load_other(tmp)
Exemple #16
0
def do_loop(condition, loop_fn):
    # store initial condition to stack
    pushint(condition if isinstance(condition,regint) else regint(condition))
    def wrapped_loop():
        # save condition to stack
        new_cond = regint.pop()
        # run the loop
        condition = loop_fn(new_cond)
        pushint(condition)
        return condition
    do_while(wrapped_loop)
    regint.pop()
Exemple #17
0
def BitDecFull(a):
    from Compiler.types import sint, regint
    from Compiler.library import do_while
    p = program.P
    bit_length = p.bit_length()
    if bit_length > 63:
        return BitDecFullBig(a)
    abits = [sint(0)] * bit_length
    bbits = [sint(0)] * bit_length
    pbits = list(bits(p, bit_length))
    # Loop until we get some random integers less than p
    @do_while
    def get_bits_loop():
        # How can we do this with a vectorized load of the bits? XXXX
        tbits = [sint(0)] * bit_length
        for i in range(bit_length):
            tbits[i] = sint.get_random_bit()
            tbits[i].store_in_mem(i)
        c = regint(BitLTFull(tbits, pbits, bit_length).reveal())
        return (c != 1)

    for i in range(bit_length):
        bbits[i] = sint.load_mem(i)
    b = SumBits(bbits, bit_length)
    # Reveal c in the correct range
    c = regint((a - b).reveal())
    bit = c < 0
    c = c + p * bit
    czero = (c == 0)
    t = (p - c).bit_decompose(bit_length)
    q = 1 - BitLTFull(bbits, t, bit_length)
    fbar = ((1 << bit_length) + c - p).bit_decompose(bit_length)
    fbard = regint(c).bit_decompose(bit_length)
    g = [sint(0)] * (bit_length)
    for i in range(bit_length):
        g[i] = (fbar[i] - fbard[i]) * q + fbard[i]
    h = BitAdd(bbits, g)
    for i in range(bit_length):
        abits[i] = (1 - czero) * h[i] + czero * bbits[i]
    return abits
Exemple #18
0
def while_loop(loop_body, condition, arg):
    if not callable(condition):
        raise CompilerError('Condition must be callable')
    # store arg in stack
    pre_condition = condition(arg)
    if not isinstance(pre_condition, (bool,int)) or pre_condition:
        pushint(arg if isinstance(arg,regint) else regint(arg))
        def loop_fn():
            result = loop_body(regint.pop())
            pushint(result)
            return condition(result)
        if_statement(pre_condition, lambda: do_while(loop_fn))
        regint.pop()
Exemple #19
0
def while_loop(loop_body, condition, arg):
    if not callable(condition):
        raise CompilerError('Condition must be callable')
    # store arg in stack
    pre_condition = condition(arg)
    if not isinstance(pre_condition, (bool,int)) or pre_condition:
        pushint(arg if isinstance(arg,regint) else regint(arg))
        def loop_fn():
            result = loop_body(regint.pop())
            pushint(result)
            return condition(result)
        if_statement(pre_condition, lambda: do_while(loop_fn))
        regint.pop()
Exemple #20
0
 def load_int(self, value):
     if (abs(value) > (1 << self.n)):
         raise Exception('public value %d longer than %d bits' \
                         % (value, self.n))
     if self.n <= 32:
         inst.ldbits(self, self.n, value)
     elif self.n <= 64:
         self.load_other(regint(value))
     elif self.n <= 128:
         lower = sbits.get_type(64)(value % 2**64)
         upper = sbits.get_type(self.n - 64)(value >> 64)
         self.mov(self, lower + (upper << 64))
     else:
         raise NotImplementedError('more than 128 bits wanted')
Exemple #21
0
 def load_int(self, value):
     if (abs(value) > (1 << self.n)):
         raise Exception('public value %d longer than %d bits' \
                         % (value, self.n))
     if self.n <= 32:
         inst.ldbits(self, self.n, value)
     elif self.n <= 64:
         self.load_other(regint(value))
     elif self.n <= 128:
         lower = sbits.get_type(64)(value % 2**64)
         upper = sbits.get_type(self.n - 64)(value >> 64)
         self.mov(self, lower + (upper << 64))
     else:
         raise NotImplementedError('more than 128 bits wanted')
def BitDecFull(a):
    from .library import get_program, do_while, if_, break_point
    from .types import sint, regint, longint
    p = get_program().prime
    assert p
    bit_length = p.bit_length()
    logp = int(round(math.log(p, 2)))
    if abs(p - 2**logp) / p < 2**-get_program().security:
        # inspired by Rabbit (https://eprint.iacr.org/2021/119)
        # no need for exact randomness generation
        # if modulo a power of two is close enough
        bbits = [sint.get_random_bit(size=a.size) for i in range(logp)]
        if logp != bit_length:
            bbits += [sint(0, size=a.size)]
    else:
        bbits = [sint(size=a.size) for i in range(bit_length)]
        tbits = [[sint(size=1) for i in range(bit_length)]
                 for j in range(a.size)]
        pbits = util.bit_decompose(p)
        # Loop until we get some random integers less than p
        done = [regint(0) for i in range(a.size)]

        @do_while
        def get_bits_loop():
            for j in range(a.size):

                @if_(done[j] == 0)
                def _():
                    for i in range(bit_length):
                        tbits[j][i].link(sint.get_random_bit())
                    c = regint(BITLT(tbits[j], pbits, bit_length).reveal())
                    done[j].link(c)

            return (sum(done) != a.size)

        for j in range(a.size):
            for i in range(bit_length):
                movs(bbits[i][j], tbits[j][i])
    b = sint.bit_compose(bbits)
    c = (a - b).reveal()
    t = (p - c).bit_decompose(bit_length)
    c = longint(c, bit_length)
    czero = (c == 0)
    q = 1 - BITLT(bbits, t, bit_length)
    fbar = ((1 << bit_length) + c - p).bit_decompose(bit_length)
    fbard = c.bit_decompose(bit_length)
    g = [(fbar[i] - fbard[i]) * q + fbard[i] for i in range(bit_length)]
    h = BitAdd(bbits, g)
    abits = [(1 - czero) * h[i] + czero * bbits[i] for i in range(bit_length)]
    return abits
Exemple #23
0
 def on_call(self, base, bases):
     if base is not None:
         instructions.starg(regint(base))
     block = self.basic_block
     if block not in get_tape().function_basicblocks:
         raise CompilerError('unknown function')
     old_block = get_tape().active_basicblock
     old_block.set_exit(instructions.CALL(0, add_to_prog=False), block)
     p_return_address = get_tape().function_basicblocks[block]
     return_address = get_tape().new_reg('r')
     get_tape().start_new_basicblock(name='call-' + self.name)
     get_tape().active_basicblock.set_return(old_block, self.last_sub_block)
     get_tape().req_node.children.append(self.node)
     if self.result is not None:
         return unmemorize(self.result)
Exemple #24
0
def or_(*terms):
    # not thread-safe
    p_res = instructions.program.malloc(1, 'ci')
    res = regint()
    for term in terms:
        if_then(term())
        store_in_mem(1, p_res)
        else_then()
    store_in_mem(0, p_res)
    for term in terms:
        end_if()
    def load_result():
        res = regint.load_mem(p_res)
        instructions.program.free(p_res, 'ci')
        return res
    return load_result
Exemple #25
0
def or_(*terms):
    # not thread-safe
    p_res = instructions.program.malloc(1, 'r')
    res = regint()
    for term in terms:
        if_then(term())
        store_in_mem(1, p_res)
        else_then()
    store_in_mem(0, p_res)
    for term in terms:
        end_if()
    def load_result():
        res = regint.load_mem(p_res)
        instructions.program.free(p_res, 'r')
        return res
    return load_result
Exemple #26
0
 def match(self, n_loops=None):
     if n_loops is None or n_loops > self.N * self.M:
         loop = do_while
         init_rounds = self.N
     else:
         loop = for_range(n_loops)
         init_rounds = n_loops / self.M
     self.wives = \
         self.oram_type(self.N, entry_size=log2(self.N), \
                            init_rounds=0, value_type=self.basic_type)
     self.husbands = \
         self.oram_type(self.N, entry_size=log2(self.N), \
                            init_rounds=0, value_type=self.basic_type)
     propose = \
         self.oram_type(self.N, entry_size=log2(self.N), \
                            init_rounds=0, value_type=self.basic_type)
     self.unengaged = OStack(self.N, oram_type=self.oram_type, \
                                 int_type=self.int_type)
     @for_range(init_rounds)
     def f(i):
         self.unengaged.append(i)
     rounds = types.MemValue(types.regint(0))
     @loop
     def f(i=None):
         rounds.iadd(1)
         time()
         man = self.unengaged.pop()
         #self.husbands.ram[0].x[0].reveal().print_reg('j')
         pref = self.int_type(propose[man])
         if self.M < self.N and n_loops is None:
             @if_((pref == self.M).reveal())
             def f():
                 print_ln('run out of acceptable women')
                 crash()
         #self.husbands.ram[0].x[0].reveal().print_reg('k')
         propose[man] = pref + 1
         #self.husbands.ram[0].x[0].reveal().print_reg('l')
         self.propose(man, self.m_prefs[man][pref], True)
         print_ln('man: %s, pref: %s, left: %s', \
                      *(x.reveal() for x in (man, pref, self.unengaged.size)))
         # self.wives[man].reveal().print_reg('wife')
         return types.regint((self.unengaged.size > 0).reveal())
     print_ln('%s rounds', rounds)
     @for_range(init_rounds)
     def f(i):
         types.cint(i).print_reg('wife')
         self.husbands[i].reveal().print_reg('husb')
Exemple #27
0
 def on_call(self, base, bases):
     if base is not None:
         instructions.starg(regint(base))
     block = self.basic_block
     if block not in get_tape().function_basicblocks:
         raise CompilerError('unknown function')
     old_block = get_tape().active_basicblock
     old_block.set_exit(instructions.jmp(0, add_to_prog=False), block)
     p_return_address = get_tape().function_basicblocks[block]
     return_address = get_tape().new_reg('ci')
     old_block.return_address_store = instructions.ldint(return_address, 0)
     instructions.stmint(return_address, p_return_address)
     get_tape().start_new_basicblock(name='call-' + self.name)
     get_tape().active_basicblock.set_return(old_block, self.last_sub_block)
     get_tape().req_node.children.append(self.node)
     if self.result is not None:
         return unmemorize(self.result)
Exemple #28
0
 def f(i=None):
     rounds.iadd(1)
     time()
     man = self.unengaged.pop()
     #self.husbands.ram[0].x[0].reveal().print_reg('j')
     pref = self.int_type(propose[man])
     if self.M < self.N and n_loops is None:
         @if_((pref == self.M).reveal())
         def f():
             print_ln('run out of acceptable women')
             crash()
     #self.husbands.ram[0].x[0].reveal().print_reg('k')
     propose[man] = pref + 1
     #self.husbands.ram[0].x[0].reveal().print_reg('l')
     self.propose(man, self.m_prefs[man][pref], True)
     print_ln('man: %s, pref: %s, left: %s', \
                  *(x.reveal() for x in (man, pref, self.unengaged.size)))
     # self.wives[man].reveal().print_reg('wife')
     return types.regint((self.unengaged.size > 0).reveal())
Exemple #29
0
 def f(i=None):
     rounds.iadd(1)
     time()
     man = self.unengaged.pop()
     #self.husbands.ram[0].x[0].reveal().print_reg('j')
     pref = self.int_type(propose[man])
     if self.M < self.N and n_loops is None:
         @if_((pref == self.M).reveal())
         def f():
             print_ln('run out of acceptable women')
             crash()
     #self.husbands.ram[0].x[0].reveal().print_reg('k')
     propose[man] = pref + 1
     #self.husbands.ram[0].x[0].reveal().print_reg('l')
     self.propose(man, self.m_prefs[man][pref], True)
     print_ln('man: %s, pref: %s, left: %s', \
                  *(x.reveal() for x in (man, pref, self.unengaged.size)))
     # self.wives[man].reveal().print_reg('wife')
     return types.regint((self.unengaged.size > 0).reveal())
Exemple #30
0
def bits(a,m):
    """ Get the bits of an int """
    if isinstance(a, int):
        res = [None]*m
        for i in range(m):
            res[i] = a & 1
            a >>= 1
    else:
        res = []
        from Compiler.types import regint, cint
        while m > 0:
            aa = regint()
            convmodp(aa, a, bitlength=0)
            res += [cint(x) for x in aa.bit_decompose(min(64, m))]
            m -= 64
            if m > 0:
                aa = cint()
                shrci(aa, a, 64)
                a = aa
    return res
Exemple #31
0
def bits(a,m):
    """ Get the bits of an int """
    if isinstance(a, int):
        res = [None]*m
        for i in range(m):
            res[i] = a & 1
            a >>= 1
    else:
        res = []
        from Compiler.types import regint, cint
        while m > 0:
            aa = regint()
            convmodp(aa, a, bitlength=0)
            res += [cint(x) for x in aa.bit_decompose(min(64, m))]
            m -= 64
            if m > 0:
                aa = cint()
                shrci(aa, a, 64)
                a = aa
    return res
Exemple #32
0
def BitDecFull(a):
    from .library import get_program, do_while, if_, break_point
    from .types import sint, regint, longint
    p = int(get_program().options.prime)
    assert p
    bit_length = p.bit_length()
    bbits = [sint(size=a.size) for i in range(bit_length)]
    tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)]
    pbits = util.bit_decompose(p)
    # Loop until we get some random integers less than p
    done = [regint(0) for i in range(a.size)]

    @do_while
    def get_bits_loop():
        for j in range(a.size):

            @if_(done[j] == 0)
            def _():
                for i in range(bit_length):
                    tbits[j][i].link(sint.get_random_bit())
                c = regint(BITLT(tbits[j], pbits, bit_length).reveal())
                done[j].link(c)

        return (sum(done) != a.size)

    for j in range(a.size):
        for i in range(bit_length):
            movs(bbits[i][j], tbits[j][i])
    b = sint.bit_compose(bbits)
    c = (a - b).reveal()
    t = (p - c).bit_decompose(bit_length)
    c = longint(c, bit_length)
    czero = (c == 0)
    q = 1 - BITLT(bbits, t, bit_length)
    fbar = ((1 << bit_length) + c - p).bit_decompose(bit_length)
    fbard = c.bit_decompose(bit_length)
    g = [(fbar[i] - fbard[i]) * q + fbard[i] for i in range(bit_length)]
    h = BitAdd(bbits, g)
    abits = [(1 - czero) * h[i] + czero * bbits[i] for i in range(bit_length)]
    return abits
Exemple #33
0
def load_int(value, size=None):
    return regint(value, size=size)
Exemple #34
0
def public_input():
    res = regint()
    pubinput(res)
    return res
Exemple #35
0
def get_thread_number():
    res = regint()
    ldtn(res)
    return res
Exemple #36
0
 def load_int(self, value):
     self.load_other(regint(value))
 def _():
     for i in range(bit_length):
         tbits[j][i].link(sint.get_random_bit())
     c = regint(BITLT(tbits[j], pbits, bit_length).reveal())
     done[j].link(c)
Exemple #38
0
def get_arg():
    res = regint()
    ldarg(res)
    return res
Exemple #39
0
def public_input():
    res = regint()
    pubinput(res)
    return res
Exemple #40
0
 def load_int(self, value):
     self.load_other(regint(value))
Exemple #41
0
 def load_int(self, value):
     if abs(value) < 2**31:
         inst.ldsi(self, value)
     else:
         self.load_other(regint(value))
Exemple #42
0
def get_arg():
    res = regint()
    ldarg(res)
    return res
Exemple #43
0
def get_thread_number():
    res = regint()
    ldtn(res)
    return res
Exemple #44
0
def open_channel(channel=0):
   r"""Open channel and return any possible error message
   """
   res = regint()
   open_chan(res,channel)
   return res
Exemple #45
0
def load_int(value, size=None):
    return regint(value, size=size)