Exemple #1
0
    def base_init(
            self,
            source_conn_id,
            dwh_engine,
            dwh_conn_id,
            extract_strategy,
            load_strategy,
            target_table_name,
            target_schema_name,
            target_schema_suffix="_next",
            target_database_name=None,  # Only for Snowflake
            load_data_from=None,  # set a minimum date e.g. for reloading of data
            reload_data_from=None,  # load data from this date for new tables
            load_data_from_relative=None,  # optional timedelta for incremental
            load_data_until=None,  # set a maximum date
            load_data_until_relative=None,  # optional timedelta for incremental
            load_data_chunking_timedelta=None,  # optional timedelta to chunk by
            primary_key=None,  # either string or list of strings
            include_columns=None,  # list of columns to explicitly include
            exclude_columns=None,  # list of columns to exclude
            index_columns=[],  # list of columns to create an index on. can be
            # an expression, must be quoted in list if quoting is required.
        hash_columns=None,  # str or list of str - columns to hash pre-upload
            hash_salt=None,  # string salt part for hashing
            wait_for_seconds=120,  # seconds past data_interval_end to wait until
            # wait_for_seconds only applies for incremental loads
        add_metadata=True,
            rename_columns: Optional[Dict[str, str]] = None,  # Rename columns
            subsequent_field=None,  # field name to use for subsequent extract strategy
            default_timezone=None,  # specify a default time zone for tz-naive datetimes
            use_temp_pickling=True,  # use new upload method if True - use it by default
            pickling_upload_chunk_size=100000,  # default chunk size for pickle upload
            pickle_compression=None,  # data compression algorithm to use for pickles
            default_values=None,  # dict with default values for columns (to avoid nulls)
            cleaner_class=EWAHCleaner,
            cleaner_callables=None,  # callables or list of callables to run during cleaning
            uploader_class=None,  # Future: deprecate dwh_engine and use this kwarg instead
            additional_uploader_kwargs=None,
            deduplication_before_upload=False,
            *args,
            **kwargs):
        super().__init__(*args, **kwargs)

        if default_values:
            assert isinstance(default_values, dict)

        assert pickle_compression is None or pickle_compression in (
            "gzip",
            "bz2",
            "lzma",
        )

        assert isinstance(rename_columns, (type(None), dict))

        if default_timezone:
            assert dwh_engine in (EC.DWH_ENGINE_POSTGRES,
                                  )  # Only for PostgreSQL so far
            assert not ";" in default_timezone  # Avoid SQL Injection

        # Check if the extract and load strategies are allowed in combination
        # Also check if required params are supplied for the strategies
        if extract_strategy == EC.ES_SUBSEQUENT:
            assert subsequent_field
            assert load_strategy in (
                EC.LS_UPSERT,
                EC.LS_INSERT_ADD,
                EC.LS_INSERT_REPLACE,
            )
            # insert_delete makes no sense in the subsequent context
        elif extract_strategy == EC.ES_FULL_REFRESH:
            assert load_strategy in (EC.LS_INSERT_REPLACE, EC.LS_INSERT_ADD)
            # upsert makes no sense - it's a full refresh!
            # insert_delete makes no sense - what to delete? Well, everyting!
        elif extract_strategy == EC.ES_INCREMENTAL:
            assert load_strategy in (EC.LS_UPSERT, EC.LS_INSERT_ADD)
            # replace makes no sense - it's incremental loading after all!
            # insert_delete makes sense but is not yet implemented
        else:
            raise Exception(
                "Invalid extract_strategy {0}!".format(extract_strategy))

        if load_strategy == EC.LS_UPSERT:
            # upserts require a (composite) primary key of some sort
            if not primary_key:
                raise Exception(
                    "If the load strategy is upsert, name of the primary"
                    " key(s) (primary_key) is required!")
        elif load_strategy in (EC.LS_INSERT_ADD, EC.LS_INSERT_REPLACE):
            pass  # No requirements
        else:
            raise Exception("Invalid load_strategy {0}!".format(load_strategy))

        for item in EWAHBaseOperator.template_fields:
            # Make sure template_fields was not overwritten
            _msg = "Operator must not overwrite template_fields!"
            assert item in self.template_fields, _msg

        _msg = 'param "wait_for_seconds" must be a nonnegative integer!'
        assert isinstance(wait_for_seconds,
                          int) and wait_for_seconds >= 0, _msg
        _msg = "extract_strategy {0} not accepted for this operator!".format(
            extract_strategy, )
        assert self._ACCEPTED_EXTRACT_STRATEGIES.get(extract_strategy), _msg

        if isinstance(primary_key, str):
            primary_key = [primary_key]

        if isinstance(hash_columns, str):
            hash_columns = [hash_columns]

        if isinstance(include_columns, str):
            include_columns = [include_columns]

        if include_columns and primary_key:
            for col in primary_key:
                if not col in include_columns:
                    _msg = """
                        Primary key {0} is not in the include_columns list.
                        Make sure all primary keys are included.
                        """.format(col)
                    raise Exception(_msg)

        if exclude_columns and isinstance(exclude_columns, str):
            exclude_columns = [exclude_columns]

        if include_columns and exclude_columns:
            _msg = "Don't use include and exclude columns config at the same time!"
            raise Exception(_msg)

        if not dwh_engine or not dwh_engine in EC.DWH_ENGINES:
            _msg = "Invalid DWH Engine: {0}\n\nAccepted Engines:\n\t{1}".format(
                str(dwh_engine),
                "\n\t".join(EC.DWH_ENGINES),
            )
            raise Exception(_msg)

        if index_columns and not dwh_engine == EC.DWH_ENGINE_POSTGRES:
            raise Exception("Indices are only allowed for PostgreSQL DWHs!")

        if (not dwh_engine
                == EC.DWH_ENGINE_SNOWFLAKE) and target_database_name:
            raise Exception('Received argument for "target_database_name"!')

        _msg = "load_data_from_relative and load_data_until_relative must be"
        _msg += " timedelta if supplied!"
        assert isinstance(
            load_data_from_relative,
            (type(None), timedelta),
        ), _msg
        assert isinstance(
            load_data_until_relative,
            (type(None), timedelta),
        ), _msg
        _msg = "load_data_chunking_timedelta must be timedelta!"
        assert isinstance(
            load_data_chunking_timedelta,
            (type(None), timedelta),
        ), _msg

        if callable(cleaner_callables):
            cleaner_callables = [cleaner_callables]

        self.source_conn_id = source_conn_id
        self.dwh_engine = dwh_engine
        self.dwh_conn_id = dwh_conn_id
        self.extract_strategy = extract_strategy
        self.load_strategy = load_strategy
        self.target_table_name = target_table_name
        self.target_schema_name = target_schema_name
        self.target_schema_suffix = target_schema_suffix
        self.target_database_name = target_database_name
        self.load_data_from = load_data_from
        self.reload_data_from = reload_data_from
        self.load_data_from_relative = load_data_from_relative
        self.load_data_until = load_data_until
        self.load_data_until_relative = load_data_until_relative
        self.load_data_chunking_timedelta = load_data_chunking_timedelta
        self.primary_key = primary_key  # may be used ...
        #   ... by a child class at execution!
        self.include_columns = include_columns
        self.exclude_columns = exclude_columns
        self.index_columns = index_columns
        self.hash_columns = hash_columns
        self.hash_salt = hash_salt
        self.wait_for_seconds = wait_for_seconds
        self.add_metadata = add_metadata
        self.rename_columns = rename_columns
        self.subsequent_field = subsequent_field
        self.default_timezone = default_timezone
        self.use_temp_pickling = use_temp_pickling
        self.pickling_upload_chunk_size = pickling_upload_chunk_size
        self.pickle_compression = pickle_compression
        self.default_values = default_values
        self.cleaner_class = cleaner_class
        self.cleaner_callables = cleaner_callables
        self.deduplication_before_upload = deduplication_before_upload

        self.uploader_class = uploader_class or get_uploader(self.dwh_engine)
        self.additional_uploader_kwargs = additional_uploader_kwargs or {}

        _msg = "DWH hook does not support extract strategy {0}!".format(
            extract_strategy, )
