예제 #1
0
 def test_case_insensitive_parameter_building(self):
     # First run with correct case
     with config({
             "CaseSensitiveParameterTask": {
                 "TParam": 2,
                 "validate_no_extra_params": ParamValidation.error,
             }
     }):
         task = CaseSensitiveParameterTask()
         assert task.TParam == 2
     # Second run with incorrect lower case
     with config({
             "CaseSensitiveParameterTask": {
                 "tparam": 3,
                 "validate_no_extra_params": ParamValidation.error,
             }
     }):
         task = CaseSensitiveParameterTask()
         assert task.TParam == 3
     # Third run with incorrect upper case
     with config({
             "CaseSensitiveParameterTask": {
                 "TPARAM": 4,
                 "validate_no_extra_params": ParamValidation.error,
             }
     }):
         task = CaseSensitiveParameterTask()
         assert task.TParam == 4
 def test_with_task_class_extend_and_override_layer(self, config_sections):
     with config(config_sections):
         with config({DummyConfig.dict_config: override({"override": "win"})}):
             # can't merge into override value
             example_task_with_task_config_extend.task(
                 expected_dict={"override": "win"}
             ).dbnd_run()
예제 #3
0
    def test_no_error_on_same_from(self):
        @task
        def task_with_from():
            return

        with config({"task_with_from": {"_from": "task_with_from"}}):
            get_task_registry().build_dbnd_task("task_with_from")
예제 #4
0
    def test_param_override(self):
        with config({TTNFirstTask.param: "223"}):
            t = TNamePipe(child="aaa")

            assert str(t.first) != str(t.second)
            assert "223" in t.first.task.param
            assert "223" in t.second.task.param
예제 #5
0
    def test_spark_inline_same_context(self):
        from pyspark.sql import SparkSession
        from dbnd_test_scenarios.spark.spark_tasks_inline import word_count_inline

        with SparkSession.builder.getOrCreate() as sc:
            with config({SparkLocalEngineConfig.enable_spark_context_inplace: True}):
                assert_run_task(word_count_inline.t(text=__file__))
예제 #6
0
    def testWithNamespaceConfig(self):
        class A(TTask):
            task_namespace = "mynamespace"
            p = parameter[int]

        with config({"mynamespace.A": {"p": "999"}}):
            assert 999 == A().p
예제 #7
0
def main():
    dbnd_bootstrap()

    from airflow.models import TaskInstance
    from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
    from airflow.utils.timezone import utcnow

    from dbnd import config
    from dbnd._core.constants import TaskExecutorType
    from dbnd._core.settings import RunConfig
    from dbnd._core.task_run.task_run import TaskRun
    from dbnd.tasks.basics import dbnd_sanity_check
    from dbnd_examples.dbnd_airflow import bash_dag

    airflow_task_log = logging.getLogger("airflow.task")

    task = bash_dag.t3
    execution_date = utcnow()
    ti = TaskInstance(task, execution_date)

    ti.init_run_context(raw=False)
    logger.warning("Running with task_log %s",
                   airflow_task_log.handlers[0].handler.baseFilename)
    with redirect_stdout(airflow_task_log, logging.INFO), redirect_stderr(
            airflow_task_log, logging.WARN):
        logger.warning("from redirect")

        logger.warning("after patch")
        with config({RunConfig.task_executor_type: TaskExecutorType.local}):
            run = dbnd_sanity_check.dbnd_run()
            tr = run.root_task_run  # type: TaskRun

    logger.warning("TR: %s %s %s", tr, tr.task_tracker_url,
                   tr.log.local_log_file)
예제 #8
0
 def test_override_config__context(self):
     # same problem as previous test
     with config(config_values={
             task_from_config.task.parameter_from_config: "from_context"
     }):
         # this layer is about config
         task_from_config.dbnd_run(expected="from_context")
예제 #9
0
def send_heartbeat(run_uid, databand_url, heartbeat_interval, driver_pid,
                   tracker, tracker_api):
    from dbnd import config
    from dbnd._core.settings import CoreConfig
    from dbnd._core.task_executor.heartbeat_sender import send_heartbeat_continuously

    with config({
            "core": {
                "tracker": tracker.split(","),
                "tracker_api": tracker_api,
                "databand_url": databand_url,
            }
    }):
        requred_context = []
        if tracker_api == "db":
            from dbnd import new_dbnd_context

            requred_context.append(
                new_dbnd_context(name="send_heartbeat",
                                 autoload_modules=False))

        with nested_context.nested(*requred_context):
            tracking_store = CoreConfig().get_tracking_store()

            send_heartbeat_continuously(run_uid, tracking_store,
                                        heartbeat_interval, driver_pid)
예제 #10
0
 def test_error_on_same_from(self):
     with pytest.raises(Exception):
         with config({
                 "unknown_task_with_from": {
                     "_from": "unknown_task_with_from"
                 }
         }):
             get_task_registry().build_dbnd_task("unknown_task_with_from")
