예제 #1
0
        def run_task_function(client: NotificationClient):
            with af.global_config_file(workflow_config_file()):
                with af.config('task_2'):
                    executor_1 = af.user_define_operation(
                        af.PythonObjectExecutor(SimpleExecutor()))
                with af.config('task_5'):
                    executor_2 = af.user_define_operation(
                        af.PythonObjectExecutor(SimpleExecutor()))
                af.user_define_control_dependency(src=executor_2,
                                                  dependency=executor_1,
                                                  namespace='test',
                                                  event_key='key_1',
                                                  event_value='value_1',
                                                  sender='*')
                workflow_info = af.workflow_operation.submit_workflow(
                    workflow_name)

            af.workflow_operation.start_new_workflow_execution(workflow_name)
            flag = True
            while True:
                with create_session() as session:
                    tes = session.query(TaskExecution).filter(
                        TaskExecution.dag_id == 'test_project.test_workflow',
                        TaskExecution.task_id == 'task_2').all()
                    if 1 == len(tes) and flag:
                        client.send_event(
                            BaseEvent(key='key_1', value='value_1'))
                        flag = False
                    dag_run = session.query(DagRun).filter(
                        DagRun.dag_id == 'test_project.test_workflow').first()
                    if dag_run is not None and dag_run.state in State.finished:
                        break
                    else:
                        time.sleep(1)
예제 #2
0
 def run_trigger_task_function(self):
     # waiting parsed dag file done,
     time.sleep(5)
     ns_client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                    default_namespace="a")
     client = EventSchedulerClient(ns_client=ns_client)
     execution_context = client.schedule_dag('trigger_task')
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'trigger_task',
                 TaskExecution.task_id == 'task_1').all()
             if len(tes) > 0:
                 client.schedule_task('trigger_task', 'task_2',
                                      SchedulingAction.START,
                                      execution_context)
                 while True:
                     with create_session() as session_2:
                         tes_2 = session_2.query(TaskExecution).filter(
                             TaskExecution.dag_id == 'trigger_task',
                             TaskExecution.task_id == 'task_2').all()
                         if len(tes_2) > 0:
                             break
                         else:
                             time.sleep(1)
                 break
             else:
                 time.sleep(1)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #3
0
 def run_event_task_function(self):
     client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                 default_namespace="")
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'event_dag',
                 TaskExecution.task_id == 'task_1').all()
             if len(tes) > 0:
                 time.sleep(5)
                 client.send_event(
                     BaseEvent(key='start',
                               value='',
                               event_type='',
                               namespace=''))
                 while True:
                     with create_session() as session_2:
                         tes_2 = session_2.query(TaskExecution).filter(
                             TaskExecution.dag_id == 'event_dag',
                             TaskExecution.task_id == 'task_2').all()
                         if len(tes_2) > 0:
                             break
                         else:
                             time.sleep(1)
                 break
             else:
                 time.sleep(1)
     client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #4
0
 def run_test_fun():
     time.sleep(3)
     client = NotificationClient(server_uri="localhost:{}".format(server_port()),
                                 default_namespace="test")
     try:
         test_function(client)
     except Exception as e:
         raise e
     finally:
         client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #5
0
 def _send_task_status_change_event(self):
     task_status_changed_event = TaskStateChangedEvent(
         self.task_instance.task_id, self.task_instance.dag_id,
         self.task_instance.execution_date, self.task_instance.state)
     event = task_status_changed_event.to_event()
     client = NotificationClient(self.server_uri,
                                 default_namespace=event.namespace,
                                 sender=event.sender)
     self.log.info("LocalTaskJob sending event: {}".format(event))
     client.send_event(event)
예제 #6
0
 def run_no_dag_file_function(self):
     ns_client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                    default_namespace="")
     client = EventSchedulerClient(ns_client=ns_client)
     with create_session() as session:
         client.trigger_parse_dag()
         result = client.schedule_dag('no_dag')
         print('result {}'.format(result.dagrun_id))
         time.sleep(5)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #7