def dag_factory_idempotent(
    dag_name: str,
    dwh_engine: str,
    dwh_conn_id: str,
    start_date: datetime,
    el_operator: Type[EWAHBaseOperator],
    operator_config: dict,
    target_schema_name: str,
    target_schema_suffix: str = "_next",
    target_database_name: Optional[str] = None,
    default_args: Optional[dict] = None,
    schedule_interval_backfill: timedelta = timedelta(days=1),
    schedule_interval_future: timedelta = timedelta(hours=1),
    end_date: Optional[datetime] = None,
    read_right_users: Optional[Union[List[str], str]] = None,
    additional_dag_args: Optional[dict] = None,
    additional_task_args: Optional[dict] = None,
    logging_func: Optional[Callable] = None,
    dagrun_timeout_factor: Optional[float] = None,
    task_timeout_factor: Optional[float] = 0.8,
    **kwargs,
) -> Tuple[DAG, DAG, DAG]:
    """Returns a tuple of three DAGs associated with incremental data loading.

    The three DAGs are:
    - Reset DAG
    - Backfill DAG
    - Current DAG

    The Reset DAG pauses the other two DAGs, deletes all DAG statistics and
    data, and deletes all data related to the DAGs from the DWH.

    The Backfill DAG runs in a long schedule interval (e.g. a week) from
    start_date on. Each run of this DAG fetches a relatively long period worth
    of data. The purpose is to Backfill the DWH.

    The Current DAG runs in a short schedule interval (e.g. one hour). It has a
    dynamic start date which is the end of the last full schedule interval
    period of the backfill DAG. This DAG keeps the data in the DWH fresh.

    :param dag_name: Base name of the DAG. The returned DAGs will be named
        after dag_nme with the suffixes "_Idempotent_Reset",
        "_Idempotent_Backfill", and "_Idempotent".
    :param dwh_engine: Type of the DWH (e.g. postgresql).
    :param dwh_conn_id: Airflow connection ID with DWH credentials.
    :param start_date: Start date of the DAGs (i.e. of the Backfill DAG).
    :param el_operator: A subclass of EWAHBaseOperator that is used to load
        the individual tables.
    :param target_schema_name: Name of the schema in the DWH that receives the data.
    :param target_schema_suffix: Suffix used during data loading process. The DAG
        creates a new schema "{target_schema_name}{target_schema_suffix}" during
        loading.
    :param target_database_name: Name of the database (Snowflake) or dataset
        (BigQuery), if applicable.
    :param default_args: A dictionary given to the DAGs as default_args param.
    :param schedule_interval_backfill: The schedule interval of the Backfill
        DAG. Must be at least 1 day. Must be larger than
        schedule_interval_future.
    :param schedule_interval_future: The schedule interval of the Current DAG.
        Must be smaller than schedule_interval_backfill. It is recommended not
        to go below 30 minutes. An appropriate schedule interval can be found
        via trial and error. The Current DAG runtime must be less than this
        param in order for EWAH to work properly.
    :param end_date: Airflow DAG kwarg end_date.
    :param read_right_users: List of strings of users or roles that should
        receive read rights on the loaded tables. Can also be a comma-separated
        string instead of a list of strings.
    :param additional_dag_args: kwargs applied to the DAG. Can be any DAG
        kwarg that is not used directly within the function.
    :param additional_task_args: kwargs applied to the tasks. Can be any Task
        kwarg, although some may be overwritten by the function.
    :param logging_func: Pass a callable for logging output. Defaults to print.
    :param dag_timeout_factor: Set a timeout factor for dag runs so they fail if
        they exceed a percentage of their schedule_interval (default: 0.8).
    """

    def raise_exception(msg: str) -> None:
        """Add information to error message before raising."""
        raise Exception("DAG: {0} - Error: {1}".format(dag_name, msg))

    logging_func = logging_func or print

    if kwargs:
        logging_func("unused config: {0}".format(str(kwargs)))

    additional_dag_args = additional_dag_args or {}
    additional_task_args = additional_task_args or {}

    if not isinstance(schedule_interval_future, timedelta):
        raise_exception("Schedule intervals must be datetime.timedelta!")
    if not isinstance(schedule_interval_backfill, timedelta):
        raise_exception("Schedule intervals must be datetime.timedelta!")
    if schedule_interval_backfill < timedelta(days=1):
        raise_exception("Backfill schedule interval cannot be below 1 day!")
    if schedule_interval_backfill < schedule_interval_future:
        raise_exception(
            "Backfill schedule interval must be larger than"
            + " regular schedule interval!"
        )
    if not operator_config.get("tables"):
        raise_exception('Requires a "tables" dictionary in operator_config!')
    if not read_right_users is None:
        if isinstance(read_right_users, str):
            read_right_users = [u.strip() for u in read_right_users.split(",")]
        if not isinstance(read_right_users, Iterable):
            raise_exception("read_right_users must be an iterable or string!")

    current_time = datetime_utcnow_with_tz()
    if not start_date.tzinfo:
        raise_exception("start_date must be timezone aware!")

    # Make switch halfway between latest normal DAG run and the
    #   data_interval_end of the next-to-run backfill DAG
    #   --> no interruption of the system, airflow has time to register
    #   the change, the backfill DAG can run once unimpeded and the
    #   normal DAG can then resume as per normal. Note: in that case,
    #   keep both DAGs active!
    current_time += schedule_interval_future / 2
    # How much time has passed in total between start_date and now?
    switch_absolute_date = current_time - start_date
    # How often could the backfill DAG run in that time frame?
    switch_absolute_date /= schedule_interval_backfill
    switch_absolute_date = int(switch_absolute_date)
    # What is the exact datetime after the last of those runs?
    switch_absolute_date *= schedule_interval_backfill
    switch_absolute_date += start_date
    # --> switch_absolute_date is always in the (recent) past

    # Make sure that the backfill and normal DAG start_date and
    #   schedule_interval calculations were successful and correct
    backfill_timedelta = switch_absolute_date - start_date
    backfill_tasks_count = backfill_timedelta / schedule_interval_backfill

    if end_date:
        backfill_end_date = min(switch_absolute_date, end_date)
    else:
        backfill_end_date = switch_absolute_date

    if dagrun_timeout_factor:
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(dagrun_timeout_factor, (int, float)) and (
            0 < dagrun_timeout_factor <= 1
        ), _msg
        dagrun_timeout = additional_dag_args.get(
            "dagrun_timeout", dagrun_timeout_factor * schedule_interval_future
        )
        dagrun_timeout_backfill = additional_dag_args.pop(
            "dagrun_timeout", dagrun_timeout_factor * schedule_interval_future
        )
    else:
        dagrun_timeout = additional_dag_args.get("dagrun_timeout")
        dagrun_timeout_backfill = additional_dag_args.pop("dagrun_timeout", None)

    if task_timeout_factor:
        execution_timeout = additional_task_args.get(
            "execution_timeout", task_timeout_factor * schedule_interval_future
        )
        execution_timeout_backfill = additional_task_args.pop(
            "execution_timeout", task_timeout_factor * schedule_interval_backfill
        )
    else:
        execution_timeout = additional_task_args.get("execution_timeout")
        execution_timeout_backfill = additional_task_args.pop("execution_timeout", None)

    dags = (
        DAG(  # Current DAG
            dag_name + "_Idempotent",
            start_date=switch_absolute_date,
            end_date=end_date,
            schedule_interval=schedule_interval_future,
            catchup=True,
            max_active_runs=1,
            default_args=default_args,
            dagrun_timeout=dagrun_timeout,
            **additional_dag_args,
        ),
        DAG(  # Backfill DAG
            dag_name + "_Idempotent_Backfill",
            start_date=start_date,
            end_date=backfill_end_date,
            schedule_interval=schedule_interval_backfill,
            catchup=True,
            max_active_runs=1,
            default_args=default_args,
            dagrun_timeout=dagrun_timeout_backfill,
            **additional_dag_args,
        ),
        DAG(  # Reset DAG
            dag_name + "_Idempotent_Reset",
            start_date=start_date,
            end_date=end_date,
            schedule_interval=None,
            catchup=False,
            max_active_runs=1,
            default_args=default_args,
            **additional_dag_args,
        ),
    )

    # Create reset DAG
    reset_bash_command = " && ".join(  # First pause DAGs, then delete their metadata
        [
            "airflow dags pause {dag_name}_Idempotent",
            "airflow dags pause {dag_name}_Idempotent_Backfill",
            "airflow dags delete {dag_name}_Idempotent -y",
            "airflow dags delete {dag_name}_Idempotent_Backfill -y",
        ]
    ).format(dag_name=dag_name)
    reset_task = BashOperator(
        bash_command=reset_bash_command,
        task_id="reset_by_deleting_all_task_instances",
        dag=dags[2],
        **additional_task_args,
    )
    drop_sql = """
        DROP SCHEMA IF EXISTS "{target_schema_name}" CASCADE;
        DROP SCHEMA IF EXISTS "{target_schema_name}{suffix}" CASCADE;
    """.format(
        target_schema_name=target_schema_name,
        suffix=target_schema_suffix,
    )
    if dwh_engine == EC.DWH_ENGINE_POSTGRES:
        drop_task = PGO(
            sql=drop_sql,
            postgres_conn_id=dwh_conn_id,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
        drop_task = SnowflakeOperator(
            sql=drop_sql,
            snowflake_conn_id=dwh_conn_id,
            database=target_database_name,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    else:
        drop_sql = """
            DROP SCHEMA IF EXISTS `{0}` CASCADE;
            DROP SCHEMA IF EXISTS `{1}` CASCADE;
        """.format(
            target_schema_name,
            target_schema_name + target_schema_suffix,
        )
        drop_task = BigqueryOperator(
            sql=drop_sql,
            bigquery_conn_id=dwh_conn_id,
            project=target_database_name,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )

    reset_task >> drop_task

    # Incremental DAG schema tasks
    kickoff, final = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[0],
        dwh_engine=dwh_engine,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        dwh_conn_id=dwh_conn_id,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout,
        **additional_task_args,
    )

    # Backfill DAG schema tasks
    kickoff_backfill, final_backfill = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[1],
        dwh_engine=dwh_engine,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        dwh_conn_id=dwh_conn_id,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout_backfill,
        **additional_task_args,
    )

    # add table creation tasks
    arg_dict = deepcopy(additional_task_args)
    arg_dict.update(operator_config.get("general_config", {}))
    # Default reload_data_from to start_date
    arg_dict["reload_data_from"] = arg_dict.get("reload_data_from", start_date)
    for table in operator_config["tables"].keys():
        kwargs = deepcopy(arg_dict)
        kwargs.update(operator_config["tables"][table] or {})

        # Overwrite / ignore changes to these kwargs:
        kwargs.update(
            {
                "extract_strategy": kwargs.get("extract_strategy", EC.ES_INCREMENTAL),
                "task_id": "extract_load_" + re.sub(r"[^a-zA-Z0-9_]", "", table),
                "dwh_engine": dwh_engine,
                "dwh_conn_id": dwh_conn_id,
                "target_table_name": operator_config["tables"][table].get(
                    "target_table_name", table
                ),
                "target_schema_name": target_schema_name,
                "target_schema_suffix": target_schema_suffix,
                "target_database_name": target_database_name,
            }
        )
        assert kwargs["extract_strategy"] in (
            EC.ES_FULL_REFRESH,
            EC.ES_SUBSEQUENT,
            EC.ES_INCREMENTAL,
        )
        kwargs["load_strategy"] = kwargs.get(
            "load_strategy",
            EC.DEFAULT_LS_PER_ES[kwargs["extract_strategy"]],
        )

        if kwargs["extract_strategy"] == EC.ES_INCREMENTAL:
            # Backfill ignores non-incremental extract strategy types
            task_backfill = el_operator(
                dag=dags[1], execution_timeout=execution_timeout_backfill, **kwargs
            )
            kickoff_backfill >> task_backfill >> final_backfill

        task = el_operator(dag=dags[0], execution_timeout=execution_timeout, **kwargs)
        kickoff >> task >> final

    # For the unlikely case that there is no incremental task
    kickoff_backfill >> final_backfill

    # Make sure incremental loading stops if there is an error!
    if additional_task_args.get("task_timeout_factor"):
        # sensors shall have no timeouts!
        del additional_task_args["task_timeout_factor"]
    ets = (
        ExtendedETS(
            task_id="sense_previous_instance",
            allowed_states=["success", "skipped"],
            external_dag_id=dags[0]._dag_id,
            external_task_id=final.task_id,
            execution_delta=schedule_interval_future,
            backfill_dag_id=dags[1]._dag_id,
            backfill_external_task_id=final_backfill.task_id,
            backfill_execution_delta=schedule_interval_backfill,
            dag=dags[0],
            poke_interval=5 * 60,
            mode="reschedule",  # don't block a worker and pool slot
            **additional_task_args,
        ),
        ExtendedETS(
            task_id="sense_previous_instance",
            allowed_states=["success", "skipped"],
            external_dag_id=dags[1]._dag_id,
            external_task_id=final_backfill.task_id,
            execution_delta=schedule_interval_backfill,
            dag=dags[1],
            poke_interval=5 * 60,
            mode="reschedule",  # don't block a worker and pool slot
            **additional_task_args,
        ),
    )
    ets[0] >> kickoff
    ets[1] >> kickoff_backfill

    return dags
