Example #1
0
 def load_table(
     dataframe,
     dataframe_type,
     manifest_path,
     schema_name,
     table_name,
     redshift_conn,
     num_files,
     iam_role,
     mode="append",
     preserve_index=False,
 ):
     cursor = redshift_conn.cursor()
     if mode == "overwrite":
         cursor.execute("-- AWS DATA WRANGLER\n"
                        f"DROP TABLE IF EXISTS {schema_name}.{table_name}")
     schema = Redshift._get_redshift_schema(
         dataframe=dataframe,
         dataframe_type=dataframe_type,
         preserve_index=preserve_index,
     )
     cols_str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
     sql = (
         "-- AWS DATA WRANGLER\n"
         f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n{cols_str}"
         ") DISTSTYLE AUTO")
     cursor.execute(sql)
     sql = ("-- AWS DATA WRANGLER\n"
            f"COPY {schema_name}.{table_name} FROM '{manifest_path}'\n"
            f"IAM_ROLE '{iam_role}'\n"
            "MANIFEST\n"
            "FORMAT AS PARQUET")
     cursor.execute(sql)
     cursor.execute(
         "-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id")
     query_id = cursor.fetchall()[0][0]
     sql = (
         "-- AWS DATA WRANGLER\n"
         f"SELECT COUNT(*) as num_files_loaded FROM STL_LOAD_COMMITS WHERE query = {query_id}"
     )
     cursor.execute(sql)
     num_files_loaded = cursor.fetchall()[0][0]
     if num_files_loaded != num_files:
         redshift_conn.rollback()
         cursor.close()
         raise RedshiftLoadError(
             f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected."
         )
     redshift_conn.commit()
     cursor.close()
Example #2
0
    def load_table(
        dataframe,
        dataframe_type,
        manifest_path,
        schema_name,
        table_name,
        redshift_conn,
        num_files,
        iam_role,
        diststyle="AUTO",
        distkey=None,
        sortstyle="COMPOUND",
        sortkey=None,
        mode="append",
        preserve_index=False,
    ):
        """
        Load Parquet files into a Redshift table using a manifest file.
        Creates the table if necessary.

        :param dataframe: Pandas or Spark Dataframe
        :param dataframe_type: "pandas" or "spark"
        :param manifest_path: S3 path for manifest file (E.g. S3://...)
        :param schema_name: Redshift schema
        :param table_name: Redshift table name
        :param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
        :param num_files: Number of files to be loaded
        :param iam_role: AWS IAM role with the related permissions
        :param diststyle: Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"] (https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html)
        :param distkey: Specifies a column name or positional number for the distribution key
        :param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
        :param sortkey: List of columns to be sorted
        :param mode: append or overwrite
        :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
        :return: None
        """
        cursor = redshift_conn.cursor()
        if mode == "overwrite":
            Redshift._create_table(
                cursor=cursor,
                dataframe=dataframe,
                dataframe_type=dataframe_type,
                schema_name=schema_name,
                table_name=table_name,
                diststyle=diststyle,
                distkey=distkey,
                sortstyle=sortstyle,
                sortkey=sortkey,
                preserve_index=preserve_index,
            )
        sql = ("-- AWS DATA WRANGLER\n"
               f"COPY {schema_name}.{table_name} FROM '{manifest_path}'\n"
               f"IAM_ROLE '{iam_role}'\n"
               "MANIFEST\n"
               "FORMAT AS PARQUET")
        cursor.execute(sql)
        cursor.execute(
            "-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id")
        query_id = cursor.fetchall()[0][0]
        sql = (
            "-- AWS DATA WRANGLER\n"
            f"SELECT COUNT(*) as num_files_loaded FROM STL_LOAD_COMMITS WHERE query = {query_id}"
        )
        cursor.execute(sql)
        num_files_loaded = cursor.fetchall()[0][0]
        if num_files_loaded != num_files:
            redshift_conn.rollback()
            cursor.close()
            raise RedshiftLoadError(
                f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected."
            )
        redshift_conn.commit()
        cursor.close()
