예제 #1
0
def hardtanh(tensor: MPCTensor,
             min_value: float = -1,
             max_value: float = 1) -> MPCTensor:
    """Calculates hardtanh of given tensor.

    Defined as
        1 if x > 1
        -1 if x < -1
        x otherwise

    Args:
        tensor (MPCTensor): whose hardtanh has to be calculated
        min_value (float): minimum value of the linear region range. Default: -1
        max_value (float): maximum value of the linear region range. Default: 1

    Returns:
        MPCTensor: calculated MPCTensor
    """
    intermediate = relu(tensor - min_value) - relu(tensor - max_value)
    return intermediate + min_value
예제 #2
0
def test_relu(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([-2, -1.5, 0, 1, 1.5, 2])
    mpc_tensor = MPCTensor(secret=secret, session=session)

    res = relu(mpc_tensor)
    res_expected = torch.nn.functional.relu(secret)

    assert all(res.reconstruct() == res_expected)