def test_rst_reconstruct_zero_share_ptrs(get_clients, security) -> None: parties = get_clients(3) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) secret = torch.Tensor([[1, -2.0, 0.0], [3.9, -4.394, -0.9], [-43, 100, -0.4343], [1.344, -5.0, 0.55]]) a = MPCTensor(secret=secret, session=session) a.share_ptrs = [] with pytest.raises(ValueError): a.reconstruct()
def test_invalid_malicious_reconstruction(get_clients, parties): parties = get_clients(parties) protocol = Falcon("malicious") session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) secret = 42.32 tensor = MPCTensor(secret=secret, session=session) tensor.share_ptrs[0][0] = tensor.share_ptrs[0][0] + 4 with pytest.raises(ValueError): tensor.reconstruct()
def test_invalid_malicious_reconstruction(get_clients, parties): parties = get_clients(parties) protocol = Falcon("malicious") session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) secret = torch.Tensor([[1, -2.0, 0.0], [3.9, -4.394, -0.9], [-43, 100, -0.4343], [1.344, -5.0, 0.55]]) tensor = MPCTensor(secret=secret, session=session) tensor.share_ptrs[0][0] = tensor.share_ptrs[0][0] + 4 with pytest.raises(ValueError): tensor.reconstruct()
def test_remote_not_tensor(get_clients) -> None: alice_client, bob_client = get_clients(2) session = Session(parties=[alice_client, bob_client]) SessionManager.setup_mpc(session) x_remote_int = bob_client.python.Int(5) x = MPCTensor(secret=x_remote_int, shape=(1, ), session=session) result = x.reconstruct() assert x_remote_int == result x_remote_int = bob_client.python.Float(5.4) x = MPCTensor(secret=x_remote_int, shape=(1, ), session=session) result = x.reconstruct() assert np.allclose(x_remote_int.get(), result, atol=1e-5)
def test_local_secret_not_tensor(get_clients) -> None: alice_client, bob_client = get_clients(2) session = Session(parties=[alice_client, bob_client]) SessionManager.setup_mpc(session) x_int = 5 x = MPCTensor(secret=x_int, session=session) result = x.reconstruct() assert x_int == result x_float = 5.987 x = MPCTensor(secret=x_float, session=session) result = x.reconstruct() assert np.allclose(torch.tensor(x_float), result)
def test_get_grad_input_padding(get_clients, common_args: List, nr_parties) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) grad = torch.Tensor([[[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]]) grad_mpc = MPCTensor(secret=grad, session=session) input_size, stride, padding, kernel_size, dilation = common_args expected_padding = torch.nn.functional.grad._grad_input_padding( grad, input_size, (stride, stride), (padding, padding), kernel_size, (dilation, dilation), ) args = [[el] + common_args + [session] for el in grad_mpc.share_ptrs] shares = parallel_execution(GradConv2d.get_grad_input_padding, grad_mpc.session.parties)(args) grad_input_padding = MPCTensor(shares=shares, session=grad_mpc.session) output_padding_tensor = grad_input_padding.reconstruct() output_padding_tensor /= grad_mpc.session.nr_parties calculated_padding = tuple(output_padding_tensor.to(torch.int).tolist()) assert calculated_padding == expected_padding
def test_rst_distribute_reconstruct_float_secret(get_clients, parties, security) -> None: parties = get_clients(parties) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) secret = 43.2 a = MPCTensor(secret=secret, session=session) assert np.allclose(secret, a.reconstruct(), atol=1e-3)
def test_remote_mpc_with_shape(get_clients) -> None: alice_client, bob_client = get_clients(2) session = Session(parties=[alice_client, bob_client]) SessionManager.setup_mpc(session) x_remote = alice_client.torch.Tensor([1, -2, 0.3]) x = MPCTensor(secret=x_remote, shape=(1, 3), session=session) result = x.reconstruct() assert np.allclose(x_remote.get(), result, atol=1e-5)
def test_rst_distribute_reconstruct_tensor_secret(get_clients, parties, security) -> None: parties = get_clients(parties) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties) SessionManager.setup_mpc(session) secret = torch.Tensor([[1, -2.0, 0.0], [3.9, -4.394, -0.9], [-43, 100, -0.4343], [1.344, -5.0, 0.55]]) a = MPCTensor(secret=secret, session=session) assert np.allclose(secret, a.reconstruct(), atol=1e-3)
def fss_op(x1: MPCTensor, x2: MPCTensor, op="eq") -> MPCTensor: """Define the workflow for a binary operation using Function Secret Sharing. Currently supported operand are = & <=, respectively corresponding to op = 'eq' and 'comp'. Args: x1 (MPCTensor): First private value. x2 (MPCTensor): Second private value. op: Type of operation to perform, should be 'eq' or 'comp'. Defaults to eq. Returns: MPCTensor: Shares of the comparison. """ assert not th.cuda.is_available() # nosec # FIXME: Better handle the case where x1 or x2 is not a MPCTensor. For the moment # FIXME: we cast it into a MPCTensor at the expense of extra communication session = x1.session dtype = session.tensor_type shape = MPCTensor._get_shape("sub", x1.shape, x2.shape) n_values = shape.numel() CryptoPrimitiveProvider.generate_primitives( f"fss_{op}", sessions=session.session_ptrs, g_kwargs={"n_values": n_values}, p_kwargs={}, ) args = zip(session.session_ptrs, x1.share_ptrs, x2.share_ptrs) args = [list(el) + [op] for el in args] shares = parallel_execution(mask_builder, session.parties)(args) # TODO: don't do .reconstruct(), this should be done remotely between the evaluators mask_value = MPCTensor(shares=shares, session=session) mask_value = mask_value.reconstruct(decode=False) % 2**n # TODO: add dtype to args args = [(session.session_ptrs[i], th.IntTensor([i]), mask_value, op) for i in range(2)] shares = parallel_execution(evaluate, session.parties)(args) response = MPCTensor(session=session, shares=shares, shape=shape) response.shape = shape return response
def test_reconstruct(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) a_rand = 3 a = ShareTensor(data=a_rand, encoder_precision=0) a_shares = MPCTensor.generate_shares(a, 2, torch.long) a_shares_copy = MPCTensor.generate_shares(a_rand, 2, torch.long) x_secret = torch.Tensor([1, -2, 3.0907, -4.870]) x = MPCTensor(secret=x_secret, session=session) x = x.reconstruct() assert torch.allclose(x_secret, x)
def test_reconstruct(get_clients) -> None: clients = get_clients(2) session = Session(parties=clients) SessionManager.setup_mpc(session) a_rand = 3 a = ShareTensor(data=a_rand, config=Config(encoder_precision=0)) MPCTensor.generate_shares(secret=a, nr_parties=2, tensor_type=torch.long) MPCTensor.generate_shares( secret=a_rand, nr_parties=2, config=Config(), tensor_type=torch.long ) x_secret = torch.Tensor([1, -2, 3.0907, -4.870]) x = MPCTensor(secret=x_secret, session=session) x = x.reconstruct() assert np.allclose(x_secret, x)
def test_ops_public_mul_integer_parties(get_clients, parties, security): config = Config(encoder_base=1, encoder_precision=0) parties = get_clients(parties) protocol = Falcon(security) session = Session(protocol=protocol, parties=parties, config=config) SessionManager.setup_mpc(session) secret = torch.tensor([[-100, 20, 30], [-90, 1000, 1], [1032, -323, 15]]) value = 8 op = getattr(operator, "mul") tensor = MPCTensor(secret=secret, session=session) shares = [op(share, value) for share in tensor.share_ptrs] result = MPCTensor(shares=shares, session=session) assert (result.reconstruct() == (secret * value)).all()
def test_truncation_algorithm1(get_clients, base, precision) -> None: parties = get_clients(3) falcon = Falcon("semi-honest") config = Config(encoder_base=base, encoder_precision=precision) session = Session(parties=parties, protocol=falcon, config=config) SessionManager.setup_mpc(session) x = torch.tensor([[1.24, 4.51, 6.87], [7.87, 1301, 541]]) x_mpc = MPCTensor(secret=x, session=session) result = ABY3.truncate(x_mpc, session, session.ring_size, session.config) fp_encoder = FixedPointEncoder(base=session.config.encoder_base, precision=session.config.encoder_precision) expected_res = x_mpc.reconstruct(decode=False) // fp_encoder.scale expected_res = fp_encoder.decode(expected_res) assert np.allclose(result.reconstruct(), expected_res, atol=1e-3)
def bit_decomposition_ttp(x: MPCTensor, session: Session) -> List[MPCTensor]: """Perform ABY3 bit decomposition using orchestrator as ttp. Args: x (MPCTensor): Arithmetic shares of secret. session (Session): session the share belongs to. Returns: b_sh (List[MPCTensor]): Returns binary shares of each bit of the secret. TODO: We should modify to use parallel prefix adder, which requires multiprocessing. """ # Decoding is not done as they are shares of PRRS. tensor = x.reconstruct(decode=False) b_sh: List[MPCTensor] = [] # binary shares of bits ring_size = session.ring_size shares_sum = ReplicatedSharedTensor.shares_sum ring_bits = get_nr_bits(ring_size) for idx in range(ring_bits): bit_mask = torch.ones(tensor.shape, dtype=tensor.dtype) << idx secret = (tensor & bit_mask).type(torch.bool) r1 = torch.empty(size=tensor.shape, dtype=torch.bool).random_(generator=gen) r2 = torch.empty(size=tensor.shape, dtype=torch.bool).random_(generator=gen) r3 = shares_sum([secret, r1, r2], ring_size=2) shares = [r1, r2, r3] config = Config(encoder_base=1, encoder_precision=0) sh_ptr = ReplicatedSharedTensor.distribute_shares(shares=shares, session=session, ring_size=2, config=config) b_mpc = MPCTensor(shares=sh_ptr, session=session, shape=tensor.shape) b_sh.append(b_mpc) return b_sh
def mul_malicious( x: MPCTensor, y: MPCTensor, session: Session, op_str: str, ring_size: int, config: Config, **kwargs_: Dict[Any, Any], ) -> MPCTensor: """Falcon malicious multiplication. Args: x (MPCTensor): Secret y (MPCTensor): Another secret session (Session): Session the tensors belong to op_str (str): Operation string. ring_size (int) : Ring size of the underlying tensor. config (Config): The configuration(base,precision) of the underlying tensor. kwargs_ (Dict[Any, Any]): Kwargs for some operations like conv2d Returns: result(MPCTensor): Result of the operation. Raises: ValueError : If the shares are not valid. """ shape_x = tuple(x.shape) shape_y = tuple(y.shape) result = Falcon.mul_semi_honest(x, y, session, op_str, ring_size, config, reshare=True, **kwargs_) args = [list(sh) + [op_str] for sh in zip(x.share_ptrs, y.share_ptrs)] try: mask = parallel_execution(Falcon.falcon_mask, session.parties)(args) except EmptyPrimitiveStore: CryptoPrimitiveProvider.generate_primitives( f"beaver_{op_str}", session=session, g_kwargs={ "session": session, "a_shape": shape_x, "b_shape": shape_y, "nr_parties": session.nr_parties, "ring_size": ring_size, "config": config, **kwargs_, }, p_kwargs={ "a_shape": shape_x, "b_shape": shape_y }, ) mask = parallel_execution(Falcon.falcon_mask, session.parties)(args) # zip on pointers is compute intensive mask_local = [mask[idx].get() for idx in range(session.nr_parties)] eps_shares, delta_shares = zip(*mask_local) eps_plaintext = ReplicatedSharedTensor.reconstruct(eps_shares) delta_plaintext = ReplicatedSharedTensor.reconstruct(delta_shares) args = [ list(sh) + [eps_plaintext, delta_plaintext, op_str] for sh in zip(result.share_ptrs) ] triple_shares = parallel_execution(Falcon.triple_verification, session.parties)(args, kwargs_) triple = MPCTensor(shares=triple_shares, session=x.session) if (triple.reconstruct(decode=False) == 0).all(): return result else: raise ValueError("Computation Aborted: Malicious behavior.")
class Linear(SMPCModule): """A Linear SMPC Layer.""" __slots__ = [ "weight", "bias", "session", "in_features", "out_features", "_parameters", ] in_features: Tuple[int] out_features: Tuple[int] weight: MPCTensor bias: Optional[MPCTensor] _parameters: Dict[str, MPCTensor] def __init__(self, session) -> None: """The initializer for the Linear layer. Args: session (Session): the session used to identify the layer """ self.bias = None self._parameters = None self.session = session def forward(self, x: MPCTensor) -> MPCTensor: """Do a feedforward through the layer. Args: x (MPCTensor): the input Returns: An MPCTensor that results by applying the layer specific operation on the input """ res = x @ self.weight.t() if self.bias is not None: res = res + self.bias return res __call__ = forward def parameters(self, recurse: bool = False) -> MPCTensor: """Get the parameters of the Linear module. Args: recurse (bool): For the moment not used. TODO Yields: Each parameter of the module """ for param in self._parameters.values(): yield param def share_state_dict( self, state_dict: Dict[str, Any], additional_attributes: Optional[Dict[str, Any]] = None, ) -> None: """Share the parameters of the normal Linear layer. Args: state_dict (Dict[str, Any]): the state dict that would be shared additional_attributes (Dict[str, Any]): Attributes apart from weights. """ bias = None if ispointer(state_dict): weight = state_dict["weight"].resolve_pointer_type() if "bias" in weight.client.python.List(state_dict).get(): bias = state_dict["bias"].resolve_pointer_type() shape = weight.client.python.Tuple(weight.shape) shape = shape.get() else: weight = state_dict["weight"] bias = state_dict.get("bias") shape = state_dict["weight"].shape self.out_features, self.in_features = shape self.weight = MPCTensor(secret=weight, session=self.session, shape=shape, requires_grad=True) self._parameters = OrderedDict({"weight": self.weight}) if bias is not None: self.bias = MPCTensor( secret=bias, session=self.session, shape=(self.out_features, ), requires_grad=True, ) self._parameters["bias"] = self.bias def reconstruct_state_dict(self) -> Dict[str, Any]: """Reconstruct the shared state dict. Returns: The reconstructed state dict (Dict[str, Any]) """ state_dict = OrderedDict() state_dict["weight"] = self.weight.reconstruct() if self.bias is not None: state_dict["bias"] = self.bias.reconstruct() return state_dict @staticmethod def get_torch_module(linear_module: "Linear") -> torch.nn.Module: """Get a torch module from a given MPC Layer module. The parameters of the models are not set. Args: linear_module (Linear): the MPC Linear layer Returns: A torch Linear module """ bias = linear_module.bias is not None module = torch.nn.Linear( in_features=linear_module.in_features, out_features=linear_module.out_features, bias=bias, ) return module
def mul_master(x: MPCTensor, y: MPCTensor, op_str: str, kwargs_: Dict[Any, Any]) -> MPCTensor: """Function that is executed by the orchestrator to multiply two secret values. Args: x (MPCTensor): First value to multiply with. y (MPCTensor): Second value to multiply with. op_str (str): Operation string. kwargs_ (dict): TODO:Add docstring. Raises: ValueError: If op_str not in EXPECTED_OPS. Returns: MPCTensor: Result of the multiplication. """ if op_str not in EXPECTED_OPS: raise ValueError(f"{op_str} should be in {EXPECTED_OPS}") session = x.session shape_x = tuple(x.shape) shape_y = tuple(y.shape) CryptoPrimitiveProvider.generate_primitives( f"beaver_{op_str}", sessions=session.session_ptrs, g_kwargs={ "a_shape": shape_x, "b_shape": shape_y, "nr_parties": session.nr_parties, **kwargs_, }, p_kwargs={ "a_shape": shape_x, "b_shape": shape_y }, ) args = [ list(el) + [op_str] for el in zip(session.session_ptrs, x.share_ptrs, y.share_ptrs) ] mask = parallel_execution(spdz_mask, session.parties)(args) eps_shares, delta_shares = zip(*mask) eps = MPCTensor(shares=eps_shares, session=session) delta = MPCTensor(shares=delta_shares, session=session) eps_plaintext = eps.reconstruct(decode=False) delta_plaintext = delta.reconstruct(decode=False) # Arguments that must be sent to all parties common_args = [eps_plaintext, delta_plaintext, op_str] # Specific arguments to each party args = [[el] + common_args for el in session.session_ptrs] shares = parallel_execution(mul_parties, session.parties)(args, kwargs_) result = MPCTensor(shares=shares, session=session) return result
def fss_op(x1: MPCTensor, x2: MPCTensor, op="eq") -> MPCTensor: """Define the workflow for a binary operation using Function Secret Sharing. Currently supported operand are = & <=, respectively corresponding to op = 'eq' and 'comp'. Args: x1 (MPCTensor): First private value. x2 (MPCTensor): Second private value. op: Type of operation to perform, should be 'eq' or 'comp'. Defaults to eq. Returns: MPCTensor: Shares of the comparison. """ if th.cuda.is_available(): # FSS is currently not supported on GPU. # https://stackoverflow.com/a/62145307/8878627 # When the CUDA_VISIBLE_DEVICES environment variable is not set, # CUDA is not used even if available. Hence, we default to None cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) os.environ["CUDA_VISIBLE_DEVICES"] = "" warnings.warn("Temporarily disabling CUDA as FSS does not support it") else: cuda_visible_devices = None # FIXME: Better handle the case where x1 or x2 is not a MPCTensor. For the moment # FIXME: we cast it into a MPCTensor at the expense of extra communication session = x1.session shape = MPCTensor._get_shape("sub", x1.shape, x2.shape) n_values = shape.numel() CryptoPrimitiveProvider.generate_primitives( f"fss_{op}", sessions=session.session_ptrs, g_kwargs={"n_values": n_values}, p_kwargs={}, ) args = zip(session.session_ptrs, x1.share_ptrs, x2.share_ptrs) args = [list(el) + [op] for el in args] shares = parallel_execution(mask_builder, session.parties)(args) # TODO: don't do .reconstruct(), this should be done remotely between the evaluators mask_value = MPCTensor(shares=shares, session=session) mask_value = mask_value.reconstruct(decode=False) % 2**n # TODO: add dtype to args args = [(session.session_ptrs[i], th.IntTensor([i]), mask_value, op) for i in range(2)] shares = parallel_execution(evaluate, session.parties)(args) response = MPCTensor(session=session, shares=shares, shape=shape) response.shape = shape if cuda_visible_devices is not None: os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices return response
class Conv2d(SMPCModule): """Convolutional 2D.""" __slots__ = ( "session", "weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "_parameters", ) in_channels: int out_channels: int kernel_size: Tuple[int, ...] stride: Tuple[int, ...] padding: Tuple[int, ...] dilation: Tuple[int, ...] groups: int weight: List[MPCTensor] bias: Optional[MPCTensor] _parameters: OrderedDict def __init__(self, session: Session) -> None: """Initialize Conv2d layer. The stride, padding, dilation and groups are hardcoded for the moment. Args: session (Session): the session used to identify the layer """ self.session = session self.in_channels = None self.out_channels = None self.stride = 1 self.padding = 0 self.dilation = 1 self.groups = 1 self._parameters = None def forward(self, x: MPCTensor) -> MPCTensor: """Do a feedforward through the layer. Args: x (MPCTensor): the input Returns: An MPCTensor representing the layer specific operation applied on the input """ res = x.conv2d( weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ) return res __call__ = forward def set_additional_attributes(self, attributes: Dict) -> None: """Sets attributes of conv apart from weights. Args: attributes (Dict): Attributes with their values. Raises: ValueError: If the attribute does not exist. """ for attr in attributes.keys(): if hasattr(self, attr): setattr(self, attr, attributes[attr]) else: raise ValueError( f"Attribute {attr} does not exist in SyMPC module.") def share_state_dict( self, state_dict: Dict[str, Any], additional_attributes: Optional[Dict[str, Any]] = None, ) -> None: """Share the parameters of the normal Conv2d layer. Args: state_dict (Dict[str, Any]): the state dict that would be shared. additional_attributes (Dict[str, Any]): Attributes of conv apart from weights. """ bias = None if ispointer(state_dict): weight = state_dict["weight"].resolve_pointer_type() if "bias" in weight.client.python.List(state_dict).get(): bias = state_dict["bias"].resolve_pointer_type() shape = weight.client.python.Tuple(weight.shape) shape = shape.get() else: weight = state_dict["weight"] bias = state_dict.get("bias") shape = state_dict["weight"].shape if ispointer(additional_attributes): self.set_additional_attributes( additional_attributes.get().resolve_pointer_type()) else: self.set_additional_attributes(additional_attributes) # Weight shape (out_channel, in_channels/groups, kernel_size_w, kernel_size_h) # we have groups == 1 ( self.out_channels, self.in_channels, kernel_size_w, kernel_size_h, ) = shape self.kernel_size = (kernel_size_w, kernel_size_h) self.weight = MPCTensor(secret=weight, session=self.session, shape=shape) self._parameters = OrderedDict({"weight": self.weight}) if bias is not None: self.bias = MPCTensor(secret=bias, session=self.session, shape=(self.out_channels, )) self._parameters["bias"] = self.bias def parameters(self, recurse: bool = False) -> MPCTensor: """Get the parameters of the Linear module. Args: recurse (bool): For the moment not used. TODO Yields: Each parameter of the module """ for param in self._parameters.values(): yield param def reconstruct_state_dict(self) -> Dict[str, Any]: """Reconstruct the shared state dict. Returns: Dict[str, Any]: The reconstructed state dict. """ state_dict = OrderedDict() state_dict["weight"] = self.weight.reconstruct() if self.bias is not None: state_dict["bias"] = self.bias.reconstruct() return state_dict @staticmethod def get_torch_module(conv_module: "Conv2d") -> torch.nn.Module: """Get a torch module from a given MPC Conv2d module. The parameters of the models are not set. Args: conv_module (Conv2d): the MPC Conv2d layer Returns: torch.nn.Module: A torch Conv2d module. """ bias = conv_module.bias is not None module = torch.nn.Conv2d( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=conv_module.kernel_size, stride=conv_module.stride, padding=conv_module.padding, dilation=conv_module.dilation, groups=conv_module.groups, bias=bias, ) return module
class Conv2d(SMPCModule): """Convolutional 2D.""" __slots__ = [ "session", "weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", ] in_channels: int out_channels: int kernel_size: Tuple[int, ...] stride: Tuple[int, ...] padding: Tuple[int, ...] dilation: Tuple[int, ...] groups: int weight: List[MPCTensor] bias: Optional[MPCTensor] def __init__(self, session: Session) -> None: """Initialize Conv2d layer. The stride, padding, dilation and groups are hardcoded for the moment. Args: session (Session): the session used to identify the layer """ self.session = session self.stride = 1 self.padding = 0 self.dilation = 1 self.groups = 1 def forward(self, x: MPCTensor) -> MPCTensor: """Do a feedforward through the layer. Args: x (MPCTensor): the input Returns: An MPCTensor representing the layer specific operation applied on the input """ res = x.conv2d( weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ) return res __call__ = forward def share_state_dict( self, state_dict: Dict[str, Any], ) -> None: """Share the parameters of the normal Conv2d layer. Args: state_dict (Dict[str, Any]): the state dict that would be shared. Raises: ValueError: If kernel sizes mismatch "kernel_size_w" and "kernel_size_h" """ bias = None if ispointer(state_dict): weight = state_dict["weight"].resolve_pointer_type() if "bias" in weight.client.python.List(state_dict).get(): bias = state_dict["bias"].resolve_pointer_type() shape = weight.client.python.Tuple(weight.shape) shape = shape.get() else: weight = state_dict["weight"] bias = state_dict.get("bias") shape = state_dict["weight"].shape # Weight shape (out_channel, in_channels/groups, kernel_size_w, kernel_size_h) # we have groups == 1 ( self.out_channels, self.in_channels, kernel_size_w, kernel_size_h, ) = shape if kernel_size_w != kernel_size_h: raise ValueError( f"Kernel sizes mismatch {kernel_size_w} and {kernel_size_h}") self.kernel_size = kernel_size_w self.weight = MPCTensor(secret=weight, session=self.session, shape=shape) if bias is not None: self.bias = MPCTensor(secret=bias, session=self.session, shape=(self.out_channels, )) def reconstruct_state_dict(self) -> Dict[str, Any]: """Reconstruct the shared state dict. Returns: Dict[str, Any]: The reconstructed state dict. """ state_dict = OrderedDict() state_dict["weight"] = self.weight.reconstruct() if self.bias is not None: state_dict["bias"] = self.bias.reconstruct() return state_dict @staticmethod def get_torch_module(conv_module: "Conv2d") -> torch.nn.Module: """Get a torch module from a given MPC Conv2d module. The parameters of the models are not set. Args: conv_module (Conv2d): the MPC Conv2d layer Returns: torch.nn.Module: A torch Conv2d module. """ bias = conv_module.bias is not None module = torch.nn.Conv2d( in_channels=conv_module.in_channels, out_channels=conv_module.out_channels, kernel_size=conv_module.kernel_size, bias=bias, ) return module
class Linear(SMPCModule): __slots__ = ["weight", "bias", "session", "in_features", "out_features"] in_features: Tuple[int] out_features: Tuple[int] weight: MPCTensor bias: Optional[MPCTensor] def __init__(self, session) -> None: """The initializer for the Linear layer. Args: session (Session): the session used to identify the layer """ self.bias = None self.session = session def forward(self, x: MPCTensor) -> MPCTensor: """Do a feedforward through the layer. Args: x (MPCTensor): the input Returns: An MPCTensor the layer specific operation applied on the input """ res = x @ self.weight.T if self.bias is not None: res = res + self.bias return res __call__ = forward def share_state_dict( self, state_dict: Dict[str, Any], ) -> None: """Share the parameters of the normal Linear layer. Args: state_dict (Dict[str, Any]): the state dict that would be shared """ bias = None if ispointer(state_dict): weight = state_dict["weight"].resolve_pointer_type() if "bias" in weight.client.python.List(state_dict).get(): bias = state_dict["bias"].resolve_pointer_type() shape = weight.client.python.Tuple(weight.shape) shape = shape.get() else: weight = state_dict["weight"] bias = state_dict.get("bias") shape = state_dict["weight"].shape self.out_features, self.in_features = shape self.weight = MPCTensor(secret=weight, session=self.session, shape=shape) if bias is not None: self.bias = MPCTensor(secret=bias, session=self.session, shape=(self.out_features, )) def reconstruct_state_dict(self) -> Dict[str, Any]: """Reconstruct the shared state dict. Returns: The reconstructed state dict (Dict[str, Any]) """ state_dict = OrderedDict() state_dict["weight"] = self.weight.reconstruct() if self.bias is not None: state_dict["bias"] = self.bias.reconstruct() return state_dict @staticmethod def get_torch_module(linear_module: "Linear") -> torch.nn.Module: """Get a torch module from a given MPC Layer module The parameters of the models are not set. Args: linear_module (Linear): the MPC Linear layer Returns: A torch Linear module """ bias = linear_module.bias is not None module = torch.nn.Linear( in_features=linear_module.in_features, out_features=linear_module.out_features, bias=bias, ) return module