示例#1
0
        def run_task_function(client: NotificationClient):
            with af.global_config_file(workflow_config_file()):
                with af.config('task_2'):
                    executor_1 = af.user_define_operation(
                        af.PythonObjectExecutor(SimpleExecutor()))
                with af.config('task_5'):
                    executor_2 = af.user_define_operation(
                        af.PythonObjectExecutor(SimpleExecutor()))
                af.user_define_control_dependency(src=executor_2,
                                                  dependency=executor_1,
                                                  namespace='test',
                                                  event_key='key_1',
                                                  event_value='value_1',
                                                  sender='*')
                workflow_info = af.workflow_operation.submit_workflow(
                    workflow_name)

            af.workflow_operation.start_new_workflow_execution(workflow_name)
            flag = True
            while True:
                with create_session() as session:
                    tes = session.query(TaskExecution).filter(
                        TaskExecution.dag_id == 'test_project.test_workflow',
                        TaskExecution.task_id == 'task_2').all()
                    if 1 == len(tes) and flag:
                        client.send_event(
                            BaseEvent(key='key_1', value='value_1'))
                        flag = False
                    dag_run = session.query(DagRun).filter(
                        DagRun.dag_id == 'test_project.test_workflow').first()
                    if dag_run is not None and dag_run.state in State.finished:
                        break
                    else:
                        time.sleep(1)
示例#2
0
 def run_trigger_task_function(self):
     # waiting parsed dag file done,
     time.sleep(5)
     ns_client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                    default_namespace="a")
     client = EventSchedulerClient(ns_client=ns_client)
     execution_context = client.schedule_dag('trigger_task')
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'trigger_task',
                 TaskExecution.task_id == 'task_1').all()
             if len(tes) > 0:
                 client.schedule_task('trigger_task', 'task_2',
                                      SchedulingAction.START,
                                      execution_context)
                 while True:
                     with create_session() as session_2:
                         tes_2 = session_2.query(TaskExecution).filter(
                             TaskExecution.dag_id == 'trigger_task',
                             TaskExecution.task_id == 'task_2').all()
                         if len(tes_2) > 0:
                             break
                         else:
                             time.sleep(1)
                 break
             else:
                 time.sleep(1)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
 def __init__(self, server_uri=_SERVER_URI, notification_service_uri=None):
     MetadataClient.__init__(self, server_uri)
     ModelCenterClient.__init__(self, server_uri)
     DeployClient.__init__(self, server_uri)
     MetricClient.__init__(self, server_uri)
     if notification_service_uri is None:
         NotificationClient.__init__(self, server_uri)
     else:
         NotificationClient.__init__(self, notification_service_uri)
 def _send_task_status_change_event(self):
     task_status_changed_event = TaskStateChangedEvent(
         self.task_instance.task_id, self.task_instance.dag_id,
         self.task_instance.execution_date, self.task_instance.state)
     event = task_status_changed_event.to_event()
     client = NotificationClient(self.server_uri,
                                 default_namespace=event.namespace,
                                 sender=event.sender)
     self.log.info("LocalTaskJob sending event: {}".format(event))
     client.send_event(event)
 def run_test_fun():
     time.sleep(3)
     client = NotificationClient(server_uri="localhost:{}".format(server_port()),
                                 default_namespace="test")
     try:
         test_function(client)
     except Exception as e:
         raise e
     finally:
         client.send_event(StopSchedulerEvent(job_id=0).to_event())
示例#6
0
 def run_no_dag_file_function(self):
     ns_client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                    default_namespace="")
     client = EventSchedulerClient(ns_client=ns_client)
     with create_session() as session:
         client.trigger_parse_dag()
         result = client.schedule_dag('no_dag')
         print('result {}'.format(result.dagrun_id))
         time.sleep(5)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
示例#7
0
 def execute(self, function_context: FunctionContext,
             input_list: List) -> List:
     from notification_service.client import NotificationClient
     client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                 default_namespace="default",
                                 sender=self.sender)
     client.send_event(
         BaseEvent(key=self.key,
                   value=self.value,
                   event_type=self.event_type))
     return []
示例#8
0
 def __init__(self, store_uri, notification_uri=None):
     db_engine = extract_db_engine_from_uri(store_uri)
     if DBType.value_of(db_engine) == DBType.MONGODB:
         username, password, host, port, db = parse_mongo_uri(store_uri)
         self.model_repo_store = MongoStore(host=host,
                                            port=int(port),
                                            username=username,
                                            password=password,
                                            db=db)
     else:
         self.model_repo_store = SqlAlchemyStore(store_uri)
     self.notification_client = None
     if notification_uri is not None:
         self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)
