def delete_dag(dag_id: str,
               keep_records_in_log: bool = True,
               session=None) -> int:
    """
    :param dag_id: the dag_id of the DAG to delete
    :param keep_records_in_log: whether keep records of the given dag_id
        in the Log table in the backend database (for reasons like auditing).
        The default value is True.
    :param session: session used
    :return count of deleted dags
    """
    log.info("Deleting DAG: %s", dag_id)
    running_tis = (session.query(models.TaskInstance.state).filter(
        models.TaskInstance.dag_id == dag_id).filter(
            models.TaskInstance.state == State.RUNNING).first())
    if running_tis:
        raise AirflowException("TaskInstances still running")
    dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
    if dag is None:
        raise DagNotFound(f"Dag id {dag_id} not found")

    # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
    # There may be a lag, so explicitly removes serialized DAG here.
    if SerializedDagModel.has_dag(dag_id=dag_id, session=session):
        SerializedDagModel.remove_dag(dag_id=dag_id, session=session)

    count = 0

    for model in get_sqla_model_classes():
        if hasattr(model, "dag_id"):
            if keep_records_in_log and model.__name__ == 'Log':
                continue
            cond = or_(model.dag_id == dag_id,
                       model.dag_id.like(dag_id + ".%"))
            count += session.query(model).filter(cond).delete(
                synchronize_session='fetch')
    if dag.is_subdag:
        parent_dag_id, task_id = dag_id.rsplit(".", 1)
        for model in TaskFail, models.TaskInstance:
            count += (session.query(model).filter(
                model.dag_id == parent_dag_id,
                model.task_id == task_id).delete())

    # Delete entries in Import Errors table for a deleted DAG
    # This handles the case when the dag_id is changed in the file
    session.query(models.ImportError).filter(
        models.ImportError.filename == dag.fileloc).delete(
            synchronize_session='fetch')

    return count
Beispiel #2
0
    def test_should_respond_200_serialized(self):

        # Get the dag out of the dagbag before we patch it to an empty one
        SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id))

        dag_bag = DagBag(os.devnull,
                         include_examples=False,
                         read_dags_from_db=True)
        patcher = unittest.mock.patch.object(self.app, 'dag_bag', dag_bag)
        patcher.start()

        expected = {
            "class_ref": {
                "class_name": "DummyOperator",
                "module_path": "airflow.operators.dummy",
            },
            "depends_on_past": False,
            "downstream_task_ids": [self.task_id2],
            "end_date": None,
            "execution_timeout": None,
            "extra_links": [],
            "owner": "airflow",
            "pool": "default_pool",
            "pool_slots": 1.0,
            "priority_weight": 1.0,
            "queue": "default",
            "retries": 0.0,
            "retry_delay": {
                "__type": "TimeDelta",
                "days": 0,
                "seconds": 300,
                "microseconds": 0
            },
            "retry_exponential_backoff": False,
            "start_date": "2020-06-15T00:00:00+00:00",
            "task_id": "op1",
            "template_fields": [],
            "trigger_rule": "all_success",
            "ui_color": "#e8f7e4",
            "ui_fgcolor": "#000",
            "wait_for_downstream": False,
            "weight_rule": "downstream",
        }
        response = self.client.get(
            f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}",
            environ_overrides={'REMOTE_USER': "******"})
        assert response.status_code == 200
        assert response.json == expected
        patcher.stop()
    def test_dag_trigger_parse_dag(self):
        mailbox = Mailbox()
        dag_trigger = DagTrigger("../../dags/test_scheduler_dags.py", -1, [], False, mailbox)
        dag_trigger.start()

        message = mailbox.get_message()
        message = SchedulerInnerEventUtil.to_inner_event(message)
        # only one dag is executable
        assert "test_task_start_date_scheduling" == message.dag_id

        assert DagModel.get_dagmodel(dag_id="test_task_start_date_scheduling") is not None
        assert DagModel.get_dagmodel(dag_id="test_start_date_scheduling") is not None
        assert SerializedDagModel.get(dag_id="test_task_start_date_scheduling") is not None
        assert SerializedDagModel.get(dag_id="test_start_date_scheduling") is not None
        dag_trigger.end()
