Exemple #1
0
    def mongoclient(self):
        if not hasattr(self, "_mc"):
            if self.conn.ssh_conn_id:
                if self.conn.conn_style == "uri":
                    raise Exception(
                        "Cannot have SSH tunnel with uri connection type!")
                if not hasattr(self, "_ssh_hook"):
                    self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                        conn_id=self.conn.ssh_conn_id)
                    self.local_bind_address = self._ssh_hook.start_tunnel(
                        self.conn.host, self.conn.port)
            else:
                self.local_bind_address = (self.conn.host, self.conn.port)

            conn_kwargs = {"tz_aware": True}
            if self.conn.conn_style == "uri":
                conn_kwargs["host"] = self.conn.uri
            else:
                conn_kwargs["host"] = self.local_bind_address[0]
                conn_kwargs["port"] = self.local_bind_address[1]
                if self.conn.username:
                    conn_kwargs["username"] = self.conn.username
                if self.conn.password:
                    conn_kwargs["password"] = self.conn.password

            with TemporaryDirectory() as tmp_dir:
                if self.conn.tls:
                    conn_kwargs["tls"] = True
                with NamedTemporaryFile(dir=tmp_dir) as ssl_cert:
                    with NamedTemporaryFile(dir=tmp_dir) as ssl_private:
                        if self.conn.ssl_cert:
                            ssl_cert.write(self.conn.ssl_cert.encode())
                            ssl_cert.seek(0)
                            conn_kwargs["ssl_certfile"] = os.path.abspath(
                                ssl_cert.name)
                        if self.conn.ssl_private:
                            ssl_private.write(self.conn.ssl_private.encode())
                            ssl_private.seek(0)
                            conn_kwargs["ssl_keyfile"] = os.path.abspath(
                                ssl_private.name)
                        if self.conn.ssl_password:
                            conn_kwargs[
                                "tlsCertificateKeyFilePassword"] = self.conn.ssl_password
                        if self.conn.tls_insecure:
                            conn_kwargs["tlsInsecure"] = True
                        if self.conn.auth_source:
                            conn_kwargs["authSource"] = self.conn.auth_source
                        if self.conn.auth_mechanism:
                            conn_kwargs[
                                "authMechanism"] = self.conn.auth_mechanism
                        self._mc = MongoClient(**conn_kwargs)

        return self._mc
Exemple #2
0
 def dbconn(self):
     if not hasattr(self, "_dbconn"):
         if hasattr(self.conn, "ssh_conn_id") and self.conn.ssh_conn_id:
             if not hasattr(self, "_ssh_hook"):
                 self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                     conn_id=self.conn.ssh_conn_id)
                 self.local_bind_address = self._ssh_hook.start_tunnel(
                     self.conn.host, self.conn.port or self._DEFAULT_PORT)
         else:
             self.local_bind_address = self.conn.host, self.conn.port
         self._dbconn = self._get_db_conn()
     return self._dbconn
Exemple #3
0
 def kickoff_func(schema_full, schema_suffix, dwh_conn_id):
     # kickoff: create new dataset
     schema = schema_full + schema_suffix
     conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn
     # delete dataset first if it already exists
     print("Deleting the dataset {0} if it already exists.".format(
         schema))
     conn.delete_dataset(schema,
                         delete_contents=True,
                         not_found_ok=True)
     print("Creating the dataset {0}.".format(schema))
     conn.create_dataset(schema)
     print("Done!")
    def __init__(
            self,
            workbook_key,  # can be seen in the URL of the workbook
            sheet_key,  # name of the worksheet
            sheet_columns,  # list or dict[column name, position] defining which columns to load
            start_row=2,  # in what row does the data begin?
            end_row=None,  # optional: what is the last row? None gets all data
            *args,
            **kwargs):
        super().__init__(*args, **kwargs)

        credentials = BaseHook.get_connection(self.source_conn_id).extra_dejson
        credentials = credentials.get("client_secrets", credentials)

        _msg = "Google Service Account Credentials misspecified!"
        _msg += " Example of a correct specifidation: {0}".format(
            json.dumps(self._SAMPLE_JSON))
        for key in self._SAMPLE_JSON["client_secrets"]:
            if not key in credentials:
                raise Exception(_msg)

        column_match = {}
        if isinstance(sheet_columns, list):
            i = 0
            for column in sheet_columns:
                i += 1
                column_match[i] = column
        elif isinstance(sheet_columns, dict):
            column_match = {
                self._translate_alphanumeric_column(value): key
                for key, value in sheet_columns.items()
            }
        else:
            raise Exception("sheet_columns must be a list or a dict!")

        self.client_secrets = credentials
        self.column_match = column_match
        self.workbook_key = workbook_key
        self.sheet_key = sheet_key
        self.start_row = start_row
        self.end_row = end_row
Exemple #5
0
    def __init__(
            self,
            workbook_key,  # can be seen in the URL of the workbook
            sheet_key,  # name of the worksheet
            start_row=2,  # in what row does the data begin?
            end_row=None,  # optional: what is the last row? None gets all data
            *args,
            **kwargs):
        super().__init__(*args, **kwargs)

        credentials = BaseHook.get_connection(self.source_conn_id).extra_dejson
        credentials = credentials.get("client_secrets", credentials)

        _msg = "Google Service Account Credentials misspecified!"
        _msg += " Example of a correct specifidation: {0}".format(
            json.dumps(self._SAMPLE_JSON))
        for key in self._SAMPLE_JSON["client_secrets"]:
            if not key in credentials:
                raise Exception(_msg)

        column_match = {}
        for col_key, col_def in self.columns_definition.items():
            if (not col_def) or (not col_def.get(EC.QBC_FIELD_GSHEET_COLNO)):
                raise Exception(
                    ("Column {0} is missing information regarding the " +
                     "position of the column in the sheet.").format(col_key))
            column_match.update({
                self._translate_alphanumeric_column(
                    col_def[EC.QBC_FIELD_GSHEET_COLNO], ):
                col_key,
            })

        self.client_secrets = credentials
        self.column_match = column_match
        self.workbook_key = workbook_key
        self.sheet_key = sheet_key
        self.start_row = start_row
        self.end_row = end_row
