Exemplo n.º 1
0
    def get_dag_params(self) -> Dict[str, Any]:
        """
        Merges default config with dag config, sets dag_id, and extropolates dag_start_date

        :returns: dict of dag parameters
        """
        try:
            dag_params: Dict[str, Any] = utils.merge_configs(
                self.dag_config, self.default_config
            )
        except Exception as err:
            raise Exception(f"Failed to merge config with default config, err: {err}")
        dag_params["dag_id"]: str = self.dag_name

        # Convert from 'dagrun_timeout_sec: int' to 'dagrun_timeout: timedelta'
        if utils.check_dict_key(dag_params, "dagrun_timeout_sec"):
            dag_params["dagrun_timeout"]: timedelta = timedelta(
                seconds=dag_params["dagrun_timeout_sec"]
            )
            del dag_params["dagrun_timeout_sec"]

        # Convert from 'end_date: Union[str, datetime, date]' to 'end_date: datetime'
        if utils.check_dict_key(dag_params["default_args"], "end_date"):
            dag_params["default_args"]["end_date"]: datetime = utils.get_datetime(
                date_value=dag_params["default_args"]["end_date"],
                timezone=dag_params["default_args"].get("timezone", "UTC"),
            )

        if utils.check_dict_key(dag_params["default_args"], "retry_delay_sec"):
            dag_params["default_args"]["retry_delay"]: timedelta = timedelta(
                seconds=dag_params["default_args"]["retry_delay_sec"]
            )
            del dag_params["default_args"]["retry_delay_sec"]

        if utils.check_dict_key(
            dag_params, "on_success_callback_name"
        ) and utils.check_dict_key(dag_params, "on_success_callback_file"):
            dag_params["on_success_callback"]: Callable = utils.get_python_callable(
                dag_params["on_success_callback_name"],
                dag_params["on_success_callback_file"],
            )

        if utils.check_dict_key(
            dag_params, "on_failure_callback_name"
        ) and utils.check_dict_key(dag_params, "on_failure_callback_file"):
            dag_params["on_failure_callback"]: Callable = utils.get_python_callable(
                dag_params["on_failure_callback_name"],
                dag_params["on_failure_callback_file"],
            )

        try:
            # ensure that default_args dictionary contains key "start_date"
            # with "datetime" value in specified timezone
            dag_params["default_args"]["start_date"]: datetime = utils.get_datetime(
                date_value=dag_params["default_args"]["start_date"],
                timezone=dag_params["default_args"].get("timezone", "UTC"),
            )
        except KeyError as err:
            raise Exception(f"{self.dag_name} config is missing start_date, err: {err}")
        return dag_params
Exemplo n.º 2
0
    def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
        """
        Takes an operator and params and creates an instance of that operator.

        :returns: instance of operator object
        """
        try:
            # class is a Callable https://stackoverflow.com/a/34578836/3679900
            operator_obj: Callable[..., BaseOperator] = import_string(operator)
        except Exception as err:
            raise Exception(
                f"Failed to import operator: {operator}. err: {err}")
        try:
            if operator_obj == PythonOperator:
                if not task_params.get(
                        "python_callable_name") and not task_params.get(
                            "python_callable_file"):
                    raise Exception(
                        "Failed to create task. PythonOperator requires `python_callable_name` \
                        and `python_callable_file` parameters.")
                task_params[
                    "python_callable"]: Callable = utils.get_python_callable(
                        task_params["python_callable_name"],
                        task_params["python_callable_file"],
                    )
            task: BaseOperator = operator_obj(**task_params)
        except Exception as err:
            raise Exception(
                f"Failed to create {operator_obj} task. err: {err}")
        return task
Exemplo n.º 3
0
def test_get_python_callable_valid():
    python_callable_file = os.path.realpath(__file__)
    python_callable_name = "print_test"

    python_callable = utils.get_python_callable(python_callable_name,
                                                python_callable_file)

    assert callable(python_callable)
