示例#1
0
    def _populate_db(self):
        session = Session()
        session.query(DagRun).delete()
        session.query(TaskInstance).delete()
        session.commit()
        session.close()

        dagbag = DagBag(
            include_examples=True,
            dag_folder=self.PAPERMILL_EXAMPLE_DAGS,
        )
        for dag in dagbag.dags.values():
            dag.sync_to_db()
            SerializedDagModel.write_dag(dag)
示例#2
0
 def setUp(self):
     super(TestKnownEventView, self).setUp()
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     self.known_event = {
         'label': 'event-label',
         'event_type': '1',
         'start_date': '2017-06-05 12:00:00',
         'end_date': '2017-06-05 13:00:00',
         'reported_by': self.user_id,
         'description': '',
     }
    def setUp(self, mock_get_connection):
        self.assertEqual(TaskInstance._sentry_integration_, True)
        mock_get_connection.return_value = Connection(
            host="https://[email protected]/123")
        self.sentry_hook = SentryHook("sentry_default")
        self.assertEqual(TaskInstance._sentry_integration_, True)
        self.dag = Mock(dag_id=DAG_ID)
        self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID)
        self.task.__class__.__name__ = OPERATOR
        self.task.get_flat_relatives = MagicMock(return_value=[self.task])

        self.session = Session()
        self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE)
        self.session.query = MagicMock(return_value=MockQuery(self.ti))
示例#4
0
    def tearDown(self):
        super(PythonOperatorTest, self).tearDown()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        print(len(session.query(DagRun).all()))
        session.commit()
        session.close()

        for var in TI_CONTEXT_ENV_VARS:
            if var in os.environ:
                del os.environ[var]
示例#5
0
def make_subdag(subdag_id, parent_dag_id, start_date, schedule_interval):
    subdag = DAG('%s.%s' % (parent_dag_id, subdag_id),
                 start_date=start_date,
                 schedule_interval=schedule_interval,
                 )

    session = Session()
    sensors = session.query(SQLSensor) \
        .filter(SQLSensor.enabled == True) \
        .all()

    graphs = []
    for sensor in sensors:
        graph = SQLSensorOperator(
            task_id=sensor.label,
            sql_sensor_id=sensor.id,
            conn_id=sensor.connection.conn_id,
            pool=sensor.pool.pool if sensor.pool else None,
            poke_interval=sensor.poke_interval,
            timeout=sensor.timeout,
            ttl=sensor.ttl,
            main_argument=sensor.main_argument,
            sql=sensor.sql,
            dag=subdag,
        )
        graphs.append(graph)

    def get_parent_ids(sql_sensor_id):
        parent_labels = session.query(SQLSensor.parent_labels) \
            .filter(SQLSensor.id == sql_sensor_id) \
            .first()[0]
        parent_ids = []
        for label in parent_labels or []:
            result = session.query(SQLSensor.id) \
                .filter(SQLSensor.label == label) \
                .first()[0]
            if not result:
                raise AirflowException("Parent sensor with label '%s' doesn't exists." % label)
            parent_ids.append(result)
        return parent_ids

    for graph in graphs:
        parent_ids = get_parent_ids(graph.sql_sensor_id)
        for parent_id in parent_ids:
            for parent_graph in graphs:
                if parent_id == parent_graph.sql_sensor_id:
                    graph.set_upstream(parent_graph)

    return subdag
示例#6
0
    def __update_conn_extra_tokens(self, auth, connection):
        from airflow.settings import Session

        conn_extra = connection.extra_dejson

        conn_extra['twitter_access_token'] = auth.access_token
        conn_extra['twitter_access_token_secret'] = auth.access_token_secret
        connection.set_extra(conn_extra)

        session = Session()
        session.add(connection)
        session.commit()

        self.log.info(f'Connection {connection.conn_id} updated with "twitter_access_token"'
                      f'and "twitter_access_token_secret" values.')
示例#7
0
 def test_charts(self):
     session = Session()
     chart_label = "Airflow task instance by type"
     chart = session.query(
         models.Chart).filter(models.Chart.label==chart_label).first()
     chart_id = chart.id
     session.close()
     response = self.app.get(
         '/admin/airflow/chart'
         '?chart_id={}&iteration_no=1'.format(chart_id))
     assert "Airflow task instance by type" in response.data.decode('utf-8')
     response = self.app.get(
         '/admin/airflow/chart_data'
         '?chart_id={}&iteration_no=1'.format(chart_id))
     assert "example" in response.data.decode('utf-8')
