Ejemplo n.º 1
0
class TestTriggerDag(unittest.TestCase):

    def setUp(self):
        conf.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        models.DagBag().get_dag("example_bash_operator").sync_to_db()

    def test_trigger_dag_button_normal_exist(self):
        resp = self.app.get('/', follow_redirects=True)
        self.assertIn('/trigger?dag_id=example_bash_operator', resp.data.decode('utf-8'))
        self.assertIn("return confirmDeleteDag('example_bash_operator')", resp.data.decode('utf-8'))

    def test_trigger_dag_button(self):

        test_dag_id = "example_bash_operator"

        DR = models.DagRun
        self.session.query(DR).delete()
        self.session.commit()

        self.app.get('/admin/airflow/trigger?dag_id={}'.format(test_dag_id))

        run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first()
        self.assertIsNotNone(run)
        self.assertIn("manual__", run.run_id)
    def test_branch_list_without_dag_run(self):
        """This checks if the BranchPythonOperator supports branching off to a list of tasks."""
        self.branch_op = BranchPythonOperator(task_id='make_choice',
                                              dag=self.dag,
                                              python_callable=lambda: ['branch_1', 'branch_2'])
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
        self.branch_3.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        session = Session()
        tis = session.query(TI).filter(
            TI.dag_id == self.dag.dag_id,
            TI.execution_date == DEFAULT_DATE
        )
        session.close()

        expected = {
            "make_choice": State.SUCCESS,
            "branch_1": State.NONE,
            "branch_2": State.NONE,
            "branch_3": State.SKIPPED,
        }

        for ti in tis:
            if ti.task_id in expected:
                self.assertEqual(ti.state, expected[ti.task_id])
            else:
                raise Exception
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        self.branch_op = BranchPythonOperator(task_id='make_choice',
                                              dag=self.dag,
                                              python_callable=lambda: 'branch_1')
        self.branch_1.set_upstream(self.branch_op)
        self.branch_2.set_upstream(self.branch_op)
        self.dag.clear()

        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        session = Session()
        tis = session.query(TI).filter(
            TI.dag_id == self.dag.dag_id,
            TI.execution_date == DEFAULT_DATE
        )
        session.close()

        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEqual(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1':
                # should exist with state None
                self.assertEqual(ti.state, State.NONE)
            elif ti.task_id == 'branch_2':
                self.assertEqual(ti.state, State.SKIPPED)
            else:
                raise Exception
Ejemplo n.º 4
0
    def test_import_variables(self):
        content = ('{"str_key": "str_value", "int_key": 60,'
                   '"list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}')
        try:
            # python 3+
            bytes_content = io.BytesIO(bytes(content, encoding='utf-8'))
        except TypeError:
            # python 2.7
            bytes_content = io.BytesIO(bytes(content))
        response = self.app.post(
            self.IMPORT_ENDPOINT,
            data={'file': (bytes_content, 'test.json')},
            follow_redirects=True
        )
        self.assertEqual(response.status_code, 200)
        session = Session()
        # Extract values from Variable
        db_dict = {x.key: x.get_val() for x in session.query(models.Variable).all()}
        session.close()
        self.assertIn('str_key', db_dict)
        self.assertIn('int_key', db_dict)
        self.assertIn('list_key', db_dict)
        self.assertIn('dict_key', db_dict)
        self.assertEquals('str_value', db_dict['str_key'])
        self.assertEquals('60', db_dict['int_key'])
        self.assertEquals('[1, 2]', db_dict['list_key'])

        case_a_dict = '{"k_a": 2, "k_b": 3}'
        case_b_dict = '{"k_b": 3, "k_a": 2}'
        try:
            self.assertEquals(case_a_dict, db_dict['dict_key'])
        except AssertionError:
            self.assertEquals(case_b_dict, db_dict['dict_key'])
Ejemplo n.º 5
0
class TestVariableView(unittest.TestCase):

    CREATE_ENDPOINT = '/admin/variable/new/?url=/admin/variable/'

    @classmethod
    def setUpClass(cls):
        super(TestVariableView, cls).setUpClass()
        session = Session()
        session.query(models.Variable).delete()
        session.commit()
        session.close()

    def setUp(self):
        super(TestVariableView, self).setUp()
        configuration.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        self.variable = {
            'key': 'test_key',
            'val': 'text_val',
            'is_encrypted': True
        }

    def tearDown(self):
        self.session.query(models.Variable).delete()
        self.session.commit()
        self.session.close()
        super(TestVariableView, self).tearDown()

    def test_can_handle_error_on_decrypt(self):
        # create valid variable
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.variable,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)

        # update the variable with a wrong value, given that is encrypted
        Var = models.Variable
        (self.session.query(Var)
            .filter(Var.key == self.variable['key'])
            .update({
                'val': 'failed_value_not_encrypted'
            }, synchronize_session=False))
        self.session.commit()

        # retrieve Variables page, should not fail and contain the Invalid
        # label for the variable
        response = self.app.get('/admin/variable', follow_redirects=True)
        self.assertEqual(response.status_code, 200)
        self.assertEqual(self.session.query(models.Variable).count(), 1)
        self.assertIn('<span class="label label-danger">Invalid</span>',
                      response.data.decode('utf-8'))
Ejemplo n.º 6
0
    def tearDown(self):
        super(BranchOperatorTest, self).tearDown()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        print(len(session.query(DagRun).all()))
        session.commit()
        session.close()
Ejemplo n.º 7
0
 def tearDown(self):
     session = Session()
     session.query(DagRun).delete()
     session.query(TaskInstance).delete()
     session.commit()
     session.close()
     super(TestApiExperimental, self).tearDown()
Ejemplo n.º 8
0
 def setUpClass(cls):
     super(TestApiExperimental, cls).setUpClass()
     session = Session()
     session.query(DagRun).delete()
     session.query(TaskInstance).delete()
     session.commit()
     session.close()
def clear_session():
    """Manage airflow database state for tests"""
    session = Session()
    session.query(DagRun).delete()
    session.query(TI).delete()
    session.commit()
    session.close()
Ejemplo n.º 10
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        value = False
        dag = DAG('shortcircuit_operator_test_without_dag_run',
                  default_args={
                       'owner': 'airflow',
                       'start_date': DEFAULT_DATE
                  },
                  schedule_interval=INTERVAL)
        short_op = ShortCircuitOperator(task_id='make_choice',
                                        dag=dag,
                                        python_callable=lambda: value)
        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
        branch_1.set_upstream(short_op)
        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
        branch_2.set_upstream(branch_1)
        upstream = DummyOperator(task_id='upstream', dag=dag)
        upstream.set_downstream(short_op)
        dag.clear()

        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        session = Session()
        tis = session.query(TI).filter(
            TI.dag_id == dag.dag_id,
            TI.execution_date == DEFAULT_DATE
        )

        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEquals(ti.state, State.SUCCESS)
            elif ti.task_id == 'upstream':
                # should not exist
                raise
            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                self.assertEquals(ti.state, State.SKIPPED)
            else:
                raise

        value = True
        dag.clear()

        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEquals(ti.state, State.SUCCESS)
            elif ti.task_id == 'upstream':
                # should not exist
                raise
            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                self.assertEquals(ti.state, State.NONE)
            else:
                raise

        session.close()
