예제 #1
0
 def test_should_read_logging_configuration(self):
     with conf_vars({
         ('logging', 'remote_logging'): 'True',
         ('logging', 'remote_base_log_folder'): 'stackdriver://logs-name',
     }):
         importlib.reload(airflow_local_settings)
         configure_logging()
         instance = info_command.ConfigInfo(info_command.NullAnonymizer())
         text = str(instance)
         self.assertIn("StackdriverTaskHandler", text)
예제 #2
0
    def test_config_use_original_when_original_and_fallback_are_present(self):
        self.assertTrue(conf.has_option("core", "FERNET_KEY"))
        self.assertFalse(conf.has_option("core", "FERNET_KEY_CMD"))

        fernet_key = conf.get('core', 'FERNET_KEY')

        with conf_vars({('core', 'FERNET_KEY_CMD'): 'printf HELLO'}):
            fallback_fernet_key = conf.get("core", "FERNET_KEY")

        self.assertEqual(fernet_key, fallback_fernet_key)
예제 #3
0
    def setup_class(cls) -> None:
        with conf_vars({
            ("api", "auth_backend"):
                "tests.test_utils.remote_user_api_auth_backend"
        }):
            cls.app = app.create_app(testing=True)  # type:ignore
        # TODO: Add new role for each view to test permission
        create_user(cls.app, username="******", role="Admin")  # type: ignore

        cls.client = None
예제 #4
0
 def setUpClass(cls) -> None:
     super().setUpClass()
     with mock.patch.dict(
             "os.environ", SKIP_DAGS_PARSING="True"), conf_vars({
                 ("api", "auth_backend"):
                 "tests.test_utils.remote_user_api_auth_backend"
             }):
         cls.app = app.create_app(testing=True)  # type:ignore
     # TODO: Add new role for each view to test permission.
     create_user(cls.app, username="******", role="Admin")  # type: ignore
예제 #5
0
 def test_safe_mode_disabled(self):
     """With safe mode disabled, an empty python file should be discovered.
     """
     with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
         with conf_vars({('core', 'dags_folder'): self.empty_dir}):
             dagbag = models.DagBag(include_examples=False, safe_mode=False)
         self.assertEqual(len(dagbag.dagbag_stats), 1)
         self.assertEqual(
             dagbag.dagbag_stats[0].file,
             "/{}".format(os.path.basename(f.name)))