예제 #11
0
 def test_read_(self):
     actual = read_from_config_files(
         [scenario_path("config_files", "test_config_reader.cfg")]
     )
     with config(actual):
         assert (
             config.get("test_config_reader", "test_config_reader") == "test_value"
         )
예제 #12
0
def tracking_config_empty():
    # Enforce "tracking" config section so that changes in config files won't affect tests
    with config(
        {"tracking": {}},
            source="dbnd_test_context",
            merge_settings=_ConfigMergeSettings(replace_section=True),
    ):
        return TrackingConfig()
예제 #13
0
 def test_full_name_config(self):
     conf = {
         SimplestTask.task_definition.full_task_family: {
             "simplest_param": "from_config"
         }
     }
     with config(conf):
         assert SimplestTask().simplest_param == "from_config"
예제 #14
0
    def test_wrong_config_validation(self):
        # raise exception
        with pytest.raises(UnknownParameterError) as e:
            with config({
                    "TTask": {
                        "t_parammm": 2,
                        "validate_no_extra_params": ParamValidation.error,
                    }
            }):
                TTask()

        assert "Did you mean: t_param" in e.value.help_msg

        # log warning to log
        with config({
                "TTask": {
                    "t_parammm": 2,
                    "validate_no_extra_params": ParamValidation.warn,
                }
        }):
            TTask()
        # tried to add a capsys assert here but couldn't get it to work

        # do nothing
        with config({
                "TTask": {
                    "t_parammm": 2,
                    "validate_no_extra_params": ParamValidation.disabled,
                }
        }):
            TTask()

        # handle core config sections too
        with pytest.raises(
                DatabandError
        ):  # might be other extra params in the config in which case a DatabandBuildError will be raised
            with config({
                    "config": {
                        "validate_no_extra_params": ParamValidation.error
                    },
                    "core": {
                        "blabla": "bla"
                    },
            }):
                CoreConfig()
예제 #15
0
 def test_override_inheritance_config(self):
     with config({
             SecondTask.param: "from_config_context",
             FooConfig.bar: "from_config_context",
     }):
         t = SecondTask()
         assert "from_config_context" == t.param, "t.param"
         assert "from_config_context" == t.foo.bar, "t.foo.bar"
         assert "SecondTask.foo.defaults" == t.foo.quz, "t.foo.quz"
예제 #16
0
def test_task_in_conf():
    # in_conf - shouldn't affect anything
    with config(
        {"task_in_conf": {"param1": "conf_value", "param2": "conf_value"}},
        source="test_source",
    ):
        param1, param2 = task_in_conf(param2="param2_value")
        assert param1 == "default_value"
        assert param2 == "param2_value"
예제 #17
0
 def test_override_inheritance_legacy(self):
     with config({
             SecondTask.param: "config.context",
             FooConfig.bar: "config.context"
     }):
         t = SecondTask()
         assert "config.context" == t.param
         assert "config.context" == t.foo.bar
         assert "SecondTask.foo.defaults" == t.foo.quz
예제 #18
0
 def test_override_inheritance_config(self):
     with config({
             SecondTask.param: "from_config_context",
             FooConfig.bar: "from_config_context",
     }):
         t = SecondTask()
         assert "from_config_context" == t.param
         assert "SecondTask.task_config" == t.foo.bar
         assert "SecondTask.task_config" == t.foo.quz
예제 #19
0
 def test_override_config_values_from_context(self):
     with config({
             "task_from_config": {
                 "parameter_from_config": "from_context_override"
             }
     }):
         assert (config.get(
             "task_from_config",
             "parameter_from_config") == "from_context_override")
         task_from_config.dbnd_run(expected="from_context_override")
예제 #20
0
 def test_override_inheritance_legacy(self):
     with config({
             SecondTask.param: "config.context",
             FooConfig.bar: "config.context"
     }):
         t = SecondTask()
         assert t.param == "config.context"
         # it created in second task where task_config is applied.
         assert t.foo.bar == "SecondTask.task_config"
         assert t.foo.quz == "SecondTask.task_config"
예제 #21
0
    def test_spark_inline_same_context(self):
        from pyspark.sql import SparkSession

        from dbnd_examples.orchestration.dbnd_spark.word_count import word_count_inline
        from dbnd_spark.local.local_spark_config import SparkLocalEngineConfig

        with SparkSession.builder.getOrCreate() as sc:
            with config({SparkLocalEngineConfig.enable_spark_context_inplace: True}):
                task_instance = word_count_inline.t(text=__file__)
                assert_run_task(task_instance)
예제 #22
0
def tracking_config_empty():
    # Enforce "tracking" config section so that changes in config files won't affect tests
    with config(
        {
            "tracking":
            replace_section_with(
                {"value_reporting_strategy": ValueTrackingLevel.ALL})
        },
            source="dbnd_test_context",
    ):
        return TrackingConfig()
