Beispiel #1
0
    def ewah_execute(self, context):
        data_from = self.data_from or context["dag"].start_date
        data_until = self.data_until or datetime_utcnow_with_tz()

        format_str = "%Y-%m-%d"
        currency_str = "{0}{1}=X".format(*self.currency_pair)
        data = YahooFinancials([currency_str]).get_historical_price_data(
            data_from.strftime(format_str),
            data_until.strftime(format_str),
            self.frequency,
        )
        self.upload_data(data[currency_str]["prices"])
Beispiel #2
0
    def execute(self, context):
        """Why this method is defined here:
        When executing a task, airflow calls this method. Generally, this
        method contains the "business logic" of the individual operator.
        However, EWAH may want to do some actions for all operators. Thus,
        the child operators shall have an ewah_execute() function which is
        called by this general execute() method.
        """

        self.log.info("""

            Running EWAH Operator {0}.
            DWH: {1} (connection id: {2})
            Extract Strategy: {3}
            Load Strategy: {4}

            """.format(
            str(self),
            self.dwh_engine,
            self.dwh_conn_id,
            self.extract_strategy,
            self.load_strategy,
        ))

        # required for metadata in data upload
        self._execution_time = datetime_utcnow_with_tz()
        self._context = context

        cleaner_callables = self.cleaner_callables or []

        if self.source_conn_id:
            # resolve conn id here & delete the object to avoid usage elsewhere
            self.source_conn = EWAHBaseHook.get_connection(self.source_conn_id)
            self.source_hook = self.source_conn.get_hook()
            if callable(
                    getattr(self.source_hook, "get_cleaner_callables", None)):
                hook_callables = self.source_hook.get_cleaner_callables()
                if callable(hook_callables):
                    cleaner_callables.append(hook_callables)
                elif hook_callables:
                    # Ought to be list of callables
                    cleaner_callables += hook_callables
        del self.source_conn_id

        if self._CONN_TYPE:
            assert (self._CONN_TYPE == self.source_conn.conn_type
                    ), "Error - connection type must be {0}!".format(
                        self._CONN_TYPE)

        uploader_callables = self.uploader_class.get_cleaner_callables()
        if callable(uploader_callables):
            cleaner_callables.append(uploader_callables)
        elif uploader_callables:
            cleaner_callables += uploader_callables

        self.uploader = self.uploader_class(
            dwh_conn=EWAHBaseHook.get_connection(self.dwh_conn_id),
            cleaner=self.cleaner_class(
                default_row=self.default_values,
                include_columns=self.include_columns,
                exclude_columns=self.exclude_columns,
                add_metadata=self.add_metadata,
                rename_columns=self.rename_columns,
                hash_columns=self.hash_columns,
                hash_salt=self.hash_salt,
                additional_callables=cleaner_callables,
            ),
            table_name=self.target_table_name,
            schema_name=self.target_schema_name,
            schema_suffix=self.target_schema_suffix,
            database_name=self.target_database_name,
            primary_key=self.primary_key,
            load_strategy=self.load_strategy,
            use_temp_pickling=self.use_temp_pickling,
            pickling_upload_chunk_size=self.pickling_upload_chunk_size,
            pickle_compression=self.pickle_compression,
            deduplication_before_upload=self.deduplication_before_upload,
            **self.additional_uploader_kwargs,
        )

        # If applicable: set the session's default time zone
        if self.default_timezone:
            self.uploader.dwh_hook.execute("SET timezone TO '{0}'".format(
                self.default_timezone))

        # Create a new copy of the target table.
        # This is so data is loaded into a new table and if data loading
        # fails, the original data is not corrupted. At a new try or re-run,
        # the original table is just copied anew.
        if not self.load_strategy == EC.LS_INSERT_REPLACE:
            # insert_replace always drops and replaces the tables completely
            self.uploader.copy_table()

        # set load_data_from and load_data_until as required
        data_from = ada(self.load_data_from)
        data_until = ada(self.load_data_until)
        if self.extract_strategy == EC.ES_INCREMENTAL:
            _tdz = timedelta(days=0)  # aka timedelta zero
            _ed = context["data_interval_start"]
            _ned = context["data_interval_end"]

            # normal incremental load
            _ed -= self.load_data_from_relative or _tdz
            data_from = min(_ed, data_from or _ed)
            if not self.test_if_target_table_exists():
                # Load data from scratch!
                data_from = ada(self.reload_data_from) or data_from

            _ned += self.load_data_until_relative or _tdz
            data_until = max(_ned, data_until or _ned)

        elif self.extract_strategy in (EC.ES_FULL_REFRESH, EC.ES_SUBSEQUENT):
            # Values may still be set as static values
            data_from = ada(self.reload_data_from) or data_from

        else:
            _msg = "Must define load_data_from etc. behavior for load strategy!"
            raise Exception(_msg)

        self.data_from = data_from
        self.data_until = data_until
        # del variables to make sure they are not used later on
        del self.load_data_from
        del self.reload_data_from
        del self.load_data_until
        del self.load_data_until_relative
        if not self.extract_strategy == EC.ES_SUBSEQUENT:
            # keep this param for subsequent loads
            del self.load_data_from_relative

        # Have an option to wait until a short period (e.g. 2 minutes) past
        # the incremental loading range timeframe to ensure that all data is
        # loaded, useful e.g. if APIs lag or if server timestamps are not
        # perfectly accurate.
        # When a DAG is executed as soon as possible, some data sources
        # may not immediately have up to date data from their API.
        # E.g. querying all data until 12.30pm only gives all relevant data
        # after 12.32pm due to some internal delays. In those cases, make
        # sure the (incremental loading) DAGs don't execute too quickly.
        if self.wait_for_seconds and self.extract_strategy == EC.ES_INCREMENTAL:
            wait_until = context.get("data_interval_end")
            if wait_until:
                wait_until += timedelta(seconds=self.wait_for_seconds)
                self.log.info("Awaiting execution until {0}...".format(
                    str(wait_until), ))
            while wait_until and datetime_utcnow_with_tz() < wait_until:
                # Only sleep a maximum of 5s at a time
                wait_for_timedelta = wait_until - datetime_utcnow_with_tz()
                time.sleep(max(0, min(wait_for_timedelta.total_seconds(), 5)))

        # execute operator
        if self.load_data_chunking_timedelta and data_from and data_until:
            # Chunking to avoid OOM
            assert data_until > data_from
            assert self.load_data_chunking_timedelta > timedelta(days=0)
            while self.data_from < data_until:
                self.data_until = min(
                    self.data_from + self.load_data_chunking_timedelta,
                    data_until)
                self.log.info("Now loading from {0} to {1}...".format(
                    str(self.data_from), str(self.data_until)))
                self.ewah_execute(context)
                self.data_from += self.load_data_chunking_timedelta
        else:
            self.ewah_execute(context)

        # Run final scripts
        # TODO: Include indexes into uploader and then remove this step
        self.uploader.finalize_upload()

        # if PostgreSQL and arg given: create indices
        for column in self.index_columns:
            assert self.dwh_engine == EC.DWH_ENGINE_POSTGRES
            # Use hashlib to create a unique 63 character string as index
            # name to avoid breaching index name length limits & accidental
            # duplicates / missing indices due to name truncation leading to
            # identical index names.
            self.uploader.dwh_hook.execute(
                self._INDEX_QUERY.format(
                    "__ewah_" + hashlib.blake2b(
                        (self.target_schema_name + self.target_schema_suffix +
                         "." + self.target_table_name + "." + column).encode(),
                        digest_size=28,
                    ).hexdigest(),
                    self.target_schema_name + self.target_schema_suffix,
                    self.target_table_name,
                    column,
                ))

        # commit only at the end, so that no data may be committed before an
        # error occurs.
        self.log.info("Now committing changes!")
        self.uploader.commit()
        self.uploader.close()
