Exemplo n.º 1
0
    def __init__(
        self,
        subdir: str = settings.DAGS_FOLDER,
        num_runs: int = conf.getint('scheduler', 'num_runs'),
        num_times_parse_dags: int = -1,
        processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'),
        do_pickle: bool = False,
        log: logging.Logger = None,
        *args,
        **kwargs,
    ):
        self.subdir = subdir

        self.num_runs = num_runs
        # In specific tests, we want to stop the parse loop after the _files_ have been parsed a certain
        # number of times. This is only to support testing, and isn't something a user is likely to want to
        # configure -- they'll want num_runs
        self.num_times_parse_dags = num_times_parse_dags
        self._processor_poll_interval = processor_poll_interval

        self.do_pickle = do_pickle
        super().__init__(*args, **kwargs)

        if log:
            self._log = log

        # Check what SQL backend we use
        sql_conn: str = conf.get('core', 'sql_alchemy_conn').lower()
        self.using_sqlite = sql_conn.startswith('sqlite')
        self.using_mysql = sql_conn.startswith('mysql')

        self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query')
        self.processor_agent: Optional[DagFileProcessorAgent] = None

        self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
def test_dags_integrity():
    dag_bag = DagBag()
    # Assert that all DAGs can be imported, i.e., all parameters required
    # for DAGs are specified and no task cycles present.
    assert dag_bag.import_errors == {}

    for dag_id in dag_bag.dag_ids:
        dag = dag_bag.get_dag(dag_id=dag_id)
        # Assert that a DAG hash at least one task.
        assert dag is not None
        assert len(dag.tasks) > 0
Exemplo n.º 3
0
    def start(self, ds, **kwargs):
        conf = kwargs['dag_run'].conf

        AIRFLOW_API = 'http://airflow-service.flow.svc:8080/'
        r = requests.get(f'{AIRFLOW_API}flow/kaapana/api/getdags')
        db_dags = []
        for key, value in r.json().items():
            db_dags.append(value['dag_id'])
        print('db', db_dags)

        airflow_home = os.environ.get('AIRFLOW_HOME')
        dagbag = DagBag(os.path.join(airflow_home, 'dags'))

        file_dags = []
        for key, dag in dagbag.dags.items():
            file_dags.append(dag.dag_id)
        print('file_dags', file_dags)

        dags_to_delete = [item for item in db_dags if item not in file_dags]
        print(dags_to_delete)
        for dag_id in dags_to_delete:
            print('Deleting', dag_id)
            r = requests.delete(
                f'{AIRFLOW_API}flow/api/experimental/dags/{dag_id}')
            print(r.status_code)
            print(r.text)

        return
Exemplo n.º 4
0
 def __init__(self,
              id,
              mailbox: Mailbox,
              task_event_manager: DagRunEventManager,
              executor: BaseExecutor,
              notification_client: NotificationClient,
              context=None):
     super().__init__(context)
     self.id = id
     self.mailbox = mailbox
     self.task_event_manager: DagRunEventManager = task_event_manager
     self.executor = executor
     self.notification_client = notification_client
     self.dagbag = DagBag(read_dags_from_db=True)
     self._timer_handler = None
     self.timers = sched.scheduler()
Exemplo n.º 5
0
    def list(self):
        """
        DVC Pushes view displays information about commit generated by the DVC operators
        """
        operators: List[AnyDVCOperator] = []
        for dag in DagBag().dags.values():
            for task in dag.tasks:
                if (isinstance(task, DVCDownloadOperator)
                        or isinstance(task, DVCUpdateOperator)
                        or isinstance(task, DVCUpdateSensor)):
                    setattr(task, "dag", dag)
                    operators.append(task)

        repos: Dict[str, List[AnyDVCOperator]] = defaultdict(list)
        for operator in operators:
            repos[operator.dvc_repo].append(operator)

        all_commits: List[DVCCommit] = []
        for repo in repos.keys():
            hook = DVCHook(repo)
            all_commits += hook.list_dag_commits()
        for commit in all_commits:
            repo_url_info = parse_git_url(commit.dvc_repo)
            target_name = f"{repo_url_info.owner}/{repo_url_info.repo}"
            commit.dvc_repo_name = target_name

        return self.render_template(
            "dvc/pushes.html",
            all_commits=all_commits,
        )
Exemplo n.º 6
0
    def execute_callbacks(self,
                          dagbag: DagBag,
                          callback_requests: List[CallbackRequest],
                          session: Session = NEW_SESSION) -> None:
        """
        Execute on failure callbacks. These objects can come from SchedulerJob or from
        DagFileProcessorManager.

        :param dagbag: Dag Bag of dags
        :param callback_requests: failure callbacks to execute
        :param session: DB session.
        """
        for request in callback_requests:
            self.log.debug("Processing Callback Request: %s", request)
            try:
                if isinstance(request, TaskCallbackRequest):
                    self._execute_task_callbacks(dagbag, request)
                elif isinstance(request, SlaCallbackRequest):
                    self.manage_slas(dagbag.get_dag(request.dag_id),
                                     session=session)
                elif isinstance(request, DagCallbackRequest):
                    self._execute_dag_callbacks(dagbag, request, session)
            except Exception:
                self.log.exception(
                    "Error executing %s callback for file: %s",
                    request.__class__.__name__,
                    request.full_filepath,
                )

        session.commit()
def make_dagster_repo_from_airflow_dags_path(
    dag_path,
    repo_name,
    safe_mode=True,
    store_serialized_dags=False,
    use_airflow_template_context=False,
):
    ''' Construct a Dagster repository corresponding to Airflow DAGs in dag_path.

    DagBag.get_dag() dependency requires Airflow DB to be initialized.

    Usage:

        Create `make_dagster_repo.py`:
            from dagster_airflow.dagster_pipeline_factory import make_dagster_repo_from_airflow_dags_path

            def make_repo_from_dir():
                return make_dagster_repo_from_airflow_dags_path(
                    '/path/to/dags/', 'my_repo_name'
                )
        Use RepositoryDefinition as usual, for example:
            `dagit -f path/to/make_dagster_repo.py -n make_repo_from_dir`

    Args:
        dag_path (str): Path to directory or file that contains Airflow Dags
        repo_name (str): Name for generated RepositoryDefinition
        include_examples (bool): True to include Airflow's example DAGs. (default: False)
        safe_mode (bool): True to use Airflow's default heuristic to find files that contain DAGs
            (ie find files that contain both b'DAG' and b'airflow') (default: True)
        store_serialized_dags (bool): True to read Airflow DAGS from Airflow DB. False to read DAGS
            from Python files. (default: False)
        use_airflow_template_context (bool): If True, will call get_template_context() on the
            Airflow TaskInstance model which requires and modifies the DagRun table.
            (default: False)

    Returns:
        RepositoryDefinition
    '''
    check.str_param(dag_path, 'dag_path')
    check.str_param(repo_name, 'repo_name')
    check.bool_param(safe_mode, 'safe_mode')
    check.bool_param(store_serialized_dags, 'store_serialized_dags')
    check.bool_param(use_airflow_template_context,
                     'use_airflow_template_context')

    try:
        dag_bag = DagBag(
            dag_folder=dag_path,
            include_examples=False,  # Exclude Airflow example dags
            safe_mode=safe_mode,
            store_serialized_dags=store_serialized_dags,
        )
    except Exception:  # pylint: disable=broad-except
        raise DagsterAirflowError(
            'Error initializing airflow.models.dagbag object with arguments')

    return make_dagster_repo_from_airflow_dag_bag(
        dag_bag, repo_name, use_airflow_template_context)
