def test_user_trigger_parse_dag(self):
     port = 50101
     service_uri = 'localhost:{}'.format(port)
     storage = MemoryEventStorage()
     master = NotificationMaster(NotificationService(storage), port)
     master.run()
     mailbox = Mailbox()
     dag_trigger = DagTrigger("../../dags/test_scheduler_dags.py", -1, [], False, mailbox, 5, service_uri)
     dag_trigger.start()
     message = mailbox.get_message()
     message = SchedulerInnerEventUtil.to_inner_event(message)
     # only one dag is executable
     assert "test_task_start_date_scheduling" == message.dag_id
     sc = EventSchedulerClient(server_uri=service_uri, namespace='a')
     sc.trigger_parse_dag()
     dag_trigger.end()
     master.stop()
Exemple #2
0
class BaseExecutorTest(unittest.TestCase):
    def setUp(self):
        clear_db_event_model()
        self.master = NotificationMaster(service=NotificationService(EventModelStorage()))
        self.master.run()
        self.client = NotificationClient(server_uri="localhost:50051")

    def tearDown(self) -> None:
        self.master.stop()

    def test_get_event_buffer(self):
        executor = BaseExecutor()

        date = datetime.utcnow()
        try_number = 1
        key1 = ("my_dag1", "my_task1", date, try_number)
        key2 = ("my_dag2", "my_task1", date, try_number)
        key3 = ("my_dag2", "my_task2", date, try_number)
        state = State.SUCCESS
        executor.event_buffer[key1] = state
        executor.event_buffer[key2] = state
        executor.event_buffer[key3] = state

        self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1)
        self.assertEqual(len(executor.get_event_buffer()), 2)
        self.assertEqual(len(executor.event_buffer), 0)

    @mock.patch('airflow.executors.base_executor.BaseExecutor.sync')
    @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
    @mock.patch('airflow.settings.Stats.gauge')
    def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
        executor = BaseExecutor()
        executor.heartbeat()
        calls = [mock.call('executor.open_slots', mock.ANY),
                 mock.call('executor.queued_tasks', mock.ANY),
                 mock.call('executor.running_tasks', mock.ANY)]
        mock_stats_gauge.assert_has_calls(calls)

    def test_use_nf_executor(self):
        executor = BaseExecutor()
        executor.set_use_nf(True)
        executor.change_state('key', State.RUNNING)
        executor.change_state('key', State.SUCCESS)
        events = self.client.list_all_events(1)
        self.assertEqual(2, len(events))