Exemplo n.º 4
0
    def build(self) -> Dict[str, Union[str, DAG]]:
        """
        Generates a DAG from the DAG parameters.

        :returns: dict with dag_id and DAG object
        :type: Dict[str, Union[str, DAG]]
        """
        dag_params: Dict[str, Any] = self.get_dag_params()
        dag: DAG = DAG(
            dag_id=dag_params["dag_id"],
            schedule_interval=dag_params["schedule_interval"],
            description=dag_params.get("description", ""),
            concurrency=dag_params.get(
                "concurrency",
                configuration.conf.getint("core", "dag_concurrency"),
            ),
            max_active_runs=dag_params.get(
                "max_active_runs",
                configuration.conf.getint("core", "max_active_runs_per_dag"),
            ),
            dagrun_timeout=dag_params.get("dagrun_timeout", None),
            default_view=dag_params.get(
                "default_view",
                configuration.conf.get("webserver", "dag_default_view")),
            orientation=dag_params.get(
                "orientation",
                configuration.conf.get("webserver", "dag_orientation"),
            ),
            on_success_callback=dag_params.get("on_success_callback", None),
            on_failure_callback=dag_params.get("on_failure_callback", None),
            default_args=dag_params.get("default_args", {}),
            doc_md=dag_params.get("doc_md", None),
        )

        if dag_params.get("doc_md_file_path"):
            if not os.path.isabs(dag_params.get("doc_md_file_path")):
                raise Exception("`doc_md_file_path` must be absolute path")

            with open(dag_params.get("doc_md_file_path"), "r") as file:
                dag.doc_md = file.read()

        if dag_params.get("doc_md_python_callable_file") and dag_params.get(
                "doc_md_python_callable_name"):
            doc_md_callable = utils.get_python_callable(
                dag_params.get("doc_md_python_callable_name"),
                dag_params.get("doc_md_python_callable_file"),
            )
            dag.doc_md = doc_md_callable(
                **dag_params.get("doc_md_python_arguments", {}))

        # tags parameter introduced in Airflow 1.10.8
        if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.8"):
            dag.tags = dag_params.get("tags", None)

        tasks: Dict[str, Dict[str, Any]] = dag_params["tasks"]

        # add a property to mark this dag as an auto-generated on
        dag.is_dagfactory_auto_generated = True

        # create dictionary to track tasks and set dependencies
        tasks_dict: Dict[str, BaseOperator] = {}
        for task_name, task_conf in tasks.items():
            task_conf["task_id"]: str = task_name
            operator: str = task_conf["operator"]
            task_conf["dag"]: DAG = dag
            params: Dict[str, Any] = {
                k: v
                for k, v in task_conf.items() if k not in SYSTEM_PARAMS
            }
            task: BaseOperator = DagBuilder.make_task(operator=operator,
                                                      task_params=params)
            tasks_dict[task.task_id]: BaseOperator = task

        # set task dependencies after creating tasks
        for task_name, task_conf in tasks.items():
            if task_conf.get("dependencies"):
                source_task: BaseOperator = tasks_dict[task_name]
                for dep in task_conf["dependencies"]:
                    dep_task: BaseOperator = tasks_dict[dep]
                    source_task.set_upstream(dep_task)

        return {"dag_id": dag_params["dag_id"], "dag": dag}
Exemplo n.º 5
0
    def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
        """
        Takes an operator and params and creates an instance of that operator.

        :returns: instance of operator object
        """
        try:
            # class is a Callable https://stackoverflow.com/a/34578836/3679900
            operator_obj: Callable[..., BaseOperator] = import_string(operator)
        except Exception as err:
            raise f"Failed to import operator: {operator}" from err
        try:
            if operator_obj == PythonOperator:
                if not task_params.get(
                        "python_callable_name") and not task_params.get(
                            "python_callable_file"):
                    raise Exception(
                        "Failed to create task. PythonOperator requires `python_callable_name` \
                        and `python_callable_file` parameters.")
                task_params[
                    "python_callable"]: Callable = utils.get_python_callable(
                        task_params["python_callable_name"],
                        task_params["python_callable_file"],
                    )

            # KubernetesPodOperator
            if operator_obj == KubernetesPodOperator:
                task_params["secrets"] = ([
                    Secret(**v) for v in task_params.get("secrets")
                ] if task_params.get("secrets") is not None else None)

                task_params["ports"] = ([
                    Port(**v) for v in task_params.get("ports")
                ] if task_params.get("ports") is not None else None)
                task_params["volume_mounts"] = ([
                    VolumeMount(**v) for v in task_params.get("volume_mounts")
                ] if task_params.get("volume_mounts") is not None else None)
                task_params["volumes"] = ([
                    Volume(**v) for v in task_params.get("volumes")
                ] if task_params.get("volumes") is not None else None)
                task_params["pod_runtime_info_envs"] = ([
                    PodRuntimeInfoEnv(**v)
                    for v in task_params.get("pod_runtime_info_envs")
                ] if task_params.get("pod_runtime_info_envs") is not None else
                                                        None)
                task_params["full_pod_spec"] = (
                    V1Pod(**task_params.get("full_pod_spec"))
                    if task_params.get("full_pod_spec") is not None else None)
                task_params["init_containers"] = ([
                    V1Container(**v)
                    for v in task_params.get("init_containers")
                ] if task_params.get("init_containers") is not None else None)

            if utils.check_dict_key(task_params, "execution_timeout_secs"):
                task_params["execution_timeout"]: timedelta = timedelta(
                    seconds=task_params["execution_timeout_secs"])
                del task_params["execution_timeout_secs"]

            # use variables as arguments on operator
            if utils.check_dict_key(task_params, "variables_as_arguments"):
                variables: List[Dict[str, str]] = task_params.get(
                    "variables_as_arguments")
                for variable in variables:
                    if Variable.get(variable["variable"],
                                    default_var=None) is not None:
                        task_params[variable["attribute"]] = Variable.get(
                            variable["variable"], default_var=None)
                del task_params["variables_as_arguments"]

            task: BaseOperator = operator_obj(**task_params)
        except Exception as err:
            raise f"Failed to create {operator_obj} task" from err
        return task
