def setup_plan_with_promises(self, *args): """ Slightly modifies a plan so that it can work with promises. The plan will also be sent to location with this method. """ for arg in args: if hasattr(arg, "child") and isinstance(arg.child, PromiseTensor): arg.child.plans.add(self.id) prom_owner = arg.owner # As we cannot perform operation between different type of tensors with torch, all the # input tensors should have the same type and the result should also have this same type. result_type = args[0].torch_type() res = PromiseTensor( owner=prom_owner, shape=self.output_shape, tensor_type=result_type, plans=set() ) self.procedure.update_args(args, self.procedure.result_ids) self.procedure.promise_out_id = res.id return res.wrap()
def method(self, *args, **kwargs): arg_shapes = list([self.shape]) arg_ids = list([self.id]) # Convert scalar arguments to tensors to be able to use them with plans args = list(args) for ia in range(len(args)): if not isinstance(args[ia], (torch.Tensor, AbstractTensor)): args[ia] = torch.tensor(args[ia]) for arg in args: arg_shapes.append(arg.shape) @syft.func2plan(arg_shapes) def operation_method(self, *args, **kwargs): return getattr(self, method_name)(*args, **kwargs) self.plans.add(operation_method.id) for arg in args: if isinstance(arg, PromiseTensor): arg.plans.add(operation_method.id) operation_method.procedure.update_args( [self, *args], operation_method.procedure.result_ids) promise_out = PromiseTensor( owner=self.owner, shape=operation_method.output_shape, tensor_type=self.obj_type, plans=set(), ) operation_method.procedure.promise_out_id = promise_out.id if operation_method.owner != self.owner: operation_method.send(self.owner) else: # otherwise object not registered on local worker operation_method.owner.register_obj(operation_method) return promise_out
def BoolTensor(shape, args, **kwargs): return PromiseTensor(shape, tensor_type="torch.BoolTensor", *args, **kwargs).wrap()
def LongTensor(shape, *args, **kwargs): return PromiseTensor(shape, tensor_type="torch.LongTensor", *args, **kwargs).wrap()
def ShortTensor(shape, *args, **kwargs): return PromiseTensor(shape, tensor_type="torch.ShortTensor", *args, **kwargs).wrap()
def DoubleTensor(shape, *args, **kwargs): return PromiseTensor(shape, tensor_type="torch.DoubleTensor", *args, **kwargs).wrap()