예제 #1
0
    def to_s3(self,
              sql: str,
              path: str,
              connection: Any,
              engine: str = "mysql") -> str:
        """
        Write a query result on S3.

        :param sql: SQL Query
        :param path: AWS S3 path to write the data (e.g. s3://...)
        :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
        :param engine: Only "mysql" by now
        :return: Manifest S3 path
        """
        if "mysql" not in engine.lower():
            raise InvalidEngine(
                f"{engine} is not a valid engine. Please use 'mysql'!")
        path = path[-1] if path[-1] == "/" else path
        self._session.s3.delete_objects(path=path)
        sql = f"{sql}\n" \
              f"INTO OUTFILE S3 '{path}'\n" \
              "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\\\\'\n" \
              "LINES TERMINATED BY '\\n'\n" \
              "MANIFEST ON\n" \
              "OVERWRITE ON"
        with connection.cursor() as cursor:
            logger.debug(sql)
            cursor.execute(sql)
        connection.commit()
        return path + ".manifest"
예제 #2
0
 def _get_schema(dataframe,
                 dataframe_type: str,
                 preserve_index: bool,
                 engine: str = "mysql") -> List[Tuple[str, str]]:
     schema_built: List[Tuple[str, str]] = []
     if "postgres" in engine.lower():
         convert_func = data_types.pyarrow2postgres
     elif "mysql" in engine.lower():
         convert_func = data_types.pyarrow2mysql
     else:
         raise InvalidEngine(
             f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
         )
     if dataframe_type.lower() == "pandas":
         pyarrow_schema: List[Tuple[
             str, str]] = data_types.extract_pyarrow_schema_from_pandas(
                 dataframe=dataframe,
                 preserve_index=preserve_index,
                 indexes_position="right")
         for name, dtype in pyarrow_schema:
             aurora_type: str = convert_func(dtype)
             schema_built.append((name, aurora_type))
     else:
         raise InvalidDataframeType(
             f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!"
         )
     return schema_built
예제 #3
0
 def _validate_connection(
         database: str,
         host: str,
         port: Union[str, int],
         user: str,
         password: str,
         engine: str = "mysql",
         tcp_keepalive: bool = True,
         application_name: str = "aws-data-wrangler-validation",
         validation_timeout: int = 10) -> None:
     if "postgres" in engine.lower():
         conn = pg8000.connect(database=database,
                               host=host,
                               port=int(port),
                               user=user,
                               password=password,
                               ssl=True,
                               application_name=application_name,
                               tcp_keepalive=tcp_keepalive,
                               timeout=validation_timeout)
     elif "mysql" in engine.lower():
         conn = pymysql.connect(database=database,
                                host=host,
                                port=int(port),
                                user=user,
                                password=password,
                                program_name=application_name,
                                connect_timeout=validation_timeout)
     else:
         raise InvalidEngine(
             f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
         )
     conn.close()
예제 #4
0
 def _get_load_sql(path: str,
                   schema_name: str,
                   table_name: str,
                   engine: str,
                   region: str = "us-east-1") -> str:
     if "postgres" in engine.lower():
         bucket, key = Aurora._parse_path(path=path)
         sql: str = (
             "-- AWS DATA WRANGLER\n"
             "SELECT aws_s3.table_import_from_s3(\n"
             f"'{schema_name}.{table_name}',\n"
             "'',\n"
             "'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
             f"'({bucket},{key},{region})')")
     elif "mysql" in engine.lower():
         sql = (
             "-- AWS DATA WRANGLER\n"
             f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
             "REPLACE\n"
             f"INTO TABLE {schema_name}.{table_name}\n"
             "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
             "LINES TERMINATED BY '\\n'")
     else:
         raise InvalidEngine(
             f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
         )
     return sql
예제 #5
0
    def generate_connection(database: str,
                            host: str,
                            port: Union[str, int],
                            user: str,
                            password: str,
                            engine: str = "mysql",
                            tcp_keepalive: bool = True,
                            application_name: str = "aws-data-wrangler",
                            connection_timeout: Optional[int] = None,
                            validation_timeout: int = 10):
        """
        Generate a valid connection object.

        :param database: The name of the database instance to connect with.
        :param host: The hostname of the Aurora server to connect with.
        :param port: The TCP/IP port of the Aurora server instance.
        :param user: The username to connect to the Aurora database with.
        :param password: The user password to connect to the server with.
        :param engine: "mysql" or "postgres"
        :param tcp_keepalive: If True then use TCP keepalive
        :param application_name: Application name
        :param connection_timeout: Connection Timeout
        :param validation_timeout: Timeout to try to validate the connection
        :return: PEP 249 compatible connection
        """
        Aurora._validate_connection(database=database,
                                    host=host,
                                    port=port,
                                    user=user,
                                    password=password,
                                    engine=engine,
                                    tcp_keepalive=tcp_keepalive,
                                    application_name=application_name,
                                    validation_timeout=validation_timeout)
        if "postgres" in engine.lower():
            conn = pg8000.connect(database=database,
                                  host=host,
                                  port=int(port),
                                  user=user,
                                  password=password,
                                  ssl=True,
                                  application_name=application_name,
                                  tcp_keepalive=tcp_keepalive,
                                  timeout=connection_timeout)
        elif "mysql" in engine.lower():
            conn = pymysql.connect(database=database,
                                   host=host,
                                   port=int(port),
                                   user=user,
                                   password=password,
                                   program_name=application_name,
                                   connect_timeout=validation_timeout)
        else:
            raise InvalidEngine(
                f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
            )
        return conn
