Exemple #1
0
def test__get_catalog_filter_conditions_considers_args(
        _input, is_single_database_metadata_val):
    catalog, api_supported_only_for_connected_database, database_col_name = _input

    mock_cursor: Cursor = Cursor.__new__(Cursor)
    mock_connection: Cursor = Connection.__new__(Connection)
    mock_cursor._c = mock_connection

    with patch(
            "redshift_connector.Connection.is_single_database_metadata",
            new_callable=PropertyMock()) as mock_is_single_database_metadata:
        mock_is_single_database_metadata.__get__ = Mock(
            return_value=is_single_database_metadata_val)
        result: str = mock_cursor._get_catalog_filter_conditions(
            catalog, api_supported_only_for_connected_database,
            database_col_name)

    if catalog is not None:
        assert catalog in result
        if is_single_database_metadata_val or api_supported_only_for_connected_database:
            assert "current_database()" in result
            assert catalog in result
        elif database_col_name is None:
            assert "database_name" in result
        else:
            assert database_col_name in result
    else:
        assert result == ""
Exemple #2
0
def test_get_schemas_considers_args(_input, is_single_database_metadata_val,
                                    mocker):
    catalog, schema_pattern = _input
    mocker.patch("redshift_connector.Cursor.execute", return_value=None)
    mocker.patch("redshift_connector.Cursor.fetchall", return_value=None)

    mock_cursor: Cursor = Cursor.__new__(Cursor)
    mock_cursor.paramstyle = "mocked"
    mock_connection: Cursor = Connection.__new__(Connection)
    mock_cursor._c = mock_connection
    spy = mocker.spy(mock_cursor, "execute")

    with patch(
            "redshift_connector.Connection.is_single_database_metadata",
            new_callable=PropertyMock()) as mock_is_single_database_metadata:
        mock_is_single_database_metadata.__get__ = Mock(
            return_value=is_single_database_metadata_val)
        mock_cursor.get_schemas(catalog, schema_pattern)

    assert spy.called
    assert spy.call_count == 1

    if schema_pattern is not None:  # should be in parameterized portion
        assert schema_pattern in spy.call_args[0][1]

    if catalog is not None:
        assert catalog in spy.call_args[0][0]
Exemple #3
0
def test_get_catalogs_considers_args(is_single_database_metadata_val, mocker):
    mocker.patch("redshift_connector.Cursor.execute", return_value=None)
    mocker.patch("redshift_connector.Cursor.fetchall", return_value=None)

    mock_cursor: Cursor = Cursor.__new__(Cursor)
    mock_cursor.paramstyle = "mocked"
    mock_connection: Cursor = Connection.__new__(Connection)
    mock_cursor._c = mock_connection
    spy = mocker.spy(mock_cursor, "execute")

    with patch(
            "redshift_connector.Connection.is_single_database_metadata",
            new_callable=PropertyMock()) as mock_is_single_database_metadata:
        mock_is_single_database_metadata.__get__ = Mock(
            return_value=is_single_database_metadata_val)
        mock_cursor.get_catalogs()

    assert spy.called
    assert spy.call_count == 1

    if is_single_database_metadata_val:
        assert "select current_database as TABLE_CAT FROM current_database()" in spy.call_args[
            0][0]
    else:
        assert (
            "SELECT CAST(database_name AS varchar(124)) AS TABLE_CAT FROM PG_CATALOG.SVV_REDSHIFT_DATABASES "
            in spy.call_args[0][0])
Exemple #4
0
def test_raw_connection_property_warns():
    mock_cursor: Cursor = Cursor.__new__(Cursor)
    mock_cursor._c = Connection.__new__(Connection)

    with pytest.warns(UserWarning,
                      match="DB-API extension cursor.connection used"):
        mock_cursor.connection
