def execute(self, context): snowflake_hook = SnowflakeHook( snowflake_conn_id=self.snowflake_conn_id) # Snowflake won't accept list of files it has to be tuple only. # but in python tuple([1]) = (1,) => which is invalid for snowflake files = str(self.s3_keys) files = files.replace('[', '(') files = files.replace(']', ')') # we can extend this based on stage base_sql = """ FROM @{stage}/ files={files} file_format={file_format} """.format(stage=self.stage, files=files, file_format=self.file_format) if self.columns_array: copy_query = """ COPY INTO {schema}.{table}({columns}) {base_sql} """.format(schema=self.schema, table=self.table, columns=",".join(self.columns_array), base_sql=base_sql) else: copy_query = """ COPY INTO {schema}.{table} {base_sql} """.format(schema=self.schema, table=self.table, base_sql=base_sql) self.log.info('Executing COPY command...') print(snowflake_hook.get_uri()) snowflake_hook.run(copy_query, self.autocommit) self.log.info("COPY command completed")
def _log_snowflake_resources( query_text, snowflake_conn_id, session_id=None, warehouse=None, database=None, role=None, schema=None, account=None, ): hook = SnowflakeHook( snowflake_conn_id, warehouse=warehouse, account=account, database=database, role=role, schema=schema, ) conn = hook.get_uri() conn_params = hook._get_conn_params() log_snowflake_resource_usage( query_text, database=hook.database, user=conn_params["user"], connection_string=conn, session_id=session_id, )
def get_table_row_count(**context): table_row_count = MEDICARE_PART_D_PROVIDER_OFF_SET_START print("Counting TABLE MAX Row ...") dwh_hook = SnowflakeHook(snowflake_conn_id="kyu_snowflake_conn", warehouse="COMPUTE_WH") table_row_count = dwh_hook.get_first(table_row_count)[0] print("OFF_SET_START: " + str(table_row_count))
def insert_row(timestamp, **context): pprint(context) dwh_hook = SnowflakeHook(snowflake_conn_id="snowflake_conn") result = dwh_hook.get_first( "insert into public.test_table_time select 'INSERT', " + timestamp + ", parse_json(' " + context + " ');") logging.info("Added row with timestamp : %s, and result is $s", timestamp, result)
def get_table_max_incremental_number(**context): get_table_max_incremental_number = MEDICARE_PART_D_PROVIDER_TABLE_MAX_NPI print("Get Table MAX increment number...") dwh_hook = SnowflakeHook(snowflake_conn_id="kyu_snowflake_conn", warehouse="COMPUTE_WH") dwh_hook.run(get_table_max_incremental_number, autocommit=True) table_max_npi = dwh_hook.get_first(get_table_max_incremental_number)[0] print("MAX Table increment number: " + str(table_max_npi))
def execute(self, context): hook = SnowflakeHook(snowflake_conn_id=snowflake_conn_id) with snowflake_query_tracker(database=database) as st: hook.run(select_query) session_id, query_id = st.get_last_session_with_query_id( many=False) context["ti"].xcom_push(key="session_id", value=session_id) context["ti"].xcom_push(key="query_id", value=query_id)
def execute(self, context): self.log.info(f'Checking DataQuality for {self._table}') snowflake_hook = SnowflakeHook(self._conn_id) self.log.info(f'Running query: {self._query}') records = snowflake_hook.get_records(self._query) if len(records) < 1 or len(records[0] < 1): raise ValueError( f"Data quality check failed. {self._table} returned no results" ) logging.info(f'Passed Data Quality Check for {self._table}')
def execute(self, context): self.log.info(f'Checking HasRows for {self._table}') snowflake_hook = SnowflakeHook(self._conn_id) self.log.info(f'Running rcord count: {self._table}') records = snowflake_hook.get_records( f"SELECT COUNT(1) FROM {self._table}") if len(records) < 1 or len(records[0] < 1): raise ValueError( f"Data quality check failed. {self._table} returned no results" ) logging.info(f'Passed Has Rows Check for {self._table}')
def fetch_data_from_snowflake(): hook = SnowflakeHook("snowflake_conn") conn = hook.get_conn() roles = [] users = [] with conn.cursor() as cursor: cursor.execute("USE DATABASE MY_AUDIT_DB;") cursor.execute("USE ROLE ACCOUNTADMIN;") # Queries for Roles and Role Grants cursor.execute("SHOW roles") rec_set = cursor.fetchall() for rec in rec_set: roles.append(Role(rec[1], rec[9])) for role in roles: cursor.execute("SHOW GRANTS TO ROLE " + role.name) grant_set = cursor.fetchall() for cur_grant in grant_set: role.add_grant( RoleGrant(cur_grant[1], cur_grant[2], cur_grant[3]), roles ) # Queries for User and User Roles cursor.execute("SHOW users") user_set = cursor.fetchmany(1000) for user in user_set: users.append(User(user)) while len(user_set) > 0: user_set = cursor.fetchmany(1000) for user in user_set: users.append(User(user)) for user in users: cursor.execute("SHOW GRANTS TO USER " + user.user_name) user.get_roles(cursor.fetchall()) with open(ROLES_PATH, "w") as roles_file, open( ROLE_GRANTS_PATH, "w" ) as role_grants_file: for role in roles: role.write_roles(roles_file) role.write_grants(role.name, "ROOT", role_grants_file) with open(USERS_PATH, "w") as users_file, open( USER_ROLES_PATH, "w" ) as user_roles_file: for user in users: user.write_user_record(users_file) user.write_roles(user_roles_file)
def snowflake_run( hook: SnowflakeHook, sql: Union[str, List[str]], autocommit: bool = False, parameters: Union[Mapping, Iterable, None] = None, ) -> Tuple[int, List[str]]: """ Runs a command or a list of commands and returns session_id and query_id(s). Pass a list of sql statements to the sql parameter to get them to execute sequentially. :param hook: Configured SnowflakeHook with Snoflake connection details. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :param autocommit: What to set the connection's autocommit setting to before executing the query. :param parameters: The parameters to render the SQL query with. :return: session_id and query_id(s) assigned by Snowflake to each submitted query """ if isinstance(sql, six.string_types): sql = [sql] with closing(hook.get_conn()) as conn: session_id, query_ids = conn.session_id, [] if hook.supports_autocommit: hook.set_autocommit(conn, autocommit) with closing(conn.cursor()) as cur: for s in sql: if sys.version_info[0] < 3: s = s.encode("utf-8") if parameters is not None: hook.log.info("{} with parameters {}".format(s, parameters)) res = cur.execute(s, parameters) else: hook.log.info(s) res = cur.execute(s) query_ids.append(res.sfqid) # If autocommit was set to False for db that supports autocommit, # or if db does not supports autocommit, we do a manual commit. if not hook.get_autocommit(conn): conn.commit() logger.info( "Executed queries '{}', got session_id {}, query_id {}".format( sql, session_id, query_ids ) ) return (session_id, query_ids)
def snowflake_get_first( hook: SnowflakeHook, sql: str, parameters: Union[Mapping, Iterable, None] = None ) -> Tuple[Any, int, str]: """ Executes the sql and returns the first resulting row alongside with session_id and query_id generated by Snowflake. :param hook: Configured SnowflakeHook with Snoflake connection details. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :param parameters: The parameters to render the SQL query with. :return: Query result alongside with session_id and query_id assigned by Snowflake to each submitted query """ if sys.version_info[0] < 3: sql = sql.encode("utf-8") with closing(hook.get_conn()) as conn: session_id = conn.session_id with closing(conn.cursor()) as cur: if parameters is not None: cur.execute(sql, parameters) else: cur.execute(sql) query_id = cur.sfqid return cur.fetchone(), session_id, query_id
def qa_checks(**context): if os.path.exists(qa_file) and GIT_USER is not None and GIT_TOKEN is not None and ENVIRONMENT != 'CI': with open(qa_file, 'r') as fd: sqlfile = fd.read() sqllist = sqlfile.split(";") sf_hook = SnowflakeHook(snowflake_conn_id=Variable.get( "SNOWFLAKE_CONNECTION", default_var="SNOWFLAKE")) for sql in sqllist: if len(sql.strip()) > 5: result = sf_hook.get_pandas_df(sql) if len(result.index) > 0: for index, row in result.iterrows(): make_github_issue('QA Failed for ' + row['TABLE_NAME'], "Error: " + row['ERROR_DESC'] + "\n" + "Error Count: " + str( row['ERROR_COUNT']) + "\n" + row['ERROR_CONDITION'], ['bug', 'qa'])
def get_hook(self): return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id, warehouse=self.warehouse, database=self.database, role=self.role, schema=self.schema, authenticator=self.authenticator)
def execute(self, context): logging.info('Snowflake Query Operator Starting') hook = SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id) if isinstance(self.s3_conn_id, str): self.execute_results_to_s3(hook=hook) else: self.execute_no_results(hook=hook) logging.info('Snowflake Query Operator Complete')
def _log_snowflake_table( table, snowflake_conn_id, warehouse=None, database=None, role=None, schema=None, account=None, ): hook = SnowflakeHook( snowflake_conn_id=snowflake_conn_id, warehouse=warehouse, account=account, database=database, role=role, schema=schema, ) connection_string = hook.get_uri() return log_snowflake_table(table, connection_string, database, schema)
def pivot_data(**kwargs): #Make connection to Snowflake hook = SnowflakeHook(snowflake_conn_id='snowflake') conn = hook.get_conn() #Define SQL query query = 'SELECT DATE, STATE, POSITIVE FROM STATE_DATA;' #Read data into pandas dataframe df = pd.read_sql(query, conn) #Pivot dataframe into new format pivot_df = df.pivot(index='DATE', columns='STATE', values='POSITIVE').reset_index() #Save dataframe to S3 s3_hook = S3Hook(aws_conn_id=S3_CONN_ID) s3_hook.load_string(pivot_df.to_csv(index=False), '{0}.csv'.format(filename), bucket_name=BUCKET, replace=True)
def snowflake_db_monitor(**op_kwarg): snowflake_hook = SnowflakeHook(snowflake_conn_id="test_snowflake_conn") with snowflake_query_tracker(database=DATABASE, schema=SCHEMA) as st: snowflake_tables = snowflake_hook.get_pandas_df(GET_COLUMNS) snowflake_shapes = DataFrame() snowflake_tables = snowflake_tables[snowflake_tables["schema_name"] == "{}".format(SCHEMA)] snowflake_shapes["column_count"] = snowflake_tables.groupby( "table_name").nunique("column_name")["column_name"] snowflake_shapes["table_name"] = snowflake_tables["table_name"].unique() table_row_info = {} snowflake_rows = snowflake_hook.get_records(GET_DB_ROW_INFO) for tablename, row_count in snowflake_rows: table_row_info[tablename] = row_count row_counts = list(table_row_info.values()) log_metric("Max table row count", max(row_counts)) log_metric("Min table row count", min(row_counts)) log_metric("Mean table row count", round(mean(row_counts), 2)) log_metric("Median table row count", median(row_counts)) snowflake_shapes["row_count"] = (snowflake_shapes["table_name"].map( table_row_info).fillna(0).astype(int)) for _, row in snowflake_shapes.iterrows(): log_metric( "{} shape".format(row["table_name"]), (row["column_count"], row["row_count"]), ) log_metric("Max table column count", snowflake_shapes["column_count"].max()) log_metric("Min table column count", snowflake_shapes["column_count"].max()) log_metric("Mean table column count", round(snowflake_shapes["column_count"].mean(), 2)) log_metric("Median table column count", snowflake_shapes["column_count"].median())
def execute(self, context): hook = SnowflakeHook(snowflake_conn_id=snowflake_conn_id) session_id, query_ids = snowflake_run(hook, select_query) connection_string = hook.get_uri() log_snowflake_table( table_name=table, connection_string=connection_string, database=database, schema=schema, key=f"example1.{table}", with_preview=False, raise_on_error=False, ) log_snowflake_resource_usage( database=database, key=f"example1.{session_id}{query_ids[0]}", connection_string=connection_string, query_ids=query_ids, session_id=int(session_id), raise_on_error=True, )
def get_snowflake_connection_dict_from_airflow(conn_id): """Return a dictionary usable with `snowflake.connector.connect`. """ from airflow.contrib.hooks.snowflake_hook import SnowflakeHook conn = SnowflakeHook.get_connection(conn_id) account = conn.extra_dejson.get('account', None) warehouse = conn.extra_dejson.get('warehouse', None) database = conn.extra_dejson.get('database', None) return { 'user': conn.login, 'password': conn.password or '', 'schema': conn.schema or '', 'database': database or '', 'account': account or '', 'warehouse': warehouse or '' }
def execute(self, context): """ Executes one or more SQL statements passed a string. Uses The Snowflake Connection.execute_string that executes one or more SQL statements passed a string. This string can contain mutiple statements separated by semi-colons, and can contain newlines and comments. By default, this method returns a sequence of Cursor objects in the order of execution. The string will have already been processed by the Jinja template engine by the time it gets here. See https://docs.snowflake.net/manuals/user-guide/python-connector-api.html#id1 """ logging.info('Starting SnowflakeSQL.execute, snowflake_conn_id=' + self.snowflake_conn_id) try: SnowflakeHook().get_conn( conn_name=self.snowflake_conn_id).execute_string(self.sql) finally: logging.info('Finished SnowflakeSQL.execute')
def create_or_replace_view(**context): create_or_replace_view = MEDICARE_PART_D_PROVIDER_LOAD_CREATE_VIEW print("creating or replacing existing view.") dwh_hook = SnowflakeHook(snowflake_conn_id="kyu_snowflake_conn", warehouse="COMPUTE_WH") dwh_hook.run(create_or_replace_view, autocommit=True)
def load_stage_to_master(**context): load_stage_to_master = MEDICARE_PART_D_PROVIDER_LOAD_MASTER_TABLE print("Loading stage to master...") dwh_hook = SnowflakeHook(snowflake_conn_id="kyu_snowflake_conn", warehouse="COMPUTE_WH") dwh_hook.run(load_stage_to_master, autocommit=True)
def stage_file(**context): stage_file = MEDICARE_PART_D_PROVIDER_COPY_TO_STAGE print("Loading S3 file into staging table") dwh_hook = SnowflakeHook(snowflake_conn_id="kyu_snowflake_conn", warehouse="COMPUTE_WH") dwh_hook.run(stage_file, autocommit=True)
def truncate_stage_table(**context): truncate_stage_table = MEDICARE_PART_D_PROVIDER_TRUNCATE_STAGE print("Truncate staging table") dwh_hook = SnowflakeHook(snowflake_conn_id="kyu_snowflake_conn", warehouse="COMPUTE_WH") dwh_hook.run(truncate_stage_table, autocommit=True)
def get_conn(): hook = SnowflakeHook(snowflake_conn_id=conn_id) return hook.get_conn()
def row_count(**context): dwh_hook = SnowflakeHook(snowflake_conn_id="snowflake_all") result = dwh_hook.get_first("select count(*) from public.tapas") logging.info("Number of rows in `public.tapas` - %s", result[0])
def get_hook(self): return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id, warehouse=self.warehouse, database=self.database)
def update_customers(**kwargs): snowflake_hook = SnowflakeHook(snowflake_conn_id=SNOWFLAKE_CONNECTION_ID) snowflake_hook.run(update_query)
def process_customers(**kwargs): snowflake_hook = SnowflakeHook(snowflake_conn_id=SNOWFLAKE_CONNECTION_ID) customers = snowflake_hook.get_records(select_query) # Process records process_records(customers)
def get_row_count(**context): dwh_hook = SnowflakeHook(snowflake_conn_id="snowflake_common") result = dwh_hook.get_first( "select count(*) from private.{}".format(table_name)) logging.info("Number of rows in `private.%s` - %s", table_name, result[0])