示例#9
0
 def run_ai_flow_function(self):
     client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                 default_namespace="default",
                                 sender='1-job-name')
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'workflow_1',
                 TaskExecution.task_id == '1-job-name').all()
             if len(tes) > 0:
                 time.sleep(5)
                 client.send_event(
                     BaseEvent(key='key_1',
                               value='value_1',
                               event_type='UNDEFINED'))
                 client.send_event(
                     BaseEvent(key='key_2',
                               value='value_2',
                               event_type='UNDEFINED'))
                 while True:
                     with create_session() as session_2:
                         tes_2 = session_2.query(TaskExecution).filter(
                             TaskExecution.dag_id == 'workflow_1').all()
                         if len(tes_2) == 3:
                             break
                         else:
                             time.sleep(1)
                 break
             else:
                 time.sleep(1)
     time.sleep(3)
     client.send_event(StopSchedulerEvent(job_id=0).to_event())
示例#10
0
 def __init__(self,
              dag_directory,
              server_uri=None,
              max_runs=-1,
              refresh_dag_dir_interval=conf.getint(
                  'scheduler', 'refresh_dag_dir_interval', fallback=30),
              *args,
              **kwargs):
     super().__init__(*args, **kwargs)
     self.mailbox: Mailbox = Mailbox()
     self.dag_trigger: DagTrigger = DagTrigger(
         dag_directory=dag_directory,
         max_runs=max_runs,
         dag_ids=None,
         pickle_dags=False,
         mailbox=self.mailbox,
         refresh_dag_dir_interval=refresh_dag_dir_interval,
         notification_service_uri=server_uri)
     self.task_event_manager = DagRunEventManager(self.mailbox)
     self.executor.set_mailbox(self.mailbox)
     self.notification_client: NotificationClient = NotificationClient(
         server_uri=server_uri, default_namespace=SCHEDULER_NAMESPACE)
     self.scheduler: EventBasedScheduler = EventBasedScheduler(
         self.id, self.mailbox, self.task_event_manager, self.executor,
         self.notification_client)
     self.last_scheduling_id = self._last_scheduler_job_id()
示例#11
0
 def setUp(self):
     db.clear_db_jobs()
     db.clear_db_dags()
     db.clear_db_serialized_dags()
     db.clear_db_runs()
     db.clear_db_task_execution()
     db.clear_db_message()
     self.scheduler = None
     self.port = 50102
     self.storage = MemoryEventStorage()
     self.master = NotificationMaster(NotificationService(self.storage),
                                      self.port)
     self.master.run()
     self.client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                      default_namespace="test_namespace")
     time.sleep(1)
示例#12
0
    def stop_workflow(self, workflow_name) -> bool:
        """
        Stop the workflow. No more workflow execution(Airflow dag_run) would be scheduled and all running jobs would be stopped.

        :param workflow_name: workflow name
        :return: True if succeed
        """
        # TODO For now, simply return True as long as message is sent successfully,
        #  actually we need a response from
        try:
            notification_client = NotificationClient(self.server_uri,
                                                     SCHEDULER_NAMESPACE)
            notification_client.send_event(
                StopDagEvent(workflow_name).to_event())
            return True
        except Exception:
            return False
示例#13
0
 def run_trigger_dag_function(self):
     ns_client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                    default_namespace="")
     client = EventSchedulerClient(ns_client=ns_client)
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'trigger_dag',
                 TaskExecution.task_id == 'task_1').all()
             if len(tes) > 0:
                 break
             else:
                 client.trigger_parse_dag()
                 result = client.schedule_dag('trigger_dag')
                 print('result {}'.format(result.dagrun_id))
             time.sleep(5)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
