def test_grad_mul_backward(get_clients) -> None: parties = get_clients(4) grad = torch.Tensor([[1, 2], [3, 4]]) x = torch.Tensor([[1, 2], [3, -4]]) y = torch.Tensor([[1, -4], [8, 9]]) x_mpc = x.share(parties=parties) y_mpc = y.share(parties=parties) grad_mpc = grad.share(parties=parties) ctx = {"x": x_mpc, "y": y_mpc} res_mpc_x, res_mpc_y = GradMul.backward(ctx, grad_mpc) assert np.allclose(res_mpc_x.reconstruct(), y * grad, rtol=1e-3) assert np.allclose(res_mpc_y.reconstruct(), x * grad, rtol=1e-3)
def test_grad_mul_forward(get_clients) -> None: parties = get_clients(4) x = torch.Tensor([[1, 2], [3, -4]]) y = torch.Tensor([[1, -4], [8, 9]]) x_mpc = x.share(parties=parties) y_mpc = y.share(parties=parties) ctx = {} res_mpc = GradMul.forward(ctx, x_mpc, y_mpc) assert "x" in ctx assert "y" in ctx res = res_mpc.reconstruct() expected = x * y assert np.allclose(res, expected, rtol=1e-3)