Esempio n. 1
0
def test_two_subscribers(ray_start_regular):
    """Tests concurrently subscribing to two channels work."""

    address_info = ray_start_regular
    gcs_server_addr = address_info["gcs_address"]

    num_messages = 100

    errors = []
    error_subscriber = GcsErrorSubscriber(address=gcs_server_addr)
    # Make sure subscription is registered before publishing starts.
    error_subscriber.subscribe()

    def receive_errors():
        while len(errors) < num_messages:
            _, msg = error_subscriber.poll()
            errors.append(msg)

    t1 = threading.Thread(target=receive_errors)
    t1.start()

    logs = []
    log_subscriber = GcsLogSubscriber(address=gcs_server_addr)
    # Make sure subscription is registered before publishing starts.
    log_subscriber.subscribe()

    def receive_logs():
        while len(logs) < num_messages:
            log_batch = log_subscriber.poll()
            logs.append(log_batch)

    t2 = threading.Thread(target=receive_logs)
    t2.start()

    publisher = GcsPublisher(address=gcs_server_addr)
    for i in range(0, num_messages):
        publisher.publish_error(b"msg_id", ErrorTableData(error_message=f"error {i}"))
        publisher.publish_logs(
            {
                "ip": "127.0.0.1",
                "pid": "gcs",
                "job": "0001",
                "is_err": False,
                "lines": [f"log {i}"],
                "actor_name": "test actor",
                "task_name": "test task",
            }
        )

    t1.join(timeout=10)
    assert len(errors) == num_messages, str(errors)
    assert not t1.is_alive(), str(errors)

    t2.join(timeout=10)
    assert len(logs) == num_messages, str(logs)
    assert not t2.is_alive(), str(logs)

    for i in range(0, num_messages):
        assert errors[i].error_message == f"error {i}", str(errors)
        assert logs[i]["lines"][0] == f"log {i}", str(logs)
Esempio n. 2
0
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
    address_info = ray_start_regular
    gcs_server_addr = address_info["gcs_address"]

    subscriber = GcsAioErrorSubscriber(address=gcs_server_addr)
    await subscriber.subscribe()

    publisher = GcsAioPublisher(address=gcs_server_addr)
    err1 = ErrorTableData(error_message="test error message 1")
    err2 = ErrorTableData(error_message="test error message 2")
    await publisher.publish_error(b"aaa_id", err1)
    await publisher.publish_error(b"bbb_id", err2)

    assert await subscriber.poll() == (b"aaa_id", err1)
    assert await subscriber.poll() == (b"bbb_id", err2)

    await subscriber.close()
Esempio n. 3
0
async def test_aio_publish_and_subscribe_error_info(ray_start_regular):
    address_info = ray_start_regular
    redis = ray._private.services.create_redis_client(
        address_info["redis_address"],
        password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)

    gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)

    subscriber = GcsAioSubscriber(address=gcs_server_addr)
    await subscriber.subscribe_error()

    publisher = GcsAioPublisher(address=gcs_server_addr)
    err1 = ErrorTableData(error_message="test error message 1")
    err2 = ErrorTableData(error_message="test error message 2")
    await publisher.publish_error(b"aaa_id", err1)
    await publisher.publish_error(b"bbb_id", err2)

    assert await subscriber.poll_error() == (b"aaa_id", err1)
    assert await subscriber.poll_error() == (b"bbb_id", err2)

    await subscriber.close()
Esempio n. 4
0
def construct_error_message(job_id, error_type, message, timestamp):
    """Construct a serialized ErrorTableData object.

    Args:
        job_id: The ID of the job that the error should go to. If this is
            nil, then the error will go to all drivers.
        error_type: The type of the error.
        message: The error message.
        timestamp: The time of the error.

    Returns:
        The serialized object.
    """
    data = ErrorTableData()
    data.job_id = job_id.binary()
    data.type = error_type
    data.error_message = message
    data.timestamp = timestamp
    return data.SerializeToString()
Esempio n. 5
0
def test_subscribe_two_channels(ray_start_regular):
    """Tests concurrently subscribing to two channels work."""

    address_info = ray_start_regular
    redis = ray._private.services.create_redis_client(
        address_info["redis_address"],
        password=ray.ray_constants.REDIS_DEFAULT_PASSWORD)

    gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis)

    num_messages = 100

    errors = []

    def receive_errors():
        subscriber = GcsErrorSubscriber(address=gcs_server_addr)
        subscriber.subscribe()
        while len(errors) < num_messages:
            _, msg = subscriber.poll()
            errors.append(msg)

    logs = []

    def receive_logs():
        subscriber = GcsLogSubscriber(address=gcs_server_addr)
        subscriber.subscribe()
        while len(logs) < num_messages:
            log_batch = subscriber.poll()
            logs.append(log_batch)

    t1 = threading.Thread(target=receive_errors)
    t1.start()

    t2 = threading.Thread(target=receive_logs)
    t2.start()

    publisher = GcsPublisher(address=gcs_server_addr)
    for i in range(0, num_messages):
        publisher.publish_error(b"msg_id",
                                ErrorTableData(error_message=f"error {i}"))
        publisher.publish_logs({
            "ip": "127.0.0.1",
            "pid": "gcs",
            "job": "0001",
            "is_err": False,
            "lines": [f"line {i}"],
            "actor_name": "test actor",
            "task_name": "test task",
        })

    t1.join(timeout=10)
    assert not t1.is_alive(), len(errors)
    assert len(errors) == num_messages, len(errors)

    t2.join(timeout=10)
    assert not t2.is_alive(), len(logs)
    assert len(logs) == num_messages, len(logs)

    for i in range(0, num_messages):
        assert errors[i].error_message == f"error {i}"
        assert logs[i]["lines"][0] == f"line {i}"