def hardtanh(tensor: MPCTensor, min_value: float = -1, max_value: float = 1) -> MPCTensor: """Calculates hardtanh of given tensor. Defined as 1 if x > 1 -1 if x < -1 x otherwise Args: tensor (MPCTensor): whose hardtanh has to be calculated min_value (float): minimum value of the linear region range. Default: -1 max_value (float): maximum value of the linear region range. Default: 1 Returns: MPCTensor: calculated MPCTensor """ intermediate = relu(tensor - min_value) - relu(tensor - max_value) return intermediate + min_value
def test_relu(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) secret = torch.Tensor([-2, -1.5, 0, 1, 1.5, 2]) mpc_tensor = MPCTensor(secret=secret, session=session) res = relu(mpc_tensor) res_expected = torch.nn.functional.relu(secret) assert all(res.reconstruct() == res_expected)