Beispiel #4
0
 def sync_to_db(self):
     """
     Save attributes about list of DAG to the DB.
     """
     # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
     from airflow.models.dag import DAG
     from airflow.models.serialized_dag import SerializedDagModel
     self.log.debug("Calling the DAG.bulk_sync_to_db method")
     DAG.bulk_sync_to_db(self.dags.values())
     # Write Serialized DAGs to DB if DAG Serialization is turned on
     # Even though self.read_dags_from_db is False
     if settings.STORE_SERIALIZED_DAGS:
         self.log.debug(
             "Calling the SerializedDagModel.bulk_sync_to_db method")
         SerializedDagModel.bulk_sync_to_db(self.dags.values())
Beispiel #5
0
def delete_dag(dag_id: str,
               keep_records_in_log: bool = True,
               session=None) -> int:
    """
    :param dag_id: the dag_id of the DAG to delete
    :param keep_records_in_log: whether keep records of the given dag_id
        in the Log table in the backend database (for reasons like auditing).
        The default value is True.
    :param session: session used
    :return count of deleted dags
    """
    logger = LoggingMixin()
    logger.log.info("Deleting DAG: %s", dag_id)
    dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
    if dag is None:
        raise DagNotFound("Dag id {} not found".format(dag_id))

    # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval.
    # There may be a lag, so explicitly removes serialized DAG here.
    if STORE_SERIALIZED_DAGS and SerializedDagModel.has_dag(dag_id=dag_id,
                                                            session=session):
        SerializedDagModel.remove_dag(dag_id=dag_id, session=session)

    count = 0

    # noinspection PyUnresolvedReferences,PyProtectedMember
    for model in models.base.Base._decl_class_registry.values():  # pylint: disable=protected-access
        if hasattr(model, "dag_id"):
            if keep_records_in_log and model.__name__ == 'Log':
                continue
            cond = or_(model.dag_id == dag_id,
                       model.dag_id.like(dag_id + ".%"))
            count += session.query(model).filter(cond).delete(
                synchronize_session='fetch')
    if dag.is_subdag:
        parent_dag_id, task_id = dag_id.rsplit(".", 1)
        for model in models.DagRun, TaskFail, models.TaskInstance:
            count += session.query(model).filter(
                model.dag_id == parent_dag_id,
                model.task_id == task_id).delete()

    # Delete entries in Import Errors table for a deleted DAG
    # This handles the case when the dag_id is changed in the file
    session.query(models.ImportError).filter(
        models.ImportError.filename == dag.fileloc).delete(
            synchronize_session='fetch')

    return count
Beispiel #6
0
    def test_should_response_200_serialized(self):
        # Create empty app with empty dagbag to check if DAG is read from db
        app_serialized = app.create_app(testing=True)
        dag_bag = DagBag(os.devnull,
                         include_examples=False,
                         read_dags_from_db=True)
        app_serialized.dag_bag = dag_bag
        client = app_serialized.test_client()

        SerializedDagModel.write_dag(self.dag)

        expected = {
            "class_ref": {
                "class_name": "DummyOperator",
                "module_path": "airflow.operators.dummy_operator",
            },
            "depends_on_past": False,
            "downstream_task_ids": [],
            "end_date": None,
            "execution_timeout": None,
            "extra_links": [],
            "owner": "airflow",
            "pool": "default_pool",
            "pool_slots": 1.0,
            "priority_weight": 1.0,
            "queue": "default",
            "retries": 0.0,
            "retry_delay": {
                "__type": "TimeDelta",
                "days": 0,
                "seconds": 300,
                "microseconds": 0
            },
            "retry_exponential_backoff": False,
            "start_date": "2020-06-15T00:00:00+00:00",
            "task_id": "op1",
            "template_fields": [],
            "trigger_rule": "all_success",
            "ui_color": "#e8f7e4",
            "ui_fgcolor": "#000",
            "wait_for_downstream": False,
            "weight_rule": "downstream",
        }
        response = client.get(
            f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}",
            environ_overrides={'REMOTE_USER': "******"})
        assert response.status_code == 200
        assert response.json == expected
