Beispiel #1
0
def test_oauth_args_wrong_type_of_auth(
    get_identifier,
    is_closed,
    close,
    connect,
    retrieve_data,
    snowflake_connector_oauth,
    snowflake_datasource,
    mocker,
):
    spy = mocker.spy(SnowflakeConnector, '_refresh_oauth_token')

    snowflake_connector_oauth.authentication_method = AuthenticationMethod.PLAIN
    snowflake_connector_oauth._retrieve_data(snowflake_datasource)
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
    assert spy.call_count == 0
Beispiel #2
0
def snowflake_connector_malformed():
    return SnowflakeConnector(
        identifier='snowflake_test',
        name='test_name',
        user='******',
        password='******',
        account='test_account',
        default_warehouse='warehouse_1',
    )
Beispiel #3
0
def snowflake_connector():
    return SnowflakeConnector(
        identifier='snowflake_test',
        name='test_name',
        authentication_method=AuthenticationMethod.PLAIN,
        user='******',
        password='******',
        account='test_account',
        default_warehouse='warehouse_1',
    )
Beispiel #4
0
def test_get_connection_connect(rt, is_closed, close, connect,
                                snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    snowflake_connector._get_connection('test_database', 'test_warehouse')
    assert rt.call_count == 0
    assert connect.call_args_list[0][1]['account'] == 'test_account'
    assert connect.call_args_list[0][1]['user'] == 'test_user'
    assert connect.call_args_list[0][1]['password'] == 'test_password'
    assert connect.call_args_list[0][1]['database'] == 'test_database'
    assert connect.call_args_list[0][1]['warehouse'] == 'test_warehouse'
    cm.force_clean()
Beispiel #5
0
def test_snowflake_connection_alive(gat, is_closed, close, connect,
                                    snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    t1 = cm.time_between_clean
    t2 = cm.time_keep_alive
    cm.time_between_clean = 1
    cm.time_keep_alive = 5
    snowflake_connector._get_connection('test_database', 'test_warehouse')
    assert len(cm.connection_list) == 1
    cm.time_between_clean = t1
    cm.time_keep_alive = t2
    cm.force_clean()
Beispiel #6
0
def test_get_status_account_nok(is_closed, close, connect, gw,
                                snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    gw.side_effect = snowflake.connector.errors.ProgrammingError('Account nok')
    result = snowflake_connector.get_status()
    assert result == ConnectorStatus(
        status=False,
        error='Account nok',
        details=[('Connection to Snowflake', False),
                 ('Default warehouse exists', None)],
    )
    cm.force_clean()
Beispiel #7
0
def test_account_forbidden(is_closed, close, connect, gw, snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    gw.side_effect = snowflake.connector.errors.ForbiddenError()
    result = snowflake_connector.get_status()
    assert result == ConnectorStatus(
        status=False,
        error=
        f"Access forbidden, please check that you have access to the '{snowflake_connector.account}' account or try again later.",
        details=[('Connection to Snowflake', False),
                 ('Default warehouse exists', None)],
    )
    cm.force_clean()
Beispiel #8
0
def test_describe(is_closed, close, connect, mocker, snowflake_datasource,
                  snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    mocked_common_describe = mocker.patch(
        'toucan_connectors.snowflake.snowflake_connector.SnowflakeCommon.describe',
        return_value={
            'toto': 'int',
            'tata': 'str'
        },
    )
    snowflake_connector.describe(snowflake_datasource)
    mocked_common_describe.assert_called_once()
    cm.force_clean()
Beispiel #9
0
def test_account_failed_for_user(is_closed, close, connect, gw,
                                 snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    gw.side_effect = snowflake.connector.errors.DatabaseError()
    result = snowflake_connector.get_status()
    assert result == ConnectorStatus(
        status=False,
        error=
        f"Connection failed for the user '{snowflake_connector.user}', please check your credentials",
        details=[('Connection to Snowflake', False),
                 ('Default warehouse exists', None)],
    )
    cm.force_clean()
Beispiel #10
0
def test_snowflake_connection_close(gat, is_closed, close, connect,
                                    snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    t1 = cm.time_between_clean
    t2 = cm.time_keep_alive
    cm.time_between_clean = 1
    cm.time_keep_alive = 1
    snowflake_connector._get_connection('test_database', 'test_warehouse')
    time.sleep(5)
    assert close.call_count >= 1
    cm.time_between_clean = t1
    cm.time_keep_alive = t2
    cm.force_clean()
Beispiel #11
0
def test_account_does_not_exists(is_closed, close, connect, gw,
                                 snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    gw.side_effect = snowflake.connector.errors.OperationalError()
    result = snowflake_connector.get_status()
    assert result == ConnectorStatus(
        status=False,
        error=
        f"Connection failed for the account '{snowflake_connector.account}', please check the Account field",
        details=[('Connection to Snowflake', False),
                 ('Default warehouse exists', None)],
    )
    cm.force_clean()
Beispiel #12
0
def test_get_connection_connect_oauth(get_identifier, rt, is_closed, close,
                                      connect, snowflake_connector_oauth):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    snowflake_connector_oauth._get_connection('test_database',
                                              'test_warehouse')
    print(connect.call_args_list)
    assert rt.call_count == 1
    assert connect.call_args_list[0][1]['account'] == 'test_account'
    assert (
        connect.call_args_list[0][1]['token'] ==
        'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjQyLCJzdWIiOiJzbm93Zmxha2VfdXNlciJ9.NJDbR-tAepC_ANrg9m5PozycbcuWDgGi4o9sN9Pl27k'
    )
    assert connect.call_args_list[0][1]['database'] == 'test_database'
    assert connect.call_args_list[0][1]['warehouse'] == 'test_warehouse'
    cm.force_clean()
Beispiel #13
0
def test_get_model_exception(is_closed, close, connect, mocker,
                             snowflake_datasource, snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    mocked_common_get_databases = mocker.patch(
        'toucan_connectors.snowflake.snowflake_connector.SnowflakeCommon.get_databases',
        return_value=['booo'],
    )
    mocker.patch(
        'toucan_connectors.snowflake.snowflake_connector.SnowflakeCommon.get_db_content',
        side_effect=Exception,
    )

    with pytest.raises(Exception):
        snowflake_connector.get_model()
    mocked_common_get_databases.assert_called_once()
    cm.force_clean()
Beispiel #14
0
def test_refresh_oauth_token(
    req_mock,
    get_identifier,
    is_closed,
    close,
    connect,
    retrieve_data,
    snowflake_connector_oauth,
    snowflake_datasource,
):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    # Expired JWT
    snowflake_connector_oauth.user_tokens_keeper.access_token = SecretStr(
        jwt.encode({'exp': datetime.now() - timedelta(hours=24)},
                   key='supersecret'))
    req_mock.return_value.status_code = 201
    req_mock.return_value.ok = False
    req_mock.return_value.return_value = {
        'access_token': 'token',
        'refresh_token': 'token'
    }

    try:
        snowflake_connector_oauth._retrieve_data(snowflake_datasource)
        assert req_mock.call_count == 1
    except Exception as e:
        assert str(e) == 'HTTP Error 401: Unauthorized'
        assert False
    else:
        assert True
    finally:
        cm.force_clean()

    req_mock.reset_mock()
    # Invalid JWT
    snowflake_connector_oauth.user_tokens_keeper.access_token = SecretStr(
        'PLOP')
    try:
        snowflake_connector_oauth._retrieve_data(snowflake_datasource)
        assert req_mock.call_count == 1
    except Exception as e:
        assert str(e) == 'HTTP Error 401: Unauthorized'
        assert False
    else:
        assert True
    finally:
        cm.force_clean()
Beispiel #15
0
def snowflake_connector_oauth(mocker):
    user_tokens_keeper = mocker.Mock(
        access_token=SecretStr(OAUTH_ACCESS_TOKEN),
        refresh_token=SecretStr(OAUTH_REFRESH_TOKEN),
        update_tokens=mocker.Mock(),
    )
    sso_credentials_keeper = mocker.Mock(
        client_id=OAUTH_CLIENT_ID,
        client_secret=SecretStr(OAUTH_CLIENT_SECRET))
    return SnowflakeConnector(
        name='test_name',
        authentication_method=AuthenticationMethod.OAUTH,
        user='******',
        password='******',
        account='test_account',
        token_endpoint=OAUTH_TOKEN_ENDPOINT,
        token_endpoint_content_type=OAUTH_TOKEN_ENDPOINT_CONTENT_TYPE,
        user_tokens_keeper=user_tokens_keeper,
        sso_credentials_keeper=sso_credentials_keeper,
        default_warehouse='default_wh',
    )
Beispiel #16
0
def test_schema_fields_order():
    schema_props_keys = list(
        JsonWrapper.loads(
            SnowflakeConnector.schema_json())['properties'].keys())
    ordered_keys = [
        'type',
        'name',
        'account',
        'authentication_method',
        'user',
        'password',
        'token_endpoint',
        'token_endpoint_content_type',
        'role',
        'default_warehouse',
        'retry_policy',
        'secrets_storage_version',
        'sso_credentials_keeper',
        'user_tokens_keeper',
    ]
    assert schema_props_keys == ordered_keys
Beispiel #17
0
def test_oauth_args_endpoint_not_200(req_mock, is_closed, close, connect,
                                     snowflake_connector_oauth,
                                     snowflake_datasource):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    snowflake_connector_oauth.user_tokens_keeper.access_token = SecretStr(
        jwt.encode({'exp': datetime.now() - timedelta(hours=24)},
                   key='supersecret'))
    req_mock.return_value.status_code = 401

    def fake_raise_for_status():
        raise HTTPError('url', 401, 'Unauthorized', {}, None)

    req_mock.return_value.ok = False
    req_mock.return_value.raise_for_status = lambda: fake_raise_for_status()

    try:
        snowflake_connector_oauth._retrieve_data(snowflake_datasource)
    except Exception as e:
        cm.force_clean()
        assert str(e) == 'HTTP Error 401: Unauthorized'
        assert req_mock.call_count == 1
    else:
        cm.force_clean()
Beispiel #18
0
def test_datasource_get_form(gd, gw, is_closed, close, connect,
                             snowflake_connector, snowflake_datasource):
    result = snowflake_datasource.get_form(snowflake_connector, {})
    assert 'warehouse_1' == result['properties']['warehouse']['default']
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
Beispiel #19
0
def test_set_warehouse(snowflake_connector, snowflake_datasource):
    snowflake_datasource.warehouse = None
    new_data_source = snowflake_connector._set_warehouse(snowflake_datasource)
    assert new_data_source.warehouse == 'warehouse_1'
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
Beispiel #20
0
def test_get_database_with_filter_found(gd, is_closed, close, connect,
                                        snowflake_connector):
    result = snowflake_connector._get_databases('database_1')
    assert result[0] == 'database_1'
    assert len(result) == 1
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
Beispiel #21
0
def test_get_warehouse_without_filter(gw, is_closed, close, connect,
                                      snowflake_connector):
    result = snowflake_connector._get_warehouses()
    assert result[0] == 'warehouse_1'
    assert result[1] == 'warehouse_2'
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
Beispiel #22
0
def test_get_warehouse_with_filter_not_found(gw, is_closed, close, connect,
                                             snowflake_connector):
    result = snowflake_connector._get_warehouses('warehouse_3')
    assert len(result) == 0
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
Beispiel #23
0
def test_get_unique_datasource_identifier():
    snowflake_connector = SnowflakeConnector(
        identifier='snowflake_test',
        name='test_name',
        authentication_method=AuthenticationMethod.PLAIN,
        user='******',
        password='******',
        account='test_account',
        default_warehouse='warehouse_1',
    )

    datasource = SnowflakeDataSource(
        name='test_name',
        domain='test_domain',
        database='database_1',
        warehouse='warehouse_1',
        query='test_query with %(foo)s and %(pokemon)s',
        query_object={
            'schema': 'SHOW_SCHEMA',
            'table': 'MY_TABLE',
            'columns': ['col1', 'col2']
        },
        parameters={
            'foo': 'bar',
            'pokemon': 'pikachu'
        },
    )
    key = snowflake_connector.get_cache_key(datasource)

    datasource2 = SnowflakeDataSource(
        name='test_name',
        domain='test_domain',
        database='database_1',
        warehouse='warehouse_1',
        query='test_query with %(foo)s and %(pokemon)s',
        query_object={
            'schema': 'SHOW_SCHEMA',
            'table': 'MY_TABLE',
            'columns': ['col1', 'col2']
        },
        parameters={
            'foo': 'bar',
            'pokemon': 'pikachu',
            'foo': 'bar'
        },
    )
    key2 = snowflake_connector.get_cache_key(datasource2)

    assert key == key2

    datasource3 = SnowflakeDataSource(
        name='test_name',
        domain='test_domain',
        database='database_2',
        warehouse='warehouse_1',
        query='test_query with %(foo)s and %(pokemon)s',
        query_object={
            'schema': 'SHOW_SCHEMA',
            'table': 'MY_TABLE',
            'columns': ['col1', 'col2']
        },
        parameters={
            'foo': 'bar',
            'pokemon': 'pikachu'
        },
    )

    key3 = snowflake_connector.get_cache_key(datasource3)
    assert key != key3

    another_snowflake_connector = SnowflakeConnector(
        identifier='snowflake_test',
        name='test_name',
        authentication_method=AuthenticationMethod.PLAIN,
        user='******',
        password='******',
        account='another_test_account',
        default_warehouse='warehouse_1',
    )

    assert snowflake_connector.get_cache_key(
        datasource) != another_snowflake_connector.get_cache_key(datasource)
    assert snowflake_connector.get_cache_key(
        datasource2) != another_snowflake_connector.get_cache_key(datasource2)
    assert snowflake_connector.get_cache_key(
        datasource3) != another_snowflake_connector.get_cache_key(datasource3)
Beispiel #24
0
def test_get_status_without_warehouses(gw, is_closed, close, connect,
                                       snowflake_connector):
    connector_status = snowflake_connector.get_status()
    assert not connector_status.status
    SnowflakeConnector.get_snowflake_connection_manager().force_clean()
Beispiel #25
0
def test_get_model(is_closed, close, connect, mocker, snowflake_datasource,
                   snowflake_connector):
    cm = SnowflakeConnector.get_snowflake_connection_manager()
    mocked_common_get_databases = mocker.patch(
        'toucan_connectors.snowflake.snowflake_connector.SnowflakeCommon.get_databases',
        return_value=['booo'],
    )
    mocked_common_get_db_content = mocker.patch(
        'toucan_connectors.snowflake.snowflake_connector.SnowflakeCommon.get_db_content',
        return_value=pd.DataFrame([{
            'DATABASE':
            'SNOWFLAKE_SAMPLE_DATA',
            'SCHEMA':
            'TPCH_SF1000',
            'TYPE':
            'table',
            'NAME':
            'REGION',
            'COLUMNS':
            '[\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": '
            '"R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",\n    "type": '
            '"TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    '
            '"name": "R_REGIONKEY",\n    "type": "NUMBER"\n  },\n  {\n    "name": "R_NAME",'
            '\n    "type": "TEXT"\n  },\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },'
            '\n  {\n    "name": "R_NAME",\n    "type": "TEXT"\n  },\n  {\n    "name": "R_NAME",'
            '\n    "type": "TEXT"\n  },\n  {\n    "name": "R_REGIONKEY",\n    "type": "NUMBER"\n  '
            '},\n  {\n    "name": "R_COMMENT",\n    "type": "TEXT"\n  },\n  {\n    "name": '
            '"R_REGIONKEY",\n    "type": "NUMBER"\n  }\n]',
        }]),
    )
    res = snowflake_connector.get_model()
    mocked_common_get_databases.assert_called_once()
    mocked_common_get_db_content.assert_called_once()
    assert res == [{
        'name':
        'REGION',
        'schema':
        'TPCH_SF1000',
        'database':
        'SNOWFLAKE_SAMPLE_DATA',
        'type':
        'table',
        'columns': [
            {
                'name': 'R_COMMENT',
                'type': 'TEXT'
            },
            {
                'name': 'R_COMMENT',
                'type': 'TEXT'
            },
            {
                'name': 'R_NAME',
                'type': 'TEXT'
            },
            {
                'name': 'R_REGIONKEY',
                'type': 'NUMBER'
            },
            {
                'name': 'R_REGIONKEY',
                'type': 'NUMBER'
            },
            {
                'name': 'R_NAME',
                'type': 'TEXT'
            },
            {
                'name': 'R_COMMENT',
                'type': 'TEXT'
            },
            {
                'name': 'R_NAME',
                'type': 'TEXT'
            },
            {
                'name': 'R_NAME',
                'type': 'TEXT'
            },
            {
                'name': 'R_REGIONKEY',
                'type': 'NUMBER'
            },
            {
                'name': 'R_COMMENT',
                'type': 'TEXT'
            },
            {
                'name': 'R_REGIONKEY',
                'type': 'NUMBER'
            },
        ],
    }]
    cm.force_clean()
import pytest

from toucan_connectors.snowflake import SnowflakeConnector, SnowflakeDataSource

sc = SnowflakeConnector(name='test_name',
                        user='******',
                        password='******',
                        account='test_account')

sd = SnowflakeDataSource(name='test_name',
                         domain='test_domain',
                         database='test_database',
                         warehouse='test_warehouse',
                         query='test_query')


def test_snowflake(mocker):
    snock = mocker.patch('snowflake.connector.connect')
    reasq = mocker.patch('pandas.read_sql')

    sc.get_df(sd)

    snock.assert_called_once_with(user='******',
                                  password='******',
                                  account='test_account',
                                  database='test_database',
                                  warehouse='test_warehouse',
                                  ocsp_response_cache_filename=None)

    reasq.assert_called_once_with('test_query', con=snock())