def test_max(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([1, 2, 3, -1, -3]) x = MPCTensor(secret=secret, session=session) max_val = x.max() assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor" expected = secret.max() res = max_val.reconstruct() assert res == expected, f"Expected argmax to be {expected}"
def test_max_dim(dim, keepdim, get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([[[1, 2], [3, -1], [4, 5]], [[2, 5], [5, 1], [6, 42]]]) x = MPCTensor(secret=secret, session=session) max_val, max_idx_val = x.max(dim=dim, keepdim=keepdim) assert isinstance(x, MPCTensor), "Expected argmax to be MPCTensor" res_idx = max_idx_val.reconstruct() res_max = max_val.reconstruct() expected_max, expected_indices = secret.max(dim=dim, keepdim=keepdim) assert ( res_idx == expected_indices ).all(), f"Expected indices for maximum to be {expected_indices}" assert (res_max == expected_max).all(), f"Expected argmax to be {expected_max}"