Beispiel #7
0
    def get_dag(self, dag_id):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        # Only read DAGs from DB if this dagbag is store_serialized_dags.
        if self.store_serialized_dags:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel
            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                row = SerializedDagModel.get(dag_id)
                if not row:
                    return None

                dag = row.dag
                for subdag in dag.subdags:
                    self.dags[subdag.dag_id] = subdag
                self.dags[dag.dag_id] = dag

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id

        # Needs to load from file for a store_serialized_dags dagbag.
        enforce_from_file = False
        if self.store_serialized_dags and dag is not None:
            from airflow.serialization.serialized_objects import SerializedDAG
            enforce_from_file = isinstance(dag, SerializedDAG)

        # If the dag corresponding to root_dag_id is absent or expired
        orm_dag = DagModel.get_current(root_dag_id)
        if (orm_dag and
            (root_dag_id not in self.dags or
             (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired))
            ) or enforce_from_file:
            # Reprocess source file
            found_dags = self.process_file(filepath=correct_maybe_zipped(
                orm_dag.fileloc),
                                           only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [
                    found_dag.dag_id for found_dag in found_dags
            ]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)
Beispiel #8
0
        def _serialize_dag_capturing_errors(dag, session):
            """
            Try to serialize the dag to the DB, but make a note of any errors.

            We can't place them directly in import_errors, as this may be retried, and work the next time
            """
            if dag.is_subdag:
                return []
            try:
                # We cant use bulk_write_to_db as we want to capture each error individually
                dag_was_updated = SerializedDagModel.write_dag(
                    dag,
                    min_update_interval=settings.
                    MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
                    session=session,
                )
                if dag_was_updated:
                    self.log.debug("Syncing DAG permissions: %s to the DB",
                                   dag.dag_id)
                    from airflow.www.security import ApplessAirflowSecurityManager

                    security_manager = ApplessAirflowSecurityManager(
                        session=session)
                    security_manager.sync_perm_for_dag(dag.dag_id,
                                                       dag.access_control)
                return []
            except OperationalError:
                raise
            except Exception:  # pylint: disable=broad-except
                return [(dag.fileloc,
                         traceback.format_exc(
                             limit=-self.dagbag_import_error_traceback_depth))]
Beispiel #9
0
        def _serialize_dag_capturing_errors(dag, session):
            """
            Try to serialize the dag to the DB, but make a note of any errors.

            We can't place them directly in import_errors, as this may be retried, and work the next time
            """
            if dag.is_subdag:
                return []
            try:
                # We can't use bulk_write_to_db as we want to capture each error individually
                dag_was_updated = SerializedDagModel.write_dag(
                    dag,
                    min_update_interval=settings.
                    MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
                    session=session,
                )
                if dag_was_updated:
                    self._sync_perm_for_dag(dag, session=session)
                return []
            except OperationalError:
                raise
            except Exception:
                self.log.exception("Failed to write serialized DAG: %s",
                                   dag.full_filepath)
                return [(dag.fileloc,
                         traceback.format_exc(
                             limit=-self.dagbag_import_error_traceback_depth))]