0
 def execute(self, function_context: FunctionContext,
             input_list: List) -> List:
     from notification_service.client import NotificationClient
     client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                 default_namespace="default",
                                 sender=self.sender)
     client.send_event(
         BaseEvent(key=self.key,
                   value=self.value,
                   event_type=self.event_type))
     return []
 def _send_request_and_receive_response(self, server_uri, file_path):
     key = '{}_{}'.format(file_path, time.time_ns())
     client = NotificationClient(server_uri=server_uri,
                                 default_namespace=SCHEDULER_NAMESPACE)
     event = BaseEvent(key=key,
                       event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value,
                       value=file_path)
     client.send_event(event)
     watcher: ResponseWatcher = ResponseWatcher()
     client.start_listen_event(key=key,
                               event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value,
                               watcher=watcher)
     res: BaseEvent = watcher.get_result()
     self.assertEquals(event.key, res.key)
     self.assertEquals(event.value, file_path)
예제 #9
0
    def stop_workflow(self, workflow_name) -> bool:
        """
        Stop the workflow. No more workflow execution(Airflow dag_run) would be scheduled and all running jobs would be stopped.

        :param workflow_name: workflow name
        :return: True if succeed
        """
        # TODO For now, simply return True as long as message is sent successfully,
        #  actually we need a response from
        try:
            notification_client = NotificationClient(self.server_uri,
                                                     SCHEDULER_NAMESPACE)
            notification_client.send_event(
                StopDagEvent(workflow_name).to_event())
            return True
        except Exception:
            return False
예제 #10
0
 def run_trigger_dag_function(self):
     ns_client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                    default_namespace="")
     client = EventSchedulerClient(ns_client=ns_client)
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'trigger_dag',
                 TaskExecution.task_id == 'task_1').all()
             if len(tes) > 0:
                 break
             else:
                 client.trigger_parse_dag()
                 result = client.schedule_dag('trigger_dag')
                 print('result {}'.format(result.dagrun_id))
             time.sleep(5)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #11
0
 def run_airflow_dag_function(self):
     # waiting parsed dag file done
     from datetime import datetime
     ns_client = NotificationClient(server_uri='localhost:50051')
     with af.global_config_file(test_util.get_workflow_config_file()):
         with af.config('task_1'):
             cmd_executor = af.user_define_operation(
                 output_num=0,
                 executor=CmdExecutor(cmd_line=['echo "hello world!"']))
     af.deploy_to_airflow(test_util.get_project_path(),
                          dag_id='test_dag_111',
                          default_args={
                              'schedule_interval': None,
                              'start_date': datetime(2025, 12, 1),
                          })
     context = af.run(project_path=test_util.get_project_path(),
                      dag_id='test_dag_111',
                      scheduler_type=SchedulerType.AIRFLOW)
     print(context.dagrun_id)
     time.sleep(5)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #12
0
 def run_ai_flow_function(self):
     client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                 default_namespace="default",
                                 sender='1-job-name')
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'workflow_1',
                 TaskExecution.task_id == '1-job-name').all()
             if len(tes) > 0:
                 time.sleep(5)
                 client.send_event(
                     BaseEvent(key='key_1',
                               value='value_1',
                               event_type='UNDEFINED'))
                 client.send_event(
                     BaseEvent(key='key_2',
                               value='value_2',
                               event_type='UNDEFINED'))
                 while True:
                     with create_session() as session_2:
                         tes_2 = session_2.query(TaskExecution).filter(
                             TaskExecution.dag_id == 'workflow_1').all()
                         if len(tes_2) == 3:
                             break
                         else:
                             time.sleep(1)
                 break
             else:
                 time.sleep(1)
     time.sleep(3)
     client.send_event(StopSchedulerEvent(job_id=0).to_event())