示例#14
0
 def setUpClass(cls):
     kwargs = {
         "host": "127.0.0.1",
         "port": 27017,
         "db": "test"
     }
     cls.storage = MongoEventStorage(**kwargs)
     cls.master = NotificationMaster(NotificationService(cls.storage))
     cls.master.run()
     cls.client = NotificationClient(server_uri="localhost:50051")
 def _execute(self):
     """
     1. Init the DagRun route.
     2. Start the executor.
     3. Option of start the notification master.
     4. Create the notification client.
     5. Start the DagTrigger.
     6. Run the scheduler event loop.
     :return:
     """
     notification_client = None
     try:
         self._init_route()
         self.executor.set_use_nf(True)
         self.executor.start()
         self.dag_trigger = DagTrigger(
             subdir=self.subdir,
             mailbox=self.mail_box,
             run_duration=self.run_duration,
             using_sqlite=self.using_sqlite,
             num_runs=self.num_runs,
             processor_poll_interval=self._processor_poll_interval)
         if self.use_local_nf:
             self.notification_master \
                 = NotificationMaster(service=NotificationService(EventModelStorage()), port=self.nf_port)
             self.notification_master.run()
             self.log.info("start notification service {0}".format(
                 self.nf_port))
             notification_client = NotificationClient(
                 server_uri="localhost:{0}".format(self.nf_port))
         else:
             notification_client \
                 = NotificationClient(server_uri="{0}:{1}".format(self.nf_host, self.nf_port))
         notification_client.start_listen_events(
             watcher=SCEventWatcher(self.mail_box))
         self.dag_trigger.start()
         self._start_executor_heartbeat()
         self._run_event_loop()
     except Exception as e:
         self.log.exception("Exception when executing _execute {0}".format(
             str(e)))
     finally:
         self.running = False
         self._stop_executor_heartheat()
         if self.dag_trigger is not None:
             self.dag_trigger.stop()
         if notification_client is not None:
             notification_client.stop_listen_events()
         if self.notification_master is not None:
             self.notification_master.stop()
         self.executor.end()
         self.log.info("Exited execute event scheduler")
 def wait_for_master_started(cls, server_uri="localhost:50051"):
     last_exception = None
     for i in range(100):
         try:
             return NotificationClient(server_uri=server_uri,
                                       enable_ha=True)
         except Exception as e:
             time.sleep(10)
             last_exception = e
     raise Exception("The server %s is unavailable." %
                     server_uri) from last_exception
示例#17
0
 def run_airflow_dag_function(self):
     # waiting parsed dag file done
     from datetime import datetime
     ns_client = NotificationClient(server_uri='localhost:50051')
     with af.global_config_file(test_util.get_workflow_config_file()):
         with af.config('task_1'):
             cmd_executor = af.user_define_operation(
                 output_num=0,
                 executor=CmdExecutor(cmd_line=['echo "hello world!"']))
     af.deploy_to_airflow(test_util.get_project_path(),
                          dag_id='test_dag_111',
                          default_args={
                              'schedule_interval': None,
                              'start_date': datetime(2025, 12, 1),
                          })
     context = af.run(project_path=test_util.get_project_path(),
                      dag_id='test_dag_111',
                      scheduler_type=SchedulerType.AIRFLOW)
     print(context.dagrun_id)
     time.sleep(5)
     ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())
示例#18
0
 def run_event_task_function(self):
     client = NotificationClient(server_uri="localhost:{}".format(
         self.port),
                                 default_namespace="")
     while True:
         with create_session() as session:
             tes = session.query(TaskExecution).filter(
                 TaskExecution.dag_id == 'event_dag',
                 TaskExecution.task_id == 'task_1').all()
             if len(tes) > 0:
                 time.sleep(5)
                 client.send_event(
                     BaseEvent(key='start',
                               value='',
                               event_type='',
                               namespace=''))
                 while True:
                     with create_session() as session_2:
                         tes_2 = session_2.query(TaskExecution).filter(
                             TaskExecution.dag_id == 'event_dag',
                             TaskExecution.task_id == 'task_2').all()
                         if len(tes_2) > 0:
                             break
                         else:
                             time.sleep(1)
                 break
             else:
                 time.sleep(1)
     client.send_event(StopSchedulerEvent(job_id=0).to_event())
class MemoryStorageTest(unittest.TestCase, NotificationTest):
    @classmethod
    def set_up_class(cls):
        cls.storage = MemoryEventStorage()
        cls.master = NotificationMaster(NotificationService(cls.storage))
        cls.master.run()

    @classmethod
    def setUpClass(cls):
        cls.set_up_class()

    @classmethod
    def tearDownClass(cls):
        cls.master.stop()

    def setUp(self):
        self.storage.clean_up()
        self.client = NotificationClient(server_uri="localhost:50051")

    def tearDown(self):
        self.client.stop_listen_events()
        self.client.stop_listen_event()
