Ejemplo n.º 1
0
    def fit(self, dataset_key: str, **kwargs):
        """Call the fit() method on the remote worker (WebsocketServerWorker instance).

        Note: The argument return_ids is provided as kwargs as otherwise there is a miss-match
        with the signature in VirtualWorker.fit() method. This is important to be able to switch
        between virtual and websocket workers.

        Args:
            dataset_key: Identifier of the dataset which shall be used for the training.
            **kwargs:
                return_ids: List[str]
        """
        return_ids = kwargs["return_ids"] if "return_ids" in kwargs else [
            sy.ID_PROVIDER.pop()
        ]

        self._send_msg_and_deserialize("fit",
                                       return_ids=return_ids,
                                       dataset_key=dataset_key)

        msg = ObjectRequestMessage(return_ids[0], None, "")
        # Send the message and return the deserialized response.
        serialized_message = sy.serde.serialize(msg)
        response = self._send_msg(serialized_message)
        return sy.serde.deserialize(response)
Ejemplo n.º 2
0
    def request_obj(self, obj_id: Union[str, int], location: "BaseWorker") -> object:
        """Returns the requested object from specified location.

        Args:
            obj_id:  A string or integer id of an object to look up.
            location: A BaseWorker instance that lets you provide the lookup
                location.

        Returns:
            A torch Tensor or Variable object.
        """
        obj = self.send_msg(ObjectRequestMessage(obj_id), location)
        return obj
Ejemplo n.º 3
0
    def request_obj(
        self, obj_id: Union[str, int], location: "BaseWorker", user=None, reason: str = ""
    ) -> object:
        """Returns the requested object from specified location.

        Args:
            obj_id (int or string):  A string or integer id of an object to look up.
            location (BaseWorker): A BaseWorker instance that lets you provide the lookup
                location.
            user (object, optional): user credentials to perform user authentication.
            reason (string, optional): a description of why the data scientist wants to see it.
        Returns:
            A torch Tensor or Variable object.
        """
        obj = self.send_msg(ObjectRequestMessage(obj_id, user, reason), location)
        return obj
Ejemplo n.º 4
0
    async def async_fit(self,
                        dataset_key: str,
                        device: str = "cpu",
                        return_ids: List[int] = None):
        """Asynchronous call to fit function on the remote location.

        Args:
            dataset_key: Identifier of the dataset which shall be used for the training.
            return_ids: List of return ids.

        Returns:
            See return value of the FederatedClient.fit() method.
        """
        if return_ids is None:
            return_ids = [sy.ID_PROVIDER.pop()]

        # Close the existing websocket connection in order to open a asynchronous connection
        # This code is not tested with secure connections (wss protocol).
        self.close()
        async with websockets.connect(
                self.url,
                timeout=TIMEOUT_INTERVAL,
                max_size=None,
                ping_timeout=TIMEOUT_INTERVAL) as websocket:
            message = self.create_worker_command_message(
                command_name="fit",
                return_ids=return_ids,
                dataset_key=dataset_key,
                device=device)

            # Send the message and return the deserialized response.
            serialized_message = sy.serde.serialize(message)
            await websocket.send(str(binascii.hexlify(serialized_message)))
            await websocket.recv(
            )  # returned value will be None, so don't care

        # Reopen the standard connection
        self.connect()

        # Send an object request message to retrieve the result tensor of the fit() method
        msg = ObjectRequestMessage(return_ids[0], None, "")
        serialized_message = sy.serde.serialize(msg)
        response = self._send_msg(serialized_message)

        # Return the deserialized response.
        return sy.serde.deserialize(response)
Ejemplo n.º 5
0
def test_recv_msg():
    """Tests the recv_msg command with 2 tests

    The first test uses recv_msg to send an object to alice.

    The second test uses recv_msg to request the object
    previously sent to alice."""

    # TEST 1: send tensor to alice

    # create a worker to send data to
    worker_id = sy.ID_PROVIDER.pop()
    alice = VirtualWorker(sy.torch.hook, id=f"alice{worker_id}")

    # create object to send
    obj = torch.Tensor([100, 100])

    # create/serialize message
    message = ObjectMessage(obj)
    bin_msg = serde.serialize(message)

    # have alice receive message
    alice.recv_msg(bin_msg)

    # ensure that object is now in alice's registry
    assert obj.id in alice.object_store._objects

    # Test 2: get tensor back from alice

    # Create message: Get tensor from alice
    message = ObjectRequestMessage(obj.id, None, "")

    # serialize message
    bin_msg = serde.serialize(message)

    # call receive message on alice
    resp = alice.recv_msg(bin_msg)

    obj_2 = sy.serde.deserialize(resp)

    # assert that response is correct type
    assert type(resp) == bytes

    # ensure that the object we receive is correct
    assert obj_2.id == obj.id