def _write_two_example_dags(self): example_dags = make_example_dags(example_dags_module) bash_dag = example_dags['example_bash_operator'] DagCode(bash_dag.fileloc).sync_to_db() xcom_dag = example_dags['example_xcom'] DagCode(xcom_dag.fileloc).sync_to_db() return [bash_dag, xcom_dag]
def _refresh_dag_dir(self): """Refresh file paths from dag dir if we haven't done it for too long.""" now = timezone.utcnow() elapsed_time_since_refresh = ( now - self.last_dag_dir_refresh_time).total_seconds() if elapsed_time_since_refresh > self.dag_dir_list_interval: # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s", self._dag_directory) self._file_paths = list_py_file_paths(self._dag_directory) self.last_dag_dir_refresh_time = now self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) self.set_file_paths(self._file_paths) try: self.log.debug("Removing old import errors") self.clear_nonexistent_import_errors() except Exception: self.log.exception("Error removing old import errors") SerializedDagModel.remove_deleted_dags(self._file_paths) DagModel.deactivate_deleted_dags(self._file_paths) from airflow.models.dagcode import DagCode DagCode.remove_deleted_code(self._file_paths)
def test_code_from_db(admin_client): dag = DagBag(include_examples=True).get_dag("example_bash_operator") DagCode(dag.fileloc, DagCode._get_code_from_file(dag.fileloc)).sync_to_db() url = 'code?dag_id=example_bash_operator' resp = admin_client.get(url) check_content_not_in_response('Failed to load file', resp) check_content_in_response('example_bash_operator', resp)
def _cleanup_stale_dags(self): """ Clean up any DAGs that we have not loaded recently. There are two parts to the cleanup: 1. Mark DAGs that haven't been seen as inactive 2. Delete any DAG serializations for DAGs that haven't been seen """ if 0 < self._dag_cleanup_interval < ( timezone.utcnow() - self.last_dag_cleanup_time).total_seconds(): # In the worst case Every DAG should have been processed within # file_process_interval + processor_timeout + min_serialized_dag_update_interval max_processing_time = self._processor_timeout + \ timedelta(seconds=self._file_process_interval) + \ timedelta(seconds=self._min_serialized_dag_update_interval) min_last_seen_date = timezone.utcnow() - max_processing_time self.log.info( "Deactivating DAGs that haven't been touched since %s", min_last_seen_date.isoformat()) airflow.models.DAG.deactivate_stale_dags(min_last_seen_date) if STORE_SERIALIZED_DAGS: from airflow.models.serialized_dag import SerializedDagModel SerializedDagModel.remove_stale_dags(min_last_seen_date) if self.store_dag_code: from airflow.models.dagcode import DagCode DagCode.remove_unused_code() self.last_dag_cleanup_time = timezone.utcnow()
def _refresh_dag_dir(self): """ Refresh file paths from dag dir if we haven't done it for too long. """ now = timezone.utcnow() elapsed_time_since_refresh = (now - self.last_dag_dir_refresh_time).total_seconds() if elapsed_time_since_refresh > self.dag_dir_list_interval: # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s", self._dag_directory) self._file_paths = list_py_file_paths(self._dag_directory) self.last_dag_dir_refresh_time = now self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) self.set_file_paths(self._file_paths) # noinspection PyBroadException try: self.log.debug("Removing old import errors") self.clear_nonexistent_import_errors() # pylint: disable=no-value-for-parameter except Exception: # pylint: disable=broad-except self.log.exception("Error removing old import errors") if STORE_SERIALIZED_DAGS: from airflow.models.serialized_dag import SerializedDagModel from airflow.models.dag import DagModel SerializedDagModel.remove_deleted_dags(self._file_paths) DagModel.deactivate_deleted_dags(self._file_paths) if self.store_dag_code: from airflow.models.dagcode import DagCode DagCode.remove_deleted_code(self._file_paths)
def test_bulk_sync_to_db(self): """Dg code can be bulk written into database.""" example_dags = make_example_dags(example_dags_module) files = [dag.fileloc for dag in example_dags.values()] with create_session() as session: DagCode.bulk_sync_to_db(files, session=session) session.commit() self._compare_example_dags(example_dags)
def _compare_example_dags(self, example_dags): with create_session() as session: for dag in example_dags.values(): self.assertTrue(DagCode.has_dag(dag.fileloc)) dag_fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc) result = session.query( DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code) \ .filter(DagCode.fileloc == dag.fileloc) \ .filter(DagCode.fileloc_hash == dag_fileloc_hash) \ .one() self.assertEqual(result.fileloc, dag.fileloc) with open_maybe_zipped(dag.fileloc, 'r') as source: source_code = source.read() self.assertEqual(result.source_code, source_code)
def upgrade(): """Apply add source code table""" op.create_table( 'dag_code', # pylint: disable=no-member sa.Column('fileloc_hash', sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False), sa.Column('fileloc', sa.String(length=2000), nullable=False), sa.Column('source_code', sa.UnicodeText(), nullable=False), sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False)) conn = op.get_bind() if conn.dialect.name not in ('sqlite'): op.drop_index('idx_fileloc_hash', 'serialized_dag') op.alter_column(table_name='serialized_dag', column_name='fileloc_hash', type_=sa.BigInteger(), nullable=False) op.create_index( # pylint: disable=no-member 'idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=conn) serialized_dags = session.query(SerializedDagModel).all() for dag in serialized_dags: dag.fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc) session.merge(dag) session.commit()
def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.data = SerializedDAG.to_dict(dag) self.last_updated = timezone.utcnow() self.dag_hash = hashlib.md5(json.dumps(self.data, sort_keys=True).encode("utf-8")).hexdigest()
def delete_workflow(self, project_name: Text, workflow_name: Text) -> Optional[WorkflowInfo]: dag_id = self.airflow_dag_id(project_name, workflow_name) if not self.dag_exist(dag_id): return None deploy_path = self.config.properties().get('airflow_deploy_path') if deploy_path is None: raise Exception("airflow_deploy_path config not set!") airflow_file_path = os.path.join(deploy_path, dag_id + '.py') if os.path.exists(airflow_file_path): os.remove(airflow_file_path) # stop all workflow executions self.kill_all_workflow_execution(project_name, workflow_name) # clean db meta with create_session() as session: dag = session.query(DagModel).filter( DagModel.dag_id == dag_id).first() session.query(DagTag).filter(DagTag.dag_id == dag_id).delete() session.query(DagModel).filter(DagModel.dag_id == dag_id).delete() session.query(DagCode).filter( DagCode.fileloc_hash == DagCode.dag_fileloc_hash( dag.fileloc)).delete() session.query(SerializedDagModel).filter( SerializedDagModel.dag_id == dag_id).delete() session.query(DagRun).filter(DagRun.dag_id == dag_id).delete() session.query(TaskState).filter( TaskState.dag_id == dag_id).delete() session.query(TaskInstance).filter( TaskInstance.dag_id == dag_id).delete() session.query(TaskExecution).filter( TaskExecution.dag_id == dag_id).delete() return WorkflowInfo(namespace=project_name, workflow_name=workflow_name)
def _compare_example_dags(self, example_dags): with create_session() as session: for dag in example_dags.values(): if dag.is_subdag: dag.fileloc = dag.parent_dag.fileloc assert DagCode.has_dag(dag.fileloc) dag_fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc) result = ( session.query(DagCode.fileloc, DagCode.fileloc_hash, DagCode.source_code) .filter(DagCode.fileloc == dag.fileloc) .filter(DagCode.fileloc_hash == dag_fileloc_hash) .one() ) assert result.fileloc == dag.fileloc with open_maybe_zipped(dag.fileloc, 'r') as source: source_code = source.read() assert result.source_code == source_code
def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.data = SerializedDAG.to_dict(dag) self.last_updated = timezone.utcnow() dag_event_deps = DagEventDependencies(dag) self.event_relationships = DagEventDependencies.to_json(dag_event_deps) self.dag_hash = hashlib.md5( json.dumps(self.data, sort_keys=True).encode("utf-8")).hexdigest()
def _refresh_dag_dir(self): """Refresh file paths from dag dir if we haven't done it for too long.""" now = timezone.utcnow() elapsed_time_since_refresh = ( now - self.last_dag_dir_refresh_time).total_seconds() if elapsed_time_since_refresh > self.dag_dir_list_interval: # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s", self._dag_directory) self._file_paths = list_py_file_paths(self._dag_directory) self.last_dag_dir_refresh_time = now self.log.info("There are %s files in %s", len(self._file_paths), self._dag_directory) self.set_file_paths(self._file_paths) try: self.log.debug("Removing old import errors") self.clear_nonexistent_import_errors() except Exception: self.log.exception("Error removing old import errors") # Check if file path is a zipfile and get the full path of the python file. # Without this, SerializedDagModel.remove_deleted_files would delete zipped dags. # Likewise DagCode.remove_deleted_code dag_filelocs = [] for fileloc in self._file_paths: if not fileloc.endswith(".py") and zipfile.is_zipfile(fileloc): with zipfile.ZipFile(fileloc) as z: dag_filelocs.extend([ os.path.join(fileloc, info.filename) for info in z.infolist() if might_contain_dag(info.filename, True, z) ]) else: dag_filelocs.append(fileloc) SerializedDagModel.remove_deleted_dags(dag_filelocs) DagModel.deactivate_deleted_dags(self._file_paths) from airflow.models.dagcode import DagCode DagCode.remove_deleted_code(dag_filelocs)
def test_remove_unused_code(self): example_dags = make_example_dags(example_dags_module) self._write_example_dags() bash_dag = example_dags['example_bash_operator'] with create_session() as session: for model in models.base.Base._decl_class_registry.values(): # pylint: disable=protected-access if hasattr(model, "dag_id"): session.query(model) \ .filter(model.dag_id == bash_dag.dag_id) \ .delete(synchronize_session='fetch') self.assertEqual( session.query(DagCode).filter( DagCode.fileloc == bash_dag.fileloc).count(), 1) DagCode.remove_unused_code() self.assertEqual( session.query(DagCode).filter( DagCode.fileloc == bash_dag.fileloc).count(), 0)
def get_code(dag_id: str) -> str: """Return python code of a given dag_id. :param dag_id: DAG id :return: code of the DAG """ dag = check_and_get_dag(dag_id=dag_id) try: return DagCode.get_code_by_fileloc(dag.fileloc) except (OSError, DagCodeNotFound) as exception: error_message = f"Error {str(exception)} while reading Dag id {dag_id} Code" raise AirflowException(error_message, exception)
def recover(self, last_scheduling_id): lost_dag_codes = DagCode.recover_lost_dag_code() self.log.info( "Found %s dags not exists in DAG folder, recovered from DB. Dags' path: %s", len(lost_dag_codes), lost_dag_codes) self.log.info("Waiting for executor recovery...") self.executor.recover_state() unprocessed_messages = self.get_unprocessed_message(last_scheduling_id) self.log.info( "Recovering %s messages of last scheduler job with id: %s", len(unprocessed_messages), last_scheduling_id) for msg in unprocessed_messages: self.mailbox.send_identified_message(msg)
def test_code_can_be_read_when_no_access_to_file(self): """ Test that code can be retrieved from DB when you do not have access to Code file. Source Code should atleast exist in one of DB or File. """ example_dag = make_example_dags(example_dags_module).get('example_bash_operator') example_dag.sync_to_db() # Mock that there is no access to the Dag File with patch('airflow.models.dagcode.open_maybe_zipped') as mock_open: mock_open.side_effect = FileNotFoundError dag_code = DagCode.get_code_by_fileloc(example_dag.fileloc) for test_string in ['example_bash_operator', 'also_run_this', 'run_this_last']: assert test_string in dag_code
def upgrade(): """Create DagCode Table.""" from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() class SerializedDagModel(Base): __tablename__ = 'serialized_dag' # There are other columns here, but these are the only ones we need for the SELECT/UPDATE we are doing dag_id = sa.Column(sa.String(250), primary_key=True) fileloc = sa.Column(sa.String(2000), nullable=False) fileloc_hash = sa.Column(sa.BigInteger, nullable=False) """Apply add source code table""" op.create_table( 'dag_code', # pylint: disable=no-member sa.Column('fileloc_hash', sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False), sa.Column('fileloc', sa.String(length=2000), nullable=False), sa.Column('source_code', sa.UnicodeText(), nullable=False), sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False), ) conn = op.get_bind() if conn.dialect.name != 'sqlite': if conn.dialect.name == "mssql": op.drop_index('idx_fileloc_hash', 'serialized_dag') op.alter_column(table_name='serialized_dag', column_name='fileloc_hash', type_=sa.BigInteger(), nullable=False) if conn.dialect.name == "mssql": op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash']) sessionmaker = sa.orm.sessionmaker() session = sessionmaker(bind=conn) serialized_dags = session.query(SerializedDagModel).all() for dag in serialized_dags: dag.fileloc_hash = DagCode.dag_fileloc_hash(dag.fileloc) session.merge(dag) session.commit()
def get_dag_source(file_token: str): """Get source code using file token""" secret_key = current_app.config["SECRET_KEY"] auth_s = URLSafeSerializer(secret_key) try: path = auth_s.loads(file_token) dag_source = DagCode.code(path) except (BadSignature, FileNotFoundError): raise NotFound("Dag source not found") return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json']) if return_type == 'text/plain': return Response(dag_source, headers={'Content-Type': return_type}) if return_type == 'application/json': content = dag_source_schema.dumps(dict(content=dag_source)) return Response(content, headers={'Content-Type': return_type}) return Response("Not Allowed Accept Header", status=406)
def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): """Deletes DAGs not included in alive_dag_filelocs. :param alive_dag_filelocs: file paths of alive DAGs :param session: ORM Session """ alive_fileloc_hashes = [ DagCode.dag_fileloc_hash(fileloc) for fileloc in alive_dag_filelocs ] log.debug( "Deleting Serialized DAGs (for which DAG files are deleted) from %s table ", cls.__tablename__) session.execute(cls.__table__.delete().where( and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs))))
def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.fileloc self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() dag_data = SerializedDAG.to_dict(dag) dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") self.dag_hash = hashlib.md5(dag_data_json).hexdigest() if COMPRESS_SERIALIZED_DAGS: self._data = None self._data_compressed = zlib.compress(dag_data_json) else: self._data = dag_data self._data_compressed = None # serve as cache so no need to decompress and load, when accessing data field # when COMPRESS_SERIALIZED_DAGS is True self.__data_cache = dag_data
def test_dag_fileloc_hash(self): """Verifies the correctness of hashing file path.""" self.assertEqual(DagCode.dag_fileloc_hash('/airflow/dags/test_dag.py'), 33826252060516589)
def _write_example_dags(self): example_dags = make_example_dags(example_dags_module) for dag in example_dags.values(): DagCode(dag.fileloc).sync_to_db() return example_dags
def __init__(self, dag): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.data = SerializedDAG.to_dict(dag) self.last_updated = timezone.utcnow()
def test_write_dag(self): self._write_example_dags() self.assertEquals(len(DagCode.recover_lost_dag_code()), 0)