Esempio n. 1
0
def test_max_multiple_max(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    x = MPCTensor(secret=torch.Tensor([1, 2, 3, -1, 3]), session=session)

    with pytest.raises(ValueError):
        x.argmax()
Esempio n. 2
0
def test_argmax(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([1, 2, 3, -1, -3])
    x = MPCTensor(secret=secret, session=session)

    argmax_val = x.argmax()
    assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor"

    expected = secret.argmax().float()
    res = argmax_val.reconstruct()
    assert res == expected, f"Expected argmax to be {expected}"
Esempio n. 3
0
def test_argmax_dim(dim, keepdim, get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]])
    x = MPCTensor(secret=secret, session=session)

    argmax_val = x.argmax(dim=dim, keepdim=keepdim)
    assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor"

    res = argmax_val.reconstruct()
    expected = secret.argmax(dim=dim, keepdim=keepdim).float()
    assert (res == expected).all(), f"Expected argmax to be {expected}"