def dag_factory_atomic(dag_name: str,
                       dwh_engine: str,
                       dwh_conn_id: str,
                       start_date: datetime,
                       el_operator: Type[EWAHBaseOperator],
                       operator_config: dict,
                       target_schema_name: str,
                       target_schema_suffix: str = "_next",
                       target_database_name: Optional[str] = None,
                       default_args: Optional[dict] = None,
                       schedule_interval: Union[str,
                                                timedelta] = timedelta(days=1),
                       end_date: Optional[datetime] = None,
                       read_right_users: Optional[Union[List[str],
                                                        str]] = None,
                       additional_dag_args: Optional[dict] = None,
                       additional_task_args: Optional[dict] = None,
                       logging_func: Optional[Callable] = None,
                       dagrun_timeout_factor: Optional[float] = None,
                       task_timeout_factor: Optional[float] = None,
                       **kwargs) -> Tuple[DAG]:
    def raise_exception(msg: str) -> None:
        """Add information to error message before raising."""
        raise Exception("DAG: {0} - Error: {1}".format(dag_name, msg))

    logging_func = logging_func or print

    if kwargs:
        logging_func("unused config: {0}".format(str(kwargs)))

    additional_dag_args = additional_dag_args or {}
    additional_task_args = additional_task_args or {}

    if not read_right_users is None:
        if isinstance(read_right_users, str):
            read_right_users = [u.strip() for u in read_right_users.split(",")]
        if not isinstance(read_right_users, Iterable):
            raise_exception("read_right_users must be an iterable or string!")

    if isinstance(schedule_interval, str):
        # Allow using cron-style schedule intervals
        catchup = False
        assert croniter.is_valid(
            schedule_interval
        ), "schedule_interval is neither timedelta nor not valid cron!"
    else:
        assert isinstance(
            schedule_interval,
            timedelta), "schedule_interval must be cron-string or timedelta!"

        catchup = True
        # fake catchup = True: between start_date and end_date is one schedule_interval
        # --> run the full refreshs every schedule_interval at the same time instead of
        # having a drift in execution time!
        if end_date:
            end_date = min(end_date, datetime_utcnow_with_tz())
        else:
            end_date = datetime_utcnow_with_tz()

        start_date += (int((end_date - start_date) / schedule_interval) -
                       1) * schedule_interval

        # case 1: end_date = start_date + schedule_interval
        # if the division result is a precise integer, that implies a definite end_date
        # --> adjust to get exactly one schedule_interval delta between start_date and
        # end_date to have one last run available (that should have run before end_date)

        # case 2: end_date > (start_date + schedule_interval)
        # Airflow executes after data_interval_end - start_date has to be
        # between exactly 1 and below 2 time schedule_interval before end_date!
        # end_date - 2*schedule_interval < start_date <= end_date - schedule_interval

        # Make sure only one execution every runs scheduled but manual triggers work!
        end_date = start_date + 2 * schedule_interval - timedelta(seconds=1)

    if dagrun_timeout_factor:
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(
            dagrun_timeout_factor,
            (int, float)) and (0 < dagrun_timeout_factor <= 1), _msg
        additional_dag_args["dagrun_timeout"] = additional_dag_args.get(
            "dagrun_timeout", dagrun_timeout_factor * schedule_interval)

    if task_timeout_factor:
        additional_task_args["execution_timeout"] = additional_task_args.get(
            "execution_timeout", task_timeout_factor * schedule_interval)

    dag = DAG(
        dag_name,
        catchup=catchup,
        default_args=default_args,
        max_active_runs=1,
        schedule_interval=schedule_interval,
        start_date=start_date,
        end_date=end_date,
        **additional_dag_args,
    )

    kickoff, final = get_uploader(dwh_engine).get_schema_tasks(
        dag=dag,
        dwh_engine=dwh_engine,
        dwh_conn_id=dwh_conn_id,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        read_right_users=read_right_users,
        **additional_task_args,
    )

    base_config = deepcopy(additional_task_args)
    base_config.update(operator_config.get("general_config", {}))
    with dag:
        for table in operator_config["tables"].keys():
            table_config = deepcopy(base_config)
            table_config.update(operator_config["tables"][table] or {})
            table_config.update({
                "task_id":
                "extract_load_" + re.sub(r"[^a-zA-Z0-9_]", "", table),
                "dwh_engine":
                dwh_engine,
                "dwh_conn_id":
                dwh_conn_id,
                "extract_strategy":
                table_config.get(  # Default to full refresh
                    "extract_strategy", None) or
                EC.ES_FULL_REFRESH,  # value can be given as None in table conf
                "target_table_name":
                operator_config["tables"][table].get("target_table_name",
                                                     table),
                "target_schema_name":
                target_schema_name,
                "target_schema_suffix":
                target_schema_suffix,
                "target_database_name":
                target_database_name,
            })
            # Atomic DAG only works with full refresh and subsequent strategies!
            assert table_config["extract_strategy"] in (
                EC.ES_FULL_REFRESH,
                EC.ES_SUBSEQUENT,
            )
            table_config["load_strategy"] = table_config.get(
                "load_strategy",
                EC.DEFAULT_LS_PER_ES[table_config["extract_strategy"]],
            )
            table_task = el_operator(**table_config)
            kickoff >> table_task >> final

    return (dag, )
