def test_rst_reconstruct_zero_share_ptrs(get_clients, security) -> None:
    parties = get_clients(3)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([[1, -2.0, 0.0], [3.9, -4.394, -0.9],
                           [-43, 100, -0.4343], [1.344, -5.0, 0.55]])

    a = MPCTensor(secret=secret, session=session)
    a.share_ptrs = []
    with pytest.raises(ValueError):
        a.reconstruct()
def test_invalid_malicious_reconstruction(get_clients, parties):
    parties = get_clients(parties)
    protocol = Falcon("malicious")
    session = Session(protocol=protocol, parties=parties)
    SessionManager.setup_mpc(session)

    secret = 42.32

    tensor = MPCTensor(secret=secret, session=session)
    tensor.share_ptrs[0][0] = tensor.share_ptrs[0][0] + 4

    with pytest.raises(ValueError):
        tensor.reconstruct()
def test_invalid_malicious_reconstruction(get_clients, parties):
    parties = get_clients(parties)
    protocol = Falcon("malicious")
    session = Session(protocol=protocol, parties=parties)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([[1, -2.0, 0.0], [3.9, -4.394, -0.9],
                           [-43, 100, -0.4343], [1.344, -5.0, 0.55]])

    tensor = MPCTensor(secret=secret, session=session)
    tensor.share_ptrs[0][0] = tensor.share_ptrs[0][0] + 4

    with pytest.raises(ValueError):
        tensor.reconstruct()
示例#4
0
def test_remote_not_tensor(get_clients) -> None:
    alice_client, bob_client = get_clients(2)
    session = Session(parties=[alice_client, bob_client])
    SessionManager.setup_mpc(session)

    x_remote_int = bob_client.python.Int(5)
    x = MPCTensor(secret=x_remote_int, shape=(1, ), session=session)
    result = x.reconstruct()

    assert x_remote_int == result

    x_remote_int = bob_client.python.Float(5.4)
    x = MPCTensor(secret=x_remote_int, shape=(1, ), session=session)
    result = x.reconstruct()

    assert np.allclose(x_remote_int.get(), result, atol=1e-5)
示例#5
0
def test_local_secret_not_tensor(get_clients) -> None:
    alice_client, bob_client = get_clients(2)
    session = Session(parties=[alice_client, bob_client])
    SessionManager.setup_mpc(session)

    x_int = 5
    x = MPCTensor(secret=x_int, session=session)
    result = x.reconstruct()

    assert x_int == result

    x_float = 5.987
    x = MPCTensor(secret=x_float, session=session)
    result = x.reconstruct()

    assert np.allclose(torch.tensor(x_float), result)
示例#6
0
def test_get_grad_input_padding(get_clients, common_args: List,
                                nr_parties) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    grad = torch.Tensor([[[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,
                                                              1.0]]]])
    grad_mpc = MPCTensor(secret=grad, session=session)

    input_size, stride, padding, kernel_size, dilation = common_args

    expected_padding = torch.nn.functional.grad._grad_input_padding(
        grad,
        input_size,
        (stride, stride),
        (padding, padding),
        kernel_size,
        (dilation, dilation),
    )

    args = [[el] + common_args + [session] for el in grad_mpc.share_ptrs]
    shares = parallel_execution(GradConv2d.get_grad_input_padding,
                                grad_mpc.session.parties)(args)
    grad_input_padding = MPCTensor(shares=shares, session=grad_mpc.session)
    output_padding_tensor = grad_input_padding.reconstruct()
    output_padding_tensor /= grad_mpc.session.nr_parties
    calculated_padding = tuple(output_padding_tensor.to(torch.int).tolist())

    assert calculated_padding == expected_padding
def test_rst_distribute_reconstruct_float_secret(get_clients, parties,
                                                 security) -> None:
    parties = get_clients(parties)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties)
    SessionManager.setup_mpc(session)

    secret = 43.2
    a = MPCTensor(secret=secret, session=session)
    assert np.allclose(secret, a.reconstruct(), atol=1e-3)
