Beispiel #1
0
def generate_matmul_triple_communication(shapes, workers):
    r, s, t = generate_matmul_triple(shapes)

    n_workers = len(workers)
    r_shares = share(r, n_workers)
    s_shares = share(s, n_workers)
    t_shares = share(t, n_workers)

    # For r, s, t as a shared var, send each share to its worker
    for var_shares in [r_shares, s_shares, t_shares]:
        for var_share, worker in zip(var_shares, workers):
            var_share.send(worker)

    # Build the pointer dict for r, s, t. Note that we remove the head of the pointer (via .child)
    gp_r = sy._GeneralizedPointerTensor({
        share.location: share.child for share in r_shares
    }).on(r)
    gp_s = sy._GeneralizedPointerTensor({
        share.location: share.child for share in s_shares
    }).on(s)
    gp_t = sy._GeneralizedPointerTensor({
        share.location: share.child for share in t_shares
    }).on(t)
    triple = [gp_r, gp_s, gp_t]
    return triple
Beispiel #2
0
def generate_zero_shares_communication(alice, bob, *sizes):
    zeros = torch.zeros(*sizes)
    u_alice, u_bob = share(zeros)
    u_alice.send(alice)
    u_bob.send(bob)
    u_gp = sy._GeneralizedPointerTensor({alice: u_alice.child, bob: u_bob.child})
    return u_gp
Beispiel #3
0
def swap_shares(shares):
    ptd = shares.child.pointer_tensor_dict
    alice, bob = list(ptd.keys())
    new_alice = (ptd[alice]+0)
    new_bob = (ptd[bob]+0)
    new_alice.send(bob)
    new_bob.send(alice)

    return sy._GeneralizedPointerTensor({alice: new_bob,bob: new_alice}).on(sy.LongTensor([]))
Beispiel #4
0
def msb(a_sh, alice, bob):

    input_shape = a_sh.get_shape()
    a_sh = a_sh.view(-1)

    # the commented out numbers below correspond to the
    # line numbers in Table 5 of the SecureNN paper
    # https://eprint.iacr.org/2018/442.pdf

    # 1)

    x = torch.LongTensor(a_sh.get_shape()).random_(L - 1)
    x_bit = decompose(x)
    x_sh = x.share(bob, alice)
    x_bit_0 = x_bit[..., -1:]  # pretty sure decompose is backwards...
    x_bit_sh_0 = x_bit_0.share(bob, alice).child.child  # least -> greatest from left -> right
    x_bit_sh = x_bit.share(bob, alice)

    # 2)
    y_sh = 2 * a_sh
    r_sh = y_sh + x_sh

    # 3)
    r = r_sh.get()  # .send(bob, alice) #TODO: make this secure by exchanging shares remotely
    r_0 = decompose(r)[..., -1].send(bob, alice)
    r = r.send(bob, alice)

    j0 = torch.zeros(x_bit_sh.get_shape()).long().send(bob).child
    j1 = (torch.ones(x_bit_sh.get_shape())).long().send(alice).child
    j = syft._GeneralizedPointerTensor({bob: j0, alice: j1}, torch_type='syft.LongTensor').wrap(True)
    j_0 = j[..., -1]

    # 4)
    BETA = (torch.rand(a_sh.get_shape()) > 0.5).long().send(bob, alice)
    BETA_prime = private_compare(x_bit_sh,
                                 r,
                                 BETA=BETA,
                                 j=j,
                                 alice=alice,
                                 bob=bob).long()
    # 5)
    BETA_prime_sh = BETA_prime.share(bob, alice).child.child

    # 7)
    _lambda = syft._SNNTensor(BETA_prime_sh + (j_0 * BETA) - (2 * BETA * BETA_prime_sh)).wrap(True)

    # 8)
    _delta = syft._SNNTensor(x_bit_sh_0.squeeze(-1) + (j_0 * r_0) - (2 * r_0 * x_bit_sh_0.squeeze(-1))).wrap(True)

    # 9)
    theta = _lambda * _delta

    # 10)
    u = torch.zeros(list(theta.get_shape())).long().share(alice, bob)
    a = _lambda + _delta - (2 * theta) + u

    return a.view(*list(input_shape))
Beispiel #5
0
    def prepPC(self):
        x = torch.LongTensor([1, 2, 3, 4])
        x_bit = decompose(x)
        x_bit_sh = x_bit.share(*self.workers)

        r = torch.ones(x.get_shape()).long().send(*self.workers) + 2

        j0 = torch.zeros(x_bit_sh.get_shape()).long().send(self.workers[0]).child
        j1 = (torch.ones(x_bit_sh.get_shape())).long().send(self.workers[1]).child
        j = sy._GeneralizedPointerTensor(
            {self.workers[0]: j0, self.workers[1]: j1}, torch_type="syft.LongTensor"
        ).wrap(True)

        # 4)
        BETA = torch.zeros(x.get_shape()).long().send(*self.workers)

        return x_bit_sh, r, BETA, j