Beispiel #10
0
    def get_dag(self, dag_id):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        if self.read_dags_from_db:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel
            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                self._add_dag_from_db(dag_id=dag_id)
                return self.dags.get(dag_id)

            # If DAG is in the DagBag, check the following
            # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs)
            # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated
            # 3. if (2) is yes, fetch the Serialized DAG.
            min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
            if (
                dag_id in self.dags_last_fetched and
                timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs
            ):
                sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id)
                if sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
                    self._add_dag_from_db(dag_id=dag_id)

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id

        # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized?
        orm_dag = DagModel.get_current(root_dag_id)
        if not orm_dag:
            return self.dags.get(dag_id)

        # If the dag corresponding to root_dag_id is absent or expired
        is_missing = root_dag_id not in self.dags
        is_expired = (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired)
        if is_missing or is_expired:
            # Reprocess source file
            found_dags = self.process_file(
                filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)
    def test_user_defined_filter_and_macros_raise_error(
            self, get_dag_function):
        """
        Test that the Rendered View is able to show rendered values
        even for TIs that have not yet executed
        """
        get_dag_function.return_value = SerializedDagModel.get(
            self.dag.dag_id).dag

        self.assertEqual(self.task2.bash_command,
                         'echo {{ fullname("Apache", "Airflow") | hello }}')

        url = (
            '/admin/airflow/rendered?task_id=task2&dag_id=testdag&execution_date={}'
            .format(self.percent_encode(self.default_date)))

        resp = self.app.get(url, follow_redirects=True)
        self.assertNotIn("echo Hello Apache Airflow",
                         resp.data.decode('utf-8'))

        if six.PY3:
            self.assertIn(
                "Webserver does not have access to User-defined Macros or Filters "
                "when Dag Serialization is enabled. Hence for the task that have not yet "
                "started running, please use &#39;airflow tasks render&#39; for debugging the "
                "rendering of template_fields.<br/><br/>OriginalError: no filter named &#39;hello&#39",
                resp.data.decode('utf-8'))
        else:
            self.assertIn(
                "Webserver does not have access to User-defined Macros or Filters "
                "when Dag Serialization is enabled. Hence for the task that have not yet "
                "started running, please use &#39;airflow tasks render&#39; for debugging the "
                "rendering of template_fields.", resp.data.decode('utf-8'))
Beispiel #12
0
def get_dag_by_deserialization(dag_id: str) -> "DAG":
    from airflow.models.serialized_dag import SerializedDagModel

    dag_model = SerializedDagModel.get(dag_id)
    if dag_model is None:
        raise AirflowException(f"Serialized DAG: {dag_id} could not be found")

    return dag_model.dag
 def test_remove_dags_by_filepath(self):
     """DAGs can be removed from database."""
     example_dags_list = list(self._write_example_dags().values())
     # Remove SubDags from the list as they are not stored in DB in a separate row
     # and are directly added in Json blob of the main DAG
     filtered_example_dags_list = [
         dag for dag in example_dags_list if not dag.is_subdag
     ]
     # Tests removing by file path.
     dag_removed_by_file = filtered_example_dags_list[0]
     # remove repeated files for those DAGs that define multiple dags in the same file (set comprehension)
     example_dag_files = list(
         {dag.full_filepath
          for dag in filtered_example_dags_list})
     example_dag_files.remove(dag_removed_by_file.full_filepath)
     SDM.remove_deleted_dags(example_dag_files)
     self.assertFalse(SDM.has_dag(dag_removed_by_file.dag_id))
Beispiel #14
0
    def get_dag(self, dag_id):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        # Only read DAGs from DB if this dagbag is read_dags_from_db.
        if self.read_dags_from_db:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel
            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                row = SerializedDagModel.get(dag_id)
                if not row:
                    return None

                dag = row.dag
                for subdag in dag.subdags:
                    self.dags[subdag.dag_id] = subdag
                self.dags[dag.dag_id] = dag

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id

        # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized?
        orm_dag = DagModel.get_current(root_dag_id)
        if not orm_dag:
            return self.dags.get(dag_id)

        # If the dag corresponding to root_dag_id is absent or expired
        is_missing = root_dag_id not in self.dags
        is_expired = (orm_dag.last_expired
                      and dag.last_loaded < orm_dag.last_expired)
        if is_missing or is_expired:
            # Reprocess source file
            found_dags = self.process_file(filepath=correct_maybe_zipped(
                orm_dag.fileloc),
                                           only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [
                    found_dag.dag_id for found_dag in found_dags
            ]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)