示例#8
0
def unpause_dag(dag):
    """
    Wrapper around airflow.bin.cli.unpause. The issue is when we deploy the airflow dags they don't exist
    in the DagModel yet, so need to check if it exists first and then run the unpause.
    :param dag: DAG object
    """
    session = Session()
    try:
        dm = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).first()
        if dm:
            unpause(dag.default_args, dag)
    except:
        session.rollback()
    finally:
        session.close()
 def _get_target_dags(self):
     session = Session()
     active_dags = session.query(DagModel.dag_id).filter(
         DagModel.is_paused.is_(False)).all()
     if self.dag_ids is None:
         target_dags = active_dags  # subdags always included
     else:
         target_dags = [
             dag_id for dag_id in active_dags if True in [
                 dag_id.startswith(dag_id_mask)
                 for dag_id_mask in self.dag_ids
                 if (self.include_subdags
                     or dag_id.count('.') == dag_id_mask.count('.'))
             ]
         ]
     return target_dags
示例#10
0
    def tearDown(self):
        if os.environ.get('KUBERNETES_VERSION') is not None:
            return

        session = Session()
        session.query(DagRun).filter(
            DagRun.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.query(TaskFail).filter(
            TaskFail.dag_id == TEST_DAG_ID).delete(
            synchronize_session=False)
        session.commit()
        session.close()
示例#11
0
 def setUp(self):
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     from airflow.www.views import dagbag
     from airflow.utils.state import State
     dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
     dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
     self.runs = []
     for rd in self.RUNS_DATA:
         run = dag.create_dagrun(run_id=rd[0],
                                 execution_date=rd[1],
                                 state=State.SUCCESS,
                                 external_trigger=True)
         self.runs.append(run)
示例#12
0
 def setUp(self):
     super(TestPoolApiExperimental, self).setUp()
     configuration.load_test_config()
     app = application.create_app(testing=True)
     self.app = app.test_client()
     self.session = Session()
     self.pools = []
     for i in range(2):
         name = 'experimental_%s' % (i + 1)
         pool = Pool(
             pool=name,
             slots=i,
             description=name,
         )
         self.session.add(pool)
         self.pools.append(pool)
     self.session.commit()
     self.pool = self.pools[0]
示例#13
0
    def test_delete_dag_button_for_dag_on_scheduler_only(self):
        # Test for JIRA AIRFLOW-3233 (PR 4069):
        # The delete-dag URL should be generated correctly for DAGs
        # that exist on the scheduler (DB) but not the webserver DagBag

        test_dag_id = "non_existent_dag"

        session = Session()
        DM = models.DagModel
        session.query(DM).filter(DM.dag_id == 'example_bash_operator').update({'dag_id': test_dag_id})
        session.commit()

        resp = self.app.get('/', follow_redirects=True)
        self.assertIn('/delete?dag_id={}'.format(test_dag_id), resp.data.decode('utf-8'))
        self.assertIn("return confirmDeleteDag('{}')".format(test_dag_id), resp.data.decode('utf-8'))

        session.query(DM).filter(DM.dag_id == test_dag_id).update({'dag_id': 'example_bash_operator'})
        session.commit()
示例#14
0
def clear_dag(dag):
    """
    Delete all TaskInstances and DagRuns of the specified dag_id.
    :param dag: DAG object
    """
    session = Session()
    try:
        session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id).delete()
        session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).delete()
        session.query(DagStat).filter(DagStat.dag_id == dag.dag_id).delete()
        session.commit()
        log_dir = conf.get('core', 'base_log_folder')
        full_dir = os.path.join(log_dir, dag.dag_id)
        shutil.rmtree(full_dir, ignore_errors=True)
    except:
        session.rollback()
    finally:
        session.close()
示例#15
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        session = Session()
        tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
                                       TI.execution_date == DEFAULT_DATE)
        session.close()

        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEquals(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1':
                # should not exist
                raise
            elif ti.task_id == 'branch_2':
                self.assertEquals(ti.state, State.SKIPPED)
            else:
                raise
示例#16
0
    def setUp(self):
        super(TestPoolApiExperimental, self).setUp()
        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()

        clear_db_pools()
        self.pools = [Pool.get_default_pool()]
        for i in range(self.USER_POOL_COUNT):
            name = 'experimental_%s' % (i + 1)
            pool = Pool(
                pool=name,
                slots=i,
                description=name,
            )
            self.session.add(pool)
            self.pools.append(pool)
        self.session.commit()
        self.pool = self.pools[-1]
示例#17
0
    def setUp(self):

        self.sentry = ConfiguredSentry()

        # Mock the Dag
        self.dag = Mock(dag_id=DAG_ID, params=[])
        self.dag.task_ids = [TASK_ID]

        # Mock the task
        self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1)
        self.task.__class__.__name__ = OPERATOR

        self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE)
        self.ti.operator = OPERATOR
        self.ti.state = STATE

        self.dag.get_task_instances = MagicMock(return_value=[self.ti])

        self.session = Session()
