Beispiel #1
0
    def send(
        self,
        *location,
        inplace: bool = False,
        user: object = None,
        local_autograd: bool = False,
        requires_grad: bool = False,
        preinitialize_grad: bool = False,
        no_wrap: bool = False,
        garbage_collect_data: bool = True,
    ):
        """Gets the pointer to a new remote object.

        One of the most commonly used methods in PySyft, this method serializes the object upon
        which it is called (self), sends the object to a remote worker, creates a pointer to
        that worker, and then returns that pointer from this function.

        Args:
            location: The BaseWorker object which you want to send this object to. Note that
                this is never actually the BaseWorker but instead a class which instantiates the
                BaseWorker abstraction.
            inplace: if true, return the same object instance, else a new wrapper
            user (object,optional): User credentials to be verified.
            local_autograd: Use autograd system on the local machine instead of PyTorch's
                autograd on the workers.
            requires_grad: Default to False. If true, whenever the remote value of this tensor
                will have its gradient updated (for example when calling .backward()), a call
                will be made to set back the local gradient value.
            preinitialize_grad: Initialize gradient for AutogradTensors to a tensor
            no_wrap: If True, wrap() is called on the created pointer
            garbage_collect_data: argument passed down to create_pointer()

        Returns:
            A torch.Tensor[PointerTensor] pointer to self. Note that this
            object will likely be wrapped by a torch.Tensor wrapper.

        Raises:
                SendNotPermittedError: Raised if send is not permitted on this tensor.
        """

        if not self.allow(user=user):
            raise SendNotPermittedError()

        # If you send a pointer p1, you want the pointer to pointer p2 to control
        # the garbage collection and not the remaining old p1 (here self). Because if
        # p2 is not GCed, GCing p1 shouldn't delete the remote tensor, but if you
        # want to do so, as p2 is not GCed, you can still do `del p2`.
        # This allows to chain multiple .send().send() calls.

        if len(location) == 1:

            location = location[0]

            if hasattr(self, "child") and isinstance(self.child, PointerTensor):
                self.child.garbage_collect_data = False
                if self._is_parameter():
                    self.data.child.garbage_collect_data = False

            ptr = self.owner.send(
                self,
                location,
                local_autograd=local_autograd,
                requires_grad=requires_grad,
                preinitialize_grad=preinitialize_grad,
                garbage_collect_data=garbage_collect_data,
            )

            ptr.description = self.description
            ptr.tags = self.tags

            # The last pointer should control remote GC, not the previous self.ptr
            if hasattr(self, "ptr") and self.ptr is not None:
                ptr_ = self.ptr()
                if ptr_ is not None:
                    ptr_.garbage_collect_data = False

            # we need to cache this weak reference to the pointer so that
            # if this method gets called multiple times we can simply re-use
            # the same pointer which was previously created
            self.ptr = weakref.ref(ptr)

            if self._is_parameter():
                if inplace:
                    self.is_wrapper = True
                    with torch.no_grad():
                        self.set_()
                    self.data = ptr
                    output = self
                else:
                    if no_wrap:
                        raise ValueError("Parameters can't accept no_wrap=True")
                    wrapper = torch.Tensor()
                    param_wrapper = torch.nn.Parameter(wrapper)
                    param_wrapper.is_wrapper = True
                    with torch.no_grad():
                        param_wrapper.set_()
                    param_wrapper.data = ptr
                    output = param_wrapper
            else:
                if inplace:
                    self.is_wrapper = True
                    self.set_()
                    self.child = ptr
                    return self
                else:
                    output = ptr if no_wrap else ptr.wrap()

            if self.requires_grad:
                # This is for AutogradTensor to work on MultiPointerTensors
                # With pre-initialized gradients, this should get it from AutogradTensor.grad
                if preinitialize_grad:
                    grad = output.child.grad
                else:
                    grad = output.attr("grad")

                output.grad = grad

                # Because of the way PyTorch works, .grad is prone to
                # create entirely new Python objects for the tensor, which
                # inadvertently deletes our custom attributes (like .child)
                # But, if we keep a backup reference around, PyTorch seems
                # to re-use it, which means .grad keeps the attributes we
                # want it to keep. #HackAlert
                output.backup_grad = grad

            if local_autograd:
                output = syft.AutogradTensor(data=output, preinitialize_grad=preinitialize_grad).on(
                    output
                )

        else:

            children = list()
            for loc in location:
                children.append(self.clone().send(loc, no_wrap=True))

            output = syft.MultiPointerTensor(children=children)

            if not no_wrap:
                output = output.wrap()

        return output
Beispiel #2
0
 def _before_send(self, *location, user: object = None, **kwargs):
     if not self.allow(user):
         raise SendNotPermittedError()