Exemple #6
0
    def execute(self, context):
        """Why this method is defined here:
        When executing a task, airflow calls this method. Generally, this
        method contains the "business logic" of the individual operator.
        However, EWAH may want to do some actions for all operators. Thus,
        the child operators shall have an ewah_execute() function which is
        called by this general execute() method.
        """

        # required for metadata in data upload
        self._execution_time = datetime_utcnow_with_tz()
        self._context = context

        self.uploader = self.uploader(
            EWAHBaseHook.get_connection(self.dwh_conn_id))

        if self.source_conn_id:
            # resolve conn id here & delete the object to avoid usage elsewhere
            self.source_conn = EWAHBaseHook.get_connection(self.source_conn_id)
            self.source_hook = self.source_conn.get_hook()
        del self.source_conn_id

        if self._CONN_TYPE:
            _msg = "Error - connection type must be {0}!".format(
                self._CONN_TYPE)
            assert self._CONN_TYPE == self.source_conn.conn_type, _msg

        temp_schema_name = self.target_schema_name + self.target_schema_suffix
        # Create a new copy of the target table.
        # This is so data is loaded into a new table and if data loading
        # fails, the original data is not corrupted. At a new try or re-run,
        # the original table is just copited anew.
        if not self.extract_strategy == EC.ES_FULL_REFRESH:
            # Full refresh always drops and replaces the tables completely
            self.uploader.copy_table(
                old_schema=self.target_schema_name,
                old_table=self.target_table_name,
                new_schema=temp_schema_name,
                new_table=self.target_table_name,
                database_name=self.target_database_name,
            )

        # set load_data_from and load_data_until as required
        data_from = ada(self.load_data_from)
        data_until = ada(self.load_data_until)
        if self.extract_strategy == EC.ES_INCREMENTAL:
            _tdz = timedelta(days=0)  # aka timedelta zero
            _ed = context["execution_date"]
            _ned = context["next_execution_date"]

            # normal incremental load
            _ed -= self.load_data_from_relative or _tdz
            data_from = max(_ed, data_from or _ed)
            if not self.test_if_target_table_exists():
                # Load data from scratch!
                data_from = ada(self.reload_data_from) or data_from

            _ned += self.load_data_until_relative or _tdz
            data_until = min(_ned, data_until or _ned)

        elif self.extract_strategy == EC.ES_FULL_REFRESH:
            # Values may still be set as static values
            data_from = ada(self.reload_data_from) or data_from

        else:
            _msg = "Must define load_data_from etc. behavior for load strategy!"
            raise Exception(_msg)

        self.data_from = data_from
        self.data_until = data_until
        # del variables to make sure they are not used later on
        del self.load_data_from
        del self.reload_data_from
        del self.load_data_until
        del self.load_data_from_relative
        del self.load_data_until_relative

        # Have an option to wait until a short period (e.g. 2 minutes) past
        # the incremental loading range timeframe to ensure that all data is
        # loaded, useful e.g. if APIs lag or if server timestamps are not
        # perfectly accurate.
        # When a DAG is executed as soon as possible, some data sources
        # may not immediately have up to date data from their API.
        # E.g. querying all data until 12.30pm only gives all relevant data
        # after 12.32pm due to some internal delays. In those cases, make
        # sure the (incremental loading) DAGs don't execute too quickly.
        if self.wait_for_seconds and self.extract_strategy == EC.ES_INCREMENTAL:
            wait_until = context.get("next_execution_date")
            if wait_until:
                wait_until += timedelta(seconds=self.wait_for_seconds)
                self.log.info("Awaiting execution until {0}...".format(
                    str(wait_until), ))
            while wait_until and datetime_utcnow_with_tz() < wait_until:
                # Only sleep a maximum of 5s at a time
                wait_for_timedelta = wait_until - datetime_utcnow_with_tz()
                time.sleep(max(0, min(wait_for_timedelta.total_seconds(), 5)))

        # execute operator
        if self.load_data_chunking_timedelta and data_from and data_until:
            # Chunking to avoid OOM
            assert data_until > data_from
            assert self.load_data_chunking_timedelta > timedelta(days=0)
            while self.data_from < data_until:
                self.data_until = self.data_from
                self.data_until += self.load_data_chunking_timedelta
                self.data_until = min(self.data_until, data_until)
                self.ewah_execute(context)
                self.data_from += self.load_data_chunking_timedelta
        else:
            self.ewah_execute(context)

        # if PostgreSQL and arg given: create indices
        for column in self.index_columns:
            assert self.dwh_engine == EC.DWH_ENGINE_POSTGRES
            # Use hashlib to create a unique 63 character string as index
            # name to avoid breaching index name length limits & accidental
            # duplicates / missing indices due to name truncation leading to
            # identical index names.
            self.uploader.dwh_hook.execute(
                self._INDEX_QUERY.format(
                    "__ewah_" + hashlib.blake2b(
                        (temp_schema_name + "." + self.target_table_name +
                         "." + column).encode(),
                        digest_size=28,
                    ).hexdigest(),
                    self.target_schema_name + self.target_schema_suffix,
                    self.target_table_name,
                    column,
                ))

        # commit only at the end, so that no data may be committed before an
        # error occurs.
        self.log.info("Now committing changes!")
        self.uploader.commit()
        self.uploader.close()
Exemple #7
0
        def final_func(schema_name, schema_suffix, dwh_conn_id):
            # final: move new data into the final dataset
            conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn
            # get dataset objects
            try:  # create final dataset if not exists
                ds_final = conn.get_dataset(schema_name)
            except:
                print("Creating dataset {0}".format(schema_name))
                ds_final = conn.create_dataset(schema_name)
            ds_temp = conn.get_dataset(schema_name + schema_suffix)

            # copy all tables from temp dataset to final dataset
            new_tables = conn.list_tables(ds_temp)
            new_table_ids = [
                table.table_id for table in conn.list_tables(ds_temp)
            ]
            old_table_ids = [
                table.table_id for table in conn.list_tables(ds_final)
            ]
            copy_jobs = []
            for table in new_tables:
                print("Copying table {0} from temp to final dataset".format(
                    table.table_id))
                try:
                    old_table = conn.get_table(table=TableReference(
                        dataset_ref=ds_final, table_id=table.table_id))
                    conn.delete_table(old_table)
                except:
                    # ignore failure, fails if old table does not exist to begin with
                    pass
                finally:
                    final_table = ds_final.table(table.table_id)
                    copy_jobs.append(conn.copy_table(table, final_table))

            # delete tables that don't exist in temp dataset from final dataset
            for table_id in old_table_ids:
                if not table_id in new_table_ids:
                    print("Deleting table {0}".format(table_id))
                    conn.delete_table(
                        conn.get_table(
                            TableReference(dataset_ref=ds_final,
                                           table_id=table_id)))

            # make sure all copy jobs succeeded
            while copy_jobs:
                sleep(0.1)
                job = copy_jobs.pop(0)
                job.result()
                assert job.state in ("RUNNING", "DONE")
                if job.state == "RUNNING":
                    copy_jobs.append(job)
                else:
                    print("Successfully copied {0}".format(
                        job.__dict__["_properties"]["configuration"]["copy"]
                        ["destinationTable"]["tableId"]))

            # delete temp dataset
            print("Deleting temp dataset.")
            conn.delete_dataset(ds_temp,
                                delete_contents=True,
                                not_found_ok=False)

            print("Done.")
