コード例 #1
0
def test_max(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([1, 2, 3, -1, -3])
    x = MPCTensor(secret=secret, session=session)

    max_val = x.max()
    assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor"

    expected = secret.max()
    res = max_val.reconstruct()
    assert res == expected, f"Expected argmax to be {expected}"
コード例 #2
0
def test_max_dim(dim, keepdim, get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]])
    x = MPCTensor(secret=secret, session=session)

    max_val, max_idx_val = x.max(dim=dim, keepdim=keepdim)
    assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor"

    res_idx = max_idx_val.reconstruct()
    res_max = max_val.reconstruct()
    expected_max, expected_indices = secret.max(dim=dim, keepdim=keepdim)
    assert (
        res_idx == expected_indices
    ).all(), f"Expected indices for maximum to be {expected_indices}"
    assert (res_max == expected_max).all(), f"Expected argmax to be {expected_max}"