Exemple #1
0
        def trigger_origin_backward(grad):
            """
            The function setting back the gradient and calling backward

            Args:
                grad: the gradient tensor being set
            """

            location = self.owner.get_worker(origin)

            # set gradient at the origin
            message = TensorCommandMessage.computation("set_grad",
                                                       id_at_origin, (grad, ),
                                                       {}, None)
            self.owner.send_msg(message=message, location=location)

            # call backward()
            message = TensorCommandMessage.computation("backward",
                                                       id_at_origin, (grad, ),
                                                       {}, None)
            self.owner.send_msg(message=message, location=location)
Exemple #2
0
    async def async_send_command(
        self,
        message: tuple,
        return_ids: str = None,
        return_value: bool = False
    ) -> Union[List[PointerTensor], PointerTensor]:
        """
        Sends a command through a message to the server part attached to the client
        Args:
            message: A tuple representing the message being sent.
            return_ids: A list of strings indicating the ids of the
                tensors that should be returned as response to the command execution.
        Returns:
            A list of PointerTensors or a single PointerTensor if just one response is expected.
        Note: this is the async version of send_command, with the major difference that you
        directly call it on the client worker (so we don't have the recipient kw argument)
        """

        if return_ids is None:
            return_ids = (sy.ID_PROVIDER.pop(), )

        name, target, args_, kwargs_ = message

        # Close the existing websocket connection in order to open a asynchronous connection
        self.close()
        try:
            message = TensorCommandMessage.computation(name, target, args_,
                                                       kwargs_, return_ids,
                                                       return_value)
            ret_val = await self.async_send_msg(message)

        except ResponseSignatureError as e:
            ret_val = None
            return_ids = e.ids_generated
        # Reopen the standard connection
        self.connect()

        if ret_val is None or type(ret_val) == bytes:
            responses = []
            for return_id in return_ids:
                response = PointerTensor(
                    location=self,
                    id_at_location=return_id,
                    owner=sy.local_worker,
                    id=sy.ID_PROVIDER.pop(),
                )
                responses.append(response)

            if len(return_ids) == 1:
                responses = responses[0]
        else:
            responses = ret_val
        return responses
Exemple #3
0
    def send_command(
        self,
        recipient: "BaseWorker",
        cmd_name: str,
        target: PointerTensor = None,
        args_: tuple = (),
        kwargs_: dict = {},
        return_ids: str = None,
        return_value: bool = False,
    ) -> Union[List[PointerTensor], PointerTensor]:
        """
        Sends a command through a message to a recipient worker.

        Args:
            recipient: A recipient worker.
            cmd_name: Command number.
            target: Target pointer Tensor.
            args_: additional args for command execution.
            kwargs_: additional kwargs for command execution.
            return_ids: A list of strings indicating the ids of the
                tensors that should be returned as response to the command execution.

        Returns:
            A list of PointerTensors or a single PointerTensor if just one response is expected.
        """
        if return_ids is None:
            return_ids = (sy.ID_PROVIDER.pop(), )

        try:
            message = TensorCommandMessage.computation(cmd_name, target, args_,
                                                       kwargs_, return_ids,
                                                       return_value)
            ret_val = self.send_msg(message, location=recipient)
        except ResponseSignatureError as e:
            ret_val = None
            return_ids = e.ids_generated

        if ret_val is None or type(ret_val) == bytes:
            responses = []
            for return_id in return_ids:
                response = PointerTensor(
                    location=recipient,
                    id_at_location=return_id,
                    owner=self,
                    id=sy.ID_PROVIDER.pop(),
                )
                responses.append(response)

            if len(return_ids) == 1:
                responses = responses[0]
        else:
            responses = ret_val
        return responses
