def run_task_function(client: NotificationClient): with af.global_config_file(workflow_config_file()): with af.config('task_2'): executor_1 = af.user_define_operation( af.PythonObjectExecutor(SimpleExecutor())) with af.config('task_5'): executor_2 = af.user_define_operation( af.PythonObjectExecutor(SimpleExecutor())) af.user_define_control_dependency(src=executor_2, dependency=executor_1, namespace='test', event_key='key_1', event_value='value_1', sender='*') workflow_info = af.workflow_operation.submit_workflow( workflow_name) af.workflow_operation.start_new_workflow_execution(workflow_name) flag = True while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'test_project.test_workflow', TaskExecution.task_id == 'task_2').all() if 1 == len(tes) and flag: client.send_event( BaseEvent(key='key_1', value='value_1')) flag = False dag_run = session.query(DagRun).filter( DagRun.dag_id == 'test_project.test_workflow').first() if dag_run is not None and dag_run.state in State.finished: break else: time.sleep(1)
def run_trigger_task_function(self): # waiting parsed dag file done, time.sleep(5) ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="a") client = EventSchedulerClient(ns_client=ns_client) execution_context = client.schedule_dag('trigger_task') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: client.schedule_task('trigger_task', 'task_2', SchedulingAction.START, execution_context) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def __init__(self, server_uri=_SERVER_URI, notification_service_uri=None): MetadataClient.__init__(self, server_uri) ModelCenterClient.__init__(self, server_uri) DeployClient.__init__(self, server_uri) MetricClient.__init__(self, server_uri) if notification_service_uri is None: NotificationClient.__init__(self, server_uri) else: NotificationClient.__init__(self, notification_service_uri)
def _send_task_status_change_event(self): task_status_changed_event = TaskStateChangedEvent( self.task_instance.task_id, self.task_instance.dag_id, self.task_instance.execution_date, self.task_instance.state) event = task_status_changed_event.to_event() client = NotificationClient(self.server_uri, default_namespace=event.namespace, sender=event.sender) self.log.info("LocalTaskJob sending event: {}".format(event)) client.send_event(event)
def run_test_fun(): time.sleep(3) client = NotificationClient(server_uri="localhost:{}".format(server_port()), default_namespace="test") try: test_function(client) except Exception as e: raise e finally: client.send_event(StopSchedulerEvent(job_id=0).to_event())
def run_no_dag_file_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) with create_session() as session: client.trigger_parse_dag() result = client.schedule_dag('no_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def execute(self, function_context: FunctionContext, input_list: List) -> List: from notification_service.client import NotificationClient client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender=self.sender) client.send_event( BaseEvent(key=self.key, value=self.value, event_type=self.event_type)) return []
def __init__(self, store_uri, notification_uri=None): db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) self.model_repo_store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.model_repo_store = SqlAlchemyStore(store_uri) self.notification_client = None if notification_uri is not None: self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)
def run_ai_flow_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender='1-job-name') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1', TaskExecution.task_id == '1-job-name').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='key_1', value='value_1', event_type='UNDEFINED')) client.send_event( BaseEvent(key='key_2', value='value_2', event_type='UNDEFINED')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1').all() if len(tes_2) == 3: break else: time.sleep(1) break else: time.sleep(1) time.sleep(3) client.send_event(StopSchedulerEvent(job_id=0).to_event())
def __init__(self, dag_directory, server_uri=None, max_runs=-1, refresh_dag_dir_interval=conf.getint( 'scheduler', 'refresh_dag_dir_interval', fallback=30), *args, **kwargs): super().__init__(*args, **kwargs) self.mailbox: Mailbox = Mailbox() self.dag_trigger: DagTrigger = DagTrigger( dag_directory=dag_directory, max_runs=max_runs, dag_ids=None, pickle_dags=False, mailbox=self.mailbox, refresh_dag_dir_interval=refresh_dag_dir_interval, notification_service_uri=server_uri) self.task_event_manager = DagRunEventManager(self.mailbox) self.executor.set_mailbox(self.mailbox) self.notification_client: NotificationClient = NotificationClient( server_uri=server_uri, default_namespace=SCHEDULER_NAMESPACE) self.scheduler: EventBasedScheduler = EventBasedScheduler( self.id, self.mailbox, self.task_event_manager, self.executor, self.notification_client) self.last_scheduling_id = self._last_scheduler_job_id()
def setUp(self): db.clear_db_jobs() db.clear_db_dags() db.clear_db_serialized_dags() db.clear_db_runs() db.clear_db_task_execution() db.clear_db_message() self.scheduler = None self.port = 50102 self.storage = MemoryEventStorage() self.master = NotificationMaster(NotificationService(self.storage), self.port) self.master.run() self.client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="test_namespace") time.sleep(1)
def stop_workflow(self, workflow_name) -> bool: """ Stop the workflow. No more workflow execution(Airflow dag_run) would be scheduled and all running jobs would be stopped. :param workflow_name: workflow name :return: True if succeed """ # TODO For now, simply return True as long as message is sent successfully, # actually we need a response from try: notification_client = NotificationClient(self.server_uri, SCHEDULER_NAMESPACE) notification_client.send_event( StopDagEvent(workflow_name).to_event()) return True except Exception: return False
def run_trigger_dag_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: client.trigger_parse_dag() result = client.schedule_dag('trigger_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def setUpClass(cls): kwargs = { "host": "127.0.0.1", "port": 27017, "db": "test" } cls.storage = MongoEventStorage(**kwargs) cls.master = NotificationMaster(NotificationService(cls.storage)) cls.master.run() cls.client = NotificationClient(server_uri="localhost:50051")
def _execute(self): """ 1. Init the DagRun route. 2. Start the executor. 3. Option of start the notification master. 4. Create the notification client. 5. Start the DagTrigger. 6. Run the scheduler event loop. :return: """ notification_client = None try: self._init_route() self.executor.set_use_nf(True) self.executor.start() self.dag_trigger = DagTrigger( subdir=self.subdir, mailbox=self.mail_box, run_duration=self.run_duration, using_sqlite=self.using_sqlite, num_runs=self.num_runs, processor_poll_interval=self._processor_poll_interval) if self.use_local_nf: self.notification_master \ = NotificationMaster(service=NotificationService(EventModelStorage()), port=self.nf_port) self.notification_master.run() self.log.info("start notification service {0}".format( self.nf_port)) notification_client = NotificationClient( server_uri="localhost:{0}".format(self.nf_port)) else: notification_client \ = NotificationClient(server_uri="{0}:{1}".format(self.nf_host, self.nf_port)) notification_client.start_listen_events( watcher=SCEventWatcher(self.mail_box)) self.dag_trigger.start() self._start_executor_heartbeat() self._run_event_loop() except Exception as e: self.log.exception("Exception when executing _execute {0}".format( str(e))) finally: self.running = False self._stop_executor_heartheat() if self.dag_trigger is not None: self.dag_trigger.stop() if notification_client is not None: notification_client.stop_listen_events() if self.notification_master is not None: self.notification_master.stop() self.executor.end() self.log.info("Exited execute event scheduler")
def wait_for_master_started(cls, server_uri="localhost:50051"): last_exception = None for i in range(100): try: return NotificationClient(server_uri=server_uri, enable_ha=True) except Exception as e: time.sleep(10) last_exception = e raise Exception("The server %s is unavailable." % server_uri) from last_exception
def run_airflow_dag_function(self): # waiting parsed dag file done from datetime import datetime ns_client = NotificationClient(server_uri='localhost:50051') with af.global_config_file(test_util.get_workflow_config_file()): with af.config('task_1'): cmd_executor = af.user_define_operation( output_num=0, executor=CmdExecutor(cmd_line=['echo "hello world!"'])) af.deploy_to_airflow(test_util.get_project_path(), dag_id='test_dag_111', default_args={ 'schedule_interval': None, 'start_date': datetime(2025, 12, 1), }) context = af.run(project_path=test_util.get_project_path(), dag_id='test_dag_111', scheduler_type=SchedulerType.AIRFLOW) print(context.dagrun_id) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
def run_event_task_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='start', value='', event_type='', namespace='')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) client.send_event(StopSchedulerEvent(job_id=0).to_event())
class MemoryStorageTest(unittest.TestCase, NotificationTest): @classmethod def set_up_class(cls): cls.storage = MemoryEventStorage() cls.master = NotificationMaster(NotificationService(cls.storage)) cls.master.run() @classmethod def setUpClass(cls): cls.set_up_class() @classmethod def tearDownClass(cls): cls.master.stop() def setUp(self): self.storage.clean_up() self.client = NotificationClient(server_uri="localhost:50051") def tearDown(self): self.client.stop_listen_events() self.client.stop_listen_event()
def test_send_listening_on_different_server(self): event_list = [] class TestWatch(EventWatcher): def __init__(self, event_list) -> None: super().__init__() self.event_list = event_list def process(self, events: List[BaseEvent]): self.event_list.extend(events) self.master2 = self.start_master("localhost", "50052") self.wait_for_new_members_detected("localhost:50052") another_client = NotificationClient(server_uri="localhost:50052") try: event1 = another_client.send_event(BaseEvent(key="key1", value="value1")) self.client.start_listen_events(watcher=TestWatch(event_list), version=event1.version) another_client.send_event(BaseEvent(key="key2", value="value2")) another_client.send_event(BaseEvent(key="key3", value="value3")) finally: self.client.stop_listen_events() self.assertEqual(2, len(event_list))
def change_state(self, key, state): self.log.debug("Changing state: %s %s", key, state) self.running.pop(key, None) if self.use_nf: if self.client is None: self.client: NotificationClient = NotificationClient( server_uri="{0}:{1}".format(self.nf_host, self.nf_port)) dag_id, task_id, execution_date, try_number = key self.client.send_event(TaskStatusEvent( task_instance_key=TaskInstanceHelper.to_task_key(dag_id, task_id, execution_date), status=TaskInstanceHelper.to_event_value(state, try_number))) else: self.event_buffer[key] = state
class BaseExecutorTest(unittest.TestCase): def setUp(self): clear_db_event_model() self.master = NotificationMaster(service=NotificationService(EventModelStorage())) self.master.run() self.client = NotificationClient(server_uri="localhost:50051") def tearDown(self) -> None: self.master.stop() def test_get_event_buffer(self): executor = BaseExecutor() date = datetime.utcnow() try_number = 1 key1 = ("my_dag1", "my_task1", date, try_number) key2 = ("my_dag2", "my_task1", date, try_number) key3 = ("my_dag2", "my_task2", date, try_number) state = State.SUCCESS executor.event_buffer[key1] = state executor.event_buffer[key2] = state executor.event_buffer[key3] = state self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1) self.assertEqual(len(executor.get_event_buffer()), 2) self.assertEqual(len(executor.event_buffer), 0) @mock.patch('airflow.executors.base_executor.BaseExecutor.sync') @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks') @mock.patch('airflow.settings.Stats.gauge') def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = BaseExecutor() executor.heartbeat() calls = [mock.call('executor.open_slots', mock.ANY), mock.call('executor.queued_tasks', mock.ANY), mock.call('executor.running_tasks', mock.ANY)] mock_stats_gauge.assert_has_calls(calls) def test_use_nf_executor(self): executor = BaseExecutor() executor.set_use_nf(True) executor.change_state('key', State.RUNNING) executor.change_state('key', State.SUCCESS) events = self.client.list_all_events(1) self.assertEqual(2, len(events))
def _send_request_and_receive_response(self, server_uri, file_path): key = '{}_{}'.format(file_path, time.time_ns()) client = NotificationClient(server_uri=server_uri, default_namespace=SCHEDULER_NAMESPACE) event = BaseEvent(key=key, event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value, value=file_path) client.send_event(event) watcher: ResponseWatcher = ResponseWatcher() client.start_listen_event(key=key, event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value, watcher=watcher) res: BaseEvent = watcher.get_result() self.assertEquals(event.key, res.key) self.assertEquals(event.value, file_path)
class HaDbStorageTest(unittest.TestCase, NotificationTest): """ This test is used to ensure the high availability would not break the original functionality. """ @classmethod def set_up_class(cls): cls.storage = DbEventStorage() cls.master1 = start_ha_master("localhost", 50051) # The server startup is asynchronous, we need to wait for a while # to ensure it writes its metadata to the db. time.sleep(0.1) cls.master2 = start_ha_master("localhost", 50052) time.sleep(0.1) cls.master3 = start_ha_master("localhost", 50053) time.sleep(0.1) @classmethod def setUpClass(cls): cls.set_up_class() @classmethod def tearDownClass(cls): cls.master1.stop() cls.master2.stop() cls.master3.stop() def setUp(self): self.storage.clean_up() self.client = NotificationClient(server_uri="localhost:50052", enable_ha=True, list_member_interval_ms=1000, retry_timeout_ms=10000) def tearDown(self): self.client.stop_listen_events() self.client.stop_listen_event() self.client.disable_high_availability()
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer): def __init__(self, store_uri, notification_uri=None): db_engine = extract_db_engine_from_uri(store_uri) if DBType.value_of(db_engine) == DBType.MONGODB: username, password, host, port, db = parse_mongo_uri(store_uri) self.model_repo_store = MongoStore(host=host, port=int(port), username=username, password=password, db=db) else: self.model_repo_store = SqlAlchemyStore(store_uri) self.notification_client = None if notification_uri is not None: self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE) @catch_exception def createRegisteredModel(self, request, context): registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name, registered_model_param.model_desc) return _wrap_response(registered_model_meta.to_meta_proto()) @catch_exception def updateRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.update_registered_model( RegisteredModel(model_meta_param.model_name), registered_model_param.model_name, registered_model_param.model_desc) return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto()) @catch_exception def deleteRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(request.model_meta) @catch_exception def listRegisteredModels(self, request, context): registered_models = self.model_repo_store.list_registered_models() return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto() for registered_model in registered_models])) @catch_exception def getRegisteredModelDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) registered_model_detail = self.model_repo_store.get_registered_model_detail( RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto()) @catch_exception def createModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name, model_version_param.model_path, model_version_param.model_type, model_version_param.version_desc, model_version_param.current_stage) event_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), event_type)) return _wrap_response(model_version_meta.to_meta_proto()) @catch_exception def updateModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.update_model_version(model_meta_param, model_version_param.model_path, model_version_param.model_type, model_version_param.version_desc, model_version_param.current_stage) if model_version_param.current_stage is not None: event_type = MODEL_VERSION_TO_EVENT_TYPE.get( ModelVersionStage.from_string(model_version_param.current_stage)) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), event_type)) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto()) @catch_exception def deleteModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) self.model_repo_store.delete_model_version(model_meta_param) if self.notification_client is not None: self.notification_client.send_event(BaseEvent(model_meta_param.model_name, json.dumps(model_meta_param.__dict__), ModelVersionEventType.MODEL_DELETED)) return _wrap_response(request.model_meta) @catch_exception def getModelVersionDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
def execute(self, context): notification_client = NotificationClient(server_uri=self.uri) notification_client.send_event(event=self.event)
def setUp(self): self.storage.clean_up() self.client = NotificationClient(server_uri="localhost:50052", enable_ha=True, list_member_interval_ms=1000, retry_timeout_ms=10000)
class TestEventBasedScheduler(unittest.TestCase): def setUp(self): db.clear_db_jobs() db.clear_db_dags() db.clear_db_serialized_dags() db.clear_db_runs() db.clear_db_task_execution() db.clear_db_message() self.scheduler = None self.port = 50102 self.storage = MemoryEventStorage() self.master = NotificationMaster(NotificationService(self.storage), self.port) self.master.run() self.client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="test_namespace") time.sleep(1) def tearDown(self): self.master.stop() def _get_task_instance(self, dag_id, task_id, session): return session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id).first() def schedule_task_function(self): stopped = False while not stopped: with create_session() as session: ti_sleep_1000_secs = self._get_task_instance( EVENT_BASED_SCHEDULER_DAG, 'sleep_1000_secs', session) ti_python_sleep = self._get_task_instance( EVENT_BASED_SCHEDULER_DAG, 'python_sleep', session) if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.SCHEDULED and \ ti_python_sleep and ti_python_sleep.state == State.SCHEDULED: self.client.send_event( BaseEvent(key='start', value='', event_type='', namespace='test_namespace')) while not stopped: ti_sleep_1000_secs.refresh_from_db() ti_python_sleep.refresh_from_db() if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.RUNNING and \ ti_python_sleep and ti_python_sleep.state == State.RUNNING: time.sleep(10) break else: time.sleep(1) self.client.send_event( BaseEvent(key='stop', value='', event_type=UNDEFINED_EVENT_TYPE, namespace='test_namespace')) self.client.send_event( BaseEvent(key='restart', value='', event_type=UNDEFINED_EVENT_TYPE, namespace='test_namespace')) while not stopped: ti_sleep_1000_secs.refresh_from_db() ti_python_sleep.refresh_from_db() if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.KILLED and \ ti_python_sleep and ti_python_sleep.state == State.RUNNING: stopped = True else: time.sleep(1) else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_event_based_scheduler(self): t = threading.Thread(target=self.schedule_task_function) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_based_scheduler.py') def test_replay_message(self): key = "stop" mailbox = Mailbox() mailbox.set_scheduling_job_id(1234) watcher = SchedulerEventWatcher(mailbox) self.client.start_listen_events(watcher=watcher, start_time=int(time.time() * 1000), version=None) self.send_event(key) msg: BaseEvent = mailbox.get_message() self.assertEqual(msg.key, key) with create_session() as session: msg_from_db = session.query(Message).first() expect_non_unprocessed = EventBasedScheduler.get_unprocessed_message( 1000) self.assertEqual(0, len(expect_non_unprocessed)) unprocessed = EventBasedScheduler.get_unprocessed_message(1234) self.assertEqual(unprocessed[0].serialized_message, msg_from_db.data) deserialized_data = pickle.loads(msg_from_db.data) self.assertEqual(deserialized_data.key, key) self.assertEqual(msg, deserialized_data) def send_event(self, key): event = self.client.send_event( BaseEvent(key=key, event_type=UNDEFINED_EVENT_TYPE, value="value1")) self.assertEqual(key, event.key) @provide_session def get_task_execution(self, dag_id, task_id, session): return session.query(TaskExecution).filter( TaskExecution.dag_id == dag_id, TaskExecution.task_id == task_id).all() @provide_session def get_latest_job_id(self, session): return session.query(BaseJob).order_by(sqlalchemy.desc( BaseJob.id)).first().id def start_scheduler(self, file_path): self.scheduler = EventBasedSchedulerJob( dag_directory=file_path, server_uri="localhost:{}".format(self.port), executor=LocalExecutor(3), max_runs=-1, refresh_dag_dir_interval=30) print("scheduler starting") self.scheduler.run() def wait_for_running(self): while True: if self.scheduler is not None: time.sleep(5) break else: time.sleep(1) def wait_for_task_execution(self, dag_id, task_id, expected_num): result = False check_nums = 100 while check_nums > 0: time.sleep(2) check_nums = check_nums - 1 tes = self.get_task_execution(dag_id, task_id) if len(tes) == expected_num: result = True break self.assertTrue(result) def wait_for_task(self, dag_id, task_id, expected_state): result = False check_nums = 100 while check_nums > 0: time.sleep(2) check_nums = check_nums - 1 with create_session() as session: ti = session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id).first() if ti and ti.state == expected_state: result = True break self.assertTrue(result) def test_notification(self): self.client.send_event(BaseEvent(key='a', value='b')) def run_a_task_function(self): while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_a_task(self): t = threading.Thread(target=self.run_a_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_single_task_dag.py') tes: List[TaskExecution] = self.get_task_execution("single", "task_1") self.assertEqual(len(tes), 1) def run_event_task_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='start', value='', event_type='', namespace='')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_event_task(self): t = threading.Thread(target=self.run_event_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_task_dag.py') tes: List[TaskExecution] = self.get_task_execution( "event_dag", "task_2") self.assertEqual(len(tes), 1) def run_trigger_dag_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: client.trigger_parse_dag() result = client.schedule_dag('trigger_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_trigger_dag(self): import multiprocessing p = multiprocessing.Process(target=self.run_trigger_dag_function, args=()) p.start() self.start_scheduler('../../dags/test_run_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_dag", "task_1") self.assertEqual(len(tes), 1) def run_no_dag_file_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) with create_session() as session: client.trigger_parse_dag() result = client.schedule_dag('no_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_no_dag_file_trigger_dag(self): import multiprocessing p = multiprocessing.Process(target=self.run_no_dag_file_function, args=()) p.start() self.start_scheduler('../../dags/test_run_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_dag", "task_1") self.assertEqual(len(tes), 0) def run_trigger_task_function(self): # waiting parsed dag file done, time.sleep(5) ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="a") client = EventSchedulerClient(ns_client=ns_client) execution_context = client.schedule_dag('trigger_task') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: client.schedule_task('trigger_task', 'task_2', SchedulingAction.START, execution_context) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_task_trigger_dag(self): import threading t = threading.Thread(target=self.run_trigger_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_task_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_task", "task_2") self.assertEqual(len(tes), 1) def run_ai_flow_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender='1-job-name') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1', TaskExecution.task_id == '1-job-name').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='key_1', value='value_1', event_type='UNDEFINED')) client.send_event( BaseEvent(key='key_2', value='value_2', event_type='UNDEFINED')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1').all() if len(tes_2) == 3: break else: time.sleep(1) break else: time.sleep(1) time.sleep(3) client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_ai_flow_dag(self): import threading t = threading.Thread(target=self.run_ai_flow_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_aiflow_dag.py') tes: List[TaskExecution] = self.get_task_execution( "workflow_1", "1-job-name") self.assertEqual(len(tes), 1) def stop_dag_function(self): stopped = False while not stopped: tes = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG, 'sleep_to_be_stopped') if tes and len(tes) == 1: self.client.send_event( StopDagEvent(EVENT_BASED_SCHEDULER_DAG).to_event()) while not stopped: tes2 = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG, 'sleep_to_be_stopped') if tes2[0].state == State.KILLED: stopped = True time.sleep(5) else: time.sleep(1) else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_stop_dag(self): t = threading.Thread(target=self.stop_dag_function) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_based_scheduler.py') with create_session() as session: from airflow.models import DagModel dag_model: DagModel = DagModel.get_dagmodel( EVENT_BASED_SCHEDULER_DAG) self.assertTrue(dag_model.is_paused) self.assertEqual(dag_model.get_last_dagrun().state, "killed") for ti in session.query(TaskInstance).filter( TaskInstance.dag_id == EVENT_BASED_SCHEDULER_DAG): self.assertTrue(ti.state in [State.SUCCESS, State.KILLED]) for te in session.query(TaskExecution).filter( TaskExecution.dag_id == EVENT_BASED_SCHEDULER_DAG): self.assertTrue(te.state in [State.SUCCESS, State.KILLED]) def run_periodic_task_function(self): while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) > 1: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_periodic_task(self): t = threading.Thread(target=self.run_periodic_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_periodic_task_dag.py') tes: List[TaskExecution] = self.get_task_execution("single", "task_1") self.assertGreater(len(tes), 1) def run_one_task_function(self): self.wait_for_running() self.client.send_event(BaseEvent(key='a', value='a')) time.sleep(5) self.client.send_event(BaseEvent(key='a', value='a')) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) >= 2: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_one_task(self): t = threading.Thread(target=self.run_one_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_multiple_trigger_task_dag.py')
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer): def __init__(self, store_uri, server_uri, notification_uri=None): self.model_repo_store = SqlAlchemyStore(store_uri) if notification_uri is None: self.notification_client = NotificationClient(server_uri) else: self.notification_client = NotificationClient(notification_uri) @catch_exception def createRegisteredModel(self, request, context): registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name, ModelType.Name( registered_model_param.model_type), registered_model_param.model_desc) return _wrap_response(registered_model_meta.to_meta_proto()) @catch_exception def updateRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) registered_model_param = RegisteredModelParam.from_proto(request) registered_model_meta = self.model_repo_store.update_registered_model( RegisteredModel(model_meta_param.model_name), registered_model_param.model_name, ModelType.Name( registered_model_param.model_type), registered_model_param.model_desc) return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto()) @catch_exception def deleteRegisteredModel(self, request, context): model_meta_param = RegisteredModel.from_proto(request) self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(request.model_meta) @catch_exception def listRegisteredModels(self, request, context): registered_models = self.model_repo_store.list_registered_models() return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto() for registered_model in registered_models])) @catch_exception def getRegisteredModelDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) registered_model_detail = self.model_repo_store.get_registered_model_detail( RegisteredModel(model_name=model_meta_param.model_name)) return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto()) @catch_exception def createModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name, model_version_param.model_path, model_version_param.model_metric, model_version_param.model_flavor, model_version_param.version_desc, model_version_param.current_stage) model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), model_type)) return _wrap_response(model_version_meta.to_meta_proto()) @catch_exception def updateModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_param = ModelVersionParam.from_proto(request) model_version_meta = self.model_repo_store.update_model_version(model_meta_param, model_version_param.model_path, model_version_param.model_metric, model_version_param.model_flavor, model_version_param.version_desc, model_version_param.current_stage) if model_version_param.current_stage is not None: model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage)) self.notification_client.send_event(BaseEvent(model_version_meta.model_name, json.dumps(model_version_meta.__dict__), model_type)) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto()) @catch_exception def deleteModelVersion(self, request, context): model_meta_param = ModelVersion.from_proto(request) self.model_repo_store.delete_model_version(model_meta_param) self.notification_client.send_event(BaseEvent(model_meta_param.model_name, json.dumps(model_meta_param.__dict__), ModelVersionEventType.MODEL_DELETED)) return _wrap_response(request.model_meta) @catch_exception def getModelVersionDetail(self, request, context): model_meta_param = ModelVersion.from_proto(request) model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param) return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
def __init__(self, store_uri, server_uri, notification_uri=None): self.model_repo_store = SqlAlchemyStore(store_uri) if notification_uri is None: self.notification_client = NotificationClient(server_uri) else: self.notification_client = NotificationClient(notification_uri)