Exemplo n.º 6
0
def test_get_python_callable_invalid_path():
    python_callable_file = "/not/absolute/path"
    python_callable_name = "print_test"

    with pytest.raises(Exception):
        utils.get_python_callable(python_callable_name, python_callable_file)
Exemplo n.º 7
0
    def get_dag_params(self) -> Dict[str, Any]:
        """
        Merges default config with dag config, sets dag_id, and extropolates dag_start_date

        :returns: dict of dag parameters
        """
        try:
            dag_params: Dict[str, Any] = utils.merge_configs(
                self.dag_config, self.default_config)
        except Exception as err:
            raise Exception(
                "Failed to merge config with default config") from err
        dag_params["dag_id"]: str = self.dag_name

        if dag_params.get("task_groups") and version.parse(
                AIRFLOW_VERSION) < version.parse("2.0.0"):
            raise Exception(
                "`task_groups` key can only be used with Airflow 2.x.x")

        if (utils.check_dict_key(dag_params, "schedule_interval")
                and dag_params["schedule_interval"] == "None"):
            dag_params["schedule_interval"] = None

        # Convert from 'dagrun_timeout_sec: int' to 'dagrun_timeout: timedelta'
        if utils.check_dict_key(dag_params, "dagrun_timeout_sec"):
            dag_params["dagrun_timeout"]: timedelta = timedelta(
                seconds=dag_params["dagrun_timeout_sec"])
            del dag_params["dagrun_timeout_sec"]

        # Convert from 'end_date: Union[str, datetime, date]' to 'end_date: datetime'
        if utils.check_dict_key(dag_params["default_args"], "end_date"):
            dag_params["default_args"][
                "end_date"]: datetime = utils.get_datetime(
                    date_value=dag_params["default_args"]["end_date"],
                    timezone=dag_params["default_args"].get("timezone", "UTC"),
                )

        if utils.check_dict_key(dag_params["default_args"], "retry_delay_sec"):
            dag_params["default_args"]["retry_delay"]: timedelta = timedelta(
                seconds=dag_params["default_args"]["retry_delay_sec"])
            del dag_params["default_args"]["retry_delay_sec"]

        if utils.check_dict_key(dag_params["default_args"],
                                "sla_miss_callback"):
            if isinstance(dag_params["default_args"]["sla_miss_callback"],
                          str):
                dag_params["default_args"][
                    "sla_miss_callback"]: Callable = import_string(
                        dag_params["default_args"]["sla_miss_callback"])

        if utils.check_dict_key(dag_params["default_args"],
                                "on_success_callback"):
            if isinstance(dag_params["default_args"]["on_success_callback"],
                          str):
                dag_params["default_args"][
                    "on_success_callback"]: Callable = import_string(
                        dag_params["default_args"]["on_success_callback"])

        if utils.check_dict_key(dag_params["default_args"],
                                "on_failure_callback"):
            if isinstance(dag_params["default_args"]["on_failure_callback"],
                          str):
                dag_params["default_args"][
                    "on_failure_callback"]: Callable = import_string(
                        dag_params["default_args"]["on_failure_callback"])

        if utils.check_dict_key(dag_params, "sla_miss_callback"):
            if isinstance(dag_params["sla_miss_callback"], str):
                dag_params["sla_miss_callback"]: Callable = import_string(
                    dag_params["sla_miss_callback"])

        if utils.check_dict_key(dag_params, "on_success_callback"):
            if isinstance(dag_params["on_success_callback"], str):
                dag_params["on_success_callback"]: Callable = import_string(
                    dag_params["on_success_callback"])

        if utils.check_dict_key(dag_params, "on_failure_callback"):
            if isinstance(dag_params["on_failure_callback"], str):
                dag_params["on_failure_callback"]: Callable = import_string(
                    dag_params["on_failure_callback"])

        if utils.check_dict_key(
                dag_params,
                "on_success_callback_name") and utils.check_dict_key(
                    dag_params, "on_success_callback_file"):
            dag_params[
                "on_success_callback"]: Callable = utils.get_python_callable(
                    dag_params["on_success_callback_name"],
                    dag_params["on_success_callback_file"],
                )

        if utils.check_dict_key(
                dag_params,
                "on_failure_callback_name") and utils.check_dict_key(
                    dag_params, "on_failure_callback_file"):
            dag_params[
                "on_failure_callback"]: Callable = utils.get_python_callable(
                    dag_params["on_failure_callback_name"],
                    dag_params["on_failure_callback_file"],
                )

        try:
            # ensure that default_args dictionary contains key "start_date"
            # with "datetime" value in specified timezone
            dag_params["default_args"][
                "start_date"]: datetime = utils.get_datetime(
                    date_value=dag_params["default_args"]["start_date"],
                    timezone=dag_params["default_args"].get("timezone", "UTC"),
                )
        except KeyError as err:
            raise Exception(
                f"{self.dag_name} config is missing start_date") from err
        return dag_params