示例#18
0
    def test_edit_disabled_fields(self):
        response = self.app.post(
            self.EDIT_URL,
            data={
                "fileloc": "/etc/passwd",
                "description": "Set in tests",
            },
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)
        session = Session()
        DM = models.DagModel
        dm = session.query(DM).filter(
            DM.dag_id == 'example_bash_operator').one()
        session.close()

        self.assertEqual(dm.description, "Set in tests")
        self.assertNotEqual(dm.fileloc, "/etc/passwd",
                            "Disabled fields shouldn't be updated")
示例#19
0
    def setUp(self):
        configuration.conf.set("webserver", "authenticate", "True")
        configuration.conf.set("webserver", "auth_backend",
                               "airflow.contrib.auth.backends.password_auth")

        app = application.create_app()
        app.config['TESTING'] = True
        self.app = app.test_client()
        from airflow.contrib.auth.backends.password_auth import PasswordUser

        session = Session()
        user = models.User()
        password_user = PasswordUser(user)
        password_user.username = '******'
        password_user.password = '******'
        print(password_user._password)
        session.add(password_user)
        session.commit()
        session.close()
示例#20
0
    def setUp(self):
        configuration.load_test_config()
        try:
            configuration.conf.add_section("api")
        except DuplicateSectionError:
            pass

        configuration.conf.set("api", "auth_backend",
                               "airflow.contrib.auth.backends.password_auth")

        self.app = application.create_app(testing=True)

        session = Session()
        user = models.User()
        password_user = PasswordUser(user)
        password_user.username = '******'
        password_user.password = '******'
        session.add(password_user)
        session.commit()
        session.close()
示例#21
0
    def registered(self, driver, frameworkId, masterInfo):
        logging.info("AirflowScheduler registered to mesos with framework ID %s", frameworkId.value)

        if configuration.getboolean('mesos', 'CHECKPOINT') and configuration.get('mesos', 'FAILOVER_TIMEOUT'):
            # Import here to work around a circular import error
            from airflow.models import Connection

            # Update the Framework ID in the database.
            session = Session()
            conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name()
            connection = Session.query(Connection).filter_by(conn_id=conn_id).first()
            if connection is None:
                connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id',
                                        extra=frameworkId.value)
            else:
                connection.extra = frameworkId.value

            session.add(connection)
            session.commit()
            Session.remove()
def initialize_gcp_connection(connection_id, gcp_project_name, gcp_key_path):
    """
    Method to initialize Google Cloud Platform Connection
    :param connection_id: ID of the Airflow connection
    :param gcp_project_name: Name of the GCP project
    :param gcp_key_path: Path of the associated GCP key in the project
    Note:
        Modified date: 11-04-2021
        Author: TB
    """
    def create_new_connection(airflow_session, attributes):
        new_conn = models.Connection()
        new_conn.conn_id = attributes.get("conn_id")
        new_conn.conn_type = attributes.get('conn_type')

        scopes = [
            "https://www.googleapis.com/auth/datastore",
            "https://www.googleapis.com/auth/bigquery",
            "https://www.googleapis.com/auth/devstorage.read_write",
            "https://www.googleapis.com/auth/logging.write",
            "https://www.googleapis.com/auth/cloud-platform",
        ]

        conn_extra = {
            "extra__google_cloud_platform__scope": ",".join(scopes),
            "extra__google_cloud_platform__project": gcp_project_name,
            "extra__google_cloud_platform__key_path": gcp_key_path
        }

        conn_extra_json = json.dumps(conn_extra)
        new_conn.set_extra(conn_extra_json)
        airflow_session.add(new_conn)
        airflow_session.commit()

    session = Session()
    create_new_connection(session, {
        "conn_id": connection_id,
        "conn_type": "google_cloud_platform"
    })
    session.commit()
    session.close()
示例#23
0
def get_filter_by_user_dagids_detail():
    res=get_filter_by_user()
    curr_user = airflow.login.current_user
    session = Session()
    dags=[]
    if res :
        result = session.execute(""" 
                                     SELECT dag_name 
                                       FROM dcmp_dag 
                                      WHERE last_editor_user_id  IN ( 
                                                    select user_id 
                                                      from dcmp_user_profile 
                                                     where user_id = %s 
                                                )
                                 """  %   curr_user.user.id )
        if result.rowcount > 0:
            records = result.fetchall()
        session.close()
        for x in records:
            dags.append(x.dag_name)
    return dags
