示例#1
0
def test_retrieve_data_with_row_count_limit_in_query(connect, fetchmany,
                                                     snowflake_datasource):
    snowflake_datasource.query = 'select name from favourite_drinks limit 10;'
    sc = SnowflakeCommon()
    sc.retrieve_data(connect, snowflake_datasource, get_row_count=True)
    assert fetchmany.call_count == 4  # +1 to select database and warehouse
    assert sc.total_rows_count == 20
示例#2
0
def test_get_warehouse_without_filter(warehouse_result, execute_query, connect,
                                      mocker):
    result = SnowflakeCommon().get_warehouses(connect)
    assert warehouse_result.call_count == 1
    assert result[0] == 'warehouse_1'
    assert result[1] == 'warehouse_2'
    assert len(result) == 2
示例#3
0
 def _retrieve_data(self,
                    data_source: SnowflakeoAuth2DataSource) -> pd.DataFrame:
     with self._get_connection(
             database=data_source.database,
             warehouse=data_source.warehouse) as connection:
         result = SnowflakeCommon().retrieve_data(connection, data_source)
     return result
示例#4
0
def test_get_slice_without_limit_with_offset(result, execute_query, connect,
                                             snowflake_datasource):
    ds: DataSlice = SnowflakeCommon().get_slice(connect,
                                                snowflake_datasource,
                                                offset=5)
    assert result.call_count == 3
    assert len(ds.df) == 14
    assert ds.stats.total_returned_rows == 14
示例#5
0
def test_get_slice_metadata(snowflake_datasource, mocker):
    snowflake_datasource.query = 'select name from favourite_drinks limit 12 offset 23;'
    connect = mocker.MagicMock()
    connect.cursor().execute().fetchone.return_value = [{'total_rows': 200}]
    connect.cursor().execute().fetchall.return_value = [{'c1': 2}]
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert ds.stats.df_memory_size == 1360
    assert ds.stats.total_returned_rows == 14
示例#6
0
 def describe(
     self,
     data_source: SnowflakeDataSource,
 ) -> Dict[str, str]:
     data_source = self._set_warehouse(data_source)
     with self._get_connection(data_source.database,
                               data_source.warehouse) as connection:
         result = SnowflakeCommon().describe(connection, data_source.query)
     return result
示例#7
0
def test_execute_broken_query(execute_query, snowflake_datasource, mocker):
    snowflake_datasource.query = 'select name from favourite_drinks limit 12 offset 23;'
    connect = mocker.MagicMock()
    with pytest.raises(ProgrammingError):
        SnowflakeCommon()._execute_parallelized_queries(
            connect,
            query=snowflake_datasource.query,
            query_parameters=snowflake_datasource.parameters,
        )
示例#8
0
def test__describe_api_didnt_describe(connect, mocker):
    mocked_cursor = mocker.MagicMock()
    mocked_describe = mocked_cursor.describe
    mocked_describe.return_value = None
    connect.cursor.return_value = mocked_cursor
    with pytest.raises(TypeError):
        SnowflakeCommon()._describe(
            connect, 'SELECT steve_madden, IPO FROM STRATON_OAKMONT;')
    mocked_describe.assert_called_once()
示例#9
0
def test_get_slice_metadata_no_select_in_query(result, snowflake_datasource,
                                               mocker):
    snowflake_datasource.query = (
        'create table users as  (id integer default id_seq.nextval,  name varchar (100), '
        'preferences string, created_at timestamp); ')
    connect = mocker.MagicMock()
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert result.call_count == 3
    assert ds
示例#10
0
def test_get_slice_metadata_wrong_response_from_count_query(
        snowflake_datasource, mocker):
    snowflake_datasource.query = 'select name from favourite_drinks limit 12 offset 23;'
    connect = mocker.MagicMock()
    connect.cursor().execute().fetchone.return_value = [{
        'error':
        'invalid query'
    }]
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert ds
    connect.cursor().execute().fetchone.side_effect = Exception()
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert ds
    connect.cursor().execute().fetchone.side_effect = IndexError()
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert ds
    connect.cursor().execute().fetchone.side_effect = TypeError()
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert ds
    connect.cursor().execute().fetchone.side_effect = AttributeError()
    ds: DataSlice = SnowflakeCommon().get_slice(connect, snowflake_datasource)
    assert ds
示例#11
0
 def _fetch_data(
     self,
     data_source: SnowflakeDataSource,
     offset: Optional[int] = None,
     limit: Optional[int] = None,
     get_row_count: bool = False,
 ) -> pd.DataFrame:
     data_source = self._set_warehouse(data_source)
     with self._get_connection(data_source.database,
                               data_source.warehouse) as connection:
         result = SnowflakeCommon().fetch_data(connection, data_source,
                                               offset, limit, get_row_count)
     return result
示例#12
0
def test_get_db_content(connect, mocker):
    scommon = SnowflakeCommon()
    mocker.patch.object(
        scommon,
        '_execute_query',
        return_value={
            'DATABASE':
            'SNOWFLAKE_SAMPLE_DATA',
            'SCHEMA':
            'TPCH_SF1000',
            'TYPE':
            'table',
            'NAME':
            'REGION',
            'COLUMNS':
            '[\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": '
            '"R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",\n    "type": '
            '"TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    '
            '"name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    "name": "R_NAME",'
            '\n    "type": "TEXT"\n  },\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },'
            '\n  {\n    "name": "R_NAME",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",'
            '\n    "type": "TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  '
            '},\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": '
            '"R_REGIONKEY",\n    "type": "NUMBER"\n  }\n]',
        },
    )
    assert scommon.get_db_content(connection=connect) == {
        'DATABASE':
        'SNOWFLAKE_SAMPLE_DATA',
        'SCHEMA':
        'TPCH_SF1000',
        'TYPE':
        'table',
        'NAME':
        'REGION',
        'COLUMNS':
        '[\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    "name": "R_NAME",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  }\n]',
    }