Exemple #5
0
def test_handle_ROW_DESCRIPTION_missing_ps_raises():
    mock_connection = Connection.__new__(Connection)
    mock_cursor = Cursor.__new__(Cursor)
    mock_cursor.ps = None

    with pytest.raises(InterfaceError,
                       match="Cursor is missing prepared statement"):
        mock_connection.handle_ROW_DESCRIPTION(b"\x00", mock_cursor)
Exemple #6
0
def test_handle_ROW_DESCRIPTION_missing_row_desc_raises():
    mock_connection = Connection.__new__(Connection)
    mock_cursor = Cursor.__new__(Cursor)
    mock_cursor.ps = {}

    with pytest.raises(InterfaceError,
                       match="Prepared Statement is missing row description"):
        mock_connection.handle_ROW_DESCRIPTION(b"\x00", mock_cursor)
Exemple #7
0
def test_is_multidatabases_catalog_enable_in_server(_input):
    param_status, exp_val = _input
    mock_connection = Connection.__new__(Connection)
    mock_connection.parameter_statuses: deque = deque()

    if param_status is not None:
        mock_connection.parameter_statuses.append(
            (b"datashare_enabled", param_status.encode()))

    assert mock_connection._is_multi_databases_catalog_enable_in_server == exp_val
Exemple #8
0
def db_table(
        request,
        con: redshift_connector.Connection) -> redshift_connector.Connection:
    filterwarnings("ignore", "DB-API extension cursor.next()")
    filterwarnings("ignore", "DB-API extension cursor.__iter__()")
    con.paramstyle = "format"  # type: ignore
    with con.cursor() as cursor:
        cursor.execute("drop table if exists book")
        cursor.execute(
            "create Temp table book(bookname varchar,author‎ varchar)")

    def fin() -> None:
        try:
            with con.cursor() as cursor:
                cursor.execute("drop table if exists book")
        except redshift_connector.ProgrammingError:
            pass

    request.addfinalizer(fin)
    return con
Exemple #9
0
def db_table(
        request,
        con: redshift_connector.Connection) -> redshift_connector.Connection:
    filterwarnings("ignore", "DB-API extension cursor.next()")
    filterwarnings("ignore", "DB-API extension cursor.__iter__()")
    con.paramstyle = "format"  # type: ignore
    with con.cursor() as cursor:
        cursor.execute("DROP TABLE IF EXISTS t1")
        cursor.execute(
            "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 bigint not null, f3 varchar(50) null) "
        )

    def fin() -> None:
        try:
            with con.cursor() as cursor:
                cursor.execute("drop table if exists t1")
        except redshift_connector.ProgrammingError:
            pass

    request.addfinalizer(fin)
    return con
Exemple #10
0
def test_is_single_database_metadata(_input):
    param_status, database_metadata_current_db_only_val, exp_val = _input

    mock_connection = Connection.__new__(Connection)
    mock_connection.parameter_statuses: deque = deque()
    mock_connection._database_metadata_current_db_only = database_metadata_current_db_only_val

    if param_status is not None:
        mock_connection.parameter_statuses.append(
            (b"datashare_enabled", param_status.encode()))

    assert mock_connection.is_single_database_metadata == exp_val
Exemple #11
0
def test_handle_ROW_DESCRIPTION_base(_input):
    data, exp_result = _input
    mock_connection = Connection.__new__(Connection)
    mock_connection._client_protocol_version = ClientProtocolVersion.BASE_SERVER.value
    mock_cursor = Cursor.__new__(Cursor)
    mock_cursor.ps = {"row_desc": []}

    mock_connection.handle_ROW_DESCRIPTION(data, mock_cursor)
    assert mock_cursor.ps is not None
    assert "row_desc" in mock_cursor.ps
    assert len(mock_cursor.ps["row_desc"]) == 1
    assert exp_result[0].items() <= mock_cursor.ps["row_desc"][0].items()
    assert "func" in mock_cursor.ps["row_desc"][0]