예제 #13
0
    def test_send_listening_on_different_server(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[BaseEvent]):
                self.event_list.extend(events)

        self.master2 = self.start_master("localhost", "50052")
        self.wait_for_new_members_detected("localhost:50052")
        another_client = NotificationClient(server_uri="localhost:50052")
        try:
            event1 = another_client.send_event(BaseEvent(key="key1", value="value1"))
            self.client.start_listen_events(watcher=TestWatch(event_list), version=event1.version)
            another_client.send_event(BaseEvent(key="key2", value="value2"))
            another_client.send_event(BaseEvent(key="key3", value="value3"))
        finally:
            self.client.stop_listen_events()
        self.assertEqual(2, len(event_list))
예제 #14
0
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer):

    def __init__(self, store_uri, notification_uri=None):
        db_engine = extract_db_engine_from_uri(store_uri)
        if DBType.value_of(db_engine) == DBType.MONGODB:
            username, password, host, port, db = parse_mongo_uri(store_uri)
            self.model_repo_store = MongoStore(host=host,
                                               port=int(port),
                                               username=username,
                                               password=password,
                                               db=db)
        else:
            self.model_repo_store = SqlAlchemyStore(store_uri)
        self.notification_client = None
        if notification_uri is not None:
            self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)

    @catch_exception
    def createRegisteredModel(self, request, context):
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name,
                                                                              registered_model_param.model_desc)
        return _wrap_response(registered_model_meta.to_meta_proto())

    @catch_exception
    def updateRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.update_registered_model(
            RegisteredModel(model_meta_param.model_name),
            registered_model_param.model_name,
            registered_model_param.model_desc)
        return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto())

    @catch_exception
    def deleteRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(request.model_meta)

    @catch_exception
    def listRegisteredModels(self, request, context):
        registered_models = self.model_repo_store.list_registered_models()
        return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto()
                                                                      for registered_model in registered_models]))

    @catch_exception
    def getRegisteredModelDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        registered_model_detail = self.model_repo_store.get_registered_model_detail(
            RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto())

    @catch_exception
    def createModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_type,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        event_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
        if self.notification_client is not None:
            self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                          json.dumps(model_version_meta.__dict__),
                                                          event_type))
        return _wrap_response(model_version_meta.to_meta_proto())

    @catch_exception
    def updateModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.update_model_version(model_meta_param,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_type,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        if model_version_param.current_stage is not None:
            event_type = MODEL_VERSION_TO_EVENT_TYPE.get(
                ModelVersionStage.from_string(model_version_param.current_stage))
            if self.notification_client is not None:
                self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                              json.dumps(model_version_meta.__dict__),
                                                              event_type))
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())

    @catch_exception
    def deleteModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        self.model_repo_store.delete_model_version(model_meta_param)
        if self.notification_client is not None:
            self.notification_client.send_event(BaseEvent(model_meta_param.model_name,
                                                          json.dumps(model_meta_param.__dict__),
                                                          ModelVersionEventType.MODEL_DELETED))
        return _wrap_response(request.model_meta)

    @catch_exception
    def getModelVersionDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param)
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
예제 #15
0
class HaServerTest(unittest.TestCase):

    @classmethod
    def start_master(cls, host, port):
        port = str(port)
        server_uri = host + ":" + port
        storage = DbEventStorage()
        ha_manager = SimpleNotificationServerHaManager()
        ha_storage = DbHighAvailabilityStorage()
        service = HighAvailableNotificationService(
            storage,
            ha_manager,
            server_uri,
            ha_storage)
        master = NotificationMaster(service, port=int(port))
        master.run()
        return master

    @classmethod
    def setUpClass(cls):
        cls.storage = DbEventStorage()
        cls.master1 = None
        cls.master2 = None
        cls.master3 = None

    def setUp(self):
        self.storage.clean_up()
        self.master1 = self.start_master("localhost", "50051")
        self.client = NotificationClient(server_uri="localhost:50051", enable_ha=True)

    def tearDown(self):
        self.client.stop_listen_events()
        self.client.stop_listen_event()
        self.client.disable_high_availability()
        if self.master1 is not None:
            self.master1.stop()
        if self.master2 is not None:
            self.master2.stop()
        if self.master3 is not None:
            self.master3.stop()

    def wait_for_new_members_detected(self, new_member_uri):
        for i in range(100):
            living_member = self.client.living_members
            if new_member_uri in living_member:
                break
            else:
                time.sleep(0.1)

    def test_server_change(self):
        self.client.send_event(BaseEvent(key="key", value="value1"))
        self.client.send_event(BaseEvent(key="key", value="value2"))
        self.client.send_event(BaseEvent(key="key", value="value3"))
        results = self.client.list_all_events()
        self.assertEqual(self.client.current_uri, "localhost:50051")
        self.master2 = self.start_master("localhost", "50052")
        self.wait_for_new_members_detected("localhost:50052")
        self.master1.stop()
        results2 = self.client.list_all_events()
        self.assertEqual(results, results2)
        self.assertEqual(self.client.current_uri, "localhost:50052")
        self.master3 = self.start_master("localhost", "50053")
        self.wait_for_new_members_detected("localhost:50053")
        self.master2.stop()
        results3 = self.client.list_all_events()
        self.assertEqual(results2, results3)
        self.assertEqual(self.client.current_uri, "localhost:50053")

    def test_send_listening_on_different_server(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[BaseEvent]):
                self.event_list.extend(events)

        self.master2 = self.start_master("localhost", "50052")
        self.wait_for_new_members_detected("localhost:50052")
        another_client = NotificationClient(server_uri="localhost:50052")
        try:
            event1 = another_client.send_event(BaseEvent(key="key1", value="value1"))
            self.client.start_listen_events(watcher=TestWatch(event_list), version=event1.version)
            another_client.send_event(BaseEvent(key="key2", value="value2"))
            another_client.send_event(BaseEvent(key="key3", value="value3"))
        finally:
            self.client.stop_listen_events()
        self.assertEqual(2, len(event_list))

    def test_start_with_multiple_servers(self):
        self.client.disable_high_availability()
        self.client = NotificationClient(server_uri="localhost:55001,localhost:50051", enable_ha=True)
        self.assertTrue(self.client.current_uri, "localhost:50051")
