def test_publish_and_subscribe_logs(ray_start_regular): address_info = ray_start_regular redis = ray._private.services.create_redis_client( address_info["redis_address"], password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis) subscriber = GcsLogSubscriber(address=gcs_server_addr) subscriber.subscribe() publisher = GcsPublisher(address=gcs_server_addr) log_batch = { "ip": "127.0.0.1", "pid": 1234, "job": "0001", "is_err": False, "lines": ["line 1", "line 2"], "actor_name": "test actor", "task_name": "test task", } publisher.publish_logs(log_batch) # PID is treated as string. log_batch["pid"] = "1234" assert subscriber.poll() == log_batch subscriber.close()
def test_two_subscribers(ray_start_regular): """Tests concurrently subscribing to two channels work.""" address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"] num_messages = 100 errors = [] error_subscriber = GcsErrorSubscriber(address=gcs_server_addr) # Make sure subscription is registered before publishing starts. error_subscriber.subscribe() def receive_errors(): while len(errors) < num_messages: _, msg = error_subscriber.poll() errors.append(msg) t1 = threading.Thread(target=receive_errors) t1.start() logs = [] log_subscriber = GcsLogSubscriber(address=gcs_server_addr) # Make sure subscription is registered before publishing starts. log_subscriber.subscribe() def receive_logs(): while len(logs) < num_messages: log_batch = log_subscriber.poll() logs.append(log_batch) t2 = threading.Thread(target=receive_logs) t2.start() publisher = GcsPublisher(address=gcs_server_addr) for i in range(0, num_messages): publisher.publish_error(b"msg_id", ErrorTableData(error_message=f"error {i}")) publisher.publish_logs( { "ip": "127.0.0.1", "pid": "gcs", "job": "0001", "is_err": False, "lines": [f"log {i}"], "actor_name": "test actor", "task_name": "test task", } ) t1.join(timeout=10) assert len(errors) == num_messages, str(errors) assert not t1.is_alive(), str(errors) t2.join(timeout=10) assert len(logs) == num_messages, str(logs) assert not t2.is_alive(), str(logs) for i in range(0, num_messages): assert errors[i].error_message == f"error {i}", str(errors) assert logs[i]["lines"][0] == f"log {i}", str(logs)
def test_publish_and_subscribe_logs(ray_start_regular): address_info = ray_start_regular gcs_server_addr = address_info["gcs_address"] subscriber = GcsLogSubscriber(address=gcs_server_addr) subscriber.subscribe() publisher = GcsPublisher(address=gcs_server_addr) log_batch = { "ip": "127.0.0.1", "pid": 1234, "job": "0001", "is_err": False, "lines": ["line 1", "line 2"], "actor_name": "test actor", "task_name": "test task", } publisher.publish_logs(log_batch) # PID is treated as string. log_batch["pid"] = "1234" assert subscriber.poll() == log_batch subscriber.close()
def test_subscribe_two_channels(ray_start_regular): """Tests concurrently subscribing to two channels work.""" address_info = ray_start_regular redis = ray._private.services.create_redis_client( address_info["redis_address"], password=ray.ray_constants.REDIS_DEFAULT_PASSWORD) gcs_server_addr = gcs_utils.get_gcs_address_from_redis(redis) num_messages = 100 errors = [] def receive_errors(): subscriber = GcsErrorSubscriber(address=gcs_server_addr) subscriber.subscribe() while len(errors) < num_messages: _, msg = subscriber.poll() errors.append(msg) logs = [] def receive_logs(): subscriber = GcsLogSubscriber(address=gcs_server_addr) subscriber.subscribe() while len(logs) < num_messages: log_batch = subscriber.poll() logs.append(log_batch) t1 = threading.Thread(target=receive_errors) t1.start() t2 = threading.Thread(target=receive_logs) t2.start() publisher = GcsPublisher(address=gcs_server_addr) for i in range(0, num_messages): publisher.publish_error(b"msg_id", ErrorTableData(error_message=f"error {i}")) publisher.publish_logs({ "ip": "127.0.0.1", "pid": "gcs", "job": "0001", "is_err": False, "lines": [f"line {i}"], "actor_name": "test actor", "task_name": "test task", }) t1.join(timeout=10) assert not t1.is_alive(), len(errors) assert len(errors) == num_messages, len(errors) t2.join(timeout=10) assert not t2.is_alive(), len(logs) assert len(logs) == num_messages, len(logs) for i in range(0, num_messages): assert errors[i].error_message == f"error {i}" assert logs[i]["lines"][0] == f"line {i}"