示例#8
0
def test_remote_mpc_with_shape(get_clients) -> None:
    alice_client, bob_client = get_clients(2)
    session = Session(parties=[alice_client, bob_client])
    SessionManager.setup_mpc(session)

    x_remote = alice_client.torch.Tensor([1, -2, 0.3])
    x = MPCTensor(secret=x_remote, shape=(1, 3), session=session)
    result = x.reconstruct()

    assert np.allclose(x_remote.get(), result, atol=1e-5)
def test_rst_distribute_reconstruct_tensor_secret(get_clients, parties,
                                                  security) -> None:
    parties = get_clients(parties)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties)
    SessionManager.setup_mpc(session)

    secret = torch.Tensor([[1, -2.0, 0.0], [3.9, -4.394, -0.9],
                           [-43, 100, -0.4343], [1.344, -5.0, 0.55]])

    a = MPCTensor(secret=secret, session=session)
    assert np.allclose(secret, a.reconstruct(), atol=1e-3)
示例#10
0
def fss_op(x1: MPCTensor, x2: MPCTensor, op="eq") -> MPCTensor:
    """Define the workflow for a binary operation using Function Secret Sharing.

    Currently supported operand are = & <=, respectively corresponding to
    op = 'eq' and 'comp'.

    Args:
        x1 (MPCTensor): First private value.
        x2 (MPCTensor): Second private value.
        op: Type of operation to perform, should be 'eq' or 'comp'. Defaults to eq.

    Returns:
        MPCTensor: Shares of the comparison.
    """
    assert not th.cuda.is_available()  # nosec

    # FIXME: Better handle the case where x1 or x2 is not a MPCTensor. For the moment
    # FIXME: we cast it into a MPCTensor at the expense of extra communication
    session = x1.session
    dtype = session.tensor_type

    shape = MPCTensor._get_shape("sub", x1.shape, x2.shape)
    n_values = shape.numel()

    CryptoPrimitiveProvider.generate_primitives(
        f"fss_{op}",
        sessions=session.session_ptrs,
        g_kwargs={"n_values": n_values},
        p_kwargs={},
    )

    args = zip(session.session_ptrs, x1.share_ptrs, x2.share_ptrs)
    args = [list(el) + [op] for el in args]

    shares = parallel_execution(mask_builder, session.parties)(args)

    # TODO: don't do .reconstruct(), this should be done remotely between the evaluators
    mask_value = MPCTensor(shares=shares, session=session)
    mask_value = mask_value.reconstruct(decode=False) % 2**n

    # TODO: add dtype to args
    args = [(session.session_ptrs[i], th.IntTensor([i]), mask_value, op)
            for i in range(2)]

    shares = parallel_execution(evaluate, session.parties)(args)

    response = MPCTensor(session=session, shares=shares, shape=shape)
    response.shape = shape
    return response