示例#24
0
    def test_import_variable_fail(self):
        with mock.patch('airflow.models.Variable.set') as set_mock:
            set_mock.side_effect = UnicodeEncodeError
            content = '{"fail_key": "fail_val"}'

            try:
                # python 3+
                bytes_content = io.BytesIO(bytes(content, encoding='utf-8'))
            except TypeError:
                # python 2.7
                bytes_content = io.BytesIO(bytes(content))
            response = self.app.post(
                self.IMPORT_ENDPOINT,
                data={'file': (bytes_content, 'test.json')},
                follow_redirects=True
            )
            self.assertEqual(response.status_code, 200)
            session = Session()
            db_dict = {x.key: x.get_val() for x in session.query(models.Variable).all()}
            session.close()
            self.assertNotIn('fail_key', db_dict)
    def execute(self, context: dict[str, Any]) -> None:
        # Do not trigger next dag when param no_next_dag is available
        # Due to bug in Airflow, dagrun misses 'conf' attribute
        # when DAG is triggered from another DAG
        dag_run = context["dag_run"]
        if dag_run is not None and (getattr(dag_run, "conf")
                                    or {}).get("no_next_dag"):
            self.log.info(
                "Not starting next dag ('no_next_dag' in dag_run config)!")
            return
        current_dag_id = self.dag.dag_id
        self.log.info("Starting dag %s", current_dag_id)
        session = Session()
        active_dag_ids = [
            d.dag_id
            for d in session.query(DagModel).filter(not_(DagModel.is_paused)).
            filter(DagModel.dag_id.like(f"{self.dag_id_prefix}%")).order_by(
                "dag_id")
        ]

        try:
            current_dag_idx = active_dag_ids.index(current_dag_id)
        except ValueError:
            self.log.error("Current dag %s is not active.", current_dag_id)
            return

        try:
            self.trigger_dag_id = active_dag_ids[current_dag_idx + 1]
            self.log.info("Next dag to trigger %s", self.trigger_dag_id)
        except IndexError:
            self.log.info("Current dag %s is the last dag.", current_dag_id)
            return

        # If the next Dag is currently running, we do not trigger it
        if DagRun.find(dag_id=self.trigger_dag_id, state=State.RUNNING):
            self.log.info("Not starting next dag %s, it is still running.",
                          self.trigger_dag_id)
            return

        super().execute(context)
示例#26
0
def get_filter_by_user_dagid(dagid):
    res=get_filter_by_user()
    curr_user = airflow.login.current_user
    session = Session()
    if res :
        result = session.execute(""" 
                                     SELECT dag_name 
                                       FROM dcmp_dag 
                                      WHERE last_editor_user_id = %s 
                                 """  %   curr_user.user.id )
        if result.rowcount > 0:
            records = result.fetchall()
            dags=[]
            for x in records:
                dags.append(x.dag_name)
            return ( dagid in dags )  # dagid 是否为用为的dagid
        else:
            return False
        session.close()
    else:
        # 无限制
        return True
示例#27
0
    def setUp(self):
        super(TestLogView, self).setUp()
        # Make sure that the configure_logging is not cached
        self.old_modules = dict(sys.modules)

        conf.load_test_config()

        # Create a custom logging configuration
        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task'][
            'base_log_folder'] = os.path.normpath(
                os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder,
                                     "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class',
                 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
示例#28
0
    def setUp(self):
        super(TestLogView, self).setUp()

        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['file.task'][
            'base_log_folder'] = os.path.normpath(
                os.path.join(current_dir, 'test_logs'))
        logging.config.dictConfig(logging_config)

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
示例#29
0
class RepomanAirflowPlugin(AirflowPlugin):
    """
    Airflow Plugin
    """
    def __init__(self):
        DAGRepo.__table__.create(engine, checkfirst=True)

        super(RepomanAirflowPlugin, self).__init__()

    name = "airflow_repoman"

    appbuilder_views = [{
        "name": "DAG Repos",
        "category": "Admin",
        "view": DAGRepoView()
    }]

    admin_views = [
        DAGRepoAdminView(DAGRepo,
                         Session(),
                         category="Admin",
                         name="DAG Repos")
    ]
示例#30
0
def run_dag_for_each_file(dag_to_trigger, **context) -> None:
    file_names = get_return_value_from_previous_task(context)
    message = 'None type passed from previous task. Accepted types are set, list or tuple.'
    assert file_names is not None, message

    session = Session()
    files_triggered = []
    for file_name in file_names:

        # check if a file has already been triggered for processing
        if session.query(DagRun).filter(and_(DagRun.run_id.startswith(file_name + '_'),
                                             DagRun.state == 'running')).first():
            continue

        trigger_dag(dag_id=dag_to_trigger,
                    run_id='{}_{}'.format(file_name, uuid4()),
                    conf=json.dumps({'file': file_name}),
                    execution_date=None,
                    replace_microseconds=False)

        files_triggered.append(file_name)

    logger.info('triggered %s for %s files: %s' % (dag_to_trigger, len(files_triggered), files_triggered))