def fix_prec(self, *args, storage="auto", field_type="int100", **kwargs): if not kwargs.get("owner"): kwargs["owner"] = self.owner if self.is_wrapper: self.child = self.child.fix_prec(*args, **kwargs) return self base = kwargs.get("base", 10) prec_fractional = kwargs.get("precision_fractional", 3) max_precision = _get_maximum_precision() need_large_prec = self._requires_large_precision(max_precision, base, prec_fractional) if storage == "crt": assert ( "field" not in kwargs ), 'When storage is set to "crt", choose the field size with the field_type argument' possible_field_types = list(_moduli_for_fields.keys()) assert ( field_type in possible_field_types ), f"Choose field_type in {possible_field_types} to build CRT tensors" residues = {} for mod in _moduli_for_fields[field_type]: residues[mod] = ( syft.FixedPrecisionTensor(*args, field=mod, **kwargs) .on(self) .child.fix_precision(check_range=False) .wrap() ) return syft.CRTPrecisionTensor(residues, *args, **kwargs).wrap() if need_large_prec or storage == "large": return ( syft.LargePrecisionTensor(*args, **kwargs) .on(self) .child.fix_large_precision() .wrap() ) else: assert not need_large_prec, "This tensor needs large precision to be correctly stored" if "internal_type" in kwargs: warnings.warn( "do not provide internal_type if data does not need LargePrecisionTensor to be stored" ) del kwargs["internal_type"] return syft.FixedPrecisionTensor(*args, **kwargs).on(self).enc_fix_prec()
def fix_prec(self, *args, no_wrap: bool = False, **kwargs): """ Convert a tensor or syft tensor to fixed precision Args: *args (tuple): args to transmit to the fixed precision tensor no_wrap (bool): if True, we don't add a wrapper on top of the fixed precision tensor **kwargs (dict): kwargs to transmit to the fixed precision tensor """ if not kwargs.get("owner"): kwargs["owner"] = self.owner if self.is_wrapper: child = self.child.fix_prec(*args, **kwargs) if no_wrap: return child else: return child.wrap() base = kwargs.get("base", 10) prec_fractional = kwargs.get("precision_fractional", 3) max_precision = _get_maximum_precision() fpt_tensor = syft.FixedPrecisionTensor(*args, **kwargs).on( self, wrap=False).fix_precision() if not no_wrap: fpt_tensor = fpt_tensor.wrap() return fpt_tensor
def fix_prec(self, *args, **kwargs): base = kwargs.get("base", 10) prec_fractional = kwargs.get("precision_fractional", 3) max_precision = _get_maximum_precision() if self._requires_large_precision(max_precision, base, prec_fractional): return (syft.LargePrecisionTensor( *args, **kwargs).on(self).child.fix_large_precision().wrap()) else: return syft.FixedPrecisionTensor( *args, **kwargs).on(self).enc_fix_prec().wrap()
def fix_prec(self, *args, no_wrap: bool = False, **kwargs): """ Convert a tensor or syft tensor to fixed precision Args: *args (tuple): args to transmit to the fixed precision tensor no_wrap (bool): if True, we don't add a wrapper on top of the fixed precision tensor **kwargs (dict): kwargs to transmit to the fixed precision tensor """ if not kwargs.get("owner"): kwargs["owner"] = self.owner wrap_dtype = None if kwargs.get("protocol") == "falcon": wrap_dtype = torch.int64 del kwargs["protocol"] if self.is_wrapper: child = self.child.fix_prec(*args, **kwargs) if no_wrap: return child else: return child.wrap() base = kwargs.get("base", 9) prec_fractional = kwargs.get("precision_fractional", 3) max_precision = _get_maximum_precision() fpt_tensor = syft.FixedPrecisionTensor(*args, **kwargs).on( self, wrap=False).fix_precision() if not no_wrap: fpt_tensor = fpt_tensor.wrap(type=wrap_dtype) # hhk : the old version : fpt_tensor.wrap() return fpt_tensor
def fix_prec(self, *args, storage="auto", field_type="int100", no_wrap: bool = False, **kwargs): """ Convert a tensor or syft tensor to fixed precision Args: *args (tuple): args to transmit to the fixed precision tensor storage (str): code to define the type of fixed precision tensor (values in (auto, crt, large)) field_type (str): code to define a storage type (only for CRTPrecisionTensor) no_wrap (bool): if True, we don't add a wrapper on top of the fixed precision tensor **kwargs (dict): kwargs to transmit to the fixed precision tensor """ if not kwargs.get("owner"): kwargs["owner"] = self.owner if self.is_wrapper: self.child = self.child.fix_prec(*args, **kwargs) if no_wrap: return self.child else: return self base = kwargs.get("base", 10) prec_fractional = kwargs.get("precision_fractional", 3) max_precision = _get_maximum_precision() need_large_prec = self._requires_large_precision(max_precision, base, prec_fractional) if storage == "crt": assert ( "field" not in kwargs ), 'When storage is set to "crt", choose the field size with the field_type argument' possible_field_types = list(_moduli_for_fields.keys()) assert ( field_type in possible_field_types ), f"Choose field_type in {possible_field_types} to build CRT tensors" residues = {} for mod in _moduli_for_fields[field_type]: residues[mod] = ( syft.FixedPrecisionTensor(*args, field=mod, **kwargs) .on(self, wrap=False) .fix_precision(check_range=False) .wrap() ) fpt_tensor = syft.CRTPrecisionTensor(residues, *args, **kwargs) elif need_large_prec or storage == "large": fpt_tensor = ( syft.LargePrecisionTensor(*args, **kwargs) .on(self, wrap=False) .fix_large_precision() ) else: assert not need_large_prec, "This tensor needs large precision to be correctly stored" if "internal_type" in kwargs: warnings.warn( "do not provide internal_type if data does not need LargePrecisionTensor to be stored" ) del kwargs["internal_type"] fpt_tensor = ( syft.FixedPrecisionTensor(*args, **kwargs).on(self, wrap=False).fix_precision() ) if not no_wrap: fpt_tensor = fpt_tensor.wrap() return fpt_tensor
def _pool2d(input, kernel_size: int = 2, stride: int = 2, padding=0, dilation=1, ceil_mode=None, mode="avg"): if isinstance(kernel_size, tuple): assert kernel_size[0] == kernel_size[1] kernel_size = kernel_size[0] if isinstance(stride, tuple): assert stride[0] == stride[1] stride = stride[0] input_fp = input input = input.child locations = input.locations im_reshaped_shares = {} params = {} for location in locations: input_share = input.child[location.id] im_reshaped_shares[location.id], *params[location.id] = remote( _pre_pool, location=location)(input_share, kernel_size, stride, padding, dilation, return_value=False, return_arity=6) im_reshaped = sy.AdditiveSharingTensor(im_reshaped_shares, **input.get_class_attributes()) if mode == "max": # We have optimisations when the kernel is small, namely a square of size 2 or 3 # to reduce the number of rounds and the total number of comparisons. # See more in Appendice C.3 https://arxiv.org/pdf/2006.04593.pdf def max_half_split(tensor4d, half_size): """ Split the tensor on 2 halves on the last dim and return the maximum half """ left, right = tensor4d[:, :, :, :half_size], tensor4d[:, :, :, half_size:] max_half = left + (right >= left) * (right - left) return max_half if im_reshaped.shape[-1] == 4: # Compute the max as a binary tree: 2 steps are needed for 4 values res = max_half_split(im_reshaped, 2) res = max_half_split(res, 1) elif im_reshaped.shape[-1] == 9: # For 9 values we need 4 steps: we process the 8 first values and then # compute the max with the 9th value res = max_half_split(im_reshaped[:, :, :, :8], 4) res = max_half_split(res, 2) left = max_half_split(res, 1) right = im_reshaped[:, :, :, 8:] res = left + (right >= left) * (right - left) else: res = im_reshaped.max(dim=-1) elif mode == "avg": res = im_reshaped.mean(dim=-1) else: raise ValueError(f"In pool2d, mode should be avg or max, not {mode}.") res_shares = {} for location in locations: res_share = res.child[location.id] res_share = remote(_post_pool, location=location)(res_share, *params[location.id]) res_shares[location.id] = res_share result_fp = sy.FixedPrecisionTensor(**input_fp.get_class_attributes()).on( sy.AdditiveSharingTensor(res_shares, **res.get_class_attributes()), wrap=False) return result_fp
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): """ Overloads torch.nn.functional.conv2d to be able to use MPC on convolutional networks. The idea is to unroll the input and weight matrices to compute a matrix multiplication equivalent to the convolution. Args: input: input image weight: convolution kernels bias: optional additive bias stride: stride of the convolution kernels padding: implicit paddings on both sides of the input. dilation: spacing between kernel elements groups: split input into groups, in_channels should be divisible by the number of groups Returns: the result of the convolution as a fixed precision tensor. """ input_fp, weight_fp = input, weight if isinstance(input.child, FrameworkTensor) or isinstance( weight.child, FrameworkTensor): assert isinstance(input.child, FrameworkTensor) assert isinstance(weight.child, FrameworkTensor) im_reshaped, weight_reshaped, *params = _pre_conv( input, weight, bias, stride, padding, dilation, groups) if groups > 1: res = [] chunks_im = torch.chunk(im_reshaped, groups, dim=2) chunks_weights = torch.chunk(weight_reshaped, groups, dim=0) for g in range(groups): tmp = chunks_im[g].matmul(chunks_weights[g]) res.append(tmp) result = torch.cat(res, dim=2) else: result = im_reshaped.matmul(weight_reshaped) result = _post_conv(bias, result, *params) return result.wrap() input, weight = input.child, weight.child if bias is not None: bias = bias.child assert isinstance( bias, sy.AdditiveSharingTensor ), "Have you provided bias as a kwarg? If so, please remove `bias=`." locations = input.locations im_reshaped_shares = {} weight_reshaped_shares = {} params = {} for location in locations: input_share = input.child[location.id] weight_share = weight.child[location.id] bias_share = bias.child[location.id] if bias is not None else None ( im_reshaped_shares[location.id], weight_reshaped_shares[location.id], *params[location.id], ) = remote(_pre_conv, location=location)( input_share, weight_share, bias_share, stride, padding, dilation, groups, return_value=False, return_arity=6, ) im_reshaped = sy.FixedPrecisionTensor( **input_fp.get_class_attributes()).on(sy.AdditiveSharingTensor( im_reshaped_shares, **input.get_class_attributes()), wrap=False) weight_reshaped = sy.FixedPrecisionTensor( **weight_fp.get_class_attributes()).on(sy.AdditiveSharingTensor( weight_reshaped_shares, **input.get_class_attributes()), wrap=False) # Now that everything is set up, we can compute the convolution as a simple matmul if groups > 1: res = [] chunks_im = torch.chunk(im_reshaped, groups, dim=2) chunks_weights = torch.chunk(weight_reshaped, groups, dim=0) for g in range(groups): tmp = chunks_im[g].matmul(chunks_weights[g]) res.append(tmp) res_fp = torch.cat(res, dim=2) res = res_fp.child else: res_fp = im_reshaped.matmul(weight_reshaped) res = res_fp.child # and then we reshape the result res_shares = {} for location in locations: bias_share = bias.child[location.id] if bias is not None else None res_share = res.child[location.id] res_share = remote(_post_conv, location=location)(bias_share, res_share, *params[location.id]) res_shares[location.id] = res_share result_fp = sy.FixedPrecisionTensor(**res_fp.get_class_attributes()).on( sy.AdditiveSharingTensor(res_shares, **res.get_class_attributes()), wrap=False) return result_fp