def test_single_partition_with_templates():
    """Run the entire DAG, instead of just the operator. This is necessary to
    instantiate the templating functionality, which requires context from the
    scheduler."""

    bucket = "test"
    prefix = "dataset/v1/submission_date=20190101"

    client = boto3.client("s3")
    client.create_bucket(Bucket=bucket)
    client.put_object(Bucket=bucket, Body="", Key=prefix + "/part=1/_SUCCESS")
    client.put_object(Bucket=bucket, Body="", Key=prefix + "/part=2/garbage")

    dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE})

    sensor_success = S3FSCheckSuccessSensor(
        task_id="test_success_template",
        bucket=bucket,
        prefix="dataset/v1/submission_date={{ ds_nodash }}/part=1",
        num_partitions=1,
        poke_interval=1,
        timeout=2,
        dag=dag,
    )
    sensor_failure = S3FSCheckSuccessSensor(
        task_id="test_failure_template",
        bucket=bucket,
        prefix="dataset/v1/submission_date={{ ds_nodash }}/part=2",
        num_partitions=1,
        poke_interval=1,
        timeout=2,
        dag=dag,
    )

    # execute everything for templating to work
    sensor_success.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
    with pytest.raises(AirflowSensorTimeout):
        sensor_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    session = Session()
    tis = session.query(TaskInstance).filter(
        TaskInstance.dag_id == dag.dag_id, TaskInstance.execution_date == DEFAULT_DATE
    )
    session.close()

    count = 0
    for ti in tis:
        if ti.task_id == "test_success_template":
            assert ti.state == State.SUCCESS
        elif ti.task_id == "test_failure_template":
            assert ti.state == State.FAILED
        else:
            assert False
        count += 1
    assert count == 2
 def tearDown(self):
     session = Session()
     session.query(models.TaskInstance).filter_by(
         dag_id=TEST_DAG_ID).delete()
     session.query(TaskFail).filter_by(
         dag_id=TEST_DAG_ID).delete()
     session.commit()
     session.close()
Ejemplo n.º 13
0
 def test_charts(self):
     session = Session()
     chart_label = "Airflow task instance by type"
     chart = session.query(models.Chart).filter(models.Chart.label == chart_label).first()
     chart_id = chart.id
     session.close()
     response = self.app.get("/admin/airflow/chart" "?chart_id={}&iteration_no=1".format(chart_id))
     assert "Airflow task instance by type" in response.data.decode("utf-8")
     response = self.app.get("/admin/airflow/chart_data" "?chart_id={}&iteration_no=1".format(chart_id))
     assert "example" in response.data.decode("utf-8")
     response = self.app.get("/admin/airflow/dag_details?dag_id=example_branch_operator")
     assert "run_this_first" in response.data.decode("utf-8")
Ejemplo n.º 14
0
    def tearDown(self):
        super(ShortCircuitOperatorTest, self).tearDown()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        session.commit()
        session.close()
Ejemplo n.º 15
0
    def setUpClass(cls):
        super(PythonOperatorTest, cls).setUpClass()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        session.commit()
        session.close()
Ejemplo n.º 16
0
    def tearDown(self):
        super(PythonOperatorTest, self).tearDown()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        print(len(session.query(DagRun).all()))
        session.commit()
        session.close()

        for var in TI_CONTEXT_ENV_VARS:
            if var in os.environ:
                del os.environ[var]
Ejemplo n.º 17
0
 def test_charts(self):
     session = Session()
     chart_label = "Airflow task instance by type"
     chart = session.query(
         models.Chart).filter(models.Chart.label==chart_label).first()
     chart_id = chart.id
     session.close()
     response = self.app.get(
         '/admin/airflow/chart'
         '?chart_id={}&iteration_no=1'.format(chart_id))
     assert "Airflow task instance by type" in response.data.decode('utf-8')
     response = self.app.get(
         '/admin/airflow/chart_data'
         '?chart_id={}&iteration_no=1'.format(chart_id))
     assert "example" in response.data.decode('utf-8')
Ejemplo n.º 18
0
 def setUp(self):
     conf.load_test_config()
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     models.DagBag().get_dag("example_bash_operator").sync_to_db()
Ejemplo n.º 19
0
    def setUp(self):
        super(TestLogView, self).setUp()

        # Create a custom logging configuration
        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task']['base_log_folder'] = os.path.normpath(
            os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
Ejemplo n.º 20
0
 def setUp(self):
     conf.load_test_config()
     self.app, self.appbuilder = application.create_app(testing=True)
     self.app.config['WTF_CSRF_ENABLED'] = False
     self.client = self.app.test_client()
     self.session = Session()
     self.login()
Ejemplo n.º 21
0
    def setUp(self):
        self.dagbag = models.DagBag(include_examples=True)
        self.dag1 = self.dagbag.dags['example_bash_operator']
        self.dag2 = self.dagbag.dags['example_subdag_operator']

        self.execution_dates = [days_ago(2), days_ago(1), days_ago(0)]

        self.session = Session()
Ejemplo n.º 22
0
 def tearDown(self):
     configuration.test_mode()
     session = Session()
     session.query(models.User).delete()
     session.commit()
     session.close()
     configuration.conf.set("webserver", "authenticate", "False")
Ejemplo n.º 23
0
class TestChartModelView(unittest.TestCase):

    CREATE_ENDPOINT = '/admin/chart/new/?url=/admin/chart/'

    @classmethod
    def setUpClass(cls):
        super(TestChartModelView, cls).setUpClass()
        session = Session()
        session.query(models.Chart).delete()
        session.query(models.User).delete()
        session.commit()
        user = models.User(username='******')
        session.add(user)
        session.commit()
        session.close()

    def setUp(self):
        super(TestChartModelView, self).setUp()
        configuration.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        self.chart = {
            'label': 'chart',
            'owner': 'airflow',
            'conn_id': 'airflow_ci',
        }

    def tearDown(self):
        self.session.query(models.Chart).delete()
        self.session.commit()
        self.session.close()
        super(TestChartModelView, self).tearDown()

    @classmethod
    def tearDownClass(cls):
        session = Session()
        session.query(models.User).delete()
        session.commit()
        session.close()
        super(TestChartModelView, cls).tearDownClass()

    def test_create_chart(self):
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.chart,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)
        self.assertEqual(self.session.query(models.Chart).count(), 1)

    def test_get_chart(self):
        response = self.app.get(
            '/admin/chart?sort=3',
            follow_redirects=True,
        )
        print(response.data)
        self.assertEqual(response.status_code, 200)
        self.assertIn('Sort by Owner', response.data.decode('utf-8'))
Ejemplo n.º 24
0
 def setUpClass(cls):
     super(TestLogView, cls).setUpClass()
     session = Session()
     session.query(TaskInstance).filter(
         TaskInstance.dag_id == cls.DAG_ID and
         TaskInstance.task_id == cls.TASK_ID and
         TaskInstance.execution_date == cls.DEFAULT_DATE).delete()
     session.commit()
     session.close()
Ejemplo n.º 25
0
 def setUp(self):
     super(TestChartModelView, self).setUp()
     configuration.load_test_config()
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     self.chart = {
         'label': 'chart',
         'owner': 'airflow',
         'conn_id': 'airflow_ci',
     }
Ejemplo n.º 26
0
    def test_without_dag_run(self):
        """This checks the defensive against non existent tasks in a dag run"""
        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

        session = Session()
        tis = session.query(TI).filter(
            TI.dag_id == self.dag.dag_id,
            TI.execution_date == DEFAULT_DATE
        )
        session.close()

        for ti in tis:
            if ti.task_id == 'make_choice':
                self.assertEquals(ti.state, State.SUCCESS)
            elif ti.task_id == 'branch_1':
                # should exist with state None
                self.assertEquals(ti.state, State.NONE)
            elif ti.task_id == 'branch_2':
                self.assertEquals(ti.state, State.SKIPPED)
            else:
                raise
Ejemplo n.º 27
0
 def setUp(self):
     super(TestVariableView, self).setUp()
     configuration.load_test_config()
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     self.variable = {
         'key': 'test_key',
         'val': 'text_val',
         'is_encrypted': True
     }
Ejemplo n.º 28
0
 def setUp(self):
     super(TestPoolModelView, self).setUp()
     configuration.load_test_config()
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     self.pool = {
         'pool': 'test-pool',
         'slots': 777,
         'description': 'test-pool-description',
     }