def dag_factory_idempotent(
    dag_name: str,
    dwh_engine: str,
    dwh_conn_id: str,
    start_date: datetime,
    el_operator: Type[EWAHBaseOperator],
    operator_config: dict,
    target_schema_name: str,
    target_schema_suffix: str = "_next",
    target_database_name: Optional[str] = None,
    default_args: Optional[dict] = None,
    schedule_interval_backfill: timedelta = timedelta(days=1),
    schedule_interval_future: timedelta = timedelta(hours=1),
    end_date: Optional[datetime] = None,
    read_right_users: Optional[Union[List[str], str]] = None,
    additional_dag_args: Optional[dict] = None,
    additional_task_args: Optional[dict] = None,
    logging_func: Optional[Callable] = None,
    dagrun_timeout_factor: Optional[float] = None,
    task_timeout_factor: Optional[float] = 0.8,
    **kwargs,
) -> Tuple[DAG, DAG, DAG]:
    """Returns a tuple of three DAGs associated with incremental data loading.

    The three DAGs are:
    - Reset DAG
    - Backfill DAG
    - Current DAG

    The Reset DAG pauses the other two DAGs, deletes all DAG statistics and
    data, and deletes all data related to the DAGs from the DWH.

    The Backfill DAG runs in a long schedule interval (e.g. a week) from
    start_date on. Each run of this DAG fetches a relatively long period worth
    of data. The purpose is to Backfill the DWH.

    The Current DAG runs in a short schedule interval (e.g. one hour). It has a
    dynamic start date which is the end of the last full schedule interval
    period of the backfill DAG. This DAG keeps the data in the DWH fresh.

    :param dag_name: Base name of the DAG. The returned DAGs will be named
        after dag_nme with the suffixes "_Idempotent_Reset",
        "_Idempotent_Backfill", and "_Idempotent".
    :param dwh_engine: Type of the DWH (e.g. postgresql).
    :param dwh_conn_id: Airflow connection ID with DWH credentials.
    :param start_date: Start date of the DAGs (i.e. of the Backfill DAG).
    :param el_operator: A subclass of EWAHBaseOperator that is used to load
        the individual tables.
    :param target_schema_name: Name of the schema in the DWH that receives the data.
    :param target_schema_suffix: Suffix used during data loading process. The DAG
        creates a new schema "{target_schema_name}{target_schema_suffix}" during
        loading.
    :param target_database_name: Name of the database (Snowflake) or dataset
        (BigQuery), if applicable.
    :param default_args: A dictionary given to the DAGs as default_args param.
    :param schedule_interval_backfill: The schedule interval of the Backfill
        DAG. Must be at least 1 day. Must be larger than
        schedule_interval_future.
    :param schedule_interval_future: The schedule interval of the Current DAG.
        Must be smaller than schedule_interval_backfill. It is recommended not
        to go below 30 minutes. An appropriate schedule interval can be found
        via trial and error. The Current DAG runtime must be less than this
        param in order for EWAH to work properly.
    :param end_date: Airflow DAG kwarg end_date.
    :param read_right_users: List of strings of users or roles that should
        receive read rights on the loaded tables. Can also be a comma-separated
        string instead of a list of strings.
    :param additional_dag_args: kwargs applied to the DAG. Can be any DAG
        kwarg that is not used directly within the function.
    :param additional_task_args: kwargs applied to the tasks. Can be any Task
        kwarg, although some may be overwritten by the function.
    :param logging_func: Pass a callable for logging output. Defaults to print.
    :param dag_timeout_factor: Set a timeout factor for dag runs so they fail if
        they exceed a percentage of their schedule_interval (default: 0.8).
    """

    def raise_exception(msg: str) -> None:
        """Add information to error message before raising."""
        raise Exception("DAG: {0} - Error: {1}".format(dag_name, msg))

    logging_func = logging_func or print

    if kwargs:
        logging_func("unused config: {0}".format(str(kwargs)))

    additional_dag_args = additional_dag_args or {}
    additional_task_args = additional_task_args or {}

    if not isinstance(schedule_interval_future, timedelta):
        raise_exception("Schedule intervals must be datetime.timedelta!")
    if not isinstance(schedule_interval_backfill, timedelta):
        raise_exception("Schedule intervals must be datetime.timedelta!")
    if schedule_interval_backfill < timedelta(days=1):
        raise_exception("Backfill schedule interval cannot be below 1 day!")
    if schedule_interval_backfill < schedule_interval_future:
        raise_exception(
            "Backfill schedule interval must be larger than"
            + " regular schedule interval!"
        )
    if not operator_config.get("tables"):
        raise_exception('Requires a "tables" dictionary in operator_config!')
    if not read_right_users is None:
        if isinstance(read_right_users, str):
            read_right_users = [u.strip() for u in read_right_users.split(",")]
        if not isinstance(read_right_users, Iterable):
            raise_exception("read_right_users must be an iterable or string!")

    current_time = datetime_utcnow_with_tz()
    if not start_date.tzinfo:
        raise_exception("start_date must be timezone aware!")

    # Make switch halfway between latest normal DAG run and the
    #   data_interval_end of the next-to-run backfill DAG
    #   --> no interruption of the system, airflow has time to register
    #   the change, the backfill DAG can run once unimpeded and the
    #   normal DAG can then resume as per normal. Note: in that case,
    #   keep both DAGs active!
    current_time += schedule_interval_future / 2
    # How much time has passed in total between start_date and now?
    switch_absolute_date = current_time - start_date
    # How often could the backfill DAG run in that time frame?
    switch_absolute_date /= schedule_interval_backfill
    switch_absolute_date = int(switch_absolute_date)
    # What is the exact datetime after the last of those runs?
    switch_absolute_date *= schedule_interval_backfill
    switch_absolute_date += start_date
    # --> switch_absolute_date is always in the (recent) past

    # Make sure that the backfill and normal DAG start_date and
    #   schedule_interval calculations were successful and correct
    backfill_timedelta = switch_absolute_date - start_date
    backfill_tasks_count = backfill_timedelta / schedule_interval_backfill

    if end_date:
        backfill_end_date = min(switch_absolute_date, end_date)
    else:
        backfill_end_date = switch_absolute_date

    if dagrun_timeout_factor:
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(dagrun_timeout_factor, (int, float)) and (
            0 < dagrun_timeout_factor <= 1
        ), _msg
        dagrun_timeout = additional_dag_args.get(
            "dagrun_timeout", dagrun_timeout_factor * schedule_interval_future
        )
        dagrun_timeout_backfill = additional_dag_args.pop(
            "dagrun_timeout", dagrun_timeout_factor * schedule_interval_future
        )
    else:
        dagrun_timeout = additional_dag_args.get("dagrun_timeout")
        dagrun_timeout_backfill = additional_dag_args.pop("dagrun_timeout", None)

    if task_timeout_factor:
        execution_timeout = additional_task_args.get(
            "execution_timeout", task_timeout_factor * schedule_interval_future
        )
        execution_timeout_backfill = additional_task_args.pop(
            "execution_timeout", task_timeout_factor * schedule_interval_backfill
        )
    else:
        execution_timeout = additional_task_args.get("execution_timeout")
        execution_timeout_backfill = additional_task_args.pop("execution_timeout", None)

    dags = (
        DAG(  # Current DAG
            dag_name + "_Idempotent",
            start_date=switch_absolute_date,
            end_date=end_date,
            schedule_interval=schedule_interval_future,
            catchup=True,
            max_active_runs=1,
            default_args=default_args,
            dagrun_timeout=dagrun_timeout,
            **additional_dag_args,
        ),
        DAG(  # Backfill DAG
            dag_name + "_Idempotent_Backfill",
            start_date=start_date,
            end_date=backfill_end_date,
            schedule_interval=schedule_interval_backfill,
            catchup=True,
            max_active_runs=1,
            default_args=default_args,
            dagrun_timeout=dagrun_timeout_backfill,
            **additional_dag_args,
        ),
        DAG(  # Reset DAG
            dag_name + "_Idempotent_Reset",
            start_date=start_date,
            end_date=end_date,
            schedule_interval=None,
            catchup=False,
            max_active_runs=1,
            default_args=default_args,
            **additional_dag_args,
        ),
    )

    # Create reset DAG
    reset_bash_command = " && ".join(  # First pause DAGs, then delete their metadata
        [
            "airflow dags pause {dag_name}_Idempotent",
            "airflow dags pause {dag_name}_Idempotent_Backfill",
            "airflow dags delete {dag_name}_Idempotent -y",
            "airflow dags delete {dag_name}_Idempotent_Backfill -y",
        ]
    ).format(dag_name=dag_name)
    reset_task = BashOperator(
        bash_command=reset_bash_command,
        task_id="reset_by_deleting_all_task_instances",
        dag=dags[2],
        **additional_task_args,
    )
    drop_sql = """
        DROP SCHEMA IF EXISTS "{target_schema_name}" CASCADE;
        DROP SCHEMA IF EXISTS "{target_schema_name}{suffix}" CASCADE;
    """.format(
        target_schema_name=target_schema_name,
        suffix=target_schema_suffix,
    )
    if dwh_engine == EC.DWH_ENGINE_POSTGRES:
        drop_task = PGO(
            sql=drop_sql,
            postgres_conn_id=dwh_conn_id,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
        drop_task = SnowflakeOperator(
            sql=drop_sql,
            snowflake_conn_id=dwh_conn_id,
            database=target_database_name,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    else:
        drop_sql = """
            DROP SCHEMA IF EXISTS `{0}` CASCADE;
            DROP SCHEMA IF EXISTS `{1}` CASCADE;
        """.format(
            target_schema_name,
            target_schema_name + target_schema_suffix,
        )
        drop_task = BigqueryOperator(
            sql=drop_sql,
            bigquery_conn_id=dwh_conn_id,
            project=target_database_name,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )

    reset_task >> drop_task

    # Incremental DAG schema tasks
    kickoff, final = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[0],
        dwh_engine=dwh_engine,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        dwh_conn_id=dwh_conn_id,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout,
        **additional_task_args,
    )

    # Backfill DAG schema tasks
    kickoff_backfill, final_backfill = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[1],
        dwh_engine=dwh_engine,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        dwh_conn_id=dwh_conn_id,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout_backfill,
        **additional_task_args,
    )

    # add table creation tasks
    arg_dict = deepcopy(additional_task_args)
    arg_dict.update(operator_config.get("general_config", {}))
    # Default reload_data_from to start_date
    arg_dict["reload_data_from"] = arg_dict.get("reload_data_from", start_date)
    for table in operator_config["tables"].keys():
        kwargs = deepcopy(arg_dict)
        kwargs.update(operator_config["tables"][table] or {})

        # Overwrite / ignore changes to these kwargs:
        kwargs.update(
            {
                "extract_strategy": kwargs.get("extract_strategy", EC.ES_INCREMENTAL),
                "task_id": "extract_load_" + re.sub(r"[^a-zA-Z0-9_]", "", table),
                "dwh_engine": dwh_engine,
                "dwh_conn_id": dwh_conn_id,
                "target_table_name": operator_config["tables"][table].get(
                    "target_table_name", table
                ),
                "target_schema_name": target_schema_name,
                "target_schema_suffix": target_schema_suffix,
                "target_database_name": target_database_name,
            }
        )
        assert kwargs["extract_strategy"] in (
            EC.ES_FULL_REFRESH,
            EC.ES_SUBSEQUENT,
            EC.ES_INCREMENTAL,
        )
        kwargs["load_strategy"] = kwargs.get(
            "load_strategy",
            EC.DEFAULT_LS_PER_ES[kwargs["extract_strategy"]],
        )

        if kwargs["extract_strategy"] == EC.ES_INCREMENTAL:
            # Backfill ignores non-incremental extract strategy types
            task_backfill = el_operator(
                dag=dags[1], execution_timeout=execution_timeout_backfill, **kwargs
            )
            kickoff_backfill >> task_backfill >> final_backfill

        task = el_operator(dag=dags[0], execution_timeout=execution_timeout, **kwargs)
        kickoff >> task >> final

    # For the unlikely case that there is no incremental task
    kickoff_backfill >> final_backfill

    # Make sure incremental loading stops if there is an error!
    if additional_task_args.get("task_timeout_factor"):
        # sensors shall have no timeouts!
        del additional_task_args["task_timeout_factor"]
    ets = (
        ExtendedETS(
            task_id="sense_previous_instance",
            allowed_states=["success", "skipped"],
            external_dag_id=dags[0]._dag_id,
            external_task_id=final.task_id,
            execution_delta=schedule_interval_future,
            backfill_dag_id=dags[1]._dag_id,
            backfill_external_task_id=final_backfill.task_id,
            backfill_execution_delta=schedule_interval_backfill,
            dag=dags[0],
            poke_interval=5 * 60,
            mode="reschedule",  # don't block a worker and pool slot
            **additional_task_args,
        ),
        ExtendedETS(
            task_id="sense_previous_instance",
            allowed_states=["success", "skipped"],
            external_dag_id=dags[1]._dag_id,
            external_task_id=final_backfill.task_id,
            execution_delta=schedule_interval_backfill,
            dag=dags[1],
            poke_interval=5 * 60,
            mode="reschedule",  # don't block a worker and pool slot
            **additional_task_args,
        ),
    )
    ets[0] >> kickoff
    ets[1] >> kickoff_backfill

    return dags
