Exemple #1
0
class BaseWorker(AbstractWorker):
    """Contains functionality to all workers.

    Other workers will extend this class to inherit all functionality necessary
    for PySyft's protocol. Extensions of this class overrides two key methods
    _send_msg() and _recv_msg() which are responsible for defining the
    procedure for sending a binary message to another worker.

    At it's core, BaseWorker (and all workers) is a collection of objects owned
    by a certain machine. Each worker defines how it interacts with objects on
    other workers as well as how other workers interact with objects owned by
    itself. Objects are either tensors or of any type supported by the PySyft
    protocol.

    Args:
        hook: A reference to the TorchHook object which is used
            to modify PyTorch with PySyft's functionality.
        id: An optional string or integer unique id of the worker.
        known_workers: An optional dictionary of all known workers on a
            network which this worker may need to communicate with in the
            future. The key of each should be each worker's unique ID and
            the value should be a worker class which extends BaseWorker.
            Extensions of BaseWorker will include advanced functionality
            for adding to this dictionary(node discovery). In some cases,
            one can initialize this with known workers to help bootstrap
            the network.
        data: Initialize workers with data on creating worker object
        is_client_worker: An optional boolean parameter to indicate
            whether this worker is associated with an end user client. If
            so, it assumes that the client will maintain control over when
            variables are instantiated or deleted as opposed to handling
            tensor/variable/model lifecycle internally. Set to True if this
            object is not where the objects will be stored, but is instead
            a pointer to a worker that exists elsewhere.
        log_msgs: An optional boolean parameter to indicate whether all
            messages should be saved into a log for later review. This is
            primarily a development/testing feature.
        auto_add: Determines whether to automatically add this worker to the
            list of known workers.
        message_pending_time (optional): A number of seconds to delay the messages to be sent.
            The argument may be a floating point number for subsecond
            precision.
    """
    def __init__(
        self,
        hook: "FrameworkHook",
        id: Union[int, str] = 0,
        data: Union[List, tuple] = None,
        is_client_worker: bool = False,
        log_msgs: bool = False,
        verbose: bool = None,
        auto_add: bool = True,
        message_pending_time: Union[int, float] = 0,
    ):
        """Initializes a BaseWorker."""
        super().__init__()
        self.hook = hook

        self.object_store = ObjectStore(owner=self)

        self.id = id
        self.is_client_worker = is_client_worker
        self.log_msgs = log_msgs
        if verbose is None:
            self.verbose = hook.verbose if hasattr(hook, "verbose") else False
        else:
            self.verbose = verbose

        if isinstance(hook, sy.TorchHook) and hasattr(hook, "_syft_workers"):
            hook._syft_workers.add(self)

        self.auto_add = auto_add
        self._message_pending_time = message_pending_time
        self.msg_history = list()

        # For performance, we cache all possible message types
        self._message_router = {
            TensorCommandMessage: self.execute_tensor_command,
            PlanCommandMessage: self.execute_plan_command,
            WorkerCommandMessage: self.execute_worker_command,
            ObjectMessage: self.handle_object_msg,
            ObjectRequestMessage: self.respond_to_obj_req,
            ForceObjectDeleteMessage: self.
            handle_delete_object_msg,  # FIXME: there is no ObjectDeleteMessage
            ForceObjectDeleteMessage: self.handle_force_delete_object_msg,
            IsNoneMessage: self.is_object_none,
            GetShapeMessage: self.handle_get_shape_message,
            SearchMessage: self.respond_to_search,
        }

        self._plan_command_router = {
            codes.PLAN_CMDS.FETCH_PLAN: self._fetch_plan_remote,
            codes.PLAN_CMDS.FETCH_PROTOCOL: self._fetch_protocol_remote,
        }

        self.load_data(data)

        # Declare workers as appropriate
        self._known_workers = {}
        if auto_add:
            if hook is not None and hook.local_worker is not None:
                known_workers = self.hook.local_worker._known_workers
                if self.id in known_workers:
                    if isinstance(known_workers[self.id], type(self)):
                        # If a worker with this id already exists and it has the
                        # same type as the one being created, we copy all the attributes
                        # of the existing worker to this one.
                        self.__dict__.update(known_workers[self.id].__dict__)
                    else:
                        raise RuntimeError(
                            "Worker initialized with the same id and different types."
                        )
                else:
                    hook.local_worker.add_worker(self)
                    for worker_id, worker in hook.local_worker._known_workers.items(
                    ):
                        if worker_id not in self._known_workers:
                            self.add_worker(worker)
                        if self.id not in worker._known_workers:
                            worker.add_worker(self)
            else:
                # Make the local worker aware of itself
                # self is the to-be-created local worker
                self.add_worker(self)

        if hook is None:
            self.framework = None
        else:
            # TODO[jvmancuso]: avoid branching here if possible, maybe by changing code in
            #     execute_tensor_command or command_guard to not expect an attribute named "torch"
            #     (#2530)
            self.framework = hook.framework
            if hasattr(hook, "torch"):
                self.torch = self.framework
                self.remote = Remote(self, "torch")
            elif hasattr(hook, "tensorflow"):
                self.tensorflow = self.framework
                self.remote = Remote(self, "tensorflow")

        # storage object for crypto primitives
        self.crypto_store = PrimitiveStorage(owner=self)

    def register_obj(self, obj):
        self.object_store.register_obj(self, obj)

    def clear_objects(self, return_self: bool = True):
        """Removes all objects from the object storage.

        Note: the "return self" statement is kept for backward compatibility
        with the Udacity Secure and Private ML course.

        Args:
            return_self: flag, whether to return self as return value

        Returns:
            self, if return_self if True, else None

        """
        self.object_store.clear_objects()

        # return based on `return_self` flag is required by Udacity course
        return self if return_self else None

    @contextmanager
    def registration_enabled(self):
        self.is_client_worker = False
        try:
            yield self
        finally:
            self.is_client_worker = True

    def remove_worker_from_registry(self, worker_id):
        """Removes a worker from the dictionary of known workers.
        Args:
            worker_id: id to be removed
        """
        del self._known_workers[worker_id]

    def remove_worker_from_local_worker_registry(self):
        """Removes itself from the registry of hook.local_worker.
        """
        self.hook.local_worker.remove_worker_from_registry(worker_id=self.id)

    def load_data(
            self, data: List[Union[FrameworkTensorType,
                                   AbstractTensor]]) -> None:
        """Allows workers to be initialized with data when created

           The method registers the tensor individual tensor objects.

        Args:

            data: A list of tensors
        """

        if data:
            for tensor in data:
                self.register_obj(tensor)
                tensor.owner = self

    def send_msg(self, message: Message, location: "BaseWorker") -> object:
        """Implements the logic to send messages.

        The message is serialized and sent to the specified location. The
        response from the location (remote worker) is deserialized and
        returned back.

        Every message uses this method.

        Args:
            msg_type: A integer representing the message type.
            message: A Message object
            location: A BaseWorker instance that lets you provide the
                destination to send the message.

        Returns:
            The deserialized form of message from the worker at specified
            location.
        """
        if self.verbose:
            print(f"worker {self} sending {message} to {location}")

        # Step 1: serialize the message to a binary
        bin_message = sy.serde.serialize(message, worker=self)

        # Step 2: send the message and wait for a response
        bin_response = self._send_msg(bin_message, location)

        # Step 3: deserialize the response
        response = sy.serde.deserialize(bin_response, worker=self)

        return response

    def recv_msg(self, bin_message: bin) -> bin:
        """Implements the logic to receive messages.

        The binary message is deserialized and routed to the appropriate
        function. And, the response serialized the returned back.

        Every message uses this method.

        Args:
            bin_message: A binary serialized message.

        Returns:
            A binary message response.
        """
        # Step 0: deserialize message
        msg = sy.serde.deserialize(bin_message, worker=self)

        # Step 1: save message and/or log it out
        if self.log_msgs:
            self.msg_history.append(msg)

        if self.verbose:
            print(f"worker {self} received {type(msg).__name__} {msg.contents}"
                  if hasattr(msg, "contents") else
                  f"worker {self} received {type(msg).__name__}")

        # Step 2: route message to appropriate function
        response = self._message_router[type(msg)](msg)

        # Step 3: Serialize the message to simple python objects
        bin_response = sy.serde.serialize(response, worker=self)

        return bin_response

        # SECTION:recv_msg() uses self._message_router to route to these methods

    def send(
        self,
        obj: Union[FrameworkTensorType, AbstractTensor],
        workers: "BaseWorker",
        ptr_id: Union[str, int] = None,
        garbage_collect_data=None,
        requires_grad=False,
        create_pointer=True,
        **kwargs,
    ) -> ObjectPointer:
        """Sends tensor to the worker(s).

        Send a syft or torch tensor/object and its child, sub-child, etc (all the
        syft chain of children) to a worker, or a list of workers, with a given
        remote storage address.

        Args:
            obj: A syft/framework tensor/object to send.
            workers: A BaseWorker object representing the worker(s) that will
                receive the object.
            ptr_id: An optional string or integer indicating the remote id of
                the object on the remote worker(s).
            garbage_collect_data: argument passed down to create_pointer()
            requires_grad: Default to False. If true, whenever the remote value of this tensor
                will have its gradient updated (for example when calling .backward()), a call
                will be made to set back the local gradient value.
            create_pointer: if set to False, no pointer to the remote value will be built.

        Example:
            >>> import torch
            >>> import syft as sy
            >>> hook = sy.TorchHook(torch)
            >>> bob = sy.VirtualWorker(hook)
            >>> x = torch.Tensor([1, 2, 3, 4])
            >>> x.send(bob, 1000)
            Will result in bob having the tensor x with id 1000

        Returns:
            A PointerTensor object representing the pointer to the remote worker(s).
        """

        if not isinstance(workers, (list, tuple)):
            workers = [workers]

        assert len(workers) > 0, "Please provide workers to receive the data"

        if len(workers) == 1:
            worker = workers[0]
        else:
            # If multiple workers are provided , you want to send the same tensor
            # to all the workers. You'll get multiple pointers, or a pointer
            # with different locations
            raise NotImplementedError("Sending to multiple workers is not \
                                        supported at the moment")

        worker = self.get_worker(worker)

        if requires_grad:
            obj.origin = self.id
            obj.id_at_origin = obj.id

        # Send the object
        self.send_obj(obj, worker)

        if requires_grad:
            obj.origin = None
            obj.id_at_origin = None

        # If we don't need to create the pointer
        if not create_pointer:
            return None

        # Create the pointer if needed
        if hasattr(obj, "create_pointer") and not isinstance(
                obj,
                sy.Protocol):  # TODO: this seems like hack to check a type
            if ptr_id is None:  # Define a remote id if not specified
                ptr_id = sy.ID_PROVIDER.pop()

            pointer = type(obj).create_pointer(
                obj,
                owner=self,
                location=worker,
                id_at_location=obj.id,
                register=True,
                ptr_id=ptr_id,
                garbage_collect_data=garbage_collect_data,
                **kwargs,
            )
        else:
            pointer = obj

        return pointer

    def handle_object_msg(self, obj_msg: ObjectMessage):
        # This should be a good seam for separating Workers from ObjectStore (someday),
        # so that Workers have ObjectStores instead of being ObjectStores. That would open
        # up the possibility of having a separate ObjectStore for each user, or for each
        # Plan/Protocol, etc. As Syft moves toward multi-tenancy with Grid and so forth,
        # that will probably be useful for providing security and permissioning. In that
        # future, this might look like `self.object_store.set_obj(obj_msg.object)`
        """Receive an object from a another worker

        Args:
            obj: a Framework Tensor or a subclass of an AbstractTensor with an id
        """
        obj = obj_msg.object

        self.object_store.set_obj(obj)

        if isinstance(obj, FrameworkTensor):
            tensor = obj
            if (tensor.requires_grad and tensor.origin is not None
                    and tensor.id_at_origin is not None):
                tensor.register_hook(
                    tensor.trigger_origin_backward_hook(
                        tensor.origin, tensor.id_at_origin))

    def handle_delete_object_msg(self, msg: ForceObjectDeleteMessage):
        # NOTE cannot currently be used because there is no ObjectDeleteMessage
        self.object_store.rm_obj(msg.object_id)

    def handle_force_delete_object_msg(self, msg: ForceObjectDeleteMessage):
        self.object_store.force_rm_obj(msg.object_id)

    def execute_tensor_command(self,
                               cmd: TensorCommandMessage) -> PointerTensor:
        if isinstance(cmd.action, ComputationAction):
            return self.execute_computation_action(cmd.action)
        else:
            return self.execute_communication_action(cmd.action)

    def execute_computation_action(self,
                                   action: ComputationAction) -> PointerTensor:
        """
        Executes commands received from other workers.
        Args:
            message: A tuple specifying the command and the args.
        Returns:
            The result or None if return_value is False.
        """

        op_name = action.name
        _self = action.target
        args_ = action.args
        kwargs_ = action.kwargs
        return_ids = action.return_ids
        return_value = action.return_value

        # Handle methods
        if _self is not None:
            if type(_self) == int:
                _self = BaseWorker.get_obj(self, _self)
                if _self is None:
                    return
            elif isinstance(_self, str):
                if _self == "self":
                    _self = self
                else:
                    res: list = self.search(_self)
                    assert (
                        len(res) == 1
                    ), f"Searching for {_self} on {self.id}. /!\\ {len(res)} found"
                    _self = res[0]
            if sy.framework.is_inplace_method(op_name):
                # TODO[jvmancuso]: figure out a good way to generalize the
                # above check (#2530)
                getattr(_self, op_name)(*args_, **kwargs_)
                return
            else:
                try:
                    response = getattr(_self, op_name)(*args_, **kwargs_)
                except TypeError:
                    # TODO Andrew thinks this is gross, please fix. Instead need to properly deserialize strings
                    new_args = [
                        arg.decode("utf-8") if isinstance(arg, bytes) else arg
                        for arg in args_
                    ]
                    response = getattr(_self, op_name)(*new_args, **kwargs_)
        # Handle functions
        else:
            # At this point, the command is ALWAYS a path to a
            # function (i.e., torch.nn.functional.relu). Thus,
            # we need to fetch this function and run it.

            sy.framework.command_guard(op_name)

            paths = op_name.split(".")
            command = self
            for path in paths:
                command = getattr(command, path)

            response = command(*args_, **kwargs_)

        # some functions don't return anything (such as .backward())
        # so we need to check for that here.
        if response is not None:
            # Register response and create pointers for tensor elements
            try:
                response = hook_args.register_response(op_name, response,
                                                       list(return_ids), self)
                # TODO: Does this mean I can set return_value to False and still get a response? That seems surprising.
                if return_value or isinstance(response,
                                              (int, float, bool, str)):
                    return response
                else:
                    return None
            except ResponseSignatureError:
                return_id_provider = sy.ID_PROVIDER
                return_id_provider.set_next_ids(return_ids, check_ids=False)
                return_id_provider.start_recording_ids()
                response = hook_args.register_response(op_name, response,
                                                       return_id_provider,
                                                       self)
                new_ids = return_id_provider.get_recorded_ids()
                raise ResponseSignatureError(new_ids)

    def execute_communication_action(
            self, action: CommunicationAction) -> PointerTensor:
        owner = action.target.owner
        destinations = [self.get_worker(id_) for id_ in action.args]
        kwargs_ = action.kwargs

        if owner != self:
            return None
        else:
            obj = self.get_obj(action.target.id)
            response = owner.send(obj, *destinations, **kwargs_)
            response.garbage_collect_data = False
            if kwargs_.get("requires_grad", False):
                response = hook_args.register_response("send", response,
                                                       [sy.ID_PROVIDER.pop()],
                                                       self)
            else:
                self.object_store.rm_obj(action.target.id)
            return response

    def execute_worker_command(self, message: tuple):
        """Executes commands received from other workers.

        Args:
            message: A tuple specifying the command and the args.

        Returns:
            A pointer to the result.
        """
        command_name = message.command_name
        args_, kwargs_, return_ids = message.message

        response = getattr(self, command_name)(*args_, **kwargs_)
        #  TODO [midokura-silvia]: send the tensor directly
        #  TODO this code is currently necessary for the async_fit method in websocket_client.py
        if isinstance(response, FrameworkTensor):
            self.register_obj(obj=response, obj_id=return_ids[0])
            return None
        return response

    def execute_plan_command(self, msg: PlanCommandMessage):
        """Executes commands related to plans.

        This method is intended to execute all commands related to plans and
        avoiding having several new message types specific to plans.

        Args:
            msg: A PlanCommandMessage specifying the command and args.
        """
        command_name = msg.command_name
        args_ = msg.args

        try:
            command = self._plan_command_router[command_name]
        except KeyError:
            raise PlanCommandUnknownError(command_name)

        return command(*args_)

    def send_command(
        self,
        recipient: "BaseWorker",
        cmd_name: str,
        target: PointerTensor = None,
        args_: tuple = (),
        kwargs_: dict = {},
        return_ids: str = None,
        return_value: bool = False,
    ) -> Union[List[PointerTensor], PointerTensor]:
        """
        Sends a command through a message to a recipient worker.

        Args:
            recipient: A recipient worker.
            cmd_name: Command number.
            target: Target pointer Tensor.
            args_: additional args for command execution.
            kwargs_: additional kwargs for command execution.
            return_ids: A list of strings indicating the ids of the
                tensors that should be returned as response to the command execution.

        Returns:
            A list of PointerTensors or a single PointerTensor if just one response is expected.
        """
        if return_ids is None:
            return_ids = tuple([sy.ID_PROVIDER.pop()])

        try:
            message = TensorCommandMessage.computation(cmd_name, target, args_,
                                                       kwargs_, return_ids,
                                                       return_value)
            ret_val = self.send_msg(message, location=recipient)
        except ResponseSignatureError as e:
            ret_val = None
            return_ids = e.ids_generated

        if ret_val is None or type(ret_val) == bytes:
            responses = []
            for return_id in return_ids:
                response = PointerTensor(
                    location=recipient,
                    id_at_location=return_id,
                    owner=self,
                    id=sy.ID_PROVIDER.pop(),
                )
                responses.append(response)

            if len(return_ids) == 1:
                responses = responses[0]
        else:
            responses = ret_val
        return responses

    def get_obj(self, obj_id: Union[str, int]) -> object:
        """Returns the object from registry.

        Look up an object from the registry using its ID.

        Args:
            obj_id: A string or integer id of an object to look up.
        """
        obj = self.object_store.get_obj(obj_id)

        # An object called with get_obj will be "with high probability" serialized
        # and sent back, so it will be GCed but remote data is any shouldn't be
        # deleted
        if hasattr(obj, "child") and hasattr(obj.child,
                                             "set_garbage_collect_data"):
            obj.child.set_garbage_collect_data(value=False)

        if hasattr(obj, "private") and obj.private:
            return None

        return obj

    def respond_to_obj_req(self, msg: ObjectRequestMessage):
        """Returns the deregistered object from registry.

        Args:
            request_msg (tuple): Tuple containing object id, user credentials and reason.
        """
        obj_id = msg.object_id
        user = msg.user
        reason = msg.reason

        obj = self.get_obj(obj_id)
        if hasattr(obj, "allow") and not obj.allow(user):
            raise GetNotPermittedError()
        else:
            self.de_register_obj(obj)
            return obj

    def register_obj(self, obj: object, obj_id: Union[str, int] = None):
        """Registers the specified object with the current worker node.

        Selects an id for the object, assigns a list of owners, and establishes
        whether it's a pointer or not. This method is generally not used by the
        client and is instead used by internal processes (hooks and workers).

        Args:
            obj: A torch Tensor or Variable object to be registered.
            obj_id (int or string): random integer between 0 and 1e10 or
                string uniquely identifying the object.
        """
        if not self.is_client_worker:
            self.object_store.register_obj(obj, obj_id=obj_id)

    def de_register_obj(self, obj: object, _recurse_torch_objs: bool = True):
        """
        De-registers the specified object with the current worker node.

        Args:
            obj: the object to deregister
            _recurse_torch_objs: A boolean indicating whether the object is
                more complex and needs to be explored.
        """
        if not self.is_client_worker:
            self.object_store.de_register_obj(obj, _recurse_torch_objs)

    # SECTION: convenience methods for constructing frequently used messages

    def send_obj(self, obj: object, location: "BaseWorker"):
        """Send a torch object to a worker.

        Args:
            obj: A torch Tensor or Variable object to be sent.
            location: A BaseWorker instance indicating the worker which should
                receive the object.
        """
        return self.send_msg(ObjectMessage(obj), location)

    def request_obj(self,
                    obj_id: Union[str, int],
                    location: "BaseWorker",
                    user=None,
                    reason: str = "") -> object:
        """Returns the requested object from specified location.

        Args:
            obj_id (int or string):  A string or integer id of an object to look up.
            location (BaseWorker): A BaseWorker instance that lets you provide the lookup
                location.
            user (object, optional): user credentials to perform user authentication.
            reason (string, optional): a description of why the data scientist wants to see it.
        Returns:
            A torch Tensor or Variable object.
        """
        obj = self.send_msg(ObjectRequestMessage(obj_id, user, reason),
                            location)
        return obj

    # SECTION: Manage the workers network

    def get_worker(self,
                   id_or_worker: Union[str, int, "BaseWorker"],
                   fail_hard: bool = False) -> Union[str, int, AbstractWorker]:
        """Returns the worker id or instance.

        Allows for resolution of worker ids to workers to happen automatically
        while also making the current worker aware of new ones when discovered
        through other processes.

        If you pass in an ID, it will try to find the worker object reference
        within self._known_workers. If you instead pass in a reference, it will
        save that as a known_worker if it does not exist as one.

        This method is useful because often tensors have to store only the ID
        to a foreign worker which may or may not be known by the worker that is
        de-serializing it at the time of deserialization.

        Args:
            id_or_worker: A string or integer id of the object to be returned
                or the BaseWorker object itself.
            fail_hard (bool): A boolean parameter indicating whether we want to
                throw an exception when a worker is not registered at this
                worker or we just want to log it.

        Returns:
            A string or integer id of the worker or the BaseWorker instance
            representing the worker.

        Example:
            >>> import syft as sy
            >>> hook = sy.TorchHook(verbose=False)
            >>> me = hook.local_worker
            >>> bob = sy.VirtualWorker(id="bob",hook=hook, is_client_worker=False)
            >>> me.add_worker([bob])
            >>> bob
            <syft.core.workers.virtual.VirtualWorker id:bob>
            >>> # we can get the worker using it's id (1)
            >>> me.get_worker('bob')
            <syft.core.workers.virtual.VirtualWorker id:bob>
            >>> # or we can get the worker by passing in the worker
            >>> me.get_worker(bob)
            <syft.core.workers.virtual.VirtualWorker id:bob>
        """
        if isinstance(id_or_worker, bytes):
            id_or_worker = str(id_or_worker, "utf-8")

        if isinstance(id_or_worker, str) or isinstance(id_or_worker, int):
            return self._get_worker_based_on_id(id_or_worker,
                                                fail_hard=fail_hard)
        else:
            return self._get_worker(id_or_worker)

    def _get_worker(self, worker: AbstractWorker):
        if worker.id not in self._known_workers:
            self.add_worker(worker)
        return worker

    def _get_worker_based_on_id(self,
                                worker_id: Union[str, int],
                                fail_hard: bool = False):
        # A worker should always know itself
        if worker_id == self.id:
            return self

        worker = self._known_workers.get(worker_id, worker_id)

        if worker == worker_id:
            if fail_hard:
                raise WorkerNotFoundException
            logger.warning("Worker %s couldn't recognize worker %s", self.id,
                           worker_id)
        return worker

    def add_worker(self, worker: "BaseWorker"):
        """Adds a single worker.

        Adds a worker to the list of _known_workers internal to the BaseWorker.
        Endows this class with the ability to communicate with the remote
        worker  being added, such as sending and receiving objects, commands,
        or  information about the network.

        Args:
            worker (:class:`BaseWorker`): A BaseWorker object representing the
                pointer to a remote worker, which must have a unique id.

        Example:
            >>> import torch
            >>> import syft as sy
            >>> hook = sy.TorchHook(verbose=False)
            >>> me = hook.local_worker
            >>> bob = sy.VirtualWorker(id="bob",hook=hook, is_client_worker=False)
            >>> me.add_worker([bob])
            >>> x = torch.Tensor([1,2,3,4,5])
            >>> x
            1
            2
            3
            4
            5
            [syft.core.frameworks.torch.tensor.FloatTensor of size 5]
            >>> x.send(bob)
            FloatTensor[_PointerTensor - id:9121428371 owner:0 loc:bob
                        id@loc:47416674672]
            >>> x.get()
            1
            2
            3
            4
            5
            [syft.core.frameworks.torch.tensor.FloatTensor of size 5]
        """
        if worker.id in self._known_workers:
            logger.warning(
                "Worker " + str(worker.id) +
                " already exists. Replacing old worker which could cause \
                    unexpected behavior")
        self._known_workers[worker.id] = worker

        return self

    def add_workers(self, workers: List["BaseWorker"]):
        """Adds several workers in a single call.

        Args:
            workers: A list of BaseWorker representing the workers to add.
        """
        for worker in workers:
            self.add_worker(worker)

        return self

    def __str__(self):
        """Returns the string representation of BaseWorker.

        A to-string method for all classes that extend BaseWorker.

        Returns:
            The Type and ID of the worker

        Example:
            A VirtualWorker instance with id 'bob' would return a string value of.
            >>> import syft as sy
            >>> bob = sy.VirtualWorker(id="bob")
            >>> bob
            <syft.workers.virtual.VirtualWorker id:bob>

        Note:
            __repr__ calls this method by default.
        """

        out = "<"
        out += str(type(self)).split("'")[1].split(".")[-1]
        out += " id:" + str(self.id)
        out += " #objects:" + str(len(self.object_store._objects))
        out += ">"
        return out

    def __repr__(self):
        """Returns the official string representation of BaseWorker."""
        return self.__str__()

    def __getitem__(self, idx):
        return self.object_store.get_obj(idx, None)

    def is_object_none(self, msg):
        obj_id = msg.object_id
        if obj_id not in self.object_store._objects:
            # If the object is not present on the worker, raise an error
            raise ObjectNotFoundError(obj_id, self)
        obj = self.get_obj(msg.object_id)
        return obj is None

    def request_is_remote_tensor_none(self, pointer: PointerTensor):
        """
        Sends a request to the remote worker that holds the target a pointer if
        the value of the remote tensor is None or not.
        Note that the pointer must be valid: if there is no target (which is
        different from having a target equal to None), it will return an error.

        Args:
            pointer: The pointer on which we can to get information.

        Returns:
            A boolean stating if the remote value is None.
        """
        return self.send_msg(IsNoneMessage(pointer.id_at_location),
                             location=pointer.location)

    def handle_get_shape_message(self, msg: GetShapeMessage) -> List:
        """
        Returns the shape of a tensor casted into a list, to bypass the serialization of
        a torch.Size object.

        Args:
            tensor: A torch.Tensor.

        Returns:
            A list containing the tensor shape.
        """
        tensor = self.get_obj(msg.tensor_id)
        return list(tensor.shape)

    def request_remote_tensor_shape(self,
                                    pointer: PointerTensor) -> FrameworkShape:
        """
        Sends a request to the remote worker that holds the target a pointer to
        have its shape.

        Args:
            pointer: A pointer on which we want to get the shape.

        Returns:
            A torch.Size object for the shape.
        """
        shape = self.send_msg(GetShapeMessage(pointer.id_at_location),
                              location=pointer.location)
        return sy.hook.create_shape(shape)

    def fetch_plan(self,
                   plan_id: Union[str, int],
                   location: "BaseWorker",
                   copy: bool = False) -> "Plan":  # noqa: F821
        """Fetchs a copy of a the plan with the given `plan_id` from the worker registry.

        This method is executed for local execution.

        Args:
            plan_id: A string indicating the plan id.

        Returns:
            A plan if a plan with the given `plan_id` exists. Returns None otherwise.
        """
        message = PlanCommandMessage("fetch_plan", (plan_id, copy))
        plan = self.send_msg(message, location=location)

        return plan

    def _fetch_plan_remote(self, plan_id: Union[str, int],
                           copy: bool) -> "Plan":  # noqa: F821
        """Fetches a copy of a the plan with the given `plan_id` from the worker registry.

        This method is executed for remote execution.

        Args:
            plan_id: A string indicating the plan id.

        Returns:
            A plan if a plan with the given `plan_id` exists. Returns None otherwise.
        """
        if plan_id in self.object_store._objects:
            candidate = self.object_store.get_obj(plan_id)
            if isinstance(candidate, sy.Plan):
                if copy:
                    return candidate.copy()
                else:
                    return candidate

        return None

    def fetch_protocol(self,
                       protocol_id: Union[str, int],
                       location: "BaseWorker",
                       copy: bool = False) -> "Plan":  # noqa: F821
        """Fetch a copy of a the protocol with the given `protocol_id` from the worker registry.

        This method is executed for local execution.

        Args:
            protocol_id: A string indicating the protocol id.

        Returns:
            A protocol if a protocol with the given `protocol_id` exists. Returns None otherwise.
        """
        message = PlanCommandMessage("fetch_protocol", (protocol_id, copy))
        protocol = self.send_msg(message, location=location)

        return protocol

    def _fetch_protocol_remote(self, protocol_id: Union[str, int],
                               copy: bool) -> "Protocol":  # noqa: F821
        """
        Target function of fetch_protocol, find and return a protocol
        """
        if protocol_id in self.object_store._objects:

            candidate = self.object_store.get_obj(protocol_id)
            if isinstance(candidate, sy.Protocol):
                return candidate

        return None

    def search(self, query: Union[List[Union[str, int]], str, int]) -> List:
        """Search for a match between the query terms and a tensor's Id, Tag, or Description.

        Note that the query is an AND query meaning that every item in the list of strings (query*)
        must be found somewhere on the tensor in order for it to be included in the results.

        Args:
            query: A list of strings to match against.

        Returns:
            A list of valid results found.

        TODO Search on description is not supported for the moment
        """
        if isinstance(query, (str, int)):
            query = [query]
        # Empty query returns all the tagged and registered values
        elif len(query) == 0:
            result_ids = set()
            for tag, object_ids in self.object_store._tag_to_object_ids.items(
            ):
                result_ids = result_ids.union(object_ids)
            return [self.get_obj(result_id) for result_id in result_ids]

        results = None
        for query_item in query:
            # Search by id is supported but it's not the preferred option
            # It will return a single element and discard tags if the query
            # Mixed an id with tags
            result_by_id = self.object_store.find_by_id(query_item)
            if result_by_id:
                results = {result_by_id}
                break

            # results_by_tag can be the empty list
            results_by_tag = set(self.object_store.find_by_tag(query_item))

            if results:
                results = results.intersection(results_by_tag)
            else:
                results = results_by_tag

        if results is not None:
            return list(results)
        else:
            return list()

    def respond_to_search(self, msg: SearchMessage) -> List[PointerTensor]:
        """
        When remote worker calling search on this worker, forwarding the call and
        replace found elements by pointers
        """
        query = msg.query
        objects = self.search(query)
        results = []
        for obj in objects:
            # set garbage_collect_data to False because if we're searching
            # for a tensor we don't own, then it's probably someone else's
            # decision to decide when to delete the tensor.
            ptr = obj.create_pointer(garbage_collect_data=False,
                                     owner=sy.local_worker,
                                     tags=obj.tags).wrap()
            results.append(ptr)

        return results

    def request_search(self, query: List[str], location: "BaseWorker") -> List:
        """
        Add a remote worker to perform a search
        Args:
            query: the tags or id used in the search
            location: the remote worker identity

        Returns:
            A list of pointers to the results
        """
        results = self.send_msg(SearchMessage(query), location=location)
        for result in results:
            self.register_obj(result)
        return results

    def find_or_request(self, tag, location):
        """
        Allow efficient retrieval: if the tag is know locally, return the local
        element. Else, perform a search on location
        """
        results = self.object_store.find_by_tag(tag)
        if results:
            assert all(result.location.id == location.id for result in results)
            return results
        else:
            return self.request_search(tag, location=location)

    def _get_msg(self, index):
        """Returns a decrypted message from msg_history. Mostly useful for testing.

        Args:
            index: the index of the message you'd like to receive.

        Returns:
            A decrypted messaging.Message object.

        """

        return self.msg_history[index]

    @property
    def message_pending_time(self):
        """
        Returns:
            The pending time in seconds for messaging between virtual workers.
        """
        return self._message_pending_time

    @message_pending_time.setter
    def message_pending_time(self, seconds: Union[int, float]) -> None:
        """Sets the pending time to send messaging between workers.

        Args:
            seconds: A number of seconds to delay the messages to be sent.
            The argument may be a floating point number for subsecond
            precision.

        """
        if self.verbose:
            print(f"Set message pending time to {seconds} seconds.")

        self._message_pending_time = seconds

    @staticmethod
    def create_worker_command_message(command_name: str,
                                      return_ids=None,
                                      *args,
                                      **kwargs):
        """helper function creating a worker command message

        Args:
            command_name: name of the command that shall be called
            return_ids: optionally set the ids of the return values (for remote objects)
            *args:  will be passed to the call of command_name
            **kwargs:  will be passed to the call of command_name

        Returns:
            cmd_msg: a WorkerCommandMessage

        """
        if return_ids is None:
            return_ids = []
        return WorkerCommandMessage(command_name, (args, kwargs, return_ids))

    def feed_crypto_primitive_store(self, types_primitives: dict):
        self.crypto_store.add_primitives(types_primitives)

    def list_tensors(self):
        return str(self.object_store._tensors)

    def tensors_count(self):
        return len(self.object_store._tensors)

    def list_objects(self):
        return str(self.object_store._objects)

    def objects_count(self):
        return len(self.object_store._objects)

    def _log_msgs(self, value):
        self.log_msgs = value

    @property
    def serializer(self, workers=None) -> codes.TENSOR_SERIALIZATION:
        """
        Define the serialization strategy to adopt depending on the workers it's connected to.
        This is relevant in particular for Tensors which can be serialized in an efficient way
        between workers which share the same Deep Learning framework, but must be converted to
        lists or json-like objects in other cases.

        Args:
            workers: (Optional) the list of workers involved in the serialization. If not
                provided, self._known_workers is used.

        Returns:
            A str code:
                'all': serialization must be compatible with all kinds of workers
                'torch': serialization will only work between workers that support PyTorch
                (more to come: 'tensorflow', 'numpy', etc)
        """
        if workers is None:
            workers = [
                w for w in self._known_workers.values()
                if isinstance(w, AbstractWorker)
            ]

        if not isinstance(workers, list):
            workers = [workers]

        workers.append(self)

        frameworks = set()
        for worker in workers:
            if worker.framework is not None:
                framework = worker.framework.__name__
            else:
                framework = "None"

            frameworks.add(framework)

        if len(frameworks) == 1 and frameworks == {"torch"}:
            return codes.TENSOR_SERIALIZATION.TORCH
        else:
            return codes.TENSOR_SERIALIZATION.ALL

    @staticmethod
    def simplify(_worker: AbstractWorker, worker: AbstractWorker) -> tuple:
        return (sy.serde.msgpack.serde._simplify(_worker, worker.id), )

    @staticmethod
    def detail(worker: AbstractWorker,
               worker_tuple: tuple) -> Union[AbstractWorker, int, str]:
        """
        This function reconstructs a PlanPointer given it's attributes in form of a tuple.

        Args:
            worker: the worker doing the deserialization
            plan_pointer_tuple: a tuple holding the attributes of the PlanPointer
        Returns:
            A worker id or worker instance.
        """
        worker_id = sy.serde.msgpack.serde._detail(worker, worker_tuple[0])

        referenced_worker = worker.get_worker(worker_id)

        return referenced_worker

    @staticmethod
    def force_simplify(_worker: AbstractWorker,
                       worker: AbstractWorker) -> tuple:
        return (
            sy.serde.msgpack.serde._simplify(_worker, worker.id),
            sy.serde.msgpack.serde._simplify(_worker,
                                             worker.object_store._objects),
            worker.auto_add,
        )

    @staticmethod
    def force_detail(worker: AbstractWorker, worker_tuple: tuple) -> tuple:
        worker_id, _objects, auto_add = worker_tuple
        worker_id = sy.serde.msgpack.serde._detail(worker, worker_id)

        result = sy.VirtualWorker(sy.hook, worker_id, auto_add=auto_add)
        _objects = sy.serde.msgpack.serde._detail(worker, _objects)
        result.object_store._objects = _objects

        # make sure they weren't accidentally double registered
        for _, obj in _objects.items():
            if obj.id in worker.object_store._objects:
                worker.object_store.rm_obj(obj.id)

        return result

    @classmethod
    def is_framework_supported(cls, framework: str) -> bool:
        """
        Returns True if framework is supported, else returns False.
        :param framework: string
        :return: True/False
        """
        return framework.lower() in framework_packages