Ejemplo n.º 29
0
    def test_import_variable_fail(self):
        with mock.patch('airflow.models.Variable.set') as set_mock:
            set_mock.side_effect = UnicodeEncodeError
            content = '{"fail_key": "fail_val"}'

            try:
                # python 3+
                bytes_content = io.BytesIO(bytes(content, encoding='utf-8'))
            except TypeError:
                # python 2.7
                bytes_content = io.BytesIO(bytes(content))
            response = self.app.post(
                self.IMPORT_ENDPOINT,
                data={'file': (bytes_content, 'test.json')},
                follow_redirects=True
            )
            self.assertEqual(response.status_code, 200)
            session = Session()
            db_dict = {x.key: x.get_val() for x in session.query(models.Variable).all()}
            session.close()
            self.assertNotIn('fail_key', db_dict)
Ejemplo n.º 30
0
    def test_delete_dag_button_for_dag_on_scheduler_only(self):
        # Test for JIRA AIRFLOW-3233 (PR 4069):
        # The delete-dag URL should be generated correctly for DAGs
        # that exist on the scheduler (DB) but not the webserver DagBag

        test_dag_id = "non_existent_dag"

        session = Session()
        DM = models.DagModel
        session.query(DM).filter(DM.dag_id == 'example_bash_operator').update({'dag_id': test_dag_id})
        session.commit()

        resp = self.app.get('/', follow_redirects=True)
        self.assertIn('/delete?dag_id={}'.format(test_dag_id), resp.data.decode('utf-8'))
        self.assertIn("return confirmDeleteDag('{}')".format(test_dag_id), resp.data.decode('utf-8'))

        session.query(DM).filter(DM.dag_id == test_dag_id).update({'dag_id': 'example_bash_operator'})
        session.commit()
Ejemplo n.º 31
0
class TestMarkTasks(unittest.TestCase):
    def setUp(self):
        self.dagbag = models.DagBag(include_examples=True)
        self.dag1 = self.dagbag.dags['test_example_bash_operator']
        self.dag2 = self.dagbag.dags['example_subdag_operator']

        self.execution_dates = [days_ago(2), days_ago(1)]

        drs = _create_dagruns(self.dag1,
                              self.execution_dates,
                              state=State.RUNNING,
                              run_id_template="scheduled__{}")
        for dr in drs:
            dr.dag = self.dag1
            dr.verify_integrity()

        drs = _create_dagruns(self.dag2,
                              [self.dag2.default_args['start_date']],
                              state=State.RUNNING,
                              run_id_template="scheduled__{}")

        for dr in drs:
            dr.dag = self.dag2
            dr.verify_integrity()

        self.session = Session()

    def snapshot_state(self, dag, execution_dates):
        TI = models.TaskInstance
        tis = self.session.query(TI).filter(
            TI.dag_id == dag.dag_id,
            TI.execution_date.in_(execution_dates)).all()

        self.session.expunge_all()

        return tis

    def verify_state(self, dag, task_ids, execution_dates, state, old_tis):
        TI = models.TaskInstance

        tis = self.session.query(TI).filter(
            TI.dag_id == dag.dag_id,
            TI.execution_date.in_(execution_dates)).all()

        self.assertTrue(len(tis) > 0)

        for ti in tis:
            if ti.task_id in task_ids and ti.execution_date in execution_dates:
                self.assertEqual(ti.state, state)
            else:
                for old_ti in old_tis:
                    if (old_ti.task_id == ti.task_id
                            and old_ti.execution_date == ti.execution_date):
                        self.assertEqual(ti.state, old_ti.state)

    def test_mark_tasks_now(self):
        # set one task to success but do not commit
        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
        task = self.dag1.get_task("runme_1")
        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=False,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=False)
        self.assertEqual(len(altered), 1)
        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
                          None, snapshot)

        # set one and only one task to success
        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=False,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 1)
        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
                          State.SUCCESS, snapshot)

        # set no tasks
        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=False,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 0)
        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
                          State.SUCCESS, snapshot)

        # set task to other than success
        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=False,
                            future=False,
                            past=False,
                            state=State.FAILED,
                            commit=True)
        self.assertEqual(len(altered), 1)
        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
                          State.FAILED, snapshot)

        # dont alter other tasks
        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
        task = self.dag1.get_task("runme_0")
        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=False,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 1)
        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
                          State.SUCCESS, snapshot)

    def test_mark_downstream(self):
        # test downstream
        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
        task = self.dag1.get_task("runme_1")
        relatives = task.get_flat_relatives(upstream=False)
        task_ids = [t.task_id for t in relatives]
        task_ids.append(task.task_id)

        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=True,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 3)
        self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
                          State.SUCCESS, snapshot)

    def test_mark_upstream(self):
        # test upstream
        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
        task = self.dag1.get_task("run_after_loop")
        relatives = task.get_flat_relatives(upstream=True)
        task_ids = [t.task_id for t in relatives]
        task_ids.append(task.task_id)

        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=True,
                            downstream=False,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 4)
        self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
                          State.SUCCESS, snapshot)

    def test_mark_tasks_future(self):
        # set one task to success towards end of scheduled dag runs
        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
        task = self.dag1.get_task("runme_1")
        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=False,
                            future=True,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 2)
        self.verify_state(self.dag1, [task.task_id], self.execution_dates,
                          State.SUCCESS, snapshot)

    def test_mark_tasks_past(self):
        # set one task to success towards end of scheduled dag runs
        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
        task = self.dag1.get_task("runme_1")
        altered = set_state(task=task,
                            execution_date=self.execution_dates[1],
                            upstream=False,
                            downstream=False,
                            future=False,
                            past=True,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 2)
        self.verify_state(self.dag1, [task.task_id], self.execution_dates,
                          State.SUCCESS, snapshot)

    def test_mark_tasks_subdag(self):
        # set one task to success towards end of scheduled dag runs
        task = self.dag2.get_task("section-1")
        relatives = task.get_flat_relatives(upstream=False)
        task_ids = [t.task_id for t in relatives]
        task_ids.append(task.task_id)

        altered = set_state(task=task,
                            execution_date=self.execution_dates[0],
                            upstream=False,
                            downstream=True,
                            future=False,
                            past=False,
                            state=State.SUCCESS,
                            commit=True)
        self.assertEqual(len(altered), 14)

        # cannot use snapshot here as that will require drilling down the
        # the sub dag tree essentially recreating the same code as in the
        # tested logic.
        self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
                          State.SUCCESS, [])

    def tearDown(self):
        self.dag1.clear()
        self.dag2.clear()

        # just to make sure we are fully cleaned up
        self.session.query(models.DagRun).delete()
        self.session.query(models.TaskInstance).delete()
        self.session.commit()

        self.session.close()