예제 #16
0
 def execute(self, context):
     notification_client = NotificationClient(server_uri=self.uri)
     notification_client.send_event(event=self.event)
예제 #17
0
class NotificationTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.storage = EventModelStorage()
        cls.master = NotificationMaster(NotificationService(cls.storage))
        cls.master.run()

    @classmethod
    def tearDownClass(cls):
        cls.master.stop()

    def setUp(self):
        self.storage.clean_up()
        self.client = NotificationClient(server_uri="localhost:50051")

    def tearDown(self):
        self.client.stop_listen_events()
        self.client.stop_listen_event()

    def test_send_event(self):
        event = self.client.send_event(Event(key="key", value="value1"))
        self.assertTrue(event.version > 0)

    def test_list_events(self):
        event1 = self.client.send_event(Event(key="key", value="value1"))
        event2 = self.client.send_event(Event(key="key", value="value2"))
        event3 = self.client.send_event(Event(key="key", value="value3"))
        events = self.client.list_events("key", version=event1.version)
        self.assertEqual(2, len(events))

    def test_listen_events(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[Event]):
                self.event_list.extend(events)

        event1 = self.client.send_event(Event(key="key", value="value1"))
        self.client.start_listen_event(key="key",
                                       watcher=TestWatch(event_list),
                                       version=event1.version)
        event = self.client.send_event(Event(key="key", value="value2"))
        event = self.client.send_event(Event(key="key", value="value3"))
        self.client.stop_listen_event("key")
        events = self.client.list_events("key", version=event1.version)
        self.assertEqual(2, len(events))
        self.assertEqual(2, len(event_list))

    def test_all_listen_events(self):
        event = self.client.send_event(Event(key="key", value="value1"))
        event = self.client.send_event(Event(key="key", value="value2"))
        start_time = event.create_time
        event = self.client.send_event(Event(key="key", value="value3"))
        events = self.client.list_all_events(start_time)
        self.assertEqual(2, len(events))

    def test_listen_all_events(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[Event]):
                self.event_list.extend(events)

        try:
            self.client.start_listen_events(watcher=TestWatch(event_list))
            event = self.client.send_event(Event(key="key1", value="value1"))
            event = self.client.send_event(Event(key="key2", value="value2"))
            event = self.client.send_event(Event(key="key3", value="value3"))
        finally:
            self.client.stop_listen_events()
        self.assertEqual(3, len(event_list))
