コード例 #1
0
ファイル: base.py プロジェクト: SanaAwan5/mnist2
    def receive_msg(self, message_wrapper_json):
        """Receives an message from a worker and then executes its contents
        appropriately. The message is encoded as a binary blob.

        * **message (binary)** the message being sent

        * **out (object)** the response. This can be a variety
          of object types. However, the object is typically only used during testing or
          local development with :class:`VirtualWorker` workers.
        """

        # load json into a dictionary where all objects have been deserialized
        message_wrapper = encode.decode(
            self.decode_msg(message_wrapper_json), worker=self
        )

        # route message to appropriate logic and execute the command, returning
        # the "response" which should be sent back to the original worker. "private" (bool)
        # determines whether we are intentionally leaving out the data in the response
        # and instead sending pointers to the data which we will actually keep locally
        response, private = self.process_message_type(message_wrapper)

        # serialize any objects in the response into their string/dictionary form (recursive)
        response = encode.encode(
            response, retrieve_pointers=False, private_local=private
        )

        response = self.encode_msg(response)

        return response
コード例 #2
0
ファイル: base.py プロジェクト: SanaAwan5/mnist2
    def request_obj(self, obj_id, recipient):
        """request_obj(self, obj_id, sender) This method requests that another
        VirtualWorker send an object to the local one. In the case that the
        local one is a client, it simply returns the object. In the case that
        the local worker is not a client, it stores the object in the permanent
        registry.

        :Parameters:

        * **obj_id (str or int)** the id of the object being requested

        * **sender (** :class:`VirtualWorker` **)** the worker who currently has the
          object who is being requested to send it.
        """

        object = self.send_msg(
            message=obj_id, message_type="req_obj", recipient=recipient
        )

        object = encode.decode(object, worker=self)

        # for some reason, when returning obj from request_obj method, the gradient
        # (obj.grad) gets re-initialized without being re-registered and as a
        # consequence does not have an id, causing the x.grad.id to fail because
        # it does not exist. As a result, we've needed to store objects temporarily
        # in self._tmpobjects which seems to fix it. Super strange bug which took
        # multiple days to figure out. The true cause is still unknown but this
        # workaround seems to work well for now. Anyway, so we need to return a cleanup
        # method which is called immediately before this is returned to the client.
        # Note that this is ONLY necessary for the client (which doesn't store objects
        # in self._objects)

        return object
コード例 #3
0
ファイル: socket.py プロジェクト: SanaAwan5/mnist2
    def search(self, query):
        """This function is designed to find relevant tensors present within
        the worker's objects (self._objects) dict. It does so by looking for
        string overlap between one or more strings in the "query" and the id of
        each tensor. If the current worker object (self) is merely a pointer to
        a remote worker (connected via socket), then it sends a command to the
        remote worker which calls this function on the remote machine. If the
        current worker object (self) is NOT a pointer, then it queries the
        local tensors.

        :param query: a string or list of strings
        :return: if self.is_pointer==True, this returns a set of pointer tensors.
        Otherwise, it returns the tensors.
        """

        if self.is_pointer:
            raw_response = self.send_msg(message=query,
                                         message_type="query",
                                         recipient=self)

            response = self.decode_msg(raw_response)

            tensors = encode.decode(response,
                                    worker=self.hook.local_worker,
                                    message_is_dict=True)

            return list(tensors)
        else:
            tensors = self._search(query)
            return tensors
コード例 #4
0
    def test_encode_decode_json_python(self):
        """
            Test that the python objects are correctly encoded and decoded in
            json with our encoder/JSONDecoder.
            The main focus is on non-serializable objects, such as torch Variable
            or tuple, or even slice().
        """

        x = Var(torch.FloatTensor([[1, -1], [0, 1]]))
        x.send(bob)
        obj = [None, ({'marcel': (1, [1.3], x), 'proust': slice(0, 2, None)}, 3)]
        enc, t = encode.encode(obj)
        enc = json.dumps(enc)
        dec1 = encode.decode(enc, me)
        enc, t = encode.encode(dec1)
        enc = json.dumps(enc)
        dec2 = encode.decode(enc, me)
        assert dec1 == dec2
コード例 #5
0
ファイル: base.py プロジェクト: SanaAwan5/mnist2
    def send_command(self, recipient, message, framework="torch"):

        if isinstance(recipient, (str, int)):
            raise TypeError("Recipient should be a worker object not his id.")

        # print(message)
        response = self.send_msg(
            message=message, message_type=framework + "_cmd", recipient=recipient
        )
        # print(response)
        response = encode.decode(response, worker=self)

        return response