Exemple #1
0
def pyarrow2sqlalchemy(  # pylint: disable=too-many-branches,too-many-return-statements
        dtype: pa.DataType, db_type: str) -> VisitableType:
    """Pyarrow to Athena data types conversion."""
    if pa.types.is_int8(dtype):
        return sqlalchemy.types.SmallInteger
    if pa.types.is_int16(dtype):
        return sqlalchemy.types.SmallInteger
    if pa.types.is_int32(dtype):
        return sqlalchemy.types.Integer
    if pa.types.is_int64(dtype):
        return sqlalchemy.types.BigInteger
    if pa.types.is_float32(dtype):
        return sqlalchemy.types.Float
    if pa.types.is_float64(dtype):
        if db_type == "mysql":
            return sqlalchemy.dialects.mysql.DOUBLE
        if db_type == "postgresql":
            return sqlalchemy.dialects.postgresql.DOUBLE_PRECISION
        if db_type == "redshift":
            return sqlalchemy_redshift.dialect.DOUBLE_PRECISION
        raise exceptions.InvalidDatabaseType(
            f"{db_type} is a invalid database type, please choose between postgresql, mysql and redshift."
        )  # pragma: no cover
    if pa.types.is_boolean(dtype):
        return sqlalchemy.types.Boolean
    if pa.types.is_string(dtype):
        if db_type == "mysql":
            return sqlalchemy.types.Text
        if db_type == "postgresql":
            return sqlalchemy.types.Text
        if db_type == "redshift":
            return sqlalchemy.types.VARCHAR(length=256)
        raise exceptions.InvalidDatabaseType(
            f"{db_type} is a invalid database type. "
            f"Please choose between postgresql, mysql and redshift."
        )  # pragma: no cover
    if pa.types.is_timestamp(dtype):
        return sqlalchemy.types.DateTime
    if pa.types.is_date(dtype):
        return sqlalchemy.types.Date
    if pa.types.is_binary(dtype):
        if db_type == "redshift":
            raise exceptions.UnsupportedType(
                f"Binary columns are not supported for Redshift."
            )  # pragma: no cover
        return sqlalchemy.types.Binary
    if pa.types.is_decimal(dtype):
        return sqlalchemy.types.Numeric(precision=dtype.precision,
                                        scale=dtype.scale)
    if pa.types.is_dictionary(dtype):
        return pyarrow2sqlalchemy(dtype=dtype.value_type, db_type=db_type)
    if dtype == pa.null():  # pragma: no cover
        raise exceptions.UndetectedType(
            "We can not infer the data type from an entire null object column")
    raise exceptions.UnsupportedType(
        f"Unsupported Pyarrow type: {dtype}")  # pragma: no cover
def get_engine(
    connection: str,
    catalog_id: Optional[str] = None,
    boto3_session: Optional[boto3.Session] = None,
    **sqlalchemy_kwargs: Any,
) -> sqlalchemy.engine.Engine:
    """Return a SQLAlchemy Engine from a Glue Catalog Connection.

    Only Redshift, PostgreSQL and MySQL are supported.

    Parameters
    ----------
    connection : str
        Connection name.
    catalog_id : str, optional
        The ID of the Data Catalog from which to retrieve Databases.
        If none is provided, the AWS account ID is used by default.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    sqlalchemy_kwargs
        keyword arguments forwarded to sqlalchemy.create_engine().
        https://docs.sqlalchemy.org/en/13/core/engines.html

    Returns
    -------
    sqlalchemy.engine.Engine
        SQLAlchemy Engine.

    Examples
    --------
    >>> import awswrangler as wr
    >>> res = wr.catalog.get_engine(name='my_connection')

    """
    details: Dict[str,
                  Any] = get_connection(
                      name=connection,
                      catalog_id=catalog_id,
                      boto3_session=boto3_session)["ConnectionProperties"]
    db_type: str = details["JDBC_CONNECTION_URL"].split(":")[1].lower()
    host: str = details["JDBC_CONNECTION_URL"].split(":")[2].replace("/", "")
    port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split("/")
    user: str = _quote_plus(details["USERNAME"])
    password: str = _quote_plus(details["PASSWORD"])
    if db_type == "postgresql":
        _utils.ensure_postgresql_casts()
    if db_type in ("redshift", "postgresql"):
        conn_str: str = f"{db_type}+psycopg2://{user}:{password}@{host}:{port}/{database}"
        sqlalchemy_kwargs["executemany_mode"] = "values"
        sqlalchemy_kwargs["executemany_values_page_size"] = 100_000
        return sqlalchemy.create_engine(conn_str, **sqlalchemy_kwargs)
    if db_type == "mysql":
        conn_str = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
        return sqlalchemy.create_engine(conn_str, **sqlalchemy_kwargs)
    raise exceptions.InvalidDatabaseType(
        f"{db_type} is not a valid Database type."
        f" Only Redshift, PostgreSQL and MySQL are supported.")