class EventSchedulerJob(SchedulerJob):
    """
    EventSchedulerJob: The scheduler driven by events.
    The scheduler get the message from notification service, then scheduling the tasks which affected by the events.
    """

    __mapper_args__ = {'polymorphic_identity': 'EventSchedulerJob'}

    def __init__(self,
                 dag_id=None,
                 dag_ids=None,
                 subdir=settings.DAGS_FOLDER,
                 num_runs=conf.getint('scheduler', 'num_runs', fallback=-1),
                 processor_poll_interval=conf.getfloat(
                     'scheduler', 'processor_poll_interval', fallback=1),
                 use_local_nf=conf.getboolean('scheduler',
                                              'use_local_notification',
                                              fallback=True),
                 nf_host=conf.get('scheduler',
                                  'notification_host',
                                  fallback='localhost'),
                 nf_port=conf.getint('scheduler',
                                     'notification_port',
                                     fallback=50051),
                 unit_test_mode=conf.getboolean('core',
                                                'unit_test_mode',
                                                fallback=False),
                 executor_heartbeat_interval=conf.getint(
                     'scheduler', 'executor_heartbeat_interval', fallback=2),
                 run_duration=None,
                 do_pickle=False,
                 log=None,
                 *args,
                 **kwargs):
        super().__init__(dag_id, dag_ids, subdir, num_runs,
                         processor_poll_interval, run_duration, do_pickle, log,
                         *args, **kwargs)
        self.dag_trigger = None
        self.notification_master = None
        self.use_local_nf = use_local_nf
        self.nf_host = nf_host
        self.nf_port = nf_port
        self.mail_box = Mailbox()
        self.running = True
        self.dagrun_route = DagRunRoute()
        self.unit_test_mode = unit_test_mode
        self.executor_heartbeat_interval = executor_heartbeat_interval
        self.heartbeat_thread = None

    @provide_session
    def _get_dag_runs(self, event, session):
        dag_runs = []
        if EventType.is_in(event.event_type) and EventType(
                event.event_type) != EventType.UNDEFINED:
            if EventType(event.event_type) == EventType.DAG_RUN_EXECUTABLE:
                dag_run_id = int(event.key)
                dag_run = session.query(DagRun).filter(
                    DagRun.id == dag_run_id).first()
                if dag_run is None:
                    self.log.error("DagRun is None id {0}".format(dag_run_id))
                    return dag_runs
                simple_dag = event.simple_dag
                dag_run.pickle_id = None
                # create route
                self.dagrun_route.add_dagrun(dag_run, simple_dag, session)
                dag_runs.append(dag_run)

            elif EventType(event.event_type) == EventType.TASK_STATUS_CHANGED:
                dag_id, task_id, execution_date = TaskInstanceHelper.from_task_key(
                    event.key)
                state, try_num = TaskInstanceHelper.from_event_value(
                    event.value)
                dag_run = self.dagrun_route.find_dagrun(dag_id, execution_date)
                if dag_run is None:
                    return dag_runs
                self._set_task_instance_state(dag_run, dag_id, task_id,
                                              execution_date, state, try_num)

                sync_dag_run = session.query(DagRun).filter(
                    DagRun.id == dag_run.id).first()
                if sync_dag_run.state in State.finished():
                    self.log.info(
                        "DagRun finished dag_id {0} execution_date {1} state {2}"
                        .format(dag_run.dag_id, dag_run.execution_date,
                                sync_dag_run.state))
                    if self.dagrun_route.find_dagrun_by_id(
                            sync_dag_run.id) is not None:
                        self.dagrun_route.remove_dagrun(dag_run, session)
                        self.log.debug("Route remove dag run {0}".format(
                            sync_dag_run.id))
                        self.mail_box.send_message(
                            DagRunFinishedEvent(dag_run.id,
                                                sync_dag_run.state))
                else:
                    dag_runs.append(dag_run)

            elif EventType(event.event_type) == EventType.DAG_RUN_FINISHED:
                self.log.debug("DagRun {0} finished".format(event.key))
            elif EventType(event.event_type) == EventType.STOP_SCHEDULER_CMD:
                if self.unit_test_mode:
                    self.running = False
                return dag_runs
        else:
            runs = self.dagrun_route.find_dagruns_by_event(
                event_key=event.key, event_type=event.event_type)
            if runs is not None:
                for run in runs:
                    task_deps = load_task_dependencies(dag_id=run.dag_id,
                                                       session=session)
                    tis = run.get_task_instances(session=session)
                    for ti in tis:
                        if ti.task_id not in task_deps:
                            continue
                        if (event.key,
                                event.event_type) in task_deps[ti.task_id]:
                            self.log.debug("{0} handle event {1}".format(
                                ti.task_id, event))
                            ts = TaskState.query_task_state(ti,
                                                            session=session)
                            handler = ts.event_handler
                            if handler is not None:
                                action = handler.handle_event(event,
                                                              ti=ti,
                                                              ts=ts,
                                                              session=session)
                                ts.action = action
                                session.merge(ts)
                                session.commit()
                                self.log.debug(
                                    "set task action {0} {1}".format(
                                        ti.task_id, action))
                dag_runs.extend(runs)
                session.commit()

        for dag_run in dag_runs:
            run_process_func(target=process_tasks,
                             args=(
                                 dag_run,
                                 self.dagrun_route.find_simple_dag(dag_run.id),
                                 self.log,
                             ))
        return dag_runs

    @provide_session
    def _sync_event_to_db(self, event: Event, session=None):
        EventModel.sync_event(event=event, session=session)

    @provide_session
    def _run_event_loop(self, session=None):
        """
        The main process event loop
        :param session: the connection of db session.
        :return: None
        """
        while self.running:
            event: Event = self.mail_box.get_message()
            self.log.debug('EVENT: {0}'.format(event))
            if not self.use_local_nf:
                self._sync_event_to_db(session)
            try:
                dag_runs = self._get_dag_runs(event)
                if dag_runs is None or len(dag_runs) == 0:
                    continue
                # create SimpleDagBag
                simple_dags = []
                for dag_run in dag_runs:
                    simple_dags.append(
                        self.dagrun_route.find_simple_dag(
                            dagrun_id=dag_run.id))
                simple_dag_bag = SimpleDagBag(simple_dags)
                if not self._validate_and_run_task_instances(
                        simple_dag_bag=simple_dag_bag):
                    continue
            except Exception as e:
                self.log.exception(str(e))
        # scheduler end
        self.log.debug("_run_event_loop end")

    @provide_session
    def _init_route(self, session=None):
        """
        Init the DagRunRoute object from db.
        :param session:
        :return:
        """
        # running_dag_runs = session.query(DagRun).filter(DagRun.state == State.RUNNING).all()
        # for dag_run in running_dag_runs:
        #     dag_model = session.query(DagModel).filter(DagModel.dag_id == dag_run.dag_id).first()
        #     dagbag = models.DagBag(dag_model.fileloc)
        #     dag_run.dag = dagbag.get_dag(dag_run.dag_id)
        #     self.dagrun_route.add_dagrun(dag_run)
        # todo init route
        pass

    def _executor_heartbeat(self):
        while self.running:
            self.log.info("executor heartbeat...")
            self.executor.heartbeat()
            time.sleep(self.executor_heartbeat_interval)

    def _start_executor_heartbeat(self):

        self.heartbeat_thread = threading.Thread(
            target=self._executor_heartbeat, args=())
        self.heartbeat_thread.setDaemon(True)
        self.heartbeat_thread.start()

    def _stop_executor_heartheat(self):
        self.running = False
        if self.heartbeat_thread is not None:
            self.heartbeat_thread.join()

    def _execute(self):
        """
        1. Init the DagRun route.
        2. Start the executor.
        3. Option of start the notification master.
        4. Create the notification client.
        5. Start the DagTrigger.
        6. Run the scheduler event loop.
        :return:
        """
        notification_client = None
        try:
            self._init_route()
            self.executor.set_use_nf(True)
            self.executor.start()
            self.dag_trigger = DagTrigger(
                subdir=self.subdir,
                mailbox=self.mail_box,
                run_duration=self.run_duration,
                using_sqlite=self.using_sqlite,
                num_runs=self.num_runs,
                processor_poll_interval=self._processor_poll_interval)
            if self.use_local_nf:
                self.notification_master \
                    = NotificationMaster(service=NotificationService(EventModelStorage()), port=self.nf_port)
                self.notification_master.run()
                self.log.info("start notification service {0}".format(
                    self.nf_port))
                notification_client = NotificationClient(
                    server_uri="localhost:{0}".format(self.nf_port))
            else:
                notification_client \
                    = NotificationClient(server_uri="{0}:{1}".format(self.nf_host, self.nf_port))
            notification_client.start_listen_events(
                watcher=SCEventWatcher(self.mail_box))
            self.dag_trigger.start()
            self._start_executor_heartbeat()
            self._run_event_loop()
        except Exception as e:
            self.log.exception("Exception when executing _execute {0}".format(
                str(e)))
        finally:
            self.running = False
            self._stop_executor_heartheat()
            if self.dag_trigger is not None:
                self.dag_trigger.stop()
            if notification_client is not None:
                notification_client.stop_listen_events()
            if self.notification_master is not None:
                self.notification_master.stop()
            self.executor.end()
            self.log.info("Exited execute event scheduler")

    @provide_session
    def _set_task_instance_state(self,
                                 dag_run,
                                 dag_id,
                                 task_id,
                                 execution_date,
                                 state,
                                 try_number,
                                 session=None):
        """
        Set the task state to db and maybe set the dagrun object finished to db.
        :param dag_run: DagRun object
        :param dag_id: Dag identify
        :param task_id: task identify
        :param execution_date: the dag run execution date
        :param state: the task state should be set.
        :param try_number: the task try_number.
        :param session:
        :return:
        """
        TI = models.TaskInstance
        qry = session.query(TI).filter(TI.dag_id == dag_id,
                                       TI.task_id == task_id,
                                       TI.execution_date == execution_date)
        ti = qry.first()
        if not ti:
            self.log.warning("TaskInstance %s went missing from the database",
                             ti)
            return
        ts = TaskState.query_task_state(ti, session)
        self.log.debug(
            "set task state dag_id {0} task_id {1} execution_date {2} try_number {3} "
            "current try_number {4} state {5} ack_id {6} action {7}.".format(
                dag_id, task_id, execution_date, try_number, ti.try_number,
                state, ts.ack_id, ts.action))
        is_restart = False
        if state == State.FAILED or state == State.SUCCESS or state == State.SHUTDOWN:
            if ti.try_number == try_number and ti.state == State.QUEUED:
                msg = ("Executor reports task instance {} finished ({}) "
                       "although the task says its {}. Was the task "
                       "killed externally?".format(ti, state, ti.state))
                Stats.incr('scheduler.tasks.killed_externally')
                self.log.error(msg)
                try:
                    dag = self.task_route.find_dagrun(dag_id, execution_date)
                    ti.task = dag.get_task(task_id)
                    ti.handle_failure(msg)
                except Exception:
                    self.log.error(
                        "Cannot load the dag bag to handle failure for %s"
                        ". Setting task to FAILED without callbacks or "
                        "retries. Do you have enough resources?", ti)
                ti.state = State.FAILED
                session.merge(ti)
            else:
                if ts.action is None:
                    self.log.debug(
                        "task dag_id {0} task_id {1} execution_date {2} action is None."
                        .format(dag_id, task_id, execution_date))
                elif TaskAction(ts.action) == TaskAction.RESTART:
                    # if ts.stop_flag is not None and ts.stop_flag == try_number:
                    ti.state = State.SCHEDULED
                    ts.action = None
                    ts.stop_flag = None
                    ts.ack_id = 0
                    session.merge(ti)
                    session.merge(ts)
                    self.log.debug(
                        "task dag_id {0} task_id {1} execution_date {2} try_number {3} restart action."
                        .format(dag_id, task_id, execution_date,
                                str(try_number)))
                    is_restart = True
                elif TaskAction(ts.action) == TaskAction.STOP:
                    # if ts.stop_flag is not None and ts.stop_flag == try_number:
                    ts.action = None
                    ts.stop_flag = None
                    ts.ack_id = 0
                    session.merge(ts)
                    self.log.debug(
                        "task dag_id {0} task_id {1} execution_date {2} try_number {3} stop action."
                        .format(dag_id, task_id, execution_date,
                                str(try_number)))
                else:
                    self.log.debug(
                        "task dag_id {0} task_id {1} execution_date {2} action {3}."
                        .format(dag_id, task_id, execution_date, ts.action))
            session.commit()

        if not is_restart and ti.state == State.RUNNING:
            self.log.debug(
                "set task dag_id {0} task_id {1} execution_date {2} state {3}".
                format(dag_id, task_id, execution_date, state))
            ti.state = state
            session.merge(ti)
        session.commit()
        # update dagrun state
        sync_dag_run = session.query(DagRun).filter(
            DagRun.id == dag_run.id).first()
        if sync_dag_run.state not in FINISHED_STATES:
            if self.dagrun_route.find_dagrun_by_id(sync_dag_run.id) is None:
                self.log.error(
                    "DagRun lost dag_id {0} task_id {1} execution_date {2}".
                    format(dag_id, task_id, execution_date))
            else:
                run_process_func(target=dag_run_update_state,
                                 args=(
                                     dag_run,
                                     self.dagrun_route.find_simple_dag(
                                         dag_run.id),
                                 ))

    @provide_session
    def _create_task_instances(self, dag_run, session=None):
        """
        This method schedules the tasks for a single DAG by looking at the
        active DAG runs and adding task instances that should run to the
        queue.
        """

        # update the state of the previously active dag runs
        dag_runs = DagRun.find(dag_id=dag_run.dag_id,
                               state=State.RUNNING,
                               session=session)
        active_dag_runs = []
        for run in dag_runs:
            self.log.info("Examining DAG run %s", run)
            # don't consider runs that are executed in the future unless
            # specified by config and schedule_interval is None
            if run.execution_date > timezone.utcnow(
            ) and not dag_run.dag.allow_future_exec_dates:
                self.log.error("Execution date is in future: %s",
                               run.execution_date)
                continue

            if len(active_dag_runs) >= dag_run.dag.max_active_runs:
                self.log.info(
                    "Number of active dag runs reached max_active_run.")
                break

            # skip backfill dagruns for now as long as they are not really scheduled
            if run.is_backfill:
                continue

            run.dag = dag_run.dag

            # todo: preferably the integrity check happens at dag collection time
            run.verify_integrity(session=session)
            run.update_state(session=session)
            if run.state == State.RUNNING:
                make_transient(run)
                active_dag_runs.append(run)

    def _process_dags_and_create_dagruns(self, dagbag, dags, dagrun_out):
        """
        Iterates over the dags and processes them. Processing includes:

        1. Create appropriate DagRun(s) in the DB.
        2. Create appropriate TaskInstance(s) in the DB.
        3. Send emails for tasks that have missed SLAs.

        :param dagbag: a collection of DAGs to process
        :type dagbag: airflow.models.DagBag
        :param dags: the DAGs from the DagBag to process
        :type dags: list[airflow.models.DAG]
        :param dagrun_out: A list to add DagRun objects
        :type dagrun_out: list[DagRun]
        :rtype: None
        """
        for dag in dags:
            dag = dagbag.get_dag(dag.dag_id)
            if not dag:
                self.log.error("DAG ID %s was not found in the DagBag",
                               dag.dag_id)
                continue

            if dag.is_paused:
                self.log.info("Not processing DAG %s since it's paused",
                              dag.dag_id)
                continue

            self.log.info("Processing %s", dag.dag_id)

            dag_run = self.create_dag_run(dag)
            if dag_run:
                dag_run.dag = dag
                expected_start_date = dag.following_schedule(
                    dag_run.execution_date)
                if expected_start_date:
                    schedule_delay = dag_run.start_date - expected_start_date
                    Stats.timing(
                        'dagrun.schedule_delay.{dag_id}'.format(
                            dag_id=dag.dag_id), schedule_delay)
                self.log.info("Created %s", dag_run)
                self._create_task_instances(dag_run)
                self.log.info("Created tasks instances %s", dag_run)
                dagrun_out.append(dag_run)
            if conf.getboolean('core', 'CHECK_SLAS', fallback=True):
                self.manage_slas(dag)

    @provide_session
    def process_file(self,
                     file_path,
                     zombies,
                     pickle_dags=False,
                     session=None):
        """
        Process a Python file containing Airflow DAGs.

        This includes:

        1. Execute the file and look for DAG objects in the namespace.
        2. Pickle the DAG and save it to the DB (if necessary).
        3. For each DAG, see what tasks should run and create appropriate task
        instances in the DB.
        4. Record any errors importing the file into ORM
        5. Kill (in ORM) any task instances belonging to the DAGs that haven't
        issued a heartbeat in a while.

        Returns a list of SimpleDag objects that represent the DAGs found in
        the file

        :param file_path: the path to the Python file that should be executed
        :type file_path: unicode
        :param zombies: zombie task instances to kill.
        :type zombies: list[airflow.utils.dag_processing.SimpleTaskInstance]
        :param pickle_dags: whether serialize the DAGs found in the file and
            save them to the db
        :type pickle_dags: bool
        :return: a list of SimpleDagRuns made from the Dags found in the file
        :rtype: list[airflow.utils.dag_processing.SimpleDagBag]
        """
        self.log.info("Processing file %s for tasks to queue", file_path)
        if session is None:
            session = settings.Session()
        # As DAGs are parsed from this file, they will be converted into SimpleDags

        try:
            dagbag = models.DagBag(file_path, include_examples=False)
        except Exception:
            self.log.exception("Failed at reloading the DAG file %s",
                               file_path)
            Stats.incr('dag_file_refresh_error', 1, 1)
            return [], []

        if len(dagbag.dags) > 0:
            self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(),
                          file_path)
        else:
            self.log.warning("No viable dags retrieved from %s", file_path)
            self.update_import_errors(session, dagbag)
            return [], len(dagbag.import_errors)

        # Save individual DAGs in the ORM and update DagModel.last_scheduled_time
        for dag in dagbag.dags.values():
            dag.sync_to_db()

        paused_dag_ids = [
            dag.dag_id for dag in dagbag.dags.values() if dag.is_paused
        ]
        self.log.info("paused_dag_ids %s", paused_dag_ids)
        self.log.info("self %s", self.dag_ids)

        dag_to_pickle = {}
        # Pickle the DAGs (if necessary) and put them into a SimpleDag
        for dag_id in dagbag.dags:
            # Only return DAGs that are not paused
            if dag_id not in paused_dag_ids:
                dag = dagbag.get_dag(dag_id)
                pickle_id = None
                if pickle_dags:
                    pickle_id = dag.pickle(session).id
                dag_to_pickle[dag.dag_id] = pickle_id

        if len(self.dag_ids) > 0:
            dags = [
                dag for dag in dagbag.dags.values() if
                dag.dag_id in self.dag_ids and dag.dag_id not in paused_dag_ids
            ]
        else:
            dags = [
                dag for dag in dagbag.dags.values()
                if not dag.parent_dag and dag.dag_id not in paused_dag_ids
            ]

        # Not using multiprocessing.Queue() since it's no longer a separate
        # process and due to some unusual behavior. (empty() incorrectly
        # returns true as described in https://bugs.python.org/issue23582 )
        self.log.info("dags %s", dags)
        dag_run_out = []
        self._process_dags_and_create_dagruns(dagbag, dags, dag_run_out)
        self.log.info("dag run out %s", len(dag_run_out))
        simple_dag_runs = []
        for dag_run in dag_run_out:
            simple_dag_runs.append(
                SimpleDagRun(dag_run.id, SimpleDag(dag_run.dag)))
        # commit batch
        session.commit()

        # Record import errors into the ORM
        try:
            self.update_import_errors(session, dagbag)
        except Exception:
            self.log.exception("Error logging import errors!")
        try:
            dagbag.kill_zombies(zombies)
        except Exception:
            self.log.exception("Error killing zombies!")

        return simple_dag_runs, len(dagbag.import_errors)