Exemple #8
0
def dbt_dags_factory_legacy(
    dwh_engine,
    dwh_conn_id,
    project_name,
    dbt_schema_name,
    airflow_conn_id,
    dag_base_name="DBT_run",
    analytics_reader=None,  # list of users of DWH who are read-only
    schedule_interval=timedelta(hours=1),
    start_date=datetime(2019, 1, 1),
    default_args=None,
    folder=None,
    models=None,
    exclude=None,
):

    if analytics_reader:
        for statement in (
                "insert",
                "update",
                "delete",
                "drop",
                "create",
                "select",
                ";",
                "grant",
        ):
            for reader in analytics_reader:
                if statement in reader.lower():
                    raise Exception("Error! The analytics reader {0} " +
                                    "is invalid.".format(reader))

        # analytics_reader = analytics_reader.split(',')
        analytics_reader_sql = f'\nGRANT USAGE ON SCHEMA "{dbt_schema_name}"'
        analytics_reader_sql += ' TO "{0}";'
        analytics_reader_sql += (f'''
        \nGRANT SELECT ON ALL TABLES IN SCHEMA "{dbt_schema_name}"''' +
                                 ' TO "{0}";')
        analytics_reader_sql = "".join(
            [analytics_reader_sql.format(i) for i in analytics_reader])

    if models and not (type(models) == str):
        models = " --models " + " ".join(models)
    else:
        models = ""

    if exclude and not (type(exclude) == str):
        exclude = " --exclude " + " ".join(exclude)
    else:
        exclude = ""

    flags = models + exclude

    dag = DAG(
        dag_base_name,
        catchup=False,
        max_active_runs=1,
        schedule_interval=schedule_interval,
        start_date=start_date,
        default_args=default_args,
    )

    dag_full_refresh = DAG(
        dag_base_name + "_full_refresh",
        catchup=False,
        max_active_runs=1,
        schedule_interval=None,
        start_date=start_date,
        default_args=default_args,
    )

    folder = folder or (os.environ.get("AIRFLOW_HOME")
                        or conf.get("core", "airflow_home")).replace(
                            "airflow_home/airflow",
                            "dbt_home",
                        )

    bash_command = """
    cd {1}
    source env/bin/activate
    cd {2}
    dbt {0}
    """.format(
        "{0}",
        folder,
        project_name,
    )

    sensor_sql = """
        SELECT
            CASE WHEN COUNT(*) = 0 THEN 1 ELSE 0 END -- only run if exatly equal to 0
        FROM public.dag_run
        WHERE dag_id IN ('{0}', '{1}')
        and state = 'running'
        and not (run_id = '{2}')
    """.format(
        dag._dag_id,
        dag_full_refresh._dag_id,
        "{{ run_id }}",
    )

    # refactor?! not coupled to values in profiles.yml!
    if dwh_engine == EC.DWH_ENGINE_POSTGRES:
        conn = BaseHook.get_connection(dwh_conn_id)
        env = {
            "DBT_DWH_HOST": str(conn.host),
            "DBT_DWH_USER": str(conn.login),
            "DBT_DWH_PASS": str(conn.password),
            "DBT_DWH_PORT": str(conn.port),
            "DBT_DWH_DBNAME": str(conn.schema),
            "DBT_DWH_SCHEMA": dbt_schema_name,
            "DBT_PROFILES_DIR": folder,
        }
    elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
        analytics_conn = BaseHook.get_connection(dwh_conn_id)
        analytics_conn_extra = analytics_conn.extra_dejson
        env = {
            "DBT_ACCOUNT":
            analytics_conn_extra.get(
                "account",
                analytics_conn.host,
            ),
            "DBT_USER":
            analytics_conn.login,
            "DBT_PASS":
            analytics_conn.password,
            "DBT_ROLE":
            analytics_conn_extra.get("role"),
            "DBT_DB":
            analytics_conn_extra.get("database"),
            "DBT_WH":
            analytics_conn_extra.get("warehouse"),
            "DBT_SCHEMA":
            dbt_schema_name,
            "DBT_PROFILES_DIR":
            folder,
        }
    else:
        raise ValueError("DWH type not implemented!")

    # with dag:
    snsr = EWAHSqlSensor(
        task_id="sense_dbt_conflict_avoided",
        conn_id=airflow_conn_id,
        sql=sensor_sql,
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        dag=dag,
    )

    dbt_seed = BashOperator(
        task_id="run_dbt_seed",
        bash_command=bash_command.format("seed"),
        env=env,
        dag=dag,
    )

    dbt_run = BashOperator(
        task_id="run_dbt",
        bash_command=bash_command.format("run" + flags),
        env=env,
        dag=dag,
    )

    dbt_test = BashOperator(
        task_id="test_dbt",
        bash_command=bash_command.format("test" + flags),
        env=env,
        dag=dag,
    )

    dbt_docs = BashOperator(
        task_id="create_dbt_docs",
        bash_command=bash_command.format("docs generate"),
        env=env,
        dag=dag,
    )

    snsr >> dbt_seed >> dbt_run >> dbt_test

    if analytics_reader:
        # This should not occur when using Snowflake
        read_rights = PostgresOperator(
            task_id="grant_access_to_read_users",
            sql=analytics_reader_sql,
            postgres_conn_id=dwh_conn_id,
            dag=dag,
        )
        dbt_test >> read_rights >> dbt_docs
    else:
        dbt_test >> dbt_docs

    # with dag_full_refresh:
    snsr = EWAHSqlSensor(
        task_id="sense_dbt_conflict_avoided",
        conn_id=airflow_conn_id,
        sql=sensor_sql,
        poke_interval=5 * 60,
        mode="reschedule",  # don't block a worker and pool slot
        dag=dag_full_refresh,
    )

    dbt_seed = BashOperator(
        task_id="run_dbt_seed",
        bash_command=bash_command.format("seed"),
        env=env,
        dag=dag_full_refresh,
    )

    dbt_run = BashOperator(
        task_id="run_dbt",
        bash_command=bash_command.format("run --full-refresh" + flags),
        env=env,
        dag=dag_full_refresh,
    )

    dbt_test = BashOperator(
        task_id="test_dbt",
        bash_command=bash_command.format("test" + flags),
        env=env,
        dag=dag_full_refresh,
    )

    dbt_docs = BashOperator(
        task_id="create_dbt_docs",
        bash_command=bash_command.format("docs generate"),
        env=env,
        dag=dag_full_refresh,
    )

    snsr >> dbt_seed >> dbt_run >> dbt_test

    if analytics_reader:
        read_rights = PostgresOperator(
            task_id="grant_access_to_read_users",
            sql=analytics_reader_sql,
            postgres_conn_id=dwh_conn_id,
            dag=dag_full_refresh,
        )
        dbt_test >> read_rights >> dbt_docs
    else:
        dbt_test >> dbt_docs
    return (dag, dag_full_refresh)
