Beispiel #1
0
def get_task_instance(dag_id, task_id, execution_date):
    """Return the task object identified by the given dag_id and task_id."""

    dagbag = DagBag()

    # Check DAG exists.
    if dag_id not in dagbag.dags:
        error_message = "Dag id {} not found".format(dag_id)
        raise DagNotFound(error_message)

    # Get DAG object and check Task Exists
    dag = dagbag.get_dag(dag_id)
    if not dag.has_task(task_id):
        error_message = 'Task {} not found in dag {}'.format(task_id, dag_id)
        raise TaskNotFound(error_message)

    # Get DagRun object and check that it exists
    dagrun = dag.get_dagrun(execution_date=execution_date)
    if not dagrun:
        error_message = ('Dag Run for date {} not found in dag {}'
                         .format(execution_date, dag_id))
        raise DagRunNotFound(error_message)

    # Get task instance object and check that it exists
    task_instance = dagrun.get_task_instance(task_id)
    if not task_instance:
        error_message = ('Task {} instance for date {} not found'
                         .format(task_id, execution_date))
        raise TaskInstanceNotFound(error_message)

    return task_instance
Beispiel #2
0
def downgrade():
    engine = settings.engine
    if engine.dialect.has_table(engine, 'task_instance'):
        connection = op.get_bind()
        sessionmaker = sa.orm.sessionmaker()
        session = sessionmaker(bind=connection)
        dagbag = DagBag(settings.DAGS_FOLDER)
        query = session.query(sa.func.count(TaskInstance.max_tries)).filter(
            TaskInstance.max_tries != -1
        )
        while query.scalar():
            tis = session.query(TaskInstance).filter(
                TaskInstance.max_tries != -1
            ).limit(BATCH_SIZE).all()
            for ti in tis:
                dag = dagbag.get_dag(ti.dag_id)
                if not dag or not dag.has_task(ti.task_id):
                    ti.try_number = 0
                else:
                    task = dag.get_task(ti.task_id)
                    # max_tries - try_number is number of times a task instance
                    # left to retry by itself. So the current try_number should be
                    # max number of self retry (task.retries) minus number of
                    # times left for task instance to try the task.
                    ti.try_number = max(0, task.retries - (ti.max_tries -
                                                           ti.try_number))
                ti.max_tries = -1
                session.merge(ti)
            session.commit()
        session.commit()
    op.drop_column('task_instance', 'max_tries')