Ejemplo n.º 32
0
def set_state(task, execution_date, upstream=False, downstream=False,
              future=False, past=False, state=State.SUCCESS, commit=False):
    """
    Set the state of a task instance and if needed its relatives. Can set state
    for future tasks (calculated from execution_date) and retroactively
    for past tasks. Will verify integrity of past dag runs in order to create
    tasks that did not exist. It will not create dag runs that are missing
    on the schedule (but it will as for subdag dag runs if needed).
    :param task: the task from which to work. task.task.dag needs to be set
    :param execution_date: the execution date from which to start looking
    :param upstream: Mark all parents (upstream tasks)
    :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
    :param future: Mark all future tasks on the interval of the dag up until
        last execution date.
    :param past: Retroactively mark all tasks starting from start_date of the DAG
    :param state: State to which the tasks need to be set
    :param commit: Commit tasks to be altered to the database
    :return: list of tasks that have been created and updated
    """
    assert timezone.is_localized(execution_date)

    # microseconds are supported by the database, but is not handled
    # correctly by airflow on e.g. the filesystem and in other places
    execution_date = execution_date.replace(microsecond=0)

    assert task.dag is not None
    dag = task.dag

    latest_execution_date = dag.latest_execution_date
    assert latest_execution_date is not None

    # determine date range of dag runs and tasks to consider
    end_date = latest_execution_date if future else execution_date

    if 'start_date' in dag.default_args:
        start_date = dag.default_args['start_date']
    elif dag.start_date:
        start_date = dag.start_date
    else:
        start_date = execution_date

    start_date = execution_date if not past else start_date

    if dag.schedule_interval == '@once':
        dates = [start_date]
    else:
        dates = dag.date_range(start_date=start_date, end_date=end_date)

    # find relatives (siblings = downstream, parents = upstream) if needed
    task_ids = [task.task_id]
    if downstream:
        relatives = task.get_flat_relatives(upstream=False)
        task_ids += [t.task_id for t in relatives]
    if upstream:
        relatives = task.get_flat_relatives(upstream=True)
        task_ids += [t.task_id for t in relatives]

    # verify the integrity of the dag runs in case a task was added or removed
    # set the confirmed execution dates as they might be different
    # from what was provided
    confirmed_dates = []
    drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates)
    for dr in drs:
        dr.dag = dag
        dr.verify_integrity()
        confirmed_dates.append(dr.execution_date)

    # go through subdagoperators and create dag runs. We will only work
    # within the scope of the subdag. We wont propagate to the parent dag,
    # but we will propagate from parent to subdag.
    session = Session()
    dags = [dag]
    sub_dag_ids = []
    while len(dags) > 0:
        current_dag = dags.pop()
        for task_id in task_ids:
            if not current_dag.has_task(task_id):
                continue

            current_task = current_dag.get_task(task_id)
            if isinstance(current_task, SubDagOperator):
                # this works as a kind of integrity check
                # it creates missing dag runs for subdagoperators,
                # maybe this should be moved to dagrun.verify_integrity
                drs = _create_dagruns(current_task.subdag,
                                      execution_dates=confirmed_dates,
                                      state=State.RUNNING,
                                      run_id_template=BackfillJob.ID_FORMAT_PREFIX)

                for dr in drs:
                    dr.dag = current_task.subdag
                    dr.verify_integrity()
                    if commit:
                        dr.state = state
                        session.merge(dr)

                dags.append(current_task.subdag)
                sub_dag_ids.append(current_task.subdag.dag_id)

    # now look for the task instances that are affected
    TI = TaskInstance

    # get all tasks of the main dag that will be affected by a state change
    qry_dag = session.query(TI).filter(
        TI.dag_id == dag.dag_id,
        TI.execution_date.in_(confirmed_dates),
        TI.task_id.in_(task_ids)).filter(
        or_(TI.state.is_(None),
            TI.state != state)
    )

    # get *all* tasks of the sub dags
    if len(sub_dag_ids) > 0:
        qry_sub_dag = session.query(TI).filter(
            TI.dag_id.in_(sub_dag_ids),
            TI.execution_date.in_(confirmed_dates)).filter(
            or_(TI.state.is_(None),
                TI.state != state)
        )

    if commit:
        tis_altered = qry_dag.with_for_update().all()
        if len(sub_dag_ids) > 0:
            tis_altered += qry_sub_dag.with_for_update().all()
        for ti in tis_altered:
            ti.state = state
        session.commit()
    else:
        tis_altered = qry_dag.all()
        if len(sub_dag_ids) > 0:
            tis_altered += qry_sub_dag.all()

    session.expunge_all()
    session.close()

    return tis_altered
Ejemplo n.º 33
0
        parent_ids = get_parent_ids(graph.sql_sensor_id)
        for parent_id in parent_ids:
            for parent_graph in graphs:
                if parent_id == parent_graph.sql_sensor_id:
                    graph.set_upstream(parent_graph)

    return subdag


def execute_statement(**kw):
    hook = BaseHook.get_connection(kw['conn_id']).get_hook()
    logging.info('Executing: %s', kw['sql'])
    hook.run(kw['sql'])


session = Session()
dags = session.query(SQLDag) \
    .filter(SQLSensor.enabled == True) \
    .all()


for _dag in dags:
    default_args['start_date'] = datetime.combine(
        _dag.start_date,
        dt.time(),
    )
    dag = DAG(_dag.id,
              default_args=default_args,
              schedule_interval=_dag.schedule_interval,
              )
    subdag_name = 'subdag_with_sensors'
Ejemplo n.º 34
0
class ViewWithDateTimeAndNumRunsAndDagRunsFormTester:
    DAG_ID = 'dag_for_testing_dt_nr_dr_form'
    DEFAULT_DATE = datetime(2017, 9, 1)
    RUNS_DATA = [
        ('dag_run_for_testing_dt_nr_dr_form_4', datetime(2018, 4, 4)),
        ('dag_run_for_testing_dt_nr_dr_form_3', datetime(2018, 3, 3)),
        ('dag_run_for_testing_dt_nr_dr_form_2', datetime(2018, 2, 2)),
        ('dag_run_for_testing_dt_nr_dr_form_1', datetime(2018, 1, 1)),
    ]

    def __init__(self, test, endpoint):
        self.test = test
        self.endpoint = endpoint

    def setUp(self):
        configuration.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        from airflow.utils.state import State
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        self.runs = []
        for rd in self.RUNS_DATA:
            run = dag.create_dagrun(run_id=rd[0],
                                    execution_date=rd[1],
                                    state=State.SUCCESS,
                                    external_trigger=True)
            self.runs.append(run)

    def tearDown(self):
        self.session.query(DagRun).filter(
            DagRun.dag_id == self.DAG_ID).delete()
        self.session.commit()
        self.session.close()

    def assertBaseDateAndNumRuns(self, base_date, num_runs, data):
        self.test.assertNotIn('name="base_date" value="{}"'.format(base_date),
                              data)
        self.test.assertNotIn(
            '<option selected="" value="{}">{}</option>'.format(
                num_runs, num_runs), data)

    def assertRunIsNotInDropdown(self, run, data):
        self.test.assertNotIn(run.execution_date.isoformat(), data)
        self.test.assertNotIn(run.run_id, data)

    def assertRunIsInDropdownNotSelected(self, run, data):
        self.test.assertIn(
            '<option value="{}">{}</option>'.format(
                run.execution_date.isoformat(), run.run_id), data)

    def assertRunIsSelected(self, run, data):
        self.test.assertIn(
            '<option selected value="{}">{}</option>'.format(
                run.execution_date.isoformat(), run.run_id), data)

    def test_with_default_parameters(self):
        """
        Tests graph view with no URL parameter.
        Should show all dag runs in the drop down.
        Should select the latest dag run.
        Should set base date to current date (not asserted)
        """
        response = self.app.get(self.endpoint)
        self.test.assertEqual(response.status_code, 200)
        data = response.data.decode('utf-8')
        self.test.assertIn('Base date:', data)
        self.test.assertIn('Number of runs:', data)
        self.assertRunIsSelected(self.runs[0], data)
        self.assertRunIsInDropdownNotSelected(self.runs[1], data)
        self.assertRunIsInDropdownNotSelected(self.runs[2], data)
        self.assertRunIsInDropdownNotSelected(self.runs[3], data)

    def test_with_execution_date_parameter_only(self):
        """
        Tests graph view with execution_date URL parameter.
        Scenario: click link from dag runs view.
        Should only show dag runs older than execution_date in the drop down.
        Should select the particular dag run.
        Should set base date to execution date.
        """
        response = self.app.get(self.endpoint + '&execution_date={}'.format(
            self.runs[1].execution_date.isoformat()))
        self.test.assertEqual(response.status_code, 200)
        data = response.data.decode('utf-8')
        self.assertBaseDateAndNumRuns(
            self.runs[1].execution_date,
            configuration.getint('webserver',
                                 'default_dag_run_display_number'), data)
        self.assertRunIsNotInDropdown(self.runs[0], data)
        self.assertRunIsSelected(self.runs[1], data)
        self.assertRunIsInDropdownNotSelected(self.runs[2], data)
        self.assertRunIsInDropdownNotSelected(self.runs[3], data)

    def test_with_base_date_and_num_runs_parmeters_only(self):
        """
        Tests graph view with base_date and num_runs URL parameters.
        Should only show dag runs older than base_date in the drop down,
        limited to num_runs.
        Should select the latest dag run.
        Should set base date and num runs to submitted values.
        """
        response = self.app.get(self.endpoint +
                                '&base_date={}&num_runs=2'.format(
                                    self.runs[1].execution_date.isoformat()))
        self.test.assertEqual(response.status_code, 200)
        data = response.data.decode('utf-8')
        self.assertBaseDateAndNumRuns(self.runs[1].execution_date, 2, data)
        self.assertRunIsNotInDropdown(self.runs[0], data)
        self.assertRunIsSelected(self.runs[1], data)
        self.assertRunIsInDropdownNotSelected(self.runs[2], data)
        self.assertRunIsNotInDropdown(self.runs[3], data)

    def test_with_base_date_and_num_runs_and_execution_date_outside(self):
        """
        Tests graph view with base_date and num_runs and execution-date URL parameters.
        Scenario: change the base date and num runs and press "Go",
        the selected execution date is outside the new range.
        Should only show dag runs older than base_date in the drop down.
        Should select the latest dag run within the range.
        Should set base date and num runs to submitted values.
        """
        response = self.app.get(self.endpoint +
                                '&base_date={}&num_runs=42&execution_date={}'.
                                format(self.runs[1].execution_date.isoformat(
                                ), self.runs[0].execution_date.isoformat()))
        self.test.assertEqual(response.status_code, 200)
        data = response.data.decode('utf-8')
        self.assertBaseDateAndNumRuns(self.runs[1].execution_date, 42, data)
        self.assertRunIsNotInDropdown(self.runs[0], data)
        self.assertRunIsSelected(self.runs[1], data)
        self.assertRunIsInDropdownNotSelected(self.runs[2], data)
        self.assertRunIsInDropdownNotSelected(self.runs[3], data)

    def test_with_base_date_and_num_runs_and_execution_date_within(self):
        """
        Tests graph view with base_date and num_runs and execution-date URL parameters.
        Scenario: change the base date and num runs and press "Go",
        the selected execution date is within the new range.
        Should only show dag runs older than base_date in the drop down.
        Should select the dag run with the execution date.
        Should set base date and num runs to submitted values.
        """
        response = self.app.get(self.endpoint +
                                '&base_date={}&num_runs=5&execution_date={}'.
                                format(self.runs[2].execution_date.isoformat(
                                ), self.runs[3].execution_date.isoformat()))
        self.test.assertEqual(response.status_code, 200)
        data = response.data.decode('utf-8')
        self.assertBaseDateAndNumRuns(self.runs[2].execution_date, 5, data)
        self.assertRunIsNotInDropdown(self.runs[0], data)
        self.assertRunIsNotInDropdown(self.runs[1], data)
        self.assertRunIsInDropdownNotSelected(self.runs[2], data)
        self.assertRunIsSelected(self.runs[3], data)