Exemple #9
0
    def execute(self, context):

        # env to be used in processes later
        env = os.environ.copy()
        env["PIP_USER"] = "******"

        # create a new temp folder, all action happens in here
        with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir:
            # clone repo into temp directory
            repo_dir = tmp_dir + os.path.sep + "repo"
            if self.repo_type == "git":
                # Clone repo into temp folder
                git_hook = EWAHBaseHook.get_hook_from_conn_id(
                    conn_id=self.git_conn_id)
                git_hook.clone_repo(repo_dir, env)
            elif self.repo_type == "local":
                # Copy local version of the repository into temp folder
                copy_tree(self.local_path, repo_dir)
            else:
                raise Exception("Not Implemented!")

            # create a virual environment in temp folder
            venv_folder = tmp_dir + os.path.sep + "venv"
            self.log.info(
                "creating a new virtual environment in {0}...".format(
                    venv_folder, ))
            venv.create(venv_folder, with_pip=True)
            dbt_dir = repo_dir
            if self.subfolder:
                if not self.subfolder[:1] == os.path.sep:
                    self.subfolder = os.path.sep + self.subfolder
                dbt_dir += self.subfolder

            dwh_hook = EWAHBaseHook.get_hook_from_conn_id(self.dwh_conn_id)
            # in case of SSH: execute a query to create the connection and tunnel
            dwh_hook.execute("SELECT 1 AS a -- Testing the connection")
            dwh_conn = dwh_hook.conn

            # read profile name and dbt version & create temporary profiles.yml
            project_yml_file = dbt_dir
            if not project_yml_file[-1:] == os.path.sep:
                project_yml_file += os.path.sep
            project_yml_file += "dbt_project.yml"
            project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader)
            profile_name = project_yml["profile"]
            dbt_version = self.dbt_version or project_yml.get(
                "require-dbt-version")
            del self.dbt_version  # Make sure it can't accidentally be used below
            assert dbt_version, "Must supply dbt_version or set require-dbt-version!"
            if isinstance(dbt_version, str):
                if not dbt_version.startswith(("=", "<", ">")):
                    dbt_version = "==" + dbt_version
            elif isinstance(dbt_version, list):
                dbt_version = ",".join(dbt_version)
            else:
                raise Exception(
                    "dbt_version must be a string or a list of strings!")
            self.log.info('Creating temp profile "{0}"'.format(profile_name))
            profiles_yml = {
                "config": {
                    "send_anonymous_usage_stats": False,
                    "use_colors": False,  # colors won't be useful in logs
                },
            }
            if self.dwh_engine == EC.DWH_ENGINE_POSTGRES:
                mb_database = dwh_conn.schema
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for postgres
                            "type": "postgres",
                            "host": dwh_hook.local_bind_address[0],
                            "port": dwh_hook.local_bind_address[1],
                            "user": dwh_conn.login,
                            "pass": dwh_conn.password,
                            "dbname": dwh_conn.schema,
                            "schema": self.schema_name,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
                mb_database = self.database_name or dwh_conn.database
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for snowflake
                            "type": "snowflake",
                            "account": dwh_conn.account,
                            "user": dwh_conn.user,
                            "password": dwh_conn.password,
                            "role": dwh_conn.role,
                            "database": self.database_name or dwh_conn.database,
                            "warehouse": dwh_conn.warehouse,
                            "schema": self.schema_name or dwh_conn.schema,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            elif self.dwh_engine == EC.DWH_ENGINE_BIGQUERY:
                mb_database = self.database_name
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {
                            "type":
                            "bigquery",
                            "method":
                            "service-account-json",
                            "project":
                            self.database_name,
                            "dataset":
                            self.schema_name,
                            "threads":
                            self.threads,
                            "timeout_seconds":
                            self.keepalives_idle or 300,
                            "priority":
                            "interactive",
                            "keyfile_json":
                            json.loads(dwh_conn.service_account_json),
                        },
                    },
                }
                if dwh_conn.location:
                    profiles_yml[profile_name]["outputs"]["prod"][
                        "location"] = dwh_conn.location
            else:
                raise Exception("DWH Engine not implemented!")

            # install dbt into created venv
            cmd = []
            cmd.append("source {0}/bin/activate".format(venv_folder))
            cmd.append("pip install --quiet --upgrade pip setuptools")
            if re.search("[^0-9\.]0(\.[0-9]+)?(\.[0-9]+)?$", dbt_version):
                # regex checks whether the (last) version start with 0
                # if true, version <1.0.0 required
                cmd.append(
                    'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt{0}"'
                    .format(dbt_version))
            else:
                # Different pip behavior since dbt 1.0.0
                cmd.append(
                    'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt-{0}{1}"'
                    .format(
                        {
                            EC.DWH_ENGINE_POSTGRES: "postgres",
                            EC.DWH_ENGINE_SNOWFLAKE: "snowflake",
                            EC.DWH_ENGINE_BIGQUERY: "bigquery",
                        }[self.dwh_engine],
                        dbt_version,
                    ))

            cmd.append("dbt --version")
            cmd.append("deactivate")
            assert run_cmd(cmd, env, self.log.info) == 0

            # run commands with correct profile in the venv in the temp folder
            profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml"
            env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir)
            with open(profiles_yml_name, "w") as profiles_file:
                # write profile into profiles.yml file
                yaml.dump(profiles_yml,
                          profiles_file,
                          default_flow_style=False)

                # run dbt commands
                self.log.info("Now running commands dbt!")
                cmd = []
                cmd.append("cd {0}".format(dbt_dir))
                cmd.append("source {0}/bin/activate".format(venv_folder))
                if self.repo_type == "local":
                    cmd.append("dbt clean")
                cmd.append("dbt deps")
                [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands]
                cmd.append("deactivate")
                assert run_cmd(cmd, env, self.log.info) == 0

                if self.metabase_conn_id:
                    # Push docs to Metabase at the end of the run!
                    metabase_hook = EWAHBaseHook.get_hook_from_conn_id(
                        conn_id=self.metabase_conn_id)
                    metabase_hook.push_dbt_docs_to_metabase(
                        dbt_project_path=dbt_dir,
                        dbt_database_name=mb_database,
                    )
Exemple #10
0
    def __init__(
            self,
            account_ids,
            insight_fields,
            level,
            time_increment=1,
            breakdowns=None,
            execution_waittime_seconds=15,  # wait for a while before execution
            #   between account_ids to avoid hitting rate limits during backfill
        pagination_limit=1000,
            async_job_read_frequency_seconds=5,
            *args,
            **kwargs):

        if kwargs.get("update_on_columns"):
            raise Exception("update_on_columns is set by operator!")

        if not account_ids.__iter__:
            raise Exception(
                "account_ids must be an iterable, such as a list," +
                " of strings or integers!")

        if level == self.levels.ad:
            kwargs["update_on_columns"] = [
                "ad_id",
                "date_start",
                "date_stop",
            ] + (breakdowns or [])
            insight_fields += ["ad_id", "ad_name"]
            insight_fields = list(set(insight_fields))
        else:
            raise Exception("Specified level not supported!")

        if not ((type(time_increment) == str
                 and time_increment in ["monthly", "all_days"]) or
                (type(time_increment) == int and time_increment >= 1
                 and time_increment <= 1)):
            raise Exception(
                "time_increment must either be an integer " +
                'between 1 and 90, or a string of either "monthly" ' +
                'or "all_days". Recommended and default is the integer 1.')

        allowed_insight_fields = [
            _attr[1] for _attr in [
                member for member in inspect.getmembers(
                    AdsInsights.Field,
                    lambda a: not (inspect.isroutine(a)),
                ) if not (
                    member[0].startswith("__") and member[0].endswith("__"))
            ]
        ]
        for i_f in insight_fields:
            if not i_f in allowed_insight_fields:
                raise Exception((
                    "Field {0} is not an accepted value for insight_fields! " +
                    "Accepted field values:\n\t{1}\n").format(
                        i_f, "\n\t".join(allowed_insight_fields)))

        super().__init__(*args, **kwargs)

        credentials = BaseHook.get_connection(self.source_conn_id)
        extra = credentials.extra_dejson

        # Note: app_secret is not always required!
        if not extra.get("app_id"):
            raise Exception('Connection extra must contain an "app_id"!')
        if not extra.get("access_token", credentials.password):
            raise Exception(
                'Connection extra must contain an "access_token" ' +
                "if it is not saved as the connection password!")

        self.credentials = {
            "app_id": extra.get("app_id"),
            "app_secret": extra.get("app_secret"),
            "access_token": extra.get("access_token", credentials.password),
        }

        self.account_ids = account_ids
        self.insight_fields = insight_fields
        self.level = level
        self.time_increment = time_increment
        self.breakdowns = breakdowns
        self.execution_waittime_seconds = execution_waittime_seconds
        self.pagination_limit = pagination_limit
        self.async_job_read_frequency_seconds = async_job_read_frequency_seconds