Beispiel #3
0
    def test_trigger_dag_for_date(self):
        url_template = '/api/experimental/dags/{}/dag_runs'
        dag_id = 'example_bash_operator'
        hour_from_now = utcnow() + timedelta(hours=1)
        execution_date = datetime(hour_from_now.year, hour_from_now.month,
                                  hour_from_now.day, hour_from_now.hour)
        datetime_string = execution_date.isoformat()

        # Test Correct execution
        response = self.client.post(url_template.format(dag_id),
                                    data=json.dumps(
                                        {'execution_date': datetime_string}),
                                    content_type="application/json")
        self.assertEqual(200, response.status_code)

        dagbag = DagBag()
        dag = dagbag.get_dag(dag_id)
        dag_run = dag.get_dagrun(execution_date)
        self.assertTrue(
            dag_run,
            'Dag Run not found for execution date {}'.format(execution_date))

        # Test error for nonexistent dag
        response = self.client.post(
            url_template.format('does_not_exist_dag'),
            data=json.dumps({'execution_date': execution_date.isoformat()}),
            content_type="application/json")
        self.assertEqual(404, response.status_code)

        # Test error for bad datetime format
        response = self.client.post(url_template.format(dag_id),
                                    data=json.dumps(
                                        {'execution_date': 'not_a_datetime'}),
                                    content_type="application/json")
        self.assertEqual(400, response.status_code)
    def test_subdag_deadlock(self):
        dagbag = DagBag()
        dag = dagbag.get_dag('test_subdag_deadlock')
        dag.clear()
        subdag = dagbag.get_dag('test_subdag_deadlock.subdag')
        subdag.clear()

        # first make sure subdag has failed
        self.assertRaises(AirflowException, subdag.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        # now make sure dag picks up the subdag error
        self.assertRaises(AirflowException, dag.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
    def setUp(self):
        self.dagbag = DagBag(include_examples=True)
        self.dag1 = self.dagbag.dags['example_bash_operator']
        self.dag2 = self.dagbag.dags['example_subdag_operator']

        self.execution_dates = [days_ago(2), days_ago(1)]

        drs = _create_dagruns(self.dag1,
                              self.execution_dates,
                              state=State.RUNNING,
                              run_id_template="scheduled__{}")
        for dr in drs:
            dr.dag = self.dag1
            dr.verify_integrity()

        drs = _create_dagruns(self.dag2,
                              [self.dag2.default_args['start_date']],
                              state=State.RUNNING,
                              run_id_template="scheduled__{}")

        for dr in drs:
            dr.dag = self.dag2
            dr.verify_integrity()

        self.session = Session()
Beispiel #6
0
    def heartbeat(self):
        """
        Override the scheduler heartbeat to determine when the test is complete
        """
        super(SchedulerMetricsJob, self).heartbeat()
        session = settings.Session()
        # Get all the relevant task instances
        TI = TaskInstance
        successful_tis = (session.query(TI).filter(
            TI.dag_id.in_(DAG_IDS)).filter(TI.state.in_([State.SUCCESS
                                                         ])).all())
        session.commit()

        dagbag = DagBag(SUBDIR)
        dags = [dagbag.dags[dag_id] for dag_id in DAG_IDS]
        # the tasks in perf_dag_1 and per_dag_2 have a daily schedule interval.
        num_task_instances = sum([(timezone.utcnow() - task.start_date).days
                                  for dag in dags for task in dag.tasks])

        if (len(successful_tis) == num_task_instances
                or (timezone.utcnow() - self.start_date).total_seconds() >
                MAX_RUNTIME_SECS):
            if len(successful_tis) == num_task_instances:
                self.log.info("All tasks processed! Printing stats.")
            else:
                self.log.info(
                    "Test timeout reached. Printing available stats.")
            self.print_stats()
            set_dags_paused_state(True)
            sys.exit()
def get_dag_runs(dag_id, state=None):
    """
    Returns a list of Dag Runs for a specific DAG ID.
    :param dag_id: String identifier of a DAG
    :param state: queued|running|success...
    :return: List of DAG runs of a DAG with requested state,
    or all runs if the state is not specified
    """
    dagbag = DagBag()

    # Check DAG exists.
    if dag_id not in dagbag.dags:
        error_message = "Dag id {} not found".format(dag_id)
        raise AirflowException(error_message)

    dag_runs = list()
    state = state.lower() if state else None
    for run in DagRun.find(dag_id=dag_id, state=state):
        dag_runs.append({
            'id': run.id,
            'run_id': run.run_id,
            'state': run.state,
            'dag_id': run.dag_id,
            'execution_date': run.execution_date.isoformat(),
            'start_date': ((run.start_date or '') and
                           run.start_date.isoformat()),
            'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id,
                                   execution_date=run.execution_date)
        })

    return dag_runs
Beispiel #8
0
def get_task(dag_id, task_id):
    """Return the task object identified by the given dag_id and task_id."""
    dagbag = DagBag()

    # Check DAG exists.
    if dag_id not in dagbag.dags:
        error_message = "Dag id {} not found".format(dag_id)
        raise DagNotFound(error_message)

    # Get DAG object and check Task Exists
    dag = dagbag.get_dag(dag_id)
    if not dag.has_task(task_id):
        error_message = 'Task {} not found in dag {}'.format(task_id, dag_id)
        raise TaskNotFound(error_message)

    # Return the task.
    return dag.get_task(task_id)
    def setUp(self):
        self.dagbag = DagBag(include_examples=True)
        self.dag1 = self.dagbag.dags['example_bash_operator']
        self.dag2 = self.dagbag.dags['example_subdag_operator']

        self.execution_dates = [days_ago(2), days_ago(1), days_ago(0)]

        self.session = Session()
Beispiel #10
0
def upgrade():
    op.add_column('task_instance', sa.Column('max_tries', sa.Integer, server_default="-1"))
    # Check if table task_instance exist before data migration. This check is
    # needed for database that does not create table until migration finishes.
    # Checking task_instance table exists prevent the error of querying
    # non-existing task_instance table.
    connection = op.get_bind()
    inspector = Inspector.from_engine(connection)
    tables = inspector.get_table_names()

    if 'task_instance' in tables:
        # Get current session
        sessionmaker = sa.orm.sessionmaker()
        session = sessionmaker(bind=connection)
        dagbag = DagBag(settings.DAGS_FOLDER)
        query = session.query(sa.func.count(TaskInstance.max_tries)).filter(
            TaskInstance.max_tries == -1
        )
        # Separate db query in batch to prevent loading entire table
        # into memory and cause out of memory error.
        while query.scalar():
            tis = session.query(TaskInstance).filter(
                TaskInstance.max_tries == -1
            ).limit(BATCH_SIZE).all()
            for ti in tis:
                dag = dagbag.get_dag(ti.dag_id)
                if not dag or not dag.has_task(ti.task_id):
                    # task_instance table might not have the up-to-date
                    # information, i.e dag or task might be modified or
                    # deleted in dagbag but is reflected in task instance
                    # table. In this case we do not retry the task that can't
                    # be parsed.
                    ti.max_tries = ti.try_number
                else:
                    task = dag.get_task(ti.task_id)
                    if task.retries:
                        ti.max_tries = task.retries
                    else:
                        ti.max_tries = ti.try_number
                session.merge(ti)

            session.commit()
        # Commit the current session.
        session.commit()
Beispiel #11
0
def get_dag_run_state(dag_id, execution_date):
    """Return the task object identified by the given dag_id and task_id."""

    dagbag = DagBag()

    # Check DAG exists.
    if dag_id not in dagbag.dags:
        error_message = "Dag id {} not found".format(dag_id)
        raise DagNotFound(error_message)

    # Get DAG object and check Task Exists
    dag = dagbag.get_dag(dag_id)

    # Get DagRun object and check that it exists
    dagrun = dag.get_dagrun(execution_date=execution_date)
    if not dagrun:
        error_message = ('Dag Run for date {} not found in dag {}'.format(
            execution_date, dag_id))
        raise DagRunNotFound(error_message)

    return {'state': dagrun.get_state()}
Beispiel #12
0
    def test_find_zombies(self):
        manager = DagFileProcessorManager(
            dag_directory='directory',
            file_paths=['abc.txt'],
            max_runs=1,
            processor_factory=MagicMock().return_value,
            signal_conn=MagicMock(),
            stat_queue=MagicMock(),
            result_queue=MagicMock,
            async_mode=True)

        dagbag = DagBag(TEST_DAG_FOLDER)
        with create_session() as session:
            session.query(LJ).delete()
            dag = dagbag.get_dag('example_branch_operator')
            task = dag.get_task(task_id='run_this_first')

            ti = TI(task, DEFAULT_DATE, State.RUNNING)
            lj = LJ(ti)
            lj.state = State.SHUTDOWN
            lj.id = 1
            ti.job_id = lj.id

            session.add(lj)
            session.add(ti)
            session.commit()

            manager._last_zombie_query_time = timezone.utcnow() - timedelta(
                seconds=manager._zombie_threshold_secs + 1)
            zombies = manager._find_zombies()
            self.assertEquals(1, len(zombies))
            self.assertIsInstance(zombies[0], SimpleTaskInstance)
            self.assertEquals(ti.dag_id, zombies[0].dag_id)
            self.assertEquals(ti.task_id, zombies[0].task_id)
            self.assertEquals(ti.execution_date, zombies[0].execution_date)

            session.query(TI).delete()
            session.query(LJ).delete()
Beispiel #13
0
    def setUp(self):
        self.dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        logger.info('Loaded DAGS:')
        logger.info(self.dagbag.dagbag_report())

        try:
            subprocess.check_output(
                ['sudo', 'useradd', '-m', TEST_USER, '-g',
                 str(os.getegid())])
        except OSError as e:
            if e.errno == errno.ENOENT:
                raise unittest.SkipTest(
                    "The 'useradd' command did not exist so unable to test "
                    "impersonation; Skipping Test. These tests can only be run on a "
                    "linux host that supports 'useradd'.")
            else:
                raise unittest.SkipTest(
                    "The 'useradd' command exited non-zero; Skipping tests. Does the "
                    "current user have permission to run 'useradd' without a password "
                    "prompt (check sudoers file)?")
    def test_trigger_dag(self, mock):
        client = self.client
        test_dag_id = "example_bash_operator"
        DagBag(include_examples=True)

        # non existent
        with self.assertRaises(AirflowException):
            client.trigger_dag(dag_id="blablabla")

        with freeze_time(EXECDATE):
            # no execution date, execution date should be set automatically
            client.trigger_dag(dag_id=test_dag_id)
            mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
                                         execution_date=EXECDATE_NOFRACTIONS,
                                         state=State.RUNNING,
                                         conf=None,
                                         external_trigger=True)
            mock.reset_mock()

            # execution date with microseconds cutoff
            client.trigger_dag(dag_id=test_dag_id, execution_date=EXECDATE)
            mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
                                         execution_date=EXECDATE_NOFRACTIONS,
                                         state=State.RUNNING,
                                         conf=None,
                                         external_trigger=True)
            mock.reset_mock()

            # run id
            run_id = "my_run_id"
            client.trigger_dag(dag_id=test_dag_id, run_id=run_id)
            mock.assert_called_once_with(run_id=run_id,
                                         execution_date=EXECDATE_NOFRACTIONS,
                                         state=State.RUNNING,
                                         conf=None,
                                         external_trigger=True)
            mock.reset_mock()

            # test conf
            conf = '{"name": "John"}'
            client.trigger_dag(dag_id=test_dag_id, conf=conf)
            mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
                                         execution_date=EXECDATE_NOFRACTIONS,
                                         state=State.RUNNING,
                                         conf=json.loads(conf),
                                         external_trigger=True)
            mock.reset_mock()
    def test_on_kill(self):
        """
        Test that ensures that clearing in the UI SIGTERMS
        the task
        """
        path = "/tmp/airflow_on_kill"
        try:
            os.unlink(path)
        except OSError:
            pass

        dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        dag = dagbag.dags.get('test_on_kill')
        task = dag.get_task('task1')

        session = settings.Session()

        dag.clear()
        dag.create_dagrun(run_id="test",
                          state=State.RUNNING,
                          execution_date=DEFAULT_DATE,
                          start_date=DEFAULT_DATE,
                          session=session)
        ti = TI(task=task, execution_date=DEFAULT_DATE)
        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)

        runner = StandardTaskRunner(job1)
        runner.start()

        # give the task some time to startup
        time.sleep(3)

        runner.terminate()

        f = open(path, "r")
        self.assertEqual("ON_KILL_TEST", f.readline())
        f.close()