Exemplo n.º 8
0
    def _check_for_existence(self, session) -> None:
        dag_to_wait = session.query(DagModel).filter(DagModel.dag_id == self.external_dag_id).first()

        if not dag_to_wait:
            raise AirflowException(f'The external DAG {self.external_dag_id} does not exist.')

        if not os.path.exists(dag_to_wait.fileloc):
            raise AirflowException(f'The external DAG {self.external_dag_id} was deleted.')

        if self.external_task_ids:
            refreshed_dag_info = DagBag(dag_to_wait.fileloc).get_dag(self.external_dag_id)
            for external_task_id in self.external_task_ids:
                if not refreshed_dag_info.has_task(external_task_id):
                    raise AirflowException(
                        f'The external task {external_task_id} in '
                        f'DAG {self.external_dag_id} does not exist.'
                    )
        self._has_checked_existence = True
Exemplo n.º 9
0
    def test_heartbeat_failed_fast(self):
        """
        Test that task heartbeat will sleep when it fails fast
        """
        self.mock_base_job_sleep.side_effect = time.sleep

        with create_session() as session:
            dagbag = DagBag(
                dag_folder=TEST_DAG_FOLDER,
                include_examples=False,
            )
            dag_id = 'test_heartbeat_failed_fast'
            task_id = 'test_heartbeat_failed_fast_op'
            dag = dagbag.get_dag(dag_id)
            task = dag.get_task(task_id)

            dag.create_dagrun(
                run_id="test_heartbeat_failed_fast_run",
                state=State.RUNNING,
                execution_date=DEFAULT_DATE,
                start_date=DEFAULT_DATE,
                session=session,
            )
            ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
            ti.refresh_from_db()
            ti.state = State.RUNNING
            ti.hostname = get_hostname()
            ti.pid = 1
            session.commit()

            job = LocalTaskJob(task_instance=ti,
                               executor=MockExecutor(do_update=False))
            job.heartrate = 2
            heartbeat_records = []
            job.heartbeat_callback = lambda session: heartbeat_records.append(
                job.latest_heartbeat)
            job._execute()
            self.assertGreater(len(heartbeat_records), 2)
            for i in range(1, len(heartbeat_records)):
                time1 = heartbeat_records[i - 1]
                time2 = heartbeat_records[i]
                # Assert that difference small enough
                delta = (time2 - time1).total_seconds()
                self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
Exemplo n.º 10
0
    def run_dag(self,
                dag_id: str,
                dag_folder: str = DEFAULT_DAG_FOLDER) -> None:
        """
        Runs example dag by it's ID.

        :param dag_id: id of a DAG to be run
        :type dag_id: str
        :param dag_folder: directory where to look for the specific DAG. Relative to AIRFLOW_HOME.
        :type dag_folder: str
        """
        if os.environ.get("RUN_AIRFLOW_1_10") == "true":
            # For system tests purpose we are changing airflow/providers
            # to side packages path of the installed providers package
            python = f"python{sys.version_info.major}.{sys.version_info.minor}"
            dag_folder = dag_folder.replace(
                "/opt/airflow/airflow/providers",
                f"/usr/local/lib/{python}/site-packages/airflow/providers",
            )
        self.log.info("Looking for DAG: %s in %s", dag_id, dag_folder)
        dag_bag = DagBag(dag_folder=dag_folder, include_examples=False)
        dag = dag_bag.get_dag(dag_id)
        if dag is None:
            raise AirflowException(
                "The Dag {dag_id} could not be found. It's either an import problem,"
                "wrong dag_id or DAG is not in provided dag_folder."
                "The content of the {dag_folder} folder is {content}".format(
                    dag_id=dag_id,
                    dag_folder=dag_folder,
                    content=os.listdir(dag_folder),
                ))

        self.log.info("Attempting to run DAG: %s", dag_id)
        if os.environ.get("RUN_AIRFLOW_1_10") == "true":
            dag.clear()
        else:
            dag.clear(dag_run_state=State.NONE)
        try:
            dag.run(ignore_first_depends_on_past=True, verbose=True)
        except Exception:
            self._print_all_log_files()
            raise
Exemplo n.º 11
0
    def run_dag(self,
                dag_id: str,
                dag_folder: str = DEFAULT_DAG_FOLDER) -> None:
        """
        Runs example dag by it's ID.

        :param dag_id: id of a DAG to be run
        :type dag_id: str
        :param dag_folder: directory where to look for the specific DAG. Relative to AIRFLOW_HOME.
        :type dag_folder: str
        """
        if os.environ.get("RUN_AIRFLOW_1_10"):
            # For system tests purpose we are mounting airflow/providers to /providers folder
            # So that we can get example_dags from there
            dag_folder = dag_folder.replace("/opt/airflow/airflow/providers",
                                            "/providers")
            temp_dir = mkdtemp()
            os.rmdir(temp_dir)
            shutil.copytree(dag_folder, temp_dir)
            dag_folder = temp_dir
            self.correct_imports_for_airflow_1_10(temp_dir)
        self.log.info("Looking for DAG: %s in %s", dag_id, dag_folder)
        dag_bag = DagBag(dag_folder=dag_folder, include_examples=False)
        dag = dag_bag.get_dag(dag_id)
        if dag is None:
            raise AirflowException(
                "The Dag {dag_id} could not be found. It's either an import problem,"
                "wrong dag_id or DAG is not in provided dag_folder."
                "The content of the {dag_folder} folder is {content}".format(
                    dag_id=dag_id,
                    dag_folder=dag_folder,
                    content=os.listdir(dag_folder),
                ))

        self.log.info("Attempting to run DAG: %s", dag_id)
        dag.clear(reset_dag_runs=True)
        try:
            dag.run(ignore_first_depends_on_past=True, verbose=True)
        except Exception:
            self._print_all_log_files()
            raise
Exemplo n.º 12
0
    def test_localtaskjob_maintain_heart_rate(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())

        # this should make sure we only heartbeat once and exit at the second
        # loop in _execute()
        return_codes = [None, 0]

        def multi_return_code():
            return return_codes.pop(0)

        time_start = time.time()
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_start:
            with patch.object(StandardTaskRunner,
                              'return_code') as mock_ret_code:
                mock_ret_code.side_effect = multi_return_code
                job1.run()
                self.assertEqual(mock_start.call_count, 1)
                self.assertEqual(mock_ret_code.call_count, 2)
        time_end = time.time()

        self.assertEqual(self.mock_base_job_sleep.call_count, 1)
        self.assertEqual(job1.state, State.SUCCESS)

        # Consider we have patched sleep call, it should not be sleeping to
        # keep up with the heart rate in other unpatched places
        #
        # We already make sure patched sleep call is only called once
        self.assertLess(time_end - time_start, job1.heartrate)
        session.close()