Beispiel #15
0
    def _deactivate_stale_dags(self, session=None):
        """
        Detects DAGs which are no longer present in files

        Deactivate them and remove them in the serialized_dag table
        """
        now = timezone.utcnow()
        elapsed_time_since_refresh = (
            now - self.last_deactivate_stale_dags_time).total_seconds()
        if elapsed_time_since_refresh > self.deactivate_stale_dags_interval:
            last_parsed = {
                fp: self.get_last_finish_time(fp)
                for fp in self.file_paths if self.get_last_finish_time(fp)
            }
            to_deactivate = set()
            dags_parsed = (session.query(DagModel.dag_id, DagModel.fileloc,
                                         DagModel.last_parsed_time).filter(
                                             DagModel.is_active).all())
            for dag in dags_parsed:
                # The largest valid difference between a DagFileStat's last_finished_time and a DAG's
                # last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is
                # no longer present in the file.
                if (dag.fileloc in last_parsed
                        and (dag.last_parsed_time + self._processor_timeout) <
                        last_parsed[dag.fileloc]):
                    self.log.info("DAG %s is missing and will be deactivated.",
                                  dag.dag_id)
                    to_deactivate.add(dag.dag_id)

            if to_deactivate:
                deactivated = (session.query(DagModel).filter(
                    DagModel.dag_id.in_(to_deactivate)).update(
                        {DagModel.is_active: False},
                        synchronize_session="fetch"))
                if deactivated:
                    self.log.info(
                        "Deactivated %i DAGs which are no longer present in file.",
                        deactivated)

                for dag_id in to_deactivate:
                    SerializedDagModel.remove_dag(dag_id)
                    self.log.info("Deleted DAG %s in serialized_dag table",
                                  dag_id)

            self.last_deactivate_stale_dags_time = timezone.utcnow()
    def test_event_op_dag_read_write(self):
        class TestHandler(EventMetHandler):
            def met(self, ti: TaskInstance, ts: TaskState) -> TaskAction:
                return TaskAction.START

        now = timezone.utcnow()
        dag_id = 'test_add_taskstate_0'
        dag = DAG(dag_id=dag_id, start_date=now)
        task0 = DummyOperator(task_id='backfill_task_0', owner='test', dag=dag)
        task0.add_event_dependency('key', "EVENT")
        task0.set_event_met_handler(TestHandler())
        SDM.write_dag(dag)
        with db.create_session() as session:
            sdag = session.query(SDM).first()
            dag = SerializedDAG.from_dict(sdag.data)
        self.assertEqual(dag_id, dag.dag_id)
        self.assertEqual(
            1, len(dag.task_dict["backfill_task_0"].event_dependencies()))
Beispiel #17
0
    def test_get_dag_with_dag_serialization(self):
        """
        Test that Serialized DAG is updated in DagBag when it is updated in
        Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed.
        """

        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)):
            example_bash_op_dag = DagBag(
                include_examples=True).dags.get("example_bash_operator")
            SerializedDagModel.write_dag(dag=example_bash_op_dag)

            dag_bag = DagBag(read_dags_from_db=True)
            ser_dag_1 = dag_bag.get_dag("example_bash_operator")
            ser_dag_1_update_time = dag_bag.dags_last_fetched[
                "example_bash_operator"]
            self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags)
            self.assertEqual(ser_dag_1_update_time,
                             tz.datetime(2020, 1, 5, 0, 0, 0))

        # Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG
        # from DB
        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)):
            with assert_queries_count(0):
                self.assertEqual(
                    dag_bag.get_dag("example_bash_operator").tags,
                    ["example", "example2"])

        # Make a change in the DAG and write Serialized DAG to the DB
        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)):
            example_bash_op_dag.tags += ["new_tag"]
            SerializedDagModel.write_dag(dag=example_bash_op_dag)

        # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
        # fetches the Serialized DAG from DB
        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)):
            with assert_queries_count(2):
                updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator")
                updated_ser_dag_1_update_time = dag_bag.dags_last_fetched[
                    "example_bash_operator"]

        self.assertCountEqual(updated_ser_dag_1.tags,
                              ["example", "example2", "new_tag"])
        self.assertGreater(updated_ser_dag_1_update_time,
                           ser_dag_1_update_time)
