def test_grad_add_backward(get_clients) -> None: parties = get_clients(4) grad = torch.Tensor([1, 2, 3, 4]) grad_mpc = grad.share(parties=parties) ctx = {"x_shape": (4, ), "y_shape": (4, )} res_mpc_x, res_mpc_y = GradAdd.backward(ctx, grad_mpc) assert (res_mpc_x.reconstruct() == grad).all() assert (res_mpc_y.reconstruct() == grad).all()
def test_grad_add_different_dims_backward(get_clients) -> None: parties = get_clients(4) grad = torch.Tensor([[[2, 4, 6], [5, 7, 9]]]) grad_x = grad grad_y = torch.Tensor([[7, 11, 15]]) grad_mpc = grad.share(parties=parties) ctx = {"x_shape": (2, 3), "y_shape": (1, 3)} res_mpc_x, res_mpc_y = GradAdd.backward(ctx, grad_mpc) assert (res_mpc_x.reconstruct() == grad_x).all() assert (res_mpc_y.reconstruct() == grad_y).all()
def test_grad_add_forward(get_clients) -> None: parties = get_clients(4) x = torch.Tensor([[1, 2, 3], [4, 5, 6]]) y = torch.Tensor([[1, 4, 6], [8, 10, 12]]) x_mpc = x.share(parties=parties) y_mpc = y.share(parties=parties) ctx = {} res_mpc = GradAdd.forward(ctx, x_mpc, y_mpc) res = res_mpc.reconstruct() expected = x + y assert (res == expected).all()