def from_source(cls, tensor_name, source, **kwargs): spdz = cls.get_spdz() q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field if 'encoder' in kwargs: encoder = kwargs['encoder'] else: base = kwargs['base'] if 'base' in kwargs else 10 frac = kwargs['frac'] if 'frac' in kwargs else 4 encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac) if is_table(source): source = encoder.encode(source) _pre = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand) spdz.communicator.remote_share(share=_pre, tensor_name=tensor_name, party=spdz.other_parties[0]) for _party in spdz.other_parties[1:]: r = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand) spdz.communicator.remote_share(share=_table_binary_mod_op( r, _pre, q_field, operator.sub), tensor_name=tensor_name, party=_party) _pre = r share = _table_binary_mod_op(source, _pre, q_field, operator.sub) elif isinstance(source, Party): share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0] else: raise ValueError(f"type={type(source)}") return FixedPointTensor(share, q_field, encoder, tensor_name)
def decrypt_tensor(tensor, private_key, otypes): if isinstance(tensor, np.ndarray): return np.vectorize(private_key.decrypt, otypes)(tensor) elif is_table(tensor): return tensor.mapValues(lambda x: np.vectorize(private_key.decrypt, otypes)(x)) else: raise NotImplementedError(f"type={type(tensor)}")
def __rsub__(self, other): if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)): return self._boxed(_table_binary_op(other.value, self.value, operator.sub)) elif is_table(other): return self._boxed(_table_binary_op(other, self.value, operator.sub)) else: return self._boxed(_table_scalar_op(self.value, other, -1 * operator.sub))
def _transform_op(cls, tensor, op): from fate_arch.session import is_table def _transform(x): arr = np.zeros(shape=x.shape, dtype=object) view = arr.view().reshape(-1) x_array = x.view().reshape(-1) for i in range(arr.size): view[i] = op(x_array[i]) return arr if isinstance(tensor, (int, np.int16, np.int32, np.int64, float, np.float16, np.float32, np.float64, FixedPointNumber)): return op(tensor) if isinstance(tensor, np.ndarray): z = _transform(tensor) return z elif is_table(tensor): f = functools.partial(_transform) return tensor.mapValues(f) else: raise ValueError(f"unsupported type: {type(tensor)}")
def dot_local(self, other, target_name=None): def _vec_dot(x, y, party_idx, q_field, endec): ret = np.dot(x, y) % q_field ret = endec.truncate(ret, party_idx) if not isinstance(ret, np.ndarray): ret = np.array([ret]) return ret if isinstance(other, FixedPointTensor) or isinstance( other, fixedpoint_numpy.FixedPointTensor): other = other.value if isinstance(other, np.ndarray): party_idx = self.get_spdz().party_idx f = functools.partial(_vec_dot, y=other, party_idx=party_idx, q_field=self.q_field, endec=self.endec) ret = self.value.mapValues(f) return self._boxed(ret, target_name) elif is_table(other): ret = table_dot_mod(self.value, other, self.q_field).reshape( (1, -1))[0] ret = self.endec.truncate(ret, self.get_spdz().party_idx) return fixedpoint_numpy.FixedPointTensor(ret, self.q_field, self.endec, target_name) else: raise ValueError(f"type={type(other)}")
def rand_tensor(q_field, tensor): if is_table(tensor): return tensor.mapValues( lambda x: np.random.randint(1, q_field, len(x)).astype(object)) if isinstance(tensor, np.ndarray): arr = np.random.randint(1, q_field, tensor.shape).astype(object) return arr raise NotImplementedError(f"type={type(tensor)}")
def __rsub__(self, other): if isinstance(other, (PaillierFixedPointTensor, FixedPointTensor)): return other - self elif is_table(other): z_value = _table_binary_mod_op(other, self.value, self.q_field, operator.sub) else: z_value = _table_scalar_mod_op(self.value, other, self.q_field, -1 * operator.sub) return self._boxed(z_value)
def encrypt_tensor(tensor, public_key): encrypted_zero = public_key.encrypt(0) if isinstance(tensor, np.ndarray): return np.vectorize(lambda e: encrypted_zero + e)(tensor) elif is_table(tensor): return tensor.mapValues(lambda x: np.vectorize(lambda e: encrypted_zero + e)(x)) else: raise NotImplementedError(f"type={type(tensor)}")
def rand_tensor(q_field, tensor): if is_table(tensor): return tensor.mapValues( lambda x: np.array([random.randint(1, q_field) for _ in x], dtype=object)) if isinstance(tensor, np.ndarray): arr = np.array([random.randint(1, q_field) for _ in tensor], dtype=object) return arr raise NotImplementedError(f"type={type(tensor)}")
def __mul__(self, other): if isinstance(other, FixedPointTensor): z_value = _table_binary_op(self.value, other.value, operator.mul) elif is_table(other): z_value = _table_binary_op(self.value, other, operator.mul) else: z_value = _table_scalar_op(self.value, other, operator.mul) return self._boxed(z_value)
def decode(self, integer_tensor): if isinstance(integer_tensor, (int, np.int16, np.int32, np.int64)): integer_tensor = np.array(integer_tensor) if isinstance(integer_tensor, np.ndarray): return self._decode(integer_tensor) elif is_table(integer_tensor): f = functools.partial(self._decode) return integer_tensor.mapValues(lambda x: f) else: raise ValueError(f"unsupported type: {type(integer_tensor)}")
def encode(self, float_tensor, check_range=True): if isinstance(float_tensor, (float, np.float)): float_tensor = np.array(float_tensor) if isinstance(float_tensor, np.ndarray): return self._encode(float_tensor, check_range) elif is_table(float_tensor): f = functools.partial(self._encode, check_range=check_range) return float_tensor.mapValues(f) else: raise ValueError(f"unsupported type: {type(float_tensor)}")
def truncate(self, integer_tensor, idx=0): if isinstance(integer_tensor, (int, np.int16, np.int32, np.int64)): integer_tensor = np.array(integer_tensor) if isinstance(integer_tensor, np.ndarray): return self._truncate(integer_tensor, idx) elif is_table(integer_tensor): f = functools.partial(self._truncate, idx=idx) return integer_tensor.mapValues(f) else: raise ValueError(f"unsupported type: {type(integer_tensor)}")
def __add__(self, other): if isinstance(other, PaillierFixedPointTensor): z_value = _table_binary_op(self.value, other.value, operator.add) return PaillierFixedPointTensor(z_value) elif isinstance(other, FixedPointTensor): z_value = _table_binary_mod_op(self.value, other.value, self.q_field, operator.add) elif is_table(other): z_value = _table_binary_mod_op(self.value, other, self.q_field, operator.add) else: z_value = _table_scalar_mod_op(self.value, other, self.q_field, operator.add) return self._boxed(z_value)
def rand_tensor(q_field, tensor): if is_table(tensor): return tensor.mapValues(lambda x: np.array( [rand_number_generator(q_field=q_field) for _ in x], dtype=FixedPointNumber)) if isinstance(tensor, np.ndarray): arr = np.zeros(shape=tensor.shape, dtype=FixedPointNumber) view = arr.view().reshape(-1) for i in range(arr.size): view[i] = rand_number_generator(q_field=q_field) return arr raise NotImplementedError(f"type={type(tensor)}")
def urand_tensor(q_field, tensor, use_mix=False): if is_table(tensor): if use_mix: return tensor.mapPartitions(functools.partial(_mix_rand_func, q_field=q_field), use_previous_behavior=False, preserves_partitioning=True) return tensor.mapValues( lambda x: np.array([random.SystemRandom().randint(1, q_field) for _ in x], dtype=object)) if isinstance(tensor, np.ndarray): arr = np.zeros(shape=tensor.shape, dtype=object) view = arr.view().reshape(-1) for i in range(arr.size): view[i] = random.SystemRandom().randint(1, q_field) return arr raise NotImplementedError(f"type={type(tensor)}")
def from_source(cls, tensor_name, source, **kwargs): spdz = cls.get_spdz() q_field = kwargs['q_field'] if 'q_field' in kwargs else spdz.q_field if 'encoder' in kwargs: encoder = kwargs['encoder'] else: base = kwargs['base'] if 'base' in kwargs else 10 frac = kwargs['frac'] if 'frac' in kwargs else 4 encoder = FixedPointEndec(n=q_field, field=q_field, base=base, precision_fractional=frac) if is_table(source): _pre = urand_tensor(q_field, source, use_mix=spdz.use_mix_rand) share = _pre spdz.communicator.remote_share(share=_table_binary_op( source, encoder.decode(_pre), operator.sub), tensor_name=tensor_name, party=spdz.other_parties[-1]) return FixedPointTensor(value=share, q_field=q_field, endec=encoder, tensor_name=tensor_name) elif isinstance(source, Party): share = spdz.communicator.get_share(tensor_name=tensor_name, party=source)[0] is_cipher_source = kwargs[ 'is_cipher_source'] if 'is_cipher_source' in kwargs else True if is_cipher_source: cipher = kwargs.get("cipher") if cipher is None: raise ValueError("Cipher is not provided") share = cipher.distribute_decrypt(share) share = encoder.encode(share) return FixedPointTensor(value=share, q_field=q_field, endec=encoder, tensor_name=tensor_name) else: raise ValueError(f"type={type(source)}")
def dot(self, other, target_name=None): def _vec_dot(x, y): ret = np.dot(x, y) if not isinstance(ret, np.ndarray): ret = np.array([ret]) return ret if isinstance(other, (FixedPointTensor, fixedpoint_numpy.FixedPointTensor)): other = other.value if isinstance(other, np.ndarray): ret = self.value.mapValues(lambda x: _vec_dot(x, other)) return self._boxed(ret, target_name) elif is_table(other): ret = table_dot(self.value, other).reshape((1, -1))[0] return fixedpoint_numpy.PaillierFixedPointTensor(ret, target_name) else: raise ValueError(f"type={type(other)}")
def encode(self, float_tensor, check_range=True): if isinstance(float_tensor, np.ndarray): upscaled = (float_tensor * self.base ** self.precision_fractional).astype(np.int64) if check_range: assert (np.abs(upscaled) < (self.field / 2)).all(), ( f"{float_tensor} cannot be correctly embedded: choose bigger field or a lower precision" ) field_element = upscaled % self.field return field_element if is_table(float_tensor): s = self.base ** self.precision_fractional upscaled = float_tensor.mapValues(lambda x: (x * s).astype(np.int64)) if check_range: assert upscaled.filter(lambda k, v: (np.abs(v) >= self.field / 2).any()).count() == 0, ( f"{float_tensor} cannot be correctly embedded: choose bigger field or a lower precision" ) field_element = upscaled.mapValues(lambda x: x % self.field) return field_element