Exemple #11
0
 def _get_hook(self):
     conn = EWAHBaseHook.get_connection(conn_id=self.conn_id)
     if not conn.conn_type.startswith("ewah"):
         raise Exception(
             "Must use an appropriate EWAH custom connection type!")
     return conn.get_hook()
Exemple #12
0
    def get_data_in_batches(self,
                            endpoint,
                            page_size=100,
                            batch_size=10000,
                            data_from=None):
        endpoint = self._ENDPOINTS.get(endpoint, endpoint)
        params = {}
        if data_from:
            assert endpoint == self._ENDPOINT_DAGRUNS
            # get DagRuns that ended since data_from
            params["end_date_gte"] = data_from.isoformat()
            params["order_by"] = "end_date"
        auth = requests.auth.HTTPBasicAuth(self.conn.login, self.conn.password)
        if self.conn.ssh_conn_id:
            ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                conn_id=self.conn.ssh_conn_id)
            ssh_host = self.conn.host or "localhost"
            ssh_host = ssh_host.replace("https://", "").replace("http://", "")
            ssh_port = self.conn.port
            if not ssh_port:
                if self.conn.protocol == "http":
                    ssh_port = 80
                else:
                    ssh_port = 443
            else:
                ssh_port = int(ssh_port)
            local_bind_address = ssh_hook.start_tunnel(ssh_host, ssh_port)
            host = "{2}://{0}:{1}".format(
                local_bind_address[0],
                str(local_bind_address[1]),
                self.conn.protocol or "http",
            )
        else:
            host = self.conn.host
            if not host.startswith("http"):
                host = self.conn.protocol + "://" + host
        url = self._BASE_URL.format(host, endpoint)
        params["limit"] = page_size
        params["offset"] = 0
        data = []
        i = 0
        while True:
            i += 1
            self.log.info("Making request {0} to {1}...".format(i, url))
            request = requests.get(url, params=params, auth=auth)
            assert request.status_code == 200, request.text
            response = request.json()
            keys = list(response.keys())
            if "total_entries" in keys:
                # Most endpoint use pagination + give "total_entries" for requests
                # The key to get the data from the response may differ from endpoint
                if keys[0] == "total_entries":
                    data_key = keys[1]
                else:
                    data_key = keys[0]
                if not response[data_key]:
                    # You know that you fetched all items if an empty list is returned
                    # (Note: total_entries is not reliable)
                    yield data
                    data = []
                    break
                data += response[data_key]
                if len(data) >= batch_size:
                    yield data
                    data = []
            else:
                # Rare endpoint that does not paginate (usually singletons)
                yield [response]
                break
            params["offset"] = params["offset"] + params["limit"]

        if self.conn.ssh_conn_id:
            ssh_hook.stop_tunnel()
            del ssh_hook
Exemple #13
0
    def __init__(
            self,
            api,  # one of _API_CORE_V3, _API_CORE_V4, _API_MULTI
            view_id,
            dimensions,
            metrics,
            page_size=10000,
            include_empty_rows=True,
            sampling_level=None,
            *args,
            **kwargs):
        if kwargs.get("update_on_columns"):
            raise Exception("update_on_columns supplied, but the field is " +
                            "auto-generated by the operator!")
        if not api in self._ACCEPTED_API:
            raise Exception("api must be one of these: {0}".format(
                ", ".join(self._ACCEPTED_API), ))

        if api == self._API_MULTI:
            shorthand = "mcf:"
        else:
            shorthand = "ga:"

        dimensions = [("" if dim.startswith(shorthand) else shorthand) + dim
                      for dim in dimensions]
        metrics = [("" if metric.startswith(shorthand) else shorthand) + metric
                   for metric in metrics]

        kwargs.update({"update_on_columns": [dim[3:] for dim in dimensions]})

        self.api = api
        self.view_id = view_id
        self.sampling_level = sampling_level
        self.dimensions = dimensions
        self.metrics = metrics
        self.page_size = page_size
        self.include_empty_rows = include_empty_rows

        self.metricMap = {
            "METRIC_TYPE_UNSPECIFIED": "varchar(255)",
            "CURRENCY": "decimal(20,5)",
            "INTEGER": "int(11)",
            "FLOAT": "decimal(20,5)",
            "PERCENT": "decimal(20,5)",
            "TIME": "time",
        }

        super().__init__(*args, **kwargs)

        credentials = BaseHook.get_connection(self.source_conn_id).extra_dejson

        if not credentials.get("client_secrets"):
            _msg = "Google Analytics Credentials misspecified!"
            _msg += " Example of a correct specifidation: {0}".format(
                json.dumps(self._SAMPLE_JSON))
            for key in self._SAMPLE_JSON["client_secrets"]:
                if not key in credentials:
                    raise Exception(_msg)

        if len(dimensions) > 7:
            raise Exception(
                ("Can only fetch up to 7 dimensions!" +
                 " Currently {0} Dimensions").format(str(len(dimensions)), ))

        if len(metrics) > 10:
            raise Exception(
                ("Can only fetch up to 10 metrics!" +
                 " Currently {0} Dimensions").format(str(len(metrics)), ))

        if self.page_size > 10000:
            raise Exception(
                "Please specify a page size equal to or lower than 10000.")
Exemple #14
0
    def execute(self, context):

        # env to be used in processes later
        env = os.environ.copy()

        # create a new temp folder, all action happens in here
        with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir:
            # clone repo into temp directory
            repo_dir = tmp_dir + os.path.sep + "repo"
            if self.repo_type == "git":
                git_hook = EWAHBaseHook.get_hook_from_conn_id(
                    conn_id=self.git_conn_id)
                git_hook.clone_repo(repo_dir, env)
            else:
                raise Exception("Not Implemented!")

            # create a virual environment in temp folder
            venv_folder = tmp_dir + os.path.sep + "venv"
            self.log.info(
                "creating a new virtual environment in {0}...".format(
                    venv_folder, ))
            venv.create(venv_folder, with_pip=True)

            # install dbt into created venv
            self.log.info("installing dbt=={0}".format(self.dbt_version))
            cmd = []
            cmd.append("source {0}/bin/activate".format(venv_folder))
            cmd.append("pip install --quiet --upgrade dbt=={0}".format(
                self.dbt_version))
            cmd.append("dbt --version")
            cmd.append("deactivate")
            assert run_cmd(cmd, env, self.log.info) == 0

            dbt_dir = repo_dir
            if self.subfolder:
                if not self.subfolder[:1] == os.path.sep:
                    self.subfolder = os.path.sep + self.subfolder
                dbt_dir += self.subfolder

            dwh_conn = EWAHBaseHook.get_connection(self.dwh_conn_id)

            # read profile name & create temporary profiles.yml
            project_yml_file = dbt_dir
            if not project_yml_file[-1:] == os.path.sep:
                project_yml_file += os.path.sep
            project_yml_file += "dbt_project.yml"
            project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader)
            profile_name = project_yml["profile"]
            self.log.info('Creating temp profile "{0}"'.format(profile_name))
            profiles_yml = {
                "config": {
                    "send_anonymous_usage_stats": False,
                    "use_colors": False,  # colors won't be useful in logs
                },
            }
            if self.dwh_engine == EC.DWH_ENGINE_POSTGRES:
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for postgres
                            "type": "postgres",
                            "host": dwh_conn.host,
                            "port": dwh_conn.port or "5432",
                            "user": dwh_conn.login,
                            "pass": dwh_conn.password,
                            "dbname": dwh_conn.schema,
                            "schema": self.schema_name,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
                profiles_yml[profile_name] = {
                    "target": "prod",  # same as the output defined below
                    "outputs": {
                        "prod": {  # for snowflake
                            "type": "snowflake",
                            "account": dwh_conn.account,
                            "user": dwh_conn.user,
                            "password": dwh_conn.password,
                            "role": dwh_conn.role,
                            "database": self.database_name or dwh_conn.database,
                            "warehouse": dwh_conn.warehouse,
                            "schema": self.schema_name or dwh_conn.schema,
                            "threads": self.threads,
                            "keepalives_idle": self.keepalives_idle,
                        },
                    },
                }
            else:
                raise Exception("DWH Engine not implemented!")

            # run commands with correct profile in the venv in the temp folder
            profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml"
            env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir)
            with open(profiles_yml_name, "w") as profiles_file:
                # write profile into profiles.yml file
                yaml.dump(profiles_yml,
                          profiles_file,
                          default_flow_style=False)

                # run dbt commands
                self.log.info("Now running commands dbt!")
                cmd = []
                cmd.append("cd {0}".format(dbt_dir))
                cmd.append("source {0}/bin/activate".format(venv_folder))
                cmd.append("dbt deps")
                [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands]
                cmd.append("deactivate")
                assert run_cmd(cmd, env, self.log.info) == 0

            # if applicable: close SSH tunnel
            if hasattr(self, "ssh_tunnel_forwarder"):
                self.log.info("Stopping!")
                self.ssh_tunnel_forwarder.stop()
                del self.ssh_tunnel_forwarder