Beispiel #16
0
    def test_next_execution(self):
        # A scaffolding function
        def reset_dr_db(dag_id):
            session = Session()
            dr = session.query(models.DagRun).filter_by(dag_id=dag_id)
            dr.delete()
            session.commit()
            session.close()

        EXAMPLE_DAGS_FOLDER = os.path.join(
            os.path.dirname(
                os.path.dirname(
                    os.path.dirname(os.path.realpath(__file__))
                )
            ),
            "airflow/example_dags"
        )

        dagbag = DagBag(dag_folder=EXAMPLE_DAGS_FOLDER,
                                           include_examples=False)
        dag_ids = ['example_bash_operator',  # schedule_interval is '0 0 * * *'
                   'latest_only',  # schedule_interval is timedelta(hours=4)
                   'example_python_operator',  # schedule_interval=None
                   'example_xcom']  # schedule_interval="@once"

        # The details below is determined by the schedule_interval of example DAGs
        now = timezone.utcnow()
        next_execution_time_for_dag1 = pytz.utc.localize(
            datetime.combine(
                now.date() + timedelta(days=1),
                time(0)
            )
        )
        next_execution_time_for_dag2 = now + timedelta(hours=4)
        expected_output = [str(next_execution_time_for_dag1),
                           str(next_execution_time_for_dag2),
                           "None",
                           "None"]

        for i in range(len(dag_ids)):
            dag_id = dag_ids[i]

            # Clear dag run so no execution history fo each DAG
            reset_dr_db(dag_id)

            p = subprocess.Popen(["airflow", "next_execution", dag_id,
                                  "--subdir", EXAMPLE_DAGS_FOLDER],
                                 stdout=subprocess.PIPE)
            p.wait()
            stdout = []
            for line in p.stdout:
                stdout.append(str(line.decode("utf-8").rstrip()))

            # `next_execution` function is inapplicable if no execution record found
            # It prints `None` in such cases
            self.assertEqual(stdout[-1], "None")

            dag = dagbag.dags[dag_id]
            # Create a DagRun for each DAG, to prepare for next step
            dag.create_dagrun(
                run_id='manual__' + now.isoformat(),
                execution_date=now,
                start_date=now,
                state=State.FAILED
            )

            p = subprocess.Popen(["airflow", "next_execution", dag_id,
                                  "--subdir", EXAMPLE_DAGS_FOLDER],
                                 stdout=subprocess.PIPE)
            p.wait()
            stdout = []
            for line in p.stdout:
                stdout.append(str(line.decode("utf-8").rstrip()))
            self.assertEqual(stdout[-1], expected_output[i])

            reset_dr_db(dag_id)