Exemplo n.º 13
0
    def _list_dags(self):
        dagbag = DagBag()
        dags = []

        for dag_id in dagbag.dags:
            orm_dag = DagModel.get_current(dag_id)
            # inactive DAGs can't be backfilled....
            is_active = (
                not orm_dag.is_paused) if orm_dag is not None else False

            if is_active:
                dags.append(dag_id)

        return dags
def make_dagster_repo_from_airflow_example_dags(
        repo_name='airflow_example_dags_repo'):
    ''' Construct a Dagster repository for Airflow's example DAGs.

    Execution of the following Airflow example DAGs is not currently supported:
            'example_external_task_marker_child',
            'example_pig_operator',
            'example_skip_dag',
            'example_trigger_target_dag',
            'example_xcom',
            'test_utils',

    Usage:

        Create `make_dagster_repo.py`:
            from dagster_airflow.dagster_pipeline_factory import make_dagster_repo_from_airflow_example_dags

            def make_airflow_example_dags():
                return make_dagster_repo_from_airflow_example_dags()

        Use RepositoryDefinition as usual, for example:
            `dagit -f path/to/make_dagster_repo.py -n make_airflow_example_dags`

    Args:
        repo_name (str): Name for generated RepositoryDefinition

    Returns:
        RepositoryDefinition
    '''
    dag_bag = DagBag(
        dag_folder=
        'some/empty/folder/with/no/dags',  # prevent defaulting to settings.DAGS_FOLDER
        include_examples=True,
    )

    # There is a bug in Airflow v1.10.8, v1.10.9, v1.10.10 where the python_callable for task
    # 'search_catalog' is missing a required position argument '_'. It is currently fixed in master.
    # v1.10 stable: https://github.com/apache/airflow/blob/v1-10-stable/airflow/example_dags/example_complex.py#L133
    # master (05-05-2020): https://github.com/apache/airflow/blob/master/airflow/example_dags/example_complex.py#L136
    patch_airflow_example_dag(dag_bag)

    return make_dagster_repo_from_airflow_dag_bag(dag_bag, repo_name)
Exemplo n.º 15
0
    def test_mark_success_no_kill(self):
        """
        Test that ensures that mark_success in the UI doesn't cause
        the task to fail, and that the task exits
        """
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_mark_success')
        task = dag.get_task('task1')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(
            run_id="test",
            state=State.RUNNING,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
        process = multiprocessing.Process(target=job1.run)
        process.start()
        ti.refresh_from_db()
        for _ in range(0, 50):
            if ti.state == State.RUNNING:
                break
            time.sleep(0.1)
            ti.refresh_from_db()
        self.assertEqual(State.RUNNING, ti.state)
        ti.state = State.SUCCESS
        session.merge(ti)
        session.commit()

        process.join(timeout=10)
        self.assertFalse(process.is_alive())
        ti.refresh_from_db()
        self.assertEqual(State.SUCCESS, ti.state)
Exemplo n.º 16
0
    def list_dag_commits(
        self,
        temp_path: Optional[str] = None,
    ) -> List[DVCCommit]:
        """
        Returns list of all commits generated for the given DVC repository.

        :param temp_path: Optional temporary clone path
        :returns: List with commits generated by the DVC operators
        """
        _, temp_dir, repo, _ = clone_repo(self.dvc_repo, temp_path)
        commits = list(repo.iter_commits(max_count=100, ))

        results: List[DVCCommit] = []
        for commit in commits:
            message_footer = commit.message.split("\n")[-1].split(" ")
            if len(message_footer) == 2 and message_footer[0] == "dag:":
                # Find DAGs responsible for this commit
                commit_dags = [
                    dag for dag in DagBag().dags.values()
                    if dag.dag_id == message_footer[1]
                ]
                if len(commit_dags) > 0:
                    results.append(
                        DVCCommit(
                            dvc_repo=self.dvc_repo,
                            dvc_repo_name=self.dvc_repo,
                            files=[
                                file_path.replace(".dvc", "")
                                for file_path in commit.stats.files.keys()
                                if ".dvc" in file_path
                            ],
                            message="\n".join(commit.message.split("\n")[:-1]),
                            date=datetime.datetime.fromtimestamp(
                                commit.committed_date),
                            dag=commit_dags[0],
                            sha=commit.hexsha,
                            commit_url=
                            f"{self.dvc_repo}/commits/{commit.hexsha}".replace(
                                ".git", ""),
                        ))
        return results
Exemplo n.º 17
0
    def test_localtaskjob_double_trigger(self):
        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
        task = dag.get_task('test_localtaskjob_double_trigger_task')

        session = settings.Session()

        dag.clear()
        dr = dag.create_dagrun(
            run_id="test",
            state=State.SUCCESS,
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            session=session,
        )
        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti_run.refresh_from_db()
        job1 = LocalTaskJob(task_instance=ti_run,
                            executor=SequentialExecutor())
        from airflow.task.task_runner.standard_task_runner import StandardTaskRunner

        with patch.object(StandardTaskRunner, 'start',
                          return_value=None) as mock_method:
            job1.run()
            mock_method.assert_not_called()

        ti = dr.get_task_instance(task_id=task.task_id, session=session)
        self.assertEqual(ti.pid, 1)
        self.assertEqual(ti.state, State.RUNNING)

        session.close()
Exemplo n.º 18
0
    def setup_attrs(self, configured_app, session) -> None:
        self.default_time = datetime(2020, 1, 1)

        clear_db_runs()
        clear_db_xcom()

        self.app = configured_app

        self.dag = self._create_dag()

        self.app.dag_bag = DagBag(os.devnull, include_examples=False)
        self.app.dag_bag.dags = {self.dag.dag_id: self.dag}  # type: ignore  # pylint: disable=no-member
        self.app.dag_bag.sync_to_db()  # type: ignore  # pylint: disable=no-member

        dr = DagRun(
            dag_id=self.dag.dag_id,
            run_id="TEST_DAG_RUN_ID",
            execution_date=self.default_time,
            run_type=DagRunType.MANUAL,
        )
        session.add(dr)
        session.commit()

        self.client = self.app.test_client()  # type:ignore
Exemplo n.º 19
0
    def process_file(
        self,
        file_path: str,
        callback_requests: List[CallbackRequest],
        pickle_dags: bool = False,
        session: Session = NEW_SESSION,
    ) -> Tuple[int, int]:
        """
        Process a Python file containing Airflow DAGs.

        This includes:

        1. Execute the file and look for DAG objects in the namespace.
        2. Execute any Callbacks if passed to this method.
        3. Serialize the DAGs and save it to DB (or update existing record in the DB).
        4. Pickle the DAG and save it to the DB (if necessary).
        5. Mark any DAGs which are no longer present as inactive
        6. Record any errors importing the file into ORM

        :param file_path: the path to the Python file that should be executed
        :param callback_requests: failure callback to execute
        :param pickle_dags: whether serialize the DAGs found in the file and
            save them to the db
        :param session: Sqlalchemy ORM Session
        :return: number of dags found, count of import errors
        :rtype: Tuple[int, int]
        """
        self.log.info("Processing file %s for tasks to queue", file_path)

        try:
            dagbag = DagBag(file_path, include_examples=False)
        except Exception:
            self.log.exception("Failed at reloading the DAG file %s",
                               file_path)
            Stats.incr('dag_file_refresh_error', 1, 1)
            return 0, 0

        if len(dagbag.dags) > 0:
            self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(),
                          file_path)
        else:
            self.log.warning("No viable dags retrieved from %s", file_path)
            self.update_import_errors(session, dagbag)
            if callback_requests:
                # If there were callback requests for this file but there was a
                # parse error we still need to progress the state of TIs,
                # otherwise they might be stuck in queued/running for ever!
                self.execute_callbacks_without_dag(callback_requests, session)
            return 0, len(dagbag.import_errors)

        self.execute_callbacks(dagbag, callback_requests, session)
        session.commit()

        # Save individual DAGs in the ORM
        dagbag.sync_to_db(session)
        session.commit()

        if pickle_dags:
            paused_dag_ids = DagModel.get_paused_dag_ids(
                dag_ids=dagbag.dag_ids)

            unpaused_dags: List[DAG] = [
                dag for dag_id, dag in dagbag.dags.items()
                if dag_id not in paused_dag_ids
            ]

            for dag in unpaused_dags:
                dag.pickle(session)

        # Record import errors into the ORM
        try:
            self.update_import_errors(session, dagbag)
        except Exception:
            self.log.exception("Error logging import errors!")

        # Record DAG warnings in the metadatabase.
        try:
            self.update_dag_warnings(session=session, dagbag=dagbag)
        except Exception:
            self.log.exception("Error logging DAG warnings.")

        return len(dagbag.dags), len(dagbag.import_errors)