Exemplo n.º 8
0
    def build(self) -> Dict[str, Union[str, DAG]]:
        """
        Generates a DAG from the DAG parameters.

        :returns: dict with dag_id and DAG object
        :type: Dict[str, Union[str, DAG]]
        """
        dag_params: Dict[str, Any] = self.get_dag_params()

        dag_kwargs: Dict[str, Any] = {}

        dag_kwargs["dag_id"] = dag_params["dag_id"]

        dag_kwargs["schedule_interval"] = dag_params.get(
            "schedule_interval", timedelta(days=1))

        if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.11"):
            dag_kwargs["description"] = dag_params.get("description", None)
        else:
            dag_kwargs["description"] = dag_params.get("description", "")

        if version.parse(AIRFLOW_VERSION) >= version.parse("2.2.0"):
            dag_kwargs["max_active_tasks"] = dag_params.get(
                "max_active_tasks",
                configuration.conf.getint("core", "max_active_tasks_per_dag"),
            )
        else:
            dag_kwargs["concurrency"] = dag_params.get(
                "concurrency",
                configuration.conf.getint("core", "dag_concurrency"))

        dag_kwargs["catchup"] = dag_params.get(
            "catchup",
            configuration.conf.getboolean("scheduler", "catchup_by_default"))

        dag_kwargs["max_active_runs"] = dag_params.get(
            "max_active_runs",
            configuration.conf.getint("core", "max_active_runs_per_dag"),
        )

        dag_kwargs["dagrun_timeout"] = dag_params.get("dagrun_timeout", None)

        dag_kwargs["default_view"] = dag_params.get(
            "default_view",
            configuration.conf.get("webserver", "dag_default_view"))

        dag_kwargs["orientation"] = dag_params.get(
            "orientation",
            configuration.conf.get("webserver", "dag_orientation"))

        dag_kwargs["sla_miss_callback"] = dag_params.get(
            "sla_miss_callback", None)

        dag_kwargs["on_success_callback"] = dag_params.get(
            "on_success_callback", None)

        dag_kwargs["on_failure_callback"] = dag_params.get(
            "on_failure_callback", None)

        dag_kwargs["default_args"] = dag_params.get("default_args", None)

        dag_kwargs["doc_md"] = dag_params.get("doc_md", None)

        dag: DAG = DAG(**dag_kwargs)

        if dag_params.get("doc_md_file_path"):
            if not os.path.isabs(dag_params.get("doc_md_file_path")):
                raise Exception("`doc_md_file_path` must be absolute path")

            with open(dag_params.get("doc_md_file_path"),
                      "r",
                      encoding="utf-8") as file:
                dag.doc_md = file.read()

        if dag_params.get("doc_md_python_callable_file") and dag_params.get(
                "doc_md_python_callable_name"):
            doc_md_callable = utils.get_python_callable(
                dag_params.get("doc_md_python_callable_name"),
                dag_params.get("doc_md_python_callable_file"),
            )
            dag.doc_md = doc_md_callable(
                **dag_params.get("doc_md_python_arguments", {}))

        # tags parameter introduced in Airflow 1.10.8
        if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.8"):
            dag.tags = dag_params.get("tags", None)

        tasks: Dict[str, Dict[str, Any]] = dag_params["tasks"]

        # add a property to mark this dag as an auto-generated on
        dag.is_dagfactory_auto_generated = True

        # create dictionary of task groups
        task_groups_dict: Dict[str, "TaskGroup"] = self.make_task_groups(
            dag_params.get("task_groups", {}), dag)

        # create dictionary to track tasks and set dependencies
        tasks_dict: Dict[str, BaseOperator] = {}
        for task_name, task_conf in tasks.items():
            task_conf["task_id"]: str = task_name
            operator: str = task_conf["operator"]
            task_conf["dag"]: DAG = dag
            # add task to task_group
            if task_groups_dict and task_conf.get("task_group_name"):
                task_conf["task_group"] = task_groups_dict[task_conf.get(
                    "task_group_name")]
            params: Dict[str, Any] = {
                k: v
                for k, v in task_conf.items() if k not in SYSTEM_PARAMS
            }
            task: BaseOperator = DagBuilder.make_task(operator=operator,
                                                      task_params=params)
            tasks_dict[task.task_id]: BaseOperator = task

        # set task dependencies after creating tasks
        self.set_dependencies(tasks, tasks_dict,
                              dag_params.get("task_groups", {}),
                              task_groups_dict)

        return {"dag_id": dag_params["dag_id"], "dag": dag}
