def unbufferize( worker: AbstractWorker, protobuf_tensor: "AdditiveSharingTensorPB" ) -> "AdditiveSharingTensor": """ This function reconstructs a AdditiveSharingTensor given its' attributes in form of a protobuf object. Args: worker: the worker doing the deserialization protobuf_tensor: a protobuf object holding the attributes of the AdditiveSharingTensor Returns: AdditiveSharingTensor: a AdditiveSharingTensor Examples: shared_tensor = unprotobuf(data) """ tensor_id = sy.serde.protobuf.proto.get_protobuf_id(protobuf_tensor.id) crypto_provider_id = sy.serde.protobuf.proto.get_protobuf_id( protobuf_tensor.crypto_provider_id) field = int( getattr(protobuf_tensor, protobuf_tensor.WhichOneof("field_size"))) dtype = protobuf_tensor.dtype tensor = AdditiveSharingTensor( owner=worker, id=tensor_id, field=field, dtype=dtype, crypto_provider=worker.get_worker(crypto_provider_id), ) if protobuf_tensor.location_ids is not None: chain = {} for pb_location_id, share in zip(protobuf_tensor.location_ids, protobuf_tensor.shares): location_id = sy.serde.protobuf.proto.get_protobuf_id( pb_location_id) chain[location_id] = sy.serde.protobuf.serde._unbufferize( worker, share) tensor.child = chain return tensor
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "AdditiveSharingTensor": """ This function reconstructs a AdditiveSharingTensor given it's attributes in form of a tuple. Args: worker: the worker doing the deserialization tensor_tuple: a tuple holding the attributes of the AdditiveSharingTensor Returns: AdditiveSharingTensor: a AdditiveSharingTensor Examples: shared_tensor = detail(data) """ _detail = lambda x: sy.serde.msgpack.serde._detail(worker, x) tensor_id, field, protocol, dtype, crypto_provider, chain, garbage_collect = tensor_tuple crypto_provider = _detail(crypto_provider) tensor = AdditiveSharingTensor( owner=worker, id=_detail(tensor_id), field=_detail(field), protocol=_detail(protocol), dtype=dtype.decode("utf-8"), crypto_provider=worker.get_worker(crypto_provider), ) chain = _detail(chain) tensor.child = {} for share in chain: if share.location is not None: # Remote tensor.child[share.location.id] = share else: # Local tensor.child[share.owner.id] = share tensor.set_garbage_collect_data(garbage_collect) return tensor
def spdz_mul(cmd: Callable, x_sh, y_sh, crypto_provider: AbstractWorker, field: int): """Abstractly multiplies two tensors (mul or matmul) Args: cmd: a callable of the equation to be computed (mul or matmul) x_sh (AdditiveSharingTensor): the left part of the operation y_sh (AdditiveSharingTensor): the right part of the operation crypto_provider (AbstractWorker): an AbstractWorker which is used to generate triples field (int): an integer denoting the size of the field Return: an AdditiveSharingTensor """ assert isinstance(x_sh, sy.AdditiveSharingTensor) assert isinstance(y_sh, sy.AdditiveSharingTensor) locations = x_sh.locations # Get triples a, b, a_mul_b = crypto_provider.generate_triple(cmd, field, x_sh.shape, y_sh.shape, locations) delta = x_sh - a epsilon = y_sh - b # Reconstruct and send to all workers delta = delta.reconstruct() epsilon = epsilon.reconstruct() delta_epsilon = cmd(delta, epsilon) # Trick to keep only one child in the MultiPointerTensor (like in SNN) j1 = torch.ones(delta_epsilon.shape).long().send(locations[0]) j0 = torch.zeros(delta_epsilon.shape).long().send(*locations[1:]) if len(locations) == 2: j = sy.MultiPointerTensor(children=[j1, j0]) else: j = sy.MultiPointerTensor(children=[j1] + list(j0.child.child.values())) delta_b = cmd(delta, b) a_epsilon = cmd(a, epsilon) return delta_epsilon * j + delta_b + a_epsilon + a_mul_b
def detail(worker: AbstractWorker, protocol_tuple: tuple) -> "Protocol": """This function reconstructs a Protocol object given its attributes in the form of a tuple. Args: worker: the worker doing the deserialization protocol_tuple: a tuple holding the attributes of the Protocol Returns: protocol: a Protocol object """ id, tags, description, plans_reference, workers_resolved = map( lambda o: sy.serde._detail(worker, o), protocol_tuple) plans = [] for owner_id, plan_id in plans_reference: if workers_resolved: plan_owner = worker.get_worker(owner_id, fail_hard=True) plan_pointer = worker.request_search(plan_id, location=plan_owner)[0] worker.register_obj(plan_pointer) plans.append((plan_owner, plan_pointer)) else: try: plan_owner = worker.get_worker(owner_id, fail_hard=True) plan_pointer = worker.request_search( plan_id, location=plan_owner)[0] plan = plan_pointer.get() except WorkerNotFoundException: plan = worker.get_obj(plan_id) plans.append((worker.id, plan)) protocol = sy.Protocol(plans=plans, id=id, owner=worker, tags=tags, description=description) return protocol
def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "PointerTensor": """ This function reconstructs a PointerTensor given it's attributes in form of a dictionary. We use the spread operator to pass the dict data as arguments to the init method of PointerTensor Args: worker: the worker doing the deserialization tensor_tuple: a tuple holding the attributes of the PointerTensor Returns: PointerTensor: a PointerTensor Examples: ptr = detail(data) """ # TODO: fix comment for this and simplifier ( obj_id, id_at_location, worker_id, point_to_attr, shape, garbage_collect_data, tags, description, ) = tensor_tuple obj_id = syft.serde.msgpack.serde._detail(worker, obj_id) id_at_location = syft.serde.msgpack.serde._detail( worker, id_at_location) worker_id = syft.serde.msgpack.serde._detail(worker, worker_id) point_to_attr = syft.serde.msgpack.serde._detail(worker, point_to_attr) if shape is not None: shape = syft.hook.create_shape( syft.serde.msgpack.serde._detail(worker, shape)) # If the pointer received is pointing at the current worker, we load the tensor instead if worker_id == worker.id: tensor = worker.get_obj(id_at_location) if point_to_attr is not None and tensor is not None: point_to_attrs = point_to_attr.split(".") for attr in point_to_attrs: if len(attr) > 0: tensor = getattr(tensor, attr) if tensor is not None: if not tensor.is_wrapper and not isinstance( tensor, FrameworkTensor): # if the tensor is a wrapper then it doesn't need to be wrapped # if the tensor isn't a wrapper, BUT it's just a plain torch tensor, # then it doesn't need to be wrapped. # if the tensor is not a wrapper BUT it's also not a torch tensor, # then it needs to be wrapped or else it won't be able to be used # by other interfaces tensor = tensor.wrap() return tensor # Else we keep the same Pointer else: location = syft.hook.local_worker.get_worker(worker_id) ptr = PointerTensor( location=location, id_at_location=id_at_location, owner=worker, id=obj_id, shape=shape, garbage_collect_data=garbage_collect_data, tags=tags, description=description, ) return ptr