예제 #6
0
 def test_should_support_plugin(self):
     executors_modules.append(
         make_module('airflow.executors.' + TEST_PLUGIN_NAME,
                     [FakeExecutor]))
     self.addCleanup(self.remove_executor_module)
     with conf_vars({
         ("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor"
     }):
         executor = ExecutorLoader.get_default_executor()
         self.assertIsNotNone(executor)
         self.assertIn("FakeExecutor", executor.__class__.__name__)
예제 #7
0
 def test_load_custom_statsd_client(self):
     with conf_vars({
         ("metrics", "statsd_on"):
             "True",
         ("metrics", "statsd_custom_client_path"):
             f"{__name__}.CustomStatsd",
     }):
         importlib.reload(airflow.stats)
         assert isinstance(airflow.stats.Stats.statsd, CustomStatsd)
     # Avoid side-effects
     importlib.reload(airflow.stats)
예제 #8
0
    def test_config_throw_error_when_original_and_fallback_is_absent(self):
        assert conf.has_option("core", "FERNET_KEY")
        assert not conf.has_option("core", "FERNET_KEY_CMD")

        with conf_vars({('core', 'fernet_key'): None}):
            with pytest.raises(AirflowConfigException) as ctx:
                conf.get("core", "FERNET_KEY")

        exception = str(ctx.value)
        message = "section/key [core/fernet_key] not found in config"
        assert message == exception
예제 #9
0
    def test_config_throw_error_when_original_and_fallback_is_absent(self):
        self.assertTrue(conf.has_option("core", "FERNET_KEY"))
        self.assertFalse(conf.has_option("core", "FERNET_KEY_CMD"))

        with conf_vars({('core', 'fernet_key'): None}):
            with self.assertRaises(AirflowConfigException) as cm:
                conf.get("core", "FERNET_KEY")

        exception = str(cm.exception)
        message = "section/key [core/fernet_key] not found in config"
        self.assertEqual(message, exception)
    def test_hiding_config(self, sensitive_variable_fields, key,
                           expected_result):
        from airflow.utils.log.secrets_masker import get_sensitive_variables_fields

        with conf_vars({
            ('core', 'sensitive_var_conn_names'):
                str(sensitive_variable_fields)
        }):
            get_sensitive_variables_fields.cache_clear()
            assert expected_result == should_hide_value_for_key(key)
        get_sensitive_variables_fields.cache_clear()
예제 #11
0
 def test_sql_alchemy_invalid_connect_args(self, mock_create_engine,
                                           mock_sessionmaker,
                                           mock_scoped_session,
                                           mock_setup_event_handlers):
     config = {
         ('core', 'sql_alchemy_connect_args'): 'does.not.exist',
         ('core', 'sql_alchemy_pool_enabled'): 'False'
     }
     with self.assertRaises(AirflowConfigException):
         with conf_vars(config):
             settings.configure_orm()
    def test_should_remove_show_paused_from_url_params(self, show_paused,
                                                       hide_by_default, expected_result):
        with conf_vars({('webserver', 'hide_paused_dags_by_default'): str(hide_by_default)}):

            self.assertEqual(
                expected_result,
                utils._should_remove_show_paused_from_url_params(
                    show_paused,
                    hide_by_default
                )
            )
예제 #13
0
 def setUpClass(cls) -> None:
     super().setUpClass()
     with conf_vars({("api", "auth_backend"): "tests.test_utils.remote_user_api_auth_backend"}):
         cls.app = app.create_app(testing=True)  # type:ignore
     create_user(
         cls.app,  # type:ignore
         username="******",
         role_name="Test",
         permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)],  # type: ignore
     )
     create_user(cls.app, username="******", role_name="TestNoPermissions")  # type: ignore
예제 #14
0
def test_configuration_do_not_expose_config(admin_client):
    with conf_vars({('webserver', 'expose_config'): 'False'}):
        resp = admin_client.get('configuration', follow_redirects=True)
    check_content_in_response(
        [
            'Airflow Configuration',
            '# Your Airflow administrator chose not to expose the configuration, '
            'most likely for security reasons.',
        ],
        resp,
    )
예제 #15
0
    def test_run_next_exception(self, mock_get_kube_client,
                                mock_kubernetes_job_watcher):
        import sys

        path = sys.path[
            0] + '/tests/kubernetes/pod_generator_base_with_secrets.yaml'

        # When a quota is exceeded this is the ApiException we get
        response = HTTPResponse(
            body=
            '{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", '
            '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, '
            'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", '
            '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}'
        )
        response.status = 403
        response.reason = "Forbidden"

        # A mock kube_client that throws errors when making a pod
        mock_kube_client = mock.patch('kubernetes.client.CoreV1Api',
                                      autospec=True)
        mock_kube_client.create_namespaced_pod = mock.MagicMock(
            side_effect=ApiException(http_resp=response))
        mock_get_kube_client.return_value = mock_kube_client
        mock_api_client = mock.MagicMock()
        mock_api_client.sanitize_for_serialization.return_value = {}
        mock_kube_client.api_client = mock_api_client
        config = {
            ('kubernetes', 'pod_template_file'): path,
        }
        with conf_vars(config):

            kubernetes_executor = self.kubernetes_executor
            kubernetes_executor.start()
            # Execute a task while the Api Throws errors
            try_number = 1
            kubernetes_executor.execute_async(
                key=('dag', 'task', datetime.utcnow(), try_number),
                queue=None,
                command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
            )
            kubernetes_executor.sync()
            kubernetes_executor.sync()

            assert mock_kube_client.create_namespaced_pod.called
            assert not kubernetes_executor.task_queue.empty()

            # Disable the ApiException
            mock_kube_client.create_namespaced_pod.side_effect = None

            # Execute the task without errors should empty the queue
            kubernetes_executor.sync()
            assert mock_kube_client.create_namespaced_pod.called
            assert kubernetes_executor.task_queue.empty()