Exemple #15
0
 def execute_snowflake(sql, conn_id, **kwargs):
     hook = EWAHBaseHook.get_hook_from_conn_id(conn_id)
     hook.execute(sql)
     hook.close()
Exemple #16
0
    def start_tunnel(self,
                     remote_host: str,
                     remote_port: int,
                     tunnel_timeout: Optional[int] = 30) -> Tuple[str, int]:
        """Starts the SSH tunnel with port forwarding. Returns the host and port tuple
        that can be used to connect to the remote.

        :param remote_host: Host of the remote that the tunnel should port forward to.
            Can be "localhost" e.g. if tunneling into a server that hosts a database.
        :param remote_port: Port that goes along with remote_host.
        :param tunnel_timeout: Optional timeout setting. Supply a higher number if the
            default (30s) is too low.
        :returns: Local bind address aka tuple of local_bind_host and local_bind_port.
            Calls to local_bind_host:local_bind_port will be forwarded to
            remote_host:remote_port via the SSH tunnel.
        """

        if not hasattr(self, "_ssh_tunnel_forwarder"):
            # Tunnel is not started yet - start it now!

            # Set a specific tunnel timeout if applicable
            if tunnel_timeout:
                old_timeout = sshtunnel.TUNNEL_TIMEOUT
                sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout

            try:
                # Build kwargs dict for SSH Tunnel Forwarder
                if self.conn.ssh_proxy_server:
                    # Use the proxy SSH server as target
                    self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id(
                        conn_id=self.conn.ssh_proxy_server, )
                    kwargs = {
                        "ssh_address_or_host":
                        self._ssh_hook.start_tunnel(self.conn.host,
                                                    self.conn.port or 22),
                        "remote_bind_address": (remote_host, remote_port),
                    }
                else:
                    kwargs = {
                        "ssh_address_or_host": (self.conn.host, self.conn.port
                                                or 22),
                        "remote_bind_address": (remote_host, remote_port),
                    }

                if self.conn.username:
                    kwargs["ssh_username"] = self.conn.username
                if self.conn.password:
                    kwargs["ssh_password"] = self.conn.password

                # Save private key in a temporary file, if applicable
                with NamedTemporaryFile() as keyfile:
                    if self.conn.private_key:
                        keyfile.write(self.conn.private_key.encode())
                        keyfile.flush()
                        kwargs["ssh_pkey"] = os.path.abspath(keyfile.name)
                    self.log.info("Opening SSH Tunnel to {0}:{1}...".format(
                        *kwargs["ssh_address_or_host"]))
                    self._ssh_tunnel_forwarder = sshtunnel.SSHTunnelForwarder(
                        **kwargs)
                    self._ssh_tunnel_forwarder.start()
            except:
                # Set package constant back to original setting, if applicable
                if tunnel_timeout:
                    sshtunnel.TUNNEL_TIMEOUT = old_timeout
                raise

        return ("localhost", self._ssh_tunnel_forwarder.local_bind_port)
Exemple #17
0
 def execute(self, context):
     hook = EWAHBaseHook.get_hook_from_conn_id(self.postgres_conn_id)
     hook.execute(sql=self.sql, params=self.parameters, commit=True)
     hook.close()  # SSH tunnel does not close if hook is not closed first
Exemple #18
0
def etl_schema_tasks(
        dag,
        dwh_engine,
        dwh_conn_id,
        target_schema_name,
        target_schema_suffix="_next",
        target_database_name=None,
        read_right_users=None,  # Only for PostgreSQL
        **additional_task_args):

    if dwh_engine == EC.DWH_ENGINE_POSTGRES:
        sql_kickoff = """
            DROP SCHEMA IF EXISTS "{schema_name}{schema_suffix}" CASCADE;
            CREATE SCHEMA "{schema_name}{schema_suffix}";
        """.format(
            schema_name=target_schema_name,
            schema_suffix=target_schema_suffix,
        )
        sql_final = """
            DROP SCHEMA IF EXISTS "{schema_name}" CASCADE;
            ALTER SCHEMA "{schema_name}{schema_suffix}"
                RENAME TO "{schema_name}";
        """.format(
            schema_name=target_schema_name,
            schema_suffix=target_schema_suffix,
        )

        # Don't fail final task just because a user or role that should
        # be granted read rights does not exist!
        grant_rights_sql = """
            DO $$
            BEGIN
              GRANT USAGE ON SCHEMA "{target_schema_name}" TO {user};
              GRANT SELECT ON ALL TABLES
                IN SCHEMA "{target_schema_name}" TO {user};
              EXCEPTION WHEN OTHERS THEN -- catches any error
                RAISE NOTICE 'not granting rights - user does not exist!';
            END
            $$;
        """
        if read_right_users:
            if not isinstance(read_right_users, list):
                raise Exception("Arg read_right_users must be of type List!")
            for user in read_right_users:
                if re.search(r"\s", user) or (";" in user):
                    _msg = "No whitespace or semicolons allowed in usernames!"
                    raise ValueError(_msg)
                sql_final += grant_rights_sql.format(
                    target_schema_name=target_schema_name,
                    user=user,
                )

        task_1_args = deepcopy(additional_task_args)
        task_2_args = deepcopy(additional_task_args)
        task_1_args.update({
            "sql": sql_kickoff,
            "task_id": "kickoff_{0}".format(target_schema_name),
            "dag": dag,
            "postgres_conn_id": dwh_conn_id,
        })
        task_2_args.update({
            "sql": sql_final,
            "task_id": "final_{0}".format(target_schema_name),
            "dag": dag,
            "postgres_conn_id": dwh_conn_id,
        })
        return (PGO(**task_1_args), PGO(**task_2_args))
    elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE:
        target_database_name = target_database_name or (
            EWAHBaseHook.get_connection(dwh_conn_id).database)
        sql_kickoff = """
            DROP SCHEMA IF EXISTS
                "{database}"."{schema_name}{schema_suffix}" CASCADE;
            CREATE SCHEMA "{database}"."{schema_name}{schema_suffix}";
        """.format(
            database=target_database_name,
            schema_name=target_schema_name,
            schema_suffix=target_schema_suffix,
        )
        sql_final = """
            DROP SCHEMA IF EXISTS "{database}"."{schema_name}" CASCADE;
            ALTER SCHEMA "{database}"."{schema_name}{schema_suffix}"
                RENAME TO "{schema_name}";
        """.format(
            database=target_database_name,
            schema_name=target_schema_name,
            schema_suffix=target_schema_suffix,
        )

        def execute_snowflake(sql, conn_id, **kwargs):
            hook = EWAHBaseHook.get_hook_from_conn_id(conn_id)
            hook.execute(sql)
            hook.close()

        task_1_args = deepcopy(additional_task_args)
        task_2_args = deepcopy(additional_task_args)
        task_1_args.update({
            "task_id": "kickoff_{0}".format(target_schema_name),
            "python_callable": execute_snowflake,
            "op_kwargs": {
                "sql": sql_kickoff,
                "conn_id": dwh_conn_id,
            },
            "provide_context": True,
            "dag": dag,
        })
        task_2_args.update({
            "task_id": "final_{0}".format(target_schema_name),
            "python_callable": execute_snowflake,
            "op_kwargs": {
                "sql": sql_final,
                "conn_id": dwh_conn_id,
            },
            "provide_context": True,
            "dag": dag,
        })
        return (PO(**task_1_args), PO(**task_2_args))
    elif dwh_engine == EC.DWH_ENGINE_GS:
        # create dummy tasks
        return (
            DO(
                task_id="kickoff",
                dag=dag,
            ),
            DO(
                task_id="final",
                dag=dag,
            ),
        )
    else:
        raise ValueError("Feature not implemented!")