Ejemplo n.º 35
0
class TestKnownEventView(unittest.TestCase):

    CREATE_ENDPOINT = '/admin/knownevent/new/?url=/admin/knownevent/'

    @classmethod
    def setUpClass(cls):
        super(TestKnownEventView, cls).setUpClass()
        session = Session()
        session.query(models.KnownEvent).delete()
        session.query(models.User).delete()
        session.commit()
        user = models.User(username='******')
        session.add(user)
        session.commit()
        cls.user_id = user.id
        session.close()

    def setUp(self):
        super(TestKnownEventView, self).setUp()
        configuration.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        self.known_event = {
            'label': 'event-label',
            'event_type': '1',
            'start_date': '2017-06-05 12:00:00',
            'end_date': '2017-06-05 13:00:00',
            'reported_by': self.user_id,
            'description': '',
        }

    def tearDown(self):
        self.session.query(models.KnownEvent).delete()
        self.session.commit()
        self.session.close()
        super(TestKnownEventView, self).tearDown()

    @classmethod
    def tearDownClass(cls):
        session = Session()
        session.query(models.User).delete()
        session.commit()
        session.close()
        super(TestKnownEventView, cls).tearDownClass()

    def test_create_known_event(self):
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.known_event,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)
        self.assertEqual(self.session.query(models.KnownEvent).count(), 1)

    def test_create_known_event_with_end_data_earlier_than_start_date(self):
        self.known_event['end_date'] = '2017-06-05 11:00:00'
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.known_event,
            follow_redirects=True,
        )
        self.assertIn(
            'Field must be greater than or equal to Start Date.',
            response.data.decode('utf-8'),
        )
        self.assertEqual(self.session.query(models.KnownEvent).count(), 0)
Ejemplo n.º 36
0
 def setUpClass(cls):
     super(TestKnownEventView, cls).setUpClass()
     session = Session()
     session.query(models.KnownEvent).delete()
     session.query(models.User).delete()
     session.commit()
     user = models.User(username='******')
     session.add(user)
     session.commit()
     cls.user_id = user.id
     session.close()
def init_datavault2_bigdata_example():
    logging.info('Creating connections, pool and sql path')

    session = Session()

    def create_new_conn(session, attributes):
        conn_id = attributes.get("conn_id")
        new_conn = session.query(models.Connection).filter(
            models.Connection.conn_id == conn_id).first()
        if not new_conn:
            logging.info("No connection found")
            new_conn = models.Connection()
        new_conn.conn_id = conn_id
        new_conn.conn_type = attributes.get('conn_type')
        new_conn.host = attributes.get('host')
        new_conn.port = attributes.get('port')
        new_conn.schema = attributes.get('schema')
        new_conn.login = attributes.get('login')
        new_conn.set_password(attributes.get('password'))
        new_conn.set_extra(attributes.get('extra'))

        session.merge(new_conn)
        session.commit()

    create_new_conn(
        session, {
            "conn_id": "dvdrentals",
            "conn_type": "postgres",
            "host": "postgres",
            "port": 5432,
            "schema": "dvdrentals",
            "login": "******",
            "password": "******"
        })

    create_new_conn(
        session, {
            "conn_id": "filestore",
            "conn_type": "File",
            "host": "",
            "port": 0,
            "schema": "",
            "login": "",
            "password": "",
            "extra": json.dumps({"path": "/tmp/datavault2-bigdata-example"})
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_default",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            "default",
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_datavault_raw",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            DATAVAULT,
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_dvdrentals_staging",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            DVDRENTALS_STAGING,
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_datavault_temp",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            'dv_temp',
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id": "hiveserver2-dvstar",
            "conn_type": "hiveserver2",
            "host": "hive",
            "schema": 'dv_star',
            "login": "******",
            "port": 10000,
            "extra": json.dumps({"authMechanism": "NOSASL"})
        })

    session.close()