def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids):
    """
    This script can be used to measure the total "scheduler overhead" of Airflow.

    By overhead we mean if the tasks executed instantly as soon as they are
    executed (i.e. they do nothing) how quickly could we schedule them.

    It will monitor the task completion of the Mock/stub executor (no actual
    tasks are run) and after the required number of dag runs for all the
    specified dags have completed all their tasks, it will cleanly shut down
    the scheduler.

    The dags you run with need to have an early enough start_date to create the
    desired number of runs.

    Care should be taken that other limits (DAG concurrency, pool size etc) are
    not the bottleneck. This script doesn't help you in that regard.

    It is recommended to repeat the test at least 3 times (`--repeat=3`, the
    default) so that you can get somewhat-accurate variance on the reported
    timing numbers, but this can be disabled for longer runs if needed.
    """

    # Turn on unit test mode so that we don't do any sleep() in the scheduler
    # loop - not needed on master, but this script can run against older
    # releases too!
    os.environ['AIRFLOW__CORE__UNIT_TEST_MODE'] = 'True'

    os.environ['AIRFLOW__CORE__DAG_CONCURRENCY'] = '500'

    # Set this so that dags can dynamically configure their end_date
    os.environ['AIRFLOW_BENCHMARK_MAX_DAG_RUNS'] = str(num_runs)
    os.environ['PERF_MAX_RUNS'] = str(num_runs)

    if pre_create_dag_runs:
        os.environ['AIRFLOW__SCHEDULER__USE_JOB_SCHEDULE'] = 'False'

    from airflow.jobs.scheduler_job import SchedulerJob
    from airflow.models.dagbag import DagBag
    from airflow.utils import db

    dagbag = DagBag()

    dags = []

    with db.create_session() as session:
        pause_all_dags(session)
        for dag_id in dag_ids:
            dag = dagbag.get_dag(dag_id)
            dag.sync_to_db(session=session)
            dags.append(dag)
            reset_dag(dag, session)

            next_run_date = dag.normalize_schedule(dag.start_date
                                                   or min(t.start_date
                                                          for t in dag.tasks))

            for _ in range(num_runs - 1):
                next_run_date = dag.following_schedule(next_run_date)

            end_date = dag.end_date or dag.default_args.get('end_date')
            if end_date != next_run_date:
                message = (
                    f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! "
                    f"It should be "
                    f" {next_run_date}")
                sys.exit(message)

            if pre_create_dag_runs:
                create_dag_runs(dag, num_runs, session)

    ShortCircuitExecutor = get_executor_under_test(executor_class)

    executor = ShortCircuitExecutor(dag_ids_to_watch=dag_ids,
                                    num_runs=num_runs)
    scheduler_job = SchedulerJob(dag_ids=dag_ids,
                                 do_pickle=False,
                                 executor=executor)
    executor.scheduler_job = scheduler_job

    total_tasks = sum(len(dag.tasks) for dag in dags)

    if 'PYSPY' in os.environ:
        pid = str(os.getpid())
        filename = os.environ.get('PYSPY_O', 'flame-' + pid + '.html')
        os.spawnlp(os.P_NOWAIT, 'sudo', 'sudo', 'py-spy', 'record', '-o',
                   filename, '-p', pid, '--idle')

    times = []

    # Need a lambda to refer to the _latest_ value for scheduler_job, not just
    # the initial one
    code_to_test = lambda: scheduler_job.run()  # pylint: disable=unnecessary-lambda

    for count in range(repeat):
        gc.disable()
        start = time.perf_counter()

        code_to_test()
        times.append(time.perf_counter() - start)
        gc.enable()
        print("Run %d time: %.5f" % (count + 1, times[-1]))

        if count + 1 != repeat:
            with db.create_session() as session:
                for dag in dags:
                    reset_dag(dag, session)

            executor.reset(dag_ids)
            scheduler_job = SchedulerJob(dag_ids=dag_ids,
                                         do_pickle=False,
                                         executor=executor)
            executor.scheduler_job = scheduler_job

    print()
    print()
    msg = "Time for %d dag runs of %d dags with %d total tasks: %.4fs"

    if len(times) > 1:
        print((msg + " (±%.3fs)") %
              (num_runs, len(dags), total_tasks, statistics.mean(times),
               statistics.stdev(times)))
    else:
        print(msg % (num_runs, len(dags), total_tasks, times[0]))

    print()
    print()
Exemplo n.º 21
0
def get_dags():
    dagbag = DagBag()
    return dagbag.dags