예제 #23
0
    def test_task_name_priority(self):
        with config({
                "Second_aaa": {
                    "param": "per_name"
                },
                TTNFirstTask.param: "224"
        }):
            t = TNamePipe(child="aaa")

            assert str(t.first) != str(t.second)
            assert "224" in t.first.task.param
            assert "per_name" in t.second.task.param
예제 #24
0
def databand_test_context(
    request, tmpdir, databand_context_kwargs, databand_config
):  # type: (...) -> DatabandContext

    test_config = {
        "run": {
            "name": _run_name_for_test_request(request),
            "heartbeat_interval_s": -1,
        },
        "local": {"root": str(tmpdir.join("local_root"))},
    }
    with config(test_config, source="databand_test_context"), new_dbnd_context(
        **databand_context_kwargs
    ) as t:
        yield t
예제 #25
0
    def __call__(cls, *args, **kwargs):
        """
        Extension of Task Metaclass, we want to add our own configuration based on some values
        so the moment task is created, that config is used.
        """

        config_values = cls.get_custom_config()

        # create new config layer, so when we are out of this process -> config is back to the previous value
        with config(
            config_values=config_values,
            source=cls.task_definition.task_passport.format_source_name(
                ".get_custom_config"
            ),
        ):
            return super(AdvancedConfigTaskMetaclass, cls).__call__(*args, **kwargs)
예제 #26
0
    def test_read_from_multiple_sources(self):
        from pyspark.sql import SparkSession

        from dbnd_examples.orchestration.dbnd_spark.read_from_multiple_sources import (
            data_source_complicated_pipeline,
        )
        from dbnd_spark.local.local_spark_config import SparkLocalEngineConfig

        with SparkSession.builder.getOrCreate() as sc:
            with config({SparkLocalEngineConfig.enable_spark_context_inplace: True}):
                task_instance = data_source_complicated_pipeline.t(
                    root_path=_data_for_spark_path("read_from_multiple_sources"),
                    extra_data=_data_for_spark_path(
                        "read_from_multiple_sources1", "configID=1", "1.tsv"
                    ),
                )
                assert_run_task(task_instance)
예제 #27
0
def try_get_or_create_task_run():
    # type: ()-> TaskRunTracker
    task_run = try_get_current_task_run()
    if task_run:
        return task_run

    try:
        from dbnd._core.task_run.task_run_tracker import TaskRunTracker
        from dbnd._core.configuration.environ_config import DBND_TASK_RUN_ATTEMPT_UID

        tra_uid = os.environ.get(DBND_TASK_RUN_ATTEMPT_UID)
        if tra_uid:
            task_run = TaskRunMock(tra_uid)
            from dbnd import config
            from dbnd._core.settings import CoreConfig

            with config({CoreConfig.tracker_raise_on_error: False},
                        source="ondemand_tracking"):
                tracking_store = CoreConfig().get_tracking_store()
                trt = TaskRunTracker(task_run, tracking_store)
                task_run.tracker = trt
                return task_run

        # let's check if we are in airflow env
        from dbnd._core.inplace_run.airflow_dag_inplace_tracking import (
            try_get_airflow_context, )

        airflow_context = try_get_airflow_context()
        if airflow_context:
            from dbnd._core.inplace_run.airflow_dag_inplace_tracking import (
                get_airflow_tracking_manager, )

            atm = get_airflow_tracking_manager(airflow_context)
            if atm:
                return atm.airflow_operator__task_run
        from dbnd._core.inplace_run.inplace_run_manager import is_inplace_run

        if is_inplace_run():
            return dbnd_run_start()

    except Exception:
        logger.info("Failed during dbnd inplace tracking init.", exc_info=True)
        return None
예제 #28
0
def _get_task_run_mock(tra_uid):
    """
    We need better implementation for this,
    currently in use only for spark
    """
    try:
        from dbnd._core.task_run.task_run_tracker import TaskRunTracker

        task_run = TaskRunMock(tra_uid)
        from dbnd import config
        from dbnd._core.settings import CoreConfig

        with config(
            {CoreConfig.tracker_raise_on_error: False}, source="on_demand_tracking"
        ):
            tracking_store = CoreConfig().get_tracking_store()
            trt = TaskRunTracker(task_run, tracking_store)
            task_run.tracker = trt
            return task_run
    except Exception:
        logger.info("Failed during dbnd inplace tracking init.", exc_info=True)
        return None
예제 #29
0
 def test_read_environ_config(self):
     os.environ["DBND__TEST_SECTION__TEST_KEY"] = "TEST_VALUE"
     actual = read_environ_config()
     with config(actual):
         assert config.get("test_section", "test_key") == "TEST_VALUE"
예제 #30
0
 def test_extend_k8s_labels(self, name, test_config, expected):
     with new_dbnd_context(conf={CoreConfig.tracker: "console"}):
         ts = read_from_config_stream(seven.StringIO(test_config))
         with config(ts):
             task_with_extend.dbnd_run(name, expected)