示例#20
0
    def test_send_listening_on_different_server(self):
        event_list = []

        class TestWatch(EventWatcher):
            def __init__(self, event_list) -> None:
                super().__init__()
                self.event_list = event_list

            def process(self, events: List[BaseEvent]):
                self.event_list.extend(events)

        self.master2 = self.start_master("localhost", "50052")
        self.wait_for_new_members_detected("localhost:50052")
        another_client = NotificationClient(server_uri="localhost:50052")
        try:
            event1 = another_client.send_event(BaseEvent(key="key1", value="value1"))
            self.client.start_listen_events(watcher=TestWatch(event_list), version=event1.version)
            another_client.send_event(BaseEvent(key="key2", value="value2"))
            another_client.send_event(BaseEvent(key="key3", value="value3"))
        finally:
            self.client.stop_listen_events()
        self.assertEqual(2, len(event_list))
示例#21
0
 def change_state(self, key, state):
     self.log.debug("Changing state: %s %s", key, state)
     self.running.pop(key, None)
     if self.use_nf:
         if self.client is None:
             self.client: NotificationClient = NotificationClient(
                 server_uri="{0}:{1}".format(self.nf_host,
                                             self.nf_port))
         dag_id, task_id, execution_date, try_number = key
         self.client.send_event(TaskStatusEvent(
             task_instance_key=TaskInstanceHelper.to_task_key(dag_id, task_id, execution_date),
             status=TaskInstanceHelper.to_event_value(state, try_number)))
     else:
         self.event_buffer[key] = state
示例#22
0
class BaseExecutorTest(unittest.TestCase):
    def setUp(self):
        clear_db_event_model()
        self.master = NotificationMaster(service=NotificationService(EventModelStorage()))
        self.master.run()
        self.client = NotificationClient(server_uri="localhost:50051")

    def tearDown(self) -> None:
        self.master.stop()

    def test_get_event_buffer(self):
        executor = BaseExecutor()

        date = datetime.utcnow()
        try_number = 1
        key1 = ("my_dag1", "my_task1", date, try_number)
        key2 = ("my_dag2", "my_task1", date, try_number)
        key3 = ("my_dag2", "my_task2", date, try_number)
        state = State.SUCCESS
        executor.event_buffer[key1] = state
        executor.event_buffer[key2] = state
        executor.event_buffer[key3] = state

        self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1)
        self.assertEqual(len(executor.get_event_buffer()), 2)
        self.assertEqual(len(executor.event_buffer), 0)

    @mock.patch('airflow.executors.base_executor.BaseExecutor.sync')
    @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
    @mock.patch('airflow.settings.Stats.gauge')
    def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
        executor = BaseExecutor()
        executor.heartbeat()
        calls = [mock.call('executor.open_slots', mock.ANY),
                 mock.call('executor.queued_tasks', mock.ANY),
                 mock.call('executor.running_tasks', mock.ANY)]
        mock_stats_gauge.assert_has_calls(calls)

    def test_use_nf_executor(self):
        executor = BaseExecutor()
        executor.set_use_nf(True)
        executor.change_state('key', State.RUNNING)
        executor.change_state('key', State.SUCCESS)
        events = self.client.list_all_events(1)
        self.assertEqual(2, len(events))
 def _send_request_and_receive_response(self, server_uri, file_path):
     key = '{}_{}'.format(file_path, time.time_ns())
     client = NotificationClient(server_uri=server_uri,
                                 default_namespace=SCHEDULER_NAMESPACE)
     event = BaseEvent(key=key,
                       event_type=SchedulerInnerEventType.PARSE_DAG_REQUEST.value,
                       value=file_path)
     client.send_event(event)
     watcher: ResponseWatcher = ResponseWatcher()
     client.start_listen_event(key=key,
                               event_type=SchedulerInnerEventType.PARSE_DAG_RESPONSE.value,
                               watcher=watcher)
     res: BaseEvent = watcher.get_result()
     self.assertEquals(event.key, res.key)
     self.assertEquals(event.value, file_path)
class HaDbStorageTest(unittest.TestCase, NotificationTest):
    """
    This test is used to ensure the high availability would not break the original functionality.
    """
    @classmethod
    def set_up_class(cls):
        cls.storage = DbEventStorage()
        cls.master1 = start_ha_master("localhost", 50051)
        # The server startup is asynchronous, we need to wait for a while
        # to ensure it writes its metadata to the db.
        time.sleep(0.1)
        cls.master2 = start_ha_master("localhost", 50052)
        time.sleep(0.1)
        cls.master3 = start_ha_master("localhost", 50053)
        time.sleep(0.1)

    @classmethod
    def setUpClass(cls):
        cls.set_up_class()

    @classmethod
    def tearDownClass(cls):
        cls.master1.stop()
        cls.master2.stop()
        cls.master3.stop()

    def setUp(self):
        self.storage.clean_up()
        self.client = NotificationClient(server_uri="localhost:50052",
                                         enable_ha=True,
                                         list_member_interval_ms=1000,
                                         retry_timeout_ms=10000)

    def tearDown(self):
        self.client.stop_listen_events()
        self.client.stop_listen_event()
        self.client.disable_high_availability()