Exemplo n.º 22
0
class EventBasedScheduler(LoggingMixin):
    def __init__(self,
                 id,
                 mailbox: Mailbox,
                 task_event_manager: DagRunEventManager,
                 executor: BaseExecutor,
                 notification_client: NotificationClient,
                 context=None):
        super().__init__(context)
        self.id = id
        self.mailbox = mailbox
        self.task_event_manager: DagRunEventManager = task_event_manager
        self.executor = executor
        self.notification_client = notification_client
        self.dagbag = DagBag(read_dags_from_db=True)
        self._timer_handler = None
        self.timers = sched.scheduler()

    def sync(self):
        def call_regular_interval(
            delay: float,
            action: Callable,
            arguments=(),
            kwargs={},
        ):  # pylint: disable=dangerous-default-value
            def repeat(*args, **kwargs):
                action(*args, **kwargs)
                # This is not perfect. If we want a timer every 60s, but action
                # takes 10s to run, this will run it every 70s.
                # Good enough for now
                self._timer_handler = self.timers.enter(
                    delay, 1, repeat, args, kwargs)

            self._timer_handler = self.timers.enter(delay, 1, repeat,
                                                    arguments, kwargs)

        call_regular_interval(delay=1.0, action=self.executor.sync)
        self.timers.run()

    def _stop_timer(self):
        if self.timers and self._timer_handler:
            self.timers.cancel(self._timer_handler)

    def submit_sync_thread(self):
        threading.Thread(target=self.sync).start()

    def schedule(self):
        self.log.info("Starting the scheduler.")
        self._restore_unfinished_dag_run()
        while True:
            identified_message = self.mailbox.get_identified_message()
            origin_event = identified_message.deserialize()
            self.log.debug("Event: {}".format(origin_event))
            if SchedulerInnerEventUtil.is_inner_event(origin_event):
                event = SchedulerInnerEventUtil.to_inner_event(origin_event)
            else:
                event = origin_event
            with create_session() as session:
                if isinstance(event, BaseEvent):
                    dagruns = self._find_dagruns_by_event(event, session)
                    for dagrun in dagruns:
                        dag_run_id = DagRunId(dagrun.dag_id, dagrun.run_id)
                        self.task_event_manager.handle_event(dag_run_id, event)
                elif isinstance(event, RequestEvent):
                    self._process_request_event(event)
                elif isinstance(event, ResponseEvent):
                    continue
                elif isinstance(event, TaskSchedulingEvent):
                    self._schedule_task(event)
                elif isinstance(event, TaskStatusChangedEvent):
                    dagrun = self._find_dagrun(event.dag_id,
                                               event.execution_date, session)
                    tasks = self._find_schedulable_tasks(dagrun, session)
                    self._send_scheduling_task_events(tasks,
                                                      SchedulingAction.START)
                elif isinstance(event, DagExecutableEvent):
                    dagrun = self._create_dag_run(event.dag_id,
                                                  session=session)
                    tasks = self._find_schedulable_tasks(dagrun, session)
                    self._send_scheduling_task_events(tasks,
                                                      SchedulingAction.START)
                elif isinstance(event, EventHandleEvent):
                    dag_runs = DagRun.find(dag_id=event.dag_id,
                                           run_id=event.dag_run_id)
                    assert len(dag_runs) == 1
                    ti = dag_runs[0].get_task_instance(event.task_id)
                    self._send_scheduling_task_event(ti, event.action)
                elif isinstance(event, StopDagEvent):
                    self._stop_dag(event.dag_id, session)
                elif isinstance(event, ParseDagRequestEvent) or isinstance(
                        event, ParseDagResponseEvent):
                    pass
                elif isinstance(event, StopSchedulerEvent):
                    self.log.info("{} {}".format(self.id, event.job_id))
                    if self.id == event.job_id or 0 == event.job_id:
                        self.log.info("break the scheduler event loop.")
                        identified_message.remove_handled_message()
                        session.expunge_all()
                        break
                else:
                    self.log.error(
                        "can not handler the event {}".format(event))
                identified_message.remove_handled_message()
                session.expunge_all()
        self._stop_timer()

    def stop(self) -> None:
        self.mailbox.send_message(StopSchedulerEvent(self.id).to_event())
        self.log.info("Send stop event to the scheduler.")

    def recover(self, last_scheduling_id):
        self.log.info("Waiting for executor recovery...")
        self.executor.recover_state()
        unprocessed_messages = self.get_unprocessed_message(last_scheduling_id)
        self.log.info(
            "Recovering %s messages of last scheduler job with id: %s",
            len(unprocessed_messages), last_scheduling_id)
        for msg in unprocessed_messages:
            self.mailbox.send_identified_message(msg)

    @staticmethod
    def get_unprocessed_message(
            last_scheduling_id: int) -> List[IdentifiedMessage]:
        with create_session() as session:
            results: List[MSG] = session.query(MSG).filter(
                MSG.scheduling_job_id == last_scheduling_id,
                MSG.state == MessageState.QUEUED).order_by(asc(MSG.id)).all()
        unprocessed: List[IdentifiedMessage] = []
        for msg in results:
            unprocessed.append(IdentifiedMessage(msg.data, msg.id))
        return unprocessed

    def _find_dagrun(self, dag_id, execution_date, session) -> DagRun:
        dagrun = session.query(DagRun).filter(
            DagRun.dag_id == dag_id,
            DagRun.execution_date == execution_date).first()
        return dagrun

    def _create_dag_run(self,
                        dag_id,
                        session,
                        run_type=DagRunType.SCHEDULED) -> DagRun:
        with prohibit_commit(session) as guard:
            if settings.USE_JOB_SCHEDULE:
                """
                Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control
                if/when the next DAGRun should be created
                """
                try:
                    dag = self.dagbag.get_dag(dag_id, session=session)
                    dag_model = session \
                        .query(DagModel).filter(DagModel.dag_id == dag_id).first()
                    if dag_model is None:
                        return None
                    next_dagrun = dag_model.next_dagrun
                    dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
                    run_id = None
                    if run_type == DagRunType.MANUAL:
                        run_id = f"{run_type}__{timezone.utcnow().isoformat()}"
                    dag_run = dag.create_dagrun(
                        run_type=run_type,
                        execution_date=next_dagrun,
                        run_id=run_id,
                        start_date=timezone.utcnow(),
                        state=State.RUNNING,
                        external_trigger=False,
                        session=session,
                        dag_hash=dag_hash,
                        creating_job_id=self.id,
                    )
                    if run_type == DagRunType.SCHEDULED:
                        self._update_dag_next_dagrun(dag_id, session)

                    # commit the session - Release the write lock on DagModel table.
                    guard.commit()
                    # END: create dagrun
                    return dag_run
                except SerializedDagNotFound:
                    self.log.exception(
                        "DAG '%s' not found in serialized_dag table", dag_id)
                    return None
                except Exception:
                    self.log.exception(
                        "Error occurred when create dag_run of dag: %s",
                        dag_id)

    def _update_dag_next_dagrun(self, dag_id, session):
        """
                Bulk update the next_dagrun and next_dagrun_create_after for all the dags.

                We batch the select queries to get info about all the dags at once
                """
        active_runs_of_dag = session \
            .query(func.count('*')).filter(
            DagRun.dag_id == dag_id,
            DagRun.state == State.RUNNING,
            DagRun.external_trigger.is_(False),
        ).scalar()
        dag_model = session \
            .query(DagModel).filter(DagModel.dag_id == dag_id).first()

        dag = self.dagbag.get_dag(dag_id, session=session)
        if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs:
            self.log.info(
                "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs",
                dag.dag_id,
                active_runs_of_dag,
                dag.max_active_runs,
            )
            dag_model.next_dagrun_create_after = None
        else:
            dag_model.next_dagrun, dag_model.next_dagrun_create_after = dag.next_dagrun_info(
                dag_model.next_dagrun)

    def _schedule_task(self, scheduling_event: TaskSchedulingEvent):
        task_key = TaskInstanceKey(scheduling_event.dag_id,
                                   scheduling_event.task_id,
                                   scheduling_event.execution_date,
                                   scheduling_event.try_number)
        self.executor.schedule_task(task_key, scheduling_event.action)

    def _find_dagruns_by_event(self, event, session) -> Optional[List[DagRun]]:
        affect_dag_runs = []
        event_key = EventKey(event.key, event.event_type, event.namespace)
        dag_runs = session \
            .query(DagRun).filter(DagRun.state == State.RUNNING).all()
        self.log.debug('dag_runs {}'.format(len(dag_runs)))

        if dag_runs is None or len(dag_runs) == 0:
            return affect_dag_runs
        dags = session.query(SerializedDagModel).filter(
            SerializedDagModel.dag_id.in_(dag_run.dag_id
                                          for dag_run in dag_runs)).all()
        self.log.debug('dags {}'.format(len(dags)))

        affect_dags = set()
        for dag in dags:
            self.log.debug('dag config {}'.format(dag.event_relationships))
            self.log.debug('event key {} {} {}'.format(event.key,
                                                       event.event_type,
                                                       event.namespace))

            dep: DagEventDependencies = DagEventDependencies.from_json(
                dag.event_relationships)
            if dep.is_affect(event_key):
                affect_dags.add(dag.dag_id)
        if len(affect_dags) == 0:
            return affect_dag_runs
        for dag_run in dag_runs:
            if dag_run.dag_id in affect_dags:
                affect_dag_runs.append(dag_run)
        return affect_dag_runs

    def _find_schedulable_tasks(
            self,
            dag_run: DagRun,
            session: Session,
            check_execution_date=False) -> Optional[List[TI]]:
        """
        Make scheduling decisions about an individual dag run

        ``currently_active_runs`` is passed in so that a batch query can be
        used to ask this for all dag runs in the batch, to avoid an n+1 query.

        :param dag_run: The DagRun to schedule
        :return: scheduled tasks
        """
        if not dag_run or dag_run.get_state() in State.finished:
            return
        try:
            dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id,
                                                    session=session)
        except SerializedDagNotFound:
            self.log.exception("DAG '%s' not found in serialized_dag table",
                               dag_run.dag_id)
            return None

        if not dag:
            self.log.error("Couldn't find dag %s in DagBag/DB!",
                           dag_run.dag_id)
            return None

        currently_active_runs = session.query(TI.execution_date, ).filter(
            TI.dag_id == dag_run.dag_id,
            TI.state.notin_(list(State.finished)),
        ).all()

        if check_execution_date and dag_run.execution_date > timezone.utcnow(
        ) and not dag.allow_future_exec_dates:
            self.log.warning("Execution date is in future: %s",
                             dag_run.execution_date)
            return None

        if dag.max_active_runs:
            if (len(currently_active_runs) >= dag.max_active_runs
                    and dag_run.execution_date not in currently_active_runs):
                self.log.info(
                    "DAG %s already has %d active runs, not queuing any tasks for run %s",
                    dag.dag_id,
                    len(currently_active_runs),
                    dag_run.execution_date,
                )
                return None

        self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)

        schedulable_tis, callback_to_run = dag_run.update_state(
            session=session, execute_callbacks=False)
        dag_run.schedule_tis(schedulable_tis, session)

        query = (session.query(TI).outerjoin(TI.dag_run).filter(
            or_(DR.run_id.is_(None),
                DR.run_type != DagRunType.BACKFILL_JOB)).join(
                    TI.dag_model).filter(not_(DM.is_paused)).filter(
                        TI.state == State.SCHEDULED).options(
                            selectinload('dag_model')))
        scheduled_tis: List[TI] = with_row_locks(
            query,
            of=TI,
            **skip_locked(session=session),
        ).all()
        # filter need event tasks
        serialized_dag = session.query(SerializedDagModel).filter(
            SerializedDagModel.dag_id == dag_run.dag_id).first()
        dep: DagEventDependencies = DagEventDependencies.from_json(
            serialized_dag.event_relationships)
        event_task_set = dep.find_event_dependencies_tasks()
        final_scheduled_tis = []
        for ti in scheduled_tis:
            if ti.task_id not in event_task_set:
                final_scheduled_tis.append(ti)

        return final_scheduled_tis

    @provide_session
    def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
        """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow"""
        latest_version = SerializedDagModel.get_latest_version_hash(
            dag_run.dag_id, session=session)
        if dag_run.dag_hash == latest_version:
            self.log.debug(
                "DAG %s not changed structure, skipping dagrun.verify_integrity",
                dag_run.dag_id)
            return

        dag_run.dag_hash = latest_version

        # Refresh the DAG
        dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id,
                                          session=session)

        # Verify integrity also takes care of session.flush
        dag_run.verify_integrity(session=session)

    def _send_scheduling_task_event(self, ti: Optional[TI],
                                    action: SchedulingAction):
        if ti is None:
            return
        task_scheduling_event = TaskSchedulingEvent(ti.task_id, ti.dag_id,
                                                    ti.execution_date,
                                                    ti.try_number, action)
        self.mailbox.send_message(task_scheduling_event.to_event())

    def _send_scheduling_task_events(self, tis: Optional[List[TI]],
                                     action: SchedulingAction):
        if tis is None:
            return
        for ti in tis:
            self._send_scheduling_task_event(ti, action)

    @provide_session
    def _emit_pool_metrics(self, session: Session = None) -> None:
        pools = models.Pool.slots_stats(session=session)
        for pool_name, slot_stats in pools.items():
            Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
            Stats.gauge(f'pool.queued_slots.{pool_name}',
                        slot_stats[State.QUEUED])
            Stats.gauge(f'pool.running_slots.{pool_name}',
                        slot_stats[State.RUNNING])

    @staticmethod
    def _reset_unfinished_task_state(dag_run):
        with create_session() as session:
            to_be_reset = [
                s for s in State.unfinished
                if s not in [State.RUNNING, State.QUEUED]
            ]
            tis = dag_run.get_task_instances(to_be_reset, session)
            for ti in tis:
                ti.state = State.NONE
            session.commit()

    @provide_session
    def _restore_unfinished_dag_run(self, session):
        dag_runs = DagRun.next_dagruns_to_examine(
            session, max_number=sys.maxsize).all()
        if not dag_runs or len(dag_runs) == 0:
            return
        for dag_run in dag_runs:
            self._reset_unfinished_task_state(dag_run)
            tasks = self._find_schedulable_tasks(dag_run, session)
            self._send_scheduling_task_events(tasks, SchedulingAction.START)

    @provide_session
    def heartbeat_callback(self, session: Session = None) -> None:
        Stats.incr('scheduler_heartbeat', 1, 1)

    @provide_session
    def _process_request_event(self,
                               event: RequestEvent,
                               session: Session = None):
        try:
            message = BaseUserDefineMessage()
            message.from_json(event.body)
            if message.message_type == UserDefineMessageType.RUN_DAG:
                # todo make sure dag file is parsed.
                dagrun = self._create_dag_run(message.dag_id,
                                              session=session,
                                              run_type=DagRunType.MANUAL)
                if not dagrun:
                    self.log.error("Failed to create dag_run.")
                    # TODO Need to add ret_code and errro_msg in ExecutionContext in case of exception
                    self.notification_client.send_event(
                        ResponseEvent(event.request_id, None).to_event())
                    return
                tasks = self._find_schedulable_tasks(dagrun, session, False)
                self._send_scheduling_task_events(tasks,
                                                  SchedulingAction.START)
                self.notification_client.send_event(
                    ResponseEvent(event.request_id, dagrun.run_id).to_event())
            elif message.message_type == UserDefineMessageType.STOP_DAG_RUN:
                dag_run = DagRun.get_run_by_id(session=session,
                                               dag_id=message.dag_id,
                                               run_id=message.dagrun_id)
                self._stop_dag_run(dag_run)
                self.notification_client.send_event(
                    ResponseEvent(event.request_id, dag_run.run_id).to_event())
            elif message.message_type == UserDefineMessageType.EXECUTE_TASK:
                dagrun = DagRun.get_run_by_id(session=session,
                                              dag_id=message.dag_id,
                                              run_id=message.dagrun_id)
                ti: TI = dagrun.get_task_instance(task_id=message.task_id)
                self.mailbox.send_message(
                    TaskSchedulingEvent(task_id=ti.task_id,
                                        dag_id=ti.dag_id,
                                        execution_date=ti.execution_date,
                                        try_number=ti.try_number,
                                        action=SchedulingAction(
                                            message.action)).to_event())
                self.notification_client.send_event(
                    ResponseEvent(event.request_id, dagrun.run_id).to_event())
        except Exception:
            self.log.exception("Error occurred when processing request event.")

    def _stop_dag(self, dag_id, session: Session):
        """
        Stop the dag. Pause the dag and cancel all running dag_runs and task_instances.
        """
        DagModel.get_dagmodel(dag_id, session)\
            .set_is_paused(is_paused=True, including_subdags=True, session=session)
        active_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING)
        for dag_run in active_runs:
            self._stop_dag_run(dag_run)

    def _stop_dag_run(self, dag_run: DagRun):
        dag_run.stop_dag_run()
        for ti in dag_run.get_task_instances():
            if ti.state in State.unfinished:
                self.executor.schedule_task(ti.key, SchedulingAction.STOP)