def get_engine(db_type: str, host: str, port: int, database: str, user: str,
               password: str,
               **sqlalchemy_kwargs: Any) -> sqlalchemy.engine.Engine:
    """Return a SQLAlchemy Engine from the given arguments.

    Only Redshift, PostgreSQL and MySQL are supported.

    Parameters
    ----------
    db_type : str
        Database type: "redshift", "mysql" or "postgresql".
    host : str
        Host address.
    port : str
        Port number.
    database : str
        Database name.
    user : str
        Username.
    password : str
        Password.
    sqlalchemy_kwargs
        keyword arguments forwarded to sqlalchemy.create_engine().
        https://docs.sqlalchemy.org/en/13/core/engines.html

    Returns
    -------
    sqlalchemy.engine.Engine
        SQLAlchemy Engine.

    Examples
    --------
    >>> import awswrangler as wr
    >>> engine = wr.db.get_engine(
    ...     db_type="postgresql",
    ...     host="...",
    ...     port=1234,
    ...     database="...",
    ...     user="******",
    ...     password="******"
    ... )

    """
    if db_type == "postgresql":
        _utils.ensure_postgresql_casts()
    if db_type in ("redshift", "postgresql"):
        conn_str: str = f"{db_type}+psycopg2://{user}:{password}@{host}:{port}/{database}"
        sqlalchemy_kwargs["executemany_mode"] = "values"
        sqlalchemy_kwargs["executemany_values_page_size"] = 100_000
        return sqlalchemy.create_engine(conn_str, **sqlalchemy_kwargs)
    if db_type == "mysql":
        conn_str = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
        return sqlalchemy.create_engine(conn_str, **sqlalchemy_kwargs)
    raise exceptions.InvalidDatabaseType(
        f"{db_type} is not a valid Database type."
        f" Only Redshift, PostgreSQL and MySQL are supported.")
Exemple #4
0
def get_engine(db_type: str, host: str, port: int, database: str, user: str,
               password: str) -> sqlalchemy.engine.Engine:
    """Return a SQLAlchemy Engine from the given arguments.

    Only Redshift, PostgreSQL and MySQL are supported.

    Parameters
    ----------
    db_type : str
        Database type: "redshift", "mysql" or "postgresql".
    host : str
        Host address.
    port : str
        Port number.
    database : str
        Database name.
    user : str
        Username.
    password : str
        Password.

    Returns
    -------
    sqlalchemy.engine.Engine
        SQLAlchemy Engine.

    Examples
    --------
    >>> import awswrangler as wr
    >>> engine = wr.db.get_engine(
    ...     db_type="postgresql",
    ...     host="...",
    ...     port=1234,
    ...     database="...",
    ...     user="******",
    ...     password="******"
    ... )

    """
    if db_type == "postgresql":
        _utils.ensure_postgresql_casts()
    if db_type in ("redshift", "postgresql"):
        conn_str: str = f"{db_type}+psycopg2://{user}:{password}@{host}:{port}/{database}"
        return sqlalchemy.create_engine(conn_str,
                                        echo=False,
                                        executemany_mode="values",
                                        executemany_values_page_size=100_000,
                                        connect_args={'sslmode': 'verify-ca'})
    if db_type == "mysql":
        conn_str = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"
        return sqlalchemy.create_engine(conn_str, echo=False)
    raise exceptions.InvalidDatabaseType(  # pragma: no cover
        f"{db_type} is not a valid Database type."
        f" Only Redshift, PostgreSQL and MySQL are supported.")