Exemplo n.º 9
0
    def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
        """
        Takes an operator and params and creates an instance of that operator.

        :returns: instance of operator object
        """
        try:
            # class is a Callable https://stackoverflow.com/a/34578836/3679900
            operator_obj: Callable[..., BaseOperator] = import_string(operator)
        except Exception as err:
            raise Exception(f"Failed to import operator: {operator}") from err
        try:
            if operator_obj in [PythonOperator, BranchPythonOperator]:
                if not task_params.get(
                        "python_callable_name") and not task_params.get(
                            "python_callable_file"):
                    raise Exception(
                        "Failed to create task. PythonOperator and BranchPythonOperator requires \
                        `python_callable_name` and `python_callable_file` parameters."
                    )
                task_params[
                    "python_callable"]: Callable = utils.get_python_callable(
                        task_params["python_callable_name"],
                        task_params["python_callable_file"],
                    )
                # remove dag-factory specific parameters
                # Airflow 2.0 doesn't allow these to be passed to operator
                del task_params["python_callable_name"]
                del task_params["python_callable_file"]

            # Check for the custom success and failure callables in SqlSensor. These are considered
            # optional, so no failures in case they aren't found. Note: there's no reason to
            # declare both a callable file and a lambda function for success/failure parameter.
            # If both are found the object will not throw and error, instead callable file will
            # take precedence over the lambda function
            if operator_obj in [SqlSensor]:
                # Success checks
                if task_params.get("success_check_file") and task_params.get(
                        "success_check_name"):
                    task_params[
                        "success"]: Callable = utils.get_python_callable(
                            task_params["success_check_name"],
                            task_params["success_check_file"],
                        )
                    del task_params["success_check_name"]
                    del task_params["success_check_file"]
                elif task_params.get("success_check_lambda"):
                    task_params[
                        "success"]: Callable = utils.get_python_callable_lambda(
                            task_params["success_check_lambda"])
                    del task_params["success_check_lambda"]
                # Failure checks
                if task_params.get("failure_check_file") and task_params.get(
                        "failure_check_name"):
                    task_params[
                        "failure"]: Callable = utils.get_python_callable(
                            task_params["failure_check_name"],
                            task_params["failure_check_file"],
                        )
                    del task_params["failure_check_name"]
                    del task_params["failure_check_file"]
                elif task_params.get("failure_check_lambda"):
                    task_params[
                        "failure"]: Callable = utils.get_python_callable_lambda(
                            task_params["failure_check_lambda"])
                    del task_params["failure_check_lambda"]

            if operator_obj in [HttpSensor]:
                if not (task_params.get("response_check_name")
                        and task_params.get("response_check_file")
                        ) and not task_params.get("response_check_lambda"):
                    raise Exception(
                        "Failed to create task. HttpSensor requires \
                        `response_check_name` and `response_check_file` parameters \
                        or `response_check_lambda` parameter.")
                if task_params.get("response_check_file"):
                    task_params[
                        "response_check"]: Callable = utils.get_python_callable(
                            task_params["response_check_name"],
                            task_params["response_check_file"],
                        )
                    # remove dag-factory specific parameters
                    # Airflow 2.0 doesn't allow these to be passed to operator
                    del task_params["response_check_name"]
                    del task_params["response_check_file"]
                else:
                    task_params[
                        "response_check"]: Callable = utils.get_python_callable_lambda(
                            task_params["response_check_lambda"])
                    # remove dag-factory specific parameters
                    # Airflow 2.0 doesn't allow these to be passed to operator
                    del task_params["response_check_lambda"]

            # KubernetesPodOperator
            if operator_obj == KubernetesPodOperator:
                task_params["secrets"] = ([
                    Secret(**v) for v in task_params.get("secrets")
                ] if task_params.get("secrets") is not None else None)

                task_params["ports"] = ([
                    Port(**v) for v in task_params.get("ports")
                ] if task_params.get("ports") is not None else None)
                task_params["volume_mounts"] = ([
                    VolumeMount(**v) for v in task_params.get("volume_mounts")
                ] if task_params.get("volume_mounts") is not None else None)
                task_params["volumes"] = ([
                    Volume(**v) for v in task_params.get("volumes")
                ] if task_params.get("volumes") is not None else None)
                task_params["pod_runtime_info_envs"] = ([
                    PodRuntimeInfoEnv(**v)
                    for v in task_params.get("pod_runtime_info_envs")
                ] if task_params.get("pod_runtime_info_envs") is not None else
                                                        None)
                task_params["full_pod_spec"] = (
                    V1Pod(**task_params.get("full_pod_spec"))
                    if task_params.get("full_pod_spec") is not None else None)
                task_params["init_containers"] = ([
                    V1Container(**v)
                    for v in task_params.get("init_containers")
                ] if task_params.get("init_containers") is not None else None)

            if utils.check_dict_key(task_params, "execution_timeout_secs"):
                task_params["execution_timeout"]: timedelta = timedelta(
                    seconds=task_params["execution_timeout_secs"])
                del task_params["execution_timeout_secs"]

            if utils.check_dict_key(task_params, "sla_secs"):
                task_params["sla"]: timedelta = timedelta(
                    seconds=task_params["sla_secs"])
                del task_params["sla_secs"]

            if utils.check_dict_key(task_params, "execution_delta_secs"):
                task_params["execution_delta"]: timedelta = timedelta(
                    seconds=task_params["execution_delta_secs"])
                del task_params["execution_delta_secs"]

            if utils.check_dict_key(
                    task_params,
                    "execution_date_fn_name") and utils.check_dict_key(
                        task_params, "execution_date_fn_file"):
                task_params[
                    "execution_date_fn"]: Callable = utils.get_python_callable(
                        task_params["execution_date_fn_name"],
                        task_params["execution_date_fn_file"],
                    )
                del task_params["execution_date_fn_name"]
                del task_params["execution_date_fn_file"]

            # on_execute_callback is an Airflow 2.0 feature
            if utils.check_dict_key(
                    task_params, "on_execute_callback"
            ) and version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
                task_params["on_execute_callback"]: Callable = import_string(
                    task_params["on_execute_callback"])

            if utils.check_dict_key(task_params, "on_failure_callback"):
                task_params["on_failure_callback"]: Callable = import_string(
                    task_params["on_failure_callback"])

            if utils.check_dict_key(task_params, "on_success_callback"):
                task_params["on_success_callback"]: Callable = import_string(
                    task_params["on_success_callback"])

            if utils.check_dict_key(task_params, "on_retry_callback"):
                task_params["on_retry_callback"]: Callable = import_string(
                    task_params["on_retry_callback"])

            # use variables as arguments on operator
            if utils.check_dict_key(task_params, "variables_as_arguments"):
                variables: List[Dict[str, str]] = task_params.get(
                    "variables_as_arguments")
                for variable in variables:
                    if Variable.get(variable["variable"],
                                    default_var=None) is not None:
                        task_params[variable["attribute"]] = Variable.get(
                            variable["variable"], default_var=None)
                del task_params["variables_as_arguments"]

            task: BaseOperator = operator_obj(**task_params)
        except Exception as err:
            raise Exception(f"Failed to create {operator_obj} task") from err
        return task