Exemplo n.º 23
0
 def setUpClass(cls):
     cls.dagbag = DagBag(include_examples=True)
     cls.parser = cli_parser.get_parser()
Exemplo n.º 24
0
import logging
import json
from airflow.models.dagbag import DagBag

logging.basicConfig(level=logging.ERROR)

if __name__ == "__main__":
    dagbag = DagBag("/opt/airflow/dags")

    # todo - use graphviz dot thing ?
    out = {
        "graph": {
            d.dag_id: {
                "tasks": {
                    t.task_id: {
                        "downstream": [tt.task_id for tt in t.downstream_list],
                        "upstream": [tt.task_id for tt in t.upstream_list],
                    }
                    for t in d.tasks
                },
                "roots": [tt.task_id for tt in d.roots],
            },
        }
        for dag_id, d in dagbag.dags.items()
    }

    print(json.dumps(out))
Exemplo n.º 25
0
    def execute(self, context: Context):
        if isinstance(self.execution_date, datetime.datetime):
            parsed_execution_date = self.execution_date
        elif isinstance(self.execution_date, str):
            parsed_execution_date = timezone.parse(self.execution_date)
        else:
            parsed_execution_date = timezone.utcnow()

        if self.trigger_run_id:
            run_id = self.trigger_run_id
        else:
            run_id = DagRun.generate_run_id(DagRunType.MANUAL,
                                            parsed_execution_date)
        try:
            dag_run = trigger_dag(
                dag_id=self.trigger_dag_id,
                run_id=run_id,
                conf=self.conf,
                execution_date=parsed_execution_date,
                replace_microseconds=False,
            )

        except DagRunAlreadyExists as e:
            if self.reset_dag_run:
                self.log.info("Clearing %s on %s", self.trigger_dag_id,
                              parsed_execution_date)

                # Get target dag object and call clear()

                dag_model = DagModel.get_current(self.trigger_dag_id)
                if dag_model is None:
                    raise DagNotFound(
                        f"Dag id {self.trigger_dag_id} not found in DagModel")

                dag_bag = DagBag(dag_folder=dag_model.fileloc,
                                 read_dags_from_db=True)
                dag = dag_bag.get_dag(self.trigger_dag_id)
                dag.clear(start_date=parsed_execution_date,
                          end_date=parsed_execution_date)
                dag_run = DagRun.find(dag_id=dag.dag_id, run_id=run_id)[0]
            else:
                raise e
        if dag_run is None:
            raise RuntimeError("The dag_run should be set here!")
        # Store the execution date from the dag run (either created or found above) to
        # be used when creating the extra link on the webserver.
        ti = context['task_instance']
        ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO,
                     value=dag_run.execution_date.isoformat())
        ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)

        if self.wait_for_completion:
            # wait for dag to complete
            while True:
                self.log.info(
                    'Waiting for %s on %s to become allowed state %s ...',
                    self.trigger_dag_id,
                    dag_run.execution_date,
                    self.allowed_states,
                )
                time.sleep(self.poke_interval)

                dag_run.refresh_from_db()
                state = dag_run.state
                if state in self.failed_states:
                    raise AirflowException(
                        f"{self.trigger_dag_id} failed with failed states {state}"
                    )
                if state in self.allowed_states:
                    self.log.info("%s finished with allowed state %s",
                                  self.trigger_dag_id, state)
                    return