示例#11
0
def test_reconstruct(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    a_rand = 3
    a = ShareTensor(data=a_rand, encoder_precision=0)
    a_shares = MPCTensor.generate_shares(a, 2, torch.long)

    a_shares_copy = MPCTensor.generate_shares(a_rand, 2, torch.long)

    x_secret = torch.Tensor([1, -2, 3.0907, -4.870])
    x = MPCTensor(secret=x_secret, session=session)
    x = x.reconstruct()

    assert torch.allclose(x_secret, x)
示例#12
0
def test_reconstruct(get_clients) -> None:
    clients = get_clients(2)
    session = Session(parties=clients)
    SessionManager.setup_mpc(session)

    a_rand = 3
    a = ShareTensor(data=a_rand, config=Config(encoder_precision=0))
    MPCTensor.generate_shares(secret=a, nr_parties=2, tensor_type=torch.long)

    MPCTensor.generate_shares(
        secret=a_rand, nr_parties=2, config=Config(), tensor_type=torch.long
    )

    x_secret = torch.Tensor([1, -2, 3.0907, -4.870])
    x = MPCTensor(secret=x_secret, session=session)
    x = x.reconstruct()

    assert np.allclose(x_secret, x)
def test_ops_public_mul_integer_parties(get_clients, parties, security):

    config = Config(encoder_base=1, encoder_precision=0)

    parties = get_clients(parties)
    protocol = Falcon(security)
    session = Session(protocol=protocol, parties=parties, config=config)
    SessionManager.setup_mpc(session)

    secret = torch.tensor([[-100, 20, 30], [-90, 1000, 1], [1032, -323, 15]])
    value = 8

    op = getattr(operator, "mul")
    tensor = MPCTensor(secret=secret, session=session)
    shares = [op(share, value) for share in tensor.share_ptrs]
    result = MPCTensor(shares=shares, session=session)

    assert (result.reconstruct() == (secret * value)).all()
示例#14
0
def test_truncation_algorithm1(get_clients, base, precision) -> None:
    parties = get_clients(3)
    falcon = Falcon("semi-honest")
    config = Config(encoder_base=base, encoder_precision=precision)
    session = Session(parties=parties, protocol=falcon, config=config)
    SessionManager.setup_mpc(session)

    x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]])

    x_mpc = MPCTensor(secret=x, session=session)

    result = ABY3.truncate(x_mpc, session, session.ring_size, session.config)

    fp_encoder = FixedPointEncoder(base=session.config.encoder_base,
                                   precision=session.config.encoder_precision)
    expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale
    expected_res = fp_encoder.decode(expected_res)

    assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
示例#15
0
    def bit_decomposition_ttp(x: MPCTensor,
                              session: Session) -> List[MPCTensor]:
        """Perform ABY3 bit decomposition using orchestrator as ttp.

        Args:
            x (MPCTensor): Arithmetic shares of secret.
            session (Session): session the share belongs to.

        Returns:
            b_sh (List[MPCTensor]): Returns binary shares of each bit of the secret.

        TODO: We should modify to use parallel prefix adder, which requires multiprocessing.
        """
        # Decoding is not done as they are shares of PRRS.
        tensor = x.reconstruct(decode=False)
        b_sh: List[MPCTensor] = []  # binary shares of bits
        ring_size = session.ring_size
        shares_sum = ReplicatedSharedTensor.shares_sum
        ring_bits = get_nr_bits(ring_size)

        for idx in range(ring_bits):
            bit_mask = torch.ones(tensor.shape, dtype=tensor.dtype) << idx
            secret = (tensor & bit_mask).type(torch.bool)
            r1 = torch.empty(size=tensor.shape,
                             dtype=torch.bool).random_(generator=gen)
            r2 = torch.empty(size=tensor.shape,
                             dtype=torch.bool).random_(generator=gen)
            r3 = shares_sum([secret, r1, r2], ring_size=2)
            shares = [r1, r2, r3]
            config = Config(encoder_base=1, encoder_precision=0)
            sh_ptr = ReplicatedSharedTensor.distribute_shares(shares=shares,
                                                              session=session,
                                                              ring_size=2,
                                                              config=config)
            b_mpc = MPCTensor(shares=sh_ptr,
                              session=session,
                              shape=tensor.shape)
            b_sh.append(b_mpc)

        return b_sh