Beispiel #18
0
    def test_collect_dags_from_db(self):
        """DAGs are collected from Database"""
        example_dags_folder = airflow.example_dags.__path__[0]
        dagbag = DagBag(example_dags_folder)

        example_dags = dagbag.dags
        for dag in example_dags.values():
            SerializedDagModel.write_dag(dag)

        new_dagbag = DagBag(read_dags_from_db=True)
        self.assertEqual(len(new_dagbag.dags), 0)
        new_dagbag.collect_dags_from_db()
        new_dags = new_dagbag.dags
        self.assertEqual(len(example_dags), len(new_dags))
        for dag_id, dag in example_dags.items():
            serialized_dag = new_dags[dag_id]

            self.assertEqual(serialized_dag.dag_id, dag.dag_id)
            self.assertEqual(set(serialized_dag.task_dict), set(dag.task_dict))
    def test_read_dags(self):
        """DAGs can be read from database."""
        example_dags = self._write_example_dags()
        serialized_dags = SDM.read_all_dags()
        self.assertTrue(len(example_dags) == len(serialized_dags))
        for dag_id, dag in example_dags.items():
            serialized_dag = serialized_dags[dag_id]

            self.assertTrue(serialized_dag.dag_id == dag.dag_id)
            self.assertTrue(set(serialized_dag.task_dict) == set(dag.task_dict))
Beispiel #20
0
    def _refresh_dag_dir(self):
        """Refresh file paths from dag dir if we haven't done it for too long."""
        now = timezone.utcnow()
        elapsed_time_since_refresh = (
            now - self.last_dag_dir_refresh_time).total_seconds()
        if elapsed_time_since_refresh > self.dag_dir_list_interval:
            # Build up a list of Python files that could contain DAGs
            self.log.info("Searching for files in %s", self._dag_directory)
            self._file_paths = list_py_file_paths(self._dag_directory)
            self.last_dag_dir_refresh_time = now
            self.log.info("There are %s files in %s", len(self._file_paths),
                          self._dag_directory)
            self.set_file_paths(self._file_paths)

            try:
                self.log.debug("Removing old import errors")
                self.clear_nonexistent_import_errors()
            except Exception:
                self.log.exception("Error removing old import errors")

            # Check if file path is a zipfile and get the full path of the python file.
            # Without this, SerializedDagModel.remove_deleted_files would delete zipped dags.
            # Likewise DagCode.remove_deleted_code
            dag_filelocs = []
            for fileloc in self._file_paths:
                if not fileloc.endswith(".py") and zipfile.is_zipfile(fileloc):
                    with zipfile.ZipFile(fileloc) as z:
                        dag_filelocs.extend([
                            os.path.join(fileloc, info.filename)
                            for info in z.infolist()
                            if might_contain_dag(info.filename, True, z)
                        ])
                else:
                    dag_filelocs.append(fileloc)

            SerializedDagModel.remove_deleted_dags(dag_filelocs)
            DagModel.deactivate_deleted_dags(self._file_paths)

            from airflow.models.dagcode import DagCode

            DagCode.remove_deleted_code(dag_filelocs)
Beispiel #21
0
    def _add_dag_from_db(self, dag_id: str):
        """Add DAG to DagBag from DB"""
        from airflow.models.serialized_dag import SerializedDagModel
        row = SerializedDagModel.get(dag_id)
        if not row:
            raise ValueError(f"DAG '{dag_id}' not found in serialized_dag table")

        dag = row.dag
        for subdag in dag.subdags:
            self.dags[subdag.dag_id] = subdag
        self.dags[dag.dag_id] = dag
        self.dags_last_fetched[dag.dag_id] = timezone.utcnow()
