def mongoclient(self): if not hasattr(self, "_mc"): if self.conn.ssh_conn_id: if self.conn.conn_style == "uri": raise Exception( "Cannot have SSH tunnel with uri connection type!") if not hasattr(self, "_ssh_hook"): self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_conn_id) self.local_bind_address = self._ssh_hook.start_tunnel( self.conn.host, self.conn.port) else: self.local_bind_address = (self.conn.host, self.conn.port) conn_kwargs = {"tz_aware": True} if self.conn.conn_style == "uri": conn_kwargs["host"] = self.conn.uri else: conn_kwargs["host"] = self.local_bind_address[0] conn_kwargs["port"] = self.local_bind_address[1] if self.conn.username: conn_kwargs["username"] = self.conn.username if self.conn.password: conn_kwargs["password"] = self.conn.password with TemporaryDirectory() as tmp_dir: if self.conn.tls: conn_kwargs["tls"] = True with NamedTemporaryFile(dir=tmp_dir) as ssl_cert: with NamedTemporaryFile(dir=tmp_dir) as ssl_private: if self.conn.ssl_cert: ssl_cert.write(self.conn.ssl_cert.encode()) ssl_cert.seek(0) conn_kwargs["ssl_certfile"] = os.path.abspath( ssl_cert.name) if self.conn.ssl_private: ssl_private.write(self.conn.ssl_private.encode()) ssl_private.seek(0) conn_kwargs["ssl_keyfile"] = os.path.abspath( ssl_private.name) if self.conn.ssl_password: conn_kwargs[ "tlsCertificateKeyFilePassword"] = self.conn.ssl_password if self.conn.tls_insecure: conn_kwargs["tlsInsecure"] = True if self.conn.auth_source: conn_kwargs["authSource"] = self.conn.auth_source if self.conn.auth_mechanism: conn_kwargs[ "authMechanism"] = self.conn.auth_mechanism self._mc = MongoClient(**conn_kwargs) return self._mc
def dbconn(self): if not hasattr(self, "_dbconn"): if hasattr(self.conn, "ssh_conn_id") and self.conn.ssh_conn_id: if not hasattr(self, "_ssh_hook"): self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_conn_id) self.local_bind_address = self._ssh_hook.start_tunnel( self.conn.host, self.conn.port or self._DEFAULT_PORT) else: self.local_bind_address = self.conn.host, self.conn.port self._dbconn = self._get_db_conn() return self._dbconn
def kickoff_func(schema_full, schema_suffix, dwh_conn_id): # kickoff: create new dataset schema = schema_full + schema_suffix conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn # delete dataset first if it already exists print("Deleting the dataset {0} if it already exists.".format( schema)) conn.delete_dataset(schema, delete_contents=True, not_found_ok=True) print("Creating the dataset {0}.".format(schema)) conn.create_dataset(schema) print("Done!")
def __init__( self, workbook_key, # can be seen in the URL of the workbook sheet_key, # name of the worksheet sheet_columns, # list or dict[column name, position] defining which columns to load start_row=2, # in what row does the data begin? end_row=None, # optional: what is the last row? None gets all data *args, **kwargs): super().__init__(*args, **kwargs) credentials = BaseHook.get_connection(self.source_conn_id).extra_dejson credentials = credentials.get("client_secrets", credentials) _msg = "Google Service Account Credentials misspecified!" _msg += " Example of a correct specifidation: {0}".format( json.dumps(self._SAMPLE_JSON)) for key in self._SAMPLE_JSON["client_secrets"]: if not key in credentials: raise Exception(_msg) column_match = {} if isinstance(sheet_columns, list): i = 0 for column in sheet_columns: i += 1 column_match[i] = column elif isinstance(sheet_columns, dict): column_match = { self._translate_alphanumeric_column(value): key for key, value in sheet_columns.items() } else: raise Exception("sheet_columns must be a list or a dict!") self.client_secrets = credentials self.column_match = column_match self.workbook_key = workbook_key self.sheet_key = sheet_key self.start_row = start_row self.end_row = end_row
def __init__( self, workbook_key, # can be seen in the URL of the workbook sheet_key, # name of the worksheet start_row=2, # in what row does the data begin? end_row=None, # optional: what is the last row? None gets all data *args, **kwargs): super().__init__(*args, **kwargs) credentials = BaseHook.get_connection(self.source_conn_id).extra_dejson credentials = credentials.get("client_secrets", credentials) _msg = "Google Service Account Credentials misspecified!" _msg += " Example of a correct specifidation: {0}".format( json.dumps(self._SAMPLE_JSON)) for key in self._SAMPLE_JSON["client_secrets"]: if not key in credentials: raise Exception(_msg) column_match = {} for col_key, col_def in self.columns_definition.items(): if (not col_def) or (not col_def.get(EC.QBC_FIELD_GSHEET_COLNO)): raise Exception( ("Column {0} is missing information regarding the " + "position of the column in the sheet.").format(col_key)) column_match.update({ self._translate_alphanumeric_column( col_def[EC.QBC_FIELD_GSHEET_COLNO], ): col_key, }) self.client_secrets = credentials self.column_match = column_match self.workbook_key = workbook_key self.sheet_key = sheet_key self.start_row = start_row self.end_row = end_row
def execute(self, context): """Why this method is defined here: When executing a task, airflow calls this method. Generally, this method contains the "business logic" of the individual operator. However, EWAH may want to do some actions for all operators. Thus, the child operators shall have an ewah_execute() function which is called by this general execute() method. """ # required for metadata in data upload self._execution_time = datetime_utcnow_with_tz() self._context = context self.uploader = self.uploader( EWAHBaseHook.get_connection(self.dwh_conn_id)) if self.source_conn_id: # resolve conn id here & delete the object to avoid usage elsewhere self.source_conn = EWAHBaseHook.get_connection(self.source_conn_id) self.source_hook = self.source_conn.get_hook() del self.source_conn_id if self._CONN_TYPE: _msg = "Error - connection type must be {0}!".format( self._CONN_TYPE) assert self._CONN_TYPE == self.source_conn.conn_type, _msg temp_schema_name = self.target_schema_name + self.target_schema_suffix # Create a new copy of the target table. # This is so data is loaded into a new table and if data loading # fails, the original data is not corrupted. At a new try or re-run, # the original table is just copited anew. if not self.extract_strategy == EC.ES_FULL_REFRESH: # Full refresh always drops and replaces the tables completely self.uploader.copy_table( old_schema=self.target_schema_name, old_table=self.target_table_name, new_schema=temp_schema_name, new_table=self.target_table_name, database_name=self.target_database_name, ) # set load_data_from and load_data_until as required data_from = ada(self.load_data_from) data_until = ada(self.load_data_until) if self.extract_strategy == EC.ES_INCREMENTAL: _tdz = timedelta(days=0) # aka timedelta zero _ed = context["execution_date"] _ned = context["next_execution_date"] # normal incremental load _ed -= self.load_data_from_relative or _tdz data_from = max(_ed, data_from or _ed) if not self.test_if_target_table_exists(): # Load data from scratch! data_from = ada(self.reload_data_from) or data_from _ned += self.load_data_until_relative or _tdz data_until = min(_ned, data_until or _ned) elif self.extract_strategy == EC.ES_FULL_REFRESH: # Values may still be set as static values data_from = ada(self.reload_data_from) or data_from else: _msg = "Must define load_data_from etc. behavior for load strategy!" raise Exception(_msg) self.data_from = data_from self.data_until = data_until # del variables to make sure they are not used later on del self.load_data_from del self.reload_data_from del self.load_data_until del self.load_data_from_relative del self.load_data_until_relative # Have an option to wait until a short period (e.g. 2 minutes) past # the incremental loading range timeframe to ensure that all data is # loaded, useful e.g. if APIs lag or if server timestamps are not # perfectly accurate. # When a DAG is executed as soon as possible, some data sources # may not immediately have up to date data from their API. # E.g. querying all data until 12.30pm only gives all relevant data # after 12.32pm due to some internal delays. In those cases, make # sure the (incremental loading) DAGs don't execute too quickly. if self.wait_for_seconds and self.extract_strategy == EC.ES_INCREMENTAL: wait_until = context.get("next_execution_date") if wait_until: wait_until += timedelta(seconds=self.wait_for_seconds) self.log.info("Awaiting execution until {0}...".format( str(wait_until), )) while wait_until and datetime_utcnow_with_tz() < wait_until: # Only sleep a maximum of 5s at a time wait_for_timedelta = wait_until - datetime_utcnow_with_tz() time.sleep(max(0, min(wait_for_timedelta.total_seconds(), 5))) # execute operator if self.load_data_chunking_timedelta and data_from and data_until: # Chunking to avoid OOM assert data_until > data_from assert self.load_data_chunking_timedelta > timedelta(days=0) while self.data_from < data_until: self.data_until = self.data_from self.data_until += self.load_data_chunking_timedelta self.data_until = min(self.data_until, data_until) self.ewah_execute(context) self.data_from += self.load_data_chunking_timedelta else: self.ewah_execute(context) # if PostgreSQL and arg given: create indices for column in self.index_columns: assert self.dwh_engine == EC.DWH_ENGINE_POSTGRES # Use hashlib to create a unique 63 character string as index # name to avoid breaching index name length limits & accidental # duplicates / missing indices due to name truncation leading to # identical index names. self.uploader.dwh_hook.execute( self._INDEX_QUERY.format( "__ewah_" + hashlib.blake2b( (temp_schema_name + "." + self.target_table_name + "." + column).encode(), digest_size=28, ).hexdigest(), self.target_schema_name + self.target_schema_suffix, self.target_table_name, column, )) # commit only at the end, so that no data may be committed before an # error occurs. self.log.info("Now committing changes!") self.uploader.commit() self.uploader.close()
def final_func(schema_name, schema_suffix, dwh_conn_id): # final: move new data into the final dataset conn = EWAHBaseHook.get_hook_from_conn_id(dwh_conn_id).dbconn # get dataset objects try: # create final dataset if not exists ds_final = conn.get_dataset(schema_name) except: print("Creating dataset {0}".format(schema_name)) ds_final = conn.create_dataset(schema_name) ds_temp = conn.get_dataset(schema_name + schema_suffix) # copy all tables from temp dataset to final dataset new_tables = conn.list_tables(ds_temp) new_table_ids = [ table.table_id for table in conn.list_tables(ds_temp) ] old_table_ids = [ table.table_id for table in conn.list_tables(ds_final) ] copy_jobs = [] for table in new_tables: print("Copying table {0} from temp to final dataset".format( table.table_id)) try: old_table = conn.get_table(table=TableReference( dataset_ref=ds_final, table_id=table.table_id)) conn.delete_table(old_table) except: # ignore failure, fails if old table does not exist to begin with pass finally: final_table = ds_final.table(table.table_id) copy_jobs.append(conn.copy_table(table, final_table)) # delete tables that don't exist in temp dataset from final dataset for table_id in old_table_ids: if not table_id in new_table_ids: print("Deleting table {0}".format(table_id)) conn.delete_table( conn.get_table( TableReference(dataset_ref=ds_final, table_id=table_id))) # make sure all copy jobs succeeded while copy_jobs: sleep(0.1) job = copy_jobs.pop(0) job.result() assert job.state in ("RUNNING", "DONE") if job.state == "RUNNING": copy_jobs.append(job) else: print("Successfully copied {0}".format( job.__dict__["_properties"]["configuration"]["copy"] ["destinationTable"]["tableId"])) # delete temp dataset print("Deleting temp dataset.") conn.delete_dataset(ds_temp, delete_contents=True, not_found_ok=False) print("Done.")
def dbt_dags_factory_legacy( dwh_engine, dwh_conn_id, project_name, dbt_schema_name, airflow_conn_id, dag_base_name="DBT_run", analytics_reader=None, # list of users of DWH who are read-only schedule_interval=timedelta(hours=1), start_date=datetime(2019, 1, 1), default_args=None, folder=None, models=None, exclude=None, ): if analytics_reader: for statement in ( "insert", "update", "delete", "drop", "create", "select", ";", "grant", ): for reader in analytics_reader: if statement in reader.lower(): raise Exception("Error! The analytics reader {0} " + "is invalid.".format(reader)) # analytics_reader = analytics_reader.split(',') analytics_reader_sql = f'\nGRANT USAGE ON SCHEMA "{dbt_schema_name}"' analytics_reader_sql += ' TO "{0}";' analytics_reader_sql += (f''' \nGRANT SELECT ON ALL TABLES IN SCHEMA "{dbt_schema_name}"''' + ' TO "{0}";') analytics_reader_sql = "".join( [analytics_reader_sql.format(i) for i in analytics_reader]) if models and not (type(models) == str): models = " --models " + " ".join(models) else: models = "" if exclude and not (type(exclude) == str): exclude = " --exclude " + " ".join(exclude) else: exclude = "" flags = models + exclude dag = DAG( dag_base_name, catchup=False, max_active_runs=1, schedule_interval=schedule_interval, start_date=start_date, default_args=default_args, ) dag_full_refresh = DAG( dag_base_name + "_full_refresh", catchup=False, max_active_runs=1, schedule_interval=None, start_date=start_date, default_args=default_args, ) folder = folder or (os.environ.get("AIRFLOW_HOME") or conf.get("core", "airflow_home")).replace( "airflow_home/airflow", "dbt_home", ) bash_command = """ cd {1} source env/bin/activate cd {2} dbt {0} """.format( "{0}", folder, project_name, ) sensor_sql = """ SELECT CASE WHEN COUNT(*) = 0 THEN 1 ELSE 0 END -- only run if exatly equal to 0 FROM public.dag_run WHERE dag_id IN ('{0}', '{1}') and state = 'running' and not (run_id = '{2}') """.format( dag._dag_id, dag_full_refresh._dag_id, "{{ run_id }}", ) # refactor?! not coupled to values in profiles.yml! if dwh_engine == EC.DWH_ENGINE_POSTGRES: conn = BaseHook.get_connection(dwh_conn_id) env = { "DBT_DWH_HOST": str(conn.host), "DBT_DWH_USER": str(conn.login), "DBT_DWH_PASS": str(conn.password), "DBT_DWH_PORT": str(conn.port), "DBT_DWH_DBNAME": str(conn.schema), "DBT_DWH_SCHEMA": dbt_schema_name, "DBT_PROFILES_DIR": folder, } elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE: analytics_conn = BaseHook.get_connection(dwh_conn_id) analytics_conn_extra = analytics_conn.extra_dejson env = { "DBT_ACCOUNT": analytics_conn_extra.get( "account", analytics_conn.host, ), "DBT_USER": analytics_conn.login, "DBT_PASS": analytics_conn.password, "DBT_ROLE": analytics_conn_extra.get("role"), "DBT_DB": analytics_conn_extra.get("database"), "DBT_WH": analytics_conn_extra.get("warehouse"), "DBT_SCHEMA": dbt_schema_name, "DBT_PROFILES_DIR": folder, } else: raise ValueError("DWH type not implemented!") # with dag: snsr = EWAHSqlSensor( task_id="sense_dbt_conflict_avoided", conn_id=airflow_conn_id, sql=sensor_sql, poke_interval=5 * 60, mode="reschedule", # don't block a worker and pool slot dag=dag, ) dbt_seed = BashOperator( task_id="run_dbt_seed", bash_command=bash_command.format("seed"), env=env, dag=dag, ) dbt_run = BashOperator( task_id="run_dbt", bash_command=bash_command.format("run" + flags), env=env, dag=dag, ) dbt_test = BashOperator( task_id="test_dbt", bash_command=bash_command.format("test" + flags), env=env, dag=dag, ) dbt_docs = BashOperator( task_id="create_dbt_docs", bash_command=bash_command.format("docs generate"), env=env, dag=dag, ) snsr >> dbt_seed >> dbt_run >> dbt_test if analytics_reader: # This should not occur when using Snowflake read_rights = PostgresOperator( task_id="grant_access_to_read_users", sql=analytics_reader_sql, postgres_conn_id=dwh_conn_id, dag=dag, ) dbt_test >> read_rights >> dbt_docs else: dbt_test >> dbt_docs # with dag_full_refresh: snsr = EWAHSqlSensor( task_id="sense_dbt_conflict_avoided", conn_id=airflow_conn_id, sql=sensor_sql, poke_interval=5 * 60, mode="reschedule", # don't block a worker and pool slot dag=dag_full_refresh, ) dbt_seed = BashOperator( task_id="run_dbt_seed", bash_command=bash_command.format("seed"), env=env, dag=dag_full_refresh, ) dbt_run = BashOperator( task_id="run_dbt", bash_command=bash_command.format("run --full-refresh" + flags), env=env, dag=dag_full_refresh, ) dbt_test = BashOperator( task_id="test_dbt", bash_command=bash_command.format("test" + flags), env=env, dag=dag_full_refresh, ) dbt_docs = BashOperator( task_id="create_dbt_docs", bash_command=bash_command.format("docs generate"), env=env, dag=dag_full_refresh, ) snsr >> dbt_seed >> dbt_run >> dbt_test if analytics_reader: read_rights = PostgresOperator( task_id="grant_access_to_read_users", sql=analytics_reader_sql, postgres_conn_id=dwh_conn_id, dag=dag_full_refresh, ) dbt_test >> read_rights >> dbt_docs else: dbt_test >> dbt_docs return (dag, dag_full_refresh)
def execute(self, context): # env to be used in processes later env = os.environ.copy() env["PIP_USER"] = "******" # create a new temp folder, all action happens in here with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir: # clone repo into temp directory repo_dir = tmp_dir + os.path.sep + "repo" if self.repo_type == "git": # Clone repo into temp folder git_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.git_conn_id) git_hook.clone_repo(repo_dir, env) elif self.repo_type == "local": # Copy local version of the repository into temp folder copy_tree(self.local_path, repo_dir) else: raise Exception("Not Implemented!") # create a virual environment in temp folder venv_folder = tmp_dir + os.path.sep + "venv" self.log.info( "creating a new virtual environment in {0}...".format( venv_folder, )) venv.create(venv_folder, with_pip=True) dbt_dir = repo_dir if self.subfolder: if not self.subfolder[:1] == os.path.sep: self.subfolder = os.path.sep + self.subfolder dbt_dir += self.subfolder dwh_hook = EWAHBaseHook.get_hook_from_conn_id(self.dwh_conn_id) # in case of SSH: execute a query to create the connection and tunnel dwh_hook.execute("SELECT 1 AS a -- Testing the connection") dwh_conn = dwh_hook.conn # read profile name and dbt version & create temporary profiles.yml project_yml_file = dbt_dir if not project_yml_file[-1:] == os.path.sep: project_yml_file += os.path.sep project_yml_file += "dbt_project.yml" project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader) profile_name = project_yml["profile"] dbt_version = self.dbt_version or project_yml.get( "require-dbt-version") del self.dbt_version # Make sure it can't accidentally be used below assert dbt_version, "Must supply dbt_version or set require-dbt-version!" if isinstance(dbt_version, str): if not dbt_version.startswith(("=", "<", ">")): dbt_version = "==" + dbt_version elif isinstance(dbt_version, list): dbt_version = ",".join(dbt_version) else: raise Exception( "dbt_version must be a string or a list of strings!") self.log.info('Creating temp profile "{0}"'.format(profile_name)) profiles_yml = { "config": { "send_anonymous_usage_stats": False, "use_colors": False, # colors won't be useful in logs }, } if self.dwh_engine == EC.DWH_ENGINE_POSTGRES: mb_database = dwh_conn.schema profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for postgres "type": "postgres", "host": dwh_hook.local_bind_address[0], "port": dwh_hook.local_bind_address[1], "user": dwh_conn.login, "pass": dwh_conn.password, "dbname": dwh_conn.schema, "schema": self.schema_name, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE: mb_database = self.database_name or dwh_conn.database profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for snowflake "type": "snowflake", "account": dwh_conn.account, "user": dwh_conn.user, "password": dwh_conn.password, "role": dwh_conn.role, "database": self.database_name or dwh_conn.database, "warehouse": dwh_conn.warehouse, "schema": self.schema_name or dwh_conn.schema, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } elif self.dwh_engine == EC.DWH_ENGINE_BIGQUERY: mb_database = self.database_name profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { "type": "bigquery", "method": "service-account-json", "project": self.database_name, "dataset": self.schema_name, "threads": self.threads, "timeout_seconds": self.keepalives_idle or 300, "priority": "interactive", "keyfile_json": json.loads(dwh_conn.service_account_json), }, }, } if dwh_conn.location: profiles_yml[profile_name]["outputs"]["prod"][ "location"] = dwh_conn.location else: raise Exception("DWH Engine not implemented!") # install dbt into created venv cmd = [] cmd.append("source {0}/bin/activate".format(venv_folder)) cmd.append("pip install --quiet --upgrade pip setuptools") if re.search("[^0-9\.]0(\.[0-9]+)?(\.[0-9]+)?$", dbt_version): # regex checks whether the (last) version start with 0 # if true, version <1.0.0 required cmd.append( 'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt{0}"' .format(dbt_version)) else: # Different pip behavior since dbt 1.0.0 cmd.append( 'pip install --quiet --upgrade "MarkupSafe<=2.0.1" "dbt-{0}{1}"' .format( { EC.DWH_ENGINE_POSTGRES: "postgres", EC.DWH_ENGINE_SNOWFLAKE: "snowflake", EC.DWH_ENGINE_BIGQUERY: "bigquery", }[self.dwh_engine], dbt_version, )) cmd.append("dbt --version") cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 # run commands with correct profile in the venv in the temp folder profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml" env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir) with open(profiles_yml_name, "w") as profiles_file: # write profile into profiles.yml file yaml.dump(profiles_yml, profiles_file, default_flow_style=False) # run dbt commands self.log.info("Now running commands dbt!") cmd = [] cmd.append("cd {0}".format(dbt_dir)) cmd.append("source {0}/bin/activate".format(venv_folder)) if self.repo_type == "local": cmd.append("dbt clean") cmd.append("dbt deps") [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands] cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 if self.metabase_conn_id: # Push docs to Metabase at the end of the run! metabase_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.metabase_conn_id) metabase_hook.push_dbt_docs_to_metabase( dbt_project_path=dbt_dir, dbt_database_name=mb_database, )
def __init__( self, account_ids, insight_fields, level, time_increment=1, breakdowns=None, execution_waittime_seconds=15, # wait for a while before execution # between account_ids to avoid hitting rate limits during backfill pagination_limit=1000, async_job_read_frequency_seconds=5, *args, **kwargs): if kwargs.get("update_on_columns"): raise Exception("update_on_columns is set by operator!") if not account_ids.__iter__: raise Exception( "account_ids must be an iterable, such as a list," + " of strings or integers!") if level == self.levels.ad: kwargs["update_on_columns"] = [ "ad_id", "date_start", "date_stop", ] + (breakdowns or []) insight_fields += ["ad_id", "ad_name"] insight_fields = list(set(insight_fields)) else: raise Exception("Specified level not supported!") if not ((type(time_increment) == str and time_increment in ["monthly", "all_days"]) or (type(time_increment) == int and time_increment >= 1 and time_increment <= 1)): raise Exception( "time_increment must either be an integer " + 'between 1 and 90, or a string of either "monthly" ' + 'or "all_days". Recommended and default is the integer 1.') allowed_insight_fields = [ _attr[1] for _attr in [ member for member in inspect.getmembers( AdsInsights.Field, lambda a: not (inspect.isroutine(a)), ) if not ( member[0].startswith("__") and member[0].endswith("__")) ] ] for i_f in insight_fields: if not i_f in allowed_insight_fields: raise Exception(( "Field {0} is not an accepted value for insight_fields! " + "Accepted field values:\n\t{1}\n").format( i_f, "\n\t".join(allowed_insight_fields))) super().__init__(*args, **kwargs) credentials = BaseHook.get_connection(self.source_conn_id) extra = credentials.extra_dejson # Note: app_secret is not always required! if not extra.get("app_id"): raise Exception('Connection extra must contain an "app_id"!') if not extra.get("access_token", credentials.password): raise Exception( 'Connection extra must contain an "access_token" ' + "if it is not saved as the connection password!") self.credentials = { "app_id": extra.get("app_id"), "app_secret": extra.get("app_secret"), "access_token": extra.get("access_token", credentials.password), } self.account_ids = account_ids self.insight_fields = insight_fields self.level = level self.time_increment = time_increment self.breakdowns = breakdowns self.execution_waittime_seconds = execution_waittime_seconds self.pagination_limit = pagination_limit self.async_job_read_frequency_seconds = async_job_read_frequency_seconds
def _get_hook(self): conn = EWAHBaseHook.get_connection(conn_id=self.conn_id) if not conn.conn_type.startswith("ewah"): raise Exception( "Must use an appropriate EWAH custom connection type!") return conn.get_hook()
def get_data_in_batches(self, endpoint, page_size=100, batch_size=10000, data_from=None): endpoint = self._ENDPOINTS.get(endpoint, endpoint) params = {} if data_from: assert endpoint == self._ENDPOINT_DAGRUNS # get DagRuns that ended since data_from params["end_date_gte"] = data_from.isoformat() params["order_by"] = "end_date" auth = requests.auth.HTTPBasicAuth(self.conn.login, self.conn.password) if self.conn.ssh_conn_id: ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_conn_id) ssh_host = self.conn.host or "localhost" ssh_host = ssh_host.replace("https://", "").replace("http://", "") ssh_port = self.conn.port if not ssh_port: if self.conn.protocol == "http": ssh_port = 80 else: ssh_port = 443 else: ssh_port = int(ssh_port) local_bind_address = ssh_hook.start_tunnel(ssh_host, ssh_port) host = "{2}://{0}:{1}".format( local_bind_address[0], str(local_bind_address[1]), self.conn.protocol or "http", ) else: host = self.conn.host if not host.startswith("http"): host = self.conn.protocol + "://" + host url = self._BASE_URL.format(host, endpoint) params["limit"] = page_size params["offset"] = 0 data = [] i = 0 while True: i += 1 self.log.info("Making request {0} to {1}...".format(i, url)) request = requests.get(url, params=params, auth=auth) assert request.status_code == 200, request.text response = request.json() keys = list(response.keys()) if "total_entries" in keys: # Most endpoint use pagination + give "total_entries" for requests # The key to get the data from the response may differ from endpoint if keys[0] == "total_entries": data_key = keys[1] else: data_key = keys[0] if not response[data_key]: # You know that you fetched all items if an empty list is returned # (Note: total_entries is not reliable) yield data data = [] break data += response[data_key] if len(data) >= batch_size: yield data data = [] else: # Rare endpoint that does not paginate (usually singletons) yield [response] break params["offset"] = params["offset"] + params["limit"] if self.conn.ssh_conn_id: ssh_hook.stop_tunnel() del ssh_hook
def __init__( self, api, # one of _API_CORE_V3, _API_CORE_V4, _API_MULTI view_id, dimensions, metrics, page_size=10000, include_empty_rows=True, sampling_level=None, *args, **kwargs): if kwargs.get("update_on_columns"): raise Exception("update_on_columns supplied, but the field is " + "auto-generated by the operator!") if not api in self._ACCEPTED_API: raise Exception("api must be one of these: {0}".format( ", ".join(self._ACCEPTED_API), )) if api == self._API_MULTI: shorthand = "mcf:" else: shorthand = "ga:" dimensions = [("" if dim.startswith(shorthand) else shorthand) + dim for dim in dimensions] metrics = [("" if metric.startswith(shorthand) else shorthand) + metric for metric in metrics] kwargs.update({"update_on_columns": [dim[3:] for dim in dimensions]}) self.api = api self.view_id = view_id self.sampling_level = sampling_level self.dimensions = dimensions self.metrics = metrics self.page_size = page_size self.include_empty_rows = include_empty_rows self.metricMap = { "METRIC_TYPE_UNSPECIFIED": "varchar(255)", "CURRENCY": "decimal(20,5)", "INTEGER": "int(11)", "FLOAT": "decimal(20,5)", "PERCENT": "decimal(20,5)", "TIME": "time", } super().__init__(*args, **kwargs) credentials = BaseHook.get_connection(self.source_conn_id).extra_dejson if not credentials.get("client_secrets"): _msg = "Google Analytics Credentials misspecified!" _msg += " Example of a correct specifidation: {0}".format( json.dumps(self._SAMPLE_JSON)) for key in self._SAMPLE_JSON["client_secrets"]: if not key in credentials: raise Exception(_msg) if len(dimensions) > 7: raise Exception( ("Can only fetch up to 7 dimensions!" + " Currently {0} Dimensions").format(str(len(dimensions)), )) if len(metrics) > 10: raise Exception( ("Can only fetch up to 10 metrics!" + " Currently {0} Dimensions").format(str(len(metrics)), )) if self.page_size > 10000: raise Exception( "Please specify a page size equal to or lower than 10000.")
def execute(self, context): # env to be used in processes later env = os.environ.copy() # create a new temp folder, all action happens in here with TemporaryDirectory(prefix="__ewah_dbt_operator_") as tmp_dir: # clone repo into temp directory repo_dir = tmp_dir + os.path.sep + "repo" if self.repo_type == "git": git_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.git_conn_id) git_hook.clone_repo(repo_dir, env) else: raise Exception("Not Implemented!") # create a virual environment in temp folder venv_folder = tmp_dir + os.path.sep + "venv" self.log.info( "creating a new virtual environment in {0}...".format( venv_folder, )) venv.create(venv_folder, with_pip=True) # install dbt into created venv self.log.info("installing dbt=={0}".format(self.dbt_version)) cmd = [] cmd.append("source {0}/bin/activate".format(venv_folder)) cmd.append("pip install --quiet --upgrade dbt=={0}".format( self.dbt_version)) cmd.append("dbt --version") cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 dbt_dir = repo_dir if self.subfolder: if not self.subfolder[:1] == os.path.sep: self.subfolder = os.path.sep + self.subfolder dbt_dir += self.subfolder dwh_conn = EWAHBaseHook.get_connection(self.dwh_conn_id) # read profile name & create temporary profiles.yml project_yml_file = dbt_dir if not project_yml_file[-1:] == os.path.sep: project_yml_file += os.path.sep project_yml_file += "dbt_project.yml" project_yml = yaml.load(open(project_yml_file, "r"), Loader=Loader) profile_name = project_yml["profile"] self.log.info('Creating temp profile "{0}"'.format(profile_name)) profiles_yml = { "config": { "send_anonymous_usage_stats": False, "use_colors": False, # colors won't be useful in logs }, } if self.dwh_engine == EC.DWH_ENGINE_POSTGRES: profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for postgres "type": "postgres", "host": dwh_conn.host, "port": dwh_conn.port or "5432", "user": dwh_conn.login, "pass": dwh_conn.password, "dbname": dwh_conn.schema, "schema": self.schema_name, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } elif self.dwh_engine == EC.DWH_ENGINE_SNOWFLAKE: profiles_yml[profile_name] = { "target": "prod", # same as the output defined below "outputs": { "prod": { # for snowflake "type": "snowflake", "account": dwh_conn.account, "user": dwh_conn.user, "password": dwh_conn.password, "role": dwh_conn.role, "database": self.database_name or dwh_conn.database, "warehouse": dwh_conn.warehouse, "schema": self.schema_name or dwh_conn.schema, "threads": self.threads, "keepalives_idle": self.keepalives_idle, }, }, } else: raise Exception("DWH Engine not implemented!") # run commands with correct profile in the venv in the temp folder profiles_yml_name = tmp_dir + os.path.sep + "profiles.yml" env["DBT_PROFILES_DIR"] = os.path.abspath(tmp_dir) with open(profiles_yml_name, "w") as profiles_file: # write profile into profiles.yml file yaml.dump(profiles_yml, profiles_file, default_flow_style=False) # run dbt commands self.log.info("Now running commands dbt!") cmd = [] cmd.append("cd {0}".format(dbt_dir)) cmd.append("source {0}/bin/activate".format(venv_folder)) cmd.append("dbt deps") [cmd.append("dbt {0}".format(dc)) for dc in self.dbt_commands] cmd.append("deactivate") assert run_cmd(cmd, env, self.log.info) == 0 # if applicable: close SSH tunnel if hasattr(self, "ssh_tunnel_forwarder"): self.log.info("Stopping!") self.ssh_tunnel_forwarder.stop() del self.ssh_tunnel_forwarder
def execute_snowflake(sql, conn_id, **kwargs): hook = EWAHBaseHook.get_hook_from_conn_id(conn_id) hook.execute(sql) hook.close()
def start_tunnel(self, remote_host: str, remote_port: int, tunnel_timeout: Optional[int] = 30) -> Tuple[str, int]: """Starts the SSH tunnel with port forwarding. Returns the host and port tuple that can be used to connect to the remote. :param remote_host: Host of the remote that the tunnel should port forward to. Can be "localhost" e.g. if tunneling into a server that hosts a database. :param remote_port: Port that goes along with remote_host. :param tunnel_timeout: Optional timeout setting. Supply a higher number if the default (30s) is too low. :returns: Local bind address aka tuple of local_bind_host and local_bind_port. Calls to local_bind_host:local_bind_port will be forwarded to remote_host:remote_port via the SSH tunnel. """ if not hasattr(self, "_ssh_tunnel_forwarder"): # Tunnel is not started yet - start it now! # Set a specific tunnel timeout if applicable if tunnel_timeout: old_timeout = sshtunnel.TUNNEL_TIMEOUT sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout try: # Build kwargs dict for SSH Tunnel Forwarder if self.conn.ssh_proxy_server: # Use the proxy SSH server as target self._ssh_hook = EWAHBaseHook.get_hook_from_conn_id( conn_id=self.conn.ssh_proxy_server, ) kwargs = { "ssh_address_or_host": self._ssh_hook.start_tunnel(self.conn.host, self.conn.port or 22), "remote_bind_address": (remote_host, remote_port), } else: kwargs = { "ssh_address_or_host": (self.conn.host, self.conn.port or 22), "remote_bind_address": (remote_host, remote_port), } if self.conn.username: kwargs["ssh_username"] = self.conn.username if self.conn.password: kwargs["ssh_password"] = self.conn.password # Save private key in a temporary file, if applicable with NamedTemporaryFile() as keyfile: if self.conn.private_key: keyfile.write(self.conn.private_key.encode()) keyfile.flush() kwargs["ssh_pkey"] = os.path.abspath(keyfile.name) self.log.info("Opening SSH Tunnel to {0}:{1}...".format( *kwargs["ssh_address_or_host"])) self._ssh_tunnel_forwarder = sshtunnel.SSHTunnelForwarder( **kwargs) self._ssh_tunnel_forwarder.start() except: # Set package constant back to original setting, if applicable if tunnel_timeout: sshtunnel.TUNNEL_TIMEOUT = old_timeout raise return ("localhost", self._ssh_tunnel_forwarder.local_bind_port)
def execute(self, context): hook = EWAHBaseHook.get_hook_from_conn_id(self.postgres_conn_id) hook.execute(sql=self.sql, params=self.parameters, commit=True) hook.close() # SSH tunnel does not close if hook is not closed first
def etl_schema_tasks( dag, dwh_engine, dwh_conn_id, target_schema_name, target_schema_suffix="_next", target_database_name=None, read_right_users=None, # Only for PostgreSQL **additional_task_args): if dwh_engine == EC.DWH_ENGINE_POSTGRES: sql_kickoff = """ DROP SCHEMA IF EXISTS "{schema_name}{schema_suffix}" CASCADE; CREATE SCHEMA "{schema_name}{schema_suffix}"; """.format( schema_name=target_schema_name, schema_suffix=target_schema_suffix, ) sql_final = """ DROP SCHEMA IF EXISTS "{schema_name}" CASCADE; ALTER SCHEMA "{schema_name}{schema_suffix}" RENAME TO "{schema_name}"; """.format( schema_name=target_schema_name, schema_suffix=target_schema_suffix, ) # Don't fail final task just because a user or role that should # be granted read rights does not exist! grant_rights_sql = """ DO $$ BEGIN GRANT USAGE ON SCHEMA "{target_schema_name}" TO {user}; GRANT SELECT ON ALL TABLES IN SCHEMA "{target_schema_name}" TO {user}; EXCEPTION WHEN OTHERS THEN -- catches any error RAISE NOTICE 'not granting rights - user does not exist!'; END $$; """ if read_right_users: if not isinstance(read_right_users, list): raise Exception("Arg read_right_users must be of type List!") for user in read_right_users: if re.search(r"\s", user) or (";" in user): _msg = "No whitespace or semicolons allowed in usernames!" raise ValueError(_msg) sql_final += grant_rights_sql.format( target_schema_name=target_schema_name, user=user, ) task_1_args = deepcopy(additional_task_args) task_2_args = deepcopy(additional_task_args) task_1_args.update({ "sql": sql_kickoff, "task_id": "kickoff_{0}".format(target_schema_name), "dag": dag, "postgres_conn_id": dwh_conn_id, }) task_2_args.update({ "sql": sql_final, "task_id": "final_{0}".format(target_schema_name), "dag": dag, "postgres_conn_id": dwh_conn_id, }) return (PGO(**task_1_args), PGO(**task_2_args)) elif dwh_engine == EC.DWH_ENGINE_SNOWFLAKE: target_database_name = target_database_name or ( EWAHBaseHook.get_connection(dwh_conn_id).database) sql_kickoff = """ DROP SCHEMA IF EXISTS "{database}"."{schema_name}{schema_suffix}" CASCADE; CREATE SCHEMA "{database}"."{schema_name}{schema_suffix}"; """.format( database=target_database_name, schema_name=target_schema_name, schema_suffix=target_schema_suffix, ) sql_final = """ DROP SCHEMA IF EXISTS "{database}"."{schema_name}" CASCADE; ALTER SCHEMA "{database}"."{schema_name}{schema_suffix}" RENAME TO "{schema_name}"; """.format( database=target_database_name, schema_name=target_schema_name, schema_suffix=target_schema_suffix, ) def execute_snowflake(sql, conn_id, **kwargs): hook = EWAHBaseHook.get_hook_from_conn_id(conn_id) hook.execute(sql) hook.close() task_1_args = deepcopy(additional_task_args) task_2_args = deepcopy(additional_task_args) task_1_args.update({ "task_id": "kickoff_{0}".format(target_schema_name), "python_callable": execute_snowflake, "op_kwargs": { "sql": sql_kickoff, "conn_id": dwh_conn_id, }, "provide_context": True, "dag": dag, }) task_2_args.update({ "task_id": "final_{0}".format(target_schema_name), "python_callable": execute_snowflake, "op_kwargs": { "sql": sql_final, "conn_id": dwh_conn_id, }, "provide_context": True, "dag": dag, }) return (PO(**task_1_args), PO(**task_2_args)) elif dwh_engine == EC.DWH_ENGINE_GS: # create dummy tasks return ( DO( task_id="kickoff", dag=dag, ), DO( task_id="final", dag=dag, ), ) else: raise ValueError("Feature not implemented!")
def execute(self, context): """Why this method is defined here: When executing a task, airflow calls this method. Generally, this method contains the "business logic" of the individual operator. However, EWAH may want to do some actions for all operators. Thus, the child operators shall have an ewah_execute() function which is called by this general execute() method. """ self.log.info(""" Running EWAH Operator {0}. DWH: {1} (connection id: {2}) Extract Strategy: {3} Load Strategy: {4} """.format( str(self), self.dwh_engine, self.dwh_conn_id, self.extract_strategy, self.load_strategy, )) # required for metadata in data upload self._execution_time = datetime_utcnow_with_tz() self._context = context cleaner_callables = self.cleaner_callables or [] if self.source_conn_id: # resolve conn id here & delete the object to avoid usage elsewhere self.source_conn = EWAHBaseHook.get_connection(self.source_conn_id) self.source_hook = self.source_conn.get_hook() if callable( getattr(self.source_hook, "get_cleaner_callables", None)): hook_callables = self.source_hook.get_cleaner_callables() if callable(hook_callables): cleaner_callables.append(hook_callables) elif hook_callables: # Ought to be list of callables cleaner_callables += hook_callables del self.source_conn_id if self._CONN_TYPE: assert (self._CONN_TYPE == self.source_conn.conn_type ), "Error - connection type must be {0}!".format( self._CONN_TYPE) uploader_callables = self.uploader_class.get_cleaner_callables() if callable(uploader_callables): cleaner_callables.append(uploader_callables) elif uploader_callables: cleaner_callables += uploader_callables self.uploader = self.uploader_class( dwh_conn=EWAHBaseHook.get_connection(self.dwh_conn_id), cleaner=self.cleaner_class( default_row=self.default_values, include_columns=self.include_columns, exclude_columns=self.exclude_columns, add_metadata=self.add_metadata, rename_columns=self.rename_columns, hash_columns=self.hash_columns, hash_salt=self.hash_salt, additional_callables=cleaner_callables, ), table_name=self.target_table_name, schema_name=self.target_schema_name, schema_suffix=self.target_schema_suffix, database_name=self.target_database_name, primary_key=self.primary_key, load_strategy=self.load_strategy, use_temp_pickling=self.use_temp_pickling, pickling_upload_chunk_size=self.pickling_upload_chunk_size, pickle_compression=self.pickle_compression, deduplication_before_upload=self.deduplication_before_upload, **self.additional_uploader_kwargs, ) # If applicable: set the session's default time zone if self.default_timezone: self.uploader.dwh_hook.execute("SET timezone TO '{0}'".format( self.default_timezone)) # Create a new copy of the target table. # This is so data is loaded into a new table and if data loading # fails, the original data is not corrupted. At a new try or re-run, # the original table is just copied anew. if not self.load_strategy == EC.LS_INSERT_REPLACE: # insert_replace always drops and replaces the tables completely self.uploader.copy_table() # set load_data_from and load_data_until as required data_from = ada(self.load_data_from) data_until = ada(self.load_data_until) if self.extract_strategy == EC.ES_INCREMENTAL: _tdz = timedelta(days=0) # aka timedelta zero _ed = context["data_interval_start"] _ned = context["data_interval_end"] # normal incremental load _ed -= self.load_data_from_relative or _tdz data_from = min(_ed, data_from or _ed) if not self.test_if_target_table_exists(): # Load data from scratch! data_from = ada(self.reload_data_from) or data_from _ned += self.load_data_until_relative or _tdz data_until = max(_ned, data_until or _ned) elif self.extract_strategy in (EC.ES_FULL_REFRESH, EC.ES_SUBSEQUENT): # Values may still be set as static values data_from = ada(self.reload_data_from) or data_from else: _msg = "Must define load_data_from etc. behavior for load strategy!" raise Exception(_msg) self.data_from = data_from self.data_until = data_until # del variables to make sure they are not used later on del self.load_data_from del self.reload_data_from del self.load_data_until del self.load_data_until_relative if not self.extract_strategy == EC.ES_SUBSEQUENT: # keep this param for subsequent loads del self.load_data_from_relative # Have an option to wait until a short period (e.g. 2 minutes) past # the incremental loading range timeframe to ensure that all data is # loaded, useful e.g. if APIs lag or if server timestamps are not # perfectly accurate. # When a DAG is executed as soon as possible, some data sources # may not immediately have up to date data from their API. # E.g. querying all data until 12.30pm only gives all relevant data # after 12.32pm due to some internal delays. In those cases, make # sure the (incremental loading) DAGs don't execute too quickly. if self.wait_for_seconds and self.extract_strategy == EC.ES_INCREMENTAL: wait_until = context.get("data_interval_end") if wait_until: wait_until += timedelta(seconds=self.wait_for_seconds) self.log.info("Awaiting execution until {0}...".format( str(wait_until), )) while wait_until and datetime_utcnow_with_tz() < wait_until: # Only sleep a maximum of 5s at a time wait_for_timedelta = wait_until - datetime_utcnow_with_tz() time.sleep(max(0, min(wait_for_timedelta.total_seconds(), 5))) # execute operator if self.load_data_chunking_timedelta and data_from and data_until: # Chunking to avoid OOM assert data_until > data_from assert self.load_data_chunking_timedelta > timedelta(days=0) while self.data_from < data_until: self.data_until = min( self.data_from + self.load_data_chunking_timedelta, data_until) self.log.info("Now loading from {0} to {1}...".format( str(self.data_from), str(self.data_until))) self.ewah_execute(context) self.data_from += self.load_data_chunking_timedelta else: self.ewah_execute(context) # Run final scripts # TODO: Include indexes into uploader and then remove this step self.uploader.finalize_upload() # if PostgreSQL and arg given: create indices for column in self.index_columns: assert self.dwh_engine == EC.DWH_ENGINE_POSTGRES # Use hashlib to create a unique 63 character string as index # name to avoid breaching index name length limits & accidental # duplicates / missing indices due to name truncation leading to # identical index names. self.uploader.dwh_hook.execute( self._INDEX_QUERY.format( "__ewah_" + hashlib.blake2b( (self.target_schema_name + self.target_schema_suffix + "." + self.target_table_name + "." + column).encode(), digest_size=28, ).hexdigest(), self.target_schema_name + self.target_schema_suffix, self.target_table_name, column, )) # commit only at the end, so that no data may be committed before an # error occurs. self.log.info("Now committing changes!") self.uploader.commit() self.uploader.close()
def execute_for_shop( self, context, shop_id, params, source_conn_id, auth_type, ): # Get data from shopify via REST API def add_get_transactions(data, shop, version, req_kwargs): # workaround to add transactions to orders self.log.info("Requesting transactions of orders...") base_url = "https://{shop}.myshopify.com/admin/api/{version}/orders/{id}/transactions.json" base_url = base_url.format(**{ "shop": shop, "version": version, "id": "{id}", }) for datum in data: id = datum["id"] # self.log.info('getting transactions for order {0}'.format(id)) time.sleep( 1) # avoid hitting api call requested per second limit url = base_url.format(id=id) req = requests.get(url, **req_kwargs) if not req.status_code == 200: self.log.info("response: " + str(req.status_code)) self.log.info("request text: " + req.text) raise Exception("non-200 response!") transactions = json.loads(req.text).get("transactions", []) datum["transactions"] = transactions return data def add_get_inventoryitems(data, shop, version, req_kwargs): # workaround to get inventory item data (i.e. costs) for products self.log.info("Requesting inventory items of product variants...") base_url = ( "https://{shop}.myshopify.com/admin/api/{version}/inventory_items.json" ) url = base_url.format( shop=shop, version=version, ) kwargs = copy.deepcopy(req_kwargs) for datum in data: ids = [ v["inventory_item_id"] for v in datum.get("variants", []) ] if ids: kwargs["params"] = {"ids": copy.deepcopy(ids)} time.sleep(1) # avoid hitting api call requested limit req = requests.get(url, **kwargs) if not req.status_code == 200: self.log.info("response: " + str(req.status_code)) self.log.info("request text: " + req.text) raise Exception("non-200 response!") inv_items = json.loads(req.text).get("inventory_items", []) datum["inventory_items"] = inv_items return data def add_get_events(data, shop, version, req_kwargs): # workaround to add events of an order to orders self.log.info("Requesting events of orders...") base_url = "https://{shop}.myshopify.com/admin/api/{version}/orders/{id}/events.json" base_url = base_url.format( shop=shop, version=version, id="{id}", ) for datum in data: id = datum["id"] time.sleep(1) url = base_url.format(id=id) req = requests.get(url, **req_kwargs) if not req.status_code == 200: self.log.info("response: " + str(req.status_code)) self.log.info("request text: " + req.text) raise Exception("non-200 response!") events = json.loads(req.text).get("events", []) datum["events"] = events return data url = self._base_url.format( **{ "shop": shop_id, "version": self.api_version, "object": self.object_metadata.get( "_object_url", self.shopify_object, ), }) # get connection for the applicable shop conn = BaseHook.get_connection(source_conn_id) login = conn.login password = conn.password if auth_type == "access_token": headers = { "X-Shopify-Access-Token": password, } kwargs_init = { "headers": headers, "params": params, } kwargs_links = {"headers": headers} elif auth_type == "basic_auth": kwargs_init = { "params": params, "auth": HTTPBasicAuth(login, password), } kwargs_links = {"auth": HTTPBasicAuth(login, password)} else: raise Exception("Authentication type not accepted!") # get and upload data self.log.info( "Requesting data from REST API - url: {0}, params: {1}".format( url, str(params))) req_kwargs = kwargs_init is_first = True while is_first or (r.status_code == 200 and url): r = requests.get(url, **req_kwargs) if is_first: is_first = False req_kwargs = kwargs_links data = json.loads(r.text or "{}").get( self.object_metadata.get( "_name_in_request_data", self.shopify_object, )) if self.get_transactions_with_orders: data = add_get_transactions( data=data, shop=shop_id, version=self.api_version, req_kwargs=kwargs_links, ) if self.get_events_with_orders: data = add_get_events( data=data, shop=shop_id, version=self.api_version, req_kwargs=kwargs_links, ) if self.get_inventory_data_with_product_variants: data = add_get_inventoryitems( data=data, shop=shop_id, version=self.api_version, req_kwargs=kwargs_links, ) self.upload_data(data) self.log.info("Requesting next page of data...") if r.headers.get("Link") and r.headers["Link"][-9:] == 'el="next"': url = r.headers["Link"][1:-13] else: url = None if not r.status_code == 200: raise Exception( "Shopify request returned an error {1}: {0}".format( r.text, str(r.status_code), ))