Beispiel #1
0
def test_invalid_mpc_pointer(get_clients) -> None:
    parties = get_clients(3)
    session = Session(parties=parties)
    SessionManager.setup_mpc(session)
    x = MPCTensor(secret=1, session=session)
    # passing sharetensor pointer
    with pytest.raises(ValueError):
        ABY3.truncate(x, session, 2**32, None)
Beispiel #2
0
    def mul_master(
        x: MPCTensor,
        y: MPCTensor,
        session: Session,
        op_str: str,
        kwargs_: Dict[Any, Any],
    ) -> MPCTensor:
        """Master method for multiplication.

        Args:
            x (MPCTensor): Secret
            y (MPCTensor): Another secret
            session (Session): Session the tensors belong to
            op_str (str): Operation string.
            kwargs_ (Dict[Any, Any]): Kwargs for some operations like conv2d

        Returns:
            result (MPCTensor): Result of the operation.

        Raises:
            ValueError: Raised when number of parties are not three.
            ValueError : Raised when invalid security_type is provided.
        """
        if len(session.parties) != 3:
            raise ValueError("Falcon requires 3 parties")

        result = None

        ring_size = int(x.share_ptrs[0].get_ring_size().get_copy())
        conf_dict = x.share_ptrs[0].get_config().get_copy()
        config = Config(**conf_dict)

        if session.protocol.security_type == "semi-honest":
            result = Falcon.mul_semi_honest(x, y, session, op_str, ring_size,
                                            config, **kwargs_)
        elif session.protocol.security_type == "malicious":
            result = Falcon.mul_malicious(x, y, session, op_str, ring_size,
                                          config, **kwargs_)
        else:
            raise ValueError("Invalid security_type for Falcon multiplication")

        result = ABY3.truncate(result, session, ring_size, config)

        return result
Beispiel #3
0
def test_truncation_algorithm1(get_clients, base, precision) -> None:
    parties = get_clients(3)
    falcon = Falcon("semi-honest")
    config = Config(encoder_base=base, encoder_precision=precision)
    session = Session(parties=parties, protocol=falcon, config=config)
    SessionManager.setup_mpc(session)

    x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]])

    x_mpc = MPCTensor(secret=x, session=session)

    result = ABY3.truncate(x_mpc, session, session.ring_size, session.config)

    fp_encoder = FixedPointEncoder(base=session.config.encoder_base,
                                   precision=session.config.encoder_precision)
    expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale
    expected_res = fp_encoder.decode(expected_res)

    assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
Beispiel #4
0
    def truncate(self, input_tensor: "MPCTensor", op_str: str,
                 is_private: bool) -> "MPCTensor":
        """Checks if operation requires truncation and performs it if required.

        Args:
            input_tensor (MPCTensor): Result of operation
            op_str (str): Operation name
            is_private (bool): If operation is private

        Returns:
            result (MPCTensor): Truncated result
        """
        from sympc.protocol import ABY3
        from sympc.tensor import ReplicatedSharedTensor

        result = None
        if (op_str in TRUNCATED_OPS
                and (not is_private or self.session.nr_parties > 2)
                and self.session.protocol.share_class == ShareTensor):
            # For private op we do the division in the mul_parties function from spdz
            scale = (self.session.config.encoder_base**
                     self.session.config.encoder_precision)
            result = input_tensor.truediv(scale)
        elif (op_str in TRUNCATED_OPS and (not is_private)
              and self.session.protocol.share_class == ReplicatedSharedTensor):
            ring_size = int(self.share_ptrs[0].get_ring_size().get_copy())
            conf_dict = self.share_ptrs[0].get_config().get_copy()
            config = Config(**conf_dict)

            result = ABY3.truncate(input_tensor, self.session, ring_size,
                                   config)

        else:
            result = input_tensor

        return result
Beispiel #5
0
def test_invalid_parties_trunc(get_clients) -> None:
    parties = get_clients(2)
    session = Session(parties=parties)

    with pytest.raises(ValueError):
        ABY3.truncate(None, session, 2**32, None)