Exemplo n.º 26
0
class Dashboard(BaseView):
    template_folder = os.path.join(os.path.dirname(__file__), 'templates')

    DATETIME_FORMAT = '%m/%d/%y %I:%M %p'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.airflow_session = settings.Session()
        self.airflow_dag_bag = DagBag()

    @expose('/')
    def list(self):
        dag_info = self.get_dag_info()

        event_dags = [
            dag for dag in dag_info
            if 'event' in dag['name'] or dag['name'] == 'full_scrape'
        ]
        bill_dags = [
            dag for dag in dag_info
            if 'bill' in dag['name'] or dag['name'] == 'full_scrape'
        ]

        event_last_run = self.get_last_successful_dagrun(event_dags)
        bill_last_run = self.get_last_successful_dagrun(bill_dags)

        event_next_run = self.get_next_dagrun(event_dags)
        bill_next_run = self.get_next_dagrun(bill_dags)

        events_in_db, bills_in_db, bills_in_index = self.get_db_info()

        metadata = {
            'data': dag_info,
            'event_last_run': event_last_run,
            'event_next_run': event_next_run,
            'events_in_db': events_in_db,
            'bill_last_run': bill_last_run,
            'bill_next_run': bill_next_run,
            'bills_in_db': bills_in_db,
            'bills_in_index': bills_in_index,
            'datetime_format': self.DATETIME_FORMAT,
        }

        return self.render_template('dashboard.html', **metadata)

    def get_dag_info(self):

        dags = [
            self.airflow_dag_bag.get_dag(dag_id)
            for dag_id in self.airflow_dag_bag.dag_ids
            if not dag_id.startswith('airflow_')
        ]  # Filter meta-DAGs

        data = []

        for d in dags:
            last_run = dag.get_last_dagrun(d.dag_id,
                                           self.airflow_session,
                                           include_externally_triggered=True)

            if last_run:
                run_state = last_run.get_state()
                run_date_info = self._get_localized_time(
                    last_run.execution_date)

                last_successful_info = self._get_last_succesful_run_date(d)

                next_scheduled = d.following_schedule(datetime.now(pytz.utc))
                next_scheduled_info = self._get_localized_time(next_scheduled)

            else:
                run_state = None

                run_date_info = {}
                last_successful_info = {}
                next_scheduled_info = {}

            dag_info = {
                'name': d.dag_id,
                'description': d.description,
                'run_state': run_state,
                'run_date': run_date_info,
                'last_successful_date': last_successful_info,
                'next_scheduled_date': next_scheduled_info,
            }

            data.append(dag_info)

        return data

    def get_last_successful_dagrun(self, dags):
        successful_runs = [
            dag for dag in dags if dag['last_successful_date'].get('pst_time')
        ]

        if successful_runs:
            return max(successful_runs,
                       key=lambda x: x['last_successful_date']['pst_time'])

    def get_next_dagrun(self, dags):
        scheduled_runs = [
            dag for dag in dags if dag['next_scheduled_date'].get('pst_time')
        ]

        if scheduled_runs:
            return min(scheduled_runs,
                       key=lambda x: x['next_scheduled_date']['pst_time'])

    def get_db_info(self):
        url_parts = {
            'hostname': os.getenv('LA_METRO_HOST', 'http://app:8000'),
            'api_key': os.getenv('LA_METRO_API_KEY', 'test key'),
        }

        endpoint = '{hostname}/object-counts/{api_key}'.format(**url_parts)

        response = requests.get(endpoint)

        try:
            response_json = response.json()

        except json.decoder.JSONDecodeError:
            print(response.text)

        else:
            if response_json['status_code'] == 200:
                return (response_json['event_count'],
                        response_json['bill_count'],
                        response_json['search_index_count'])

        return None, None, None

    def _get_localized_time(self, date):
        pst_time = date.astimezone(PACIFIC_TIMEZONE)
        cst_time = date.astimezone(CENTRAL_TIMEZONE)

        return {
            'pst_time': pst_time,
            'cst_time': cst_time,
        }

    def _get_last_succesful_run_date(self, dag):
        run = self.airflow_session.query(dagrun.DagRun)\
                                  .filter(dagrun.DagRun.dag_id == dag.dag_id)\
                                  .filter(dagrun.DagRun.state == 'success')\
                                  .order_by(dagrun.DagRun.execution_date.desc())\
                                  .first()

        if run:
            return self._get_localized_time(run.execution_date)
        else:
            return {}