Exemple #19
0
    def execute(self, context):
        """Why this method is defined here:
        When executing a task, airflow calls this method. Generally, this
        method contains the "business logic" of the individual operator.
        However, EWAH may want to do some actions for all operators. Thus,
        the child operators shall have an ewah_execute() function which is
        called by this general execute() method.
        """

        self.log.info("""

            Running EWAH Operator {0}.
            DWH: {1} (connection id: {2})
            Extract Strategy: {3}
            Load Strategy: {4}

            """.format(
            str(self),
            self.dwh_engine,
            self.dwh_conn_id,
            self.extract_strategy,
            self.load_strategy,
        ))

        # required for metadata in data upload
        self._execution_time = datetime_utcnow_with_tz()
        self._context = context

        cleaner_callables = self.cleaner_callables or []

        if self.source_conn_id:
            # resolve conn id here & delete the object to avoid usage elsewhere
            self.source_conn = EWAHBaseHook.get_connection(self.source_conn_id)
            self.source_hook = self.source_conn.get_hook()
            if callable(
                    getattr(self.source_hook, "get_cleaner_callables", None)):
                hook_callables = self.source_hook.get_cleaner_callables()
                if callable(hook_callables):
                    cleaner_callables.append(hook_callables)
                elif hook_callables:
                    # Ought to be list of callables
                    cleaner_callables += hook_callables
        del self.source_conn_id

        if self._CONN_TYPE:
            assert (self._CONN_TYPE == self.source_conn.conn_type
                    ), "Error - connection type must be {0}!".format(
                        self._CONN_TYPE)

        uploader_callables = self.uploader_class.get_cleaner_callables()
        if callable(uploader_callables):
            cleaner_callables.append(uploader_callables)
        elif uploader_callables:
            cleaner_callables += uploader_callables

        self.uploader = self.uploader_class(
            dwh_conn=EWAHBaseHook.get_connection(self.dwh_conn_id),
            cleaner=self.cleaner_class(
                default_row=self.default_values,
                include_columns=self.include_columns,
                exclude_columns=self.exclude_columns,
                add_metadata=self.add_metadata,
                rename_columns=self.rename_columns,
                hash_columns=self.hash_columns,
                hash_salt=self.hash_salt,
                additional_callables=cleaner_callables,
            ),
            table_name=self.target_table_name,
            schema_name=self.target_schema_name,
            schema_suffix=self.target_schema_suffix,
            database_name=self.target_database_name,
            primary_key=self.primary_key,
            load_strategy=self.load_strategy,
            use_temp_pickling=self.use_temp_pickling,
            pickling_upload_chunk_size=self.pickling_upload_chunk_size,
            pickle_compression=self.pickle_compression,
            deduplication_before_upload=self.deduplication_before_upload,
            **self.additional_uploader_kwargs,
        )

        # If applicable: set the session's default time zone
        if self.default_timezone:
            self.uploader.dwh_hook.execute("SET timezone TO '{0}'".format(
                self.default_timezone))

        # Create a new copy of the target table.
        # This is so data is loaded into a new table and if data loading
        # fails, the original data is not corrupted. At a new try or re-run,
        # the original table is just copied anew.
        if not self.load_strategy == EC.LS_INSERT_REPLACE:
            # insert_replace always drops and replaces the tables completely
            self.uploader.copy_table()

        # set load_data_from and load_data_until as required
        data_from = ada(self.load_data_from)
        data_until = ada(self.load_data_until)
        if self.extract_strategy == EC.ES_INCREMENTAL:
            _tdz = timedelta(days=0)  # aka timedelta zero
            _ed = context["data_interval_start"]
            _ned = context["data_interval_end"]

            # normal incremental load
            _ed -= self.load_data_from_relative or _tdz
            data_from = min(_ed, data_from or _ed)
            if not self.test_if_target_table_exists():
                # Load data from scratch!
                data_from = ada(self.reload_data_from) or data_from

            _ned += self.load_data_until_relative or _tdz
            data_until = max(_ned, data_until or _ned)

        elif self.extract_strategy in (EC.ES_FULL_REFRESH, EC.ES_SUBSEQUENT):
            # Values may still be set as static values
            data_from = ada(self.reload_data_from) or data_from

        else:
            _msg = "Must define load_data_from etc. behavior for load strategy!"
            raise Exception(_msg)

        self.data_from = data_from
        self.data_until = data_until
        # del variables to make sure they are not used later on
        del self.load_data_from
        del self.reload_data_from
        del self.load_data_until
        del self.load_data_until_relative
        if not self.extract_strategy == EC.ES_SUBSEQUENT:
            # keep this param for subsequent loads
            del self.load_data_from_relative

        # Have an option to wait until a short period (e.g. 2 minutes) past
        # the incremental loading range timeframe to ensure that all data is
        # loaded, useful e.g. if APIs lag or if server timestamps are not
        # perfectly accurate.
        # When a DAG is executed as soon as possible, some data sources
        # may not immediately have up to date data from their API.
        # E.g. querying all data until 12.30pm only gives all relevant data
        # after 12.32pm due to some internal delays. In those cases, make
        # sure the (incremental loading) DAGs don't execute too quickly.
        if self.wait_for_seconds and self.extract_strategy == EC.ES_INCREMENTAL:
            wait_until = context.get("data_interval_end")
            if wait_until:
                wait_until += timedelta(seconds=self.wait_for_seconds)
                self.log.info("Awaiting execution until {0}...".format(
                    str(wait_until), ))
            while wait_until and datetime_utcnow_with_tz() < wait_until:
                # Only sleep a maximum of 5s at a time
                wait_for_timedelta = wait_until - datetime_utcnow_with_tz()
                time.sleep(max(0, min(wait_for_timedelta.total_seconds(), 5)))

        # execute operator
        if self.load_data_chunking_timedelta and data_from and data_until:
            # Chunking to avoid OOM
            assert data_until > data_from
            assert self.load_data_chunking_timedelta > timedelta(days=0)
            while self.data_from < data_until:
                self.data_until = min(
                    self.data_from + self.load_data_chunking_timedelta,
                    data_until)
                self.log.info("Now loading from {0} to {1}...".format(
                    str(self.data_from), str(self.data_until)))
                self.ewah_execute(context)
                self.data_from += self.load_data_chunking_timedelta
        else:
            self.ewah_execute(context)

        # Run final scripts
        # TODO: Include indexes into uploader and then remove this step
        self.uploader.finalize_upload()

        # if PostgreSQL and arg given: create indices
        for column in self.index_columns:
            assert self.dwh_engine == EC.DWH_ENGINE_POSTGRES
            # Use hashlib to create a unique 63 character string as index
            # name to avoid breaching index name length limits & accidental
            # duplicates / missing indices due to name truncation leading to
            # identical index names.
            self.uploader.dwh_hook.execute(
                self._INDEX_QUERY.format(
                    "__ewah_" + hashlib.blake2b(
                        (self.target_schema_name + self.target_schema_suffix +
                         "." + self.target_table_name + "." + column).encode(),
                        digest_size=28,
                    ).hexdigest(),
                    self.target_schema_name + self.target_schema_suffix,
                    self.target_table_name,
                    column,
                ))

        # commit only at the end, so that no data may be committed before an
        # error occurs.
        self.log.info("Now committing changes!")
        self.uploader.commit()
        self.uploader.close()
