def _does_table_exist(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> bool: schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" cursor.execute(f"SELECT true WHERE EXISTS (" f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'" f");") return len(cursor.fetchall()) > 0
def _get_primary_keys(cursor: redshift_connector.Cursor, schema: str, table: str) -> List[str]: cursor.execute( f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'" ) result: str = cursor.fetchall()[0][0] rfields: List[str] = result.split("(")[1].strip(")").split(",") fields: List[str] = [field.strip().strip('"') for field in rfields] return fields
def _copy( cursor: redshift_connector.Cursor, path: str, table: str, iam_role: str, schema: Optional[str] = None, ) -> None: if schema is None: table_name: str = table else: table_name = f"{schema}.{table}" sql: str = f"COPY {table_name} FROM '{path}'\nIAM_ROLE '{iam_role}'\nFORMAT AS PARQUET" _logger.debug("copy query:\n%s", sql) cursor.execute(sql)
def test_insert_data_column_stmt(mocked_csv, indexes, names, exp_execute_args, mocker): # mock fetchone to return "True" to ensure the table_name and column_name # validation steps pass mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1]) mock_cursor: Cursor = Cursor.__new__(Cursor) # spy on the execute method, so we can check value of sql_query spy = mocker.spy(mock_cursor, "execute") # mock out the connection mock_cursor._c = Mock() mock_cursor.paramstyle = "qmark" mocked_csv.side_effect = [ StringIO("""\col1,col2,col3\n1,3,foo\n2,5,bar\n-1,7,baz""") ] mock_cursor.insert_data_bulk( filename="mocked_csv", table_name="test_table", parameter_indices=indexes, column_names=names, delimiter=",", batch_size=3, ) assert spy.called is True assert spy.call_args[0][0] == exp_execute_args[0] assert spy.call_args[0][1] == exp_execute_args[1]
def test_insert_data_uses_batch_size(mocked_csv, batch_size, mocker): # mock fetchone to return "True" to ensure the table_name and column_name # validation steps pass mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1]) mock_cursor: Cursor = Cursor.__new__(Cursor) # spy on the execute method, so we can check value of sql_query spy = mocker.spy(mock_cursor, "execute") # mock out the connection mock_cursor._c = Mock() mock_cursor.paramstyle = "qmark" mocked_csv.side_effect = [ StringIO("""\col1,col2,col3\n1,3,foo\n2,5,bar\n-1,7,baz""") ] mock_cursor.insert_data_bulk( filename="mocked_csv", table_name="test_table", parameter_indices=[0, 1, 2], column_names=["col1", "col2", "col3"], delimiter=",", batch_size=batch_size, ) assert spy.called is True actual_insert_stmts_executed = 0 for call in spy.mock_calls: if len(call[1]) == 2 and "INSERT INTO" in call[1][0]: actual_insert_stmts_executed += 1 assert actual_insert_stmts_executed == ceil(3 / batch_size)
def test_get_schemas_considers_args(_input, is_single_database_metadata_val, mocker): catalog, schema_pattern = _input mocker.patch("redshift_connector.Cursor.execute", return_value=None) mocker.patch("redshift_connector.Cursor.fetchall", return_value=None) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.paramstyle = "mocked" mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection spy = mocker.spy(mock_cursor, "execute") with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) mock_cursor.get_schemas(catalog, schema_pattern) assert spy.called assert spy.call_count == 1 if schema_pattern is not None: # should be in parameterized portion assert schema_pattern in spy.call_args[0][1] if catalog is not None: assert catalog in spy.call_args[0][0]
def test__get_catalog_filter_conditions_considers_args( _input, is_single_database_metadata_val): catalog, api_supported_only_for_connected_database, database_col_name = _input mock_cursor: Cursor = Cursor.__new__(Cursor) mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) result: str = mock_cursor._get_catalog_filter_conditions( catalog, api_supported_only_for_connected_database, database_col_name) if catalog is not None: assert catalog in result if is_single_database_metadata_val or api_supported_only_for_connected_database: assert "current_database()" in result assert catalog in result elif database_col_name is None: assert "database_name" in result else: assert database_col_name in result else: assert result == ""
def test_raw_connection_property_warns(): mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor._c = Connection.__new__(Connection) with pytest.warns(UserWarning, match="DB-API extension cursor.connection used"): mock_cursor.connection
def test_get_catalogs_considers_args(is_single_database_metadata_val, mocker): mocker.patch("redshift_connector.Cursor.execute", return_value=None) mocker.patch("redshift_connector.Cursor.fetchall", return_value=None) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.paramstyle = "mocked" mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection spy = mocker.spy(mock_cursor, "execute") with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) mock_cursor.get_catalogs() assert spy.called assert spy.call_count == 1 if is_single_database_metadata_val: assert "select current_database as TABLE_CAT FROM current_database()" in spy.call_args[ 0][0] else: assert ( "SELECT CAST(database_name AS varchar(124)) AS TABLE_CAT FROM PG_CATALOG.SVV_REDSHIFT_DATABASES " in spy.call_args[0][0])
def test_handle_ROW_DESCRIPTION_missing_ps_raises(): mock_connection = Connection.__new__(Connection) mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = None with pytest.raises(InterfaceError, match="Cursor is missing prepared statement"): mock_connection.handle_ROW_DESCRIPTION(b"\x00", mock_cursor)
def test_fetch_dataframe_warns_user(_input, mocker): data, exp_warning_msg = _input mock_cursor: Cursor = Cursor.__new__(Cursor) mocker.patch("redshift_connector.Cursor._getDescription", return_value=[data]) mocker.patch("redshift_connector.Cursor.__next__", return_value=["blah"]) with pytest.warns(UserWarning, match=exp_warning_msg): mock_cursor.fetch_dataframe(1)
def test_handle_ROW_DESCRIPTION_missing_row_desc_raises(): mock_connection = Connection.__new__(Connection) mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = {} with pytest.raises(InterfaceError, match="Prepared Statement is missing row description"): mock_connection.handle_ROW_DESCRIPTION(b"\x00", mock_cursor)
def test_fetch_dataframe_no_results(mocker): mock_cursor: Cursor = Cursor.__new__(Cursor) mocker.patch("redshift_connector.Cursor._getDescription", return_value=["test"]) mocker.patch("redshift_connector.Cursor.__next__", side_effect=StopIteration("mocked end")) assert mock_cursor.fetch_dataframe(1) is None
def test_handle_ROW_DESCRIPTION_extended_metadata(_input, protocol): data, exp_result = _input mock_connection = Connection.__new__(Connection) mock_connection._client_protocol_version = protocol mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = {"row_desc": []} mock_connection.handle_ROW_DESCRIPTION(data, mock_cursor) assert mock_cursor.ps is not None assert "row_desc" in mock_cursor.ps assert len(mock_cursor.ps["row_desc"]) == 1 assert exp_result[0].items() <= mock_cursor.ps["row_desc"][0].items() assert "func" in mock_cursor.ps["row_desc"][0]
def test_handle_ROW_DESCRIPTION_base(_input): data, exp_result = _input mock_connection = Connection.__new__(Connection) mock_connection._client_protocol_version = ClientProtocolVersion.BASE_SERVER.value mock_cursor = Cursor.__new__(Cursor) mock_cursor.ps = {"row_desc": []} mock_connection.handle_ROW_DESCRIPTION(data, mock_cursor) assert mock_cursor.ps is not None assert "row_desc" in mock_cursor.ps assert len(mock_cursor.ps["row_desc"]) == 1 assert exp_result[0].items() <= mock_cursor.ps["row_desc"][0].items() assert "func" in mock_cursor.ps["row_desc"][0]
def _upsert( cursor: redshift_connector.Cursor, table: str, temp_table: str, schema: str, primary_keys: Optional[List[str]] = None, ) -> None: if not primary_keys: primary_keys = _get_primary_keys(cursor=cursor, schema=schema, table=table) _logger.debug("primary_keys: %s", primary_keys) if not primary_keys: raise exceptions.InvalidRedshiftPrimaryKeys() equals_clause: str = f"{table}.%s = {temp_table}.%s" join_clause: str = " AND ".join( [equals_clause % (pk, pk) for pk in primary_keys]) sql: str = f"DELETE FROM {schema}.{table} USING {temp_table} WHERE {join_clause}" _logger.debug(sql) cursor.execute(sql) sql = f"INSERT INTO {schema}.{table} SELECT * FROM {temp_table}" _logger.debug(sql) cursor.execute(sql) _drop_table(cursor=cursor, schema=schema, table=temp_table)
def test_get_procedures_considers_args(_input, mocker): catalog, schema_pattern, procedure_name_pattern = _input mocker.patch("redshift_connector.Cursor.execute", return_value=None) mocker.patch("redshift_connector.Cursor.fetchall", return_value=None) mocker.patch("redshift_connector.Connection.is_single_database_metadata", return_value=True) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection mock_cursor.paramstyle = "mocked_val" spy = mocker.spy(mock_cursor, "execute") mock_cursor.get_procedures(catalog, schema_pattern, procedure_name_pattern) assert spy.called assert spy.call_count == 1 assert catalog not in spy.call_args[0][1] for arg in (schema_pattern, procedure_name_pattern): if arg is not None: assert arg in spy.call_args[0][1]
def test_get_tables_considers_args(is_single_database_metadata_val, _input, schema_pattern_type, mocker): catalog, schema_pattern, table_name_pattern = _input mocker.patch("redshift_connector.Cursor.execute", return_value=None) # mock the return value from __schema_pattern_match as it's return value is used in get_tables() # the other potential call to this method in get_tables() result is simply returned, so at this time # it has no impact mocker.patch( "redshift_connector.Cursor.fetchall", return_value=None if schema_pattern_type == "EXTERNAL_SCHEMA_QUERY" else tuple("mock"), ) mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.paramstyle = "mocked" mock_connection: Cursor = Connection.__new__(Connection) mock_cursor._c = mock_connection spy = mocker.spy(mock_cursor, "execute") with patch( "redshift_connector.Connection.is_single_database_metadata", new_callable=PropertyMock()) as mock_is_single_database_metadata: mock_is_single_database_metadata.__get__ = Mock( return_value=is_single_database_metadata_val) mock_cursor.get_tables(catalog, schema_pattern, table_name_pattern) assert spy.called if schema_pattern is not None and is_single_database_metadata_val: assert spy.call_count == 2 # call in __schema_pattern_match(), get_tables() else: assert spy.call_count == 1 if catalog is not None: assert catalog in spy.call_args[0][0] for arg in (schema_pattern, table_name_pattern): if arg is not None: assert arg in spy.call_args[0][1]
def test_insert_data_column_names_indexes_mismatch_raises( indexes, names, mocker): # mock fetchone to return "True" to ensure the table_name and column_name # validation steps pass mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1]) mock_cursor: Cursor = Cursor.__new__(Cursor) # mock out the connection mock_cursor._c = Mock() mock_cursor.paramstyle = "qmark" with pytest.raises( InterfaceError, match="Column names and parameter indexes must be the same length" ): mock_cursor.insert_data_bulk( filename="test_file", table_name="test_table", parameter_indices=indexes, column_names=names, delimiter=",", )
def execute_ddl_2(cursor: redshift_connector.Cursor) -> None: cursor.execute(xddl2) cursor.execute(ddl2)
def _drop_table(cursor: redshift_connector.Cursor, schema: Optional[str], table: str) -> None: schema_str = f"{schema}." if schema else "" sql = f"DROP TABLE IF EXISTS {schema_str}{table}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql)
def test_get_description_no_ps(): mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor.ps = None assert mock_cursor._getDescription() is None
def _create_table( df: Optional[pd.DataFrame], path: Optional[Union[str, List[str]]], cursor: redshift_connector.Cursor, table: str, schema: str, mode: str, index: bool, dtype: Optional[Dict[str, str]], diststyle: str, sortstyle: str, distkey: Optional[str], sortkey: Optional[List[str]], primary_keys: Optional[List[str]], varchar_lengths_default: int, varchar_lengths: Optional[Dict[str, int]], parquet_infer_sampling: float = 1.0, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, s3_additional_kwargs: Optional[Dict[str, str]] = None, ) -> Tuple[str, Optional[str]]: if mode == "overwrite": _drop_table(cursor=cursor, schema=schema, table=table) elif _does_table_exist(cursor=cursor, schema=schema, table=table) is True: if mode == "upsert": guid: str = uuid.uuid4().hex temp_table: str = f"temp_redshift_{guid}" sql: str = f"CREATE TEMPORARY TABLE {temp_table} (LIKE {schema}.{table})" _logger.debug(sql) cursor.execute(sql) return temp_table, None return table, schema diststyle = diststyle.upper() if diststyle else "AUTO" sortstyle = sortstyle.upper() if sortstyle else "COMPOUND" if df is not None: redshift_types: Dict[ str, str] = _data_types.database_types_from_pandas( df=df, index=index, dtype=dtype, varchar_lengths_default=varchar_lengths_default, varchar_lengths=varchar_lengths, converter_func=_data_types.pyarrow2redshift, ) elif path is not None: redshift_types = _redshift_types_from_path( path=path, varchar_lengths_default=varchar_lengths_default, varchar_lengths=varchar_lengths, parquet_infer_sampling=parquet_infer_sampling, use_threads=use_threads, boto3_session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, ) else: raise ValueError("df and path are None.You MUST pass at least one.") _validate_parameters( redshift_types=redshift_types, diststyle=diststyle, distkey=distkey, sortstyle=sortstyle, sortkey=sortkey, ) cols_str: str = "".join([f"{k} {v},\n" for k, v in redshift_types.items()])[:-2] primary_keys_str: str = f",\nPRIMARY KEY ({', '.join(primary_keys)})" if primary_keys else "" distkey_str: str = f"\nDISTKEY({distkey})" if distkey and diststyle == "KEY" else "" sortkey_str: str = f"\n{sortstyle} SORTKEY({','.join(sortkey)})" if sortkey else "" sql = (f"CREATE TABLE IF NOT EXISTS {schema}.{table} (\n" f"{cols_str}" f"{primary_keys_str}" f")\nDISTSTYLE {diststyle}" f"{distkey_str}" f"{sortkey_str}") _logger.debug("Create table query:\n%s", sql) cursor.execute(sql) return table, schema
def test_execute_no_connection_raises_interface_error(): mock_cursor: Cursor = Cursor.__new__(Cursor) mock_cursor._c = None with pytest.raises(InterfaceError, match="Cursor closed"): mock_cursor.execute("blah")