Beispiel #4
0
def dbt_dags_factory(
        airflow_conn_id,
        repo_type,
        dwh_engine,
        dwh_conn_id,
        database_name=None,
        git_conn_id=None,  # if provided, expecting private SSH key in conn extra
        local_path=None,
        dbt_version=None,  # Defaults to require-dbt-version in dbt_project.yml
        subfolder=None,  # optional: supply if dbt project is in a subfolder
        threads=4,  # see https://docs.getdbt.com/dbt-cli/configure-your-profile/#understanding-threads
        schema_name="analytics",  # see https://docs.getdbt.com/dbt-cli/configure-your-profile/#understanding-target-schemas
        keepalives_idle=0,  # see https://docs.getdbt.com/reference/warehouse-profiles/postgres-profile/
        dag_base_name="T_dbt_run",
        schedule_interval: Union[str, timedelta] = timedelta(hours=1),
        start_date=datetime(2019, 1, 1),
        default_args=None,
        run_flags=None,  # e.g. --model tag:base
        project=None,  # BigQuery alias
        dataset=None,  # BigQuery alias
        dagrun_timeout_factor=None,  # doesn't apply to full refresh
        task_timeout_factor=0.8,  # doesn't apply to full refresh
        metabase_conn_id=None,  # push docs to Metabase after full refresh run if exists
):
    run_flags = run_flags or ""  # use empty string instead of None

    # only PostgreSQL & Snowflake implemented as of now!
    assert dwh_engine in (
        EC.DWH_ENGINE_POSTGRES,
        EC.DWH_ENGINE_SNOWFLAKE,
        EC.DWH_ENGINE_BIGQUERY,
    )

    if isinstance(schedule_interval, str):
        # Allow using cron-style schedule intervals
        assert croniter.is_valid(
            schedule_interval
        ), "schedule_interval is neither timedelta nor not valid cron!"
        catchup = False
        end_date = None
    else:
        # if start_date is timezone offset-naive, assume utc and turn into offset-aware
        catchup = True
        if not start_date.tzinfo:
            start_date = start_date.replace(tzinfo=pytz.utc)

        start_date += (int(
            (datetime_utcnow_with_tz() - start_date) / schedule_interval) -
                       1) * schedule_interval
        end_date = start_date + 2 * schedule_interval - timedelta(seconds=1)

    dag_kwargs = {
        "catchup": catchup,
        "start_date": start_date,
        "end_date": end_date,
        "default_args": default_args,
        "max_active_runs": 1,
    }

    if dagrun_timeout_factor and isinstance(schedule_interval, timedelta):
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(
            dagrun_timeout_factor,
            (int, float)) and (0 < dagrun_timeout_factor <= 1), _msg
        dagrun_timeout = dagrun_timeout_factor * schedule_interval
    else:
        dagrun_timeout = None

    if task_timeout_factor and isinstance(schedule_interval, timedelta):
        _msg = "task_timeout_factor must be a number between 0 and 1!"
        assert isinstance(
            task_timeout_factor,
            (int, float)) and (0 < task_timeout_factor <= 1), _msg
        execution_timeout = task_timeout_factor * schedule_interval
    else:
        execution_timeout = None

    dag_1 = DAG(
        dag_base_name,
        schedule_interval=schedule_interval,
        dagrun_timeout=dagrun_timeout,
        **dag_kwargs,
    )
    dag_2 = DAG(dag_base_name + "_full_refresh",
                schedule_interval=None,
                **dag_kwargs)

    sensor_sql = """
        SELECT
            -- only succeed if there is no other running DagRun
            CASE WHEN COUNT(*) = 0 THEN 1 ELSE 0 END
        FROM public.dag_run
        WHERE dag_id IN ('{0}', '{1}')
          AND state = 'running'
          AND data_interval_start < '{2}' -- DagRun's data_interval_end
          AND NOT (run_id = '{3}' AND dag_id = '{4}')
          -- Note: data_interval_end = data_interval_start if run_type = 'manual'
    """.format(
        dag_1._dag_id,
        dag_2._dag_id,
        "{{ data_interval_end }}",
        "{{ run_id }}",
        "{{ dag._dag_id }}",
    )

    snsr_1 = EWAHSqlSensor(
        task_id="sense_dbt_conflict_avoided",
        conn_id=airflow_conn_id,
        sql=sensor_sql,
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        dag=dag_1,
    )
    snsr_2 = EWAHSqlSensor(
        task_id="sense_dbt_conflict_avoided",
        conn_id=airflow_conn_id,
        sql=sensor_sql,
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        dag=dag_2,
    )

    dbt_kwargs = {
        "repo_type": repo_type,
        "dwh_conn_id": dwh_conn_id,
        "git_conn_id": git_conn_id,
        "local_path": local_path,
        "dbt_version": dbt_version,
        "subfolder": subfolder,
        "threads": threads,
        "schema_name": schema_name,
        "keepalives_idle": 0,
        "dwh_engine": dwh_engine,
        "database_name": database_name,
        "project": project,
        "dataset": dataset,
    }

    run_1 = EWAHdbtOperator(
        task_id="dbt_run",
        dbt_commands=["seed", f"run {run_flags}"],
        dag=dag_1,
        execution_timeout=execution_timeout,
        **dbt_kwargs,
    )
    run_2 = EWAHdbtOperator(
        task_id="dbt_run",
        dbt_commands=[
            "seed --full-refresh", f"run --full-refresh {run_flags}"
        ],
        dag=dag_2,
        # If metabase_conn_id exists, push dbt docs to Metabase after full refresh run
        metabase_conn_id=metabase_conn_id,
        **dbt_kwargs,
    )

    run_flags_freshness = (run_flags.replace("--models", "--select").replace(
        "--model", "--select").replace("-m ", "-s "))
    test_1 = EWAHdbtOperator(
        task_id="dbt_test",
        dbt_commands=[
            f"test {run_flags}", f"source freshness {run_flags_freshness}"
        ],
        dag=dag_1,
        execution_timeout=execution_timeout,
        **dbt_kwargs,
    )
    test_2 = EWAHdbtOperator(task_id="dbt_test",
                             dbt_commands=f"test {run_flags}",
                             dag=dag_2,
                             **dbt_kwargs)

    snsr_1 >> run_1 >> test_1
    snsr_2 >> run_2 >> test_2

    return (dag_1, dag_2)