Exemple #12
0
def test_handle_ROW_DESCRIPTION_extended_metadata(_input, protocol):
    data, exp_result = _input
    mock_connection = Connection.__new__(Connection)
    mock_connection._client_protocol_version = protocol
    mock_cursor = Cursor.__new__(Cursor)
    mock_cursor.ps = {"row_desc": []}

    mock_connection.handle_ROW_DESCRIPTION(data, mock_cursor)
    assert mock_cursor.ps is not None
    assert "row_desc" in mock_cursor.ps
    assert len(mock_cursor.ps["row_desc"]) == 1
    assert exp_result[0].items() <= mock_cursor.ps["row_desc"][0].items()
    assert "func" in mock_cursor.ps["row_desc"][0]
Exemple #13
0
def test_get_procedures_considers_args(_input, mocker):
    catalog, schema_pattern, procedure_name_pattern = _input
    mocker.patch("redshift_connector.Cursor.execute", return_value=None)
    mocker.patch("redshift_connector.Cursor.fetchall", return_value=None)
    mocker.patch("redshift_connector.Connection.is_single_database_metadata",
                 return_value=True)

    mock_cursor: Cursor = Cursor.__new__(Cursor)
    mock_connection: Cursor = Connection.__new__(Connection)
    mock_cursor._c = mock_connection

    mock_cursor.paramstyle = "mocked_val"
    spy = mocker.spy(mock_cursor, "execute")

    mock_cursor.get_procedures(catalog, schema_pattern, procedure_name_pattern)
    assert spy.called
    assert spy.call_count == 1
    assert catalog not in spy.call_args[0][1]
    for arg in (schema_pattern, procedure_name_pattern):
        if arg is not None:
            assert arg in spy.call_args[0][1]
Exemple #14
0
def test_get_tables_considers_args(is_single_database_metadata_val, _input,
                                   schema_pattern_type, mocker):
    catalog, schema_pattern, table_name_pattern = _input
    mocker.patch("redshift_connector.Cursor.execute", return_value=None)
    # mock the return value from __schema_pattern_match as it's return value is used in get_tables()
    # the other potential call to this method in get_tables() result is simply returned, so at this time
    # it has no impact
    mocker.patch(
        "redshift_connector.Cursor.fetchall",
        return_value=None
        if schema_pattern_type == "EXTERNAL_SCHEMA_QUERY" else tuple("mock"),
    )

    mock_cursor: Cursor = Cursor.__new__(Cursor)
    mock_cursor.paramstyle = "mocked"
    mock_connection: Cursor = Connection.__new__(Connection)
    mock_cursor._c = mock_connection
    spy = mocker.spy(mock_cursor, "execute")

    with patch(
            "redshift_connector.Connection.is_single_database_metadata",
            new_callable=PropertyMock()) as mock_is_single_database_metadata:
        mock_is_single_database_metadata.__get__ = Mock(
            return_value=is_single_database_metadata_val)
        mock_cursor.get_tables(catalog, schema_pattern, table_name_pattern)

    assert spy.called

    if schema_pattern is not None and is_single_database_metadata_val:
        assert spy.call_count == 2  # call in __schema_pattern_match(), get_tables()
    else:
        assert spy.call_count == 1

    if catalog is not None:
        assert catalog in spy.call_args[0][0]

    for arg in (schema_pattern, table_name_pattern):
        if arg is not None:
            assert arg in spy.call_args[0][1]
def test_client_os_version_is_present():
    mock_connection: Connection = Connection.__new__(Connection)
    assert mock_connection.client_os_version is not None
    assert isinstance(mock_connection.client_os_version, str)