示例#13
0
 def get_slice(
     self,
     data_source: SnowflakeDataSource,
     permissions: Optional[dict] = None,
     offset: int = 0,
     limit: Optional[int] = None,
     get_row_count: Optional[bool] = False,
 ) -> DataSlice:
     data_source = self._set_warehouse(data_source)
     with self._get_connection(data_source.database,
                               data_source.warehouse) as connection:
         result = SnowflakeCommon().get_slice(connection,
                                              data_source,
                                              offset=offset,
                                              limit=limit,
                                              get_row_count=get_row_count)
     return result
示例#14
0
 def get_model(self):
     with self._get_connection() as connection:
         databases = SnowflakeCommon().get_databases(connection=connection)
     content_queries = []
     for db in databases:
         content_queries.append(build_database_model_extraction_query())
     db_contents = []
     with concurrent.futures.ThreadPoolExecutor() as executor:
         futures = [
             executor.submit(self._get_connection_and_db_content, db,
                             db_contents) for db in databases
         ]
         for future in concurrent.futures.as_completed(futures):
             if future.exception():
                 raise future.exception()
             else:
                 self.logger.info('query finished')
     return DiscoverableConnector.format_db_model(db_contents)
示例#15
0
def test__describe_api_changed(connect, mocker):
    mocked_cursor = mocker.MagicMock()
    mocked_describe = mocked_cursor.describe
    mocker.patch('toucan_connectors.snowflake_common.json.dumps')

    class fake_result:
        def __init__(self, name, type_code):
            self.name = name
            self.type_code = type_code

    mocked_describe.return_value = [
        fake_result(name='steve_madden', type_code=14),
        fake_result(name='IPO', type_code=0),
    ]
    connect.cursor.return_value = mocked_cursor
    res = SnowflakeCommon()._describe(
        connect, 'SELECT steve_madden, IPO FROM STRATON_OAKMONT;')
    mocked_describe.assert_called_once()
    assert res['steve_madden'] is None
    assert res['IPO'] == 'float'
示例#16
0
def test_fetch_data_warehouse_none(execute_query, execute_parallelized,
                                   connect):
    """The connection's warehouse should not be switched to datasource's if none"""
    s = SnowflakeDataSource(
        name='test_name',
        domain='test_domain',
        database='database_1',
        warehouse=None,
        query='select * from my_table where toto=%(foo);',
        query_object={
            'schema': 'SHOW_SCHEMA',
            'table': 'MY_TABLE',
            'columns': ['col1', 'col2']
        },
        parameters={
            'foo': 'bar',
            'pokemon': 'pikachu'
        },
    )
    SnowflakeCommon().fetch_data(connect, s)
    assert execute_query.call_count == 0
示例#17
0
def test_get_database_with_filter_one_result(database_result, execute_query,
                                             connect):
    result = SnowflakeCommon().get_databases(connect, 'database_1')
    assert database_result.call_count == 1
    assert result[0] == 'database_1'
    assert len(result) == 1
示例#18
0
 def _get_unique_datasource_identifier(
         self, data_source: SnowflakeDataSource) -> dict:
     return SnowflakeCommon().render_datasource(data_source)
示例#19
0
def test_retrieve_total_rows():
    sc = SnowflakeCommon()
    sc.set_total_returned_rows_count(20)
    assert sc.total_returned_rows_count == 20
示例#20
0
 def _get_databases(self, database_name: Optional[str] = None) -> List[str]:
     with self._get_connection(database=database_name) as connection:
         result = SnowflakeCommon().get_databases(connection, database_name)
     return result
示例#21
0
 def _get_connection_and_db_content(self, database: str, db_contents: List):
     with self._get_connection(
             database=database,
             warehouse=self.default_warehouse) as connection:
         db_contents += SnowflakeCommon().get_db_content(
             connection).to_dict('records')
示例#22
0
def test_describe(connect, mocker):
    mocked__describe = mocker.patch(
        'toucan_connectors.query_manager.QueryManager.describe')
    SnowflakeCommon().describe(connect,
                               'SELECT FAIRY_DUST FROM STRATON_OAKMONT;')
    mocked__describe.assert_called_once()
示例#23
0
def test_get_warehouse_with_filter_no_result(warehouse_result, execute_query,
                                             connect):
    result = SnowflakeCommon().get_warehouses(connect, 'warehouse_3')
    assert warehouse_result.call_count == 1
    assert len(result) == 0
示例#24
0
def test_retrieve_data(result, execute_query, connect, snowflake_datasource):
    df: pd.DataFrame = SnowflakeCommon().retrieve_data(connect,
                                                       snowflake_datasource)
    assert result.call_count == 3
    assert len(df) == 14
示例#25
0
 def _get_warehouses(self,
                     warehouse_name: Optional[str] = None) -> List[str]:
     with self._get_connection(warehouse=warehouse_name) as connection:
         result = SnowflakeCommon().get_warehouses(connection,
                                                   warehouse_name)
     return result