def test_invalid_mpc_pointer(get_clients) -> None: parties = get_clients(3) session = Session(parties=parties) SessionManager.setup_mpc(session) x = MPCTensor(secret=1, session=session) # passing sharetensor pointer with pytest.raises(ValueError): ABY3.truncate(x, session, 2**32, None)
def mul_master( x: MPCTensor, y: MPCTensor, session: Session, op_str: str, kwargs_: Dict[Any, Any], ) -> MPCTensor: """Master method for multiplication. Args: x (MPCTensor): Secret y (MPCTensor): Another secret session (Session): Session the tensors belong to op_str (str): Operation string. kwargs_ (Dict[Any, Any]): Kwargs for some operations like conv2d Returns: result (MPCTensor): Result of the operation. Raises: ValueError: Raised when number of parties are not three. ValueError : Raised when invalid security_type is provided. """ if len(session.parties) != 3: raise ValueError("Falcon requires 3 parties") result = None ring_size = int(x.share_ptrs[0].get_ring_size().get_copy()) conf_dict = x.share_ptrs[0].get_config().get_copy() config = Config(**conf_dict) if session.protocol.security_type == "semi-honest": result = Falcon.mul_semi_honest(x, y, session, op_str, ring_size, config, **kwargs_) elif session.protocol.security_type == "malicious": result = Falcon.mul_malicious(x, y, session, op_str, ring_size, config, **kwargs_) else: raise ValueError("Invalid security_type for Falcon multiplication") result = ABY3.truncate(result, session, ring_size, config) return result
def test_truncation_algorithm1(get_clients, base, precision) -> None: parties = get_clients(3) falcon = Falcon("semi-honest") config = Config(encoder_base=base, encoder_precision=precision) session = Session(parties=parties, protocol=falcon, config=config) SessionManager.setup_mpc(session) x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]]) x_mpc = MPCTensor(secret=x, session=session) result = ABY3.truncate(x_mpc, session, session.ring_size, session.config) fp_encoder = FixedPointEncoder(base=session.config.encoder_base, precision=session.config.encoder_precision) expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale expected_res = fp_encoder.decode(expected_res) assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def truncate(self, input_tensor: "MPCTensor", op_str: str, is_private: bool) -> "MPCTensor": """Checks if operation requires truncation and performs it if required. Args: input_tensor (MPCTensor): Result of operation op_str (str): Operation name is_private (bool): If operation is private Returns: result (MPCTensor): Truncated result """ from sympc.protocol import ABY3 from sympc.tensor import ReplicatedSharedTensor result = None if (op_str in TRUNCATED_OPS and (not is_private or self.session.nr_parties > 2) and self.session.protocol.share_class == ShareTensor): # For private op we do the division in the mul_parties function from spdz scale = (self.session.config.encoder_base** self.session.config.encoder_precision) result = input_tensor.truediv(scale) elif (op_str in TRUNCATED_OPS and (not is_private) and self.session.protocol.share_class == ReplicatedSharedTensor): ring_size = int(self.share_ptrs[0].get_ring_size().get_copy()) conf_dict = self.share_ptrs[0].get_config().get_copy() config = Config(**conf_dict) result = ABY3.truncate(input_tensor, self.session, ring_size, config) else: result = input_tensor return result
def test_invalid_parties_trunc(get_clients) -> None: parties = get_clients(2) session = Session(parties=parties) with pytest.raises(ValueError): ABY3.truncate(None, session, 2**32, None)