def _bufferize_torch_parameter(worker: AbstractWorker, param: torch.nn.Parameter) -> ParameterPB: protobuf_param = ParameterPB() set_protobuf_id(protobuf_param.id, param.id) protobuf_param.tensor.CopyFrom( syft.serde.protobuf.serde._bufferize(worker, param.data)) protobuf_param.requires_grad = param.requires_grad if param.grad: protobuf_param.grad.CopyFrom( syft.serde.protobuf.serde._bufferize(worker, param.grad)) return protobuf_param
def _bufferize_torch_tensor(worker: AbstractWorker, tensor: torch.Tensor) -> bin: """ This function converts a Torch tensor into a serialized tensor using Protobuf. Depending on the worker's serializer, the tensor contents may be serialized to binary representations using Torch or Numpy, or to a generic inner Protobuf message for cross-platform communication. Args: tensor (torch.Tensor): an input tensor to be serialized Returns: protobuf_obj: Protobuf version of torch tensor. """ serialized_tensor = _serialize_tensor(worker, tensor) if tensor.grad is not None: if hasattr(tensor, "child"): if isinstance(tensor.child, PointerTensor): grad_chain = None else: grad_chain = _bufferize_torch_tensor(worker, tensor.grad) else: grad_chain = _bufferize_torch_tensor(worker, tensor.grad) else: grad_chain = None chain = None if hasattr(tensor, "child"): chain = syft.serde.protobuf.serde._bufferize(worker, tensor.child) protobuf_tensor = TorchTensorPB() set_protobuf_id(protobuf_tensor.id, tensor.id) protobuf_tensor.serializer = SERIALIZERS_SYFT_TO_PROTOBUF[ worker.serializer] if worker.serializer == TENSOR_SERIALIZATION.ALL: protobuf_tensor.contents_data.CopyFrom(serialized_tensor) else: protobuf_tensor.contents_bin = serialized_tensor if chain: protobuf_tensor.chain.CopyFrom(chain) if grad_chain: protobuf_tensor.grad_chain.CopyFrom(grad_chain) if tensor.description: protobuf_tensor.description = tensor.description protobuf_tensor.tags.extend(tensor.tags) return protobuf_tensor
def bufferize(worker: AbstractWorker, param: torch.nn.Parameter) -> ParameterPB: """ This method converts a torch.nn.Parameter into a serialized parameter using ParameterPB. Args: param (torch.nn.Parameter): input nn.parameter to be serialized. Returns: protobuf_param: serialized torch.nn.Parameter. """ protobuf_param = ParameterPB() set_protobuf_id(protobuf_param.id, param.id) protobuf_param.tensor.CopyFrom(syft.serde.protobuf.serde._bufferize(worker, param.data)) protobuf_param.requires_grad = param.requires_grad if param.grad: protobuf_param.grad.CopyFrom(syft.serde.protobuf.serde._bufferize(worker, param.grad)) return protobuf_param