def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: """ :param dag_id: the dag_id of the DAG to delete :param keep_records_in_log: whether keep records of the given dag_id in the Log table in the backend database (for reasons like auditing). The default value is True. :param session: session used :return count of deleted dags """ log.info("Deleting DAG: %s", dag_id) running_tis = (session.query(models.TaskInstance.state).filter( models.TaskInstance.dag_id == dag_id).filter( models.TaskInstance.state == State.RUNNING).first()) if running_tis: raise AirflowException("TaskInstances still running") dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() if dag is None: raise DagNotFound(f"Dag id {dag_id} not found") # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval. # There may be a lag, so explicitly removes serialized DAG here. if SerializedDagModel.has_dag(dag_id=dag_id, session=session): SerializedDagModel.remove_dag(dag_id=dag_id, session=session) count = 0 for model in get_sqla_model_classes(): if hasattr(model, "dag_id"): if keep_records_in_log and model.__name__ == 'Log': continue cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%")) count += session.query(model).filter(cond).delete( synchronize_session='fetch') if dag.is_subdag: parent_dag_id, task_id = dag_id.rsplit(".", 1) for model in TaskFail, models.TaskInstance: count += (session.query(model).filter( model.dag_id == parent_dag_id, model.task_id == task_id).delete()) # Delete entries in Import Errors table for a deleted DAG # This handles the case when the dag_id is changed in the file session.query(models.ImportError).filter( models.ImportError.filename == dag.fileloc).delete( synchronize_session='fetch') return count
def test_should_respond_200_serialized(self): # Get the dag out of the dagbag before we patch it to an empty one SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id)) dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) patcher = unittest.mock.patch.object(self.app, 'dag_bag', dag_bag) patcher.start() expected = { "class_ref": { "class_name": "DummyOperator", "module_path": "airflow.operators.dummy", }, "depends_on_past": False, "downstream_task_ids": [self.task_id2], "end_date": None, "execution_timeout": None, "extra_links": [], "owner": "airflow", "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, "queue": "default", "retries": 0.0, "retry_delay": { "__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0 }, "retry_exponential_backoff": False, "start_date": "2020-06-15T00:00:00+00:00", "task_id": "op1", "template_fields": [], "trigger_rule": "all_success", "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, "weight_rule": "downstream", } response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 assert response.json == expected patcher.stop()
def test_dag_trigger_parse_dag(self): mailbox = Mailbox() dag_trigger = DagTrigger("../../dags/test_scheduler_dags.py", -1, [], False, mailbox) dag_trigger.start() message = mailbox.get_message() message = SchedulerInnerEventUtil.to_inner_event(message) # only one dag is executable assert "test_task_start_date_scheduling" == message.dag_id assert DagModel.get_dagmodel(dag_id="test_task_start_date_scheduling") is not None assert DagModel.get_dagmodel(dag_id="test_start_date_scheduling") is not None assert SerializedDagModel.get(dag_id="test_task_start_date_scheduling") is not None assert SerializedDagModel.get(dag_id="test_start_date_scheduling") is not None dag_trigger.end()
def sync_to_db(self): """ Save attributes about list of DAG to the DB. """ # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel self.log.debug("Calling the DAG.bulk_sync_to_db method") DAG.bulk_sync_to_db(self.dags.values()) # Write Serialized DAGs to DB if DAG Serialization is turned on # Even though self.read_dags_from_db is False if settings.STORE_SERIALIZED_DAGS: self.log.debug( "Calling the SerializedDagModel.bulk_sync_to_db method") SerializedDagModel.bulk_sync_to_db(self.dags.values())
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int: """ :param dag_id: the dag_id of the DAG to delete :param keep_records_in_log: whether keep records of the given dag_id in the Log table in the backend database (for reasons like auditing). The default value is True. :param session: session used :return count of deleted dags """ logger = LoggingMixin() logger.log.info("Deleting DAG: %s", dag_id) dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() if dag is None: raise DagNotFound("Dag id {} not found".format(dag_id)) # Scheduler removes DAGs without files from serialized_dag table every dag_dir_list_interval. # There may be a lag, so explicitly removes serialized DAG here. if STORE_SERIALIZED_DAGS and SerializedDagModel.has_dag(dag_id=dag_id, session=session): SerializedDagModel.remove_dag(dag_id=dag_id, session=session) count = 0 # noinspection PyUnresolvedReferences,PyProtectedMember for model in models.base.Base._decl_class_registry.values(): # pylint: disable=protected-access if hasattr(model, "dag_id"): if keep_records_in_log and model.__name__ == 'Log': continue cond = or_(model.dag_id == dag_id, model.dag_id.like(dag_id + ".%")) count += session.query(model).filter(cond).delete( synchronize_session='fetch') if dag.is_subdag: parent_dag_id, task_id = dag_id.rsplit(".", 1) for model in models.DagRun, TaskFail, models.TaskInstance: count += session.query(model).filter( model.dag_id == parent_dag_id, model.task_id == task_id).delete() # Delete entries in Import Errors table for a deleted DAG # This handles the case when the dag_id is changed in the file session.query(models.ImportError).filter( models.ImportError.filename == dag.fileloc).delete( synchronize_session='fetch') return count
def test_should_response_200_serialized(self): # Create empty app with empty dagbag to check if DAG is read from db app_serialized = app.create_app(testing=True) dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) app_serialized.dag_bag = dag_bag client = app_serialized.test_client() SerializedDagModel.write_dag(self.dag) expected = { "class_ref": { "class_name": "DummyOperator", "module_path": "airflow.operators.dummy_operator", }, "depends_on_past": False, "downstream_task_ids": [], "end_date": None, "execution_timeout": None, "extra_links": [], "owner": "airflow", "pool": "default_pool", "pool_slots": 1.0, "priority_weight": 1.0, "queue": "default", "retries": 0.0, "retry_delay": { "__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0 }, "retry_exponential_backoff": False, "start_date": "2020-06-15T00:00:00+00:00", "task_id": "op1", "template_fields": [], "trigger_rule": "all_success", "ui_color": "#e8f7e4", "ui_fgcolor": "#000", "wait_for_downstream": False, "weight_rule": "downstream", } response = client.get( f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", environ_overrides={'REMOTE_USER': "******"}) assert response.status_code == 200 assert response.json == expected
def get_dag(self, dag_id): """ Gets the DAG out of the dictionary, and refreshes it if expired :param dag_id: DAG Id :type dag_id: str """ # Avoid circular import from airflow.models.dag import DagModel # Only read DAGs from DB if this dagbag is store_serialized_dags. if self.store_serialized_dags: # Import here so that serialized dag is only imported when serialization is enabled from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag row = SerializedDagModel.get(dag_id) if not row: return None dag = row.dag for subdag in dag.subdags: self.dags[subdag.dag_id] = subdag self.dags[dag.dag_id] = dag return self.dags.get(dag_id) # If asking for a known subdag, we want to refresh the parent dag = None root_dag_id = dag_id if dag_id in self.dags: dag = self.dags[dag_id] if dag.is_subdag: root_dag_id = dag.parent_dag.dag_id # Needs to load from file for a store_serialized_dags dagbag. enforce_from_file = False if self.store_serialized_dags and dag is not None: from airflow.serialization.serialized_objects import SerializedDAG enforce_from_file = isinstance(dag, SerializedDAG) # If the dag corresponding to root_dag_id is absent or expired orm_dag = DagModel.get_current(root_dag_id) if (orm_dag and (root_dag_id not in self.dags or (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired)) ) or enforce_from_file: # Reprocess source file found_dags = self.process_file(filepath=correct_maybe_zipped( orm_dag.fileloc), only_if_updated=False) # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [ found_dag.dag_id for found_dag in found_dags ]: return self.dags[dag_id] elif dag_id in self.dags: del self.dags[dag_id] return self.dags.get(dag_id)
def _serialize_dag_capturing_errors(dag, session): """ Try to serialize the dag to the DB, but make a note of any errors. We can't place them directly in import_errors, as this may be retried, and work the next time """ if dag.is_subdag: return [] try: # We cant use bulk_write_to_db as we want to capture each error individually dag_was_updated = SerializedDagModel.write_dag( dag, min_update_interval=settings. MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session, ) if dag_was_updated: self.log.debug("Syncing DAG permissions: %s to the DB", dag.dag_id) from airflow.www.security import ApplessAirflowSecurityManager security_manager = ApplessAirflowSecurityManager( session=session) security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control) return [] except OperationalError: raise except Exception: # pylint: disable=broad-except return [(dag.fileloc, traceback.format_exc( limit=-self.dagbag_import_error_traceback_depth))]
def _serialize_dag_capturing_errors(dag, session): """ Try to serialize the dag to the DB, but make a note of any errors. We can't place them directly in import_errors, as this may be retried, and work the next time """ if dag.is_subdag: return [] try: # We can't use bulk_write_to_db as we want to capture each error individually dag_was_updated = SerializedDagModel.write_dag( dag, min_update_interval=settings. MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session, ) if dag_was_updated: self._sync_perm_for_dag(dag, session=session) return [] except OperationalError: raise except Exception: self.log.exception("Failed to write serialized DAG: %s", dag.full_filepath) return [(dag.fileloc, traceback.format_exc( limit=-self.dagbag_import_error_traceback_depth))]
def get_dag(self, dag_id): """ Gets the DAG out of the dictionary, and refreshes it if expired :param dag_id: DAG Id :type dag_id: str """ # Avoid circular import from airflow.models.dag import DagModel if self.read_dags_from_db: # Import here so that serialized dag is only imported when serialization is enabled from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag self._add_dag_from_db(dag_id=dag_id) return self.dags.get(dag_id) # If DAG is in the DagBag, check the following # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs) # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated # 3. if (2) is yes, fetch the Serialized DAG. min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL) if ( dag_id in self.dags_last_fetched and timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs ): sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id) if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: self._add_dag_from_db(dag_id=dag_id) return self.dags.get(dag_id) # If asking for a known subdag, we want to refresh the parent dag = None root_dag_id = dag_id if dag_id in self.dags: dag = self.dags[dag_id] if dag.is_subdag: root_dag_id = dag.parent_dag.dag_id # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized? orm_dag = DagModel.get_current(root_dag_id) if not orm_dag: return self.dags.get(dag_id) # If the dag corresponding to root_dag_id is absent or expired is_missing = root_dag_id not in self.dags is_expired = (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired) if is_missing or is_expired: # Reprocess source file found_dags = self.process_file( filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False) # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]: return self.dags[dag_id] elif dag_id in self.dags: del self.dags[dag_id] return self.dags.get(dag_id)
def test_user_defined_filter_and_macros_raise_error( self, get_dag_function): """ Test that the Rendered View is able to show rendered values even for TIs that have not yet executed """ get_dag_function.return_value = SerializedDagModel.get( self.dag.dag_id).dag self.assertEqual(self.task2.bash_command, 'echo {{ fullname("Apache", "Airflow") | hello }}') url = ( '/admin/airflow/rendered?task_id=task2&dag_id=testdag&execution_date={}' .format(self.percent_encode(self.default_date))) resp = self.app.get(url, follow_redirects=True) self.assertNotIn("echo Hello Apache Airflow", resp.data.decode('utf-8')) if six.PY3: self.assertIn( "Webserver does not have access to User-defined Macros or Filters " "when Dag Serialization is enabled. Hence for the task that have not yet " "started running, please use 'airflow tasks render' for debugging the " "rendering of template_fields.<br/><br/>OriginalError: no filter named 'hello'", resp.data.decode('utf-8')) else: self.assertIn( "Webserver does not have access to User-defined Macros or Filters " "when Dag Serialization is enabled. Hence for the task that have not yet " "started running, please use 'airflow tasks render' for debugging the " "rendering of template_fields.", resp.data.decode('utf-8'))
def get_dag_by_deserialization(dag_id: str) -> "DAG": from airflow.models.serialized_dag import SerializedDagModel dag_model = SerializedDagModel.get(dag_id) if dag_model is None: raise AirflowException(f"Serialized DAG: {dag_id} could not be found") return dag_model.dag
def test_remove_dags_by_filepath(self): """DAGs can be removed from database.""" example_dags_list = list(self._write_example_dags().values()) # Remove SubDags from the list as they are not stored in DB in a separate row # and are directly added in Json blob of the main DAG filtered_example_dags_list = [ dag for dag in example_dags_list if not dag.is_subdag ] # Tests removing by file path. dag_removed_by_file = filtered_example_dags_list[0] # remove repeated files for those DAGs that define multiple dags in the same file (set comprehension) example_dag_files = list( {dag.full_filepath for dag in filtered_example_dags_list}) example_dag_files.remove(dag_removed_by_file.full_filepath) SDM.remove_deleted_dags(example_dag_files) self.assertFalse(SDM.has_dag(dag_removed_by_file.dag_id))
def get_dag(self, dag_id): """ Gets the DAG out of the dictionary, and refreshes it if expired :param dag_id: DAG Id :type dag_id: str """ # Avoid circular import from airflow.models.dag import DagModel # Only read DAGs from DB if this dagbag is read_dags_from_db. if self.read_dags_from_db: # Import here so that serialized dag is only imported when serialization is enabled from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag row = SerializedDagModel.get(dag_id) if not row: return None dag = row.dag for subdag in dag.subdags: self.dags[subdag.dag_id] = subdag self.dags[dag.dag_id] = dag return self.dags.get(dag_id) # If asking for a known subdag, we want to refresh the parent dag = None root_dag_id = dag_id if dag_id in self.dags: dag = self.dags[dag_id] if dag.is_subdag: root_dag_id = dag.parent_dag.dag_id # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized? orm_dag = DagModel.get_current(root_dag_id) if not orm_dag: return self.dags.get(dag_id) # If the dag corresponding to root_dag_id is absent or expired is_missing = root_dag_id not in self.dags is_expired = (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired) if is_missing or is_expired: # Reprocess source file found_dags = self.process_file(filepath=correct_maybe_zipped( orm_dag.fileloc), only_if_updated=False) # If the source file no longer exports `dag_id`, delete it from self.dags if found_dags and dag_id in [ found_dag.dag_id for found_dag in found_dags ]: return self.dags[dag_id] elif dag_id in self.dags: del self.dags[dag_id] return self.dags.get(dag_id)
def _deactivate_stale_dags(self, session=None): """ Detects DAGs which are no longer present in files Deactivate them and remove them in the serialized_dag table """ now = timezone.utcnow() elapsed_time_since_refresh = ( now - self.last_deactivate_stale_dags_time).total_seconds() if elapsed_time_since_refresh > self.deactivate_stale_dags_interval: last_parsed = { fp: self.get_last_finish_time(fp) for fp in self.file_paths if self.get_last_finish_time(fp) } to_deactivate = set() dags_parsed = (session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter( DagModel.is_active).all()) for dag in dags_parsed: # The largest valid difference between a DagFileStat's last_finished_time and a DAG's # last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is # no longer present in the file. if (dag.fileloc in last_parsed and (dag.last_parsed_time + self._processor_timeout) < last_parsed[dag.fileloc]): self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id) to_deactivate.add(dag.dag_id) if to_deactivate: deactivated = (session.query(DagModel).filter( DagModel.dag_id.in_(to_deactivate)).update( {DagModel.is_active: False}, synchronize_session="fetch")) if deactivated: self.log.info( "Deactivated %i DAGs which are no longer present in file.", deactivated) for dag_id in to_deactivate: SerializedDagModel.remove_dag(dag_id) self.log.info("Deleted DAG %s in serialized_dag table", dag_id) self.last_deactivate_stale_dags_time = timezone.utcnow()
def test_event_op_dag_read_write(self): class TestHandler(EventMetHandler): def met(self, ti: TaskInstance, ts: TaskState) -> TaskAction: return TaskAction.START now = timezone.utcnow() dag_id = 'test_add_taskstate_0' dag = DAG(dag_id=dag_id, start_date=now) task0 = DummyOperator(task_id='backfill_task_0', owner='test', dag=dag) task0.add_event_dependency('key', "EVENT") task0.set_event_met_handler(TestHandler()) SDM.write_dag(dag) with db.create_session() as session: sdag = session.query(SDM).first() dag = SerializedDAG.from_dict(sdag.data) self.assertEqual(dag_id, dag.dag_id) self.assertEqual( 1, len(dag.task_dict["backfill_task_0"].event_dependencies()))
def test_get_dag_with_dag_serialization(self): """ Test that Serialized DAG is updated in DagBag when it is updated in Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed. """ with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)): example_bash_op_dag = DagBag( include_examples=True).dags.get("example_bash_operator") SerializedDagModel.write_dag(dag=example_bash_op_dag) dag_bag = DagBag(read_dags_from_db=True) ser_dag_1 = dag_bag.get_dag("example_bash_operator") ser_dag_1_update_time = dag_bag.dags_last_fetched[ "example_bash_operator"] self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags) self.assertEqual(ser_dag_1_update_time, tz.datetime(2020, 1, 5, 0, 0, 0)) # Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG # from DB with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)): with assert_queries_count(0): self.assertEqual( dag_bag.get_dag("example_bash_operator").tags, ["example", "example2"]) # Make a change in the DAG and write Serialized DAG to the DB with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)): example_bash_op_dag.tags += ["new_tag"] SerializedDagModel.write_dag(dag=example_bash_op_dag) # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag' # fetches the Serialized DAG from DB with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)): with assert_queries_count(2): updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator") updated_ser_dag_1_update_time = dag_bag.dags_last_fetched[ "example_bash_operator"] self.assertCountEqual(updated_ser_dag_1.tags, ["example", "example2", "new_tag"]) self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)
def test_collect_dags_from_db(self): """DAGs are collected from Database""" example_dags_folder = airflow.example_dags.__path__[0] dagbag = DagBag(example_dags_folder) example_dags = dagbag.dags for dag in example_dags.values(): SerializedDagModel.write_dag(dag) new_dagbag = DagBag(read_dags_from_db=True) self.assertEqual(len(new_dagbag.dags), 0) new_dagbag.collect_dags_from_db() new_dags = new_dagbag.dags self.assertEqual(len(example_dags), len(new_dags)) for dag_id, dag in example_dags.items(): serialized_dag = new_dags[dag_id] self.assertEqual(serialized_dag.dag_id, dag.dag_id) self.assertEqual(set(serialized_dag.task_dict), set(dag.task_dict))
def test_read_dags(self): """DAGs can be read from database.""" example_dags = self._write_example_dags() serialized_dags = SDM.read_all_dags() self.assertTrue(len(example_dags) == len(serialized_dags)) for dag_id, dag in example_dags.items(): serialized_dag = serialized_dags[dag_id] self.assertTrue(serialized_dag.dag_id == dag.dag_id) self.assertTrue(set(serialized_dag.task_dict) == set(dag.task_dict))
def _refresh_dag_dir(self): """Refresh file paths from dag dir if we haven't done it for too long.""" now = timezone.utcnow() elapsed_time_since_refresh = ( now - self.last_dag_dir_refresh_time).total_seconds() if elapsed_time_since_refresh > self.dag_dir_list_interval: # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s", self._dag_directory) self._file_paths = list_py_file_paths(self._dag_directory) self.last_dag_dir_refresh_time = now self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) self.set_file_paths(self._file_paths) try: self.log.debug("Removing old import errors") self.clear_nonexistent_import_errors() except Exception: self.log.exception("Error removing old import errors") # Check if file path is a zipfile and get the full path of the python file. # Without this, SerializedDagModel.remove_deleted_files would delete zipped dags. # Likewise DagCode.remove_deleted_code dag_filelocs = [] for fileloc in self._file_paths: if not fileloc.endswith(".py") and zipfile.is_zipfile(fileloc): with zipfile.ZipFile(fileloc) as z: dag_filelocs.extend([ os.path.join(fileloc, info.filename) for info in z.infolist() if might_contain_dag(info.filename, True, z) ]) else: dag_filelocs.append(fileloc) SerializedDagModel.remove_deleted_dags(dag_filelocs) DagModel.deactivate_deleted_dags(self._file_paths) from airflow.models.dagcode import DagCode DagCode.remove_deleted_code(dag_filelocs)
def _add_dag_from_db(self, dag_id: str): """Add DAG to DagBag from DB""" from airflow.models.serialized_dag import SerializedDagModel row = SerializedDagModel.get(dag_id) if not row: raise ValueError(f"DAG '{dag_id}' not found in serialized_dag table") dag = row.dag for subdag in dag.subdags: self.dags[subdag.dag_id] = subdag self.dags[dag.dag_id] = dag self.dags_last_fetched[dag.dag_id] = timezone.utcnow()
def test_write_dag(self): """DAGs can be written into database.""" example_dags = self._write_example_dags() with create_session() as session: for dag in example_dags.values(): assert SDM.has_dag(dag.dag_id) result = session.query(SDM.fileloc, SDM.data).filter(SDM.dag_id == dag.dag_id).one() assert result.fileloc == dag.full_filepath # Verifies JSON schema. SerializedDAG.validate_schema(result.data)
def test_serialized_dag_is_updated_only_if_dag_is_changed(self): """Test Serialized DAG is updated if DAG is changed""" example_dags = make_example_dags(example_dags_module) example_bash_op_dag = example_dags.get("example_bash_operator") SDM.write_dag(dag=example_bash_op_dag) with create_session() as session: s_dag = session.query(SDM).get(example_bash_op_dag.dag_id) # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated # column is not updated SDM.write_dag(dag=example_bash_op_dag) s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id) self.assertEqual(s_dag_1.dag_hash, s_dag.dag_hash) self.assertEqual(s_dag.last_updated, s_dag_1.last_updated) # Update DAG example_bash_op_dag.tags += ["new_tag"] self.assertCountEqual(example_bash_op_dag.tags, ["example", "new_tag"]) SDM.write_dag(dag=example_bash_op_dag) s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id) self.assertNotEqual(s_dag.last_updated, s_dag_2.last_updated) self.assertNotEqual(s_dag.dag_hash, s_dag_2.dag_hash) self.assertEqual(s_dag_2.data["dag"]["tags"], ["example", "new_tag"])
def test_serialized_dag_is_updated_only_if_dag_is_changed(self): """Test Serialized DAG is updated if DAG is changed""" example_dags = make_example_dags(example_dags_module) example_bash_op_dag = example_dags.get("example_bash_operator") SDM.write_dag(dag=example_bash_op_dag) with create_session() as session: s_dag = session.query(SDM).get(example_bash_op_dag.dag_id) # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated # column is not updated SDM.write_dag(dag=example_bash_op_dag) s_dag_1 = session.query(SDM).get(example_bash_op_dag.dag_id) assert s_dag_1.dag_hash == s_dag.dag_hash assert s_dag.last_updated == s_dag_1.last_updated # Update DAG example_bash_op_dag.tags += ["new_tag"] assert set(example_bash_op_dag.tags) == { "example", "example2", "new_tag" } SDM.write_dag(dag=example_bash_op_dag) s_dag_2 = session.query(SDM).get(example_bash_op_dag.dag_id) assert s_dag.last_updated != s_dag_2.last_updated assert s_dag.dag_hash != s_dag_2.dag_hash assert s_dag_2.data["dag"]["tags"] == [ "example", "example2", "new_tag" ]
def test_serialized_dag_is_updated_only_if_dag_is_changed(self): """Test Serialized DAG is updated if DAG is changed""" example_dags = make_example_dags(example_dags_module) example_bash_op_dag = example_dags.get("example_bash_operator") SDM.write_dag(dag=example_bash_op_dag) with create_session() as session: last_updated = session.query(SDM.last_updated).filter( SDM.dag_id == example_bash_op_dag.dag_id).one_or_none() # Test that if DAG is not changed, Serialized DAG is not re-written and last_updated # column is not updated SDM.write_dag(dag=example_bash_op_dag) last_updated_1 = session.query(SDM.last_updated).filter( SDM.dag_id == example_bash_op_dag.dag_id).one_or_none() self.assertEqual(last_updated, last_updated_1) # Update DAG example_bash_op_dag.tags += ["new_tag"] self.assertCountEqual(example_bash_op_dag.tags, ["example", "new_tag"]) SDM.write_dag(dag=example_bash_op_dag) new_s_dag = session.query(SDM.last_updated, SDM.data).filter( SDM.dag_id == example_bash_op_dag.dag_id).one_or_none() self.assertNotEqual(last_updated, new_s_dag.last_updated) self.assertEqual(new_s_dag.data["dag"]["tags"], ["example", "new_tag"])
def setUp(self): app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.default_date = datetime(2020, 3, 1) self.dag = DAG( "testdag", start_date=self.default_date, user_defined_filters={"hello": lambda name: 'Hello ' + name}, user_defined_macros={ "fullname": lambda fname, lname: fname + " " + lname }) self.task1 = BashOperator(task_id='task1', bash_command='{{ task_instance_key_str }}', dag=self.dag) self.task2 = BashOperator( task_id='task2', bash_command='echo {{ fullname("Apache", "Airflow") | hello }}', dag=self.dag) SerializedDagModel.write_dag(self.dag) with create_session() as session: session.query(RTIF).delete()
def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow""" latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session) if dag_run.dag_hash == latest_version: self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) return dag_run.dag_hash = latest_version # Refresh the DAG dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) # Verify integrity also takes care of session.flush dag_run.verify_integrity(session=session)
def sync_to_db(self, session: Optional[Session] = None): """Save attributes about list of DAG to the DB.""" # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel # Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case # of any Operational Errors # In case of failures, provide_session handles rollback for attempt in tenacity.Retrying( retry=tenacity.retry_if_exception_type( exception_types=OperationalError), wait=tenacity.wait_random_exponential(multiplier=0.5, max=5), stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES), before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG), reraise=True, ): with attempt: self.log.debug( "Running dagbag.sync_to_db with retries. Try %d of %d", attempt.retry_state.attempt_number, settings.MAX_DB_RETRIES, ) self.log.debug("Calling the DAG.bulk_sync_to_db method") try: DAG.bulk_write_to_db(self.dags.values(), session=session) # Write Serialized DAGs to DB self.log.debug( "Calling the SerializedDagModel.bulk_sync_to_db method" ) SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session) except OperationalError: session.rollback() raise
def test_should_response_200_serialized(self): # Create empty app with empty dagbag to check if DAG is read from db app_serialized = app.create_app(testing=True) # type:ignore dag_bag = DagBag(os.devnull, include_examples=False, store_serialized_dags=True) app_serialized.dag_bag = dag_bag # type:ignore client = app_serialized.test_client() SerializedDagModel.write_dag(self.dag) expected = { 'catchup': True, 'concurrency': 16, 'dag_id': 'test_dag', 'dag_run_timeout': None, 'default_view': 'tree', 'description': None, 'doc_md': 'details', 'fileloc': __file__, 'is_paused': None, 'is_subdag': False, 'orientation': 'LR', 'owners': [], 'schedule_interval': { '__type': 'TimeDelta', 'days': 1, 'microseconds': 0, 'seconds': 0 }, 'start_date': '2020-06-15T00:00:00+00:00', 'tags': None, 'timezone': "Timezone('UTC')" } response = client.get(f"/api/v1/dags/{self.dag_id}/details") assert response.status_code == 200 assert response.json == expected
def _add_dag_from_db(self, dag_id: str, session: Session): """Add DAG to DagBag from DB""" from airflow.models.serialized_dag import SerializedDagModel row = SerializedDagModel.get(dag_id, session) if not row: raise SerializedDagNotFound(f"DAG '{dag_id}' not found in serialized_dag table") row.load_op_links = self.load_op_links dag = row.dag for subdag in dag.subdags: self.dags[subdag.dag_id] = subdag self.dags[dag.dag_id] = dag self.dags_last_fetched[dag.dag_id] = timezone.utcnow() self.dags_hash[dag.dag_id] = row.dag_hash