def dag_factory_atomic(dag_name: str,
                       dwh_engine: str,
                       dwh_conn_id: str,
                       start_date: datetime,
                       el_operator: Type[EWAHBaseOperator],
                       operator_config: dict,
                       target_schema_name: str,
                       target_schema_suffix: str = "_next",
                       target_database_name: Optional[str] = None,
                       default_args: Optional[dict] = None,
                       schedule_interval: Union[str,
                                                timedelta] = timedelta(days=1),
                       end_date: Optional[datetime] = None,
                       read_right_users: Optional[Union[List[str],
                                                        str]] = None,
                       additional_dag_args: Optional[dict] = None,
                       additional_task_args: Optional[dict] = None,
                       logging_func: Optional[Callable] = None,
                       dagrun_timeout_factor: Optional[float] = None,
                       task_timeout_factor: Optional[float] = None,
                       **kwargs) -> Tuple[DAG]:
    def raise_exception(msg: str) -> None:
        """Add information to error message before raising."""
        raise Exception("DAG: {0} - Error: {1}".format(dag_name, msg))

    logging_func = logging_func or print

    if kwargs:
        logging_func("unused config: {0}".format(str(kwargs)))

    additional_dag_args = additional_dag_args or {}
    additional_task_args = additional_task_args or {}

    if not read_right_users is None:
        if isinstance(read_right_users, str):
            read_right_users = [u.strip() for u in read_right_users.split(",")]
        if not isinstance(read_right_users, Iterable):
            raise_exception("read_right_users must be an iterable or string!")

    if isinstance(schedule_interval, str):
        # Allow using cron-style schedule intervals
        catchup = False
        assert croniter.is_valid(
            schedule_interval
        ), "schedule_interval is neither timedelta nor not valid cron!"
    else:
        assert isinstance(
            schedule_interval,
            timedelta), "schedule_interval must be cron-string or timedelta!"

        catchup = True
        # fake catchup = True: between start_date and end_date is one schedule_interval
        # --> run the full refreshs every schedule_interval at the same time instead of
        # having a drift in execution time!
        if end_date:
            end_date = min(end_date, datetime_utcnow_with_tz())
        else:
            end_date = datetime_utcnow_with_tz()

        start_date += (int((end_date - start_date) / schedule_interval) -
                       1) * schedule_interval

        # case 1: end_date = start_date + schedule_interval
        # if the division result is a precise integer, that implies a definite end_date
        # --> adjust to get exactly one schedule_interval delta between start_date and
        # end_date to have one last run available (that should have run before end_date)

        # case 2: end_date > (start_date + schedule_interval)
        # Airflow executes after data_interval_end - start_date has to be
        # between exactly 1 and below 2 time schedule_interval before end_date!
        # end_date - 2*schedule_interval < start_date <= end_date - schedule_interval

        # Make sure only one execution every runs scheduled but manual triggers work!
        end_date = start_date + 2 * schedule_interval - timedelta(seconds=1)

    if dagrun_timeout_factor:
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(
            dagrun_timeout_factor,
            (int, float)) and (0 < dagrun_timeout_factor <= 1), _msg
        additional_dag_args["dagrun_timeout"] = additional_dag_args.get(
            "dagrun_timeout", dagrun_timeout_factor * schedule_interval)

    if task_timeout_factor:
        additional_task_args["execution_timeout"] = additional_task_args.get(
            "execution_timeout", task_timeout_factor * schedule_interval)

    dag = DAG(
        dag_name,
        catchup=catchup,
        default_args=default_args,
        max_active_runs=1,
        schedule_interval=schedule_interval,
        start_date=start_date,
        end_date=end_date,
        **additional_dag_args,
    )

    kickoff, final = get_uploader(dwh_engine).get_schema_tasks(
        dag=dag,
        dwh_engine=dwh_engine,
        dwh_conn_id=dwh_conn_id,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        read_right_users=read_right_users,
        **additional_task_args,
    )

    base_config = deepcopy(additional_task_args)
    base_config.update(operator_config.get("general_config", {}))
    with dag:
        for table in operator_config["tables"].keys():
            table_config = deepcopy(base_config)
            table_config.update(operator_config["tables"][table] or {})
            table_config.update({
                "task_id":
                "extract_load_" + re.sub(r"[^a-zA-Z0-9_]", "", table),
                "dwh_engine":
                dwh_engine,
                "dwh_conn_id":
                dwh_conn_id,
                "extract_strategy":
                table_config.get(  # Default to full refresh
                    "extract_strategy", None) or
                EC.ES_FULL_REFRESH,  # value can be given as None in table conf
                "target_table_name":
                operator_config["tables"][table].get("target_table_name",
                                                     table),
                "target_schema_name":
                target_schema_name,
                "target_schema_suffix":
                target_schema_suffix,
                "target_database_name":
                target_database_name,
            })
            # Atomic DAG only works with full refresh and subsequent strategies!
            assert table_config["extract_strategy"] in (
                EC.ES_FULL_REFRESH,
                EC.ES_SUBSEQUENT,
            )
            table_config["load_strategy"] = table_config.get(
                "load_strategy",
                EC.DEFAULT_LS_PER_ES[table_config["extract_strategy"]],
            )
            table_task = el_operator(**table_config)
            kickoff >> table_task >> final

    return (dag, )