Exemple #20
0
    def execute_for_shop(
        self,
        context,
        shop_id,
        params,
        source_conn_id,
        auth_type,
    ):
        # Get data from shopify via REST API
        def add_get_transactions(data, shop, version, req_kwargs):
            # workaround to add transactions to orders
            self.log.info("Requesting transactions of orders...")
            base_url = "https://{shop}.myshopify.com/admin/api/{version}/orders/{id}/transactions.json"
            base_url = base_url.format(**{
                "shop": shop,
                "version": version,
                "id": "{id}",
            })

            for datum in data:
                id = datum["id"]
                # self.log.info('getting transactions for order {0}'.format(id))
                time.sleep(
                    1)  # avoid hitting api call requested per second limit
                url = base_url.format(id=id)
                req = requests.get(url, **req_kwargs)
                if not req.status_code == 200:
                    self.log.info("response: " + str(req.status_code))
                    self.log.info("request text: " + req.text)
                    raise Exception("non-200 response!")
                transactions = json.loads(req.text).get("transactions", [])
                datum["transactions"] = transactions

            return data

        def add_get_inventoryitems(data, shop, version, req_kwargs):
            # workaround to get inventory item data (i.e. costs) for products
            self.log.info("Requesting inventory items of product variants...")
            base_url = (
                "https://{shop}.myshopify.com/admin/api/{version}/inventory_items.json"
            )
            url = base_url.format(
                shop=shop,
                version=version,
            )

            kwargs = copy.deepcopy(req_kwargs)

            for datum in data:
                ids = [
                    v["inventory_item_id"] for v in datum.get("variants", [])
                ]
                if ids:
                    kwargs["params"] = {"ids": copy.deepcopy(ids)}
                    time.sleep(1)  # avoid hitting api call requested limit
                    req = requests.get(url, **kwargs)
                    if not req.status_code == 200:
                        self.log.info("response: " + str(req.status_code))
                        self.log.info("request text: " + req.text)
                        raise Exception("non-200 response!")
                    inv_items = json.loads(req.text).get("inventory_items", [])
                    datum["inventory_items"] = inv_items

            return data

        def add_get_events(data, shop, version, req_kwargs):
            # workaround to add events of an order to orders
            self.log.info("Requesting events of orders...")
            base_url = "https://{shop}.myshopify.com/admin/api/{version}/orders/{id}/events.json"
            base_url = base_url.format(
                shop=shop,
                version=version,
                id="{id}",
            )

            for datum in data:
                id = datum["id"]
                time.sleep(1)
                url = base_url.format(id=id)
                req = requests.get(url, **req_kwargs)
                if not req.status_code == 200:
                    self.log.info("response: " + str(req.status_code))
                    self.log.info("request text: " + req.text)
                    raise Exception("non-200 response!")
                events = json.loads(req.text).get("events", [])
                datum["events"] = events

            return data

        url = self._base_url.format(
            **{
                "shop":
                shop_id,
                "version":
                self.api_version,
                "object":
                self.object_metadata.get(
                    "_object_url",
                    self.shopify_object,
                ),
            })

        # get connection for the applicable shop
        conn = BaseHook.get_connection(source_conn_id)
        login = conn.login
        password = conn.password

        if auth_type == "access_token":
            headers = {
                "X-Shopify-Access-Token": password,
            }
            kwargs_init = {
                "headers": headers,
                "params": params,
            }
            kwargs_links = {"headers": headers}
        elif auth_type == "basic_auth":
            kwargs_init = {
                "params": params,
                "auth": HTTPBasicAuth(login, password),
            }
            kwargs_links = {"auth": HTTPBasicAuth(login, password)}
        else:
            raise Exception("Authentication type not accepted!")

        # get and upload data
        self.log.info(
            "Requesting data from REST API - url: {0}, params: {1}".format(
                url, str(params)))
        req_kwargs = kwargs_init
        is_first = True
        while is_first or (r.status_code == 200 and url):
            r = requests.get(url, **req_kwargs)
            if is_first:
                is_first = False
                req_kwargs = kwargs_links
            data = json.loads(r.text or "{}").get(
                self.object_metadata.get(
                    "_name_in_request_data",
                    self.shopify_object,
                ))
            if self.get_transactions_with_orders:
                data = add_get_transactions(
                    data=data,
                    shop=shop_id,
                    version=self.api_version,
                    req_kwargs=kwargs_links,
                )
            if self.get_events_with_orders:
                data = add_get_events(
                    data=data,
                    shop=shop_id,
                    version=self.api_version,
                    req_kwargs=kwargs_links,
                )
            if self.get_inventory_data_with_product_variants:
                data = add_get_inventoryitems(
                    data=data,
                    shop=shop_id,
                    version=self.api_version,
                    req_kwargs=kwargs_links,
                )
            self.upload_data(data)
            self.log.info("Requesting next page of data...")
            if r.headers.get("Link") and r.headers["Link"][-9:] == 'el="next"':
                url = r.headers["Link"][1:-13]
            else:
                url = None

        if not r.status_code == 200:
            raise Exception(
                "Shopify request returned an error {1}: {0}".format(
                    r.text,
                    str(r.status_code),
                ))