def _populate_db(self): session = Session() session.query(DagRun).delete() session.query(TaskInstance).delete() session.commit() session.close() dagbag = DagBag( include_examples=True, dag_folder=self.PAPERMILL_EXAMPLE_DAGS, ) for dag in dagbag.dags.values(): dag.sync_to_db() SerializedDagModel.write_dag(dag)
def setUp(self): super(TestKnownEventView, self).setUp() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() self.known_event = { 'label': 'event-label', 'event_type': '1', 'start_date': '2017-06-05 12:00:00', 'end_date': '2017-06-05 13:00:00', 'reported_by': self.user_id, 'description': '', }
def setUp(self, mock_get_connection): self.assertEqual(TaskInstance._sentry_integration_, True) mock_get_connection.return_value = Connection( host="https://[email protected]/123") self.sentry_hook = SentryHook("sentry_default") self.assertEqual(TaskInstance._sentry_integration_, True) self.dag = Mock(dag_id=DAG_ID) self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID) self.task.__class__.__name__ = OPERATOR self.task.get_flat_relatives = MagicMock(return_value=[self.task]) self.session = Session() self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE) self.session.query = MagicMock(return_value=MockQuery(self.ti))
def tearDown(self): super(PythonOperatorTest, self).tearDown() session = Session() session.query(DagRun).delete() session.query(TI).delete() print(len(session.query(DagRun).all())) session.commit() session.close() for var in TI_CONTEXT_ENV_VARS: if var in os.environ: del os.environ[var]
def make_subdag(subdag_id, parent_dag_id, start_date, schedule_interval): subdag = DAG('%s.%s' % (parent_dag_id, subdag_id), start_date=start_date, schedule_interval=schedule_interval, ) session = Session() sensors = session.query(SQLSensor) \ .filter(SQLSensor.enabled == True) \ .all() graphs = [] for sensor in sensors: graph = SQLSensorOperator( task_id=sensor.label, sql_sensor_id=sensor.id, conn_id=sensor.connection.conn_id, pool=sensor.pool.pool if sensor.pool else None, poke_interval=sensor.poke_interval, timeout=sensor.timeout, ttl=sensor.ttl, main_argument=sensor.main_argument, sql=sensor.sql, dag=subdag, ) graphs.append(graph) def get_parent_ids(sql_sensor_id): parent_labels = session.query(SQLSensor.parent_labels) \ .filter(SQLSensor.id == sql_sensor_id) \ .first()[0] parent_ids = [] for label in parent_labels or []: result = session.query(SQLSensor.id) \ .filter(SQLSensor.label == label) \ .first()[0] if not result: raise AirflowException("Parent sensor with label '%s' doesn't exists." % label) parent_ids.append(result) return parent_ids for graph in graphs: parent_ids = get_parent_ids(graph.sql_sensor_id) for parent_id in parent_ids: for parent_graph in graphs: if parent_id == parent_graph.sql_sensor_id: graph.set_upstream(parent_graph) return subdag
def __update_conn_extra_tokens(self, auth, connection): from airflow.settings import Session conn_extra = connection.extra_dejson conn_extra['twitter_access_token'] = auth.access_token conn_extra['twitter_access_token_secret'] = auth.access_token_secret connection.set_extra(conn_extra) session = Session() session.add(connection) session.commit() self.log.info(f'Connection {connection.conn_id} updated with "twitter_access_token"' f'and "twitter_access_token_secret" values.')
def test_charts(self): session = Session() chart_label = "Airflow task instance by type" chart = session.query( models.Chart).filter(models.Chart.label==chart_label).first() chart_id = chart.id session.close() response = self.app.get( '/admin/airflow/chart' '?chart_id={}&iteration_no=1'.format(chart_id)) assert "Airflow task instance by type" in response.data.decode('utf-8') response = self.app.get( '/admin/airflow/chart_data' '?chart_id={}&iteration_no=1'.format(chart_id)) assert "example" in response.data.decode('utf-8')
def unpause_dag(dag): """ Wrapper around airflow.bin.cli.unpause. The issue is when we deploy the airflow dags they don't exist in the DagModel yet, so need to check if it exists first and then run the unpause. :param dag: DAG object """ session = Session() try: dm = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).first() if dm: unpause(dag.default_args, dag) except: session.rollback() finally: session.close()
def _get_target_dags(self): session = Session() active_dags = session.query(DagModel.dag_id).filter( DagModel.is_paused.is_(False)).all() if self.dag_ids is None: target_dags = active_dags # subdags always included else: target_dags = [ dag_id for dag_id in active_dags if True in [ dag_id.startswith(dag_id_mask) for dag_id_mask in self.dag_ids if (self.include_subdags or dag_id.count('.') == dag_id_mask.count('.')) ] ] return target_dags
def tearDown(self): if os.environ.get('KUBERNETES_VERSION') is not None: return session = Session() session.query(DagRun).filter( DagRun.dag_id == TEST_DAG_ID).delete( synchronize_session=False) session.query(TaskInstance).filter( TaskInstance.dag_id == TEST_DAG_ID).delete( synchronize_session=False) session.query(TaskFail).filter( TaskFail.dag_id == TEST_DAG_ID).delete( synchronize_session=False) session.commit() session.close()
def setUp(self): app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun(run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True) self.runs.append(run)
def setUp(self): super(TestPoolApiExperimental, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() self.pools = [] for i in range(2): name = 'experimental_%s' % (i + 1) pool = Pool( pool=name, slots=i, description=name, ) self.session.add(pool) self.pools.append(pool) self.session.commit() self.pool = self.pools[0]
def test_delete_dag_button_for_dag_on_scheduler_only(self): # Test for JIRA AIRFLOW-3233 (PR 4069): # The delete-dag URL should be generated correctly for DAGs # that exist on the scheduler (DB) but not the webserver DagBag test_dag_id = "non_existent_dag" session = Session() DM = models.DagModel session.query(DM).filter(DM.dag_id == 'example_bash_operator').update({'dag_id': test_dag_id}) session.commit() resp = self.app.get('/', follow_redirects=True) self.assertIn('/delete?dag_id={}'.format(test_dag_id), resp.data.decode('utf-8')) self.assertIn("return confirmDeleteDag('{}')".format(test_dag_id), resp.data.decode('utf-8')) session.query(DM).filter(DM.dag_id == test_dag_id).update({'dag_id': 'example_bash_operator'}) session.commit()
def clear_dag(dag): """ Delete all TaskInstances and DagRuns of the specified dag_id. :param dag: DAG object """ session = Session() try: session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id).delete() session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).delete() session.query(DagStat).filter(DagStat.dag_id == dag.dag_id).delete() session.commit() log_dir = conf.get('core', 'base_log_folder') full_dir = os.path.join(log_dir, dag.dag_id) shutil.rmtree(full_dir, ignore_errors=True) except: session.rollback() finally: session.close()
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) session.close() for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should not exist raise elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise
def setUp(self): super(TestPoolApiExperimental, self).setUp() app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() clear_db_pools() self.pools = [Pool.get_default_pool()] for i in range(self.USER_POOL_COUNT): name = 'experimental_%s' % (i + 1) pool = Pool( pool=name, slots=i, description=name, ) self.session.add(pool) self.pools.append(pool) self.session.commit() self.pool = self.pools[-1]
def setUp(self): self.sentry = ConfiguredSentry() # Mock the Dag self.dag = Mock(dag_id=DAG_ID, params=[]) self.dag.task_ids = [TASK_ID] # Mock the task self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1) self.task.__class__.__name__ = OPERATOR self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE) self.ti.operator = OPERATOR self.ti.state = STATE self.dag.get_task_instances = MagicMock(return_value=[self.ti]) self.session = Session()
def test_edit_disabled_fields(self): response = self.app.post( self.EDIT_URL, data={ "fileloc": "/etc/passwd", "description": "Set in tests", }, follow_redirects=True, ) self.assertEqual(response.status_code, 200) session = Session() DM = models.DagModel dm = session.query(DM).filter( DM.dag_id == 'example_bash_operator').one() session.close() self.assertEqual(dm.description, "Set in tests") self.assertNotEqual(dm.fileloc, "/etc/passwd", "Disabled fields shouldn't be updated")
def setUp(self): configuration.conf.set("webserver", "authenticate", "True") configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.password_auth") app = application.create_app() app.config['TESTING'] = True self.app = app.test_client() from airflow.contrib.auth.backends.password_auth import PasswordUser session = Session() user = models.User() password_user = PasswordUser(user) password_user.username = '******' password_user.password = '******' print(password_user._password) session.add(password_user) session.commit() session.close()
def setUp(self): configuration.load_test_config() try: configuration.conf.add_section("api") except DuplicateSectionError: pass configuration.conf.set("api", "auth_backend", "airflow.contrib.auth.backends.password_auth") self.app = application.create_app(testing=True) session = Session() user = models.User() password_user = PasswordUser(user) password_user.username = '******' password_user.password = '******' session.add(password_user) session.commit() session.close()
def registered(self, driver, frameworkId, masterInfo): logging.info("AirflowScheduler registered to mesos with framework ID %s", frameworkId.value) if configuration.getboolean('mesos', 'CHECKPOINT') and configuration.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error from airflow.models import Connection # Update the Framework ID in the database. session = Session() conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name() connection = Session.query(Connection).filter_by(conn_id=conn_id).first() if connection is None: connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id', extra=frameworkId.value) else: connection.extra = frameworkId.value session.add(connection) session.commit() Session.remove()
def initialize_gcp_connection(connection_id, gcp_project_name, gcp_key_path): """ Method to initialize Google Cloud Platform Connection :param connection_id: ID of the Airflow connection :param gcp_project_name: Name of the GCP project :param gcp_key_path: Path of the associated GCP key in the project Note: Modified date: 11-04-2021 Author: TB """ def create_new_connection(airflow_session, attributes): new_conn = models.Connection() new_conn.conn_id = attributes.get("conn_id") new_conn.conn_type = attributes.get('conn_type') scopes = [ "https://www.googleapis.com/auth/datastore", "https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/devstorage.read_write", "https://www.googleapis.com/auth/logging.write", "https://www.googleapis.com/auth/cloud-platform", ] conn_extra = { "extra__google_cloud_platform__scope": ",".join(scopes), "extra__google_cloud_platform__project": gcp_project_name, "extra__google_cloud_platform__key_path": gcp_key_path } conn_extra_json = json.dumps(conn_extra) new_conn.set_extra(conn_extra_json) airflow_session.add(new_conn) airflow_session.commit() session = Session() create_new_connection(session, { "conn_id": connection_id, "conn_type": "google_cloud_platform" }) session.commit() session.close()
def get_filter_by_user_dagids_detail(): res=get_filter_by_user() curr_user = airflow.login.current_user session = Session() dags=[] if res : result = session.execute(""" SELECT dag_name FROM dcmp_dag WHERE last_editor_user_id IN ( select user_id from dcmp_user_profile where user_id = %s ) """ % curr_user.user.id ) if result.rowcount > 0: records = result.fetchall() session.close() for x in records: dags.append(x.dag_name) return dags
def test_import_variable_fail(self): with mock.patch('airflow.models.Variable.set') as set_mock: set_mock.side_effect = UnicodeEncodeError content = '{"fail_key": "fail_val"}' try: # python 3+ bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) except TypeError: # python 2.7 bytes_content = io.BytesIO(bytes(content)) response = self.app.post( self.IMPORT_ENDPOINT, data={'file': (bytes_content, 'test.json')}, follow_redirects=True ) self.assertEqual(response.status_code, 200) session = Session() db_dict = {x.key: x.get_val() for x in session.query(models.Variable).all()} session.close() self.assertNotIn('fail_key', db_dict)
def execute(self, context: dict[str, Any]) -> None: # Do not trigger next dag when param no_next_dag is available # Due to bug in Airflow, dagrun misses 'conf' attribute # when DAG is triggered from another DAG dag_run = context["dag_run"] if dag_run is not None and (getattr(dag_run, "conf") or {}).get("no_next_dag"): self.log.info( "Not starting next dag ('no_next_dag' in dag_run config)!") return current_dag_id = self.dag.dag_id self.log.info("Starting dag %s", current_dag_id) session = Session() active_dag_ids = [ d.dag_id for d in session.query(DagModel).filter(not_(DagModel.is_paused)). filter(DagModel.dag_id.like(f"{self.dag_id_prefix}%")).order_by( "dag_id") ] try: current_dag_idx = active_dag_ids.index(current_dag_id) except ValueError: self.log.error("Current dag %s is not active.", current_dag_id) return try: self.trigger_dag_id = active_dag_ids[current_dag_idx + 1] self.log.info("Next dag to trigger %s", self.trigger_dag_id) except IndexError: self.log.info("Current dag %s is the last dag.", current_dag_id) return # If the next Dag is currently running, we do not trigger it if DagRun.find(dag_id=self.trigger_dag_id, state=State.RUNNING): self.log.info("Not starting next dag %s, it is still running.", self.trigger_dag_id) return super().execute(context)
def get_filter_by_user_dagid(dagid): res=get_filter_by_user() curr_user = airflow.login.current_user session = Session() if res : result = session.execute(""" SELECT dag_name FROM dcmp_dag WHERE last_editor_user_id = %s """ % curr_user.user.id ) if result.rowcount > 0: records = result.fetchall() dags=[] for x in records: dags.append(x.dag_name) return ( dagid in dags ) # dagid 是否为用为的dagid else: return False session.close() else: # 无限制 return True
def setUp(self): super(TestLogView, self).setUp() # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) conf.load_test_config() # Create a custom logging configuration configuration.load_test_config() logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging_config['handlers']['task']['filename_template'] = \ '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") new_logging_file = "LOGGING_CONFIG = {}".format(logging_config) with open(settings_file, 'w') as handle: handle.writelines(new_logging_file) sys.path.append(self.settings_folder) conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG') app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit()
def setUp(self): super(TestLogView, self).setUp() configuration.load_test_config() logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['file.task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging.config.dictConfig(logging_config) app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit()
class RepomanAirflowPlugin(AirflowPlugin): """ Airflow Plugin """ def __init__(self): DAGRepo.__table__.create(engine, checkfirst=True) super(RepomanAirflowPlugin, self).__init__() name = "airflow_repoman" appbuilder_views = [{ "name": "DAG Repos", "category": "Admin", "view": DAGRepoView() }] admin_views = [ DAGRepoAdminView(DAGRepo, Session(), category="Admin", name="DAG Repos") ]
def run_dag_for_each_file(dag_to_trigger, **context) -> None: file_names = get_return_value_from_previous_task(context) message = 'None type passed from previous task. Accepted types are set, list or tuple.' assert file_names is not None, message session = Session() files_triggered = [] for file_name in file_names: # check if a file has already been triggered for processing if session.query(DagRun).filter(and_(DagRun.run_id.startswith(file_name + '_'), DagRun.state == 'running')).first(): continue trigger_dag(dag_id=dag_to_trigger, run_id='{}_{}'.format(file_name, uuid4()), conf=json.dumps({'file': file_name}), execution_date=None, replace_microseconds=False) files_triggered.append(file_name) logger.info('triggered %s for %s files: %s' % (dag_to_trigger, len(files_triggered), files_triggered))