Beispiel #22
0
    def test_write_dag(self):
        """DAGs can be written into database."""
        example_dags = self._write_example_dags()

        with create_session() as session:
            for dag in example_dags.values():
                assert SDM.has_dag(dag.dag_id)
                result = session.query(SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one()

                assert result.fileloc == dag.full_filepath
                # Verifies JSON schema.
                SerializedDAG.validate_schema(result.data)
Beispiel #23
0
    def test_serialized_dag_is_updated_only_if_dag_is_changed(self):
        """Test Serialized DAG is updated if DAG is changed"""

        example_dags = make_example_dags(example_dags_module)
        example_bash_op_dag = example_dags.get("example_bash_operator")
        SDM.write_dag(dag=example_bash_op_dag)

        with create_session() as session:
            s_dag = session.query(SDM).get(example_bash_op_dag.dag_id)

            # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
            # column is not updated
            SDM.write_dag(dag=example_bash_op_dag)
            s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id)

            self.assertEqual(s_dag_1.dag_hash, s_dag.dag_hash)
            self.assertEqual(s_dag.last_updated, s_dag_1.last_updated)

            # Update DAG
            example_bash_op_dag.tags += ["new_tag"]
            self.assertCountEqual(example_bash_op_dag.tags, ["example", "new_tag"])

            SDM.write_dag(dag=example_bash_op_dag)
            s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id)

            self.assertNotEqual(s_dag.last_updated, s_dag_2.last_updated)
            self.assertNotEqual(s_dag.dag_hash, s_dag_2.dag_hash)
            self.assertEqual(s_dag_2.data["dag"]["tags"], ["example", "new_tag"])
Beispiel #24
0
    def test_serialized_dag_is_updated_only_if_dag_is_changed(self):
        """Test Serialized DAG is updated if DAG is changed"""

        example_dags = make_example_dags(example_dags_module)
        example_bash_op_dag = example_dags.get("example_bash_operator")
        SDM.write_dag(dag=example_bash_op_dag)

        with create_session() as session:
            s_dag = session.query(SDM).get(example_bash_op_dag.dag_id)

            # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
            # column is not updated
            SDM.write_dag(dag=example_bash_op_dag)
            s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id)

            assert s_dag_1.dag_hash == s_dag.dag_hash
            assert s_dag.last_updated == s_dag_1.last_updated

            # Update DAG
            example_bash_op_dag.tags += ["new_tag"]
            assert set(example_bash_op_dag.tags) == {
                "example", "example2", "new_tag"
            }

            SDM.write_dag(dag=example_bash_op_dag)
            s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id)

            assert s_dag.last_updated != s_dag_2.last_updated
            assert s_dag.dag_hash != s_dag_2.dag_hash
            assert s_dag_2.data["dag"]["tags"] == [
                "example", "example2", "new_tag"
            ]
    def test_serialized_dag_is_updated_only_if_dag_is_changed(self):
        """Test Serialized DAG is updated if DAG is changed"""

        example_dags = make_example_dags(example_dags_module)
        example_bash_op_dag = example_dags.get("example_bash_operator")
        SDM.write_dag(dag=example_bash_op_dag)

        with create_session() as session:
            last_updated = session.query(SDM.last_updated).filter(
                SDM.dag_id == example_bash_op_dag.dag_id).one_or_none()

            # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated
            # column is not updated
            SDM.write_dag(dag=example_bash_op_dag)
            last_updated_1 = session.query(SDM.last_updated).filter(
                SDM.dag_id == example_bash_op_dag.dag_id).one_or_none()

            self.assertEqual(last_updated, last_updated_1)

            # Update DAG
            example_bash_op_dag.tags += ["new_tag"]
            self.assertCountEqual(example_bash_op_dag.tags,
                                  ["example", "new_tag"])

            SDM.write_dag(dag=example_bash_op_dag)
            new_s_dag = session.query(SDM.last_updated, SDM.data).filter(
                SDM.dag_id == example_bash_op_dag.dag_id).one_or_none()

            self.assertNotEqual(last_updated, new_s_dag.last_updated)
            self.assertEqual(new_s_dag.data["dag"]["tags"],
                             ["example", "new_tag"])
 def setUp(self):
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.default_date = datetime(2020, 3, 1)
     self.dag = DAG(
         "testdag",
         start_date=self.default_date,
         user_defined_filters={"hello": lambda name: 'Hello ' + name},
         user_defined_macros={
             "fullname": lambda fname, lname: fname + " " + lname
         })
     self.task1 = BashOperator(task_id='task1',
                               bash_command='{{ task_instance_key_str }}',
                               dag=self.dag)
     self.task2 = BashOperator(
         task_id='task2',
         bash_command='echo {{ fullname("Apache", "Airflow") | hello }}',
         dag=self.dag)
     SerializedDagModel.write_dag(self.dag)
     with create_session() as session:
         session.query(RTIF).delete()