Exemple #5
0
def connect(
    connection: Optional[str] = None,
    secret_id: Optional[str] = None,
    catalog_id: Optional[str] = None,
    dbname: Optional[str] = None,
    boto3_session: Optional[boto3.Session] = None,
    ssl_context: Optional[Dict[Any, Any]] = None,
    timeout: Optional[int] = None,
    tcp_keepalive: bool = True,
) -> pg8000.Connection:
    """Return a pg8000 connection from a Glue Catalog Connection.

    https://github.com/tlocke/pg8000

    Parameters
    ----------
    connection : Optional[str]
        Glue Catalog Connection name.
    secret_id: Optional[str]:
        Specifies the secret containing the version that you want to retrieve.
        You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
    catalog_id : str, optional
        The ID of the Data Catalog.
        If none is provided, the AWS account ID is used by default.
    dbname: Optional[str]
        Optional database name to overwrite the stored one.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    ssl_context: Optional[Dict]
        This governs SSL encryption for TCP/IP sockets.
        This parameter is forward to pg8000.
        https://github.com/tlocke/pg8000#functions
    timeout: Optional[int]
        This is the time in seconds before the connection to the server will time out.
        The default is None which means no timeout.
        This parameter is forward to pg8000.
        https://github.com/tlocke/pg8000#functions
    tcp_keepalive: bool
        If True then use TCP keepalive. The default is True.
        This parameter is forward to pg8000.
        https://github.com/tlocke/pg8000#functions

    Returns
    -------
    pg8000.Connection
        pg8000 connection.

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.postgresql.connect("MY_GLUE_CONNECTION")
    >>> with con.cursor() as cursor:
    >>>     cursor.execute("SELECT 1")
    >>>     print(cursor.fetchall())
    >>> con.close()

    """
    attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
        connection=connection,
        secret_id=secret_id,
        catalog_id=catalog_id,
        dbname=dbname,
        boto3_session=boto3_session)
    if attrs.kind != "postgresql":
        exceptions.InvalidDatabaseType(
            f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)"
        )
    return pg8000.connect(
        user=attrs.user,
        database=attrs.database,
        password=attrs.password,
        port=attrs.port,
        host=attrs.host,
        ssl_context=ssl_context,
        timeout=timeout,
        tcp_keepalive=tcp_keepalive,
    )
Exemple #6
0
def connect(
    connection: Optional[str] = None,
    secret_id: Optional[str] = None,
    catalog_id: Optional[str] = None,
    dbname: Optional[str] = None,
    odbc_driver_version: int = 17,
    boto3_session: Optional[boto3.Session] = None,
    timeout: Optional[int] = 0,
) -> "pyodbc.Connection":
    """Return a pyodbc connection from a Glue Catalog Connection.

    https://github.com/mkleehammer/pyodbc

    Parameters
    ----------
    connection : Optional[str]
        Glue Catalog Connection name.
    secret_id: Optional[str]:
        Specifies the secret containing the version that you want to retrieve.
        You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
    catalog_id : str, optional
        The ID of the Data Catalog.
        If none is provided, the AWS account ID is used by default.
    dbname: Optional[str]
        Optional database name to overwrite the stored one.
    odbc_driver_version : int
        Major version of the OBDC Driver version that is installed and should be used.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    timeout: Optional[int]
        This is the time in seconds before the connection to the server will time out.
        The default is None which means no timeout.
        This parameter is forwarded to pyodbc.
        https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#connect

    Returns
    -------
    pyodbc.Connection
        pyodbc connection.

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17)
    >>> with con.cursor() as cursor:
    >>>     cursor.execute("SELECT 1")
    >>>     print(cursor.fetchall())
    >>> con.close()

    """
    attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
        connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
    )
    if attrs.kind != "sqlserver":
        raise exceptions.InvalidDatabaseType(
            f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)"
        )
    connection_str = (
        f"DRIVER={{ODBC Driver {odbc_driver_version} for SQL Server}};"
        f"SERVER={attrs.host},{attrs.port};"
        f"DATABASE={attrs.database};"
        f"UID={attrs.user};"
        f"PWD={attrs.password}"
    )

    return pyodbc.connect(connection_str, timeout=timeout)
