示例#1
0
    def setup_plan_with_promises(self, *args):
        """ Slightly modifies a plan so that it can work with promises.
        The plan will also be sent to location with this method.
        """
        for arg in args:
            if hasattr(arg, "child") and isinstance(arg.child, PromiseTensor):
                arg.child.plans.add(self.id)
                prom_owner = arg.owner

        # As we cannot perform operation between different type of tensors with torch, all the
        # input tensors should have the same type and the result should also have this same type.
        result_type = args[0].torch_type()

        res = PromiseTensor(
            owner=prom_owner, shape=self.output_shape, tensor_type=result_type, plans=set()
        )

        self.procedure.update_args(args, self.procedure.result_ids)
        self.procedure.promise_out_id = res.id

        return res.wrap()
示例#2
0
            def method(self, *args, **kwargs):

                arg_shapes = list([self.shape])
                arg_ids = list([self.id])

                # Convert scalar arguments to tensors to be able to use them with plans
                args = list(args)
                for ia in range(len(args)):
                    if not isinstance(args[ia],
                                      (torch.Tensor, AbstractTensor)):
                        args[ia] = torch.tensor(args[ia])

                for arg in args:
                    arg_shapes.append(arg.shape)

                @syft.func2plan(arg_shapes)
                def operation_method(self, *args, **kwargs):
                    return getattr(self, method_name)(*args, **kwargs)

                self.plans.add(operation_method.id)
                for arg in args:
                    if isinstance(arg, PromiseTensor):
                        arg.plans.add(operation_method.id)

                operation_method.procedure.update_args(
                    [self, *args], operation_method.procedure.result_ids)

                promise_out = PromiseTensor(
                    owner=self.owner,
                    shape=operation_method.output_shape,
                    tensor_type=self.obj_type,
                    plans=set(),
                )
                operation_method.procedure.promise_out_id = promise_out.id

                if operation_method.owner != self.owner:
                    operation_method.send(self.owner)
                else:  # otherwise object not registered on local worker
                    operation_method.owner.register_obj(operation_method)

                return promise_out
示例#3
0
 def BoolTensor(shape, args, **kwargs):
     return PromiseTensor(shape,
                          tensor_type="torch.BoolTensor",
                          *args,
                          **kwargs).wrap()
示例#4
0
 def LongTensor(shape, *args, **kwargs):
     return PromiseTensor(shape,
                          tensor_type="torch.LongTensor",
                          *args,
                          **kwargs).wrap()
示例#5
0
 def ShortTensor(shape, *args, **kwargs):
     return PromiseTensor(shape,
                          tensor_type="torch.ShortTensor",
                          *args,
                          **kwargs).wrap()
示例#6
0
 def DoubleTensor(shape, *args, **kwargs):
     return PromiseTensor(shape,
                          tensor_type="torch.DoubleTensor",
                          *args,
                          **kwargs).wrap()