示例#25
0
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer):

    def __init__(self, store_uri, notification_uri=None):
        db_engine = extract_db_engine_from_uri(store_uri)
        if DBType.value_of(db_engine) == DBType.MONGODB:
            username, password, host, port, db = parse_mongo_uri(store_uri)
            self.model_repo_store = MongoStore(host=host,
                                               port=int(port),
                                               username=username,
                                               password=password,
                                               db=db)
        else:
            self.model_repo_store = SqlAlchemyStore(store_uri)
        self.notification_client = None
        if notification_uri is not None:
            self.notification_client = NotificationClient(notification_uri, default_namespace=DEFAULT_NAMESPACE)

    @catch_exception
    def createRegisteredModel(self, request, context):
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name,
                                                                              registered_model_param.model_desc)
        return _wrap_response(registered_model_meta.to_meta_proto())

    @catch_exception
    def updateRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.update_registered_model(
            RegisteredModel(model_meta_param.model_name),
            registered_model_param.model_name,
            registered_model_param.model_desc)
        return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto())

    @catch_exception
    def deleteRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(request.model_meta)

    @catch_exception
    def listRegisteredModels(self, request, context):
        registered_models = self.model_repo_store.list_registered_models()
        return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto()
                                                                      for registered_model in registered_models]))

    @catch_exception
    def getRegisteredModelDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        registered_model_detail = self.model_repo_store.get_registered_model_detail(
            RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto())

    @catch_exception
    def createModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_type,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        event_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
        if self.notification_client is not None:
            self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                          json.dumps(model_version_meta.__dict__),
                                                          event_type))
        return _wrap_response(model_version_meta.to_meta_proto())

    @catch_exception
    def updateModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.update_model_version(model_meta_param,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_type,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        if model_version_param.current_stage is not None:
            event_type = MODEL_VERSION_TO_EVENT_TYPE.get(
                ModelVersionStage.from_string(model_version_param.current_stage))
            if self.notification_client is not None:
                self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                              json.dumps(model_version_meta.__dict__),
                                                              event_type))
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())

    @catch_exception
    def deleteModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        self.model_repo_store.delete_model_version(model_meta_param)
        if self.notification_client is not None:
            self.notification_client.send_event(BaseEvent(model_meta_param.model_name,
                                                          json.dumps(model_meta_param.__dict__),
                                                          ModelVersionEventType.MODEL_DELETED))
        return _wrap_response(request.model_meta)

    @catch_exception
    def getModelVersionDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param)
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
示例#26
0
 def execute(self, context):
     notification_client = NotificationClient(server_uri=self.uri)
     notification_client.send_event(event=self.event)
 def setUp(self):
     self.storage.clean_up()
     self.client = NotificationClient(server_uri="localhost:50052",
                                      enable_ha=True,
                                      list_member_interval_ms=1000,
                                      retry_timeout_ms=10000)