def to_sql(
    df: pd.DataFrame,
    con: redshift_connector.Connection,
    table: str,
    schema: str,
    mode: str = "append",
    index: bool = False,
    dtype: Optional[Dict[str, str]] = None,
    diststyle: str = "AUTO",
    distkey: Optional[str] = None,
    sortstyle: str = "COMPOUND",
    sortkey: Optional[List[str]] = None,
    primary_keys: Optional[List[str]] = None,
    varchar_lengths_default: int = 256,
    varchar_lengths: Optional[Dict[str, int]] = None,
) -> None:
    """Write records stored in a DataFrame into Redshift.

    Note
    ----
    For large DataFrames (1K+ rows) consider the function **wr.redshift.copy()**.


    Parameters
    ----------
    df : pandas.DataFrame
        Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
    con : redshift_connector.Connection
        Use redshift_connector.connect() to use "
        "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
    table : str
        Table name
    schema : str
        Schema name
    mode : str
        Append, overwrite or upsert.
    index : bool
        True to store the DataFrame index as a column in the table,
        otherwise False to ignore it.
    dtype: Dict[str, str], optional
        Dictionary of columns names and Redshift types to be casted.
        Useful when you have columns with undetermined or mixed data types.
        (e.g. {'col name': 'VARCHAR(10)', 'col2 name': 'FLOAT'})
        diststyle : str
        Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
        https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
    distkey : str, optional
        Specifies a column name or positional number for the distribution key.
    sortstyle : str
        Sorting can be "COMPOUND" or "INTERLEAVED".
        https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html
    sortkey : List[str], optional
        List of columns to be sorted.
    primary_keys : List[str], optional
        Primary keys.
    varchar_lengths_default : int
        The size that will be set for all VARCHAR columns not specified with varchar_lengths.
    varchar_lengths : Dict[str, int], optional
        Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).

    Returns
    -------
    None
        None.

    Examples
    --------
    Writing to Redshift using a Glue Catalog Connections

    >>> import awswrangler as wr
    >>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
    >>> wr.redshift.to_sql(
    ...     df=df
    ...     table="my_table",
    ...     schema="public",
    ...     con=con
    ... )
    >>> con.close()

    """
    if df.empty is True:
        raise exceptions.EmptyDataFrame()
    _validate_connection(con=con)
    con.autocommit = False
    try:
        with con.cursor() as cursor:
            created_table, created_schema = _create_table(
                df=df,
                path=None,
                cursor=cursor,
                table=table,
                schema=schema,
                mode=mode,
                index=index,
                dtype=dtype,
                diststyle=diststyle,
                sortstyle=sortstyle,
                distkey=distkey,
                sortkey=sortkey,
                primary_keys=primary_keys,
                varchar_lengths_default=varchar_lengths_default,
                varchar_lengths=varchar_lengths,
            )
            if index:
                df.reset_index(level=df.index.names, inplace=True)
            placeholders: str = ", ".join(["%s"] * len(df.columns))
            schema_str = f"{created_schema}." if created_schema else ""
            sql: str = f"INSERT INTO {schema_str}{created_table} VALUES ({placeholders})"
            _logger.debug("sql: %s", sql)
            parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
            cursor.executemany(sql, parameters)
            if table != created_table:  # upsert
                _upsert(cursor=cursor,
                        schema=schema,
                        table=table,
                        temp_table=created_table,
                        primary_keys=primary_keys)
            con.commit()
    except Exception as ex:
        con.rollback()
        _logger.error(ex)
        raise