示例#16
0
    def mul_malicious(
        x: MPCTensor,
        y: MPCTensor,
        session: Session,
        op_str: str,
        ring_size: int,
        config: Config,
        **kwargs_: Dict[Any, Any],
    ) -> MPCTensor:
        """Falcon malicious multiplication.

        Args:
            x (MPCTensor): Secret
            y (MPCTensor): Another secret
            session (Session): Session the tensors belong to
            op_str (str): Operation string.
            ring_size (int) : Ring size of the underlying tensor.
            config (Config): The configuration(base,precision) of the underlying tensor.
            kwargs_ (Dict[Any, Any]): Kwargs for some operations like conv2d

        Returns:
            result(MPCTensor): Result of the operation.

        Raises:
            ValueError : If the shares are not valid.
        """
        shape_x = tuple(x.shape)
        shape_y = tuple(y.shape)

        result = Falcon.mul_semi_honest(x,
                                        y,
                                        session,
                                        op_str,
                                        ring_size,
                                        config,
                                        reshare=True,
                                        **kwargs_)

        args = [list(sh) + [op_str] for sh in zip(x.share_ptrs, y.share_ptrs)]
        try:
            mask = parallel_execution(Falcon.falcon_mask,
                                      session.parties)(args)
        except EmptyPrimitiveStore:
            CryptoPrimitiveProvider.generate_primitives(
                f"beaver_{op_str}",
                session=session,
                g_kwargs={
                    "session": session,
                    "a_shape": shape_x,
                    "b_shape": shape_y,
                    "nr_parties": session.nr_parties,
                    "ring_size": ring_size,
                    "config": config,
                    **kwargs_,
                },
                p_kwargs={
                    "a_shape": shape_x,
                    "b_shape": shape_y
                },
            )
            mask = parallel_execution(Falcon.falcon_mask,
                                      session.parties)(args)

        # zip on pointers is compute intensive
        mask_local = [mask[idx].get() for idx in range(session.nr_parties)]
        eps_shares, delta_shares = zip(*mask_local)

        eps_plaintext = ReplicatedSharedTensor.reconstruct(eps_shares)
        delta_plaintext = ReplicatedSharedTensor.reconstruct(delta_shares)

        args = [
            list(sh) + [eps_plaintext, delta_plaintext, op_str]
            for sh in zip(result.share_ptrs)
        ]

        triple_shares = parallel_execution(Falcon.triple_verification,
                                           session.parties)(args, kwargs_)

        triple = MPCTensor(shares=triple_shares, session=x.session)

        if (triple.reconstruct(decode=False) == 0).all():
            return result
        else:
            raise ValueError("Computation Aborted: Malicious behavior.")