Beispiel #6
0
def dag_factory_mixed(
    dag_name: str,
    dwh_engine: str,
    dwh_conn_id: str,
    airflow_conn_id: str,
    start_date: datetime,
    el_operator: Type[EWAHBaseOperator],
    operator_config: dict,
    target_schema_name: str,
    target_schema_suffix: str = "_next",
    target_database_name: Optional[str] = None,
    default_args: Optional[dict] = None,
    schedule_interval_full_refresh: timedelta = timedelta(days=1),
    schedule_interval_incremental: timedelta = timedelta(hours=1),
    end_date: Optional[datetime] = None,
    read_right_users: Optional[Union[List[str], str]] = None,
    additional_dag_args: Optional[dict] = None,
    additional_task_args: Optional[dict] = None,
    logging_func: Optional[Callable] = None,
    dagrun_timeout_factor: Optional[float] = None,
    task_timeout_factor: Optional[float] = 0.8,
    **kwargs,
) -> Tuple[DAG, DAG]:
    def raise_exception(msg: str) -> None:
        """Add information to error message before raising."""
        raise Exception("DAG: {0} - Error: {1}".format(dag_name, msg))

    logging_func = logging_func or print

    if kwargs:
        logging_func("unused config: {0}".format(str(kwargs)))

    additional_dag_args = additional_dag_args or {}
    additional_task_args = additional_task_args or {}

    if not read_right_users is None:
        if isinstance(read_right_users, str):
            read_right_users = [u.strip() for u in read_right_users.split(",")]
        if not isinstance(read_right_users, Iterable):
            raise_exception("read_right_users must be an iterable or string!")
    if not isinstance(schedule_interval_full_refresh, timedelta):
        raise_exception("schedule_interval_full_refresh must be timedelta!")
    if not isinstance(schedule_interval_incremental, timedelta):
        raise_exception("schedule_interval_incremental must be timedelta!")
    if schedule_interval_incremental >= schedule_interval_full_refresh:
        _msg = "schedule_interval_incremental must be shorter than "
        _msg += "schedule_interval_full_refresh!"
        raise_exception(_msg)

    """Calculate the datetimes and timedeltas for the two DAGs.

    Full Refresh: The start_date should be chosen such that there is always only
    one DAG execution to be executed at any given point in time.
    See dag_factory_atomic for the same calculation including detailed comments.

    The Incremental DAG starts at the start date + schedule interval of the
    Full Refresh DAG, so that the Incremental executions only happen after
    the Full Refresh execution.
    """
    if not start_date.tzinfo:
        # if no timezone is given, assume UTC
        raise_exception("start_date must be timezone aware!")
    time_now = datetime_utcnow_with_tz()

    if end_date:
        end_date = min(end_date, time_now)
    else:
        end_date = time_now

    if start_date > time_now:
        # Start date for both is in the future
        # start_date_fr = start_date
        # start_date_inc = start_date
        pass
    else:
        start_date += (
            int((end_date - start_date) / schedule_interval_full_refresh)
            * schedule_interval_full_refresh
        )
        if start_date == end_date:
            start_date -= schedule_interval_full_refresh
        else:
            start_date -= schedule_interval_full_refresh
            end_date = (
                start_date + 2 * schedule_interval_full_refresh - timedelta(seconds=1)
            )

        # _td = int((time_now - start_date) / schedule_interval_full_refresh) - 2
        # start_date_fr = start_date + _td * schedule_interval_full_refresh
        # start_date_inc = start_date_fr + schedule_interval_full_refresh

    default_args = default_args or {}
    default_args_fr = deepcopy(default_args)
    default_args_inc = deepcopy(default_args)

    if dagrun_timeout_factor:
        _msg = "dagrun_timeout_factor must be a number between 0 and 1!"
        assert isinstance(dagrun_timeout_factor, (int, float)) and (
            0 < dagrun_timeout_factor <= 1
        ), _msg
        dagrun_timeout_inc = dagrun_timeout_factor * schedule_interval_incremental
        dagrun_timeout_fr = dagrun_timeout_factor * schedule_interval_full_refresh
    else:  # In case of 0 set to None
        dagrun_timeout_inc = None
        dagrun_timeout_fr = None

    if task_timeout_factor:
        _msg = "task_timeout_factor must be a number between 0 and 1!"
        assert isinstance(task_timeout_factor, (int, float)) and (
            0 < task_timeout_factor <= 1
        ), _msg
        execution_timeout_fr = task_timeout_factor * schedule_interval_full_refresh
        execution_timeout_inc = task_timeout_factor * schedule_interval_incremental
    else:
        execution_timeout_fr = None
        execution_timeout_inc = None

    dag_name_fr = dag_name + "_Mixed_Atomic"
    dag_name_inc = dag_name + "_Mixed_Idempotent"
    dags = (
        DAG(
            dag_name_fr,
            start_date=start_date,
            end_date=end_date,
            schedule_interval=schedule_interval_full_refresh,
            catchup=True,
            max_active_runs=1,
            default_args=default_args_fr,
            dagrun_timeout=dagrun_timeout_fr,
            **additional_dag_args,
        ),
        DAG(
            dag_name_inc,
            start_date=start_date + schedule_interval_full_refresh,
            end_date=start_date + 2 * schedule_interval_full_refresh,
            schedule_interval=schedule_interval_incremental,
            catchup=True,
            max_active_runs=1,
            default_args=default_args_inc,
            dagrun_timeout=dagrun_timeout_inc,
            **additional_dag_args,
        ),
        DAG(  # Reset DAG
            dag_name + "_Mixed_Reset",
            start_date=start_date,
            end_date=end_date,
            schedule_interval=None,
            catchup=False,
            max_active_runs=1,
            default_args=default_args,
            **additional_dag_args,
        ),
    )

    # Create reset DAG
    reset_bash_command = " && ".join(  # First pause DAGs, then delete their metadata
        [
            "airflow dags pause {dag_name}_Mixed_Atomic",
            "airflow dags pause {dag_name}_Mixed_Idempotent",
            "airflow dags delete {dag_name}_Mixed_Atomic -y",
            "airflow dags delete {dag_name}_Mixed_Idempotent -y",
        ]
    ).format(dag_name=dag_name)
    reset_task = BashOperator(
        bash_command=reset_bash_command,
        task_id="reset_by_deleting_all_task_instances",
        dag=dags[2],
        **additional_task_args,
    )
    drop_sql = """
        DROP SCHEMA IF EXISTS "{target_schema_name}" CASCADE;
        DROP SCHEMA IF EXISTS "{target_schema_name}{suffix}" CASCADE;
    """.format(
        target_schema_name=target_schema_name,
        suffix=target_schema_suffix,
    )
    if dwh_engine == EC.DWH_ENGINE_POSTGRES:
        drop_task = PGO(
            sql=drop_sql,
            postgres_conn_id=dwh_conn_id,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
        drop_task = SnowflakeOperator(
            sql=drop_sql,
            snowflake_conn_id=dwh_conn_id,
            database=target_database_name,
            task_id="delete_previous_schema_if_exists",
            dag=dags[2],
            **additional_task_args,
        )
    else:
        raise_exception(f'DWH "{dwh_engine}" not implemented for this task!')

    kickoff_fr, final_fr = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[0],
        dwh_engine=dwh_engine,
        dwh_conn_id=dwh_conn_id,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout_fr,
        **additional_task_args,
    )

    kickoff_inc, final_inc = get_uploader(dwh_engine).get_schema_tasks(
        dag=dags[1],
        dwh_engine=dwh_engine,
        dwh_conn_id=dwh_conn_id,
        target_schema_name=target_schema_name,
        target_schema_suffix=target_schema_suffix,
        target_database_name=target_database_name,
        read_right_users=read_right_users,
        execution_timeout=execution_timeout_inc,
        **additional_task_args,
    )

    sql_fr = """
        SELECT
             -- only run if there are no active DAGs that have to finish first
            CASE WHEN COUNT(*) = 0 THEN 1 ELSE 0 END
        FROM public.dag_run
        WHERE state = 'running'
          AND (
                (dag_id = '{0}' AND data_interval_start < '{1}')
            OR  (dag_id = '{2}' AND data_interval_start < '{3}')
          )
    """.format(
        dags[0]._dag_id,  # fr
        "{{ data_interval_start }}",  # no previous full refresh, please!
        dags[1]._dag_id,  # inc
        "{{ data_interval_end }}",  # no old incremental running, please!
    )

    # Sense if a previous instance runs OR if any incremental loads run
    # except incremental load of the same time, which is expected and waits
    fr_snsr = EWAHSqlSensor(
        task_id="sense_run_validity",
        conn_id=airflow_conn_id,
        sql=sql_fr,
        dag=dags[0],
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        **additional_task_args,
    )

    # Sense if a previous instance is complete excepts if its the first, then
    # check for a full refresh of the same time
    inc_ets = ExtendedETS(
        task_id="sense_run_validity",
        allowed_states=["success"],
        external_dag_id=dags[1]._dag_id,
        external_task_id=final_inc.task_id,
        execution_delta=schedule_interval_incremental,
        backfill_dag_id=dags[0]._dag_id,
        backfill_external_task_id=final_fr.task_id,
        backfill_execution_delta=schedule_interval_full_refresh,
        dag=dags[1],
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        **additional_task_args,
    )

    fr_snsr >> kickoff_fr
    inc_ets >> kickoff_inc

    for table in operator_config["tables"].keys():
        arg_dict_inc = deepcopy(additional_task_args)
        arg_dict_inc.update(operator_config.get("general_config", {}))
        op_conf = operator_config["tables"][table] or {}
        arg_dict_inc.update(op_conf)
        arg_dict_inc.update(
            {
                "extract_strategy": arg_dict_inc.get(
                    "extract_strategy", EC.ES_INCREMENTAL
                ),
                "task_id": "extract_load_" + re.sub(r"[^a-zA-Z0-9_]", "", table),
                "dwh_engine": dwh_engine,
                "dwh_conn_id": dwh_conn_id,
                "target_table_name": op_conf.get("target_table_name", table),
                "target_schema_name": target_schema_name,
                "target_schema_suffix": target_schema_suffix,
                "target_database_name": target_database_name,
            }
        )
        arg_dict_inc["load_strategy"] = arg_dict_inc.get(
            "load_strategy", EC.DEFAULT_LS_PER_ES[arg_dict_inc["extract_strategy"]]
        )
        arg_dict_fr = deepcopy(arg_dict_inc)
        arg_dict_fr["extract_strategy"] = EC.ES_FULL_REFRESH
        arg_dict_fr["load_strategy"] = EC.LS_INSERT_REPLACE

        arg_dict_fr["execution_timeout"] = execution_timeout_fr
        arg_dict_inc["execution_timeout"] = execution_timeout_inc

        task_fr = el_operator(dag=dags[0], **arg_dict_fr)
        task_inc = el_operator(dag=dags[1], **arg_dict_inc)

        kickoff_fr >> task_fr >> final_fr
        kickoff_inc >> task_inc >> final_inc

    return dags