def test__get_catalog_filter_conditions_considers_args( _input, is_single_database_metadata_val): catalog, api_supported_only_for_connected_database, database_col_name = _input mock_cursor: Cursor = Cursor.__new__(Cursor) mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) result: str = mock_cursor._get_catalog_filter_conditions( catalog, api_supported_only_for_connected_database, database_col_name) if catalog is not None: assert catalog in result if is_single_database_metadata_val or api_supported_only_for_connected_database: assert "current_database()" in result assert catalog in result elif database_col_name is None: assert "database_name" in result else: assert database_col_name in result else: assert result == ""
def test_get_schemas_considers_args(_input, is_single_database_metadata_val, mocker): catalog, schema_pattern = _input mocker.patch("redshift_connector.Cursor.execute", return_value=None) mocker.patch("redshift_connector.Cursor.fetchall", return_value=None) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.paramstyle = "mocked" mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection spy = mocker.spy(mock_cursor, "execute") with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) mock_cursor.get_schemas(catalog, schema_pattern) assert spy.called assert spy.call_count == 1 if schema_pattern is not None: # should be in parameterized portion assert schema_pattern in spy.call_args[0][1] if catalog is not None: assert catalog in spy.call_args[0][0]
def test_get_catalogs_considers_args(is_single_database_metadata_val, mocker): mocker.patch("redshift_connector.Cursor.execute", return_value=None) mocker.patch("redshift_connector.Cursor.fetchall", return_value=None) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.paramstyle = "mocked" mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection spy = mocker.spy(mock_cursor, "execute") with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) mock_cursor.get_catalogs() assert spy.called assert spy.call_count == 1 if is_single_database_metadata_val: assert "select current_database as TABLE_CAT FROM current_database()" in spy.call_args[ 0][0] else: assert ( "SELECT CAST(database_name AS varchar(124)) AS TABLE_CAT FROM PG_CATALOG.SVV_REDSHIFT_DATABASES " in spy.call_args[0][0])
def test_raw_connection_property_warns(): mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor._c = Connection.__new__(Connection) with pytest.warns(UserWarning, match="DB-API extension cursor.connection used"): mock_cursor.connection
def test_handle_ROW_DESCRIPTION_missing_ps_raises(): mock_connection = Connection.__new__(Connection) mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = None with pytest.raises(InterfaceError, match="Cursor is missing prepared statement"): mock_connection.handle_ROW_DESCRIPTION(b"\x00", mock_cursor)
def test_handle_ROW_DESCRIPTION_missing_row_desc_raises(): mock_connection = Connection.__new__(Connection) mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = {} with pytest.raises(InterfaceError, match="Prepared Statement is missing row description"): mock_connection.handle_ROW_DESCRIPTION(b"\x00", mock_cursor)
def test_is_multidatabases_catalog_enable_in_server(_input): param_status, exp_val = _input mock_connection = Connection.__new__(Connection) mock_connection.parameter_statuses: deque = deque() if param_status is not None: mock_connection.parameter_statuses.append( (b"datashare_enabled", param_status.encode())) assert mock_connection._is_multi_databases_catalog_enable_in_server == exp_val
def db_table( request, con: redshift_connector.Connection) -> redshift_connector.Connection: filterwarnings("ignore", "DB-API extension cursor.next()") filterwarnings("ignore", "DB-API extension cursor.__iter__()") con.paramstyle = "format" # type: ignore with con.cursor() as cursor: cursor.execute("drop table if exists book") cursor.execute( "create Temp table book(bookname varchar,author varchar)") def fin() -> None: try: with con.cursor() as cursor: cursor.execute("drop table if exists book") except redshift_connector.ProgrammingError: pass request.addfinalizer(fin) return con
def db_table( request, con: redshift_connector.Connection) -> redshift_connector.Connection: filterwarnings("ignore", "DB-API extension cursor.next()") filterwarnings("ignore", "DB-API extension cursor.__iter__()") con.paramstyle = "format" # type: ignore with con.cursor() as cursor: cursor.execute("DROP TABLE IF EXISTS t1") cursor.execute( "CREATE TEMPORARY TABLE t1 (f1 int primary key, f2 bigint not null, f3 varchar(50) null) " ) def fin() -> None: try: with con.cursor() as cursor: cursor.execute("drop table if exists t1") except redshift_connector.ProgrammingError: pass request.addfinalizer(fin) return con
def test_is_single_database_metadata(_input): param_status, database_metadata_current_db_only_val, exp_val = _input mock_connection = Connection.__new__(Connection) mock_connection.parameter_statuses: deque = deque() mock_connection._database_metadata_current_db_only = database_metadata_current_db_only_val if param_status is not None: mock_connection.parameter_statuses.append( (b"datashare_enabled", param_status.encode())) assert mock_connection.is_single_database_metadata == exp_val
def test_handle_ROW_DESCRIPTION_base(_input): data, exp_result = _input mock_connection = Connection.__new__(Connection) mock_connection._client_protocol_version = ClientProtocolVersion.BASE_SERVER.value mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = {"row_desc": []} mock_connection.handle_ROW_DESCRIPTION(data, mock_cursor) assert mock_cursor.ps is not None assert "row_desc" in mock_cursor.ps assert len(mock_cursor.ps["row_desc"]) == 1 assert exp_result[0].items() <= mock_cursor.ps["row_desc"][0].items() assert "func" in mock_cursor.ps["row_desc"][0]
def test_handle_ROW_DESCRIPTION_extended_metadata(_input, protocol): data, exp_result = _input mock_connection = Connection.__new__(Connection) mock_connection._client_protocol_version = protocol mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = {"row_desc": []} mock_connection.handle_ROW_DESCRIPTION(data, mock_cursor) assert mock_cursor.ps is not None assert "row_desc" in mock_cursor.ps assert len(mock_cursor.ps["row_desc"]) == 1 assert exp_result[0].items() <= mock_cursor.ps["row_desc"][0].items() assert "func" in mock_cursor.ps["row_desc"][0]
def test_get_procedures_considers_args(_input, mocker): catalog, schema_pattern, procedure_name_pattern = _input mocker.patch("redshift_connector.Cursor.execute", return_value=None) mocker.patch("redshift_connector.Cursor.fetchall", return_value=None) mocker.patch("redshift_connector.Connection.is_single_database_metadata", return_value=True) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection mock_cursor.paramstyle = "mocked_val" spy = mocker.spy(mock_cursor, "execute") mock_cursor.get_procedures(catalog, schema_pattern, procedure_name_pattern) assert spy.called assert spy.call_count == 1 assert catalog not in spy.call_args[0][1] for arg in (schema_pattern, procedure_name_pattern): if arg is not None: assert arg in spy.call_args[0][1]
def test_get_tables_considers_args(is_single_database_metadata_val, _input, schema_pattern_type, mocker): catalog, schema_pattern, table_name_pattern = _input mocker.patch("redshift_connector.Cursor.execute", return_value=None) # mock the return value from __schema_pattern_match as it's return value is used in get_tables() # the other potential call to this method in get_tables() result is simply returned, so at this time # it has no impact mocker.patch( "redshift_connector.Cursor.fetchall", return_value=None if schema_pattern_type == "EXTERNAL_SCHEMA_QUERY" else tuple("mock"), ) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.paramstyle = "mocked" mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection spy = mocker.spy(mock_cursor, "execute") with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) mock_cursor.get_tables(catalog, schema_pattern, table_name_pattern) assert spy.called if schema_pattern is not None and is_single_database_metadata_val: assert spy.call_count == 2 # call in __schema_pattern_match(), get_tables() else: assert spy.call_count == 1 if catalog is not None: assert catalog in spy.call_args[0][0] for arg in (schema_pattern, table_name_pattern): if arg is not None: assert arg in spy.call_args[0][1]
def test_client_os_version_is_present(): mock_connection: Connection = Connection.__new__(Connection) assert mock_connection.client_os_version is not None assert isinstance(mock_connection.client_os_version, str)
def to_sql( df: pd.DataFrame, con: redshift_connector.Connection, table: str, schema: str, mode: str = "append", index: bool = False, dtype: Optional[Dict[str, str]] = None, diststyle: str = "AUTO", distkey: Optional[str] = None, sortstyle: str = "COMPOUND", sortkey: Optional[List[str]] = None, primary_keys: Optional[List[str]] = None, varchar_lengths_default: int = 256, varchar_lengths: Optional[Dict[str, int]] = None, ) -> None: """Write records stored in a DataFrame into Redshift. Note ---- For large DataFrames (1K+ rows) consider the function **wr.redshift.copy()**. Parameters ---------- df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html con : redshift_connector.Connection Use redshift_connector.connect() to use " "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog. table : str Table name schema : str Schema name mode : str Append, overwrite or upsert. index : bool True to store the DataFrame index as a column in the table, otherwise False to ignore it. dtype: Dict[str, str], optional Dictionary of columns names and Redshift types to be casted. Useful when you have columns with undetermined or mixed data types. (e.g. {'col name': 'VARCHAR(10)', 'col2 name': 'FLOAT'}) diststyle : str Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"]. https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html distkey : str, optional Specifies a column name or positional number for the distribution key. sortstyle : str Sorting can be "COMPOUND" or "INTERLEAVED". https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html sortkey : List[str], optional List of columns to be sorted. primary_keys : List[str], optional Primary keys. varchar_lengths_default : int The size that will be set for all VARCHAR columns not specified with varchar_lengths. varchar_lengths : Dict[str, int], optional Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). Returns ------- None None. Examples -------- Writing to Redshift using a Glue Catalog Connections >>> import awswrangler as wr >>> con = wr.redshift.connect("MY_GLUE_CONNECTION") >>> wr.redshift.to_sql( ... df=df ... table="my_table", ... schema="public", ... con=con ... ) >>> con.close() """ if df.empty is True: raise exceptions.EmptyDataFrame() _validate_connection(con=con) con.autocommit = False try: with con.cursor() as cursor: created_table, created_schema = _create_table( df=df, path=None, cursor=cursor, table=table, schema=schema, mode=mode, index=index, dtype=dtype, diststyle=diststyle, sortstyle=sortstyle, distkey=distkey, sortkey=sortkey, primary_keys=primary_keys, varchar_lengths_default=varchar_lengths_default, varchar_lengths=varchar_lengths, ) if index: df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["%s"] * len(df.columns)) schema_str = f"{created_schema}." if created_schema else "" sql: str = f"INSERT INTO {schema_str}{created_table} VALUES ({placeholders})" _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) cursor.executemany(sql, parameters) if table != created_table: # upsert _upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys) con.commit() except Exception as ex: con.rollback() _logger.error(ex) raise
def unload_to_files( sql: str, path: str, con: redshift_connector.Connection, iam_role: str, region: Optional[str] = None, max_file_size: Optional[float] = None, kms_key_id: Optional[str] = None, manifest: bool = False, use_threads: bool = True, partition_cols: Optional[List[str]] = None, boto3_session: Optional[boto3.Session] = None, ) -> None: """Unload Parquet files on s3 from a Redshift query result (Through the UNLOAD command). https://docs.aws.amazon.com/redshift/latest/dg/r_UNLOAD.html Note ---- In case of `use_threads=True` the number of threads that will be spawned will be gotten from os.cpu_count(). Parameters ---------- sql: str SQL query. path : Union[str, List[str]] S3 path to write stage files (e.g. s3://bucket_name/any_name/) con : redshift_connector.Connection Use redshift_connector.connect() to use " "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog. iam_role : str AWS IAM role with the related permissions. region : str, optional Specifies the AWS Region where the target Amazon S3 bucket is located. REGION is required for UNLOAD to an Amazon S3 bucket that isn't in the same AWS Region as the Amazon Redshift cluster. By default, UNLOAD assumes that the target Amazon S3 bucket is located in the same AWS Region as the Amazon Redshift cluster. max_file_size : float, optional Specifies the maximum size (MB) of files that UNLOAD creates in Amazon S3. Specify a decimal value between 5.0 MB and 6200.0 MB. If None, the default maximum file size is 6200.0 MB. kms_key_id : str, optional Specifies the key ID for an AWS Key Management Service (AWS KMS) key to be used to encrypt data files on Amazon S3. use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. manifest : bool Unload a manifest file on S3. partition_cols: List[str], optional Specifies the partition keys for the unload operation. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. Returns ------- None Examples -------- >>> import awswrangler as wr >>> con = wr.redshift.connect("MY_GLUE_CONNECTION") >>> wr.redshift.unload_to_files( ... sql="SELECT * FROM public.mytable", ... path="s3://bucket/extracted_parquet_files/", ... con=con, ... iam_role="arn:aws:iam::XXX:role/XXX" ... ) >>> con.close() """ path = path if path.endswith("/") else f"{path}/" session: boto3.Session = _utils.ensure_session(session=boto3_session) s3.delete_objects(path=path, use_threads=use_threads, boto3_session=session) with con.cursor() as cursor: partition_str: str = f"\nPARTITION BY ({','.join(partition_cols)})" if partition_cols else "" manifest_str: str = "\nmanifest" if manifest is True else "" region_str: str = f"\nREGION AS '{region}'" if region is not None else "" max_file_size_str: str = f"\nMAXFILESIZE AS {max_file_size} MB" if max_file_size is not None else "" kms_key_id_str: str = f"\nKMS_KEY_ID '{kms_key_id}'" if kms_key_id is not None else "" sql = (f"UNLOAD ('{sql}')\n" f"TO '{path}'\n" f"IAM_ROLE '{iam_role}'\n" "ALLOWOVERWRITE\n" "PARALLEL ON\n" "FORMAT PARQUET\n" "ENCRYPTED" f"{kms_key_id_str}" f"{partition_str}" f"{region_str}" f"{max_file_size_str}" f"{manifest_str};") _logger.debug("sql: \n%s", sql) cursor.execute(sql)
def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments path: str, con: redshift_connector.Connection, table: str, schema: str, iam_role: str, parquet_infer_sampling: float = 1.0, mode: str = "append", diststyle: str = "AUTO", distkey: Optional[str] = None, sortstyle: str = "COMPOUND", sortkey: Optional[List[str]] = None, primary_keys: Optional[List[str]] = None, varchar_lengths_default: int = 256, varchar_lengths: Optional[Dict[str, int]] = None, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, ) -> None: """Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command). https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html Note ---- If the table does not exist yet, it will be automatically created for you using the Parquet metadata to infer the columns data types. Note ---- In case of `use_threads=True` the number of threads that will be spawned will be gotten from os.cpu_count(). Parameters ---------- path : str S3 prefix (e.g. s3://bucket/prefix/) con : redshift_connector.Connection Use redshift_connector.connect() to use " "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog. table : str Table name schema : str Schema name iam_role : str AWS IAM role with the related permissions. parquet_infer_sampling : float Random sample ratio of files that will have the metadata inspected. Must be `0.0 < sampling <= 1.0`. The higher, the more accurate. The lower, the faster. mode : str Append, overwrite or upsert. diststyle : str Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"]. https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html distkey : str, optional Specifies a column name or positional number for the distribution key. sortstyle : str Sorting can be "COMPOUND" or "INTERLEAVED". https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html sortkey : List[str], optional List of columns to be sorted. primary_keys : List[str], optional Primary keys. varchar_lengths_default : int The size that will be set for all VARCHAR columns not specified with varchar_lengths. varchar_lengths : Dict[str, int], optional Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). use_threads : bool True to enable concurrent requests, False to disable multiple threads. If enabled os.cpu_count() will be used as the max number of threads. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. s3_additional_kwargs: Forward to botocore requests. Valid parameters: "ACL", "Metadata", "ServerSideEncryption", "StorageClass", "SSECustomerAlgorithm", "SSECustomerKey", "SSEKMSKeyId", "SSEKMSEncryptionContext", "Tagging". e.g. s3_additional_kwargs={'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'YOUR_KMS_KEY_ARN'} Returns ------- None None. Examples -------- >>> import awswrangler as wr >>> con = wr.redshift.connect("MY_GLUE_CONNECTION") >>> wr.db.copy_from_files( ... path="s3://bucket/my_parquet_files/", ... con=con, ... table="my_table", ... schema="public" ... iam_role="arn:aws:iam::XXX:role/XXX" ... ) >>> con.close() """ con.autocommit = False try: with con.cursor() as cursor: created_table, created_schema = _create_table( df=None, path=path, parquet_infer_sampling=parquet_infer_sampling, cursor=cursor, table=table, schema=schema, mode=mode, diststyle=diststyle, sortstyle=sortstyle, distkey=distkey, sortkey=sortkey, primary_keys=primary_keys, varchar_lengths_default=varchar_lengths_default, varchar_lengths=varchar_lengths, index=False, dtype=None, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, ) _copy( cursor=cursor, path=path, table=created_table, schema=created_schema, iam_role=iam_role, ) if table != created_table: # upsert _upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys) con.commit() except Exception as ex: con.rollback() _logger.error(ex) raise
def test_inspect_int(_input): input_val, expected_type = _input mock_connection = Connection.__new__(Connection) assert mock_connection.inspect_int(input_val) == expected_type
def test_handle_COPY_DONE(): mock_connection = Connection.__new__(Connection) assert hasattr(mock_connection, "_copy_done") is False mock_connection.handle_COPY_DONE(None, None) assert mock_connection._copy_done is True
def test_handle_ERROR_RESPONSE(_input): server_msg, expected_decoded_msg, expected_error = _input mock_connection = Connection.__new__(Connection) mock_connection.handle_ERROR_RESPONSE(server_msg, None) assert type(mock_connection.error) == expected_error assert str(expected_decoded_msg) in str(mock_connection.error)
def test_client_os_version_is_not_present(): mock_connection: Connection = Connection.__new__(Connection) with patch("platform.platform", side_effect=Exception("not for you")): assert mock_connection.client_os_version == "unknown"