示例#17
0
class Linear(SMPCModule):
    """A Linear SMPC Layer."""

    __slots__ = [
        "weight",
        "bias",
        "session",
        "in_features",
        "out_features",
        "_parameters",
    ]

    in_features: Tuple[int]
    out_features: Tuple[int]
    weight: MPCTensor
    bias: Optional[MPCTensor]
    _parameters: Dict[str, MPCTensor]

    def __init__(self, session) -> None:
        """The initializer for the Linear layer.

        Args:
            session (Session): the session used to identify the layer
        """
        self.bias = None
        self._parameters = None
        self.session = session

    def forward(self, x: MPCTensor) -> MPCTensor:
        """Do a feedforward through the layer.

        Args:
            x (MPCTensor): the input

        Returns:
            An MPCTensor that results by applying the layer specific operation on the input
        """
        res = x @ self.weight.t()

        if self.bias is not None:
            res = res + self.bias

        return res

    __call__ = forward

    def parameters(self, recurse: bool = False) -> MPCTensor:
        """Get the parameters of the Linear module.

        Args:
            recurse (bool): For the moment not used. TODO

        Yields:
            Each parameter of the module
        """
        for param in self._parameters.values():
            yield param

    def share_state_dict(
        self,
        state_dict: Dict[str, Any],
        additional_attributes: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Share the parameters of the normal Linear layer.

        Args:
            state_dict (Dict[str, Any]): the state dict that would be shared
            additional_attributes (Dict[str, Any]): Attributes apart from weights.
        """
        bias = None
        if ispointer(state_dict):
            weight = state_dict["weight"].resolve_pointer_type()
            if "bias" in weight.client.python.List(state_dict).get():
                bias = state_dict["bias"].resolve_pointer_type()
            shape = weight.client.python.Tuple(weight.shape)
            shape = shape.get()
        else:
            weight = state_dict["weight"]
            bias = state_dict.get("bias")
            shape = state_dict["weight"].shape

        self.out_features, self.in_features = shape
        self.weight = MPCTensor(secret=weight,
                                session=self.session,
                                shape=shape,
                                requires_grad=True)
        self._parameters = OrderedDict({"weight": self.weight})

        if bias is not None:
            self.bias = MPCTensor(
                secret=bias,
                session=self.session,
                shape=(self.out_features, ),
                requires_grad=True,
            )
            self._parameters["bias"] = self.bias

    def reconstruct_state_dict(self) -> Dict[str, Any]:
        """Reconstruct the shared state dict.

        Returns:
            The reconstructed state dict (Dict[str, Any])
        """
        state_dict = OrderedDict()
        state_dict["weight"] = self.weight.reconstruct()

        if self.bias is not None:
            state_dict["bias"] = self.bias.reconstruct()

        return state_dict

    @staticmethod
    def get_torch_module(linear_module: "Linear") -> torch.nn.Module:
        """Get a torch module from a given MPC Layer module.

        The parameters of the models are not set.

        Args:
            linear_module (Linear): the MPC Linear layer

        Returns:
            A torch Linear module
        """
        bias = linear_module.bias is not None
        module = torch.nn.Linear(
            in_features=linear_module.in_features,
            out_features=linear_module.out_features,
            bias=bias,
        )
        return module
示例#18
0
def mul_master(x: MPCTensor, y: MPCTensor, op_str: str,
               kwargs_: Dict[Any, Any]) -> MPCTensor:
    """Function that is executed by the orchestrator to multiply two secret values.

    Args:
        x (MPCTensor): First value to multiply with.
        y (MPCTensor): Second value to multiply with.
        op_str (str): Operation string.
        kwargs_ (dict): TODO:Add docstring.

    Raises:
        ValueError: If op_str not in EXPECTED_OPS.

    Returns:
        MPCTensor: Result of the multiplication.
    """
    if op_str not in EXPECTED_OPS:
        raise ValueError(f"{op_str} should be in {EXPECTED_OPS}")

    session = x.session

    shape_x = tuple(x.shape)
    shape_y = tuple(y.shape)

    CryptoPrimitiveProvider.generate_primitives(
        f"beaver_{op_str}",
        sessions=session.session_ptrs,
        g_kwargs={
            "a_shape": shape_x,
            "b_shape": shape_y,
            "nr_parties": session.nr_parties,
            **kwargs_,
        },
        p_kwargs={
            "a_shape": shape_x,
            "b_shape": shape_y
        },
    )

    args = [
        list(el) + [op_str]
        for el in zip(session.session_ptrs, x.share_ptrs, y.share_ptrs)
    ]

    mask = parallel_execution(spdz_mask, session.parties)(args)
    eps_shares, delta_shares = zip(*mask)

    eps = MPCTensor(shares=eps_shares, session=session)
    delta = MPCTensor(shares=delta_shares, session=session)

    eps_plaintext = eps.reconstruct(decode=False)
    delta_plaintext = delta.reconstruct(decode=False)

    # Arguments that must be sent to all parties
    common_args = [eps_plaintext, delta_plaintext, op_str]

    # Specific arguments to each party
    args = [[el] + common_args for el in session.session_ptrs]

    shares = parallel_execution(mul_parties, session.parties)(args, kwargs_)

    result = MPCTensor(shares=shares, session=session)

    return result
示例#19
0
def fss_op(x1: MPCTensor, x2: MPCTensor, op="eq") -> MPCTensor:
    """Define the workflow for a binary operation using Function Secret Sharing.

    Currently supported operand are = & <=, respectively corresponding to
    op = 'eq' and 'comp'.

    Args:
        x1 (MPCTensor): First private value.
        x2 (MPCTensor): Second private value.
        op: Type of operation to perform, should be 'eq' or 'comp'. Defaults to eq.

    Returns:
        MPCTensor: Shares of the comparison.
    """
    if th.cuda.is_available():
        # FSS is currently not supported on GPU.
        # https://stackoverflow.com/a/62145307/8878627

        # When the CUDA_VISIBLE_DEVICES environment variable is not set,
        # CUDA is not used even if available. Hence, we default to None
        cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        warnings.warn("Temporarily disabling CUDA as FSS does not support it")
    else:
        cuda_visible_devices = None

    # FIXME: Better handle the case where x1 or x2 is not a MPCTensor. For the moment
    # FIXME: we cast it into a MPCTensor at the expense of extra communication
    session = x1.session

    shape = MPCTensor._get_shape("sub", x1.shape, x2.shape)
    n_values = shape.numel()

    CryptoPrimitiveProvider.generate_primitives(
        f"fss_{op}",
        sessions=session.session_ptrs,
        g_kwargs={"n_values": n_values},
        p_kwargs={},
    )

    args = zip(session.session_ptrs, x1.share_ptrs, x2.share_ptrs)
    args = [list(el) + [op] for el in args]

    shares = parallel_execution(mask_builder, session.parties)(args)

    # TODO: don't do .reconstruct(), this should be done remotely between the evaluators
    mask_value = MPCTensor(shares=shares, session=session)
    mask_value = mask_value.reconstruct(decode=False) % 2**n

    # TODO: add dtype to args
    args = [(session.session_ptrs[i], th.IntTensor([i]), mask_value, op)
            for i in range(2)]

    shares = parallel_execution(evaluate, session.parties)(args)

    response = MPCTensor(session=session, shares=shares, shape=shape)
    response.shape = shape

    if cuda_visible_devices is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

    return response
示例#20
0
class Conv2d(SMPCModule):
    """Convolutional 2D."""

    __slots__ = (
        "session",
        "weight",
        "bias",
        "kernel_size",
        "stride",
        "padding",
        "dilation",
        "groups",
        "_parameters",
    )

    in_channels: int
    out_channels: int
    kernel_size: Tuple[int, ...]
    stride: Tuple[int, ...]
    padding: Tuple[int, ...]
    dilation: Tuple[int, ...]
    groups: int
    weight: List[MPCTensor]
    bias: Optional[MPCTensor]
    _parameters: OrderedDict

    def __init__(self, session: Session) -> None:
        """Initialize Conv2d layer.

        The stride, padding, dilation and groups are hardcoded for the moment.

        Args:
            session (Session): the session used to identify the layer
        """
        self.session = session
        self.in_channels = None
        self.out_channels = None
        self.stride = 1
        self.padding = 0
        self.dilation = 1
        self.groups = 1
        self._parameters = None

    def forward(self, x: MPCTensor) -> MPCTensor:
        """Do a feedforward through the layer.

        Args:
            x (MPCTensor): the input

        Returns:
            An MPCTensor representing the layer specific operation applied on the input
        """
        res = x.conv2d(
            weight=self.weight,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
        )

        return res

    __call__ = forward

    def set_additional_attributes(self, attributes: Dict) -> None:
        """Sets attributes of conv apart from weights.

        Args:
            attributes (Dict): Attributes with their values.

        Raises:
            ValueError: If the attribute does not exist.
        """
        for attr in attributes.keys():
            if hasattr(self, attr):
                setattr(self, attr, attributes[attr])
            else:
                raise ValueError(
                    f"Attribute {attr} does not exist in SyMPC module.")

    def share_state_dict(
        self,
        state_dict: Dict[str, Any],
        additional_attributes: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Share the parameters of the normal Conv2d layer.

        Args:
            state_dict (Dict[str, Any]): the state dict that would be shared.
            additional_attributes (Dict[str, Any]): Attributes of conv apart from weights.

        """
        bias = None
        if ispointer(state_dict):

            weight = state_dict["weight"].resolve_pointer_type()
            if "bias" in weight.client.python.List(state_dict).get():
                bias = state_dict["bias"].resolve_pointer_type()
            shape = weight.client.python.Tuple(weight.shape)
            shape = shape.get()

        else:
            weight = state_dict["weight"]
            bias = state_dict.get("bias")
            shape = state_dict["weight"].shape

        if ispointer(additional_attributes):
            self.set_additional_attributes(
                additional_attributes.get().resolve_pointer_type())
        else:
            self.set_additional_attributes(additional_attributes)

        # Weight shape (out_channel, in_channels/groups, kernel_size_w, kernel_size_h)
        # we have groups == 1

        (
            self.out_channels,
            self.in_channels,
            kernel_size_w,
            kernel_size_h,
        ) = shape

        self.kernel_size = (kernel_size_w, kernel_size_h)
        self.weight = MPCTensor(secret=weight,
                                session=self.session,
                                shape=shape)
        self._parameters = OrderedDict({"weight": self.weight})

        if bias is not None:
            self.bias = MPCTensor(secret=bias,
                                  session=self.session,
                                  shape=(self.out_channels, ))
            self._parameters["bias"] = self.bias

    def parameters(self, recurse: bool = False) -> MPCTensor:
        """Get the parameters of the Linear module.

        Args:
            recurse (bool): For the moment not used. TODO

        Yields:
            Each parameter of the module
        """
        for param in self._parameters.values():
            yield param

    def reconstruct_state_dict(self) -> Dict[str, Any]:
        """Reconstruct the shared state dict.

        Returns:
            Dict[str, Any]: The reconstructed state dict.
        """
        state_dict = OrderedDict()
        state_dict["weight"] = self.weight.reconstruct()

        if self.bias is not None:
            state_dict["bias"] = self.bias.reconstruct()

        return state_dict

    @staticmethod
    def get_torch_module(conv_module: "Conv2d") -> torch.nn.Module:
        """Get a torch module from a given MPC Conv2d module.

        The parameters of the models are not set.

        Args:
            conv_module (Conv2d): the MPC Conv2d layer

        Returns:
            torch.nn.Module: A torch Conv2d module.
        """
        bias = conv_module.bias is not None
        module = torch.nn.Conv2d(
            in_channels=conv_module.in_channels,
            out_channels=conv_module.out_channels,
            kernel_size=conv_module.kernel_size,
            stride=conv_module.stride,
            padding=conv_module.padding,
            dilation=conv_module.dilation,
            groups=conv_module.groups,
            bias=bias,
        )
        return module
示例#21
0
class Conv2d(SMPCModule):
    """Convolutional 2D."""

    __slots__ = [
        "session",
        "weight",
        "bias",
        "kernel_size",
        "stride",
        "padding",
        "dilation",
        "groups",
    ]

    in_channels: int
    out_channels: int
    kernel_size: Tuple[int, ...]
    stride: Tuple[int, ...]
    padding: Tuple[int, ...]
    dilation: Tuple[int, ...]
    groups: int
    weight: List[MPCTensor]
    bias: Optional[MPCTensor]

    def __init__(self, session: Session) -> None:
        """Initialize Conv2d layer.

        The stride, padding, dilation and groups are hardcoded for the moment.

        Args:
            session (Session): the session used to identify the layer
        """
        self.session = session
        self.stride = 1
        self.padding = 0
        self.dilation = 1
        self.groups = 1

    def forward(self, x: MPCTensor) -> MPCTensor:
        """Do a feedforward through the layer.

        Args:
            x (MPCTensor): the input

        Returns:
            An MPCTensor representing the layer specific operation applied on the input
        """
        res = x.conv2d(
            weight=self.weight,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
        )

        return res

    __call__ = forward

    def share_state_dict(
        self,
        state_dict: Dict[str, Any],
    ) -> None:
        """Share the parameters of the normal Conv2d layer.

        Args:
            state_dict (Dict[str, Any]): the state dict that would be shared.

        Raises:
            ValueError: If kernel sizes mismatch "kernel_size_w" and "kernel_size_h"
        """
        bias = None
        if ispointer(state_dict):
            weight = state_dict["weight"].resolve_pointer_type()
            if "bias" in weight.client.python.List(state_dict).get():
                bias = state_dict["bias"].resolve_pointer_type()
            shape = weight.client.python.Tuple(weight.shape)
            shape = shape.get()
        else:
            weight = state_dict["weight"]
            bias = state_dict.get("bias")
            shape = state_dict["weight"].shape

        # Weight shape (out_channel, in_channels/groups, kernel_size_w, kernel_size_h)
        # we have groups == 1
        (
            self.out_channels,
            self.in_channels,
            kernel_size_w,
            kernel_size_h,
        ) = shape

        if kernel_size_w != kernel_size_h:
            raise ValueError(
                f"Kernel sizes mismatch {kernel_size_w} and {kernel_size_h}")

        self.kernel_size = kernel_size_w

        self.weight = MPCTensor(secret=weight,
                                session=self.session,
                                shape=shape)

        if bias is not None:
            self.bias = MPCTensor(secret=bias,
                                  session=self.session,
                                  shape=(self.out_channels, ))

    def reconstruct_state_dict(self) -> Dict[str, Any]:
        """Reconstruct the shared state dict.

        Returns:
            Dict[str, Any]: The reconstructed state dict.
        """
        state_dict = OrderedDict()
        state_dict["weight"] = self.weight.reconstruct()

        if self.bias is not None:
            state_dict["bias"] = self.bias.reconstruct()

        return state_dict

    @staticmethod
    def get_torch_module(conv_module: "Conv2d") -> torch.nn.Module:
        """Get a torch module from a given MPC Conv2d module.

        The parameters of the models are not set.

        Args:
            conv_module (Conv2d): the MPC Conv2d layer

        Returns:
            torch.nn.Module: A torch Conv2d module.
        """
        bias = conv_module.bias is not None
        module = torch.nn.Conv2d(
            in_channels=conv_module.in_channels,
            out_channels=conv_module.out_channels,
            kernel_size=conv_module.kernel_size,
            bias=bias,
        )
        return module
示例#22
0
class Linear(SMPCModule):
    __slots__ = ["weight", "bias", "session", "in_features", "out_features"]

    in_features: Tuple[int]
    out_features: Tuple[int]
    weight: MPCTensor
    bias: Optional[MPCTensor]

    def __init__(self, session) -> None:
        """The initializer for the Linear layer.

        Args:
            session (Session): the session used to identify the layer
        """

        self.bias = None
        self.session = session

    def forward(self, x: MPCTensor) -> MPCTensor:
        """Do a feedforward through the layer.

        Args:
            x (MPCTensor): the input

        Returns:
            An MPCTensor the layer specific operation applied on the input
        """

        res = x @ self.weight.T

        if self.bias is not None:
            res = res + self.bias

        return res

    __call__ = forward

    def share_state_dict(
        self,
        state_dict: Dict[str, Any],
    ) -> None:
        """Share the parameters of the normal Linear layer.

        Args:
            state_dict (Dict[str, Any]): the state dict that would be shared
        """

        bias = None
        if ispointer(state_dict):
            weight = state_dict["weight"].resolve_pointer_type()
            if "bias" in weight.client.python.List(state_dict).get():
                bias = state_dict["bias"].resolve_pointer_type()
            shape = weight.client.python.Tuple(weight.shape)
            shape = shape.get()
        else:
            weight = state_dict["weight"]
            bias = state_dict.get("bias")
            shape = state_dict["weight"].shape

        self.out_features, self.in_features = shape
        self.weight = MPCTensor(secret=weight,
                                session=self.session,
                                shape=shape)

        if bias is not None:
            self.bias = MPCTensor(secret=bias,
                                  session=self.session,
                                  shape=(self.out_features, ))

    def reconstruct_state_dict(self) -> Dict[str, Any]:
        """Reconstruct the shared state dict.

        Returns:
            The reconstructed state dict (Dict[str, Any])
        """

        state_dict = OrderedDict()
        state_dict["weight"] = self.weight.reconstruct()

        if self.bias is not None:
            state_dict["bias"] = self.bias.reconstruct()

        return state_dict

    @staticmethod
    def get_torch_module(linear_module: "Linear") -> torch.nn.Module:
        """Get a torch module from a given MPC Layer module The parameters of
        the models are not set.

        Args:
            linear_module (Linear): the MPC Linear layer

        Returns:
            A torch Linear module
        """

        bias = linear_module.bias is not None
        module = torch.nn.Linear(
            in_features=linear_module.in_features,
            out_features=linear_module.out_features,
            bias=bias,
        )
        return module