Exemple #4
0
    def __init__(
            self,
            source_conn_id,
            dwh_engine,
            dwh_conn_id,
            extract_strategy,
            target_table_name,
            target_schema_name,
            target_schema_suffix="_next",
            target_database_name=None,  # Only for Snowflake
            load_data_from=None,  # set a minimum date e.g. for reloading of data
            reload_data_from=None,  # load data from this date for new tables
            load_data_from_relative=None,  # optional timedelta for incremental
            load_data_until=None,  # set a maximum date
            load_data_until_relative=None,  # optional timedelta for incremental
            load_data_chunking_timedelta=None,  # optional timedelta to chunk by
            columns_definition=None,
            update_on_columns=None,
            primary_key_column_name=None,
            clean_data_before_upload=True,
            exclude_columns=[],  # list of columns to exclude, if no
            # columns_definition was supplied (e.g. for select * with sql)
        index_columns=[],  # list of columns to create an index on. can be
            # an expression, must be quoted in list if quoting is required.
        hash_columns=None,  # str or list of str - columns to hash pre-upload
            hashlib_func_name="sha256",  # specify hashlib hashing function
            wait_for_seconds=120,  # seconds past next_execution_date to wait until
            # wait_for_seconds only applies for incremental loads
        add_metadata=True,
            rename_columns: Optional[Dict[str, str]] = None,  # Rename columns
            *args,
            **kwargs):
        super().__init__(*args, **kwargs)

        assert not (rename_columns and columns_definition)
        assert rename_columns is None or isinstance(rename_columns, dict)

        for item in EWAHBaseOperator.template_fields:
            # Make sure template_fields was not overwritten
            _msg = "Operator must not overwrite template_fields!"
            assert item in self.template_fields, _msg

        _msg = 'param "wait_for_seconds" must be a nonnegative integer!'
        assert isinstance(wait_for_seconds,
                          int) and wait_for_seconds >= 0, _msg
        _msg = "extract_strategy {0} not accepted for this operator!".format(
            extract_strategy, )
        assert self._ACCEPTED_EXTRACT_STRATEGIES.get(extract_strategy), _msg

        if hash_columns and not clean_data_before_upload:
            _msg = "column hashing is only possible with data cleaning!"
            raise Exception(_msg)
        elif isinstance(hash_columns, str):
            hash_columns = [hash_columns]
        if hashlib_func_name:
            _msg = "Invalid hashing function: hashlib.{0}()"
            _msg = _msg.format(hashlib_func_name)
            assert hasattr(hashlib, hashlib_func_name), _msg

        if columns_definition and exclude_columns:
            raise Exception("Must not supply both columns_definition and " +
                            "exclude_columns!")

        if not dwh_engine or not dwh_engine in EC.DWH_ENGINES:
            _msg = "Invalid DWH Engine: {0}\n\nAccepted Engines:\n\t{1}".format(
                str(dwh_engine),
                "\n\t".join(EC.DWH_ENGINES),
            )
            raise Exception(_msg)

        if index_columns and not dwh_engine == EC.DWH_ENGINE_POSTGRES:
            raise Exception("Indices are only allowed for PostgreSQL DWHs!")

        if (not dwh_engine
                == EC.DWH_ENGINE_SNOWFLAKE) and target_database_name:
            raise Exception('Received argument for "target_database_name"!')

        if self._REQUIRES_COLUMNS_DEFINITION:
            if not columns_definition:
                raise Exception("This operator requires the argument " +
                                "columns_definition!")

        if primary_key_column_name and update_on_columns:
            raise Exception("Cannot supply BOTH primary_key_column_name AND" +
                            " update_on_columns!")

        if not extract_strategy in [EC.ES_FULL_REFRESH]:
            # Required settings for incremental loads
            # Update condition for new load strategies as required
            if not (update_on_columns or primary_key_column_name or
                    (columns_definition and (0 < sum([
                        bool(columns_definition[col].get(EC.QBC_FIELD_PK))
                        for col in list(columns_definition.keys())
                    ])))):
                raise Exception(
                    "If this is incremental loading of a table, " +
                    "one of the following is required:" +
                    "\n- List of columns to update on (update_on_columns)" +
                    "\n- Name of the primary key (primary_key_column_name)" +
                    "\n- Column definition (columns_definition) that includes"
                    + " the primary key(s)")

        _msg = "load_data_from_relative and load_data_until_relative must be"
        _msg += " timedelta if supplied!"
        assert isinstance(
            load_data_from_relative,
            (type(None), timedelta),
        ), _msg
        assert isinstance(
            load_data_until_relative,
            (type(None), timedelta),
        ), _msg
        _msg = "load_data_chunking_timedelta must be timedelta!"
        assert isinstance(
            load_data_chunking_timedelta,
            (type(None), timedelta),
        ), _msg

        self.source_conn_id = source_conn_id
        self.dwh_engine = dwh_engine
        self.dwh_conn_id = dwh_conn_id
        self.extract_strategy = extract_strategy
        self.target_table_name = target_table_name
        self.target_schema_name = target_schema_name
        self.target_schema_suffix = target_schema_suffix
        self.target_database_name = target_database_name
        self.load_data_from = load_data_from
        self.reload_data_from = reload_data_from
        self.load_data_from_relative = load_data_from_relative
        self.load_data_until = load_data_until
        self.load_data_until_relative = load_data_until_relative
        self.load_data_chunking_timedelta = load_data_chunking_timedelta
        self.columns_definition = columns_definition
        if (not update_on_columns) and primary_key_column_name:
            if type(primary_key_column_name) == str:
                update_on_columns = [primary_key_column_name]
            elif type(primary_key_column_name) in (list, tuple):
                update_on_columns = primary_key_column_name
        self.update_on_columns = update_on_columns
        self.clean_data_before_upload = clean_data_before_upload
        self.primary_key_column_name = primary_key_column_name  # may be used ...
        #   ... by a child class at execution!
        self.exclude_columns = exclude_columns
        self.index_columns = index_columns
        self.hash_columns = hash_columns
        self.hashlib_func_name = hashlib_func_name
        self.wait_for_seconds = wait_for_seconds
        self.add_metadata = add_metadata
        self.rename_columns = rename_columns

        self.uploader = get_uploader(self.dwh_engine)

        _msg = "DWH hook does not support extract strategy {0}!".format(
            extract_strategy, )