예제 #16
0
    def test_task_instance_info(self):
        with conf_vars(
            {("core", "store_serialized_dags"): self.dag_serialization}
        ):
            url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}'
            dag_id = 'example_bash_operator'
            task_id = 'also_run_this'
            execution_date = utcnow().replace(microsecond=0)
            datetime_string = quote_plus(execution_date.isoformat())
            wrong_datetime_string = quote_plus(
                datetime(1990, 1, 1, 1, 1, 1).isoformat()
            )

            # Create DagRun
            trigger_dag(dag_id=dag_id,
                        run_id='test_task_instance_info_run',
                        execution_date=execution_date)

            # Test Correct execution
            response = self.client.get(
                url_template.format(dag_id, datetime_string, task_id)
            )
            self.assertEqual(200, response.status_code)
            self.assertIn('state', response.data.decode('utf-8'))
            self.assertNotIn('error', response.data.decode('utf-8'))

            # Test error for nonexistent dag
            response = self.client.get(
                url_template.format('does_not_exist_dag', datetime_string,
                                    task_id),
            )
            self.assertEqual(404, response.status_code)
            self.assertIn('error', response.data.decode('utf-8'))

            # Test error for nonexistent task
            response = self.client.get(
                url_template.format(dag_id, datetime_string, 'does_not_exist_task')
            )
            self.assertEqual(404, response.status_code)
            self.assertIn('error', response.data.decode('utf-8'))

            # Test error for nonexistent dag run (wrong execution_date)
            response = self.client.get(
                url_template.format(dag_id, wrong_datetime_string, task_id)
            )
            self.assertEqual(404, response.status_code)
            self.assertIn('error', response.data.decode('utf-8'))

            # Test error for bad datetime format
            response = self.client.get(
                url_template.format(dag_id, 'not_a_datetime', task_id)
            )
            self.assertEqual(400, response.status_code)
            self.assertIn('error', response.data.decode('utf-8'))
예제 #17
0
    def test_xcom_deserialize_with_pickle_to_json_switch(self):
        json_obj = {"key": "value"}
        execution_date = timezone.utcnow()
        key = "xcom_test3"
        dag_id = "test_dag"
        task_id = "test_task3"

        with conf_vars({("core", "enable_xcom_pickling"): "True"}):
            XCom.set(key=key,
                     value=json_obj,
                     dag_id=dag_id,
                     task_id=task_id,
                     execution_date=execution_date)

        with conf_vars({("core", "enable_xcom_pickling"): "False"}):
            ret_value = XCom.get_one(key=key,
                                     dag_id=dag_id,
                                     task_id=task_id,
                                     execution_date=execution_date)

        assert ret_value == json_obj
    def assert_remote_logs(self, expected_message, ti):
        with provide_gcp_context(GCP_STACKDDRIVER), conf_vars({
            ('logging', 'remote_logging'): 'True',
            ('logging', 'remote_base_log_folder'): f"stackdriver://{self.log_name}",
        }):
            from airflow.config_templates import airflow_local_settings
            importlib.reload(airflow_local_settings)
            settings.configure_logging()

            task_log_reader = TaskLogReader()
            logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
            self.assertIn(expected_message, logs)
예제 #19
0
    def test_does_not_send_stats_using_statsd_when_statsd_and_dogstatsd_both_on(
            self):
        from datadog import DogStatsd

        with conf_vars({
            ('metrics', 'statsd_on'): 'True',
            ('metrics', 'statsd_datadog_enabled'): 'True'
        }):
            importlib.reload(airflow.stats)
            assert isinstance(airflow.stats.Stats.dogstatsd, DogStatsd)
            assert not hasattr(airflow.stats.Stats, 'statsd')
        importlib.reload(airflow.stats)
    def test_should_support_plugin(self):
        plugins_manager.plugins = [
            FakePlugin()
        ]

        self.addCleanup(self.remove_executor_plugins)
        with conf_vars({
            ("core", "executor"): f"{TEST_PLUGIN_NAME}.FakeExecutor"
        }):
            executor = ExecutorLoader.get_default_executor()
            self.assertIsNotNone(executor)
            self.assertIn("FakeExecutor", executor.__class__.__name__)
