def load_table( dataframe, dataframe_type, manifest_path, schema_name, table_name, redshift_conn, num_files, iam_role, mode="append", preserve_index=False, ): cursor = redshift_conn.cursor() if mode == "overwrite": cursor.execute("-- AWS DATA WRANGLER\n" f"DROP TABLE IF EXISTS {schema_name}.{table_name}") schema = Redshift._get_redshift_schema( dataframe=dataframe, dataframe_type=dataframe_type, preserve_index=preserve_index, ) cols_str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2] sql = ( "-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n{cols_str}" ") DISTSTYLE AUTO") cursor.execute(sql) sql = ("-- AWS DATA WRANGLER\n" f"COPY {schema_name}.{table_name} FROM '{manifest_path}'\n" f"IAM_ROLE '{iam_role}'\n" "MANIFEST\n" "FORMAT AS PARQUET") cursor.execute(sql) cursor.execute( "-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id") query_id = cursor.fetchall()[0][0] sql = ( "-- AWS DATA WRANGLER\n" f"SELECT COUNT(*) as num_files_loaded FROM STL_LOAD_COMMITS WHERE query = {query_id}" ) cursor.execute(sql) num_files_loaded = cursor.fetchall()[0][0] if num_files_loaded != num_files: redshift_conn.rollback() cursor.close() raise RedshiftLoadError( f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected." ) redshift_conn.commit() cursor.close()
def load_table( dataframe, dataframe_type, manifest_path, schema_name, table_name, redshift_conn, num_files, iam_role, diststyle="AUTO", distkey=None, sortstyle="COMPOUND", sortkey=None, mode="append", preserve_index=False, ): """ Load Parquet files into a Redshift table using a manifest file. Creates the table if necessary. :param dataframe: Pandas or Spark Dataframe :param dataframe_type: "pandas" or "spark" :param manifest_path: S3 path for manifest file (E.g. S3://...) :param schema_name: Redshift schema :param table_name: Redshift table name :param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) :param num_files: Number of files to be loaded :param iam_role: AWS IAM role with the related permissions :param diststyle: Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"] (https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html) :param distkey: Specifies a column name or positional number for the distribution key :param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html) :param sortkey: List of columns to be sorted :param mode: append or overwrite :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) :return: None """ cursor = redshift_conn.cursor() if mode == "overwrite": Redshift._create_table( cursor=cursor, dataframe=dataframe, dataframe_type=dataframe_type, schema_name=schema_name, table_name=table_name, diststyle=diststyle, distkey=distkey, sortstyle=sortstyle, sortkey=sortkey, preserve_index=preserve_index, ) sql = ("-- AWS DATA WRANGLER\n" f"COPY {schema_name}.{table_name} FROM '{manifest_path}'\n" f"IAM_ROLE '{iam_role}'\n" "MANIFEST\n" "FORMAT AS PARQUET") cursor.execute(sql) cursor.execute( "-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id") query_id = cursor.fetchall()[0][0] sql = ( "-- AWS DATA WRANGLER\n" f"SELECT COUNT(*) as num_files_loaded FROM STL_LOAD_COMMITS WHERE query = {query_id}" ) cursor.execute(sql) num_files_loaded = cursor.fetchall()[0][0] if num_files_loaded != num_files: redshift_conn.rollback() cursor.close() raise RedshiftLoadError( f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected." ) redshift_conn.commit() cursor.close()
def load_table(dataframe, dataframe_type, manifest_path, schema_name, table_name, redshift_conn, num_files, iam_role, diststyle="AUTO", distkey=None, sortstyle="COMPOUND", sortkey=None, primary_keys: Optional[List[str]] = None, mode="append", preserve_index=False, cast_columns=None): """ Load Parquet files into a Redshift table using a manifest file. Creates the table if necessary. :param dataframe: Pandas or Spark Dataframe :param dataframe_type: "pandas" or "spark" :param manifest_path: S3 path for manifest file (E.g. S3://...) :param schema_name: Redshift schema :param table_name: Redshift table name :param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) :param num_files: Number of files to be loaded :param iam_role: AWS IAM role with the related permissions :param diststyle: Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"] (https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html) :param distkey: Specifies a column name or positional number for the distribution key :param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html) :param sortkey: List of columns to be sorted :param primary_keys: Primary keys :param mode: append, overwrite or upsert :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe) :param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"}) :return: None """ final_table_name: Optional[str] = None temp_table_name: Optional[str] = None with redshift_conn.cursor() as cursor: if mode == "overwrite": Redshift._create_table(cursor=cursor, dataframe=dataframe, dataframe_type=dataframe_type, schema_name=schema_name, table_name=table_name, diststyle=diststyle, distkey=distkey, sortstyle=sortstyle, sortkey=sortkey, primary_keys=primary_keys, preserve_index=preserve_index, cast_columns=cast_columns) table_name = f"{schema_name}.{table_name}" elif mode == "upsert": guid: str = pa.compat.guid() temp_table_name = f"temp_redshift_{guid}" final_table_name = table_name table_name = temp_table_name sql: str = f"CREATE TEMPORARY TABLE {temp_table_name} (LIKE {schema_name}.{final_table_name})" logger.debug(sql) cursor.execute(sql) else: table_name = f"{schema_name}.{table_name}" sql = ("-- AWS DATA WRANGLER\n" f"COPY {table_name} FROM '{manifest_path}'\n" f"IAM_ROLE '{iam_role}'\n" "MANIFEST\n" "FORMAT AS PARQUET") logger.debug(sql) cursor.execute(sql) cursor.execute("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id") query_id = cursor.fetchall()[0][0] sql = ("-- AWS DATA WRANGLER\n" f"SELECT COUNT(DISTINCT filename) as num_files_loaded " f"FROM STL_LOAD_COMMITS " f"WHERE query = {query_id}") logger.debug(sql) cursor.execute(sql) num_files_loaded = cursor.fetchall()[0][0] if num_files_loaded != num_files: redshift_conn.rollback() raise RedshiftLoadError( f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected.") if (mode == "upsert") and (final_table_name is not None): if not primary_keys: primary_keys = Redshift.get_primary_keys(connection=redshift_conn, schema=schema_name, table=final_table_name) if not primary_keys: raise InvalidRedshiftPrimaryKeys() equals_clause = f"{final_table_name}.%s = {temp_table_name}.%s" join_clause = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys]) sql = f"DELETE FROM {schema_name}.{final_table_name} USING {temp_table_name} WHERE {join_clause}" logger.debug(sql) cursor.execute(sql) sql = f"INSERT INTO {schema_name}.{final_table_name} SELECT * FROM {temp_table_name}" logger.debug(sql) cursor.execute(sql) redshift_conn.commit()