Exemple #5
0
def dag_factory_mixed(
    dag_name: str,
    dwh_engine: str,
    dwh_conn_id: str,
    airflow_conn_id: str,
    start_date: datetime,
    el_operator: Type[EWAHBaseOperator],
    operator_config: dict,
    target_schema_name: str,
    target_schema_suffix: str = "_next",
    target_database_name: Optional[str] = None,
    default_args: Optional[dict] = None,
    schedule_interval_full_refresh: timedelta = timedelta(days=1),
    schedule_interval_incremental: timedelta = timedelta(hours=1),
    end_date: Optional[datetime] = None,
    read_right_users: Optional[Union[List[str], str]] = None,
    additional_dag_args: Optional[dict] = None,
    additional_task_args: Optional[dict] = None,
    logging_func: Optional[Callable] = None,
    dagrun_timeout_factor: Optional[float] = None,
    task_timeout_factor: Optional[float] = 0.8,
    **kwargs,
) -> Tuple[DAG, DAG]:
    def raise_exception(msg: str) -> None:
        """Add information to error message before raising."""
        raise Exception("DAG: {0} - Error: {1}".format(dag_name, msg))

    logging_func = logging_func or print

    if kwargs:
        logging_func("unused config: {0}".format(str(kwargs)))

    additional_dag_args = additional_dag_args or {}
    additional_task_args = additional_task_args or {}

    if not read_right_users is None:
        if isinstance(read_right_users, str):
            read_right_users = [u.strip() for u in read_right_users.split(",")]
        if not isinstance(read_right_users, Iterable):
            raise_exception("read_right_users must be an iterable or string!")
    if not isinstance(schedule_interval_full_refresh, timedelta):
        raise_exception("schedule_interval_full_refresh must be timedelta!")
    if not isinstance(schedule_interval_incremental, timedelta):
        raise_exception("schedule_interval_incremental must be timedelta!")
    if schedule_interval_incremental >= schedule_interval_full_refresh:
        _msg = "schedule_interval_incremental must be shorter than "
        _msg += "schedule_interval_full_refresh!"
        raise_exception(_msg)

    """Calculate the datetimes and timedeltas for the two DAGs.

    Full Refresh: The start_date should be chosen such that there is always only
    one DAG execution to be executed at any given point in time.
    See dag_factory_atomic for the same calculation including detailed comments.

    The Incremental DAG starts at the start date + schedule interval of the
    Full Refresh DAG, so that the Incremental executions only happen after
    the Full Refresh execution.
    """
    if not start_date.tzinfo:
        # if no timezone is given, assume UTC
        raise_exception("start_date must be timezone aware!")
    time_now = datetime_utcnow_with_tz()

    if end_date:
        end_date = min(end_date, time_now)
    else:
        end_date = time_now

    if start_date > time_now:
        # Start date for both is in the future
        # start_date_fr = start_date
        # start_date_inc = start_date
        pass
    else:
        start_date += (
            int((end_date - start_date) / schedule_interval_full_refresh)
            * schedule_interval_full_refresh
        )
        if start_date == end_date:
            start_date -= schedule_interval_full_refresh
        else:
            start_date -= schedule_interval_full_refresh
            end_date = (
                start_date + 2 * schedule_interval_full_refresh - timedelta(seconds=1)
            )

        # _td = int((time_now - start_date) / schedule_interval_full_refresh) - 2
        # start_date_fr = start_date + _td * schedule_interval_full_refresh
        # start_date_inc = start_date_fr + schedule_interval_full_refresh

    default_args = default_args or {}
    default_args_fr = deepcopy(default_args)
    default_args_inc = deepcopy(default_args)

    if dagrun_timeout_factor:
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(dagrun_timeout_factor, (int, float)) and (
            0 < dagrun_timeout_factor <= 1
        ), _msg
        dagrun_timeout_inc = dagrun_timeout_factor * schedule_interval_incremental
        dagrun_timeout_fr = dagrun_timeout_factor * schedule_interval_full_refresh
    else:  # In case of 0 set to None
        dagrun_timeout_inc = None
        dagrun_timeout_fr = None

    if task_timeout_factor:
        _msg = "task_timeout_factor must be a number between 0 and 1!"
        assert isinstance(task_timeout_factor, (int, float)) and (
            0 < task_timeout_factor <= 1
        ), _msg
        execution_timeout_fr = task_timeout_factor * schedule_interval_full_refresh
        execution_timeout_inc = task_timeout_factor * schedule_interval_incremental
    else:
        execution_timeout_fr = None
        execution_timeout_inc = None

    dag_name_fr = dag_name + "_Mixed_Atomic"
    dag_name_inc = dag_name + "_Mixed_Idempotent"
    dags = (
        DAG(
            dag_name_fr,
            start_date=start_date,
            end_date=end_date,
            schedule_interval=schedule_interval_full_refresh,
            catchup=True,
            max_active_runs=1,
            default_args=default_args_fr,
            dagrun_timeout=dagrun_timeout_fr,
            **additional_dag_args,
        ),
        DAG(
            dag_name_inc,
            start_date=start_date + schedule_interval_full_refresh,
            end_date=start_date + 2 * schedule_interval_full_refresh,
            schedule_interval=schedule_interval_incremental,
            catchup=True,
            max_active_runs=1,
            default_args=default_args_inc,
            dagrun_timeout=dagrun_timeout_inc,
            **additional_dag_args,
        ),
        DAG(  # Reset DAG
            dag_name + "_Mixed_Reset",
            start_date=start_date,
            end_date=end_date,
            schedule_interval=None,
            catchup=False,
            max_active_runs=1,
            default_args=default_args,
            **additional_dag_args,
        ),
    )

    # Create reset DAG
    reset_bash_command = " && ".join(  # First pause DAGs, then delete their metadata
        [
            "airflow dags pause {dag_name}_Mixed_Atomic",
            "airflow dags pause {dag_name}_Mixed_Idempotent",
            "airflow dags delete {dag_name}_Mixed_Atomic -y",
            "airflow dags delete {dag_name}_Mixed_Idempotent -y",
        ]
    ).format(dag_name=dag_name)
    reset_task = BashOperator(
        bash_command=reset_bash_command,
        task_id="reset_by_deleting_all_task_instances",
        dag=dags[2],
        **additional_task_args,
    )
    drop_sql = """
        DROP SCHEMA IF EXISTS "{target_schema_name}" CASCADE;
        DROP SCHEMA IF EXISTS "{target_schema_name}{suffix}" CASCADE;
    """.format(
        target_schema_name=target_schema_name,
        suffix=target_schema_suffix,
    )
    if dwh_engine == EC.DWH_ENGINE_POSTGRES:
        drop_task = PGO(
            sql=drop_sql,
            postgres_conn_id=dwh_conn_id,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
        drop_task = SnowflakeOperator(
            sql=drop_sql,
            snowflake_conn_id=dwh_conn_id,
            database=target_database_name,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    else:
        raise_exception(f'DWH "{dwh_engine}" not implemented for this task!')

    kickoff_fr, final_fr = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[0],
        dwh_engine=dwh_engine,
        dwh_conn_id=dwh_conn_id,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout_fr,
        **additional_task_args,
    )

    kickoff_inc, final_inc = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[1],
        dwh_engine=dwh_engine,
        dwh_conn_id=dwh_conn_id,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout_inc,
        **additional_task_args,
    )

    sql_fr = """
        SELECT
             -- only run if there are no active DAGs that have to finish first
            CASE WHEN COUNT(*) = 0 THEN 1 ELSE 0 END
        FROM public.dag_run
        WHERE state = 'running'
          AND (
                (dag_id = '{0}' AND data_interval_start < '{1}')
            OR  (dag_id = '{2}' AND data_interval_start < '{3}')
          )
    """.format(
        dags[0]._dag_id,  # fr
        "{{ data_interval_start }}",  # no previous full refresh, please!
        dags[1]._dag_id,  # inc
        "{{ data_interval_end }}",  # no old incremental running, please!
    )

    # Sense if a previous instance runs OR if any incremental loads run
    # except incremental load of the same time, which is expected and waits
    fr_snsr = EWAHSqlSensor(
        task_id="sense_run_validity",
        conn_id=airflow_conn_id,
        sql=sql_fr,
        dag=dags[0],
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        **additional_task_args,
    )

    # Sense if a previous instance is complete excepts if its the first, then
    # check for a full refresh of the same time
    inc_ets = ExtendedETS(
        task_id="sense_run_validity",
        allowed_states=["success"],
        external_dag_id=dags[1]._dag_id,
        external_task_id=final_inc.task_id,
        execution_delta=schedule_interval_incremental,
        backfill_dag_id=dags[0]._dag_id,
        backfill_external_task_id=final_fr.task_id,
        backfill_execution_delta=schedule_interval_full_refresh,
        dag=dags[1],
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        **additional_task_args,
    )

    fr_snsr >> kickoff_fr
    inc_ets >> kickoff_inc

    for table in operator_config["tables"].keys():
        arg_dict_inc = deepcopy(additional_task_args)
        arg_dict_inc.update(operator_config.get("general_config", {}))
        op_conf = operator_config["tables"][table] or {}
        arg_dict_inc.update(op_conf)
        arg_dict_inc.update(
            {
                "extract_strategy": arg_dict_inc.get(
                    "extract_strategy", EC.ES_INCREMENTAL
                ),
                "task_id": "extract_load_" + re.sub(r"[^a-zA-Z0-9_]", "", table),
                "dwh_engine": dwh_engine,
                "dwh_conn_id": dwh_conn_id,
                "target_table_name": op_conf.get("target_table_name", table),
                "target_schema_name": target_schema_name,
                "target_schema_suffix": target_schema_suffix,
                "target_database_name": target_database_name,
            }
        )
        arg_dict_inc["load_strategy"] = arg_dict_inc.get(
            "load_strategy", EC.DEFAULT_LS_PER_ES[arg_dict_inc["extract_strategy"]]
        )
        arg_dict_fr = deepcopy(arg_dict_inc)
        arg_dict_fr["extract_strategy"] = EC.ES_FULL_REFRESH
        arg_dict_fr["load_strategy"] = EC.LS_INSERT_REPLACE

        arg_dict_fr["execution_timeout"] = execution_timeout_fr
        arg_dict_inc["execution_timeout"] = execution_timeout_inc

        task_fr = el_operator(dag=dags[0], **arg_dict_fr)
        task_inc = el_operator(dag=dags[1], **arg_dict_inc)

        kickoff_fr >> task_fr >> final_fr
        kickoff_inc >> task_inc >> final_inc

    return dags