Exemple #4
0
class TestEventBasedScheduler(unittest.TestCase):
    def setUp(self):
        db.clear_db_jobs()
        db.clear_db_dags()
        db.clear_db_serialized_dags()
        db.clear_db_runs()
        db.clear_db_task_execution()
        db.clear_db_message()
        self.scheduler = None
        self.port = 50102
        self.storage = MemoryEventStorage()
        self.master = NotificationMaster(NotificationService(self.storage),
                                         self.port)
        self.master.run()
        self.client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                         default_namespace="test_namespace")
        time.sleep(1)

    def tearDown(self):
        self.master.stop()

    def _get_task_instance(self, dag_id, task_id, session):
        return session.query(TaskInstance).filter(
            TaskInstance.dag_id == dag_id,
            TaskInstance.task_id == task_id).first()

    def schedule_task_function(self):
        stopped = False
        while not stopped:
            with create_session() as session:
                ti_sleep_1000_secs = self._get_task_instance(
                    EVENT_BASED_SCHEDULER_DAG, 'sleep_1000_secs', session)
                ti_python_sleep = self._get_task_instance(
                    EVENT_BASED_SCHEDULER_DAG, 'python_sleep', session)
                if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.SCHEDULED and \
                   ti_python_sleep and ti_python_sleep.state == State.SCHEDULED:
                    self.client.send_event(
                        BaseEvent(key='start',
                                  value='',
                                  event_type='',
                                  namespace='test_namespace'))

                    while not stopped:
                        ti_sleep_1000_secs.refresh_from_db()
                        ti_python_sleep.refresh_from_db()
                        if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.RUNNING and \
                           ti_python_sleep and ti_python_sleep.state == State.RUNNING:
                            time.sleep(10)
                            break
                        else:
                            time.sleep(1)
                    self.client.send_event(
                        BaseEvent(key='stop',
                                  value='',
                                  event_type=UNDEFINED_EVENT_TYPE,
                                  namespace='test_namespace'))
                    self.client.send_event(
                        BaseEvent(key='restart',
                                  value='',
                                  event_type=UNDEFINED_EVENT_TYPE,
                                  namespace='test_namespace'))
                    while not stopped:
                        ti_sleep_1000_secs.refresh_from_db()
                        ti_python_sleep.refresh_from_db()
                        if ti_sleep_1000_secs and ti_sleep_1000_secs.state == State.KILLED and \
                           ti_python_sleep and ti_python_sleep.state == State.RUNNING:
                            stopped = True
                        else:
                            time.sleep(1)
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_event_based_scheduler(self):
        t = threading.Thread(target=self.schedule_task_function)
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_based_scheduler.py')

    def test_replay_message(self):
        key = "stop"
        mailbox = Mailbox()
        mailbox.set_scheduling_job_id(1234)
        watcher = SchedulerEventWatcher(mailbox)
        self.client.start_listen_events(watcher=watcher,
                                        start_time=int(time.time() * 1000),
                                        version=None)
        self.send_event(key)
        msg: BaseEvent = mailbox.get_message()
        self.assertEqual(msg.key, key)
        with create_session() as session:
            msg_from_db = session.query(Message).first()
            expect_non_unprocessed = EventBasedScheduler.get_unprocessed_message(
                1000)
            self.assertEqual(0, len(expect_non_unprocessed))
            unprocessed = EventBasedScheduler.get_unprocessed_message(1234)
            self.assertEqual(unprocessed[0].serialized_message,
                             msg_from_db.data)
        deserialized_data = pickle.loads(msg_from_db.data)
        self.assertEqual(deserialized_data.key, key)
        self.assertEqual(msg, deserialized_data)

    def send_event(self, key):
        event = self.client.send_event(
            BaseEvent(key=key, event_type=UNDEFINED_EVENT_TYPE,
                      value="value1"))
        self.assertEqual(key, event.key)

    @provide_session
    def get_task_execution(self, dag_id, task_id, session):
        return session.query(TaskExecution).filter(
            TaskExecution.dag_id == dag_id,
            TaskExecution.task_id == task_id).all()

    @provide_session
    def get_latest_job_id(self, session):
        return session.query(BaseJob).order_by(sqlalchemy.desc(
            BaseJob.id)).first().id

    def start_scheduler(self, file_path):
        self.scheduler = EventBasedSchedulerJob(
            dag_directory=file_path,
            server_uri="localhost:{}".format(self.port),
            executor=LocalExecutor(3),
            max_runs=-1,
            refresh_dag_dir_interval=30)
        print("scheduler starting")
        self.scheduler.run()

    def wait_for_running(self):
        while True:
            if self.scheduler is not None:
                time.sleep(5)
                break
            else:
                time.sleep(1)

    def wait_for_task_execution(self, dag_id, task_id, expected_num):
        result = False
        check_nums = 100
        while check_nums > 0:
            time.sleep(2)
            check_nums = check_nums - 1
            tes = self.get_task_execution(dag_id, task_id)
            if len(tes) == expected_num:
                result = True
                break
        self.assertTrue(result)

    def wait_for_task(self, dag_id, task_id, expected_state):
        result = False
        check_nums = 100
        while check_nums > 0:
            time.sleep(2)
            check_nums = check_nums - 1
            with create_session() as session:
                ti = session.query(TaskInstance).filter(
                    TaskInstance.dag_id == dag_id,
                    TaskInstance.task_id == task_id).first()
            if ti and ti.state == expected_state:
                result = True
                break
        self.assertTrue(result)

    def test_notification(self):
        self.client.send_event(BaseEvent(key='a', value='b'))

    def run_a_task_function(self):
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_a_task(self):
        t = threading.Thread(target=self.run_a_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_single_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution("single", "task_1")
        self.assertEqual(len(tes), 1)

    def run_event_task_function(self):
        client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                    default_namespace="")
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'event_dag',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    time.sleep(5)
                    client.send_event(
                        BaseEvent(key='start',
                                  value='',
                                  event_type='',
                                  namespace=''))
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'event_dag',
                                TaskExecution.task_id == 'task_2').all()
                            if len(tes_2) > 0:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_event_task(self):
        t = threading.Thread(target=self.run_event_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "event_dag", "task_2")
        self.assertEqual(len(tes), 1)

    def run_trigger_dag_function(self):
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="")
        client = EventSchedulerClient(ns_client=ns_client)
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'trigger_dag',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    break
                else:
                    client.trigger_parse_dag()
                    result = client.schedule_dag('trigger_dag')
                    print('result {}'.format(result.dagrun_id))
                time.sleep(5)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_trigger_dag(self):
        import multiprocessing
        p = multiprocessing.Process(target=self.run_trigger_dag_function,
                                    args=())
        p.start()
        self.start_scheduler('../../dags/test_run_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_dag", "task_1")
        self.assertEqual(len(tes), 1)

    def run_no_dag_file_function(self):
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="")
        client = EventSchedulerClient(ns_client=ns_client)
        with create_session() as session:
            client.trigger_parse_dag()
            result = client.schedule_dag('no_dag')
            print('result {}'.format(result.dagrun_id))
            time.sleep(5)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_no_dag_file_trigger_dag(self):
        import multiprocessing
        p = multiprocessing.Process(target=self.run_no_dag_file_function,
                                    args=())
        p.start()
        self.start_scheduler('../../dags/test_run_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_dag", "task_1")
        self.assertEqual(len(tes), 0)

    def run_trigger_task_function(self):
        # waiting parsed dag file done,
        time.sleep(5)
        ns_client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                       default_namespace="a")
        client = EventSchedulerClient(ns_client=ns_client)
        execution_context = client.schedule_dag('trigger_task')
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'trigger_task',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 0:
                    client.schedule_task('trigger_task', 'task_2',
                                         SchedulingAction.START,
                                         execution_context)
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'trigger_task',
                                TaskExecution.task_id == 'task_2').all()
                            if len(tes_2) > 0:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        ns_client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_task_trigger_dag(self):
        import threading
        t = threading.Thread(target=self.run_trigger_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_task_trigger_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "trigger_task", "task_2")
        self.assertEqual(len(tes), 1)

    def run_ai_flow_function(self):
        client = NotificationClient(server_uri="localhost:{}".format(
            self.port),
                                    default_namespace="default",
                                    sender='1-job-name')
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'workflow_1',
                    TaskExecution.task_id == '1-job-name').all()
                if len(tes) > 0:
                    time.sleep(5)
                    client.send_event(
                        BaseEvent(key='key_1',
                                  value='value_1',
                                  event_type='UNDEFINED'))
                    client.send_event(
                        BaseEvent(key='key_2',
                                  value='value_2',
                                  event_type='UNDEFINED'))
                    while True:
                        with create_session() as session_2:
                            tes_2 = session_2.query(TaskExecution).filter(
                                TaskExecution.dag_id == 'workflow_1').all()
                            if len(tes_2) == 3:
                                break
                            else:
                                time.sleep(1)
                    break
                else:
                    time.sleep(1)
        time.sleep(3)
        client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_ai_flow_dag(self):
        import threading
        t = threading.Thread(target=self.run_ai_flow_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_aiflow_dag.py')
        tes: List[TaskExecution] = self.get_task_execution(
            "workflow_1", "1-job-name")
        self.assertEqual(len(tes), 1)

    def stop_dag_function(self):
        stopped = False
        while not stopped:
            tes = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG,
                                          'sleep_to_be_stopped')
            if tes and len(tes) == 1:
                self.client.send_event(
                    StopDagEvent(EVENT_BASED_SCHEDULER_DAG).to_event())
                while not stopped:
                    tes2 = self.get_task_execution(EVENT_BASED_SCHEDULER_DAG,
                                                   'sleep_to_be_stopped')
                    if tes2[0].state == State.KILLED:
                        stopped = True
                        time.sleep(5)
                    else:
                        time.sleep(1)
            else:
                time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_stop_dag(self):
        t = threading.Thread(target=self.stop_dag_function)
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_event_based_scheduler.py')
        with create_session() as session:
            from airflow.models import DagModel
            dag_model: DagModel = DagModel.get_dagmodel(
                EVENT_BASED_SCHEDULER_DAG)
            self.assertTrue(dag_model.is_paused)
            self.assertEqual(dag_model.get_last_dagrun().state, "killed")
            for ti in session.query(TaskInstance).filter(
                    TaskInstance.dag_id == EVENT_BASED_SCHEDULER_DAG):
                self.assertTrue(ti.state in [State.SUCCESS, State.KILLED])
            for te in session.query(TaskExecution).filter(
                    TaskExecution.dag_id == EVENT_BASED_SCHEDULER_DAG):
                self.assertTrue(te.state in [State.SUCCESS, State.KILLED])

    def run_periodic_task_function(self):
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) > 1:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_periodic_task(self):
        t = threading.Thread(target=self.run_periodic_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_periodic_task_dag.py')
        tes: List[TaskExecution] = self.get_task_execution("single", "task_1")
        self.assertGreater(len(tes), 1)

    def run_one_task_function(self):
        self.wait_for_running()
        self.client.send_event(BaseEvent(key='a', value='a'))
        time.sleep(5)
        self.client.send_event(BaseEvent(key='a', value='a'))
        while True:
            with create_session() as session:
                tes = session.query(TaskExecution).filter(
                    TaskExecution.dag_id == 'single',
                    TaskExecution.task_id == 'task_1').all()
                if len(tes) >= 2:
                    break
                else:
                    time.sleep(1)
        self.client.send_event(StopSchedulerEvent(job_id=0).to_event())

    def test_run_one_task(self):
        t = threading.Thread(target=self.run_one_task_function, args=())
        t.setDaemon(True)
        t.start()
        self.start_scheduler('../../dags/test_multiple_trigger_task_dag.py')
Exemple #5
0
class TestHighAvailableAIFlowServer(unittest.TestCase):
    @staticmethod
    def start_aiflow_server(host, port):
        port = str(port)
        server_uri = host + ":" + port
        server = AIFlowServer(store_uri=_SQLITE_DB_URI,
                              port=port,
                              enabled_ha=True,
                              start_scheduler_service=False,
                              ha_server_uri=server_uri,
                              notification_uri='localhost:30031',
                              start_default_notification=False)
        server.run()
        return server

    def wait_for_new_members_detected(self, new_member_uri):
        while True:
            living_member = self.client.living_aiflow_members
            if new_member_uri in living_member:
                break
            else:
                time.sleep(1)

    def setUp(self) -> None:
        SqlAlchemyStore(_SQLITE_DB_URI)
        self.notification = NotificationMaster(
            service=NotificationService(storage=MemoryEventStorage()),
            port=30031)
        self.notification.run()
        self.server1 = AIFlowServer(store_uri=_SQLITE_DB_URI,
                                    port=50051,
                                    enabled_ha=True,
                                    start_scheduler_service=False,
                                    ha_server_uri='localhost:50051',
                                    notification_uri='localhost:30031',
                                    start_default_notification=False)
        self.server1.run()
        self.server2 = None
        self.server3 = None
        self.config = ProjectConfig()
        self.config.set_enable_ha(True)
        self.config.set_notification_service_uri('localhost:30031')
        self.client = AIFlowClient(
            server_uri='localhost:50052,localhost:50051',
            project_config=self.config)

    def tearDown(self) -> None:
        self.client.stop_listen_event()
        self.client.disable_high_availability()
        if self.server1 is not None:
            self.server1.stop()
        if self.server2 is not None:
            self.server2.stop()
        if self.server3 is not None:
            self.server3.stop()
        if self.notification is not None:
            self.notification.stop()
        store = SqlAlchemyStore(_SQLITE_DB_URI)
        base.metadata.drop_all(store.db_engine)

    def test_server_change(self) -> None:
        self.client.register_project("test_project")
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50051")
        self.assertEqual(projects[0].name, "test_project")

        self.server2 = self.start_aiflow_server("localhost", 50052)
        self.wait_for_new_members_detected("localhost:50052")
        self.server1.stop()
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50052")
        self.assertEqual(projects[0].name, "test_project")

        self.server3 = self.start_aiflow_server("localhost", 50053)
        self.wait_for_new_members_detected("localhost:50053")
        self.server2.stop()
        projects = self.client.list_project(10, 0)
        self.assertEqual(self.client.current_aiflow_uri, "localhost:50053")
        self.assertEqual(projects[0].name, "test_project")