def test_two_subscribers(ray_start_regular): """Tests concurrently subscribing to two channels work.""" address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"] num_messages = 100 errors = [] error_subscriber = GcsErrorSubscriber(address=gcs_server_addr) # Make sure subscription is registered before publishing starts. error_subscriber.subscribe() def receive_errors(): while len(errors) < num_messages: _, msg = error_subscriber.poll() errors.append(msg) t1 = threading.Thread(target=receive_errors) t1.start() logs = [] log_subscriber = GcsLogSubscriber(address=gcs_server_addr) # Make sure subscription is registered before publishing starts. log_subscriber.subscribe() def receive_logs(): while len(logs) < num_messages: log_batch = log_subscriber.poll() logs.append(log_batch) t2 = threading.Thread(target=receive_logs) t2.start() publisher = GcsPublisher(address=gcs_server_addr) for i in range(0, num_messages): publisher.publish_error(b"msg_id", ErrorTableData(error_message=f"error {i}")) publisher.publish_logs( { "ip": "127.0.0.1", "pid": "gcs", "job": "0001", "is_err": False, "lines": [f"log {i}"], "actor_name": "test actor", "task_name": "test task", } ) t1.join(timeout=10) assert len(errors) == num_messages, str(errors) assert not t1.is_alive(), str(errors) t2.join(timeout=10) assert len(logs) == num_messages, str(logs) assert not t2.is_alive(), str(logs) for i in range(0, num_messages): assert errors[i].error_message == f"error {i}", str(errors) assert logs[i]["lines"][0] == f"log {i}", str(logs)
async def test_aio_publish_and_subscribe_error_info(ray_start_regular): address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"] subscriber = GcsAioErrorSubscriber(address=gcs_server_addr) await subscriber.subscribe() publisher = GcsAioPublisher(address=gcs_server_addr) err1 = ErrorTableData(error_message="test error message 1") err2 = ErrorTableData(error_message="test error message 2") await publisher.publish_error(b"aaa_id", err1) await publisher.publish_error(b"bbb_id", err2) assert await subscriber.poll() == (b"aaa_id", err1) assert await subscriber.poll() == (b"bbb_id", err2) await subscriber.close()
async def test_aio_publish_and_subscribe_error_info(ray_start_regular): address_info = ray_start_regular redis = ray._private.services.create_redis_client( address_info["redis_address"], password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis) subscriber = GcsAioSubscriber(address=gcs_server_addr) await subscriber.subscribe_error() publisher = GcsAioPublisher(address=gcs_server_addr) err1 = ErrorTableData(error_message="test error message 1") err2 = ErrorTableData(error_message="test error message 2") await publisher.publish_error(b"aaa_id", err1) await publisher.publish_error(b"bbb_id", err2) assert await subscriber.poll_error() == (b"aaa_id", err1) assert await subscriber.poll_error() == (b"bbb_id", err2) await subscriber.close()
def construct_error_message(job_id, error_type, message, timestamp): """Construct a serialized ErrorTableData object. Args: job_id: The ID of the job that the error should go to. If this is nil, then the error will go to all drivers. error_type: The type of the error. message: The error message. timestamp: The time of the error. Returns: The serialized object. """ data = ErrorTableData() data.job_id = job_id.binary() data.type = error_type data.error_message = message data.timestamp = timestamp return data.SerializeToString()
def test_subscribe_two_channels(ray_start_regular): """Tests concurrently subscribing to two channels work.""" address_info = ray_start_regular redis = ray._private.services.create_redis_client( address_info["redis_address"], password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis) num_messages = 100 errors = [] def receive_errors(): subscriber = GcsErrorSubscriber(address=gcs_server_addr) subscriber.subscribe() while len(errors) < num_messages: _, msg = subscriber.poll() errors.append(msg) logs = [] def receive_logs(): subscriber = GcsLogSubscriber(address=gcs_server_addr) subscriber.subscribe() while len(logs) < num_messages: log_batch = subscriber.poll() logs.append(log_batch) t1 = threading.Thread(target=receive_errors) t1.start() t2 = threading.Thread(target=receive_logs) t2.start() publisher = GcsPublisher(address=gcs_server_addr) for i in range(0, num_messages): publisher.publish_error(b"msg_id", ErrorTableData(error_message=f"error {i}")) publisher.publish_logs({ "ip": "127.0.0.1", "pid": "gcs", "job": "0001", "is_err": False, "lines": [f"line {i}"], "actor_name": "test actor", "task_name": "test task", }) t1.join(timeout=10) assert not t1.is_alive(), len(errors) assert len(errors) == num_messages, len(errors) t2.join(timeout=10) assert not t2.is_alive(), len(logs) assert len(logs) == num_messages, len(logs) for i in range(0, num_messages): assert errors[i].error_message == f"error {i}" assert logs[i]["lines"][0] == f"line {i}"