Beispiel #17
0
class ImpersonationTest(unittest.TestCase):
    def setUp(self):
        self.dagbag = DagBag(
            dag_folder=TEST_DAG_FOLDER,
            include_examples=False,
        )
        logger.info('Loaded DAGS:')
        logger.info(self.dagbag.dagbag_report())

        try:
            subprocess.check_output(
                ['sudo', 'useradd', '-m', TEST_USER, '-g',
                 str(os.getegid())])
        except OSError as e:
            if e.errno == errno.ENOENT:
                raise unittest.SkipTest(
                    "The 'useradd' command did not exist so unable to test "
                    "impersonation; Skipping Test. These tests can only be run on a "
                    "linux host that supports 'useradd'.")
            else:
                raise unittest.SkipTest(
                    "The 'useradd' command exited non-zero; Skipping tests. Does the "
                    "current user have permission to run 'useradd' without a password "
                    "prompt (check sudoers file)?")

    def tearDown(self):
        subprocess.check_output(['sudo', 'userdel', '-r', TEST_USER])

    def run_backfill(self, dag_id, task_id):
        dag = self.dagbag.get_dag(dag_id)
        dag.clear()

        jobs.BackfillJob(dag=dag,
                         start_date=DEFAULT_DATE,
                         end_date=DEFAULT_DATE).run()

        ti = models.TaskInstance(task=dag.get_task(task_id),
                                 execution_date=DEFAULT_DATE)
        ti.refresh_from_db()

        self.assertEqual(ti.state, State.SUCCESS)

    def test_impersonation(self):
        """
        Tests that impersonating a unix user works
        """
        self.run_backfill('test_impersonation', 'test_impersonated_user')

    def test_no_impersonation(self):
        """
        If default_impersonation=None, tests that the job is run
        as the current user (which will be a sudoer)
        """
        self.run_backfill(
            'test_no_impersonation',
            'test_superuser',
        )

    def test_default_impersonation(self):
        """
        If default_impersonation=TEST_USER, tests that the job defaults
        to running as TEST_USER for a test without run_as_user set
        """
        os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] = TEST_USER

        try:
            self.run_backfill('test_default_impersonation',
                              'test_deelevated_user')
        finally:
            del os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION']

    def test_impersonation_custom(self):
        """
        Tests that impersonation using a unix user works with custom packages in
        PYTHONPATH
        """
        # PYTHONPATH is already set in script triggering tests
        assert 'PYTHONPATH' in os.environ

        self.run_backfill('impersonation_with_custom_pkg', 'exec_python_fn')

    def test_impersonation_subdag(self):
        """
        Tests that impersonation using a subdag correctly passes the right configuration
        :return:
        """
        self.run_backfill('impersonation_subdag', 'test_subdag_operation')
