def start_notification_service(port: int = 50052, db_conn: str = None, enable_ha: bool = False, server_uris: str = None, create_table_if_not_exists: bool = True): if db_conn: storage = DbEventStorage(db_conn, create_table_if_not_exists) else: raise Exception( 'Failed to start notification service without database connection info.' ) if enable_ha: if not server_uris: raise Exception("When HA enabled, server_uris must be set.") ha_storage = DbHighAvailabilityStorage(db_conn=db_conn) ha_manager = SimpleNotificationServerHaManager() service = HighAvailableNotificationService(storage, ha_manager, server_uris, ha_storage, 5000) master = NotificationMaster(service=service, port=int(port)) else: master = NotificationMaster(service=NotificationService(storage), port=port) master.run(is_block=True)
def start_ha_master(host, port): server_uri = host + ":" + str(port) storage = DbEventStorage() ha_manager = SimpleNotificationServerHaManager() ha_storage = DbHighAvailabilityStorage() service = HighAvailableNotificationService(storage, ha_manager, server_uri, ha_storage) master = NotificationMaster(service, port=port) master.run() return master
def start_master(cls, host, port): port = str(port) server_uri = host + ":" + port storage = DbEventStorage() ha_manager = SimpleNotificationServerHaManager() ha_storage = DbHighAvailabilityStorage(db_conn=_SQLITE_DB_URI) service = HighAvailableNotificationService(storage, ha_manager, server_uri, ha_storage, 5000) master = NotificationMaster(service, port=int(port)) master.run() return master
def test_user_trigger_parse_dag(self): port = 50101 service_uri = 'localhost:{}'.format(port) storage = MemoryEventStorage() master = NotificationMaster(NotificationService(storage), port) master.run() mailbox = Mailbox() dag_trigger = DagTrigger("../../dags/test_scheduler_dags.py", -1, [], False, mailbox, 5, service_uri) dag_trigger.start() message = mailbox.get_message() message = SchedulerInnerEventUtil.to_inner_event(message) # only one dag is executable assert "test_task_start_date_scheduling" == message.dag_id sc = EventSchedulerClient(server_uri=service_uri, namespace='a') sc.trigger_parse_dag() dag_trigger.end() master.stop()
class BaseExecutorTest(unittest.TestCase): def setUp(self): clear_db_event_model() self.master = NotificationMaster(service=NotificationService(EventModelStorage())) self.master.run() self.client = NotificationClient(server_uri="localhost:50051") def tearDown(self) -> None: self.master.stop() def test_get_event_buffer(self): executor = BaseExecutor() date = datetime.utcnow() try_number = 1 key1 = ("my_dag1", "my_task1", date, try_number) key2 = ("my_dag2", "my_task1", date, try_number) key3 = ("my_dag2", "my_task2", date, try_number) state = State.SUCCESS executor.event_buffer[key1] = state executor.event_buffer[key2] = state executor.event_buffer[key3] = state self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1) self.assertEqual(len(executor.get_event_buffer()), 2) self.assertEqual(len(executor.event_buffer), 0) @mock.patch('airflow.executors.base_executor.BaseExecutor.sync') @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks') @mock.patch('airflow.settings.Stats.gauge') def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): executor = BaseExecutor() executor.heartbeat() calls = [mock.call('executor.open_slots', mock.ANY), mock.call('executor.queued_tasks', mock.ANY), mock.call('executor.running_tasks', mock.ANY)] mock_stats_gauge.assert_has_calls(calls) def test_use_nf_executor(self): executor = BaseExecutor() executor.set_use_nf(True) executor.change_state('key', State.RUNNING) executor.change_state('key', State.SUCCESS) events = self.client.list_all_events(1) self.assertEqual(2, len(events))
def test_trigger_parse_dag(self): import os port = 50102 server_uri = "localhost:{}".format(port) storage = MemoryEventStorage() master = NotificationMaster(NotificationService(storage), port) master.run() dag_folder = os.path.abspath(os.path.dirname(__file__)) + "/../../dags" mailbox = Mailbox() dag_trigger = DagTrigger(dag_folder, -1, [], False, mailbox, notification_service_uri=server_uri) dag_trigger.start() to_be_triggered = [dag_folder + "/test_event_based_scheduler.py", dag_folder + "/test_event_task_dag.py", dag_folder + "/test_event_based_executor.py", dag_folder + "/test_scheduler_dags.py", ] for file in to_be_triggered: self._send_request_and_receive_response(server_uri, file) dag_trigger.end()
class EventSchedulerJob(SchedulerJob): """ EventSchedulerJob: The scheduler driven by events. The scheduler get the message from notification service, then scheduling the tasks which affected by the events. """ __mapper_args__ = {'polymorphic_identity': 'EventSchedulerJob'} def __init__(self, dag_id=None, dag_ids=None, subdir=settings.DAGS_FOLDER, num_runs=conf.getint('scheduler', 'num_runs', fallback=-1), processor_poll_interval=conf.getfloat( 'scheduler', 'processor_poll_interval', fallback=1), use_local_nf=conf.getboolean('scheduler', 'use_local_notification', fallback=True), nf_host=conf.get('scheduler', 'notification_host', fallback='localhost'), nf_port=conf.getint('scheduler', 'notification_port', fallback=50051), unit_test_mode=conf.getboolean('core', 'unit_test_mode', fallback=False), executor_heartbeat_interval=conf.getint( 'scheduler', 'executor_heartbeat_interval', fallback=2), run_duration=None, do_pickle=False, log=None, *args, **kwargs): super().__init__(dag_id, dag_ids, subdir, num_runs, processor_poll_interval, run_duration, do_pickle, log, *args, **kwargs) self.dag_trigger = None self.notification_master = None self.use_local_nf = use_local_nf self.nf_host = nf_host self.nf_port = nf_port self.mail_box = Mailbox() self.running = True self.dagrun_route = DagRunRoute() self.unit_test_mode = unit_test_mode self.executor_heartbeat_interval = executor_heartbeat_interval self.heartbeat_thread = None @provide_session def _get_dag_runs(self, event, session): dag_runs = [] if EventType.is_in(event.event_type) and EventType( event.event_type) != EventType.UNDEFINED: if EventType(event.event_type) == EventType.DAG_RUN_EXECUTABLE: dag_run_id = int(event.key) dag_run = session.query(DagRun).filter( DagRun.id == dag_run_id).first() if dag_run is None: self.log.error("DagRun is None id {0}".format(dag_run_id)) return dag_runs simple_dag = event.simple_dag dag_run.pickle_id = None # create route self.dagrun_route.add_dagrun(dag_run, simple_dag, session) dag_runs.append(dag_run) elif EventType(event.event_type) == EventType.TASK_STATUS_CHANGED: dag_id, task_id, execution_date = TaskInstanceHelper.from_task_key( event.key) state, try_num = TaskInstanceHelper.from_event_value( event.value) dag_run = self.dagrun_route.find_dagrun(dag_id, execution_date) if dag_run is None: return dag_runs self._set_task_instance_state(dag_run, dag_id, task_id, execution_date, state, try_num) sync_dag_run = session.query(DagRun).filter( DagRun.id == dag_run.id).first() if sync_dag_run.state in State.finished(): self.log.info( "DagRun finished dag_id {0} execution_date {1} state {2}" .format(dag_run.dag_id, dag_run.execution_date, sync_dag_run.state)) if self.dagrun_route.find_dagrun_by_id( sync_dag_run.id) is not None: self.dagrun_route.remove_dagrun(dag_run, session) self.log.debug("Route remove dag run {0}".format( sync_dag_run.id)) self.mail_box.send_message( DagRunFinishedEvent(dag_run.id, sync_dag_run.state)) else: dag_runs.append(dag_run) elif EventType(event.event_type) == EventType.DAG_RUN_FINISHED: self.log.debug("DagRun {0} finished".format(event.key)) elif EventType(event.event_type) == EventType.STOP_SCHEDULER_CMD: if self.unit_test_mode: self.running = False return dag_runs else: runs = self.dagrun_route.find_dagruns_by_event( event_key=event.key, event_type=event.event_type) if runs is not None: for run in runs: task_deps = load_task_dependencies(dag_id=run.dag_id, session=session) tis = run.get_task_instances(session=session) for ti in tis: if ti.task_id not in task_deps: continue if (event.key, event.event_type) in task_deps[ti.task_id]: self.log.debug("{0} handle event {1}".format( ti.task_id, event)) ts = TaskState.query_task_state(ti, session=session) handler = ts.event_handler if handler is not None: action = handler.handle_event(event, ti=ti, ts=ts, session=session) ts.action = action session.merge(ts) session.commit() self.log.debug( "set task action {0} {1}".format( ti.task_id, action)) dag_runs.extend(runs) session.commit() for dag_run in dag_runs: run_process_func(target=process_tasks, args=( dag_run, self.dagrun_route.find_simple_dag(dag_run.id), self.log, )) return dag_runs @provide_session def _sync_event_to_db(self, event: Event, session=None): EventModel.sync_event(event=event, session=session) @provide_session def _run_event_loop(self, session=None): """ The main process event loop :param session: the connection of db session. :return: None """ while self.running: event: Event = self.mail_box.get_message() self.log.debug('EVENT: {0}'.format(event)) if not self.use_local_nf: self._sync_event_to_db(session) try: dag_runs = self._get_dag_runs(event) if dag_runs is None or len(dag_runs) == 0: continue # create SimpleDagBag simple_dags = [] for dag_run in dag_runs: simple_dags.append( self.dagrun_route.find_simple_dag( dagrun_id=dag_run.id)) simple_dag_bag = SimpleDagBag(simple_dags) if not self._validate_and_run_task_instances( simple_dag_bag=simple_dag_bag): continue except Exception as e: self.log.exception(str(e)) # scheduler end self.log.debug("_run_event_loop end") @provide_session def _init_route(self, session=None): """ Init the DagRunRoute object from db. :param session: :return: """ # running_dag_runs = session.query(DagRun).filter(DagRun.state == State.RUNNING).all() # for dag_run in running_dag_runs: # dag_model = session.query(DagModel).filter(DagModel.dag_id == dag_run.dag_id).first() # dagbag = models.DagBag(dag_model.fileloc) # dag_run.dag = dagbag.get_dag(dag_run.dag_id) # self.dagrun_route.add_dagrun(dag_run) # todo init route pass def _executor_heartbeat(self): while self.running: self.log.info("executor heartbeat...") self.executor.heartbeat() time.sleep(self.executor_heartbeat_interval) def _start_executor_heartbeat(self): self.heartbeat_thread = threading.Thread( target=self._executor_heartbeat, args=()) self.heartbeat_thread.setDaemon(True) self.heartbeat_thread.start() def _stop_executor_heartheat(self): self.running = False if self.heartbeat_thread is not None: self.heartbeat_thread.join() def _execute(self): """ 1. Init the DagRun route. 2. Start the executor. 3. Option of start the notification master. 4. Create the notification client. 5. Start the DagTrigger. 6. Run the scheduler event loop. :return: """ notification_client = None try: self._init_route() self.executor.set_use_nf(True) self.executor.start() self.dag_trigger = DagTrigger( subdir=self.subdir, mailbox=self.mail_box, run_duration=self.run_duration, using_sqlite=self.using_sqlite, num_runs=self.num_runs, processor_poll_interval=self._processor_poll_interval) if self.use_local_nf: self.notification_master \ = NotificationMaster(service=NotificationService(EventModelStorage()), port=self.nf_port) self.notification_master.run() self.log.info("start notification service {0}".format( self.nf_port)) notification_client = NotificationClient( server_uri="localhost:{0}".format(self.nf_port)) else: notification_client \ = NotificationClient(server_uri="{0}:{1}".format(self.nf_host, self.nf_port)) notification_client.start_listen_events( watcher=SCEventWatcher(self.mail_box)) self.dag_trigger.start() self._start_executor_heartbeat() self._run_event_loop() except Exception as e: self.log.exception("Exception when executing _execute {0}".format( str(e))) finally: self.running = False self._stop_executor_heartheat() if self.dag_trigger is not None: self.dag_trigger.stop() if notification_client is not None: notification_client.stop_listen_events() if self.notification_master is not None: self.notification_master.stop() self.executor.end() self.log.info("Exited execute event scheduler") @provide_session def _set_task_instance_state(self, dag_run, dag_id, task_id, execution_date, state, try_number, session=None): """ Set the task state to db and maybe set the dagrun object finished to db. :param dag_run: DagRun object :param dag_id: Dag identify :param task_id: task identify :param execution_date: the dag run execution date :param state: the task state should be set. :param try_number: the task try_number. :param session: :return: """ TI = models.TaskInstance qry = session.query(TI).filter(TI.dag_id == dag_id, TI.task_id == task_id, TI.execution_date == execution_date) ti = qry.first() if not ti: self.log.warning("TaskInstance %s went missing from the database", ti) return ts = TaskState.query_task_state(ti, session) self.log.debug( "set task state dag_id {0} task_id {1} execution_date {2} try_number {3} " "current try_number {4} state {5} ack_id {6} action {7}.".format( dag_id, task_id, execution_date, try_number, ti.try_number, state, ts.ack_id, ts.action)) is_restart = False if state == State.FAILED or state == State.SUCCESS or state == State.SHUTDOWN: if ti.try_number == try_number and ti.state == State.QUEUED: msg = ("Executor reports task instance {} finished ({}) " "although the task says its {}. Was the task " "killed externally?".format(ti, state, ti.state)) Stats.incr('scheduler.tasks.killed_externally') self.log.error(msg) try: dag = self.task_route.find_dagrun(dag_id, execution_date) ti.task = dag.get_task(task_id) ti.handle_failure(msg) except Exception: self.log.error( "Cannot load the dag bag to handle failure for %s" ". Setting task to FAILED without callbacks or " "retries. Do you have enough resources?", ti) ti.state = State.FAILED session.merge(ti) else: if ts.action is None: self.log.debug( "task dag_id {0} task_id {1} execution_date {2} action is None." .format(dag_id, task_id, execution_date)) elif TaskAction(ts.action) == TaskAction.RESTART: # if ts.stop_flag is not None and ts.stop_flag == try_number: ti.state = State.SCHEDULED ts.action = None ts.stop_flag = None ts.ack_id = 0 session.merge(ti) session.merge(ts) self.log.debug( "task dag_id {0} task_id {1} execution_date {2} try_number {3} restart action." .format(dag_id, task_id, execution_date, str(try_number))) is_restart = True elif TaskAction(ts.action) == TaskAction.STOP: # if ts.stop_flag is not None and ts.stop_flag == try_number: ts.action = None ts.stop_flag = None ts.ack_id = 0 session.merge(ts) self.log.debug( "task dag_id {0} task_id {1} execution_date {2} try_number {3} stop action." .format(dag_id, task_id, execution_date, str(try_number))) else: self.log.debug( "task dag_id {0} task_id {1} execution_date {2} action {3}." .format(dag_id, task_id, execution_date, ts.action)) session.commit() if not is_restart and ti.state == State.RUNNING: self.log.debug( "set task dag_id {0} task_id {1} execution_date {2} state {3}". format(dag_id, task_id, execution_date, state)) ti.state = state session.merge(ti) session.commit() # update dagrun state sync_dag_run = session.query(DagRun).filter( DagRun.id == dag_run.id).first() if sync_dag_run.state not in FINISHED_STATES: if self.dagrun_route.find_dagrun_by_id(sync_dag_run.id) is None: self.log.error( "DagRun lost dag_id {0} task_id {1} execution_date {2}". format(dag_id, task_id, execution_date)) else: run_process_func(target=dag_run_update_state, args=( dag_run, self.dagrun_route.find_simple_dag( dag_run.id), )) @provide_session def _create_task_instances(self, dag_run, session=None): """ This method schedules the tasks for a single DAG by looking at the active DAG runs and adding task instances that should run to the queue. """ # update the state of the previously active dag runs dag_runs = DagRun.find(dag_id=dag_run.dag_id, state=State.RUNNING, session=session) active_dag_runs = [] for run in dag_runs: self.log.info("Examining DAG run %s", run) # don't consider runs that are executed in the future unless # specified by config and schedule_interval is None if run.execution_date > timezone.utcnow( ) and not dag_run.dag.allow_future_exec_dates: self.log.error("Execution date is in future: %s", run.execution_date) continue if len(active_dag_runs) >= dag_run.dag.max_active_runs: self.log.info( "Number of active dag runs reached max_active_run.") break # skip backfill dagruns for now as long as they are not really scheduled if run.is_backfill: continue run.dag = dag_run.dag # todo: preferably the integrity check happens at dag collection time run.verify_integrity(session=session) run.update_state(session=session) if run.state == State.RUNNING: make_transient(run) active_dag_runs.append(run) def _process_dags_and_create_dagruns(self, dagbag, dags, dagrun_out): """ Iterates over the dags and processes them. Processing includes: 1. Create appropriate DagRun(s) in the DB. 2. Create appropriate TaskInstance(s) in the DB. 3. Send emails for tasks that have missed SLAs. :param dagbag: a collection of DAGs to process :type dagbag: airflow.models.DagBag :param dags: the DAGs from the DagBag to process :type dags: list[airflow.models.DAG] :param dagrun_out: A list to add DagRun objects :type dagrun_out: list[DagRun] :rtype: None """ for dag in dags: dag = dagbag.get_dag(dag.dag_id) if not dag: self.log.error("DAG ID %s was not found in the DagBag", dag.dag_id) continue if dag.is_paused: self.log.info("Not processing DAG %s since it's paused", dag.dag_id) continue self.log.info("Processing %s", dag.dag_id) dag_run = self.create_dag_run(dag) if dag_run: dag_run.dag = dag expected_start_date = dag.following_schedule( dag_run.execution_date) if expected_start_date: schedule_delay = dag_run.start_date - expected_start_date Stats.timing( 'dagrun.schedule_delay.{dag_id}'.format( dag_id=dag.dag_id), schedule_delay) self.log.info("Created %s", dag_run) self._create_task_instances(dag_run) self.log.info("Created tasks instances %s", dag_run) dagrun_out.append(dag_run) if conf.getboolean('core', 'CHECK_SLAS', fallback=True): self.manage_slas(dag) @provide_session def process_file(self, file_path, zombies, pickle_dags=False, session=None): """ Process a Python file containing Airflow DAGs. This includes: 1. Execute the file and look for DAG objects in the namespace. 2. Pickle the DAG and save it to the DB (if necessary). 3. For each DAG, see what tasks should run and create appropriate task instances in the DB. 4. Record any errors importing the file into ORM 5. Kill (in ORM) any task instances belonging to the DAGs that haven't issued a heartbeat in a while. Returns a list of SimpleDag objects that represent the DAGs found in the file :param file_path: the path to the Python file that should be executed :type file_path: unicode :param zombies: zombie task instances to kill. :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance] :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool :return: a list of SimpleDagRuns made from the Dags found in the file :rtype: list[airflow.utils.dag_processing.SimpleDagBag] """ self.log.info("Processing file %s for tasks to queue", file_path) if session is None: session = settings.Session() # As DAGs are parsed from this file, they will be converted into SimpleDags try: dagbag = models.DagBag(file_path, include_examples=False) except Exception: self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) return [], [] if len(dagbag.dags) > 0: self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) return [], len(dagbag.import_errors) # Save individual DAGs in the ORM and update DagModel.last_scheduled_time for dag in dagbag.dags.values(): dag.sync_to_db() paused_dag_ids = [ dag.dag_id for dag in dagbag.dags.values() if dag.is_paused ] self.log.info("paused_dag_ids %s", paused_dag_ids) self.log.info("self %s", self.dag_ids) dag_to_pickle = {} # Pickle the DAGs (if necessary) and put them into a SimpleDag for dag_id in dagbag.dags: # Only return DAGs that are not paused if dag_id not in paused_dag_ids: dag = dagbag.get_dag(dag_id) pickle_id = None if pickle_dags: pickle_id = dag.pickle(session).id dag_to_pickle[dag.dag_id] = pickle_id if len(self.dag_ids) > 0: dags = [ dag for dag in dagbag.dags.values() if dag.dag_id in self.dag_ids and dag.dag_id not in paused_dag_ids ] else: dags = [ dag for dag in dagbag.dags.values() if not dag.parent_dag and dag.dag_id not in paused_dag_ids ] # Not using multiprocessing.Queue() since it's no longer a separate # process and due to some unusual behavior. (empty() incorrectly # returns true as described in https://bugs.python.org/issue23582 ) self.log.info("dags %s", dags) dag_run_out = [] self._process_dags_and_create_dagruns(dagbag, dags, dag_run_out) self.log.info("dag run out %s", len(dag_run_out)) simple_dag_runs = [] for dag_run in dag_run_out: simple_dag_runs.append( SimpleDagRun(dag_run.id, SimpleDag(dag_run.dag))) # commit batch session.commit() # Record import errors into the ORM try: self.update_import_errors(session, dagbag) except Exception: self.log.exception("Error logging import errors!") try: dagbag.kill_zombies(zombies) except Exception: self.log.exception("Error killing zombies!") return simple_dag_runs, len(dagbag.import_errors)
class TestEventBasedScheduler(unittest.TestCase): def setUp(self): db.clear_db_jobs() db.clear_db_dags() db.clear_db_serialized_dags() db.clear_db_runs() db.clear_db_task_execution() db.clear_db_message() self.scheduler = None self.port = 50102 self.storage = MemoryEventStorage() self.master = NotificationMaster(NotificationService(self.storage), self.port) self.master.run() self.client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="test_namespace") time.sleep(1) def tearDown(self): self.master.stop() def _get_task_instance(self, dag_id, task_id, session): return session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id).first() def schedule_task_function(self): stopped = False while not stopped: with create_session() as session: ti_sleep_1000_secs = self._get_task_instance( EVENT_BASED_SCHEDULER_DAG, 'sleep_1000_secs', session) ti_python_sleep = self._get_task_instance( EVENT_BASED_SCHEDULER_DAG, 'python_sleep', session) if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.SCHEDULED and \ ti_python_sleep and ti_python_sleep.state == State.SCHEDULED: self.client.send_event( BaseEvent(key='start', value='', event_type='', namespace='test_namespace')) while not stopped: ti_sleep_1000_secs.refresh_from_db() ti_python_sleep.refresh_from_db() if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.RUNNING and \ ti_python_sleep and ti_python_sleep.state == State.RUNNING: time.sleep(10) break else: time.sleep(1) self.client.send_event( BaseEvent(key='stop', value='', event_type=UNDEFINED_EVENT_TYPE, namespace='test_namespace')) self.client.send_event( BaseEvent(key='restart', value='', event_type=UNDEFINED_EVENT_TYPE, namespace='test_namespace')) while not stopped: ti_sleep_1000_secs.refresh_from_db() ti_python_sleep.refresh_from_db() if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.KILLED and \ ti_python_sleep and ti_python_sleep.state == State.RUNNING: stopped = True else: time.sleep(1) else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_event_based_scheduler(self): t = threading.Thread(target=self.schedule_task_function) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_based_scheduler.py') def test_replay_message(self): key = "stop" mailbox = Mailbox() mailbox.set_scheduling_job_id(1234) watcher = SchedulerEventWatcher(mailbox) self.client.start_listen_events(watcher=watcher, start_time=int(time.time() * 1000), version=None) self.send_event(key) msg: BaseEvent = mailbox.get_message() self.assertEqual(msg.key, key) with create_session() as session: msg_from_db = session.query(Message).first() expect_non_unprocessed = EventBasedScheduler.get_unprocessed_message( 1000) self.assertEqual(0, len(expect_non_unprocessed)) unprocessed = EventBasedScheduler.get_unprocessed_message(1234) self.assertEqual(unprocessed[0].serialized_message, msg_from_db.data) deserialized_data = pickle.loads(msg_from_db.data) self.assertEqual(deserialized_data.key, key) self.assertEqual(msg, deserialized_data) def send_event(self, key): event = self.client.send_event( BaseEvent(key=key, event_type=UNDEFINED_EVENT_TYPE, value="value1")) self.assertEqual(key, event.key) @provide_session def get_task_execution(self, dag_id, task_id, session): return session.query(TaskExecution).filter( TaskExecution.dag_id == dag_id, TaskExecution.task_id == task_id).all() @provide_session def get_latest_job_id(self, session): return session.query(BaseJob).order_by(sqlalchemy.desc( BaseJob.id)).first().id def start_scheduler(self, file_path): self.scheduler = EventBasedSchedulerJob( dag_directory=file_path, server_uri="localhost:{}".format(self.port), executor=LocalExecutor(3), max_runs=-1, refresh_dag_dir_interval=30) print("scheduler starting") self.scheduler.run() def wait_for_running(self): while True: if self.scheduler is not None: time.sleep(5) break else: time.sleep(1) def wait_for_task_execution(self, dag_id, task_id, expected_num): result = False check_nums = 100 while check_nums > 0: time.sleep(2) check_nums = check_nums - 1 tes = self.get_task_execution(dag_id, task_id) if len(tes) == expected_num: result = True break self.assertTrue(result) def wait_for_task(self, dag_id, task_id, expected_state): result = False check_nums = 100 while check_nums > 0: time.sleep(2) check_nums = check_nums - 1 with create_session() as session: ti = session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id, TaskInstance.task_id == task_id).first() if ti and ti.state == expected_state: result = True break self.assertTrue(result) def test_notification(self): self.client.send_event(BaseEvent(key='a', value='b')) def run_a_task_function(self): while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_a_task(self): t = threading.Thread(target=self.run_a_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_single_task_dag.py') tes: List[TaskExecution] = self.get_task_execution("single", "task_1") self.assertEqual(len(tes), 1) def run_event_task_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='start', value='', event_type='', namespace='')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'event_dag', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_event_task(self): t = threading.Thread(target=self.run_event_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_task_dag.py') tes: List[TaskExecution] = self.get_task_execution( "event_dag", "task_2") self.assertEqual(len(tes), 1) def run_trigger_dag_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_dag', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: break else: client.trigger_parse_dag() result = client.schedule_dag('trigger_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_trigger_dag(self): import multiprocessing p = multiprocessing.Process(target=self.run_trigger_dag_function, args=()) p.start() self.start_scheduler('../../dags/test_run_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_dag", "task_1") self.assertEqual(len(tes), 1) def run_no_dag_file_function(self): ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="") client = EventSchedulerClient(ns_client=ns_client) with create_session() as session: client.trigger_parse_dag() result = client.schedule_dag('no_dag') print('result {}'.format(result.dagrun_id)) time.sleep(5) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_no_dag_file_trigger_dag(self): import multiprocessing p = multiprocessing.Process(target=self.run_no_dag_file_function, args=()) p.start() self.start_scheduler('../../dags/test_run_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_dag", "task_1") self.assertEqual(len(tes), 0) def run_trigger_task_function(self): # waiting parsed dag file done, time.sleep(5) ns_client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="a") client = EventSchedulerClient(ns_client=ns_client) execution_context = client.schedule_dag('trigger_task') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_1').all() if len(tes) > 0: client.schedule_task('trigger_task', 'task_2', SchedulingAction.START, execution_context) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'trigger_task', TaskExecution.task_id == 'task_2').all() if len(tes_2) > 0: break else: time.sleep(1) break else: time.sleep(1) ns_client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_task_trigger_dag(self): import threading t = threading.Thread(target=self.run_trigger_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_task_trigger_dag.py') tes: List[TaskExecution] = self.get_task_execution( "trigger_task", "task_2") self.assertEqual(len(tes), 1) def run_ai_flow_function(self): client = NotificationClient(server_uri="localhost:{}".format( self.port), default_namespace="default", sender='1-job-name') while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1', TaskExecution.task_id == '1-job-name').all() if len(tes) > 0: time.sleep(5) client.send_event( BaseEvent(key='key_1', value='value_1', event_type='UNDEFINED')) client.send_event( BaseEvent(key='key_2', value='value_2', event_type='UNDEFINED')) while True: with create_session() as session_2: tes_2 = session_2.query(TaskExecution).filter( TaskExecution.dag_id == 'workflow_1').all() if len(tes_2) == 3: break else: time.sleep(1) break else: time.sleep(1) time.sleep(3) client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_ai_flow_dag(self): import threading t = threading.Thread(target=self.run_ai_flow_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_aiflow_dag.py') tes: List[TaskExecution] = self.get_task_execution( "workflow_1", "1-job-name") self.assertEqual(len(tes), 1) def stop_dag_function(self): stopped = False while not stopped: tes = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG, 'sleep_to_be_stopped') if tes and len(tes) == 1: self.client.send_event( StopDagEvent(EVENT_BASED_SCHEDULER_DAG).to_event()) while not stopped: tes2 = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG, 'sleep_to_be_stopped') if tes2[0].state == State.KILLED: stopped = True time.sleep(5) else: time.sleep(1) else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_stop_dag(self): t = threading.Thread(target=self.stop_dag_function) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_event_based_scheduler.py') with create_session() as session: from airflow.models import DagModel dag_model: DagModel = DagModel.get_dagmodel( EVENT_BASED_SCHEDULER_DAG) self.assertTrue(dag_model.is_paused) self.assertEqual(dag_model.get_last_dagrun().state, "killed") for ti in session.query(TaskInstance).filter( TaskInstance.dag_id == EVENT_BASED_SCHEDULER_DAG): self.assertTrue(ti.state in [State.SUCCESS, State.KILLED]) for te in session.query(TaskExecution).filter( TaskExecution.dag_id == EVENT_BASED_SCHEDULER_DAG): self.assertTrue(te.state in [State.SUCCESS, State.KILLED]) def run_periodic_task_function(self): while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) > 1: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_periodic_task(self): t = threading.Thread(target=self.run_periodic_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_periodic_task_dag.py') tes: List[TaskExecution] = self.get_task_execution("single", "task_1") self.assertGreater(len(tes), 1) def run_one_task_function(self): self.wait_for_running() self.client.send_event(BaseEvent(key='a', value='a')) time.sleep(5) self.client.send_event(BaseEvent(key='a', value='a')) while True: with create_session() as session: tes = session.query(TaskExecution).filter( TaskExecution.dag_id == 'single', TaskExecution.task_id == 'task_1').all() if len(tes) >= 2: break else: time.sleep(1) self.client.send_event(StopSchedulerEvent(job_id=0).to_event()) def test_run_one_task(self): t = threading.Thread(target=self.run_one_task_function, args=()) t.setDaemon(True) t.start() self.start_scheduler('../../dags/test_multiple_trigger_task_dag.py')
def start_notification_service(port: int = 50052): storage = MemoryEventStorage() notification_master \ = NotificationMaster(service=NotificationService(storage), port=port) notification_master.run(is_block=True)
class TestHighAvailableAIFlowServer(unittest.TestCase): @staticmethod def start_aiflow_server(host, port): port = str(port) server_uri = host + ":" + port server = AIFlowServer(store_uri=_SQLITE_DB_URI, port=port, enabled_ha=True, start_scheduler_service=False, ha_server_uri=server_uri, notification_uri='localhost:30031', start_default_notification=False) server.run() return server def wait_for_new_members_detected(self, new_member_uri): while True: living_member = self.client.living_aiflow_members if new_member_uri in living_member: break else: time.sleep(1) def setUp(self) -> None: SqlAlchemyStore(_SQLITE_DB_URI) self.notification = NotificationMaster( service=NotificationService(storage=MemoryEventStorage()), port=30031) self.notification.run() self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI, port=50051, enabled_ha=True, start_scheduler_service=False, ha_server_uri='localhost:50051', notification_uri='localhost:30031', start_default_notification=False) self.server1.run() self.server2 = None self.server3 = None self.config = ProjectConfig() self.config.set_enable_ha(True) self.config.set_notification_service_uri('localhost:30031') self.client = AIFlowClient( server_uri='localhost:50052,localhost:50051', project_config=self.config) def tearDown(self) -> None: self.client.stop_listen_event() self.client.disable_high_availability() if self.server1 is not None: self.server1.stop() if self.server2 is not None: self.server2.stop() if self.server3 is not None: self.server3.stop() if self.notification is not None: self.notification.stop() store = SqlAlchemyStore(_SQLITE_DB_URI) base.metadata.drop_all(store.db_engine) def test_server_change(self) -> None: self.client.register_project("test_project") projects = self.client.list_project(10, 0) self.assertEqual(self.client.current_aiflow_uri, "localhost:50051") self.assertEqual(projects[0].name, "test_project") self.server2 = self.start_aiflow_server("localhost", 50052) self.wait_for_new_members_detected("localhost:50052") self.server1.stop() projects = self.client.list_project(10, 0) self.assertEqual(self.client.current_aiflow_uri, "localhost:50052") self.assertEqual(projects[0].name, "test_project") self.server3 = self.start_aiflow_server("localhost", 50053) self.wait_for_new_members_detected("localhost:50053") self.server2.stop() projects = self.client.list_project(10, 0) self.assertEqual(self.client.current_aiflow_uri, "localhost:50053") self.assertEqual(projects[0].name, "test_project")
def run_server(): storage = MemoryEventStorage() master = NotificationMaster(service=NotificationService(storage), port=50051) master.run(is_block=True)