예제 #18
0
class TestEventBasedScheduler(unittest.TestCase):
    def setUp(self):
        db.clear_db_jobs()
        db.clear_db_dags()
        db.clear_db_serialized_dags()
        db.clear_db_runs()
        db.clear_db_task_execution()
        db.clear_db_message()
        self.scheduler = None
        self.port = 50102
        self.storage = MemoryEventStorage()
        self.master = NotificationMaster(NotificationService(self.storage),
                                         self.port)
        self.master.run()
        self.client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                         default_namespace="test_namespace")
        time.sleep(1)

    def tearDown(self):
        self.master.stop()

    def _get_task_instance(self, dag_id, task_id, session):
        return session.query(TaskInstance).filter(
            TaskInstance.dag_id == dag_id,
            TaskInstance.task_id == task_id).first()

    def schedule_task_function(self):
        stopped = False
        while not stopped:
            with create_session() as session:
                ti_sleep_1000_secs = self._get_task_instance(
                    EVENT_BASED_SCHEDULER_DAG, 'sleep_1000_secs', session)
                ti_python_sleep = self._get_task_instance(
                    EVENT_BASED_SCHEDULER_DAG, 'python_sleep', session)
                if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.SCHEDULED and \
                   ti_python_sleep and ti_python_sleep.state == State.SCHEDULED:
                    self.client.send_event(
                        BaseEvent(key='start',
                                  value='',
                                  event_type='',
                                  namespace='test_namespace'))

                    while not stopped:
                        ti_sleep_1000_secs.refresh_from_db()
                        ti_python_sleep.refresh_from_db()
                        if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.RUNNING and \
                           ti_python_sleep and ti_python_sleep.state == State.RUNNING:
                            time.sleep(10)
                            break
                        else:
                            time.sleep(1)
                    self.client.send_event(
                        BaseEvent(key='stop',
                                  value='',
                                  event_type=UNDEFINED_EVENT_TYPE,
                                  namespace='test_namespace'))
                    self.client.send_event(
                        BaseEvent(key='restart',
                                  value='',
                                  event_type=UNDEFINED_EVENT_TYPE,
                                  namespace='test_namespace'))
                    while not stopped:
                        ti_sleep_1000_secs.refresh_from_db()
                        ti_python_sleep.refresh_from_db()
                        if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.KILLED and \
                           ti_python_sleep and ti_python_sleep.state == State.RUNNING:
                            stopped = True
                        else:
                            time.sleep(1)
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_event_based_scheduler(self):
        t = threading.Thread(target=self.schedule_task_function)
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_based_scheduler.py')

    def test_replay_message(self):
        key = "stop"
        mailbox = Mailbox()
        mailbox.set_scheduling_job_id(1234)
        watcher = SchedulerEventWatcher(mailbox)
        self.client.start_listen_events(watcher=watcher,
                                        start_time=int(time.time() * 1000),
                                        version=None)
        self.send_event(key)
        msg: BaseEvent = mailbox.get_message()
        self.assertEqual(msg.key, key)
        with create_session() as session:
            msg_from_db = session.query(Message).first()
            expect_non_unprocessed = EventBasedScheduler.get_unprocessed_message(
                1000)
            self.assertEqual(0, len(expect_non_unprocessed))
            unprocessed = EventBasedScheduler.get_unprocessed_message(1234)
            self.assertEqual(unprocessed[0].serialized_message,
                             msg_from_db.data)
        deserialized_data = pickle.loads(msg_from_db.data)
        self.assertEqual(deserialized_data.key, key)
        self.assertEqual(msg, deserialized_data)

    def send_event(self, key):
        event = self.client.send_event(
            BaseEvent(key=key, event_type=UNDEFINED_EVENT_TYPE,
                      value="value1"))
        self.assertEqual(key, event.key)

    @provide_session
    def get_task_execution(self, dag_id, task_id, session):
        return session.query(TaskExecution).filter(
            TaskExecution.dag_id == dag_id,
            TaskExecution.task_id == task_id).all()

    @provide_session
    def get_latest_job_id(self, session):
        return session.query(BaseJob).order_by(sqlalchemy.desc(
            BaseJob.id)).first().id

    def start_scheduler(self, file_path):
        self.scheduler = EventBasedSchedulerJob(
            dag_directory=file_path,
            server_uri="localhost:{}".format(self.port),
            executor=LocalExecutor(3),
            max_runs=-1,
            refresh_dag_dir_interval=30)
        print("scheduler starting")
        self.scheduler.run()

    def wait_for_running(self):
        while True:
            if self.scheduler is not None:
                time.sleep(5)
                break
            else:
                time.sleep(1)

    def wait_for_task_execution(self, dag_id, task_id, expected_num):
        result = False
        check_nums = 100
        while check_nums > 0:
            time.sleep(2)
            check_nums = check_nums - 1
            tes = self.get_task_execution(dag_id, task_id)
            if len(tes) == expected_num:
                result = True
                break
        self.assertTrue(result)

    def wait_for_task(self, dag_id, task_id, expected_state):
        result = False
        check_nums = 100
        while check_nums > 0:
            time.sleep(2)
            check_nums = check_nums - 1
            with create_session() as session:
                ti = session.query(TaskInstance).filter(
                    TaskInstance.dag_id == dag_id,
                    TaskInstance.task_id == task_id).first()
            if ti and ti.state == expected_state:
                result = True
                break
        self.assertTrue(result)

    def test_notification(self):
        self.client.send_event(BaseEvent(key='a', value='b'))

    def run_a_task_function(self):
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_a_task(self):
        t = threading.Thread(target=self.run_a_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_single_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution("single", "task_1")
        self.assertEqual(len(tes), 1)

    def run_event_task_function(self):
        client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                    default_namespace="")
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'event_dag',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    time.sleep(5)
                    client.send_event(
                        BaseEvent(key='start',
                                  value='',
                                  event_type='',
                                  namespace=''))
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'event_dag',
                                TaskExecution.task_id == 'task_2').all()
                            if len(tes_2) > 0:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_event_task(self):
        t = threading.Thread(target=self.run_event_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "event_dag", "task_2")
        self.assertEqual(len(tes), 1)

    def run_trigger_dag_function(self):
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="")
        client = EventSchedulerClient(ns_client=ns_client)
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'trigger_dag',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    break
                else:
                    client.trigger_parse_dag()
                    result = client.schedule_dag('trigger_dag')
                    print('result {}'.format(result.dagrun_id))
                time.sleep(5)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_trigger_dag(self):
        import multiprocessing
        p = multiprocessing.Process(target=self.run_trigger_dag_function,
                                    args=())
        p.start()
        self.start_scheduler('../../dags/test_run_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_dag", "task_1")
        self.assertEqual(len(tes), 1)

    def run_no_dag_file_function(self):
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="")
        client = EventSchedulerClient(ns_client=ns_client)
        with create_session() as session:
            client.trigger_parse_dag()
            result = client.schedule_dag('no_dag')
            print('result {}'.format(result.dagrun_id))
            time.sleep(5)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_no_dag_file_trigger_dag(self):
        import multiprocessing
        p = multiprocessing.Process(target=self.run_no_dag_file_function,
                                    args=())
        p.start()
        self.start_scheduler('../../dags/test_run_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_dag", "task_1")
        self.assertEqual(len(tes), 0)

    def run_trigger_task_function(self):
        # waiting parsed dag file done,
        time.sleep(5)
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="a")
        client = EventSchedulerClient(ns_client=ns_client)
        execution_context = client.schedule_dag('trigger_task')
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'trigger_task',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    client.schedule_task('trigger_task', 'task_2',
                                         SchedulingAction.START,
                                         execution_context)
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'trigger_task',
                                TaskExecution.task_id == 'task_2').all()
                            if len(tes_2) > 0:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_task_trigger_dag(self):
        import threading
        t = threading.Thread(target=self.run_trigger_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_task_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_task", "task_2")
        self.assertEqual(len(tes), 1)

    def run_ai_flow_function(self):
        client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                    default_namespace="default",
                                    sender='1-job-name')
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'workflow_1',
                    TaskExecution.task_id == '1-job-name').all()
                if len(tes) > 0:
                    time.sleep(5)
                    client.send_event(
                        BaseEvent(key='key_1',
                                  value='value_1',
                                  event_type='UNDEFINED'))
                    client.send_event(
                        BaseEvent(key='key_2',
                                  value='value_2',
                                  event_type='UNDEFINED'))
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'workflow_1').all()
                            if len(tes_2) == 3:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        time.sleep(3)
        client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_ai_flow_dag(self):
        import threading
        t = threading.Thread(target=self.run_ai_flow_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_aiflow_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "workflow_1", "1-job-name")
        self.assertEqual(len(tes), 1)

    def stop_dag_function(self):
        stopped = False
        while not stopped:
            tes = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG,
                                          'sleep_to_be_stopped')
            if tes and len(tes) == 1:
                self.client.send_event(
                    StopDagEvent(EVENT_BASED_SCHEDULER_DAG).to_event())
                while not stopped:
                    tes2 = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG,
                                                   'sleep_to_be_stopped')
                    if tes2[0].state == State.KILLED:
                        stopped = True
                        time.sleep(5)
                    else:
                        time.sleep(1)
            else:
                time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_stop_dag(self):
        t = threading.Thread(target=self.stop_dag_function)
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_based_scheduler.py')
        with create_session() as session:
            from airflow.models import DagModel
            dag_model: DagModel = DagModel.get_dagmodel(
                EVENT_BASED_SCHEDULER_DAG)
            self.assertTrue(dag_model.is_paused)
            self.assertEqual(dag_model.get_last_dagrun().state, "killed")
            for ti in session.query(TaskInstance).filter(
                    TaskInstance.dag_id == EVENT_BASED_SCHEDULER_DAG):
                self.assertTrue(ti.state in [State.SUCCESS, State.KILLED])
            for te in session.query(TaskExecution).filter(
                    TaskExecution.dag_id == EVENT_BASED_SCHEDULER_DAG):
                self.assertTrue(te.state in [State.SUCCESS, State.KILLED])

    def run_periodic_task_function(self):
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 1:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_periodic_task(self):
        t = threading.Thread(target=self.run_periodic_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_periodic_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution("single", "task_1")
        self.assertGreater(len(tes), 1)

    def run_one_task_function(self):
        self.wait_for_running()
        self.client.send_event(BaseEvent(key='a', value='a'))
        time.sleep(5)
        self.client.send_event(BaseEvent(key='a', value='a'))
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) >= 2:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_one_task(self):
        t = threading.Thread(target=self.run_one_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_multiple_trigger_task_dag.py')
class EventSchedulerClient(object):
    def __init__(self, server_uri=None, namespace=None, ns_client=None):
        if ns_client is None:
            self.ns_client = NotificationClient(server_uri, namespace)
        else:
            self.ns_client = ns_client

    @staticmethod
    def generate_id(id):
        return '{}_{}'.format(id, time.time_ns())

    def trigger_parse_dag(self) -> bool:
        id = self.generate_id('')
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)

        self.ns_client.send_event(
            BaseEvent(
                key=id,
                event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value,
                value=''))
        result = watcher.get_result()
        handler.stop()
        return True

    def schedule_dag(self, dag_id) -> ExecutionContext:
        id = self.generate_id(dag_id)
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)
        self.ns_client.send_event(
            RequestEvent(request_id=id,
                         body=RunDagMessage(dag_id).to_json()).to_event())
        result: ResponseEvent = ResponseEvent.from_base_event(
            watcher.get_result())
        handler.stop()
        return ExecutionContext(dagrun_id=result.body)

    def stop_dag_run(self, dag_id,
                     context: ExecutionContext) -> ExecutionContext:
        id = self.generate_id(str(dag_id) + str(context.dagrun_id))
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)
        self.ns_client.send_event(
            RequestEvent(
                request_id=id,
                body=StopDagRunMessage(
                    dag_id=dag_id,
                    dagrun_id=context.dagrun_id).to_json()).to_event())
        result: ResponseEvent = ResponseEvent.from_base_event(
            watcher.get_result())
        handler.stop()
        return ExecutionContext(dagrun_id=result.body)

    def schedule_task(self, dag_id: str, task_id: str,
                      action: SchedulingAction,
                      context: ExecutionContext) -> ExecutionContext:
        id = self.generate_id(context.dagrun_id)
        watcher: ResponseWatcher = ResponseWatcher()
        handler: ThreadEventWatcherHandle \
            = self.ns_client.start_listen_event(key=id,
                                                event_type=SchedulerInnerEventType.RESPONSE.value,
                                                namespace=SCHEDULER_NAMESPACE, watcher=watcher)
        self.ns_client.send_event(
            RequestEvent(request_id=id,
                         body=ExecuteTaskMessage(
                             dag_id=dag_id,
                             task_id=task_id,
                             dagrun_id=context.dagrun_id,
                             action=action.value).to_json()).to_event())
        result: ResponseEvent = ResponseEvent.from_base_event(
            watcher.get_result())
        handler.stop()
        return ExecutionContext(dagrun_id=result.body)