class FederatedClient:
    """A Client able to execute federated learning in local datasets."""
    def __init__(self, datasets=None):
        super().__init__()
        self.datasets = datasets if datasets is not None else dict()
        self.optimizer = None
        self.train_config = None
        self.object_store = ObjectStore(owner=self)

    def add_dataset(self, dataset, key: str):
        if key not in self.datasets:
            self.datasets[key] = dataset
        else:
            raise ValueError(f"Key {key} already exists in Datasets")

    def remove_dataset(self, key: str):
        if key in self.datasets:
            del self.datasets[key]

    def get_obj(self, obj_id: Union[str, int]) -> object:
        """Returns the object from registry.

        Look up an object from the registry using its ID.

        Args:
            obj_id: A string or integer id of an object to look up.

        Returns:
            Object with id equals to `obj_id`.
        """
        return self.object_store.get_obj(obj_id)

    def set_obj(self, obj: object):
        """Registers objects checking if which objects it should cache.

        Args:
            obj: An object to be registered.
        """
        if isinstance(obj, TrainConfig):
            self.train_config = obj
            self.optimizer = None
        else:
            self.object_store.set_obj(obj)

    def _check_train_config(self):
        if self.train_config is None:
            raise ValueError("Operation needs TrainConfig object to be set.")

    def _build_optimizer(self, optimizer_name: str, model,
                         optimizer_args: dict) -> th.optim.Optimizer:
        """Build an optimizer if needed.

        Args:
            optimizer_name: A string indicating the optimizer name.
            optimizer_args: A dict containing the args used to initialize the optimizer.
        Returns:
            A Torch Optimizer.
        """
        if self.optimizer is not None:
            return self.optimizer

        if optimizer_name in dir(th.optim):
            optimizer = getattr(th.optim, optimizer_name)
            optimizer_args.setdefault("params", model.parameters())
            self.optimizer = optimizer(**optimizer_args)
        else:
            raise ValueError(f"Unknown optimizer: {optimizer_name}")
        return self.optimizer

    def fit(self, dataset_key: str, device: str = "cpu", **kwargs):
        """Fits a model on the local dataset as specified in the local TrainConfig object.

        Args:
            dataset_key: Identifier of the local dataset that shall be used for training.
            **kwargs: Unused.

        Returns:
            loss: Training loss on the last batch of training data.
        """
        self._check_train_config()

        if dataset_key not in self.datasets:
            raise ValueError(f"Dataset {dataset_key} unknown.")

        model = self.object_store.get_obj(self.train_config._model_id).obj
        loss_fn = self.object_store.get_obj(self.train_config._loss_fn_id).obj

        self._build_optimizer(self.train_config.optimizer,
                              model,
                              optimizer_args=self.train_config.optimizer_args)

        return self._fit(model=model,
                         dataset_key=dataset_key,
                         loss_fn=loss_fn,
                         device=device)

    def _create_data_loader(self, dataset_key: str, shuffle: bool = False):
        data_range = range(len(self.datasets[dataset_key]))
        if shuffle:
            sampler = RandomSampler(data_range)
        else:
            sampler = SequentialSampler(data_range)
        data_loader = th.utils.data.DataLoader(
            self.datasets[dataset_key],
            batch_size=self.train_config.batch_size,
            sampler=sampler,
            num_workers=0,
        )
        return data_loader

    def _fit(self, model, dataset_key, loss_fn, device="cpu"):
        model.train()
        data_loader = self._create_data_loader(
            dataset_key=dataset_key, shuffle=self.train_config.shuffle)

        loss = None
        iteration_count = 0

        for _ in range(self.train_config.epochs):
            for (data, target) in data_loader:
                # Set gradients to zero
                self.optimizer.zero_grad()

                # Update model
                output = model(data.to(device))
                loss = loss_fn(target=target.to(device), pred=output)
                loss.backward()
                self.optimizer.step()

                # Update and check interation count
                iteration_count += 1
                if iteration_count >= self.train_config.max_nr_batches >= 0:
                    break

        return loss

    def evaluate(
        self,
        dataset_key: str,
        return_histograms: bool = False,
        nr_bins: int = -1,
        return_loss: bool = True,
        return_raw_accuracy: bool = True,
        device: str = "cpu",
    ):
        """Evaluates a model on the local dataset as specified in the local TrainConfig object.

        Args:
            dataset_key: Identifier of the local dataset that shall be used for training.
            return_histograms: If True, calculate the histograms of predicted classes.
            nr_bins: Used together with calculate_histograms. Provide the number of classes/bins.
            return_loss: If True, loss is calculated additionally.
            return_raw_accuracy: If True, return nr_correct_predictions and nr_predictions
            device: "cuda" or "cpu"

        Returns:
            Dictionary containing depending on the provided flags:
                * loss: avg loss on data set, None if not calculated.
                * nr_correct_predictions: number of correct predictions.
                * nr_predictions: total number of predictions.
                * histogram_predictions: histogram of predictions.
                * histogram_target: histogram of target values in the dataset.
        """
        self._check_train_config()

        if dataset_key not in self.datasets:
            raise ValueError(f"Dataset {dataset_key} unknown.")

        eval_result = dict()
        model = self.object_store.get_obj(self.train_config._model_id).obj
        loss_fn = self.object_store.get_obj(self.train_config._loss_fn_id).obj
        model.eval()
        device = "cuda" if device == "cuda" else "cpu"
        data_loader = self._create_data_loader(dataset_key=dataset_key,
                                               shuffle=False)
        test_loss = 0.0
        correct = 0
        if return_histograms:
            hist_target = np.zeros(nr_bins)
            hist_pred = np.zeros(nr_bins)

        with th.no_grad():
            for data, target in data_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                if return_loss:
                    test_loss += loss_fn(output,
                                         target).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1,
                    keepdim=True)  # get the index of the max log-probability
                if return_histograms:
                    hist, _ = np.histogram(target,
                                           bins=nr_bins,
                                           range=(0, nr_bins))
                    hist_target += hist
                    hist, _ = np.histogram(pred,
                                           bins=nr_bins,
                                           range=(0, nr_bins))
                    hist_pred += hist
                if return_raw_accuracy:
                    correct += pred.eq(target.view_as(pred)).sum().item()

        if return_loss:
            test_loss /= len(data_loader.dataset)
            eval_result["loss"] = test_loss
        if return_raw_accuracy:
            eval_result["nr_correct_predictions"] = correct
            eval_result["nr_predictions"] = len(data_loader.dataset)
        if return_histograms:
            eval_result["histogram_predictions"] = hist_pred
            eval_result["histogram_target"] = hist_target

        return eval_result

    def _log_msgs(self, value):
        self.log_msgs = value