示例#28
0
class TestEventBasedScheduler(unittest.TestCase):
    def setUp(self):
        db.clear_db_jobs()
        db.clear_db_dags()
        db.clear_db_serialized_dags()
        db.clear_db_runs()
        db.clear_db_task_execution()
        db.clear_db_message()
        self.scheduler = None
        self.port = 50102
        self.storage = MemoryEventStorage()
        self.master = NotificationMaster(NotificationService(self.storage),
                                         self.port)
        self.master.run()
        self.client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                         default_namespace="test_namespace")
        time.sleep(1)

    def tearDown(self):
        self.master.stop()

    def _get_task_instance(self, dag_id, task_id, session):
        return session.query(TaskInstance).filter(
            TaskInstance.dag_id == dag_id,
            TaskInstance.task_id == task_id).first()

    def schedule_task_function(self):
        stopped = False
        while not stopped:
            with create_session() as session:
                ti_sleep_1000_secs = self._get_task_instance(
                    EVENT_BASED_SCHEDULER_DAG, 'sleep_1000_secs', session)
                ti_python_sleep = self._get_task_instance(
                    EVENT_BASED_SCHEDULER_DAG, 'python_sleep', session)
                if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.SCHEDULED and \
                   ti_python_sleep and ti_python_sleep.state == State.SCHEDULED:
                    self.client.send_event(
                        BaseEvent(key='start',
                                  value='',
                                  event_type='',
                                  namespace='test_namespace'))

                    while not stopped:
                        ti_sleep_1000_secs.refresh_from_db()
                        ti_python_sleep.refresh_from_db()
                        if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.RUNNING and \
                           ti_python_sleep and ti_python_sleep.state == State.RUNNING:
                            time.sleep(10)
                            break
                        else:
                            time.sleep(1)
                    self.client.send_event(
                        BaseEvent(key='stop',
                                  value='',
                                  event_type=UNDEFINED_EVENT_TYPE,
                                  namespace='test_namespace'))
                    self.client.send_event(
                        BaseEvent(key='restart',
                                  value='',
                                  event_type=UNDEFINED_EVENT_TYPE,
                                  namespace='test_namespace'))
                    while not stopped:
                        ti_sleep_1000_secs.refresh_from_db()
                        ti_python_sleep.refresh_from_db()
                        if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.KILLED and \
                           ti_python_sleep and ti_python_sleep.state == State.RUNNING:
                            stopped = True
                        else:
                            time.sleep(1)
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_event_based_scheduler(self):
        t = threading.Thread(target=self.schedule_task_function)
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_based_scheduler.py')

    def test_replay_message(self):
        key = "stop"
        mailbox = Mailbox()
        mailbox.set_scheduling_job_id(1234)
        watcher = SchedulerEventWatcher(mailbox)
        self.client.start_listen_events(watcher=watcher,
                                        start_time=int(time.time() * 1000),
                                        version=None)
        self.send_event(key)
        msg: BaseEvent = mailbox.get_message()
        self.assertEqual(msg.key, key)
        with create_session() as session:
            msg_from_db = session.query(Message).first()
            expect_non_unprocessed = EventBasedScheduler.get_unprocessed_message(
                1000)
            self.assertEqual(0, len(expect_non_unprocessed))
            unprocessed = EventBasedScheduler.get_unprocessed_message(1234)
            self.assertEqual(unprocessed[0].serialized_message,
                             msg_from_db.data)
        deserialized_data = pickle.loads(msg_from_db.data)
        self.assertEqual(deserialized_data.key, key)
        self.assertEqual(msg, deserialized_data)

    def send_event(self, key):
        event = self.client.send_event(
            BaseEvent(key=key, event_type=UNDEFINED_EVENT_TYPE,
                      value="value1"))
        self.assertEqual(key, event.key)

    @provide_session
    def get_task_execution(self, dag_id, task_id, session):
        return session.query(TaskExecution).filter(
            TaskExecution.dag_id == dag_id,
            TaskExecution.task_id == task_id).all()

    @provide_session
    def get_latest_job_id(self, session):
        return session.query(BaseJob).order_by(sqlalchemy.desc(
            BaseJob.id)).first().id

    def start_scheduler(self, file_path):
        self.scheduler = EventBasedSchedulerJob(
            dag_directory=file_path,
            server_uri="localhost:{}".format(self.port),
            executor=LocalExecutor(3),
            max_runs=-1,
            refresh_dag_dir_interval=30)
        print("scheduler starting")
        self.scheduler.run()

    def wait_for_running(self):
        while True:
            if self.scheduler is not None:
                time.sleep(5)
                break
            else:
                time.sleep(1)

    def wait_for_task_execution(self, dag_id, task_id, expected_num):
        result = False
        check_nums = 100
        while check_nums > 0:
            time.sleep(2)
            check_nums = check_nums - 1
            tes = self.get_task_execution(dag_id, task_id)
            if len(tes) == expected_num:
                result = True
                break
        self.assertTrue(result)

    def wait_for_task(self, dag_id, task_id, expected_state):
        result = False
        check_nums = 100
        while check_nums > 0:
            time.sleep(2)
            check_nums = check_nums - 1
            with create_session() as session:
                ti = session.query(TaskInstance).filter(
                    TaskInstance.dag_id == dag_id,
                    TaskInstance.task_id == task_id).first()
            if ti and ti.state == expected_state:
                result = True
                break
        self.assertTrue(result)

    def test_notification(self):
        self.client.send_event(BaseEvent(key='a', value='b'))

    def run_a_task_function(self):
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_a_task(self):
        t = threading.Thread(target=self.run_a_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_single_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution("single", "task_1")
        self.assertEqual(len(tes), 1)

    def run_event_task_function(self):
        client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                    default_namespace="")
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'event_dag',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    time.sleep(5)
                    client.send_event(
                        BaseEvent(key='start',
                                  value='',
                                  event_type='',
                                  namespace=''))
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'event_dag',
                                TaskExecution.task_id == 'task_2').all()
                            if len(tes_2) > 0:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_event_task(self):
        t = threading.Thread(target=self.run_event_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "event_dag", "task_2")
        self.assertEqual(len(tes), 1)

    def run_trigger_dag_function(self):
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="")
        client = EventSchedulerClient(ns_client=ns_client)
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'trigger_dag',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    break
                else:
                    client.trigger_parse_dag()
                    result = client.schedule_dag('trigger_dag')
                    print('result {}'.format(result.dagrun_id))
                time.sleep(5)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_trigger_dag(self):
        import multiprocessing
        p = multiprocessing.Process(target=self.run_trigger_dag_function,
                                    args=())
        p.start()
        self.start_scheduler('../../dags/test_run_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_dag", "task_1")
        self.assertEqual(len(tes), 1)

    def run_no_dag_file_function(self):
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="")
        client = EventSchedulerClient(ns_client=ns_client)
        with create_session() as session:
            client.trigger_parse_dag()
            result = client.schedule_dag('no_dag')
            print('result {}'.format(result.dagrun_id))
            time.sleep(5)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_no_dag_file_trigger_dag(self):
        import multiprocessing
        p = multiprocessing.Process(target=self.run_no_dag_file_function,
                                    args=())
        p.start()
        self.start_scheduler('../../dags/test_run_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_dag", "task_1")
        self.assertEqual(len(tes), 0)

    def run_trigger_task_function(self):
        # waiting parsed dag file done,
        time.sleep(5)
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="a")
        client = EventSchedulerClient(ns_client=ns_client)
        execution_context = client.schedule_dag('trigger_task')
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'trigger_task',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    client.schedule_task('trigger_task', 'task_2',
                                         SchedulingAction.START,
                                         execution_context)
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'trigger_task',
                                TaskExecution.task_id == 'task_2').all()
                            if len(tes_2) > 0:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_task_trigger_dag(self):
        import threading
        t = threading.Thread(target=self.run_trigger_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_task_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_task", "task_2")
        self.assertEqual(len(tes), 1)

    def run_ai_flow_function(self):
        client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                    default_namespace="default",
                                    sender='1-job-name')
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'workflow_1',
                    TaskExecution.task_id == '1-job-name').all()
                if len(tes) > 0:
                    time.sleep(5)
                    client.send_event(
                        BaseEvent(key='key_1',
                                  value='value_1',
                                  event_type='UNDEFINED'))
                    client.send_event(
                        BaseEvent(key='key_2',
                                  value='value_2',
                                  event_type='UNDEFINED'))
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'workflow_1').all()
                            if len(tes_2) == 3:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        time.sleep(3)
        client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_ai_flow_dag(self):
        import threading
        t = threading.Thread(target=self.run_ai_flow_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_aiflow_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "workflow_1", "1-job-name")
        self.assertEqual(len(tes), 1)

    def stop_dag_function(self):
        stopped = False
        while not stopped:
            tes = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG,
                                          'sleep_to_be_stopped')
            if tes and len(tes) == 1:
                self.client.send_event(
                    StopDagEvent(EVENT_BASED_SCHEDULER_DAG).to_event())
                while not stopped:
                    tes2 = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG,
                                                   'sleep_to_be_stopped')
                    if tes2[0].state == State.KILLED:
                        stopped = True
                        time.sleep(5)
                    else:
                        time.sleep(1)
            else:
                time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_stop_dag(self):
        t = threading.Thread(target=self.stop_dag_function)
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_based_scheduler.py')
        with create_session() as session:
            from airflow.models import DagModel
            dag_model: DagModel = DagModel.get_dagmodel(
                EVENT_BASED_SCHEDULER_DAG)
            self.assertTrue(dag_model.is_paused)
            self.assertEqual(dag_model.get_last_dagrun().state, "killed")
            for ti in session.query(TaskInstance).filter(
                    TaskInstance.dag_id == EVENT_BASED_SCHEDULER_DAG):
                self.assertTrue(ti.state in [State.SUCCESS, State.KILLED])
            for te in session.query(TaskExecution).filter(
                    TaskExecution.dag_id == EVENT_BASED_SCHEDULER_DAG):
                self.assertTrue(te.state in [State.SUCCESS, State.KILLED])

    def run_periodic_task_function(self):
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 1:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_periodic_task(self):
        t = threading.Thread(target=self.run_periodic_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_periodic_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution("single", "task_1")
        self.assertGreater(len(tes), 1)

    def run_one_task_function(self):
        self.wait_for_running()
        self.client.send_event(BaseEvent(key='a', value='a'))
        time.sleep(5)
        self.client.send_event(BaseEvent(key='a', value='a'))
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) >= 2:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_one_task(self):
        t = threading.Thread(target=self.run_one_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_multiple_trigger_task_dag.py')
示例#29
0
class ModelCenterService(model_center_service_pb2_grpc.ModelCenterServiceServicer):

    def __init__(self, store_uri, server_uri, notification_uri=None):
        self.model_repo_store = SqlAlchemyStore(store_uri)
        if notification_uri is None:
            self.notification_client = NotificationClient(server_uri)
        else:
            self.notification_client = NotificationClient(notification_uri)

    @catch_exception
    def createRegisteredModel(self, request, context):
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.create_registered_model(registered_model_param.model_name,
                                                                              ModelType.Name(
                                                                                  registered_model_param.model_type),
                                                                              registered_model_param.model_desc)
        return _wrap_response(registered_model_meta.to_meta_proto())

    @catch_exception
    def updateRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        registered_model_param = RegisteredModelParam.from_proto(request)
        registered_model_meta = self.model_repo_store.update_registered_model(
            RegisteredModel(model_meta_param.model_name),
            registered_model_param.model_name,
            ModelType.Name(
                registered_model_param.model_type),
            registered_model_param.model_desc)
        return _wrap_response(None if registered_model_meta is None else registered_model_meta.to_meta_proto())

    @catch_exception
    def deleteRegisteredModel(self, request, context):
        model_meta_param = RegisteredModel.from_proto(request)
        self.model_repo_store.delete_registered_model(RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(request.model_meta)

    @catch_exception
    def listRegisteredModels(self, request, context):
        registered_models = self.model_repo_store.list_registered_models()
        return _wrap_response(RegisteredModelMetas(registered_models=[registered_model.to_meta_proto()
                                                                      for registered_model in registered_models]))

    @catch_exception
    def getRegisteredModelDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        registered_model_detail = self.model_repo_store.get_registered_model_detail(
            RegisteredModel(model_name=model_meta_param.model_name))
        return _wrap_response(None if registered_model_detail is None else registered_model_detail.to_detail_proto())

    @catch_exception
    def createModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.create_model_version(model_meta_param.model_name,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_metric,
                                                                        model_version_param.model_flavor,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
        self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                      json.dumps(model_version_meta.__dict__),
                                                      model_type))
        return _wrap_response(model_version_meta.to_meta_proto())

    @catch_exception
    def updateModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_param = ModelVersionParam.from_proto(request)
        model_version_meta = self.model_repo_store.update_model_version(model_meta_param,
                                                                        model_version_param.model_path,
                                                                        model_version_param.model_metric,
                                                                        model_version_param.model_flavor,
                                                                        model_version_param.version_desc,
                                                                        model_version_param.current_stage)
        if model_version_param.current_stage is not None:
            model_type = MODEL_VERSION_TO_EVENT_TYPE.get(ModelVersionStage.from_string(model_version_param.current_stage))
            self.notification_client.send_event(BaseEvent(model_version_meta.model_name,
                                                          json.dumps(model_version_meta.__dict__),
                                                          model_type))
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())

    @catch_exception
    def deleteModelVersion(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        self.model_repo_store.delete_model_version(model_meta_param)
        self.notification_client.send_event(BaseEvent(model_meta_param.model_name,
                                                      json.dumps(model_meta_param.__dict__),
                                                      ModelVersionEventType.MODEL_DELETED))
        return _wrap_response(request.model_meta)

    @catch_exception
    def getModelVersionDetail(self, request, context):
        model_meta_param = ModelVersion.from_proto(request)
        model_version_meta = self.model_repo_store.get_model_version_detail(model_meta_param)
        return _wrap_response(None if model_version_meta is None else model_version_meta.to_meta_proto())
示例#30
0
 def __init__(self, store_uri, server_uri, notification_uri=None):
     self.model_repo_store = SqlAlchemyStore(store_uri)
     if notification_uri is None:
         self.notification_client = NotificationClient(server_uri)
     else:
         self.notification_client = NotificationClient(notification_uri)