Beispiel #27
0
    def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
        """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow"""
        latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session)
        if dag_run.dag_hash == latest_version:
            self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
            return

        dag_run.dag_hash = latest_version

        # Refresh the DAG
        dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)

        # Verify integrity also takes care of session.flush
        dag_run.verify_integrity(session=session)
Beispiel #28
0
    def sync_to_db(self, session: Optional[Session] = None):
        """Save attributes about list of DAG to the DB."""
        # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
        from airflow.models.dag import DAG
        from airflow.models.serialized_dag import SerializedDagModel

        # Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case
        # of any Operational Errors
        # In case of failures, provide_session handles rollback
        for attempt in tenacity.Retrying(
                retry=tenacity.retry_if_exception_type(
                    exception_types=OperationalError),
                wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
                stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES),
                before_sleep=tenacity.before_sleep_log(self.log,
                                                       logging.DEBUG),
                reraise=True,
        ):
            with attempt:
                self.log.debug(
                    "Running dagbag.sync_to_db with retries. Try %d of %d",
                    attempt.retry_state.attempt_number,
                    settings.MAX_DB_RETRIES,
                )
                self.log.debug("Calling the DAG.bulk_sync_to_db method")
                try:
                    DAG.bulk_write_to_db(self.dags.values(), session=session)

                    # Write Serialized DAGs to DB
                    self.log.debug(
                        "Calling the SerializedDagModel.bulk_sync_to_db method"
                    )
                    SerializedDagModel.bulk_sync_to_db(self.dags.values(),
                                                       session=session)
                except OperationalError:
                    session.rollback()
                    raise
    def test_should_response_200_serialized(self):
        # Create empty app with empty dagbag to check if DAG is read from db
        app_serialized = app.create_app(testing=True)  # type:ignore
        dag_bag = DagBag(os.devnull,
                         include_examples=False,
                         store_serialized_dags=True)
        app_serialized.dag_bag = dag_bag  # type:ignore
        client = app_serialized.test_client()

        SerializedDagModel.write_dag(self.dag)

        expected = {
            'catchup': True,
            'concurrency': 16,
            'dag_id': 'test_dag',
            'dag_run_timeout': None,
            'default_view': 'tree',
            'description': None,
            'doc_md': 'details',
            'fileloc': __file__,
            'is_paused': None,
            'is_subdag': False,
            'orientation': 'LR',
            'owners': [],
            'schedule_interval': {
                '__type': 'TimeDelta',
                'days': 1,
                'microseconds': 0,
                'seconds': 0
            },
            'start_date': '2020-06-15T00:00:00+00:00',
            'tags': None,
            'timezone': "Timezone('UTC')"
        }
        response = client.get(f"/api/v1/dags/{self.dag_id}/details")
        assert response.status_code == 200
        assert response.json == expected
Beispiel #30
0
    def _add_dag_from_db(self, dag_id: str, session: Session):
        """Add DAG to DagBag from DB"""
        from airflow.models.serialized_dag import SerializedDagModel

        row = SerializedDagModel.get(dag_id, session)
        if not row:
            raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table")

        row.load_op_links = self.load_op_links
        dag = row.dag
        for subdag in dag.subdags:
            self.dags[subdag.dag_id] = subdag
        self.dags[dag.dag_id] = dag
        self.dags_last_fetched[dag.dag_id] = timezone.utcnow()
        self.dags_hash[dag.dag_id] = row.dag_hash