示例#1
0
def test_grad_transpose_forward(get_clients) -> None:
    secret = torch.Tensor([[1, 2, 3], [4, 5, 6]])
    mpc_tensor = secret.share(parties=get_clients(4))

    ctx = {}
    res_mpc = GradT.forward(ctx, mpc_tensor)

    res = res_mpc.reconstruct()
    expected = secret.t()

    assert (res == expected).all()
示例#2
0
def test_grad_transpose_backward(get_clients) -> None:
    parties = get_clients(4)
    grad = torch.Tensor([[1, 2, 3], [4, 5, 6]])
    grad_mpc = grad.t().share(parties=parties)

    ctx = {}
    res_mpc = GradT.backward(ctx, grad_mpc)

    res = res_mpc.reconstruct()
    expected = grad

    assert (res == expected).all()