예제 #20
0
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer):

    def __init__(self, store_uri, server_uri, notification_uri=None):
        self.model_repo_store = SqlAlchemyStore(store_uri)
        if notification_uri is None:
            self.notification_client = NotificationClient(server_uri)
        else:
            self.notification_client = NotificationClient(notification_uri)

    @catch_exception
    def createRegisteredModel(self, request, context):
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name,
                                                                              ModelType.Name(
                                                                                  registered_model_param.model_type),
                                                                              registered_model_param.model_desc)
        return _wrap_response(registered_model_meta.to_meta_proto())

    @catch_exception
    def updateRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.update_registered_model(
            RegisteredModel(model_meta_param.model_name),
            registered_model_param.model_name,
            ModelType.Name(
                registered_model_param.model_type),
            registered_model_param.model_desc)
        return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto())

    @catch_exception
    def deleteRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(request.model_meta)

    @catch_exception
    def listRegisteredModels(self, request, context):
        registered_models = self.model_repo_store.list_registered_models()
        return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto()
                                                                      for registered_model in registered_models]))

    @catch_exception
    def getRegisteredModelDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        registered_model_detail = self.model_repo_store.get_registered_model_detail(
            RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto())

    @catch_exception
    def createModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_metric,
                                                                        model_version_param.model_flavor,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
        self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                      json.dumps(model_version_meta.__dict__),
                                                      model_type))
        return _wrap_response(model_version_meta.to_meta_proto())

    @catch_exception
    def updateModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.update_model_version(model_meta_param,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_metric,
                                                                        model_version_param.model_flavor,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        if model_version_param.current_stage is not None:
            model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
            self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                          json.dumps(model_version_meta.__dict__),
                                                          model_type))
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())

    @catch_exception
    def deleteModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        self.model_repo_store.delete_model_version(model_meta_param)
        self.notification_client.send_event(BaseEvent(model_meta_param.model_name,
                                                      json.dumps(model_meta_param.__dict__),
                                                      ModelVersionEventType.MODEL_DELETED))
        return _wrap_response(request.model_meta)

    @catch_exception
    def getModelVersionDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param)
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())