def build_dbnd_task(self, task_name, task_kwargs=None, expected_type=None): task_kwargs = task_kwargs or dict() task_kwargs.setdefault("task_name", task_name) task_cls = self.get_task_cls(task_name) # type: Type[Task] if is_airflow_enabled(): from dbnd_airflow.dbnd_task_executor.airflow_operator_as_dbnd import ( AirflowDagAsDbndTask, ) if issubclass(task_cls, AirflowDagAsDbndTask): # we are running old style dag dag = self._get_aiflow_dag(task_name) airflow_task = AirflowDagAsDbndTask.build_dbnd_task_from_dag( dag=dag) return airflow_task try: logger.debug("Building %s task", task_cls.task_definition.full_task_family) obj = task_cls(**task_kwargs) except Exception: exc = get_databand_context().settings.log.format_exception_as_str( sys.exc_info(), isolate=True) logger.error("Failed to build %s: \n\n%s", task_cls.get_task_family(), exc) raise if expected_type and not issubclass(task_cls, expected_type): raise friendly_error.task_registry.wrong_type_for_task( task_name, task_cls, expected_type) return obj
def is_save_run(self, run, task_runs): core_settings = run.context.settings.core if core_settings.always_save_pipeline: return True if core_settings.disable_save_pipeline: return False if any(tr.task._conf__require_run_dump_file for tr in task_runs): return True if self.target_engine.require_submit: return True if self.task_executor_type == TaskExecutorType.local: return False if is_airflow_enabled(): from dbnd_airflow.executors import AirflowTaskExecutorType return self.task_executor_type not in [ AirflowTaskExecutorType.airflow_inprocess, TaskExecutorType.local, ] return True
def initialize_band(self): try: band_context = [ self.task._auto_load_save_params(auto_read=False, normalize_on_change=True) ] if is_airflow_enabled(): from dbnd_airflow.dbnd_task_executor.airflow_operators_catcher import ( get_databand_op_catcher_dag, ) band_context.append(get_databand_op_catcher_dag()) with nested(*band_context): band = self.task.band() # this one would be normalized self.task._task_band_result = band self.task_band_result = band # real value except Exception as ex: logger.error( self.visualiser.banner( msg="Failed to run %s" % _band_call_str(self.task), color="red", exc_info=sys.exc_info(), )) if self.task._conf__decorator_spec: raise raise friendly_error.task_build.failed_to_call_band(ex, self.task)
def is_in_airflow_dag_build_context(): if not is_airflow_enabled(): return False from dbnd_airflow.functional.dbnd_functional_dag import ( is_in_airflow_dag_build_context as airflow__is_in_airflow_dag_build_context, ) return airflow__is_in_airflow_dag_build_context()
def dbnd_setup_plugin(): from dbnd_docker.docker.docker_engine_config import DockerEngineConfig from dbnd_docker.docker.docker_task import DockerRunTask register_config_cls(DockerEngineConfig) register_config_cls(DockerRunTask) if is_airflow_enabled(): from dbnd_docker.kubernetes.kubernetes_engine_config import ( KubernetesEngineConfig, ) register_config_cls(KubernetesEngineConfig) logger.debug("Registered kubernetes plugin")
def start( self, root_task_name, in_memory=True, run_uid=None, airflow_context=False, job_name=None, ): if try_get_databand_context(): return if not airflow_context and not self._atexit_registered: atexit.register(self.stop) if is_airflow_enabled(): from airflow.settings import dispose_orm atexit.unregister(dispose_orm) c = { "run": { "skip_completed": False }, # we don't want to "check" as script is task_version="now" "task": { "task_in_memory_outputs": in_memory }, # do not save any outputs } config.set_values(config_values=c, override=True, source="dbnd_start") context_kwargs = {"name": "airflow"} if airflow_context else {} # create databand context dc = self._enter_cm( new_dbnd_context(**context_kwargs)) # type: DatabandContext root_task = _build_inline_root_task(root_task_name, airflow_context=airflow_context) # create databand run dr = self._enter_cm( new_databand_run( context=dc, task_or_task_name=root_task, run_uid=run_uid, existing_run=False, job_name=job_name, )) # type: DatabandRun if run_uid: root_task_run_uid = get_task_run_uid(run_uid, root_task_name) else: root_task_run_uid = None dr._init_without_run(root_task_run_uid=root_task_run_uid) self._start_taskrun(dr.driver_task_run) self._start_taskrun(dr.root_task_run) return dr
def dbnd_bootstrap(): global _dbnd_bootstrap global _dbnd_bootstrap_started if _dbnd_bootstrap_started: return _dbnd_bootstrap_started = True dbnd_system_bootstrap() from targets.marshalling import register_basic_data_marshallers register_basic_data_marshallers() _surpress_loggers() _suppress_warnings() enable_osx_forked_request_calls() if is_airflow_enabled(): from dbnd_airflow.bootstrap import airflow_bootstrap airflow_bootstrap() register_dbnd_plugins() from dbnd._core.configuration import environ_config from dbnd._core.utils.basics.load_python_module import run_user_func from dbnd._core.plugin.dbnd_plugins import pm from dbnd._core.configuration.dbnd_config import config user_plugins = config.get("core", "plugins", None) if user_plugins: register_dbnd_user_plugins(user_plugins.split(",")) if is_unit_test_mode(): pm.hook.dbnd_setup_unittest() pm.hook.dbnd_setup_plugin() if is_sigquit_handler_on(): from dbnd._core.utils.basics.signal_utils import ( register_sigquit_stack_dump_handler, ) register_sigquit_stack_dump_handler() # now we can run user code ( at driver/task) user_preinit = environ_config.get_user_preinit() if user_preinit: run_user_func(user_preinit) # if for any reason there will be code that calls dbnd_bootstrap, this will prevent endless recursion _dbnd_bootstrap = True
def stop(self, at_exit=True, update_run_state=True): if update_run_state: databand_run = try_get_databand_run() if databand_run: root_tr = databand_run.task.current_task_run root_tr.finished_time = utcnow() for tr in databand_run.task_runs: if tr.task_run_state == TaskRunState.FAILED: root_tr.set_task_run_state( TaskRunState.UPSTREAM_FAILED) databand_run.set_run_state(RunState.FAILED) break else: root_tr.set_task_run_state(TaskRunState.SUCCESS) databand_run.set_run_state(RunState.SUCCESS) logger.info(databand_run.describe.run_banner_for_finished()) self._close_all_context_managers() if at_exit and is_airflow_enabled(): from airflow.settings import dispose_orm dispose_orm()
def _is_save_run_pickle(self, task_runs, remote_engine): if self.run_config.always_save_pipeline: return True if self.run_config.disable_save_pipeline: return False if any(tr.task._conf__require_run_dump_file for tr in task_runs): return True if remote_engine.require_submit: return True if self.task_executor_type == TaskExecutorType.local: return False if is_airflow_enabled(): from dbnd_airflow.executors import AirflowTaskExecutorType return self.task_executor_type not in [ AirflowTaskExecutorType.airflow_inprocess, TaskExecutorType.local, ] return True
def _run_spark_submit(self, application, jars): # task_env = get_cloud_config(Clouds.local) spark_local_config = SparkLocalEngineConfig() _config = self.config deploy = self.deploy AIRFLOW_ON = is_airflow_enabled() if AIRFLOW_ON: from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook from airflow.exceptions import AirflowException as SparkException else: from dbnd_spark._vendor.airflow.spark_hook import ( SparkException, SparkSubmitHook, ) spark = SparkSubmitHook( conf=_config.conf, conn_id=spark_local_config.conn_id, name=self.job.job_id, application_args=list_of_strings(self.task.application_args()), java_class=self.task.main_class, files=deploy.arg_files(_config.files), py_files=deploy.arg_files(self.task.get_py_files()), driver_class_path=_config.driver_class_path, jars=deploy.arg_files(jars), packages=_config.packages, exclude_packages=_config.exclude_packages, repositories=_config.repositories, total_executor_cores=_config.total_executor_cores, executor_cores=_config.executor_cores, executor_memory=_config.executor_memory, driver_memory=_config.driver_memory, keytab=_config.keytab, principal=_config.principal, num_executors=_config.num_executors, env_vars=self._get_env_vars(), verbose=_config.verbose, ) if not AIRFLOW_ON: # If there's no Airflow then there's no Connection so we # take conn information from spark config spark.set_connection(spark_local_config.conn_uri) log_buffer = StringIO() with log_buffer as lb: dbnd_log_handler = self._capture_submit_log(spark, lb) try: # sync the application file to remote if needed spark.submit(application=deploy.sync(application)) except SparkException as ex: return_code = self._get_spark_return_code_from_exception(ex) if return_code != "0": error_snippets = parse_spark_log_safe( log_buffer.getvalue().split(os.linesep)) raise failed_to_run_spark_script( self, spark._build_spark_submit_command( application=application), application, return_code, error_snippets, ) else: raise failed_spark_status(ex) finally: spark.log.handlers = [ h for h in spark.log.handlers if not dbnd_log_handler ]
def _get_task_cls(self, task_name): from dbnd._core.utils.basics.load_python_module import load_python_module task_cls = self._get_registered_task_cls(task_name) if task_cls: return task_cls # we are going to check if we have override/definition in config config_task_type = config.get(task_name, "_type", None) if config_task_type: _validate_no_recursion_in_config(task_name, config_task_type, "_type") try: return self._get_task_cls(config_task_type) except Exception: logger.error( "Failed to load type required by [%s] using _type=%s", task_name, config_task_type, ) raise config_task_type = config.get(task_name, "_from", None) if config_task_type: _validate_no_recursion_in_config(task_name, config_task_type, "_from") return self._get_task_cls(config_task_type) if "." in task_name: parts = task_name.split(".") possible_root_task = parts.pop() possible_module = ".".join(parts) # Try to load module and check again for existance load_python_module(possible_module, "task name '%s'" % task_name) task_cls = self._get_registered_task_cls(task_name) if task_cls: return task_cls # Check if task exists but user forgot to decorate method with @task task_module = sys.modules.get(possible_module) if task_module and hasattr(task_module, possible_root_task): user_func = getattr(task_module, possible_root_task) if callable(user_func): # Non-decorated function was found - decorate and return it from dbnd._core.decorator import dbnd_func_proxy decorated_task = dbnd_func_proxy.task(user_func) setattr(task_module, possible_root_task, decorated_task) logger.warning( "Found non-decorated task: %s. " "Please decorate this task with the proper symbol @pipeline \ @task.\n" "Auto-decorating and treating it as @task ...", task_name, ) return decorated_task.task if is_airflow_enabled(): from dbnd_airflow.dbnd_task_executor.airflow_operator_as_dbnd import ( AirflowDagAsDbndTask, ) dag = self._get_aiflow_dag(task_name) if dag: return AirflowDagAsDbndTask return None
def initialize_band(self): try: band_context = [] if is_airflow_enabled(): from dbnd_airflow.dbnd_task_executor.airflow_operators_catcher import ( get_databand_op_catcher_dag, ) band_context.append(get_databand_op_catcher_dag()) original_param_values = [] for param_value in self.task.task_params.get_param_values( ParameterFilters.OUTPUTS): if param_value.name == "task_band" or isinstance( param_value.parameter, FuncResultParameter): continue original_param_values.append((param_value, param_value.value)) with nested(*band_context): band = self.task.band() # this one would be normalized self.task._task_band_result = band self.task_band_result = band # real value from dbnd import PipelineTask if isinstance(self.task, PipelineTask): # after .band has finished, all user outputs of the .band should be defined for param_value, _ in original_param_values: # we want to validate only user facing parameters # they should have assigned values by this moment, # pipeline task can not have None outputs, after band call if param_value.parameter.system: continue if is_not_defined(param_value.value): raise friendly_error.task_build.pipeline_task_has_unassigned_outputs( task=self.task, param=param_value.parameter) # now let's normalize if user has changed outputs for param_value, original_value in original_param_values: if param_value.value is original_value: continue try: from dbnd._core.utils.task_utils import to_targets normalized_value = to_targets(param_value.value) param_value.update_param_value(normalized_value) except Exception as ex: raise friendly_error.task_build.failed_to_assign_param_value_at_band( ex, param_value.parameter, param_value.value, self.task) except Exception as ex: logger.warning( self.visualiser.banner( msg="Failed to run %s" % _band_call_str(self.task), color="red", exc_info=sys.exc_info(), )) if self.task.task_decorator: # just re-raise, we already have an error from the "run" function raise raise friendly_error.task_build.failed_to_call_band(ex, self.task)
def calculate_task_executor_type(submit_tasks, remote_engine, settings): run_config = settings.run parallel = run_config.parallel task_executor_type = run_config.task_executor_type if task_executor_type is None: if is_airflow_enabled(): from dbnd_airflow.executors import AirflowTaskExecutorType task_executor_type = AirflowTaskExecutorType.airflow_inprocess else: task_executor_type = TaskExecutorType.local if is_airflow_enabled(): from dbnd_airflow.executors import AirflowTaskExecutorType if parallel: if task_executor_type == TaskExecutorType.local: logger.warning( "Auto switching to engine type '%s' due to parallel mode.", AirflowTaskExecutorType.airflow_multiprocess_local, ) task_executor_type = AirflowTaskExecutorType.airflow_multiprocess_local if task_executor_type == AirflowTaskExecutorType.airflow_inprocess: logger.warning( "Auto switching to engine type '%s' due to parallel mode.", AirflowTaskExecutorType.airflow_multiprocess_local, ) task_executor_type = AirflowTaskExecutorType.airflow_multiprocess_local if (task_executor_type == AirflowTaskExecutorType.airflow_multiprocess_local or task_executor_type == AirflowTaskExecutorType.airflow_kubernetes): if "sqlite" in settings.core.sql_alchemy_conn: if settings.run.enable_concurent_sqlite: logger.warning( "You are running parallel execution on top of sqlite database! (see run.enable_concurent_sqlite)" ) else: # in theory sqlite can support a decent amount of parallelism, but in practice # the way airflow works each process holds the db exlusively locked which leads # to sqlite DB is locked exceptions raise friendly_error.execute_engine.parallel_or_remote_sqlite( "parallel") if is_plugin_enabled("dbnd-docker"): from dbnd_docker.kubernetes.kubernetes_engine_config import ( KubernetesEngineConfig, ) if (submit_tasks and isinstance(remote_engine, KubernetesEngineConfig) and run_config.enable_airflow_kubernetes): if task_executor_type != AirflowTaskExecutorType.airflow_kubernetes: logger.info( "Using dedicated kubernetes executor for this run") task_executor_type = AirflowTaskExecutorType.airflow_kubernetes parallel = True else: if parallel: logger.warning( "Airflow is not installed, parallel mode is not supported") all_executor_types = [TaskExecutorType.local] if is_airflow_enabled(): from dbnd_airflow.executors import AirflowTaskExecutorType all_executor_types.extend(AirflowTaskExecutorType.all()) if task_executor_type not in all_executor_types: raise DatabandConfigError("Unsupported engine type %s" % task_executor_type) return task_executor_type, parallel
def init_airflow_test_config(): if is_airflow_enabled(): from airflow import configuration from airflow.configuration import TEST_CONFIG_FILE configuration.conf.read(TEST_CONFIG_FILE)
def _try_get_task_from_airflow_op(value): if is_airflow_enabled(): from dbnd_airflow.dbnd_task_executor.converters import try_operator_to_dbnd_task return try_operator_to_dbnd_task(value)
def sql_conn_repr(self): if is_airflow_enabled(): from sqlalchemy.engine.url import make_url return repr(make_url(self.get_sql_alchemy_conn()))