Exemplo n.º 1
0
    def execute_communication_action(self, action: CommunicationAction) -> PointerTensor:
        owner = action.target.owner
        destinations = [self.worker.get_worker(id_) for id_ in action.args]
        kwargs_ = action.kwargs

        if owner != self.worker:
            return None
        else:
            obj = self.get_obj(action.target.id)
            response = owner.send(obj, *destinations, **kwargs_)
            response.garbage_collect_data = False
            if kwargs_.get("requires_grad", False):
                response = hook_args.register_response(
                    "send", response, [sy.ID_PROVIDER.pop()], self.worker
                )
            else:
                self.object_store.rm_obj(action.target.id)
            return response
Exemplo n.º 2
0
 def execute_communication_action(self, action: CommunicationAction) -> PointerTensor:
     obj_id = action.obj_id
     source = action.source
     destinations = action.destinations
     kwargs_ = action.kwargs
     source_worker = self.get_worker(source)
     if source_worker != self:
         return None
     else:
         obj = self.get_obj(obj_id)
         response = source_worker.send(obj, *destinations, **kwargs_)
         response.garbage_collect_data = False
         if kwargs_.get("requires_grad", False):
             response = hook_args.register_response(
                 "send", response, [sy.ID_PROVIDER.pop()], self
             )
         else:
             self.rm_obj(obj_id)
         return response
Exemplo n.º 3
0
    def execute_communication_action(
            self, action: CommunicationAction) -> PointerTensor:
        obj_id = action.obj_id
        source = action.source
        destinations = action.destinations
        kwargs = action.kwargs
        source_worker = self.get_worker(source)
        if source_worker != self:
            return None
        else:
            obj = self.get_obj(obj_id)
            response = source_worker.send(obj, *destinations, **kwargs)

            response = hook_args.register_response("send", response,
                                                   [sy.ID_PROVIDER.pop()],
                                                   self)

            self.rm_obj(obj_id)
            return response
Exemplo n.º 4
0
    def execute_communication_action(
            self, action: CommunicationAction) -> PointerTensor:
        obj_id = action.obj_id
        source = action.source
        destinations = action.destinations
        kwargs = action.kwargs
        source_worker = self.get_worker(source)
        if source_worker != self:
            return None
        else:
            obj = self.get_obj(obj_id)
            response = source_worker.send(obj, *destinations, **kwargs)

            response = hook_args.register_response("send", response,
                                                   [sy.ID_PROVIDER.pop()],
                                                   self)

            # @lariffle: We only remove remote objects when the operations are inplace
            # otherwise we could have stale pointers which we really want to avoid.
            # TODO: needs more discussion
            if kwargs.get("inplace"):
                self.rm_obj(obj_id)
            return response
Exemplo n.º 5
0
    def execute_command(self, message: tuple) -> PointerTensor:
        """
        Executes commands received from other workers.

        Args:
            message: A tuple specifying the command and the args.

        Returns:
            A pointer to the result.
        """

        (command_name, _self, args, kwargs), return_ids = message

        # TODO add kwargs
        command_name = command_name
        # Handle methods
        if _self is not None:
            if type(_self) == int:
                _self = BaseWorker.get_obj(self, _self)
                if _self is None:
                    return
            if type(_self) == str and _self == "self":
                _self = self
            if sy.framework.is_inplace_method(command_name):
                # TODO[jvmancuso]: figure out a good way to generalize the
                # above check (#2530)
                getattr(_self, command_name)(*args, **kwargs)
                return
            else:
                try:
                    response = getattr(_self, command_name)(*args, **kwargs)
                except TypeError:
                    # TODO Andrew thinks this is gross, please fix. Instead need to properly deserialize strings
                    new_args = [
                        arg.decode("utf-8") if isinstance(arg, bytes) else arg
                        for arg in args
                    ]
                    response = getattr(_self, command_name)(*new_args,
                                                            **kwargs)
        # Handle functions
        else:
            # At this point, the command is ALWAYS a path to a
            # function (i.e., torch.nn.functional.relu). Thus,
            # we need to fetch this function and run it.

            sy.framework.command_guard(command_name)

            paths = command_name.split(".")
            command = self
            for path in paths:
                command = getattr(command, path)

            response = command(*args, **kwargs)

        # some functions don't return anything (such as .backward())
        # so we need to check for that here.
        if response is not None:
            # Register response and create pointers for tensor elements
            try:
                response = hook_args.register_response(command_name, response,
                                                       list(return_ids), self)
                return response
            except ResponseSignatureError:
                return_id_provider = sy.ID_PROVIDER
                return_id_provider.set_next_ids(return_ids, check_ids=False)
                return_id_provider.start_recording_ids()
                response = hook_args.register_response(command_name, response,
                                                       return_id_provider,
                                                       self)
                new_ids = return_id_provider.get_recorded_ids()
                raise ResponseSignatureError(new_ids)
Exemplo n.º 6
0
    def execute_computation_action(self, action: ComputationAction) -> PointerTensor:
        """
        Executes commands received from other workers.
        Args:
            message: A tuple specifying the command and the args.
        Returns:
            The result or None if return_value is False.
        """

        op_name = action.name
        _self = action.target
        args_ = action.args
        kwargs_ = action.kwargs
        return_ids = action.return_ids
        return_value = action.return_value

        # Handle methods
        if _self is not None:
            if type(_self) == int:
                _self = self.get_obj(_self)
                if _self is None:
                    return
            elif isinstance(_self, str):
                if _self == "self":
                    _self = self.worker
                else:
                    res: list = self.worker.search(_self)
                    if len(res) != 1:
                        raise ValueError(
                            f"Searching for {_self} on {self.worker.id}. /!\\ {len(res)} found"
                        )
                    _self = res[0]
            if sy.framework.is_inplace_method(op_name):
                # TODO[jvmancuso]: figure out a good way to generalize the
                # above check (#2530)
                getattr(_self, op_name)(*args_, **kwargs_)
                return
            else:
                try:
                    response = getattr(_self, op_name)(*args_, **kwargs_)
                except TypeError:
                    # TODO Andrew thinks this is gross, please fix. Instead need to
                    # properly deserialize strings
                    new_args = [
                        arg.decode("utf-8") if isinstance(arg, bytes) else arg for arg in args_
                    ]
                    response = getattr(_self, op_name)(*new_args, **kwargs_)
        # Handle functions
        else:
            # At this point, the command is ALWAYS a path to a
            # function (i.e., torch.nn.functional.relu). Thus,
            # we need to fetch this function and run it.

            sy.framework.command_guard(op_name)

            paths = op_name.split(".")
            command = self.worker
            for path in paths:
                command = getattr(command, path)

            response = command(*args_, **kwargs_)

        # some functions don't return anything (such as .backward())
        # so we need to check for that here.
        if response is not None:
            # Register response and create pointers for tensor elements
            try:
                response = hook_args.register_response(
                    op_name, response, list(return_ids), self.worker
                )
                # TODO: Does this mean I can set return_value to False and still
                # get a response? That seems surprising.
                if return_value or isinstance(response, (int, float, bool, str)):
                    return response
                else:
                    return None
            except ResponseSignatureError:
                return_id_provider = sy.ID_PROVIDER
                return_id_provider.set_next_ids(return_ids, check_ids=False)
                return_id_provider.start_recording_ids()
                response = hook_args.register_response(
                    op_name, response, return_id_provider, self.worker
                )
                new_ids = return_id_provider.get_recorded_ids()
                raise ResponseSignatureError(new_ids)