Exemple #4
0
    def register_hook(self, hook_function):
        """
        This allows to register torch hooks on remote tensors. Such operation
        is tricky because you can't really send the hook function to a remote
        party, as python functions are not serializable within PySyft. So you
        need to keep it attached to the PointerTensor.
        On the other hand, the PointerTensor cannot watch for gradient update
        or be triggered natively by torch when backpropagation happens. That's
        why we actually remotely set a hook that we call a callback hook whose
        function is only to call back the pointer during the backpropagation
        to effectively run the hook function.
        So the workflow is: the remote hook is triggered by pytorch, a message
        is sent back to the pointer owner which has the hook function, then
        the hook function is run remotely on the remote gradient, and a termi-
        nation message is returned to the gradient owner.

        Args:
            hook_function (Callable): the function to run when the hook is
                triggered. It should be able to run on PointerTensor, other-
                wise you will get an error, which will by hard to understand
                as only the backward engine of torch will return a generic
                error.
        """
        # store the hook_function
        self._hook_function = hook_function
        # The hook function can run on tensor.grad_fn, but we always register it
        # on the tensor because we only interact remotely with tensors.
        # `self` can be a pointer to tensor.grad_fn, but we can easily retrieve
        # the pointer to the tensor by temporarily setting self.point_to_attr
        # to None. Note that the id & id_at_location are the same, so now
        # self is (temporarily) a direct reference to tensor, but self.id in
        # the message also refers to the tensor while we might need to refer
        # to the tensor.grad_fn, that's why trigger_hook_function actually
        # checks the .grad_fn attribute
        point_to_attr = self.point_to_attr
        self.point_to_attr = None
        # send a request to set a hook to trigger back the real hook
        self.owner.send_command(
            recipient=self.location,
            cmd_name="register_callback_hook",
            target=self,
            args_=(),
            kwargs_={
                # args & kwargs are not provided, they will be filled by
                # the remote party
                "message":
                TensorCommandMessage.computation("trigger_hook_function",
                                                 self.id, (), {}, None),
                "location":
                self.owner.id,
            },
        )
        self.point_to_attr = point_to_attr
Exemple #5
0
 def remote_send(self, destination: AbstractWorker, requires_grad: bool = False):
     """Request the worker where the tensor being pointed to belongs to send it to destination.
     For instance, if C holds a pointer, ptr, to a tensor on A and calls ptr.remote_send(B),
     C will hold a pointer to a pointer on A which points to the tensor on B.
     Args:
         destination: where the remote value should be sent
         requires_grad: if true updating the grad of the remote tensor on destination B will
             trigger a message to update the gradient of the value on A.
     """
     kwargs_ = {"inplace": False, "requires_grad": requires_grad}
     message = TensorCommandMessage.communication(
         "remote_send", self, (destination.id,), kwargs_, (self.id,)
     )
     self.owner.send_msg(message=message, location=self.location)
     return self
Exemple #6
0
    def send_command(self,
                     recipient: "BaseWorker",
                     message: tuple,
                     return_ids: str = None
                     ) -> Union[List[PointerTensor], PointerTensor]:
        """
        Sends a command through a message to a recipient worker.

        Args:
            recipient: A recipient worker.
            message: A tuple representing the message being sent.
            return_ids: A list of strings indicating the ids of the
                tensors that should be returned as response to the command execution.

        Returns:
            A list of PointerTensors or a single PointerTensor if just one response is expected.
        """
        if return_ids is None:
            return_ids = tuple([sy.ID_PROVIDER.pop()])

        name, target, args_, kwargs_ = message

        try:
            message = TensorCommandMessage.computation(name, target, args_,
                                                       kwargs_, return_ids)
            ret_val = self.send_msg(message, location=recipient)
        except ResponseSignatureError as e:
            ret_val = None
            return_ids = e.ids_generated

        if ret_val is None or type(ret_val) == bytes:
            responses = []
            for return_id in return_ids:
                response = PointerTensor(
                    location=recipient,
                    id_at_location=return_id,
                    owner=self,
                    id=sy.ID_PROVIDER.pop(),
                )
                responses.append(response)

            if len(return_ids) == 1:
                responses = responses[0]
        else:
            responses = ret_val
        return responses