def unload_to_files(
    sql: str,
    path: str,
    con: redshift_connector.Connection,
    iam_role: str,
    region: Optional[str] = None,
    max_file_size: Optional[float] = None,
    kms_key_id: Optional[str] = None,
    manifest: bool = False,
    use_threads: bool = True,
    partition_cols: Optional[List[str]] = None,
    boto3_session: Optional[boto3.Session] = None,
) -> None:
    """Unload Parquet files on s3 from a Redshift query result (Through the UNLOAD command).

    https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html

    Note
    ----
    In case of `use_threads=True` the number of threads
    that will be spawned will be gotten from os.cpu_count().

    Parameters
    ----------
    sql: str
        SQL query.
    path : Union[str, List[str]]
        S3 path to write stage files (e.g. s3://bucket_name/any_name/)
    con : redshift_connector.Connection
        Use redshift_connector.connect() to use "
        "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
    iam_role : str
        AWS IAM role with the related permissions.
    region : str, optional
        Specifies the AWS Region where the target Amazon S3 bucket is located.
        REGION is required for UNLOAD to an Amazon S3 bucket that isn't in the
        same AWS Region as the Amazon Redshift cluster. By default, UNLOAD
        assumes that the target Amazon S3 bucket is located in the same AWS
        Region as the Amazon Redshift cluster.
    max_file_size : float, optional
        Specifies the maximum size (MB) of files that UNLOAD creates in Amazon S3.
        Specify a decimal value between 5.0 MB and 6200.0 MB. If None, the default
        maximum file size is 6200.0 MB.
    kms_key_id : str, optional
        Specifies the key ID for an AWS Key Management Service (AWS KMS) key to be
        used to encrypt data files on Amazon S3.
    use_threads : bool
        True to enable concurrent requests, False to disable multiple threads.
        If enabled os.cpu_count() will be used as the max number of threads.
    manifest : bool
        Unload a manifest file on S3.
    partition_cols: List[str], optional
        Specifies the partition keys for the unload operation.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.

    Returns
    -------
    None

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
    >>> wr.redshift.unload_to_files(
    ...     sql="SELECT * FROM public.mytable",
    ...     path="s3://bucket/extracted_parquet_files/",
    ...     con=con,
    ...     iam_role="arn:aws:iam::XXX:role/XXX"
    ... )
    >>> con.close()


    """
    path = path if path.endswith("/") else f"{path}/"
    session: boto3.Session = _utils.ensure_session(session=boto3_session)
    s3.delete_objects(path=path,
                      use_threads=use_threads,
                      boto3_session=session)
    with con.cursor() as cursor:
        partition_str: str = f"\nPARTITION BY ({','.join(partition_cols)})" if partition_cols else ""
        manifest_str: str = "\nmanifest" if manifest is True else ""
        region_str: str = f"\nREGION AS '{region}'" if region is not None else ""
        max_file_size_str: str = f"\nMAXFILESIZE AS {max_file_size} MB" if max_file_size is not None else ""
        kms_key_id_str: str = f"\nKMS_KEY_ID '{kms_key_id}'" if kms_key_id is not None else ""
        sql = (f"UNLOAD ('{sql}')\n"
               f"TO '{path}'\n"
               f"IAM_ROLE '{iam_role}'\n"
               "ALLOWOVERWRITE\n"
               "PARALLEL ON\n"
               "FORMAT PARQUET\n"
               "ENCRYPTED"
               f"{kms_key_id_str}"
               f"{partition_str}"
               f"{region_str}"
               f"{max_file_size_str}"
               f"{manifest_str};")
        _logger.debug("sql: \n%s", sql)
        cursor.execute(sql)
