Ejemplo n.º 1
0
def test_grad_add_backward(get_clients) -> None:
    parties = get_clients(4)

    grad = torch.Tensor([1, 2, 3, 4])
    grad_mpc = grad.share(parties=parties)

    ctx = {"x_shape": (4, ), "y_shape": (4, )}
    res_mpc_x, res_mpc_y = GradAdd.backward(ctx, grad_mpc)

    assert (res_mpc_x.reconstruct() == grad).all()
    assert (res_mpc_y.reconstruct() == grad).all()
Ejemplo n.º 2
0
def test_grad_add_different_dims_backward(get_clients) -> None:
    parties = get_clients(4)

    grad = torch.Tensor([[[2, 4, 6], [5, 7, 9]]])
    grad_x = grad
    grad_y = torch.Tensor([[7, 11, 15]])
    grad_mpc = grad.share(parties=parties)

    ctx = {"x_shape": (2, 3), "y_shape": (1, 3)}
    res_mpc_x, res_mpc_y = GradAdd.backward(ctx, grad_mpc)

    assert (res_mpc_x.reconstruct() == grad_x).all()
    assert (res_mpc_y.reconstruct() == grad_y).all()
Ejemplo n.º 3
0
def test_grad_add_forward(get_clients) -> None:
    parties = get_clients(4)
    x = torch.Tensor([[1, 2, 3], [4, 5, 6]])
    y = torch.Tensor([[1, 4, 6], [8, 10, 12]])

    x_mpc = x.share(parties=parties)
    y_mpc = y.share(parties=parties)

    ctx = {}
    res_mpc = GradAdd.forward(ctx, x_mpc, y_mpc)

    res = res_mpc.reconstruct()
    expected = x + y

    assert (res == expected).all()