def test_external_task_cmd_line(self): with new_dbnd_context(conf={RunConfig.task_executor_type: "local"}): def run(): ttask_simple.dbnd_run() main_thread = Thread(target=run) main_thread.start() main_thread.join() t = ttask_simple.task() assert t._complete()
def test_sign_by_task_code_build(self): with new_dbnd_context( conf={"task_build": { "sign_with_full_qualified_name": "True" }}): # we need to recreate task definition with TTaskSigChange=True class TTaskSigChange(TTask): pass task = TTaskSigChange() assert str(TTaskSigChange.__module__) in task.task_signature_source
def test_no_outputs(self, capsys): class TMissingOutputs(PythonTask): some_output = output forgotten_output = output def run(self): self.some_output.write("") with new_dbnd_context(conf={"run": {"task_executor_type": "local"}}): with pytest.raises(DatabandRunError, match="Failed tasks are:"): TMissingOutputs().dbnd_run()
def test_simple_driver(self): with new_dbnd_context( conf={ DatabandSystemConfig.env: "gcp_k8s", EnvConfig.local_engine: "local_engine", EnvConfig.remote_engine: "gcp_k8s_engine", KubernetesEngineConfig.debug: True, } ): t = dbnd_sanity_check.task(task_version="now") assert t.task_env.task_name == "gcp_k8s" t.dbnd_run()
def test_dbnd_store_initialization(self): with new_dbnd_context( conf={"mlflow_tracking": { "databand_tracking": True }}), mock.patch("dbnd_mlflow.tracking_store.TrackingApiClient" ) as fake_tracking_store: mlflow_tracking_integration_check.dbnd_run() fake_tracking_store.assert_called_once_with( "http://localhost:8080") with new_dbnd_context( conf={ "core": { "databand_url": "https://secure" }, "mlflow_tracking": { "databand_tracking": True }, }), mock.patch("dbnd_mlflow.tracking_store.TrackingApiClient" ) as fake_tracking_store: mlflow_tracking_integration_check.dbnd_run(task_version="now") fake_tracking_store.assert_called_once_with("https://secure")
def test_to_read_input(self, capsys): class TCorruptedInput(PythonTask): some_input = data[DataFrame] forgotten_output = output def run(self): self.some_output.write("") with new_dbnd_context(conf={"run": {"task_executor_type": "local"}}): t = self.target("some_input.json").write("corrupted dataframe") with pytest.raises(DatabandRunError, match="Failed tasks are:"): TCorruptedInput(some_input=t).dbnd_run()
def _save_graph(self, task): with new_dbnd_context( conf={ RunConfig.task_executor_type: override( TaskExecutorType.local), CoreConfig.tracker: override(["console"]), }) as dc: run = dc.dbnd_run_task(task_or_task_name=task) run.save_run() loaded_run = DatabandRun.load_run(dump_file=run.driver_dump, disable_tracking_api=False) assert loaded_run return run
def schedule(ctx): """Manage scheduled jobs""" ctx.obj = {} ctx.obj["headers"] = SCHEDULED_JOB_HEADERS from dbnd import new_dbnd_context context = new_dbnd_context(autoload_modules=False, conf={ "core": { "tracker": "" } }).__enter__() ctx.obj["scheduled_job_service"] = context.scheduled_job_service
def test_single_instance(self): with new_dbnd_context(name="first") as ctx: config1 = ctx.settings.get_config("my_dummy") config2 = Config.from_databand_context("my_dummy") config3 = DummyConfig.from_databand_context() config4 = ctx.settings.get_config("my_dummy") assert config1 is config2 assert config1 is config3 assert config1 is config4 config1.foo = "123" assert config2.foo == "123" assert config3.foo == "123" assert config4.foo == "123"
def test_plugin_loading(self): with new_dbnd_context( conf={ "core": { "databand_url": "https://secure" }, "mlflow_tracking": { "databand_tracking": True, "duplicate_tracking_to": "http://mlflow", }, }): assert ( get_tracking_uri() == "dbnd+s://secure?duplicate_tracking_to=http%253A%252F%252Fmlflow" )
def ipython(): """Get ipython shell with Databand's context""" # noinspection PyUnresolvedReferences from dbnd_web import models # noqa from dbnd import new_dbnd_context from airflow.utils.db import create_session import IPython with new_dbnd_context( name="ipython", autoload_modules=False) as ctx, create_session() as session: header = "\n\t".join([ "Welcome to \033[91mDataband\033[0m's ipython command.\nPredefined variable are", "\033[92m\033[1mctx\033[0m -> dbnd_context", "\033[92m\033[1msession\033[0m -> DB session", "\033[92m\033[1mmodels\033[0m -> dbnd models", ]) IPython.embed(colors="neutral", header=header)
def test_auto_load(self): with new_dbnd_context( conf= { "autotestconfig": { "param_datetime": "2018-01-01", "param_int": "42" }, "core": { "user_configs": "autotestconfig", "user_init": "test_dbnd.settings.autoloaded_config.user_code_load_config", }, "databand": { "module": "test_dbnd.settings.autoloaded_config" }, }) as dc: # type: DatabandContext dc.dbnd_run_task(task_or_task_name="task_auto_config")
def test_params_without_preview(self): PARAMS_WITHOUT_PREVIEWS = dedent(""" Name Kind Type Format Source -= Value =- num_param param int t.t.t.t_f[default] *** list_param param List t.t.t.t_f[default] "***" none_param param object t.t.t.t_f[default] *** result output object .pickle "***" """) with new_dbnd_context(conf={"tracking": {"log_value_preview": False}}): @task def t_f(num_param=12, list_param=[1, 2, 3], none_param=None): return "123" run = t_f.dbnd_run() task_run = run.root_task_run actual = task_run.task.ctrl.visualiser._banner.get_banner_str() assert PARAMS_WITHOUT_PREVIEWS in actual
def test_log_exception_to_server(self): with new_dbnd_context({ "core": { "databand_access_token": "token", "databand_url": "some_url" } }): with patch("dbnd.utils.api_client.ApiClient._send_request" ) as send_request_mock: with pytest.raises(TestException): raising_function() assert send_request_mock.call_count == 1 call_kwargs = send_request_mock.call_args.kwargs assert call_kwargs["url"] == "/api/v1/log_exception" assert call_kwargs["json"]["source"] == "tracking-sdk" assert "in raising_function" in call_kwargs["json"][ "stack_trace"] assert ("TestException: some error message" in call_kwargs["json"]["stack_trace"])
def test_no_skip_after_failure(self, fake_api_request): with new_dbnd_context( conf={ "core": { "tracker_raise_on_error": False }, "databand": { "verbose": True }, }) as ctx: with patch.object(TrackingAsyncWebChannel, "_background_worker_skip_processing_callback" ) as fake_skip: async_store = TrackingStoreThroughChannel.build_with_async_web_channel( ctx) fake_api_request.side_effect = DatabandWebserverNotReachableError( "fake_message") async_store.heartbeat(get_uuid()) # fail here async_store.heartbeat(get_uuid()) # no skip here async_store.flush() fake_skip.assert_not_called()
def test_init_protoweb_channel(self): # testing env should not include protobuf package # otherwise this test is useless with pytest.raises(ImportError): from google import protobuf # noqa: F401 with new_dbnd_context( conf={ "core": { "databand_url": "http://fake-url.dbnd.local:8080", "tracker": ["api"], "tracker_api": "proto", "tracker_raise_on_error": True, "allow_vendored_package": True, } }) as dc: ts = first_store(dc.tracking_store) assert ts.__class__.__name__ == "TrackingStoreThroughChannel", ts assert (ts.channel.__class__.__name__ == "TrackingProtoWebChannel" ), ts.channel # an extra check that protobuf is available: from google import protobuf # noqa: F401
def send_heartbeat(run_uid, databand_url, heartbeat_interval, driver_pid, tracker, tracker_api): from dbnd import config from dbnd._core.task_executor.heartbeat_sender import send_heartbeat_continuously with config({ "core": { "tracker": tracker.split(","), "tracker_api": tracker_api, "databand_url": databand_url, } }): requred_context = [] if tracker_api == "db": from dbnd import new_dbnd_context requred_context.append(new_dbnd_context(name="send_heartbeat")) with nested_context.nested(*requred_context): tracking_store = get_databand_context().tracking_store send_heartbeat_continuously(run_uid, tracking_store, heartbeat_interval, driver_pid)
def launch(self, context, scheduled_run_info: ScheduledRunInfo): with new_dbnd_context( name="airflow", conf={ RunConfig.task_executor_type: override(TaskExecutorType.local), RunConfig.parallel: override(False), LoggingConfig.disabled: override(True), }, ) as dc: launcher_task = Launcher( scheduled_cmd=self.scheduled_cmd, task_name=context.get("dag").dag_id, task_version="now", task_is_system=True, shell=self.shell, ) dc.dbnd_run_task( task_or_task_name=launcher_task, scheduled_run_info=scheduled_run_info, send_heartbeat=False, )
def test_start_heartbeat_sender(self): # we are not going to mock settings as that's too much work with new_dbnd_context( conf={ RunConfig.heartbeat_interval_s: 5, RunConfig.hearbeat_disable_plugins: 5, }) as dc: run = MagicMock(DatabandRun) type(run).run_uid = PropertyMock(return_value="testtest") type(run).context = PropertyMock(return_value=dc) type(run).run_local_root = PropertyMock( return_value=dc.env.dbnd_local_root) run_executor = MagicMock(RunExecutor) type(run_executor).run = PropertyMock(return_value=run) with patch("subprocess.Popen") as mock_popen: hearbeat = start_heartbeat_sender(run_executor) with hearbeat: logger.info("running with heartbeat") mock_popen.assert_called_once() call = mock_popen.call_args_list[-1] assert ENV_DBND__NO_PLUGINS in call.kwargs["env"] assert "testtest" in call.args[0]
def test_circle_dependency_type(self): @task def task_a(a): return a @task def task_b(a): return a @pipeline def task_circle(): a = task_a(1) b = task_b(a) b.task.set_downstream(a) return b with new_dbnd_context(conf={"core": {"recheck_circle_dependencies": "True"}}): with pytest.raises( DatabandBuildError, match="A cyclic dependency occurred" ): task_circle.task() with pytest.raises(DatabandRunError, match="A cyclic dependency occurred"): task_circle.dbnd_run()
def build_task(root_task, **kwargs): from dbnd import new_dbnd_context with new_dbnd_context(conf={root_task: kwargs}): return get_task_registry().build_dbnd_task(task_name=root_task)
def test_input_is_missing_file(self): with new_dbnd_context(conf={"run": {"task_executor_type": "local"}}): with pytest.raises(DatabandRunError, match="Failed tasks are:"): t = TTaskWithInput(t_input="file_that_not_exists") assert_run_task(t)
logger = logging.getLogger(__name__) def custom_logging_mod(): base_config = get_databand_context( ).settings.log.get_dbnd_logging_config_base() base_config["handlers"]["my_file"] = { "class": "logging.FileHandler", "formatter": "formatter", "filename": "/tmp/my_custom_file", "encoding": "utf-8", } base_config["root"]["handlers"].append("my_file") return base_config @task def task_with_some_logging(): logging.info("root logger") logger.info("custom logger") logger.warning("warning!") return "ok" if __name__ == "__main__": mod_func = "dbnd_examples.feature_system.custom_logging.custom_logging_mod" with new_dbnd_context(conf={"log": {"custom_dict_config": mod_func}}): task_with_some_logging.dbnd_run()
def test_extend_k8s_labels(self, name, test_config, expected): with new_dbnd_context(conf={CoreConfig.tracker: "console"}): ts = read_from_config_stream(seven.StringIO(test_config)) with config(ts): task_with_extend.dbnd_run(name, expected)
def test_no_outputs(self, capsys): with new_dbnd_context(conf={"run": {"task_executor_type": "local"}}): with pytest.raises(DatabandRunError, match="Failed tasks are:"): TMissingOutputs().dbnd_run()
def test_verbose_build(self): with new_dbnd_context(conf={"task_build": {"verbose": "True"}}): task = TTask(override={TTask.t_param: "test_driver"}) assert task.t_param == "test_driver"
def test_simple_build(self): with new_dbnd_context(): task = TTask(t_param="test_driver") assert task.t_param == "test_driver" assert task.t_output
def test_verbose_build(self): with new_dbnd_context(conf={"task_build": {"verbose": "True"}}): task = SimpleSparkTask( override={SparkConfig.driver_memory: "test_driver"}) assert task.spark_config.driver_memory == "test_driver"
def test_different_instances(self): with new_dbnd_context(name="first") as ctx: config1 = ctx.settings.get_config("my_dummy") with new_dbnd_context(name="second") as ctx2: config2 = ctx2.settings.get_config("my_dummy") assert config1 is not config2
def test_foreign_context_should_not_fail(self): with new_dbnd_context(): t = SimplestTask() t.dbnd_run() TTaskWithInput(t_input=t).dbnd_run()