Ejemplo n.º 1
0
    def wait(self,
             object_refs: List[ClientObjectRef],
             *,
             num_returns: int = 1,
             timeout: float = None,
             fetch_local: bool = True
             ) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
        if not isinstance(object_refs, list):
            raise TypeError("wait() expected a list of ClientObjectRef, "
                            f"got {type(object_refs)}")
        for ref in object_refs:
            if not isinstance(ref, ClientObjectRef):
                raise TypeError("wait() expected a list of ClientObjectRef, "
                                f"got list containing {type(ref)}")
        data = {
            "object_ids": [object_ref.id for object_ref in object_refs],
            "num_returns": num_returns,
            "timeout": timeout if (timeout is not None) else -1,
            "client_id": self._client_id,
        }
        req = ray_client_pb2.WaitRequest(**data)
        resp = self.server.WaitObject(req, metadata=self.metadata)
        if not resp.valid:
            # TODO(ameer): improve error/exceptions messages.
            raise Exception("Client Wait request failed. Reference invalid?")
        client_ready_object_ids = [
            ClientObjectRef(ref) for ref in resp.ready_object_ids
        ]
        client_remaining_object_ids = [
            ClientObjectRef(ref) for ref in resp.remaining_object_ids
        ]

        return (client_ready_object_ids, client_remaining_object_ids)
Ejemplo n.º 2
0
def test_wait(ray_start_regular_shared):
    with ray_start_client_server() as ray:
        objectref = ray.put("hello world")
        ready, remaining = ray.wait([objectref])
        assert remaining == []
        retval = ray.get(ready[0])
        assert retval == "hello world"

        objectref2 = ray.put(5)
        ready, remaining = ray.wait([objectref, objectref2])
        assert (ready, remaining) == ([objectref], [objectref2]) or \
            (ready, remaining) == ([objectref2], [objectref])
        ready_retval = ray.get(ready[0])
        remaining_retval = ray.get(remaining[0])
        assert (ready_retval, remaining_retval) == ("hello world", 5) \
            or (ready_retval, remaining_retval) == (5, "hello world")

        with pytest.raises(Exception):
            # Reference not in the object store.
            ray.wait([ClientObjectRef(b"blabla")])
        with pytest.raises(TypeError):
            ray.wait("blabla")
        with pytest.raises(TypeError):
            ray.wait(ClientObjectRef("blabla"))
        with pytest.raises(TypeError):
            ray.wait(["blabla"])
Ejemplo n.º 3
0
def test_client_object_ref_basics(ray_start_regular):
    with ray_start_client_server_pair() as pair:
        ray, server = pair
        ref = ray.put("Hello World")
        # Make sure ClientObjectRef is a subclass of ObjectRef
        assert isinstance(ref, ClientObjectRef)
        assert isinstance(ref, ObjectRef)

        # Invalid ref format.
        with pytest.raises(Exception):
            ClientObjectRef(b"\0")

        obj_id = b"\0" * 28
        fut = Future()
        fut.set_result(obj_id)
        server_ref = ObjectRef(obj_id)
        for client_ref in [ClientObjectRef(obj_id), ClientObjectRef(fut)]:
            client_members = set(client_ref.__dir__())
            server_members = set(server_ref.__dir__())
            assert client_members.difference(server_members) == {"id"}
            assert server_members.difference(client_members) == set()

            # Test __eq__()
            assert client_ref == ClientObjectRef(obj_id)
            assert client_ref != ref
            assert client_ref != server_ref

            # Test other methods
            assert client_ref.__repr__() == f"ClientObjectRef({obj_id.hex()})"
            assert client_ref.binary() == obj_id
            assert client_ref.hex() == obj_id.hex()
            assert not client_ref.is_nil()
            assert client_ref.task_id() == server_ref.task_id()
            assert client_ref.job_id() == server_ref.job_id()
Ejemplo n.º 4
0
 def persistent_load(self, pid):
     assert isinstance(pid, PickleStub)
     if pid.type == "Object":
         return ClientObjectRef(pid.ref_id)
     elif pid.type == "Actor":
         return ClientActorHandle(ClientActorRef(pid.ref_id))
     else:
         raise NotImplementedError("Being passed back an unknown stub")
Ejemplo n.º 5
0
Archivo: worker.py Proyecto: alipay/ray
 def _put_pickled(self, data, client_ref_id: bytes):
     req = ray_client_pb2.PutRequest(data=data)
     if client_ref_id is not None:
         req.client_ref_id = client_ref_id
     resp = self.data_client.PutObject(req)
     if not resp.valid:
         try:
             raise cloudpickle.loads(resp.error)
         except (pickle.UnpicklingError, TypeError):
             logger.exception("Failed to deserialize {}".format(resp.error))
             raise
     return ClientObjectRef(resp.id)
Ejemplo n.º 6
0
 def _put(self, val):
     if isinstance(val, ClientObjectRef):
         raise TypeError(
             "Calling 'put' on an ObjectRef is not allowed "
             "(similarly, returning an ObjectRef from a remote "
             "function is not allowed). If you really want to "
             "do this, you can wrap the ObjectRef in a list and "
             "call 'put' on it (or return it).")
     data = dumps_from_client(val, self._client_id)
     req = ray_client_pb2.PutRequest(data=data)
     resp = self.data_client.PutObject(req)
     return ClientObjectRef(resp.id)
Ejemplo n.º 7
0
def test_put_get(ray_start_regular_shared):
    with ray_start_client_server() as ray:
        objectref = ray.put("hello world")
        print(objectref)

        retval = ray.get(objectref)
        assert retval == "hello world"
        # Make sure ray.put(1) == 1 is False and does not raise an exception.
        objectref = ray.put(1)
        assert not objectref == 1
        # Make sure it returns True when necessary as well.
        assert objectref == ClientObjectRef(objectref.id)
Ejemplo n.º 8
0
def test_client_object_ref_basics(ray_start_regular):
    with ray_start_client_server_pair() as pair:
        ray, server = pair
        ref = ray.put("Hello World")
        # Make sure ClientObjectRef is a subclass of ObjectRef
        assert isinstance(ref, ClientObjectRef)
        assert isinstance(ref, ObjectRef)

        # Invalid ref format.
        with pytest.raises(Exception):
            ClientObjectRef(b"\0")

        # Test __eq__()
        id = b"\0" * 28
        assert ClientObjectRef(id) == ClientObjectRef(id)
        assert ClientObjectRef(id) != ref
        assert ClientObjectRef(id) != ObjectRef(id)

        assert ClientObjectRef(id).__repr__() == f"ClientObjectRef({id.hex()})"
        assert ClientObjectRef(id).binary() == id
        assert ClientObjectRef(id).hex() == id.hex()
        assert not ClientObjectRef(id).is_nil()
Ejemplo n.º 9
0
 def _put(self, val, *, client_ref_id: bytes = None):
     if isinstance(val, ClientObjectRef):
         raise TypeError(
             "Calling 'put' on an ObjectRef is not allowed "
             "(similarly, returning an ObjectRef from a remote "
             "function is not allowed). If you really want to "
             "do this, you can wrap the ObjectRef in a list and "
             "call 'put' on it (or return it).")
     data = dumps_from_client(val, self._client_id)
     req = ray_client_pb2.PutRequest(data=data)
     if client_ref_id is not None:
         req.client_ref_id = client_ref_id
     resp = self.data_client.PutObject(req)
     if not resp.valid:
         try:
             raise cloudpickle.loads(resp.error)
         except pickle.UnpicklingError:
             logger.exception("Failed to deserialize {}".format(resp.error))
             raise
     return ClientObjectRef(resp.id)