Exemplo n.º 27
0
def test_dags_load_without_errors() -> None:
    dag_bag = DagBag(dag_folder=f"{path.dirname(path.abspath(__file__))}/../../../dags", include_examples=False)
    assert len(dag_bag.import_errors) == 0
Exemplo n.º 28
0
def get_dag(dag_id: str) -> DAG:
    dag_bag = DagBag()
    dag = dag_bag.get_dag(dag_id=dag_id)
    if dag is None:
        raise KeyError(f"DAG with ID '{dag_id}' does not exist.")
    return dag
Exemplo n.º 29
0
    def process_file(
        self,
        file_path: str,
        callback_requests: List[CallbackRequest],
        pickle_dags: bool = False,
        session: Session = None,
    ) -> Tuple[int, int]:
        """
        Process a Python file containing Airflow DAGs.

        This includes:

        1. Execute the file and look for DAG objects in the namespace.
        2. Execute any Callbacks if passed to this method.
        3. Serialize the DAGs and save it to DB (or update existing record in the DB).
        4. Pickle the DAG and save it to the DB (if necessary).
        5. Record any errors importing the file into ORM

        :param file_path: the path to the Python file that should be executed
        :type file_path: str
        :param callback_requests: failure callback to execute
        :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest]
        :param pickle_dags: whether serialize the DAGs found in the file and
            save them to the db
        :type pickle_dags: bool
        :param session: Sqlalchemy ORM Session
        :type session: Session
        :return: number of dags found, count of import errors
        :rtype: Tuple[int, int]
        """
        self.log.info("Processing file %s for tasks to queue", file_path)

        try:
            dagbag = DagBag(file_path,
                            include_examples=False,
                            include_smart_sensor=False)
        except Exception:  # pylint: disable=broad-except
            self.log.exception("Failed at reloading the DAG file %s",
                               file_path)
            Stats.incr('dag_file_refresh_error', 1, 1)
            return 0, 0

        if len(dagbag.dags) > 0:
            self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(),
                          file_path)
        else:
            self.log.warning("No viable dags retrieved from %s", file_path)
            self.update_import_errors(session, dagbag)
            return 0, len(dagbag.import_errors)

        self.execute_callbacks(dagbag, callback_requests)

        # Save individual DAGs in the ORM
        dagbag.sync_to_db()

        if pickle_dags:
            paused_dag_ids = DagModel.get_paused_dag_ids(
                dag_ids=dagbag.dag_ids)

            unpaused_dags: List[DAG] = [
                dag for dag_id, dag in dagbag.dags.items()
                if dag_id not in paused_dag_ids
            ]

            for dag in unpaused_dags:
                dag.pickle(session)

        # Record import errors into the ORM
        try:
            self.update_import_errors(session, dagbag)
        except Exception:  # pylint: disable=broad-except
            self.log.exception("Error logging import errors!")

        return len(dagbag.dags), len(dagbag.import_errors)
Exemplo n.º 30
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.airflow_session = settings.Session()
        self.airflow_dag_bag = DagBag()