Exemplo n.º 10
0
    def make_task(operator: str, task_params: Dict[str, Any],
                  af_vars: Dict[str, Any]) -> BaseOperator:
        """
        Takes an operator and params and creates an instance of that operator.

        :returns: instance of operator object
        """
        try:
            # class is a Callable https://stackoverflow.com/a/34578836/3679900
            operator_obj: Callable[..., BaseOperator] = import_string(operator)
        except Exception as err:
            raise Exception(f"Failed to import operator: {operator}") from err
        try:
            if operator_obj in [
                    PythonOperator, BranchPythonOperator, PythonSensor
            ]:
                if (not task_params.get("python_callable")
                        and not task_params.get("python_callable_name")
                        and not task_params.get("python_callable_file")):
                    # pylint: disable=line-too-long
                    raise Exception(
                        "Failed to create task. PythonOperator, BranchPythonOperator and PythonSensor requires \
                        `python_callable_name` and `python_callable_file` "
                        "parameters.\nOptionally you can load python_callable "
                        "from a file. with the special pyyaml notation:\n"
                        "  python_callable_file: !!python/name:my_module.my_func"
                    )
                if not task_params.get("python_callable"):
                    task_params[
                        "python_callable"]: Callable = utils.get_python_callable(
                            task_params["python_callable_name"],
                            task_params["python_callable_file"],
                        )
                    # remove dag-factory specific parameters
                    # Airflow 2.0 doesn't allow these to be passed to operator
                    del task_params["python_callable_name"]
                    del task_params["python_callable_file"]

            # Check for the custom success and failure callables in SqlSensor. These are considered
            # optional, so no failures in case they aren't found. Note: there's no reason to
            # declare both a callable file and a lambda function for success/failure parameter.
            # If both are found the object will not throw and error, instead callable file will
            # take precedence over the lambda function
            if operator_obj in [SqlSensor]:
                # Success checks
                if task_params.get("success_check_file") and task_params.get(
                        "success_check_name"):
                    task_params[
                        "success"]: Callable = utils.get_python_callable(
                            task_params["success_check_name"],
                            task_params["success_check_file"],
                        )
                    del task_params["success_check_name"]
                    del task_params["success_check_file"]
                elif task_params.get("success_check_lambda"):
                    task_params[
                        "success"]: Callable = utils.get_python_callable_lambda(
                            task_params["success_check_lambda"])
                    del task_params["success_check_lambda"]
                # Failure checks
                if task_params.get("failure_check_file") and task_params.get(
                        "failure_check_name"):
                    task_params[
                        "failure"]: Callable = utils.get_python_callable(
                            task_params["failure_check_name"],
                            task_params["failure_check_file"],
                        )
                    del task_params["failure_check_name"]
                    del task_params["failure_check_file"]
                elif task_params.get("failure_check_lambda"):
                    task_params[
                        "failure"]: Callable = utils.get_python_callable_lambda(
                            task_params["failure_check_lambda"])
                    del task_params["failure_check_lambda"]

            if operator_obj in [HttpSensor]:
                if not (task_params.get("response_check_name")
                        and task_params.get("response_check_file")
                        ) and not task_params.get("response_check_lambda"):
                    raise Exception(
                        "Failed to create task. HttpSensor requires \
                        `response_check_name` and `response_check_file` parameters \
                        or `response_check_lambda` parameter.")
                if task_params.get("response_check_file"):
                    task_params[
                        "response_check"]: Callable = utils.get_python_callable(
                            task_params["response_check_name"],
                            task_params["response_check_file"],
                        )
                    # remove dag-factory specific parameters
                    # Airflow 2.0 doesn't allow these to be passed to operator
                    del task_params["response_check_name"]
                    del task_params["response_check_file"]
                else:
                    task_params[
                        "response_check"]: Callable = utils.get_python_callable_lambda(
                            task_params["response_check_lambda"])
                    # remove dag-factory specific parameters
                    # Airflow 2.0 doesn't allow these to be passed to operator
                    del task_params["response_check_lambda"]

            # KubernetesPodOperator
            if operator_obj == KubernetesPodOperator:
                task_params["secrets"] = ([
                    Secret(**v) for v in task_params.get("secrets")
                ] if task_params.get("secrets") is not None else None)

                task_params["ports"] = ([
                    Port(**v) for v in task_params.get("ports")
                ] if task_params.get("ports") is not None else None)
                task_params["volume_mounts"] = ([
                    VolumeMount(**v) for v in task_params.get("volume_mounts")
                ] if task_params.get("volume_mounts") is not None else None)
                task_params["volumes"] = ([
                    Volume(**v) for v in task_params.get("volumes")
                ] if task_params.get("volumes") is not None else None)
                task_params["pod_runtime_info_envs"] = ([
                    PodRuntimeInfoEnv(**v)
                    for v in task_params.get("pod_runtime_info_envs")
                ] if task_params.get("pod_runtime_info_envs") is not None else
                                                        None)
                task_params["full_pod_spec"] = (
                    V1Pod(**task_params.get("full_pod_spec"))
                    if task_params.get("full_pod_spec") is not None else None)
                task_params["init_containers"] = ([
                    V1Container(**v)
                    for v in task_params.get("init_containers")
                ] if task_params.get("init_containers") is not None else None)
            if operator_obj == DockerOperator:
                if task_params.get("environment") is not None:
                    task_params["environment"] = {
                        k: os.environ.get(v, v)
                        for k, v in task_params["environment"].items()
                    }

            if operator_obj == EcsOperator:
                for c in task_params["overrides"]["containerOverrides"]:
                    if c.get('environment') is not None:
                        for env in c['environment']:
                            env['value'] = os.environ.get(
                                env['value'], env['value'])

                if 'ECS_SECURITY_GROUPS' in af_vars and 'network_configuration' in task_params:
                    task_params["network_configuration"]["awsvpcConfiguration"]['securityGroups'] \
                        = af_vars['ECS_SECURITY_GROUPS']

                if 'ECS_SUBNETS' in af_vars and 'network_configuration' in task_params:
                    task_params['network_configuration'][
                        "awsvpcConfiguration"]["subnets"] = af_vars[
                            "ECS_SUBNETS"]

                if af_vars.get('ECS_CLUSTER'):
                    task_params['cluster'] = af_vars["ECS_CLUSTER"]
                    task_params['task_definition'] = (
                        af_vars.get('ECS_CLUSTER') + '_' +
                        task_params['task_definition']).lower()

                    task_params['awslogs_group'] = \
                        task_params['awslogs_group'] + '/' + af_vars.get('ECS_CLUSTER').lower()

            if utils.check_dict_key(task_params, "execution_timeout_secs"):
                task_params["execution_timeout"]: timedelta = timedelta(
                    seconds=task_params["execution_timeout_secs"])
                del task_params["execution_timeout_secs"]

            if utils.check_dict_key(task_params, "sla_secs"):
                task_params["sla"]: timedelta = timedelta(
                    seconds=task_params["sla_secs"])
                del task_params["sla_secs"]

            if utils.check_dict_key(task_params, "execution_delta_secs"):
                task_params["execution_delta"]: timedelta = timedelta(
                    seconds=task_params["execution_delta_secs"])
                del task_params["execution_delta_secs"]

            if utils.check_dict_key(
                    task_params,
                    "execution_date_fn_name") and utils.check_dict_key(
                        task_params, "execution_date_fn_file"):
                task_params[
                    "execution_date_fn"]: Callable = utils.get_python_callable(
                        task_params["execution_date_fn_name"],
                        task_params["execution_date_fn_file"],
                    )
                del task_params["execution_date_fn_name"]
                del task_params["execution_date_fn_file"]

            # on_execute_callback is an Airflow 2.0 feature
            if utils.check_dict_key(
                    task_params, "on_execute_callback"
            ) and version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"):
                task_params["on_execute_callback"]: Callable = import_string(
                    task_params["on_execute_callback"])

            if utils.check_dict_key(task_params, "on_failure_callback"):
                task_params["on_failure_callback"]: Callable = import_string(
                    task_params["on_failure_callback"])

            if utils.check_dict_key(task_params, "on_success_callback"):
                task_params["on_success_callback"]: Callable = import_string(
                    task_params["on_success_callback"])

            if utils.check_dict_key(task_params, "on_retry_callback"):
                task_params["on_retry_callback"]: Callable = import_string(
                    task_params["on_retry_callback"])

            # use variables as arguments on operator
            if utils.check_dict_key(task_params, "variables_as_arguments"):
                variables: List[Dict[str, str]] = task_params.get(
                    "variables_as_arguments")
                for variable in variables:
                    if Variable.get(variable["variable"],
                                    default_var=None) is not None:
                        task_params[variable["attribute"]] = Variable.get(
                            variable["variable"], default_var=None)
                del task_params["variables_as_arguments"]

            # use variables as arguments on operator
            if utils.check_dict_key(task_params, "af_vars_as_arguments"):
                variables: List[Dict[str, str]] = task_params.get(
                    "af_vars_as_arguments")
                for variable in variables:
                    if af_vars.get(variable["variable"], None) is not None:
                        task_params[variable["attribute"]] = af_vars.get(
                            variable["variable"], None)
                del task_params["af_vars_as_arguments"]

            task: BaseOperator = operator_obj(**task_params)
        except Exception as err:
            raise Exception(f"Failed to create {operator_obj} task") from err
        return task
Exemplo n.º 11
0
def test_get_python_callable_missing_param_name():
    python_callable_file = "/not/absolute/path"
    python_callable_name = None

    with pytest.raises(Exception):
        utils.get_python_callable(python_callable_name, python_callable_file)
Exemplo n.º 12
0
def test_get_python_callable_missing_param_file():
    python_callable_file = None
    python_callable_name = "print_test"

    with pytest.raises(Exception):
        utils.get_python_callable(python_callable_name, python_callable_file)