예제 #21
0
    def setup_class(cls) -> None:
        cls.exit_stack = ExitStack()
        cls.exit_stack.enter_context(
            conf_vars({('webserver', 'expose_config'): 'True'}))
        with conf_vars({
            ("api", "auth_backend"):
                "tests.test_utils.remote_user_api_auth_backend"
        }):
            cls.app = app.create_app(testing=True)  # type:ignore
        create_user(
            cls.app,  # type:ignore
            username="******",
            role_name="Test",
            permissions=[(permissions.ACTION_CAN_READ,
                          permissions.RESOURCE_CONFIG)],  # type: ignore
        )
        create_user(cls.app,
                    username="******",
                    role_name="TestNoPermissions")  # type: ignore

        cls.client = None
예제 #22
0
 def test_illegal_args(self):
     """
     Tests that Operators reject illegal arguments
     """
     msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
     with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
         with self.assertWarns(PendingDeprecationWarning) as warning:
             BashOperator(task_id='test_illegal_args',
                          bash_command='echo success',
                          dag=self.dag,
                          illegal_argument_1234='hello?')
             assert any(msg in str(w) for w in warning.warnings)
예제 #23
0
    def setUpClass(cls) -> None:
        super().setUpClass()
        with conf_vars({
            ("api", "auth_backend"):
                "tests.test_utils.remote_user_api_auth_backend"
        }):
            cls.app = app.create_app(testing=True)  # type:ignore
        create_user(
            cls.app,  # type: ignore
            username="******",
            role_name="Test",
            permissions=[
                (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG),
                (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
                (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
                (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG),
            ],
        )
        create_user(cls.app,
                    username="******",
                    role_name="TestNoPermissions")  # type: ignore
        create_user(
            cls.app,
            username="******",
            role_name="TestGranularDag"  # type: ignore
        )
        cls.app.appbuilder.sm.sync_perm_for_dag(  # type: ignore  # pylint: disable=no-member
            "TEST_DAG_1",
            access_control={'TestGranularDag': [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
        )

        with DAG(cls.dag_id,
                 start_date=datetime(2020, 6, 15),
                 doc_md="details",
                 params={"foo": 1}) as dag:
            DummyOperator(task_id=cls.task_id)

        with DAG(cls.dag2_id, start_date=datetime(2020, 6,
                                                  15)) as dag2:  # no doc_md
            DummyOperator(task_id=cls.task_id)

        with DAG(cls.dag3_id) as dag3:  # DAG start_date set to None
            DummyOperator(task_id=cls.task_id,
                          start_date=datetime(2019, 6, 12))

        cls.dag = dag  # type:ignore
        cls.dag2 = dag2  # type: ignore
        cls.dag3 = dag3  # tupe: ignore

        dag_bag = DagBag(os.devnull, include_examples=False)
        dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3}

        cls.app.dag_bag = dag_bag  # type:ignore
예제 #24
0
    def test_get_dag_runs_invalid_dag_id(self):
        with conf_vars(
            {("core", "store_serialized_dags"): self.dag_serialzation}
        ):
            url_template = '/api/experimental/dags/{}/dag_runs'
            dag_id = 'DUMMY_DAG'

            response = self.app.get(url_template.format(dag_id))
            self.assertEqual(400, response.status_code)
            data = json.loads(response.data.decode('utf-8'))

            self.assertNotIsInstance(data, list)
예제 #25
0
    def test_should_response_200_serialized(self):
        # Create empty app with empty dagbag to check if DAG is read from db
        with conf_vars({
            ("api", "auth_backend"):
                "tests.test_utils.remote_user_api_auth_backend"
        }):
            app_serialized = app.create_app(testing=True)
        dag_bag = DagBag(os.devnull,
                         include_examples=False,
                         read_dags_from_db=True)
        app_serialized.dag_bag = dag_bag
        client = app_serialized.test_client()

        SerializedDagModel.write_dag(self.dag)

        expected = {
            "class_ref": {
                "class_name": "DummyOperator",
                "module_path": "airflow.operators.dummy_operator",
            },
            "depends_on_past": False,
            "downstream_task_ids": [],
            "end_date": None,
            "execution_timeout": None,
            "extra_links": [],
            "owner": "airflow",
            "pool": "default_pool",
            "pool_slots": 1.0,
            "priority_weight": 1.0,
            "queue": "default",
            "retries": 0.0,
            "retry_delay": {
                "__type": "TimeDelta",
                "days": 0,
                "seconds": 300,
                "microseconds": 0
            },
            "retry_exponential_backoff": False,
            "start_date": "2020-06-15T00:00:00+00:00",
            "task_id": "op1",
            "template_fields": [],
            "trigger_rule": "all_success",
            "ui_color": "#e8f7e4",
            "ui_fgcolor": "#000",
            "wait_for_downstream": False,
            "weight_rule": "downstream",
        }
        response = client.get(
            f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}",
            environ_overrides={'REMOTE_USER': "******"})
        assert response.status_code == 200
        assert response.json == expected
예제 #26
0
    def test_get_dag_runs_no_runs(self):
        with conf_vars(
            {("core", "store_serialized_dags"): self.dag_serialzation}
        ):
            url_template = '/api/experimental/dags/{}/dag_runs'
            dag_id = 'example_bash_operator'

            response = self.app.get(url_template.format(dag_id))
            self.assertEqual(200, response.status_code)
            data = json.loads(response.data.decode('utf-8'))

            self.assertIsInstance(data, list)
            self.assertEqual(len(data), 0)
    def assert_remote_logs(self, expected_message, ti):
        with provide_gcp_context(GCP_GCS_KEY), conf_vars({
            ('logging', 'remote_logging'): 'True',
            ('logging', 'remote_base_log_folder'): f"gs://{self.bucket_name}/path/to/logs",
            ('logging', 'remote_log_conn_id'): "google_cloud_default",
        }):
            from airflow.config_templates import airflow_local_settings
            importlib.reload(airflow_local_settings)
            settings.configure_logging()

            task_log_reader = TaskLogReader()
            logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
            self.assertIn(expected_message, logs)
예제 #28
0
    def test_get_dag_code(self):
        with conf_vars({
            ("core", "store_serialized_dags"): self.dag_serialzation
        }):
            url_template = '/api/experimental/dags/{}/code'

            response = self.client.get(
                url_template.format('example_bash_operator'))
            self.assertIn('BashOperator(', response.data.decode('utf-8'))
            self.assertEqual(200, response.status_code)

            response = self.client.get(url_template.format('xyz'))
            self.assertEqual(404, response.status_code)
예제 #29
0
 def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
     mock_smtp_ssl.return_value = mock.Mock()
     with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
         utils.email.send_mime_email('from',
                                     'to',
                                     MIMEMultipart(),
                                     dryrun=False)
     assert not mock_smtp.called
     mock_smtp_ssl.assert_called_once_with(
         host=conf.get('smtp', 'SMTP_HOST'),
         port=conf.getint('smtp', 'SMTP_PORT'),
         timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
     )
예제 #30
0
 def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
     mock_smtp.return_value = mock.Mock()
     mock_smtp_ssl.return_value = mock.Mock()
     with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
         utils.email.send_MIME_email('from',
                                     'to',
                                     MIMEMultipart(),
                                     dryrun=False)
     self.assertFalse(mock_smtp.called)
     mock_smtp_ssl.assert_called_once_with(
         conf.get('smtp', 'SMTP_HOST'),
         conf.getint('smtp', 'SMTP_PORT'),
     )