Exemple #7
0
def connect(
    connection: Optional[str] = None,
    secret_id: Optional[str] = None,
    catalog_id: Optional[str] = None,
    dbname: Optional[str] = None,
    boto3_session: Optional[boto3.Session] = None,
    read_timeout: Optional[int] = None,
    write_timeout: Optional[int] = None,
    connect_timeout: int = 10,
) -> pymysql.connections.Connection:
    """Return a pymysql connection from a Glue Catalog Connection.

    https://pymysql.readthedocs.io

    Parameters
    ----------
    connection : str
        Glue Catalog Connection name.
    secret_id: Optional[str]:
        Specifies the secret containing the version that you want to retrieve.
        You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
    catalog_id : str, optional
        The ID of the Data Catalog.
        If none is provided, the AWS account ID is used by default.
    dbname: Optional[str]
        Optional database name to overwrite the stored one.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    read_timeout: Optional[int]
        The timeout for reading from the connection in seconds (default: None - no timeout).
        This parameter is forward to pymysql.
        https://pymysql.readthedocs.io/en/latest/modules/connections.html
    write_timeout: Optional[int]
        The timeout for writing to the connection in seconds (default: None - no timeout)
        This parameter is forward to pymysql.
        https://pymysql.readthedocs.io/en/latest/modules/connections.html
    connect_timeout: int
        Timeout before throwing an exception when connecting.
        (default: 10, min: 1, max: 31536000)
        This parameter is forward to pymysql.
        https://pymysql.readthedocs.io/en/latest/modules/connections.html

    Returns
    -------
    pymysql.connections.Connection
        pymysql connection.

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.mysql.connect("MY_GLUE_CONNECTION")
    >>> with con.cursor() as cursor:
    >>>     cursor.execute("SELECT 1")
    >>>     print(cursor.fetchall())
    >>> con.close()

    """
    attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
        connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
    )
    if attrs.kind != "mysql":
        raise exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)")
    return pymysql.connect(
        user=attrs.user,
        database=attrs.database,
        password=attrs.password,
        port=attrs.port,
        host=attrs.host,
        read_timeout=read_timeout,
        write_timeout=write_timeout,
        connect_timeout=connect_timeout,
    )
