def test_cat(get_clients): clients = get_clients(2) x_secret = torch.Tensor([0.0, 1, -2, 3, -4]) y_secret = torch.Tensor([-4, 3, -2, 1, 0.0]) secret_concatenated = torch.cat([x_secret, y_secret]) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) y = MPCTensor(secret=y_secret, session=session) concatenated = cat([x, y]) assert (secret_concatenated == concatenated.reconstruct()).all()
def stack(tensors: List, dim: int = 0) -> MPCTensor: """Concatenates a sequence of tensors along a new dimension. Args: tensors (List): sequence of tensors to stacks dim (int): dimension to insert. Has to be between 0 and the number of dimensions of concatenated tensors (inclusive) Returns: MPCTensor: calculated MPCTensor """ session = tensors[0].session args = list( zip([str(uuid) for uuid in session.rank_to_uuid.values()], *[tensor.share_ptrs for tensor in tensors])) stack_shares = parallel_execution(stack_share_tensor, session.parties)(args) from sympc.tensor import MPCTensor expected_shape = torch.stack( [torch.empty(each_tensor.shape) for each_tensor in tensors], dim=dim).shape result = MPCTensor(shares=stack_shares, session=session, shape=expected_shape) return result
def cat(tensors: List, dim: int = 0) -> MPCTensor: """Concatenates the given sequence of seq tensors in the given dimension. Args: tensors (List): sequence of tensors to concatenate dim (int): the dimension over which the tensors are concatenated Returns: MPCTensor: calculated MPCTensor """ session = tensors[0].session args = list( zip([str(uuid) for uuid in session.rank_to_uuid.values()], *[tensor.share_ptrs for tensor in tensors])) stack_shares = parallel_execution(cat_share_tensor, session.parties)(args) from sympc.tensor import MPCTensor expected_shape = torch.cat( [torch.empty(each_tensor.shape) for each_tensor in tensors], dim=dim).shape result = MPCTensor(shares=stack_shares, session=session, shape=expected_shape) return result
def test_max_multiple_max(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=torch.Tensor([1, 2, 3, -1, 3]), session=session) with pytest.raises(ValueError): x.argmax()
def test_exception_value_error(get_clients) -> None: clients = get_clients(2) session_one = Session(parties=clients) SessionManager.setup_mpc(session_one) x_secret = torch.Tensor([-2.0, 6.0, 2.0, 3.0, -5.0, -0.5]) x = MPCTensor(secret=x_secret, session=session_one) with pytest.raises(ValueError): reciprocal(x, method="exp")
def test_softmax_single_along_dim(get_clients) -> None: clients = get_clients(2) x_secret = torch.arange(4, dtype=torch.float).view(4, 1) x_secret_softmax = F.softmax(x_secret) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) x_softmax = softmax(x) assert torch.allclose(x_secret_softmax, x_softmax.reconstruct(), atol=1e-2)
def test_softmax(get_clients, dim) -> None: clients = get_clients(2) x_secret = torch.arange(-6, 6, dtype=torch.float).view(3, 4) x_secret_softmax = F.softmax(x_secret, dim=dim) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) x_softmax = softmax(x, dim=dim) assert torch.allclose(x_secret_softmax, x_softmax.reconstruct(), atol=1e-2)
def test_exception_value_error(get_clients) -> None: clients = get_clients(2) x_secret = torch.Tensor([0.0, 1, -2, 3, -4]) torch.tanh(x_secret) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) with pytest.raises(ValueError): tanh(x, method="exp")
def test_sigmoid(get_clients, method) -> None: clients = get_clients(2) x_secret = torch.Tensor([0.0, 1, -2, 3, -4]) x_secret_sigmoid = torch.sigmoid(x_secret) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) x_sigmoid = sigmoid(x, method) assert torch.allclose(x_secret_sigmoid, x_sigmoid.reconstruct(), atol=1e-1)
def test_max(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([1, 2, 3, -1, -3]) x = MPCTensor(secret=secret, session=session) max_val = x.max() assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor" expected = secret.max() res = max_val.reconstruct() assert res == expected, f"Expected argmax to be {expected}"
def test_log(get_clients) -> None: clients = get_clients(2) x_secret = torch.Tensor([0.1, 0.5, 2, 5, 10]) x_secret_log = torch.log(x_secret) # with custom precision config = Config(encoder_precision=20) session = Session(parties=clients, config=config) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) x_log = log(x) assert torch.allclose(x_secret_log, x_log.reconstruct(), atol=1e-1)
def test_argmax_dim(dim, keepdim, get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]]) x = MPCTensor(secret=secret, session=session) argmax_val = x.argmax(dim=dim, keepdim=keepdim) assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor" res = argmax_val.reconstruct() expected = secret.argmax(dim=dim, keepdim=keepdim).float() assert (res == expected).all(), f"Expected argmax to be {expected}"
def test_reciprocal(method, get_clients) -> None: clients = get_clients(2) session_one = Session(parties=clients) SessionManager.setup_mpc(session_one) x_secret = torch.Tensor([-2.0, 6.0, 2.0, 3.0, -5.0, -0.5]) x = MPCTensor(secret=x_secret, session=session_one) x_secret_reciprocal = torch.reciprocal(x_secret) x_reciprocal = reciprocal(x, method=method) assert torch.allclose(x_secret_reciprocal, x_reciprocal.reconstruct(), atol=1e-1)
def test_tanh(get_clients) -> None: clients = get_clients(2) x_secret = torch.Tensor([0.0, 1, -2, 3, -4]) x_secret_tanh = torch.tanh(x_secret) session = Session(parties=clients) SessionManager.setup_mpc(session) x = MPCTensor(secret=x_secret, session=session) x_tanh = tanh(x, method="sigmoid") assert torch.allclose(x_secret_tanh, x_tanh.reconstruct(), atol=1e-2) with pytest.raises(ValueError): x_tanh = tanh(x, method="exp")
def test_max_dim(dim, keepdim, get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]]) x = MPCTensor(secret=secret, session=session) max_val, max_idx_val = x.max(dim=dim, keepdim=keepdim) assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor" res_idx = max_idx_val.reconstruct() res_max = max_val.reconstruct() expected_max, expected_indices = secret.max(dim=dim, keepdim=keepdim) assert ( res_idx == expected_indices ).all(), f"Expected indices for maximum to be {expected_indices}" assert (res_max == expected_max).all(), f"Expected argmax to be {expected_max}"
def backward(ctx: Dict[str, Any], grad: MPCTensor) -> Tuple[MPCTensor]: """Perform the backward pass for the conv2d operation. Args: ctx (Dict[str, Any]): Context used to retrieve the information for the backward pass grad (MPCTensor): The gradient that came from the child nodes Returns: (input_grad, weight_grad) (Tuple[MPCTensor]): The gradients passed to the input and kernal nodes. """ input = ctx["input"] weight = ctx["weight"] stride = ctx["stride"] padding = ctx["padding"] dilation = ctx["dilation"] groups = ctx["groups"] weight_size = (weight.shape[2], weight.shape[3]) in_channels = input.shape[1] out_channels = grad.shape[1] min_batch = input.shape[0] # Gradient w.r.t input of the Conv. common_args = [ tuple(input.shape), stride, padding, weight_size, dilation, grad.session, ] args = [[el] + common_args for el in grad.share_ptrs] shares = parallel_execution( GradConv2d.get_grad_input_padding, grad.session.parties )(args) grad_input_padding = MPCTensor(shares=shares, session=grad.session) output_padding_tensor = grad_input_padding.reconstruct() output_padding_tensor /= grad.session.nr_parties output_padding = tuple(output_padding_tensor.to(torch.int).tolist()) input_grad = grad.conv_transpose2d( weight, None, stride, output_padding, dilation, groups ) # Gradient w.r.t weights of the Conv. grad = grad.repeat(1, in_channels // groups, 1, 1) grad = grad.view(grad.shape[0] * grad.shape[1], 1, grad.shape[2], grad.shape[3]) input = input.view( 1, input.shape[0] * input.shape[1], input.shape[2], input.shape[3] ) weight_grad = input.conv2d( weight=grad, bias=None, dilation=stride, padding=padding, stride=dilation, groups=in_channels * min_batch, ) weight_grad = weight_grad.view( min_batch, weight_grad.shape[1] // min_batch, weight_grad.shape[2], weight_grad.shape[3], ) weight_grad = ( weight_grad.sum(0) .view( in_channels // groups, out_channels, weight_grad.shape[2], weight_grad.shape[3], ) .transpose(0, 1) ) weight_grad = weight_grad.narrow(2, 0, weight_size[1]) weight_grad = weight_grad.narrow(3, 0, weight_size[0]) return input_grad, weight_grad
def helper_argmax( x: MPCTensor, dim: Optional[Union[int, Tuple[int]]] = None, keepdim: bool = False, one_hot: bool = False, ) -> MPCTensor: """Compute argmax using pairwise comparisons. Makes the number of rounds fixed, here it is 2. This is inspired from CrypTen. Args: x (MPCTensor): the MPCTensor on which to compute helper_argmax on dim (Union[int, Tuple[int]): compute argmax over a specific dimension(s) keepdim (bool): when one_hot is true, keep all the dimensions of the tensor one_hot (bool): return the argmax as a one hot vector Returns: Given the args, it returns a one hot encoding (as an MPCTensor) or the index of the maximum value Raises: ValueError: In case more max values are found and we need to return the index """ # for each share in MPCTensor # do the algorithm portrayed in paper (helper_argmax_pairwise) # results in creating two matrices and subtraction them session = x.session prep_x = x.flatten() if dim is None else x args = [[str(uuid), share_ptr_tensor, dim] for uuid, share_ptr_tensor in zip( session.rank_to_uuid.values(), prep_x.share_ptrs)] shares = parallel_execution(helper_argmax_pairwise, session.parties)(args) res_shape = shares[0].shape.get() x_pairwise = MPCTensor(shares=shares, session=x.session, shape=res_shape) # with the MPCTensor tensor we check what entries are positive # then we check what columns of M matrix have m-1 non-zero entries after comparison # (by summing over cols) pairwise_comparisons = x_pairwise >= 0 # re-compute row_length _dim = -1 if dim is None else dim row_length = x.shape[_dim] if x.shape[_dim] > 1 else 2 result = pairwise_comparisons.sum(0) result = result >= (row_length - 1) res_shape = res_shape[1:] # Remove the leading dimension because of sum(0) if not one_hot: if dim is None: check = result * torch.Tensor( [i for i in range(np.prod(res_shape))]) else: size = [1 for _ in range(len(res_shape))] size[dim] = res_shape[dim] check = result * torch.Tensor([i for i in range(res_shape[_dim]) ]).view(size) if dim is not None: argmax = check.sum(dim=dim, keepdim=keepdim) else: argmax = check.sum() if (argmax >= row_length).reconstruct(): # In case we have 2 max values, rather then returning an invalid index # we raise an exception raise ValueError("There are multiple argmax values") result = argmax return result