def run_task_function(client: NotificationClient): with af.global_config_file(workflow_config_file()): with af.config('task_2'): executor_1 = af.user_define_operation( af.PythonObjectExecutor(SimpleExecutor())) with af.config('task_5'): executor_2 = af.user_define_operation( af.PythonObjectExecutor(SimpleExecutor())) af.user_define_control_dependency(src=executor_2, dependency=executor_1, namespace='test', event_key='key_1', event_value='value_1', sender='*') workflow_info = af.workflow_operation.submit_workflow( workflow_name) af.workflow_operation.start_new_workflow_execution(workflow_name) flag = True while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'test_project.test_workflow', TaskExecution.task_id == 'task_2').all() if 1 == len(tes) and flag: client.send_event( BaseEvent(key='key_1', value='value_1')) flag = False dag_run = session.query(DagRun).filter( DagRun.dag_id == 'test_project.test_workflow').first() if dag_run is not None and dag_run.state in State.finished: break else: time.sleep(1)
def run_trigger_task_function(self): # waiting parsed dag file done, time.sleep(5) ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="a") client = EventSchedulerClient(ns_client=ns_client) execution_context = client.schedule_dag('trigger_task') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: client.schedule_task('trigger_task', 'task_2', SchedulingAction.START, execution_context) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def run_event_task_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='start', value='', event_type='', namespace='')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) client.send_event(StopSchedulerEvent(job_id=0).to_event())
def run_test_fun(): time.sleep(3) client = NotificationClient(server_uri="localhost:{}".format(server_port()), default_namespace="test") try: test_function(client) except Exception as e: raise e finally: client.send_event(StopSchedulerEvent(job_id=0).to_event())
def _send_task_status_change_event(self): task_status_changed_event = TaskStateChangedEvent( self.task_instance.task_id, self.task_instance.dag_id, self.task_instance.execution_date, self.task_instance.state) event = task_status_changed_event.to_event() client = NotificationClient(self.server_uri, default_namespace=event.namespace, sender=event.sender) self.log.info("LocalTaskJob sending event: {}".format(event)) client.send_event(event)
def run_no_dag_file_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) with create_session() as session: client.trigger_parse_dag() result = client.schedule_dag('no_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def execute(self, function_context: FunctionContext, input_list: List) -> List: from notification_service.client import NotificationClient client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender=self.sender) client.send_event( BaseEvent(key=self.key, value=self.value, event_type=self.event_type)) return []
def _send_request_and_receive_response(self, server_uri, file_path): key = '{}_{}'.format(file_path, time.time_ns()) client = NotificationClient(server_uri=server_uri, default_namespace=SCHEDULER_NAMESPACE) event = BaseEvent(key=key, event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value, value=file_path) client.send_event(event) watcher: ResponseWatcher = ResponseWatcher() client.start_listen_event(key=key, event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value, watcher=watcher) res: BaseEvent = watcher.get_result() self.assertEquals(event.key, res.key) self.assertEquals(event.value, file_path)
def stop_workflow(self, workflow_name) -> bool: """ Stop the workflow. No more workflow execution(Airflow dag_run) would be scheduled and all running jobs would be stopped. :param workflow_name: workflow name :return: True if succeed """ # TODO For now, simply return True as long as message is sent successfully, # actually we need a response from try: notification_client = NotificationClient(self.server_uri, SCHEDULER_NAMESPACE) notification_client.send_event( StopDagEvent(workflow_name).to_event()) return True except Exception: return False
def run_trigger_dag_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: client.trigger_parse_dag() result = client.schedule_dag('trigger_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def run_airflow_dag_function(self): # waiting parsed dag file done from datetime import datetime ns_client = NotificationClient(server_uri='localhost:50051') with af.global_config_file(test_util.get_workflow_config_file()): with af.config('task_1'): cmd_executor = af.user_define_operation( output_num=0, executor=CmdExecutor(cmd_line=['echo "hello world!"'])) af.deploy_to_airflow(test_util.get_project_path(), dag_id='test_dag_111', default_args={ 'schedule_interval': None, 'start_date': datetime(2025, 12, 1), }) context = af.run(project_path=test_util.get_project_path(), dag_id='test_dag_111', scheduler_type=SchedulerType.AIRFLOW) print(context.dagrun_id) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def run_ai_flow_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender='1-job-name') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1', TaskExecution.task_id == '1-job-name').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='key_1', value='value_1', event_type='UNDEFINED')) client.send_event( BaseEvent(key='key_2', value='value_2', event_type='UNDEFINED')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1').all() if len(tes_2) == 3: break else: time.sleep(1) break else: time.sleep(1) time.sleep(3) client.send_event(StopSchedulerEvent(job_id=0).to_event())
def test_send_listening_on_different_server(self): event_list = [] class TestWatch(EventWatcher): def __init__(self, event_list) -> None: super().__init__() self.event_list = event_list def process(self, events: List[BaseEvent]): self.event_list.extend(events) self.master2 = self.start_master("localhost", "50052") self.wait_for_new_members_detected("localhost:50052") another_client = NotificationClient(server_uri="localhost:50052") try: event1 = another_client.send_event(BaseEvent(key="key1", value="value1")) self.client.start_listen_events(watcher=TestWatch(event_list), version=event1.version) another_client.send_event(BaseEvent(key="key2", value="value2")) another_client.send_event(BaseEvent(key="key3", value="value3")) finally: self.client.stop_listen_events() self.assertEqual(2, len(event_list))
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer): def __init__(self, store_uri, notification_uri=None): db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) self.model_repo_store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.model_repo_store = SqlAlchemyStore(store_uri) self.notification_client = None if notification_uri is not None: self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE) @catch_exception def createRegisteredModel(self, request, context): registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name, registered_model_param.model_desc) return _wrap_response(registered_model_meta.to_meta_proto()) @catch_exception def updateRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.update_registered_model( RegisteredModel(model_meta_param.model_name), registered_model_param.model_name, registered_model_param.model_desc) return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto()) @catch_exception def deleteRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(request.model_meta) @catch_exception def listRegisteredModels(self, request, context): registered_models = self.model_repo_store.list_registered_models() return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto() for registered_model in registered_models])) @catch_exception def getRegisteredModelDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) registered_model_detail = self.model_repo_store.get_registered_model_detail( RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto()) @catch_exception def createModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name, model_version_param.model_path, model_version_param.model_type, model_version_param.version_desc, model_version_param.current_stage) event_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), event_type)) return _wrap_response(model_version_meta.to_meta_proto()) @catch_exception def updateModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.update_model_version(model_meta_param, model_version_param.model_path, model_version_param.model_type, model_version_param.version_desc, model_version_param.current_stage) if model_version_param.current_stage is not None: event_type = MODEL_VERSION_TO_EVENT_TYPE.get( ModelVersionStage.from_string(model_version_param.current_stage)) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), event_type)) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto()) @catch_exception def deleteModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) self.model_repo_store.delete_model_version(model_meta_param) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_meta_param.model_name, json.dumps(model_meta_param.__dict__), ModelVersionEventType.MODEL_DELETED)) return _wrap_response(request.model_meta) @catch_exception def getModelVersionDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
class HaServerTest(unittest.TestCase): @classmethod def start_master(cls, host, port): port = str(port) server_uri = host + ":" + port storage = DbEventStorage() ha_manager = SimpleNotificationServerHaManager() ha_storage = DbHighAvailabilityStorage() service = HighAvailableNotificationService( storage, ha_manager, server_uri, ha_storage) master = NotificationMaster(service, port=int(port)) master.run() return master @classmethod def setUpClass(cls): cls.storage = DbEventStorage() cls.master1 = None cls.master2 = None cls.master3 = None def setUp(self): self.storage.clean_up() self.master1 = self.start_master("localhost", "50051") self.client = NotificationClient(server_uri="localhost:50051", enable_ha=True) def tearDown(self): self.client.stop_listen_events() self.client.stop_listen_event() self.client.disable_high_availability() if self.master1 is not None: self.master1.stop() if self.master2 is not None: self.master2.stop() if self.master3 is not None: self.master3.stop() def wait_for_new_members_detected(self, new_member_uri): for i in range(100): living_member = self.client.living_members if new_member_uri in living_member: break else: time.sleep(0.1) def test_server_change(self): self.client.send_event(BaseEvent(key="key", value="value1")) self.client.send_event(BaseEvent(key="key", value="value2")) self.client.send_event(BaseEvent(key="key", value="value3")) results = self.client.list_all_events() self.assertEqual(self.client.current_uri, "localhost:50051") self.master2 = self.start_master("localhost", "50052") self.wait_for_new_members_detected("localhost:50052") self.master1.stop() results2 = self.client.list_all_events() self.assertEqual(results, results2) self.assertEqual(self.client.current_uri, "localhost:50052") self.master3 = self.start_master("localhost", "50053") self.wait_for_new_members_detected("localhost:50053") self.master2.stop() results3 = self.client.list_all_events() self.assertEqual(results2, results3) self.assertEqual(self.client.current_uri, "localhost:50053") def test_send_listening_on_different_server(self): event_list = [] class TestWatch(EventWatcher): def __init__(self, event_list) -> None: super().__init__() self.event_list = event_list def process(self, events: List[BaseEvent]): self.event_list.extend(events) self.master2 = self.start_master("localhost", "50052") self.wait_for_new_members_detected("localhost:50052") another_client = NotificationClient(server_uri="localhost:50052") try: event1 = another_client.send_event(BaseEvent(key="key1", value="value1")) self.client.start_listen_events(watcher=TestWatch(event_list), version=event1.version) another_client.send_event(BaseEvent(key="key2", value="value2")) another_client.send_event(BaseEvent(key="key3", value="value3")) finally: self.client.stop_listen_events() self.assertEqual(2, len(event_list)) def test_start_with_multiple_servers(self): self.client.disable_high_availability() self.client = NotificationClient(server_uri="localhost:55001,localhost:50051", enable_ha=True) self.assertTrue(self.client.current_uri, "localhost:50051")
def execute(self, context): notification_client = NotificationClient(server_uri=self.uri) notification_client.send_event(event=self.event)
class NotificationTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.storage = EventModelStorage() cls.master = NotificationMaster(NotificationService(cls.storage)) cls.master.run() @classmethod def tearDownClass(cls): cls.master.stop() def setUp(self): self.storage.clean_up() self.client = NotificationClient(server_uri="localhost:50051") def tearDown(self): self.client.stop_listen_events() self.client.stop_listen_event() def test_send_event(self): event = self.client.send_event(Event(key="key", value="value1")) self.assertTrue(event.version > 0) def test_list_events(self): event1 = self.client.send_event(Event(key="key", value="value1")) event2 = self.client.send_event(Event(key="key", value="value2")) event3 = self.client.send_event(Event(key="key", value="value3")) events = self.client.list_events("key", version=event1.version) self.assertEqual(2, len(events)) def test_listen_events(self): event_list = [] class TestWatch(EventWatcher): def __init__(self, event_list) -> None: super().__init__() self.event_list = event_list def process(self, events: List[Event]): self.event_list.extend(events) event1 = self.client.send_event(Event(key="key", value="value1")) self.client.start_listen_event(key="key", watcher=TestWatch(event_list), version=event1.version) event = self.client.send_event(Event(key="key", value="value2")) event = self.client.send_event(Event(key="key", value="value3")) self.client.stop_listen_event("key") events = self.client.list_events("key", version=event1.version) self.assertEqual(2, len(events)) self.assertEqual(2, len(event_list)) def test_all_listen_events(self): event = self.client.send_event(Event(key="key", value="value1")) event = self.client.send_event(Event(key="key", value="value2")) start_time = event.create_time event = self.client.send_event(Event(key="key", value="value3")) events = self.client.list_all_events(start_time) self.assertEqual(2, len(events)) def test_listen_all_events(self): event_list = [] class TestWatch(EventWatcher): def __init__(self, event_list) -> None: super().__init__() self.event_list = event_list def process(self, events: List[Event]): self.event_list.extend(events) try: self.client.start_listen_events(watcher=TestWatch(event_list)) event = self.client.send_event(Event(key="key1", value="value1")) event = self.client.send_event(Event(key="key2", value="value2")) event = self.client.send_event(Event(key="key3", value="value3")) finally: self.client.stop_listen_events() self.assertEqual(3, len(event_list))
class TestEventBasedScheduler(unittest.TestCase): def setUp(self): db.clear_db_jobs() db.clear_db_dags() db.clear_db_serialized_dags() db.clear_db_runs() db.clear_db_task_execution() db.clear_db_message() self.scheduler = None self.port = 50102 self.storage = MemoryEventStorage() self.master = NotificationMaster(NotificationService(self.storage), self.port) self.master.run() self.client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="test_namespace") time.sleep(1) def tearDown(self): self.master.stop() def _get_task_instance(self, dag_id, task_id, session): return session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id).first() def schedule_task_function(self): stopped = False while not stopped: with create_session() as session: ti_sleep_1000_secs = self._get_task_instance( EVENT_BASED_SCHEDULER_DAG, 'sleep_1000_secs', session) ti_python_sleep = self._get_task_instance( EVENT_BASED_SCHEDULER_DAG, 'python_sleep', session) if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.SCHEDULED and \ ti_python_sleep and ti_python_sleep.state == State.SCHEDULED: self.client.send_event( BaseEvent(key='start', value='', event_type='', namespace='test_namespace')) while not stopped: ti_sleep_1000_secs.refresh_from_db() ti_python_sleep.refresh_from_db() if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.RUNNING and \ ti_python_sleep and ti_python_sleep.state == State.RUNNING: time.sleep(10) break else: time.sleep(1) self.client.send_event( BaseEvent(key='stop', value='', event_type=UNDEFINED_EVENT_TYPE, namespace='test_namespace')) self.client.send_event( BaseEvent(key='restart', value='', event_type=UNDEFINED_EVENT_TYPE, namespace='test_namespace')) while not stopped: ti_sleep_1000_secs.refresh_from_db() ti_python_sleep.refresh_from_db() if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.KILLED and \ ti_python_sleep and ti_python_sleep.state == State.RUNNING: stopped = True else: time.sleep(1) else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_event_based_scheduler(self): t = threading.Thread(target=self.schedule_task_function) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_based_scheduler.py') def test_replay_message(self): key = "stop" mailbox = Mailbox() mailbox.set_scheduling_job_id(1234) watcher = SchedulerEventWatcher(mailbox) self.client.start_listen_events(watcher=watcher, start_time=int(time.time() * 1000), version=None) self.send_event(key) msg: BaseEvent = mailbox.get_message() self.assertEqual(msg.key, key) with create_session() as session: msg_from_db = session.query(Message).first() expect_non_unprocessed = EventBasedScheduler.get_unprocessed_message( 1000) self.assertEqual(0, len(expect_non_unprocessed)) unprocessed = EventBasedScheduler.get_unprocessed_message(1234) self.assertEqual(unprocessed[0].serialized_message, msg_from_db.data) deserialized_data = pickle.loads(msg_from_db.data) self.assertEqual(deserialized_data.key, key) self.assertEqual(msg, deserialized_data) def send_event(self, key): event = self.client.send_event( BaseEvent(key=key, event_type=UNDEFINED_EVENT_TYPE, value="value1")) self.assertEqual(key, event.key) @provide_session def get_task_execution(self, dag_id, task_id, session): return session.query(TaskExecution).filter( TaskExecution.dag_id == dag_id, TaskExecution.task_id == task_id).all() @provide_session def get_latest_job_id(self, session): return session.query(BaseJob).order_by(sqlalchemy.desc( BaseJob.id)).first().id def start_scheduler(self, file_path): self.scheduler = EventBasedSchedulerJob( dag_directory=file_path, server_uri="localhost:{}".format(self.port), executor=LocalExecutor(3), max_runs=-1, refresh_dag_dir_interval=30) print("scheduler starting") self.scheduler.run() def wait_for_running(self): while True: if self.scheduler is not None: time.sleep(5) break else: time.sleep(1) def wait_for_task_execution(self, dag_id, task_id, expected_num): result = False check_nums = 100 while check_nums > 0: time.sleep(2) check_nums = check_nums - 1 tes = self.get_task_execution(dag_id, task_id) if len(tes) == expected_num: result = True break self.assertTrue(result) def wait_for_task(self, dag_id, task_id, expected_state): result = False check_nums = 100 while check_nums > 0: time.sleep(2) check_nums = check_nums - 1 with create_session() as session: ti = session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id).first() if ti and ti.state == expected_state: result = True break self.assertTrue(result) def test_notification(self): self.client.send_event(BaseEvent(key='a', value='b')) def run_a_task_function(self): while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_a_task(self): t = threading.Thread(target=self.run_a_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_single_task_dag.py') tes: List[TaskExecution] = self.get_task_execution("single", "task_1") self.assertEqual(len(tes), 1) def run_event_task_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='start', value='', event_type='', namespace='')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_event_task(self): t = threading.Thread(target=self.run_event_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_task_dag.py') tes: List[TaskExecution] = self.get_task_execution( "event_dag", "task_2") self.assertEqual(len(tes), 1) def run_trigger_dag_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: client.trigger_parse_dag() result = client.schedule_dag('trigger_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_trigger_dag(self): import multiprocessing p = multiprocessing.Process(target=self.run_trigger_dag_function, args=()) p.start() self.start_scheduler('../../dags/test_run_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_dag", "task_1") self.assertEqual(len(tes), 1) def run_no_dag_file_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) with create_session() as session: client.trigger_parse_dag() result = client.schedule_dag('no_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_no_dag_file_trigger_dag(self): import multiprocessing p = multiprocessing.Process(target=self.run_no_dag_file_function, args=()) p.start() self.start_scheduler('../../dags/test_run_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_dag", "task_1") self.assertEqual(len(tes), 0) def run_trigger_task_function(self): # waiting parsed dag file done, time.sleep(5) ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="a") client = EventSchedulerClient(ns_client=ns_client) execution_context = client.schedule_dag('trigger_task') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: client.schedule_task('trigger_task', 'task_2', SchedulingAction.START, execution_context) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_task_trigger_dag(self): import threading t = threading.Thread(target=self.run_trigger_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_task_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_task", "task_2") self.assertEqual(len(tes), 1) def run_ai_flow_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender='1-job-name') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1', TaskExecution.task_id == '1-job-name').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='key_1', value='value_1', event_type='UNDEFINED')) client.send_event( BaseEvent(key='key_2', value='value_2', event_type='UNDEFINED')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1').all() if len(tes_2) == 3: break else: time.sleep(1) break else: time.sleep(1) time.sleep(3) client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_ai_flow_dag(self): import threading t = threading.Thread(target=self.run_ai_flow_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_aiflow_dag.py') tes: List[TaskExecution] = self.get_task_execution( "workflow_1", "1-job-name") self.assertEqual(len(tes), 1) def stop_dag_function(self): stopped = False while not stopped: tes = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG, 'sleep_to_be_stopped') if tes and len(tes) == 1: self.client.send_event( StopDagEvent(EVENT_BASED_SCHEDULER_DAG).to_event()) while not stopped: tes2 = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG, 'sleep_to_be_stopped') if tes2[0].state == State.KILLED: stopped = True time.sleep(5) else: time.sleep(1) else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_stop_dag(self): t = threading.Thread(target=self.stop_dag_function) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_based_scheduler.py') with create_session() as session: from airflow.models import DagModel dag_model: DagModel = DagModel.get_dagmodel( EVENT_BASED_SCHEDULER_DAG) self.assertTrue(dag_model.is_paused) self.assertEqual(dag_model.get_last_dagrun().state, "killed") for ti in session.query(TaskInstance).filter( TaskInstance.dag_id == EVENT_BASED_SCHEDULER_DAG): self.assertTrue(ti.state in [State.SUCCESS, State.KILLED]) for te in session.query(TaskExecution).filter( TaskExecution.dag_id == EVENT_BASED_SCHEDULER_DAG): self.assertTrue(te.state in [State.SUCCESS, State.KILLED]) def run_periodic_task_function(self): while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) > 1: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_periodic_task(self): t = threading.Thread(target=self.run_periodic_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_periodic_task_dag.py') tes: List[TaskExecution] = self.get_task_execution("single", "task_1") self.assertGreater(len(tes), 1) def run_one_task_function(self): self.wait_for_running() self.client.send_event(BaseEvent(key='a', value='a')) time.sleep(5) self.client.send_event(BaseEvent(key='a', value='a')) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) >= 2: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_one_task(self): t = threading.Thread(target=self.run_one_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_multiple_trigger_task_dag.py')
class EventSchedulerClient(object): def __init__(self, server_uri=None, namespace=None, ns_client=None): if ns_client is None: self.ns_client = NotificationClient(server_uri, namespace) else: self.ns_client = ns_client @staticmethod def generate_id(id): return '{}_{}'.format(id, time.time_ns()) def trigger_parse_dag(self) -> bool: id = self.generate_id('') watcher: ResponseWatcher = ResponseWatcher() handler: ThreadEventWatcherHandle \ = self.ns_client.start_listen_event(key=id, event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value, namespace=SCHEDULER_NAMESPACE, watcher=watcher) self.ns_client.send_event( BaseEvent( key=id, event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value, value='')) result = watcher.get_result() handler.stop() return True def schedule_dag(self, dag_id) -> ExecutionContext: id = self.generate_id(dag_id) watcher: ResponseWatcher = ResponseWatcher() handler: ThreadEventWatcherHandle \ = self.ns_client.start_listen_event(key=id, event_type=SchedulerInnerEventType.RESPONSE.value, namespace=SCHEDULER_NAMESPACE, watcher=watcher) self.ns_client.send_event( RequestEvent(request_id=id, body=RunDagMessage(dag_id).to_json()).to_event()) result: ResponseEvent = ResponseEvent.from_base_event( watcher.get_result()) handler.stop() return ExecutionContext(dagrun_id=result.body) def stop_dag_run(self, dag_id, context: ExecutionContext) -> ExecutionContext: id = self.generate_id(str(dag_id) + str(context.dagrun_id)) watcher: ResponseWatcher = ResponseWatcher() handler: ThreadEventWatcherHandle \ = self.ns_client.start_listen_event(key=id, event_type=SchedulerInnerEventType.RESPONSE.value, namespace=SCHEDULER_NAMESPACE, watcher=watcher) self.ns_client.send_event( RequestEvent( request_id=id, body=StopDagRunMessage( dag_id=dag_id, dagrun_id=context.dagrun_id).to_json()).to_event()) result: ResponseEvent = ResponseEvent.from_base_event( watcher.get_result()) handler.stop() return ExecutionContext(dagrun_id=result.body) def schedule_task(self, dag_id: str, task_id: str, action: SchedulingAction, context: ExecutionContext) -> ExecutionContext: id = self.generate_id(context.dagrun_id) watcher: ResponseWatcher = ResponseWatcher() handler: ThreadEventWatcherHandle \ = self.ns_client.start_listen_event(key=id, event_type=SchedulerInnerEventType.RESPONSE.value, namespace=SCHEDULER_NAMESPACE, watcher=watcher) self.ns_client.send_event( RequestEvent(request_id=id, body=ExecuteTaskMessage( dag_id=dag_id, task_id=task_id, dagrun_id=context.dagrun_id, action=action.value).to_json()).to_event()) result: ResponseEvent = ResponseEvent.from_base_event( watcher.get_result()) handler.stop() return ExecutionContext(dagrun_id=result.body)
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer): def __init__(self, store_uri, server_uri, notification_uri=None): self.model_repo_store = SqlAlchemyStore(store_uri) if notification_uri is None: self.notification_client = NotificationClient(server_uri) else: self.notification_client = NotificationClient(notification_uri) @catch_exception def createRegisteredModel(self, request, context): registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name, ModelType.Name( registered_model_param.model_type), registered_model_param.model_desc) return _wrap_response(registered_model_meta.to_meta_proto()) @catch_exception def updateRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.update_registered_model( RegisteredModel(model_meta_param.model_name), registered_model_param.model_name, ModelType.Name( registered_model_param.model_type), registered_model_param.model_desc) return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto()) @catch_exception def deleteRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(request.model_meta) @catch_exception def listRegisteredModels(self, request, context): registered_models = self.model_repo_store.list_registered_models() return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto() for registered_model in registered_models])) @catch_exception def getRegisteredModelDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) registered_model_detail = self.model_repo_store.get_registered_model_detail( RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto()) @catch_exception def createModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name, model_version_param.model_path, model_version_param.model_metric, model_version_param.model_flavor, model_version_param.version_desc, model_version_param.current_stage) model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), model_type)) return _wrap_response(model_version_meta.to_meta_proto()) @catch_exception def updateModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.update_model_version(model_meta_param, model_version_param.model_path, model_version_param.model_metric, model_version_param.model_flavor, model_version_param.version_desc, model_version_param.current_stage) if model_version_param.current_stage is not None: model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), model_type)) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto()) @catch_exception def deleteModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) self.model_repo_store.delete_model_version(model_meta_param) self.notification_client.send_event(BaseEvent(model_meta_param.model_name, json.dumps(model_meta_param.__dict__), ModelVersionEventType.MODEL_DELETED)) return _wrap_response(request.model_meta) @catch_exception def getModelVersionDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())