예제 #6
0
    def load_table(dataframe: pd.DataFrame,
                   dataframe_type: str,
                   load_paths: List[str],
                   schema_name: str,
                   table_name: str,
                   connection: Any,
                   num_files: int,
                   columns: Optional[List[str]] = None,
                   mode: str = "append",
                   preserve_index: bool = False,
                   engine: str = "mysql",
                   region: str = "us-east-1"):
        """
        Load text/CSV files into a Aurora table using a manifest file.
        Creates the table if necessary.

        :param dataframe: Pandas or Spark Dataframe
        :param dataframe_type: "pandas" or "spark"
        :param load_paths: S3 paths to be loaded (E.g. S3://...)
        :param schema_name: Aurora schema
        :param table_name: Aurora table name
        :param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
        :param num_files: Number of files to be loaded
        :param columns: List of columns to load
        :param mode: append or overwrite
        :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
        :param engine: "mysql" or "postgres"
        :param region: AWS S3 bucket region (Required only for postgres engine)
        :return: None
        """
        if "postgres" in engine.lower():
            Aurora.load_table_postgres(dataframe=dataframe,
                                       dataframe_type=dataframe_type,
                                       load_paths=load_paths,
                                       schema_name=schema_name,
                                       table_name=table_name,
                                       connection=connection,
                                       mode=mode,
                                       preserve_index=preserve_index,
                                       region=region,
                                       columns=columns)
        elif "mysql" in engine.lower():
            Aurora.load_table_mysql(dataframe=dataframe,
                                    dataframe_type=dataframe_type,
                                    manifest_path=load_paths[0],
                                    schema_name=schema_name,
                                    table_name=table_name,
                                    connection=connection,
                                    mode=mode,
                                    preserve_index=preserve_index,
                                    num_files=num_files,
                                    columns=columns)
        else:
            raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
예제 #7
0
    def _create_table(
            cursor,
            dataframe,
            dataframe_type,
            schema_name,
            table_name,
            preserve_index=False,
            engine: str = "mysql",
            columns: Optional[List[str]] = None,
            varchar_default_length: int = 256,
            varchar_lengths: Optional[Dict[str, int]] = None) -> None:
        """
        Create Aurora table.

        :param cursor: A PEP 249 compatible cursor
        :param dataframe: Pandas or Spark Dataframe
        :param dataframe_type: "pandas" or "spark"
        :param schema_name: Redshift schema
        :param table_name: Redshift table name
        :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
        :param engine: "mysql" or "postgres"
        :param columns: List of columns to load
        :param varchar_default_length: The size that will be set for all VARCHAR columns not specified with varchar_lengths
        :param varchar_lengths: Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200})
        :return: None
        """
        sql: str = f"-- AWS DATA WRANGLER\n" \
                   f"DROP TABLE IF EXISTS {schema_name}.{table_name}"
        logger.debug(f"Drop table query:\n{sql}")
        if "postgres" in engine.lower():
            cursor.execute(sql)
        elif "mysql" in engine.lower():
            with warnings.catch_warnings():
                warnings.filterwarnings(action="ignore",
                                        message=".*Unknown table.*")
                cursor.execute(sql)
        else:
            raise InvalidEngine(
                f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
            )
        schema = Aurora._get_schema(
            dataframe=dataframe,
            dataframe_type=dataframe_type,
            preserve_index=preserve_index,
            engine=engine,
            columns=columns,
            varchar_default_length=varchar_default_length,
            varchar_lengths=varchar_lengths)
        cols_str: str = "".join([f"{col[0]} {col[1]},\n"
                                 for col in schema])[:-2]
        sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
        logger.debug(f"Create table query:\n{sql}")
        cursor.execute(sql)