Exemple #8
0
def connect(
    connection: Optional[str] = None,
    secret_id: Optional[str] = None,
    catalog_id: Optional[str] = None,
    dbname: Optional[str] = None,
    boto3_session: Optional[boto3.Session] = None,
    read_timeout: Optional[int] = None,
    write_timeout: Optional[int] = None,
    connect_timeout: int = 10,
    cursorclass: Type[Cursor] = Cursor,
) -> "pymysql.connections.Connection[Any]":
    """Return a pymysql connection from a Glue Catalog Connection or Secrets Manager.

    https://pymysql.readthedocs.io

    Note
    ----
    You MUST pass a `connection` OR `secret_id`.
    Here is an example of the secret structure in Secrets Manager:
    {
    "host":"mysql-instance-wrangler.dr8vkeyrb9m1.us-east-1.rds.amazonaws.com",
    "username":"******",
    "password":"******",
    "engine":"mysql",
    "port":"3306",
    "dbname": "mydb" # Optional
    }

    Note
    ----
    It is only possible to configure SSL using Glue Catalog Connection. More at:
    https://docs.aws.amazon.com/glue/latest/dg/connection-defining.html

    Note
    ----
    Consider using SSCursor `cursorclass` for queries that return a lot of data. More at:
    https://pymysql.readthedocs.io/en/latest/modules/cursors.html#pymysql.cursors.SSCursor

    Parameters
    ----------
    connection : str
        Glue Catalog Connection name.
    secret_id: Optional[str]:
        Specifies the secret containing the connection details that you want to retrieve.
        You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
    catalog_id : str, optional
        The ID of the Data Catalog.
        If none is provided, the AWS account ID is used by default.
    dbname: Optional[str]
        Optional database name to overwrite the stored one.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    read_timeout: Optional[int]
        The timeout for reading from the connection in seconds (default: None - no timeout).
        This parameter is forward to pymysql.
        https://pymysql.readthedocs.io/en/latest/modules/connections.html
    write_timeout: Optional[int]
        The timeout for writing to the connection in seconds (default: None - no timeout)
        This parameter is forward to pymysql.
        https://pymysql.readthedocs.io/en/latest/modules/connections.html
    connect_timeout: int
        Timeout before throwing an exception when connecting.
        (default: 10, min: 1, max: 31536000)
        This parameter is forward to pymysql.
        https://pymysql.readthedocs.io/en/latest/modules/connections.html
    cursorclass : Cursor
        Cursor class to use, e.g. SSCursor; defaults to :class:`pymysql.cursors.Cursor`
        https://pymysql.readthedocs.io/en/latest/modules/cursors.html

    Returns
    -------
    pymysql.connections.Connection
        pymysql connection.

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.mysql.connect("MY_GLUE_CONNECTION")
    >>> with con.cursor() as cursor:
    >>>     cursor.execute("SELECT 1")
    >>>     print(cursor.fetchall())
    >>> con.close()

    """
    attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
        connection=connection,
        secret_id=secret_id,
        catalog_id=catalog_id,
        dbname=dbname,
        boto3_session=boto3_session)
    if attrs.kind != "mysql":
        raise exceptions.InvalidDatabaseType(
            f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)"
        )
    return pymysql.connect(
        user=attrs.user,
        database=attrs.database,
        password=attrs.password,
        port=attrs.port,
        host=attrs.host,
        ssl=attrs.ssl_context,  # type: ignore
        read_timeout=read_timeout,
        write_timeout=write_timeout,
        connect_timeout=connect_timeout,
        cursorclass=cursorclass,
    )
def connect(
    connection: str,
    catalog_id: Optional[str] = None,
    boto3_session: Optional[boto3.Session] = None,
    ssl: bool = True,
    timeout: Optional[int] = None,
    max_prepared_statements: int = 1000,
    tcp_keepalive: bool = True,
) -> redshift_connector.Connection:
    """Return a redshift_connector connection from a Glue Catalog Connection.

    https://github.com/aws/amazon-redshift-python-driver

    Parameters
    ----------
    connection : str
        Glue Catalog Connection name.
    catalog_id : str, optional
        The ID of the Data Catalog.
        If none is provided, the AWS account ID is used by default.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    ssl: bool
        This governs SSL encryption for TCP/IP sockets.
        This parameter is forward to redshift_connector.
        https://github.com/aws/amazon-redshift-python-driver
    timeout: Optional[int]
        This is the time in seconds before the connection to the server will time out.
        The default is None which means no timeout.
        This parameter is forward to redshift_connector.
        https://github.com/aws/amazon-redshift-python-driver
    max_prepared_statements: int
        This parameter is forward to redshift_connector.
        https://github.com/aws/amazon-redshift-python-driver
    tcp_keepalive: bool
        If True then use TCP keepalive. The default is True.
        This parameter is forward to redshift_connector.
        https://github.com/aws/amazon-redshift-python-driver

    Returns
    -------
    redshift_connector.Connection
        redshift_connector connection.

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
    >>> with con.cursor() as cursor:
    >>>     cursor.execute("SELECT 1")
    >>>     print(cursor.fetchall())
    >>> con.close()

    """
    attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
        connection=connection,
        catalog_id=catalog_id,
        boto3_session=boto3_session)
    if attrs.kind != "redshift":
        exceptions.InvalidDatabaseType(
            f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)"
        )
    return redshift_connector.connect(
        user=attrs.user,
        database=attrs.database,
        password=attrs.password,
        port=attrs.port,
        host=attrs.host,
        ssl=ssl,
        timeout=timeout,
        max_prepared_statements=max_prepared_statements,
        tcp_keepalive=tcp_keepalive,
    )