def spdz_mul( cmd: Callable, shares: dict, other_shares, crypto_provider: AbstractWorker, field: int, **kwargs ): """Abstractly Multiplies two tensors Args: cmd: a callable of the equation to be commuted shares: a dictionary <location_id -> PointerTensor) of shares corresponding to self. Equivalent to calling self.child. other_shares: a dictionary <location_id -> PointerTensor) of shares corresponding to the tensor being multiplied by self. crypto_provider: an AbstractWorker which is used to generate triples field: an interger denoting the size of the field """ locations = list(shares.keys()) shares_shape = shares[locations[0]].shape other_shape = other_shares[locations[0]].shape triple = crypto_provider.generate_triple(cmd, field, shares_shape, other_shape, locations) a, b, c = triple a, b, c = a.child, b.child, c.child d = {} e = {} for location in locations: d[location] = shares[location] - a[location] e[location] = other_shares[location] - b[location] delta = torch.zeros(shares_shape, dtype=torch.long) epsilon = torch.zeros(other_shape, dtype=torch.long) for location in locations: d_temp = d[location].get() e_temp = e[location].get() delta = delta + d_temp epsilon = epsilon + e_temp delta_epsilon = cmd(delta, epsilon) delta_ptrs = {} epsilon_ptrs = {} a_epsilon = {} delta_b = {} z = {} for location in locations: delta_ptrs[location] = delta.send(location) epsilon_ptrs[location] = epsilon.send(location) a_epsilon[location] = cmd(a[location], epsilon_ptrs[location]) delta_b[location] = cmd(delta_ptrs[location], b[location]) z[location] = a_epsilon[location] + delta_b[location] + c[location] delta_epsilon_pointer = delta_epsilon.send(locations[0]) z[locations[0]] = z[locations[0]] + delta_epsilon_pointer return z
def spdz_mul(cmd: Callable, x_sh, y_sh, crypto_provider: AbstractWorker, field: int): """Abstractly multiplies two tensors (mul or matmul) Args: cmd: a callable of the equation to be computed (mul or matmul) x_sh (AdditiveSharingTensor): the left part of the operation y_sh (AdditiveSharingTensor): the right part of the operation crypto_provider (AbstractWorker): an AbstractWorker which is used to generate triples field (int): an integer denoting the size of the field Return: an AdditiveSharingTensor """ assert isinstance(x_sh, sy.AdditiveSharingTensor) assert isinstance(y_sh, sy.AdditiveSharingTensor) locations = x_sh.locations # Get triples a, b, a_mul_b = crypto_provider.generate_triple(cmd, field, x_sh.shape, y_sh.shape, locations) delta = x_sh - a epsilon = y_sh - b # Reconstruct and send to all workers delta = delta.reconstruct() epsilon = epsilon.reconstruct() delta_epsilon = cmd(delta, epsilon) # Trick to keep only one child in the MultiPointerTensor (like in SNN) j1 = torch.ones(delta_epsilon.shape).long().send(locations[0], **no_wrap) j0 = torch.zeros(delta_epsilon.shape).long().send(*locations[1:], **no_wrap) if len(locations) == 2: j = sy.MultiPointerTensor(children=[j1, j0]) else: j = sy.MultiPointerTensor(children=[j1] + j0.child.values()) delta_b = cmd(delta, b) a_epsilon = cmd(a, epsilon) return delta_epsilon * j + delta_b + a_epsilon + a_mul_b