def copy_from_files(  # pylint: disable=too-many-locals,too-many-arguments
    path: str,
    con: redshift_connector.Connection,
    table: str,
    schema: str,
    iam_role: str,
    parquet_infer_sampling: float = 1.0,
    mode: str = "append",
    diststyle: str = "AUTO",
    distkey: Optional[str] = None,
    sortstyle: str = "COMPOUND",
    sortkey: Optional[List[str]] = None,
    primary_keys: Optional[List[str]] = None,
    varchar_lengths_default: int = 256,
    varchar_lengths: Optional[Dict[str, int]] = None,
    use_threads: bool = True,
    boto3_session: Optional[boto3.Session] = None,
    s3_additional_kwargs: Optional[Dict[str, str]] = None,
) -> None:
    """Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).

    https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html

    Note
    ----
    If the table does not exist yet,
    it will be automatically created for you
    using the Parquet metadata to
    infer the columns data types.

    Note
    ----
    In case of `use_threads=True` the number of threads
    that will be spawned will be gotten from os.cpu_count().

    Parameters
    ----------
    path : str
        S3 prefix (e.g. s3://bucket/prefix/)
    con : redshift_connector.Connection
        Use redshift_connector.connect() to use "
        "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
    table : str
        Table name
    schema : str
        Schema name
    iam_role : str
        AWS IAM role with the related permissions.
    parquet_infer_sampling : float
        Random sample ratio of files that will have the metadata inspected.
        Must be `0.0 < sampling <= 1.0`.
        The higher, the more accurate.
        The lower, the faster.
    mode : str
        Append, overwrite or upsert.
    diststyle : str
        Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"].
        https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html
    distkey : str, optional
        Specifies a column name or positional number for the distribution key.
    sortstyle : str
        Sorting can be "COMPOUND" or "INTERLEAVED".
        https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html
    sortkey : List[str], optional
        List of columns to be sorted.
    primary_keys : List[str], optional
        Primary keys.
    varchar_lengths_default : int
        The size that will be set for all VARCHAR columns not specified with varchar_lengths.
    varchar_lengths : Dict[str, int], optional
        Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}).
    use_threads : bool
        True to enable concurrent requests, False to disable multiple threads.
        If enabled os.cpu_count() will be used as the max number of threads.
    boto3_session : boto3.Session(), optional
        Boto3 Session. The default boto3 session will be used if boto3_session receive None.
    s3_additional_kwargs:
        Forward to botocore requests. Valid parameters: "ACL", "Metadata", "ServerSideEncryption", "StorageClass",
        "SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging".
        e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMS_KEY_ARN'}

    Returns
    -------
    None
        None.

    Examples
    --------
    >>> import awswrangler as wr
    >>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
    >>> wr.db.copy_from_files(
    ...     path="s3://bucket/my_parquet_files/",
    ...     con=con,
    ...     table="my_table",
    ...     schema="public"
    ...     iam_role="arn:aws:iam::XXX:role/XXX"
    ... )
    >>> con.close()

    """
    con.autocommit = False
    try:
        with con.cursor() as cursor:
            created_table, created_schema = _create_table(
                df=None,
                path=path,
                parquet_infer_sampling=parquet_infer_sampling,
                cursor=cursor,
                table=table,
                schema=schema,
                mode=mode,
                diststyle=diststyle,
                sortstyle=sortstyle,
                distkey=distkey,
                sortkey=sortkey,
                primary_keys=primary_keys,
                varchar_lengths_default=varchar_lengths_default,
                varchar_lengths=varchar_lengths,
                index=False,
                dtype=None,
                use_threads=use_threads,
                boto3_session=boto3_session,
                s3_additional_kwargs=s3_additional_kwargs,
            )
            _copy(
                cursor=cursor,
                path=path,
                table=created_table,
                schema=created_schema,
                iam_role=iam_role,
            )
            if table != created_table:  # upsert
                _upsert(cursor=cursor,
                        schema=schema,
                        table=table,
                        temp_table=created_table,
                        primary_keys=primary_keys)
            con.commit()
    except Exception as ex:
        con.rollback()
        _logger.error(ex)
        raise
Exemple #19
0
def test_inspect_int(_input):
    input_val, expected_type = _input
    mock_connection = Connection.__new__(Connection)
    assert mock_connection.inspect_int(input_val) == expected_type
Exemple #20
0
def test_handle_COPY_DONE():
    mock_connection = Connection.__new__(Connection)
    assert hasattr(mock_connection, "_copy_done") is False
    mock_connection.handle_COPY_DONE(None, None)
    assert mock_connection._copy_done is True
Exemple #21
0
def test_handle_ERROR_RESPONSE(_input):
    server_msg, expected_decoded_msg, expected_error = _input
    mock_connection = Connection.__new__(Connection)
    mock_connection.handle_ERROR_RESPONSE(server_msg, None)
    assert type(mock_connection.error) == expected_error
    assert str(expected_decoded_msg) in str(mock_connection.error)
def test_client_os_version_is_not_present():
    mock_connection: Connection = Connection.__new__(Connection)

    with patch("platform.platform", side_effect=Exception("not for you")):
        assert mock_connection.client_os_version == "unknown"