예제 #8
0
    def _create_table(cursor,
                      dataframe,
                      dataframe_type,
                      schema_name,
                      table_name,
                      preserve_index=False,
                      engine: str = "mysql"):
        """
        Creates Aurora table.

        :param cursor: A PEP 249 compatible cursor
        :param dataframe: Pandas or Spark Dataframe
        :param dataframe_type: "pandas" or "spark"
        :param schema_name: Redshift schema
        :param table_name: Redshift table name
        :param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
        :param engine: "mysql" or "postgres"
        :return: None
        """
        sql: str = f"-- AWS DATA WRANGLER\n" \
                   f"DROP TABLE IF EXISTS {schema_name}.{table_name}"
        logger.debug(f"Drop table query:\n{sql}")
        if "postgres" in engine.lower():
            cursor.execute(sql)
        elif "mysql" in engine.lower():
            with warnings.catch_warnings():
                warnings.filterwarnings(action="ignore",
                                        message=".*Unknown table.*")
                cursor.execute(sql)
        else:
            raise InvalidEngine(
                f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
            )
        schema = Aurora._get_schema(dataframe=dataframe,
                                    dataframe_type=dataframe_type,
                                    preserve_index=preserve_index,
                                    engine=engine)
        cols_str: str = "".join([f"{col[0]} {col[1]},\n"
                                 for col in schema])[:-2]
        sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
        logger.debug(f"Create table query:\n{sql}")
        cursor.execute(sql)
예제 #9
0
 def _get_load_sql(path: str,
                   schema_name: str,
                   table_name: str,
                   engine: str,
                   region: str = "us-east-1",
                   columns: Optional[List[str]] = None) -> str:
     if "postgres" in engine.lower():
         bucket, key = Aurora._parse_path(path=path)
         if columns is None:
             cols_str: str = ""
         else:
             cols_str = ",".join(columns)
         sql: str = (
             "-- AWS DATA WRANGLER\n"
             "SELECT aws_s3.table_import_from_s3(\n"
             f"'{schema_name}.{table_name}',\n"
             f"'{cols_str}',\n"
             "'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
             f"'({bucket},{key},{region})')")
     elif "mysql" in engine.lower():
         if columns is None:
             cols_str = ""
         else:
             # building something like: (@col1,@col2) set col1=@col1,col2=@col2
             col_str = [f"@{x}" for x in columns]
             set_str = [f"{x}=@{x}" for x in columns]
             cols_str = f"({','.join(col_str)}) SET {','.join(set_str)}"
             logger.debug(f"cols_str: {cols_str}")
         sql = (
             "-- AWS DATA WRANGLER\n"
             f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
             "REPLACE\n"
             f"INTO TABLE {schema_name}.{table_name}\n"
             "FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
             "LINES TERMINATED BY '\\n'"
             f"{cols_str}")
     else:
         raise InvalidEngine(
             f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
         )
     return sql
예제 #10
0
 def _get_schema(
     dataframe,
     dataframe_type: str,
     preserve_index: bool,
     engine: str = "mysql",
     columns: Optional[List[str]] = None,
     varchar_default_length: int = 256,
     varchar_lengths: Optional[Dict[str,
                                    int]] = None) -> List[Tuple[str, str]]:
     varchar_lengths = {} if varchar_lengths is None else varchar_lengths
     schema_built: List[Tuple[str, str]] = []
     if "postgres" in engine.lower():
         convert_func = data_types.pyarrow2postgres
     elif "mysql" in engine.lower():
         convert_func = data_types.pyarrow2mysql
     else:
         raise InvalidEngine(
             f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
         )
     if dataframe_type.lower() == "pandas":
         pyarrow_schema: List[Tuple[
             str, str]] = data_types.extract_pyarrow_schema_from_pandas(
                 dataframe=dataframe,
                 preserve_index=preserve_index,
                 indexes_position="right")
         for name, dtype in pyarrow_schema:
             if columns is None or name in columns:
                 varchar_len = varchar_lengths.get(name,
                                                   varchar_default_length)
                 aurora_type: str = convert_func(dtype=dtype,
                                                 varchar_length=varchar_len)
                 schema_built.append((name, aurora_type))
     else:
         raise InvalidDataframeType(
             f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!"
         )
     return schema_built
예제 #11
0
                                  password=password,
                                  ssl=True,
                                  application_name=application_name,
                                  tcp_keepalive=tcp_keepalive,
                                  timeout=connection_timeout)
        elif "mysql" in engine.lower():
            conn = pymysql.connect(database=database,
                                   host=host,
                                   port=int(port),
                                   user=user,
                                   password=password,
                                   program_name=application_name,
                                   connect_timeout=connection_timeout)
        else:
            raise InvalidEngine(
                f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!"
            )
        return conn

    def write_load_manifest(
        self, manifest_path: str, objects_paths: List[str]
    ) -> Dict[str, List[Dict[str, Union[str, bool]]]]:
        manifest: Dict[str, List[Dict[str, Union[str, bool]]]] = {
            "entries": []
        }
        path: str
        for path in objects_paths:
            entry: Dict[str, Union[str, bool]] = {
                "url": path,
                "mandatory": True
            }