def init_datavault2_example():
    logging.info('Creating connections, pool and sql path')

    session = Session()

    def create_new_conn(session, attributes):
        conn_id = attributes.get("conn_id")
        new_conn = session.query(models.Connection).filter(
            models.Connection.conn_id == conn_id).first()
        if not new_conn:
            logging.info("No connection found")
            new_conn = models.Connection()
        new_conn.conn_id = conn_id
        new_conn.conn_type = attributes.get('conn_type')
        new_conn.host = attributes.get('host')
        new_conn.port = attributes.get('port')
        new_conn.schema = attributes.get('schema')
        new_conn.login = attributes.get('login')
        new_conn.set_password(attributes.get('password'))
        new_conn.set_extra(attributes.get('extra'))

        session.merge(new_conn)
        session.commit()

    create_new_conn(
        session, {
            "conn_id": "adventureworks",
            "conn_type": "postgres",
            "host": "postgres",
            "port": 5432,
            "schema": "adventureworks",
            "login": "******",
            "password": "******"
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_default",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            "default",
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_advworks_staging",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            ADVWORKS_STAGING,
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_datavault_raw",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            DATAVAULT,
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_datavault_temp",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            'dv_temp',
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id":
            "hive_dv_star",
            "conn_type":
            "hive_cli",
            "host":
            "hive",
            "schema":
            'dv_star',
            "port":
            10000,
            "login":
            "******",
            "password":
            "******",
            "extra":
            json.dumps({
                "hive_cli_params": "",
                "auth": "noSasl",
                "use_beeline": "true"
            })
        })

    create_new_conn(
        session, {
            "conn_id": "hiveserver2-dvstar",
            "conn_type": "hiveserver2",
            "host": "hive",
            "schema": 'dv_star',
            "login": "******",
            "port": 10000,
            "extra": json.dumps({"authMechanism": "NOSASL"})
        })

    create_new_conn(
        session, {
            "conn_id":
            "gcp",
            "conn_type":
            "google_cloud_platform",
            "extra":
            json.dumps({
                "extra__google_cloud_platform__key_path":
                "/usr/local/airflow/keyfile.json",
                "extra__google_cloud_platform__scope":
                "https://www.googleapis.com/auth/cloud-platform"
            })
        })
    session.close()
Ejemplo n.º 39
0
 def tearDown(self):
     session = Session()
     session.query(DagRun).delete()
     session.commit()
     session.close()
     super(TestDagRunsEndpoint, self).tearDown()
Ejemplo n.º 40
0
 def setUp(self):
     self.session = Session()
     self.cleanup_dagruns()
     self.prepare_dagruns()
Ejemplo n.º 41
0
 def tearDown(self):
     session = Session()
     session.query(DagRun).delete()
     session.commit()
     session.close()
     super().tearDown()
Ejemplo n.º 42
0
class TestDecorators(unittest.TestCase):
    EXAMPLE_DAG_DEFAULT_DATE = dates.days_ago(2)
    run_id = "test_{}".format(DagRun.id_for_date(EXAMPLE_DAG_DEFAULT_DATE))

    @classmethod
    def setUpClass(cls):
        cls.dagbag = DagBag(include_examples=True)
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        cls.app = app.test_client()

    def setUp(self):
        self.session = Session()
        self.cleanup_dagruns()
        self.prepare_dagruns()

    def cleanup_dagruns(self):
        DR = DagRun
        dag_ids = 'example_bash_operator'
        (self.session.query(DR).filter(DR.dag_id == dag_ids).filter(
            DR.run_id == self.run_id).delete(synchronize_session='fetch'))
        self.session.commit()

    def prepare_dagruns(self):
        self.bash_dag = self.dagbag.dags['example_bash_operator']
        self.bash_dag.sync_to_db()

        self.bash_dagrun = self.bash_dag.create_dagrun(
            run_id=self.run_id,
            execution_date=self.EXAMPLE_DAG_DEFAULT_DATE,
            start_date=timezone.utcnow(),
            state=State.RUNNING)

    def check_last_log(self, dag_id, event, execution_date=None):
        qry = self.session.query(Log.dag_id, Log.task_id, Log.event,
                                 Log.execution_date, Log.owner, Log.extra)
        qry = qry.filter(Log.dag_id == dag_id, Log.event == event)
        if execution_date:
            qry = qry.filter(Log.execution_date == execution_date)
        logs = qry.order_by(Log.dttm.desc()).limit(5).all()
        self.assertGreaterEqual(len(logs), 1)
        self.assertTrue(logs[0].extra)

    def test_action_logging_get(self):
        url = '/admin/airflow/graph?dag_id=example_bash_operator&execution_date={}'.format(
            quote_plus(
                self.EXAMPLE_DAG_DEFAULT_DATE.isoformat().encode('utf-8')))
        self.app.get(url, follow_redirects=True)

        # In mysql backend, this commit() is needed to write down the logs
        self.session.commit()
        self.check_last_log("example_bash_operator",
                            event="graph",
                            execution_date=self.EXAMPLE_DAG_DEFAULT_DATE)

    def test_action_logging_post(self):
        form = dict(
            task_id="runme_1",
            dag_id="example_bash_operator",
            execution_date=self.EXAMPLE_DAG_DEFAULT_DATE.isoformat().encode(
                'utf-8'),
            upstream="false",
            downstream="false",
            future="false",
            past="false",
            only_failed="false",
        )
        self.app.post("/admin/airflow/clear", data=form)
        # In mysql backend, this commit() is needed to write down the logs
        self.session.commit()
        self.check_last_log("example_bash_operator",
                            event="clear",
                            execution_date=self.EXAMPLE_DAG_DEFAULT_DATE)
Ejemplo n.º 43
0
 def setUpClass(cls):
     super(TestVariableView, cls).setUpClass()
     session = Session()
     session.query(models.Variable).delete()
     session.commit()
     session.close()
    def tearDown(self):
        super(BranchOperatorTest, self).tearDown()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        print(len(session.query(DagRun).all()))
        session.commit()
        session.close()
Ejemplo n.º 45
0
class TestPoolModelView(unittest.TestCase):

    CREATE_ENDPOINT = '/admin/pool/new/?url=/admin/pool/'

    @classmethod
    def setUpClass(cls):
        super(TestPoolModelView, cls).setUpClass()
        session = Session()
        session.query(models.Pool).delete()
        session.commit()
        session.close()

    def setUp(self):
        super(TestPoolModelView, self).setUp()
        configuration.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        self.pool = {
            'pool': 'test-pool',
            'slots': 777,
            'description': 'test-pool-description',
        }

    def tearDown(self):
        self.session.query(models.Pool).delete()
        self.session.commit()
        self.session.close()
        super(TestPoolModelView, self).tearDown()

    def test_create_pool(self):
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.pool,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)
        self.assertEqual(self.session.query(models.Pool).count(), 1)

    def test_create_pool_with_same_name(self):
        # create test pool
        self.app.post(
            self.CREATE_ENDPOINT,
            data=self.pool,
            follow_redirects=True,
        )
        # create pool with the same name
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.pool,
            follow_redirects=True,
        )
        self.assertIn('Already exists.', response.data.decode('utf-8'))
        self.assertEqual(self.session.query(models.Pool).count(), 1)

    def test_create_pool_with_empty_name(self):
        self.pool['pool'] = ''
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.pool,
            follow_redirects=True,
        )
        self.assertIn('This field is required.', response.data.decode('utf-8'))
        self.assertEqual(self.session.query(models.Pool).count(), 0)
 def tearDown(self):
     session = Session()
     session.query(models.User).delete()
     session.commit()
     session.close()
    def tearDown(self):
        super(PythonOperatorTest, self).tearDown()

        session = Session()

        session.query(DagRun).delete()
        session.query(TI).delete()
        print(len(session.query(DagRun).all()))
        session.commit()
        session.close()

        for var in TI_CONTEXT_ENV_VARS:
            if var in os.environ:
                del os.environ[var]
Ejemplo n.º 48
0
 def setUpClass(cls):
     super(TestPoolModelView, cls).setUpClass()
     session = Session()
     session.query(models.Pool).delete()
     session.commit()
     session.close()
Ejemplo n.º 49
0
 def tearDown(self):
     session = Session()
     session.query(models.TaskInstance).filter_by(dag_id=TEST_DAG_ID).delete()
     session.query(TaskFail).filter_by(dag_id=TEST_DAG_ID).delete()
     session.commit()
     session.close()
Ejemplo n.º 50
0
 def setUpClass(cls):
     super(TestChartModelView, cls).setUpClass()
     session = Session()
     session.query(models.Chart).delete()
     session.query(models.User).delete()
     session.commit()
     user = models.User(username='******')
     session.add(user)
     session.commit()
     session.close()