Example #3
0
    def load_table(dataframe,
                   dataframe_type,
                   manifest_path,
                   schema_name,
                   table_name,
                   redshift_conn,
                   num_files,
                   iam_role,
                   diststyle="AUTO",
                   distkey=None,
                   sortstyle="COMPOUND",
                   sortkey=None,
                   primary_keys: Optional[List[str]] = None,
                   mode="append",
                   preserve_index=False,
                   cast_columns=None):
        """
        Load Parquet files into a Redshift table using a manifest file.
        Creates the table if necessary.

        :param dataframe: Pandas or Spark Dataframe
        :param dataframe_type: "pandas" or "spark"
        :param manifest_path: S3 path for manifest file (E.g. S3://...)
        :param schema_name: Redshift schema
        :param table_name: Redshift table name
        :param redshift_conn: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
        :param num_files: Number of files to be loaded
        :param iam_role: AWS IAM role with the related permissions
        :param diststyle: Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"] (https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html)
        :param distkey: Specifies a column name or positional number for the distribution key
        :param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
        :param sortkey: List of columns to be sorted
        :param primary_keys: Primary keys
        :param mode: append, overwrite or upsert
        :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
        :param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
        :return: None
        """
        final_table_name: Optional[str] = None
        temp_table_name: Optional[str] = None
        with redshift_conn.cursor() as cursor:
            if mode == "overwrite":
                Redshift._create_table(cursor=cursor,
                                       dataframe=dataframe,
                                       dataframe_type=dataframe_type,
                                       schema_name=schema_name,
                                       table_name=table_name,
                                       diststyle=diststyle,
                                       distkey=distkey,
                                       sortstyle=sortstyle,
                                       sortkey=sortkey,
                                       primary_keys=primary_keys,
                                       preserve_index=preserve_index,
                                       cast_columns=cast_columns)
                table_name = f"{schema_name}.{table_name}"
            elif mode == "upsert":
                guid: str = pa.compat.guid()
                temp_table_name = f"temp_redshift_{guid}"
                final_table_name = table_name
                table_name = temp_table_name
                sql: str = f"CREATE TEMPORARY TABLE {temp_table_name} (LIKE {schema_name}.{final_table_name})"
                logger.debug(sql)
                cursor.execute(sql)
            else:
                table_name = f"{schema_name}.{table_name}"

            sql = ("-- AWS DATA WRANGLER\n"
                   f"COPY {table_name} FROM '{manifest_path}'\n"
                   f"IAM_ROLE '{iam_role}'\n"
                   "MANIFEST\n"
                   "FORMAT AS PARQUET")
            logger.debug(sql)
            cursor.execute(sql)
            cursor.execute("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id")
            query_id = cursor.fetchall()[0][0]
            sql = ("-- AWS DATA WRANGLER\n"
                   f"SELECT COUNT(DISTINCT filename) as num_files_loaded "
                   f"FROM STL_LOAD_COMMITS "
                   f"WHERE query = {query_id}")
            logger.debug(sql)
            cursor.execute(sql)
            num_files_loaded = cursor.fetchall()[0][0]
            if num_files_loaded != num_files:
                redshift_conn.rollback()
                raise RedshiftLoadError(
                    f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected.")

            if (mode == "upsert") and (final_table_name is not None):
                if not primary_keys:
                    primary_keys = Redshift.get_primary_keys(connection=redshift_conn,
                                                             schema=schema_name,
                                                             table=final_table_name)
                if not primary_keys:
                    raise InvalidRedshiftPrimaryKeys()
                equals_clause = f"{final_table_name}.%s = {temp_table_name}.%s"
                join_clause = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
                sql = f"DELETE FROM {schema_name}.{final_table_name} USING {temp_table_name} WHERE {join_clause}"
                logger.debug(sql)
                cursor.execute(sql)
                sql = f"INSERT INTO {schema_name}.{final_table_name} SELECT * FROM {temp_table_name}"
                logger.debug(sql)
                cursor.execute(sql)

        redshift_conn.commit()