Beispiel #18
0
def initdb(rbac=False):
    session = settings.Session()

    from airflow import models
    upgradedb()

    merge_conn(
        models.Connection(
            conn_id='airflow_db', conn_type='mysql',
            host='mysql', login='******', password='',
            schema='airflow'))
    merge_conn(
        models.Connection(
            conn_id='beeline_default', conn_type='beeline', port="10000",
            host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}",
            schema='default'))
    merge_conn(
        models.Connection(
            conn_id='bigquery_default', conn_type='google_cloud_platform',
            schema='default'))
    merge_conn(
        models.Connection(
            conn_id='local_mysql', conn_type='mysql',
            host='localhost', login='******', password='******',
            schema='airflow'))
    merge_conn(
        models.Connection(
            conn_id='presto_default', conn_type='presto',
            host='localhost',
            schema='hive', port=3400))
    merge_conn(
        models.Connection(
            conn_id='google_cloud_default', conn_type='google_cloud_platform',
            schema='default',))
    merge_conn(
        models.Connection(
            conn_id='hive_cli_default', conn_type='hive_cli',
            schema='default',))
    merge_conn(
        models.Connection(
            conn_id='hiveserver2_default', conn_type='hiveserver2',
            host='localhost',
            schema='default', port=10000))
    merge_conn(
        models.Connection(
            conn_id='metastore_default', conn_type='hive_metastore',
            host='localhost', extra="{\"authMechanism\": \"PLAIN\"}",
            port=9083))
    merge_conn(
        models.Connection(
            conn_id='mongo_default', conn_type='mongo',
            host='mongo', port=27017))
    merge_conn(
        models.Connection(
            conn_id='mysql_default', conn_type='mysql',
            login='******',
            schema='airflow',
            host='mysql'))
    merge_conn(
        models.Connection(
            conn_id='postgres_default', conn_type='postgres',
            login='******',
            password='******',
            schema='airflow',
            host='postgres'))
    merge_conn(
        models.Connection(
            conn_id='sqlite_default', conn_type='sqlite',
            host='/tmp/sqlite_default.db'))
    merge_conn(
        models.Connection(
            conn_id='http_default', conn_type='http',
            host='https://www.google.com/'))
    merge_conn(
        models.Connection(
            conn_id='mssql_default', conn_type='mssql',
            host='localhost', port=1433))
    merge_conn(
        models.Connection(
            conn_id='vertica_default', conn_type='vertica',
            host='localhost', port=5433))
    merge_conn(
        models.Connection(
            conn_id='wasb_default', conn_type='wasb',
            extra='{"sas_token": null}'))
    merge_conn(
        models.Connection(
            conn_id='webhdfs_default', conn_type='hdfs',
            host='localhost', port=50070))
    merge_conn(
        models.Connection(
            conn_id='ssh_default', conn_type='ssh',
            host='localhost'))
    merge_conn(
        models.Connection(
            conn_id='sftp_default', conn_type='sftp',
            host='localhost', port=22, login='******',
            extra='''
                {"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}
            '''))
    merge_conn(
        models.Connection(
            conn_id='fs_default', conn_type='fs',
            extra='{"path": "/"}'))
    merge_conn(
        models.Connection(
            conn_id='aws_default', conn_type='aws',
            extra='{"region_name": "us-east-1"}'))
    merge_conn(
        models.Connection(
            conn_id='spark_default', conn_type='spark',
            host='yarn', extra='{"queue": "root.default"}'))
    merge_conn(
        models.Connection(
            conn_id='druid_broker_default', conn_type='druid',
            host='druid-broker', port=8082, extra='{"endpoint": "druid/v2/sql"}'))
    merge_conn(
        models.Connection(
            conn_id='druid_ingest_default', conn_type='druid',
            host='druid-overlord', port=8081, extra='{"endpoint": "druid/indexer/v1/task"}'))
    merge_conn(
        models.Connection(
            conn_id='redis_default', conn_type='redis',
            host='redis', port=6379,
            extra='{"db": 0}'))
    merge_conn(
        models.Connection(
            conn_id='sqoop_default', conn_type='sqoop',
            host='rmdbs', extra=''))
    merge_conn(
        models.Connection(
            conn_id='emr_default', conn_type='emr',
            extra='''
                {   "Name": "default_job_flow_name",
                    "LogUri": "s3://my-emr-log-bucket/default_job_flow_location",
                    "ReleaseLabel": "emr-4.6.0",
                    "Instances": {
                        "Ec2KeyName": "mykey",
                        "Ec2SubnetId": "somesubnet",
                        "InstanceGroups": [
                            {
                                "Name": "Master nodes",
                                "Market": "ON_DEMAND",
                                "InstanceRole": "MASTER",
                                "InstanceType": "r3.2xlarge",
                                "InstanceCount": 1
                            },
                            {
                                "Name": "Slave nodes",
                                "Market": "ON_DEMAND",
                                "InstanceRole": "CORE",
                                "InstanceType": "r3.2xlarge",
                                "InstanceCount": 1
                            }
                        ],
                        "TerminationProtected": false,
                        "KeepJobFlowAliveWhenNoSteps": false
                    },
                    "Applications":[
                        { "Name": "Spark" }
                    ],
                    "VisibleToAllUsers": true,
                    "JobFlowRole": "EMR_EC2_DefaultRole",
                    "ServiceRole": "EMR_DefaultRole",
                    "Tags": [
                        {
                            "Key": "app",
                            "Value": "analytics"
                        },
                        {
                            "Key": "environment",
                            "Value": "development"
                        }
                    ]
                }
            '''))
    merge_conn(
        models.Connection(
            conn_id='databricks_default', conn_type='databricks',
            host='localhost'))
    merge_conn(
        models.Connection(
            conn_id='qubole_default', conn_type='qubole',
            host='localhost'))
    merge_conn(
        models.Connection(
            conn_id='segment_default', conn_type='segment',
            extra='{"write_key": "my-segment-write-key"}')),
    merge_conn(
        models.Connection(
            conn_id='azure_data_lake_default', conn_type='azure_data_lake',
            extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }'))
    merge_conn(
        models.Connection(
            conn_id='azure_cosmos_default', conn_type='azure_cosmos',
            extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }'))
    merge_conn(
        models.Connection(
            conn_id='cassandra_default', conn_type='cassandra',
            host='cassandra', port=9042))

    # Known event types
    KET = models.KnownEventType
    if not session.query(KET).filter(KET.know_event_type == 'Holiday').first():
        session.add(KET(know_event_type='Holiday'))
    if not session.query(KET).filter(KET.know_event_type == 'Outage').first():
        session.add(KET(know_event_type='Outage'))
    if not session.query(KET).filter(
            KET.know_event_type == 'Natural Disaster').first():
        session.add(KET(know_event_type='Natural Disaster'))
    if not session.query(KET).filter(
            KET.know_event_type == 'Marketing Campaign').first():
        session.add(KET(know_event_type='Marketing Campaign'))
    session.commit()

    dagbag = DagBag()
    # Save individual DAGs in the ORM
    for dag in dagbag.dags.values():
        dag.sync_to_db()
    # Deactivate the unknown ones
    models.DAG.deactivate_unknown_dags(dagbag.dags.keys())

    Chart = models.Chart
    chart_label = "Airflow task instance by type"
    chart = session.query(Chart).filter(Chart.label == chart_label).first()
    if not chart:
        chart = Chart(
            label=chart_label,
            conn_id='airflow_db',
            chart_type='bar',
            x_is_date=False,
            sql=(
                "SELECT state, COUNT(1) as number "
                "FROM task_instance "
                "WHERE dag_id LIKE 'example%' "
                "GROUP BY state"),
        )
        session.add(chart)
        session.commit()

    if rbac:
        from flask_appbuilder.security.sqla import models
        from flask_appbuilder.models.sqla import Base
        Base.metadata.create_all(settings.engine)
 def setUp(self):
     self.session = settings.Session()
     self.dagbag = DagBag(include_examples=True)
     self.dag_id = 'example_bash_operator'
     self.dag = self.dagbag.dags[self.dag_id]
Beispiel #20
0
 def setUp(self):
     configuration.load_test_config()
     self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
     self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     self.dag = DAG(TEST_DAG_ID, default_args=self.args)
 def setUp(self):
     self.dagbag = DagBag(include_examples=True)
     self.cluster = LocalCluster()
 def setUp(self):
     self.dagbag = DagBag(include_examples=True)
Beispiel #23
0
 def _run_dag(self):
     dag_bag = DagBag(dag_folder=TESTS_DAG_FOLDER, include_examples=False)
     self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     dag = dag_bag.get_dag(self.dag_id)
     dag.clear(reset_dag_runs=True)
     dag.run(ignore_first_depends_on_past=True, verbose=True)