Ejemplo n.º 51
0
class TestLogView(unittest.TestCase):
    DAG_ID = 'dag_for_testing_log_view'
    TASK_ID = 'task_for_testing_log_view'
    DEFAULT_DATE = datetime(2017, 9, 1)
    ENDPOINT = '/admin/airflow/log?dag_id={dag_id}&task_id={task_id}&execution_date={execution_date}'.format(
        dag_id=DAG_ID,
        task_id=TASK_ID,
        execution_date=DEFAULT_DATE,
    )

    @classmethod
    def setUpClass(cls):
        super(TestLogView, cls).setUpClass()
        session = Session()
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == cls.DAG_ID
            and TaskInstance.task_id == cls.TASK_ID
            and TaskInstance.execution_date == cls.DEFAULT_DATE).delete()
        session.commit()
        session.close()

    def setUp(self):
        super(TestLogView, self).setUp()

        # Create a custom logging configuration
        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task'][
            'base_log_folder'] = os.path.normpath(
                os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder,
                                     "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class',
                 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()

    def tearDown(self):
        logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
        self.session.query(TaskInstance).filter(
            TaskInstance.dag_id == self.DAG_ID
            and TaskInstance.task_id == self.TASK_ID
            and TaskInstance.execution_date == self.DEFAULT_DATE).delete()
        self.session.commit()
        self.session.close()

        sys.path.remove(self.settings_folder)
        shutil.rmtree(self.settings_folder)
        conf.set('core', 'logging_config_class', '')

        super(TestLogView, self).tearDown()

    def test_get_file_task_log(self):
        response = self.app.get(
            TestLogView.ENDPOINT,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)
        self.assertIn('Log by attempts', response.data.decode('utf-8'))

    def test_get_logs_with_metadata_as_download_file(self):
        url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \
                       "task_id={}&execution_date={}&" \
                       "try_number={}&metadata={}&format=file"
        try_number = 1
        url = url_template.format(self.DAG_ID, self.TASK_ID,
                                  quote_plus(self.DEFAULT_DATE.isoformat()),
                                  try_number, json.dumps({}))
        response = self.app.get(url)
        expected_filename = '{}/{}/{}/{}.log'.format(
            self.DAG_ID, self.TASK_ID, self.DEFAULT_DATE.isoformat(),
            try_number)

        content_disposition = response.headers.get('Content-Disposition')
        self.assertTrue(content_disposition.startswith('attachment'))
        self.assertTrue(expected_filename in content_disposition)
        self.assertEqual(200, response.status_code)
        self.assertIn('Log for testing.', response.data.decode('utf-8'))

    def test_get_logs_with_metadata(self):
        url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \
                       "task_id={}&execution_date={}&" \
                       "try_number={}&metadata={}"
        response = \
            self.app.get(url_template.format(self.DAG_ID,
                                             self.TASK_ID,
                                             quote_plus(self.DEFAULT_DATE.isoformat()),
                                             1,
                                             json.dumps({})))

        self.assertIn('"message":', response.data.decode('utf-8'))
        self.assertIn('"metadata":', response.data.decode('utf-8'))
        self.assertIn('Log for testing.', response.data.decode('utf-8'))
        self.assertEqual(200, response.status_code)

    def test_get_logs_with_null_metadata(self):
        url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \
                       "task_id={}&execution_date={}&" \
                       "try_number={}&metadata=null"
        response = \
            self.app.get(url_template.format(self.DAG_ID,
                                             self.TASK_ID,
                                             quote_plus(self.DEFAULT_DATE.isoformat()),
                                             1))

        self.assertIn('"message":', response.data.decode('utf-8'))
        self.assertIn('"metadata":', response.data.decode('utf-8'))
        self.assertIn('Log for testing.', response.data.decode('utf-8'))
        self.assertEqual(200, response.status_code)
Ejemplo n.º 52
0
 def tearDownClass(cls):
     session = Session()
     session.query(models.User).delete()
     session.commit()
     session.close()
     super(TestChartModelView, cls).tearDownClass()
Ejemplo n.º 53
0
    def tearDown(self):
        if os.environ.get('KUBERNETES_VERSION') is not None:
            return

        dag_ids_to_clean = [
            TEST_DAG_ID,
            self.TEST_SCHEDULE_WITH_NO_PREVIOUS_RUNS_DAG_ID,
            self.TEST_SCHEDULE_DAG_FAKE_SCHEDULED_PREVIOUS_DAG_ID,
            self.TEST_SCHEDULE_DAG_NO_END_DATE_UP_TO_TODAY_ONLY_DAG_ID,
            self.TEST_SCHEDULE_ONCE_DAG_ID,
            self.TEST_SCHEDULE_RELATIVEDELTA_DAG_ID,
            self.TEST_SCHEDULE_START_END_DATES_DAG_ID,
        ]
        session = Session()
        session.query(DagRun).filter(
            DagRun.dag_id.in_(dag_ids_to_clean)).delete(
            synchronize_session=False)
        session.query(TaskInstance).filter(
            TaskInstance.dag_id.in_(dag_ids_to_clean)).delete(
            synchronize_session=False)
        session.query(TaskFail).filter(
            TaskFail.dag_id.in_(dag_ids_to_clean)).delete(
            synchronize_session=False)
        session.commit()
        session.close()
Ejemplo n.º 54
0
 def reset_dr_db(dag_id):
     session = Session()
     dr = session.query(models.DagRun).filter_by(dag_id=dag_id)
     dr.delete()
     session.commit()
     session.close()
Ejemplo n.º 55
0
def reset(dag_id=TEST_DAG_ID):
    session = Session()
    tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
    tis.delete()
    session.commit()
    session.close()
def auto_conn():
    logging.info('Creating connections, pool and sql path')

    session = Session()

    def create_new_conn(session, attributes):
        new_conn = models.Connection()
        new_conn.conn_id = attributes.get("conn_id")
        new_conn.conn_type = attributes.get('conn_type')
        new_conn.host = attributes.get('host')
        new_conn.port = attributes.get('port')
        new_conn.schema = attributes.get('schema')
        new_conn.login = attributes.get('login')
        new_conn.extra = attributes.get('extra')
        # new_conn.password = attributes.get('password')
        new_conn.set_password(attributes.get('password'))

        session.add(new_conn)
        session.commit()

    create_new_conn(session,
                    {"conn_id": configuration.get('s3' , 's3_conn_id'),
                     "conn_type": configuration.get('s3' , 's3_conn_type'),
                     "extra":configuration.get('s3', 's3_extra')

                     })

    create_new_conn(session,
                    {"conn_id": configuration.get('mysql', 'mysql_conn_id'),
                     "conn_type": configuration.get('mysql', 'mysql_conn_type'),
                     "schema":configuration.get('mysql', 'mysql_schema'),
                     "host": configuration.get('mysql', 'mysql_host'),
                     "port": configuration.getint('mysql', 'mysql_port'),
                     "login": configuration.get('mysql', 'mysql_login'),
                     "password": configuration.get('mysql', 'mysql_password')})

    create_new_conn(session,
                    {"conn_id": configuration.get('postgresql', 'postgresql_conn_id'),
                     "conn_type": configuration.get('postgresql', 'postgresql_conn_type'),
                     "host": configuration.get('postgresql', 'postgresql_host'),
                     "port": configuration.getint('postgresql', 'postgresql_port'),
                     "schema": configuration.get('postgresql', 'postgresql_schema'),
                     "login": configuration.get('postgresql', 'postgresql_login'),
                     "password": configuration.get('postgresql', 'postgresql_password')})

    create_new_conn(session,
                    {"conn_id": "airflow_connection",
                     "conn_type": configuration.get('mysql', 'mysql_conn_type'),
                     "schema": "airflow",
                     "host": "localhost",
                     "login": "******",
                     "password": "******"})

    create_new_conn(session, {
        "conn_id": "mongo_connection",
        "conn_type": "mongo",
        "host": "13.126.117.239",
        "port": "27017",
        "login": "******",
        "password": "******"
    });




    session.close()
Ejemplo n.º 57
0
class TestLogView(TestBase):
    DAG_ID = 'dag_for_testing_log_view'
    TASK_ID = 'task_for_testing_log_view'
    DEFAULT_DATE = timezone.datetime(2017, 9, 1)
    ENDPOINT = 'log?dag_id={dag_id}&task_id={task_id}&' \
               'execution_date={execution_date}'.format(dag_id=DAG_ID,
                                                        task_id=TASK_ID,
                                                        execution_date=DEFAULT_DATE)

    def setUp(self):
        conf.load_test_config()

        # Create a custom logging configuration
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task']['base_log_folder'] = os.path.normpath(
            os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/' \
            '{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG')

        self.app, self.appbuilder = application.create_app(testing=True)
        self.app.config['WTF_CSRF_ENABLED'] = False
        self.client = self.app.test_client()
        self.login()
        self.session = Session()

        from airflow.www_rbac.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()

    def tearDown(self):
        logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
        self.clear_table(TaskInstance)

        shutil.rmtree(self.settings_folder)
        conf.set('core', 'logging_config_class', '')

        self.logout()
        super(TestLogView, self).tearDown()

    def test_get_file_task_log(self):
        response = self.client.get(
            TestLogView.ENDPOINT,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)
        self.assertIn('Log by attempts',
                      response.data.decode('utf-8'))

    def test_get_logs_with_metadata(self):
        url_template = "get_logs_with_metadata?dag_id={}&" \
                       "task_id={}&execution_date={}&" \
                       "try_number={}&metadata={}"
        response = \
            self.client.get(url_template.format(self.DAG_ID,
                                                self.TASK_ID,
                                                quote_plus(self.DEFAULT_DATE.isoformat()),
                                                1,
                                                json.dumps({})), follow_redirects=True)

        self.assertIn('"message":', response.data.decode('utf-8'))
        self.assertIn('"metadata":', response.data.decode('utf-8'))
        self.assertIn('Log for testing.', response.data.decode('utf-8'))
        self.assertEqual(200, response.status_code)

    def test_get_logs_with_null_metadata(self):
        url_template = "get_logs_with_metadata?dag_id={}&" \
                       "task_id={}&execution_date={}&" \
                       "try_number={}&metadata=null"
        response = \
            self.client.get(url_template.format(self.DAG_ID,
                                                self.TASK_ID,
                                                quote_plus(self.DEFAULT_DATE.isoformat()),
                                                1), follow_redirects=True)

        self.assertIn('"message":', response.data.decode('utf-8'))
        self.assertIn('"metadata":', response.data.decode('utf-8'))
        self.assertIn('Log for testing.', response.data.decode('utf-8'))
        self.assertEqual(200, response.status_code)
Ejemplo n.º 58
0
class TestVariableView(unittest.TestCase):

    CREATE_ENDPOINT = '/admin/variable/new/?url=/admin/variable/'

    @classmethod
    def setUpClass(cls):
        super(TestVariableView, cls).setUpClass()
        session = Session()
        session.query(models.Variable).delete()
        session.commit()
        session.close()

    def setUp(self):
        super(TestVariableView, self).setUp()
        configuration.load_test_config()
        app = application.create_app(testing=True)
        app.config['WTF_CSRF_METHODS'] = []
        self.app = app.test_client()
        self.session = Session()
        self.variable = {
            'key': 'test_key',
            'val': 'text_val',
            'is_encrypted': True
        }

    def tearDown(self):
        self.session.query(models.Variable).delete()
        self.session.commit()
        self.session.close()
        super(TestVariableView, self).tearDown()

    def test_can_handle_error_on_decrypt(self):
        # create valid variable
        response = self.app.post(
            self.CREATE_ENDPOINT,
            data=self.variable,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 200)

        # update the variable with a wrong value, given that is encrypted
        Var = models.Variable
        (self.session.query(Var).filter(
            Var.key == self.variable['key']).update(
                {'val': 'failed_value_not_encrypted'},
                synchronize_session=False))
        self.session.commit()

        # retrieve Variables page, should not fail and contain the Invalid
        # label for the variable
        response = self.app.get('/admin/variable', follow_redirects=True)
        self.assertEqual(response.status_code, 200)
        self.assertEqual(self.session.query(models.Variable).count(), 1)

    def test_xss_prevention(self):
        xss = "/admin/airflow/variables/asdf<img%20src=''%20onerror='alert(1);'>"

        response = self.app.get(
            xss,
            follow_redirects=True,
        )
        self.assertEqual(response.status_code, 404)
        self.assertNotIn("<img src='' onerror='alert(1);'>",
                         response.data.decode("utf-8"))
Ejemplo n.º 59
0
    def start(self):
        self.task_queue = Queue()
        self.result_queue = Queue()
        framework = mesos_pb2.FrameworkInfo()
        framework.user = ''

        if not conf.get('mesos', 'MASTER'):
            self.log.error("Expecting mesos master URL for mesos executor")
            raise AirflowException("mesos.master not provided for mesos executor")

        master = conf.get('mesos', 'MASTER')

        framework.name = get_framework_name()

        if not conf.get('mesos', 'TASK_CPU'):
            task_cpu = 1
        else:
            task_cpu = conf.getint('mesos', 'TASK_CPU')

        if not conf.get('mesos', 'TASK_MEMORY'):
            task_memory = 256
        else:
            task_memory = conf.getint('mesos', 'TASK_MEMORY')

        if conf.getboolean('mesos', 'CHECKPOINT'):
            framework.checkpoint = True

            if conf.get('mesos', 'FAILOVER_TIMEOUT'):
                # Import here to work around a circular import error
                from airflow.models import Connection

                # Query the database to get the ID of the Mesos Framework, if available.
                conn_id = FRAMEWORK_CONNID_PREFIX + framework.name
                session = Session()
                connection = session.query(Connection).filter_by(conn_id=conn_id).first()
                if connection is not None:
                    # Set the Framework ID to let the scheduler reconnect
                    # with running tasks.
                    framework.id.value = connection.extra

                framework.failover_timeout = conf.getint(
                    'mesos', 'FAILOVER_TIMEOUT'
                )
        else:
            framework.checkpoint = False

        self.log.info(
            'MesosFramework master : %s, name : %s, cpu : %s, mem : %s, checkpoint : %s',
            master, framework.name,
            str(task_cpu), str(task_memory), str(framework.checkpoint)
        )

        implicit_acknowledgements = 1

        if conf.getboolean('mesos', 'AUTHENTICATE'):
            if not conf.get('mesos', 'DEFAULT_PRINCIPAL'):
                self.log.error("Expecting authentication principal in the environment")
                raise AirflowException(
                    "mesos.default_principal not provided in authenticated mode")
            if not conf.get('mesos', 'DEFAULT_SECRET'):
                self.log.error("Expecting authentication secret in the environment")
                raise AirflowException(
                    "mesos.default_secret not provided in authenticated mode")

            credential = mesos_pb2.Credential()
            credential.principal = conf.get('mesos', 'DEFAULT_PRINCIPAL')
            credential.secret = conf.get('mesos', 'DEFAULT_SECRET')

            framework.principal = credential.principal

            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue,
                                      self.result_queue,
                                      task_cpu,
                                      task_memory),
                framework,
                master,
                implicit_acknowledgements,
                credential)
        else:
            framework.principal = 'Airflow'
            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue,
                                      self.result_queue,
                                      task_cpu,
                                      task_memory),
                framework,
                master,
                implicit_acknowledgements)

        self.mesos_driver = driver
        self.mesos_driver.start()
Ejemplo n.º 60
0
 def tearDown(self):
     session = Session()
     session.query(DagRun).filter(DagRun.dag_id == TEST_DAG_ID).delete(
         synchronize_session=False)
     session.query(TaskInstance).filter(
         TaskInstance.dag_id == TEST_DAG_ID).delete(
             synchronize_session=False)
     session.query(TaskFail).filter(TaskFail.dag_id == TEST_DAG_ID).delete(
         synchronize_session=False)
     session.commit()
     session.close()