class TestSnowflakeAdapter(unittest.TestCase):
    def setUp(self):
        flags.STRICT_MODE = False

        profile_cfg = {
            'outputs': {
                'test': {
                    'type': 'snowflake',
                    'account': 'test_account',
                    'user': '******',
                    'database': 'test_database',
                    'warehouse': 'test_warehouse',
                    'schema': 'public',
                },
            },
            'target': 'test',
        }

        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            }
        }
        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

        self.handle = mock.MagicMock(
            spec=snowflake_connector.SnowflakeConnection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch(
            'dbt.adapters.snowflake.connections.snowflake.connector.connect')
        self.snowflake = self.patcher.start()

        self.snowflake.return_value = self.handle
        self.adapter = SnowflakeAdapter(self.config)
        # patch our new adapter into the factory so macros behave
        inject_adapter('snowflake', self.adapter)

    def tearDown(self):
        # we want a unique self.handle every time.
        self.adapter.cleanup_connections()
        self.patcher.stop()

    def test_quoting_on_drop_schema(self):
        self.adapter.drop_schema(
            database='test_database',
            schema='test_schema'
        )

        self.mock_execute.assert_has_calls([
            mock.call('drop schema if exists test_database."test_schema" cascade', None)
        ])

    def test_quoting_on_drop(self):
        relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='test_table',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        self.adapter.drop_relation(relation)

        self.mock_execute.assert_has_calls([
            mock.call(
                'drop table if exists test_database."test_schema".test_table cascade',
                None
            )
        ])

    def test_quoting_on_truncate(self):
        relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='test_table',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        self.adapter.truncate_relation(relation)

        self.mock_execute.assert_has_calls([
            mock.call('truncate table test_database."test_schema".test_table', None)
        ])

    def test_quoting_on_rename(self):
        from_relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='table_a',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        to_relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='table_b',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )

        self.adapter.rename_relation(
            from_relation=from_relation,
            to_relation=to_relation
        )
        self.mock_execute.assert_has_calls([
            mock.call(
                'alter table test_database."test_schema".table_a rename to test_database."test_schema".table_b',
                None
            )
        ])

    def test_cancel_open_connections_empty(self):
        self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

    def test_cancel_open_connections_master(self):
        self.adapter.connections.in_use['master'] = mock.MagicMock()
        self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

    def test_cancel_open_connections_single(self):
        master = mock.MagicMock()
        model = mock.MagicMock()
        model.handle.session_id = 42

        self.adapter.connections.in_use.update({
            'master': master,
            'model': model,
        })
        with mock.patch.object(self.adapter.connections, 'add_query') as add_query:
            query_result = mock.MagicMock()
            add_query.return_value = (None, query_result)

            self.assertEqual(
                len(list(self.adapter.cancel_open_connections())), 1)

            add_query.assert_called_once_with(
                'select system$abort_session(42)', 'master')

    def test_client_session_keep_alive_false_by_default(self):
        self.adapter.connections.get(name='new_connection_with_new_config')
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key=None)
        ])

    def test_client_session_keep_alive_true(self):
        self.config.credentials = self.config.credentials.incorporate(
            client_session_keep_alive=True)
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=True, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key=None)
        ])

    def test_user_pass_authentication(self):
        self.config.credentials = self.config.credentials.incorporate(
            password='******')
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                password='******', role=None, schema='public',
                user='******', warehouse='test_warehouse', private_key=None)
        ])

    def test_authenticator_user_pass_authentication(self):
        self.config.credentials = self.config.credentials.incorporate(
            password='******', authenticator='test_sso_url')
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                password='******', role=None, schema='public',
                user='******', warehouse='test_warehouse',
                authenticator='test_sso_url', private_key=None)
        ])

    def test_authenticator_externalbrowser_authentication(self):
        self.config.credentials = self.config.credentials.incorporate(
            authenticator='externalbrowser')
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', authenticator='externalbrowser',
                private_key=None)
        ])

    @patch('dbt.adapters.snowflake.SnowflakeConnectionManager._get_private_key', return_value='test_key')
    def test_authenticator_private_key_authentication(self, mock_get_private_key):
        self.config.credentials = self.config.credentials.incorporate(
            private_key_path='/tmp/test_key.p8',
            private_key_passphrase='p@ssphr@se')

        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key='test_key')
        ])
Exemplo n.º 2
0
class TestSnowflakeAdapter(unittest.TestCase):
    def setUp(self):
        flags.STRICT_MODE = False

        profile_cfg = {
            'outputs': {
                'test': {
                    'type': 'snowflake',
                    'account': 'test_account',
                    'user': '******',
                    'database': 'test_database',
                    'warehouse': 'test_warehouse',
                    'schema': 'public',
                },
            },
            'target': 'test',
        }

        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            },
            'query-comment': 'dbt',
            'config-version': 2,
        }
        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
        self.assertEqual(self.config.query_comment.comment, 'dbt')
        self.assertEqual(self.config.query_comment.append, False)

        self.handle = mock.MagicMock(
            spec=snowflake_connector.SnowflakeConnection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch(
            'dbt.adapters.snowflake.connections.snowflake.connector.connect'
        )
        self.snowflake = self.patcher.start()

        # Create the Manifest.state_check patcher
        @mock.patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
        def _mock_state_check(self):
            config = self.root_project
            all_projects = self.all_projects
            return ManifestStateCheck(
                vars_hash=FileHash.from_contents('vars'),
                project_hashes={name: FileHash.from_contents(name) for name in all_projects},
                profile_hash=FileHash.from_contents('profile'),
            )
        self.load_state_check = mock.patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
        self.mock_state_check = self.load_state_check.start()
        self.mock_state_check.side_effect = _mock_state_check

        self.snowflake.return_value = self.handle
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
        self.adapter.connections.query_header = MacroQueryStringSetter(self.config, self.adapter._macro_manifest_lazy)

        self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q)

        self.adapter.acquire_connection()
        inject_adapter(self.adapter, SnowflakePlugin)

    def tearDown(self):
        # we want a unique self.handle every time.
        self.adapter.cleanup_connections()
        self.qh_patch.stop()
        self.patcher.stop()
        self.load_state_check.stop()

    def test_quoting_on_drop_schema(self):
        relation = SnowflakeAdapter.Relation.create(
            database='test_database',
            schema='test_schema',
            quote_policy=self.adapter.config.quoting
        )
        self.adapter.drop_schema(relation)

        self.mock_execute.assert_has_calls([
            mock.call('/* dbt */\ndrop schema if exists test_database."test_schema" cascade', None)
        ])

    def test_quoting_on_drop(self):
        relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='test_table',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        self.adapter.drop_relation(relation)

        self.mock_execute.assert_has_calls([
            mock.call(
                '/* dbt */\ndrop table if exists test_database."test_schema".test_table cascade',
                None
            )
        ])

    def test_quoting_on_truncate(self):
        relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='test_table',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        self.adapter.truncate_relation(relation)

        self.mock_execute.assert_has_calls([
            mock.call('/* dbt */\ntruncate table test_database."test_schema".test_table', None)
        ])

    def test_quoting_on_rename(self):
        from_relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='table_a',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        to_relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='table_b',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )

        self.adapter.rename_relation(
            from_relation=from_relation,
            to_relation=to_relation
        )
        self.mock_execute.assert_has_calls([
            mock.call(
                '/* dbt */\nalter table test_database."test_schema".table_a rename to test_database."test_schema".table_b',
                None
            )
        ])

    @contextmanager
    def current_warehouse(self, response):
        # there is probably some elegant way built into mock.patch to do this
        fetchall_return = self.cursor.fetchall.return_value
        execute_side_effect = self.mock_execute.side_effect

        def execute_effect(sql, *args, **kwargs):
            if sql == '/* dbt */\nselect current_warehouse() as warehouse':
                self.cursor.description = [['name']]
                self.cursor.fetchall.return_value = [[response]]
            else:
                self.cursor.description = None
                self.cursor.fetchall.return_value = fetchall_return
            return self.mock_execute.return_value

        self.mock_execute.side_effect = execute_effect
        try:
            yield
        finally:
            self.cursor.fetchall.return_value = fetchall_return
            self.mock_execute.side_effect = execute_side_effect

    def _strip_transactions(self):
        result = []
        for call_args in self.mock_execute.call_args_list:
            args, kwargs = tuple(call_args)
            is_transactional = (
                len(kwargs) == 0 and
                len(args) == 2 and
                args[1] is None and
                args[0] in {'BEGIN', 'COMMIT'}
            )
            if not is_transactional:
                result.append(call_args)
        return result

    def test_pre_post_hooks_warehouse(self):
        with self.current_warehouse('warehouse'):
            config = {'snowflake_warehouse': 'other_warehouse'}
            result = self.adapter.pre_model_hook(config)
            self.assertIsNotNone(result)
            calls = [
                mock.call('/* dbt */\nselect current_warehouse() as warehouse', None),
                mock.call('/* dbt */\nuse warehouse other_warehouse', None)
            ]
            self.mock_execute.assert_has_calls(calls)
            self.adapter.post_model_hook(config, result)
            calls.append(mock.call('/* dbt */\nuse warehouse warehouse', None))
            self.mock_execute.assert_has_calls(calls)

    def test_pre_post_hooks_no_warehouse(self):
        with self.current_warehouse('warehouse'):
            config = {}
            result = self.adapter.pre_model_hook(config)
            self.assertIsNone(result)
            self.mock_execute.assert_not_called()
            self.adapter.post_model_hook(config, result)
            self.mock_execute.assert_not_called()

    def test_cancel_open_connections_empty(self):
        self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

    def test_cancel_open_connections_master(self):
        key = self.adapter.connections.get_thread_identifier()
        self.adapter.connections.thread_connections[key] = mock_connection('master')
        self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

    def test_cancel_open_connections_single(self):
        master = mock_connection('master')
        model = mock_connection('model')
        model.handle.session_id = 42

        key = self.adapter.connections.get_thread_identifier()
        self.adapter.connections.thread_connections.update({
            key: master,
            1: model,
        })
        with mock.patch.object(self.adapter.connections, 'add_query') as add_query:
            query_result = mock.MagicMock()
            add_query.return_value = (None, query_result)

            self.assertEqual(
                len(list(self.adapter.cancel_open_connections())), 1)

            add_query.assert_called_once_with('select system$abort_session(42)')

    def test_client_session_keep_alive_false_by_default(self):
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key=None, application='dbt')
        ])

    def test_client_session_keep_alive_true(self):
        self.config.credentials = self.config.credentials.replace(
                                          client_session_keep_alive=True)
        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=True, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key=None, application='dbt')
        ])

    def test_user_pass_authentication(self):
        self.config.credentials = self.config.credentials.replace(
            password='******',
        )
        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                password='******', role=None, schema='public',
                user='******', warehouse='test_warehouse', private_key=None,
                application='dbt')
        ])

    def test_authenticator_user_pass_authentication(self):
        self.config.credentials = self.config.credentials.replace(
            password='******',
            authenticator='test_sso_url',
        )
        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                password='******', role=None, schema='public',
                user='******', warehouse='test_warehouse',
                authenticator='test_sso_url', private_key=None,
                application='dbt', client_store_temporary_credential=True)
        ])

    def test_authenticator_externalbrowser_authentication(self):
        self.config.credentials = self.config.credentials.replace(
            authenticator='externalbrowser'
        )
        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', authenticator='externalbrowser',
                private_key=None, application='dbt', client_store_temporary_credential=True)
        ])

    def test_authenticator_oauth_authentication(self):
        self.config.credentials = self.config.credentials.replace(
            authenticator='oauth',
            token='my-oauth-token',
        )
        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', authenticator='oauth', token='my-oauth-token',
                private_key=None, application='dbt', client_store_temporary_credential=True)
        ])

    @mock.patch('dbt.adapters.snowflake.SnowflakeCredentials._get_private_key', return_value='test_key')
    def test_authenticator_private_key_authentication(self, mock_get_private_key):
        self.config.credentials = self.config.credentials.replace(
            private_key_path='/tmp/test_key.p8',
            private_key_passphrase='p@ssphr@se',
        )

        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key='test_key',
                application='dbt')
        ])

    @mock.patch('dbt.adapters.snowflake.SnowflakeCredentials._get_private_key', return_value='test_key')
    def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_private_key):
        self.config.credentials = self.config.credentials.replace(
            private_key_path='/tmp/test_key.p8',
            private_key_passphrase=None,
        )

        self.adapter = SnowflakeAdapter(self.config)
        conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config')

        self.snowflake.assert_not_called()
        conn.handle
        self.snowflake.assert_has_calls([
            mock.call(
                account='test_account', autocommit=False,
                client_session_keep_alive=False, database='test_database',
                role=None, schema='public', user='******',
                warehouse='test_warehouse', private_key='test_key',
                application='dbt')
        ])
Exemplo n.º 3
0
class TestSnowflakeAdapter(unittest.TestCase):
    def setUp(self):
        flags.STRICT_MODE = False

        profile_cfg = {
            'outputs': {
                'test': {
                    'type': 'snowflake',
                    'account': 'test_account',
                    'user': '******',
                    'database': 'test_database',
                    'warehouse': 'test_warehouse',
                    'schema': 'public',
                },
            },
            'target': 'test',
        }

        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            }
        }
        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

        self.handle = mock.MagicMock(
            spec=snowflake_connector.SnowflakeConnection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch(
            'dbt.adapters.snowflake.connections.snowflake.connector.connect')
        self.snowflake = self.patcher.start()

        self.snowflake.return_value = self.handle
        self.adapter = SnowflakeAdapter(self.config)
        # patch our new adapter into the factory so macros behave
        inject_adapter('snowflake', self.adapter)

    def tearDown(self):
        # we want a unique self.handle every time.
        self.adapter.cleanup_connections()
        self.patcher.stop()

    def test_quoting_on_drop_schema(self):
        self.adapter.drop_schema(database='test_database',
                                 schema='test_schema')

        self.mock_execute.assert_has_calls([
            mock.call(
                'drop schema if exists test_database."test_schema" cascade',
                None)
        ])

    def test_quoting_on_drop(self):
        relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='test_table',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        self.adapter.drop_relation(relation)

        self.mock_execute.assert_has_calls([
            mock.call(
                'drop table if exists test_database."test_schema".test_table cascade',
                None)
        ])

    def test_quoting_on_truncate(self):
        relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='test_table',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        self.adapter.truncate_relation(relation)

        self.mock_execute.assert_has_calls([
            mock.call('truncate table test_database."test_schema".test_table',
                      None)
        ])

    def test_quoting_on_rename(self):
        from_relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='table_a',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )
        to_relation = self.adapter.Relation.create(
            database='test_database',
            schema='test_schema',
            identifier='table_b',
            type='table',
            quote_policy=self.adapter.config.quoting,
        )

        self.adapter.rename_relation(from_relation=from_relation,
                                     to_relation=to_relation)
        self.mock_execute.assert_has_calls([
            mock.call(
                'alter table test_database."test_schema".table_a rename to test_database."test_schema".table_b',
                None)
        ])

    def test_cancel_open_connections_empty(self):
        self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

    def test_cancel_open_connections_master(self):
        self.adapter.connections.in_use['master'] = mock.MagicMock()
        self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

    def test_cancel_open_connections_single(self):
        master = mock.MagicMock()
        model = mock.MagicMock()
        model.handle.session_id = 42

        self.adapter.connections.in_use.update({
            'master': master,
            'model': model,
        })
        with mock.patch.object(self.adapter.connections,
                               'add_query') as add_query:
            query_result = mock.MagicMock()
            add_query.return_value = (None, query_result)

            self.assertEqual(len(list(self.adapter.cancel_open_connections())),
                             1)

            add_query.assert_called_once_with(
                'select system$abort_session(42)', 'master')

    def test_client_session_keep_alive_false_by_default(self):
        self.adapter.connections.get(name='new_connection_with_new_config')
        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=False,
                      database='test_database',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse',
                      private_key=None)
        ])

    def test_client_session_keep_alive_true(self):
        self.config.credentials = self.config.credentials.incorporate(
            client_session_keep_alive=True)
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=True,
                      database='test_database',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse',
                      private_key=None)
        ])

    def test_user_pass_authentication(self):
        self.config.credentials = self.config.credentials.incorporate(
            password='******')
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=False,
                      database='test_database',
                      password='******',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse',
                      private_key=None)
        ])

    def test_authenticator_user_pass_authentication(self):
        self.config.credentials = self.config.credentials.incorporate(
            password='******', authenticator='test_sso_url')
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=False,
                      database='test_database',
                      password='******',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse',
                      authenticator='test_sso_url',
                      private_key=None)
        ])

    def test_authenticator_externalbrowser_authentication(self):
        self.config.credentials = self.config.credentials.incorporate(
            authenticator='externalbrowser')
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=False,
                      database='test_database',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse',
                      authenticator='externalbrowser',
                      private_key=None)
        ])

    @patch(
        'dbt.adapters.snowflake.SnowflakeConnectionManager._get_private_key',
        return_value='test_key')
    def test_authenticator_private_key_authentication(self,
                                                      mock_get_private_key):
        self.config.credentials = self.config.credentials.incorporate(
            private_key_path='/tmp/test_key.p8',
            private_key_passphrase='p@ssphr@se')

        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.connections.get(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=False,
                      database='test_database',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse',
                      private_key='test_key')
        ])
Exemplo n.º 4
0
 def tearDown(self):
     # we want a unique self.handle every time.
     SnowflakeAdapter.cleanup_connections()
     self.patcher.stop()
Exemplo n.º 5
0
class TestSnowflakeAdapter(unittest.TestCase):
    def setUp(self):
        flags.STRICT_MODE = False

        profile_cfg = {
            'outputs': {
                'test': {
                    'type': 'snowflake',
                    'account': 'test_account',
                    'user': '******',
                    'password': '******',
                    'database': 'test_databse',
                    'warehouse': 'test_warehouse',
                    'schema': 'public',
                },
            },
            'target': 'test',
        }

        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            }
        }
        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

        self.handle = mock.MagicMock(
            spec=snowflake_connector.SnowflakeConnection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch(
            'dbt.adapters.snowflake.impl.snowflake.connector.connect')
        self.snowflake = self.patcher.start()

        self.snowflake.return_value = self.handle
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.get_connection()

    def tearDown(self):
        # we want a unique self.handle every time.
        self.adapter.cleanup_connections()
        self.patcher.stop()

    def test_quoting_on_drop_schema(self):
        self.adapter.drop_schema(schema='test_schema')

        self.mock_execute.assert_has_calls(
            [mock.call('drop schema if exists "test_schema" cascade', None)])

    def test_quoting_on_drop(self):
        self.adapter.drop(schema='test_schema',
                          relation='test_table',
                          relation_type='table')
        self.mock_execute.assert_has_calls([
            mock.call('drop table if exists "test_schema".test_table cascade',
                      None)
        ])

    def test_quoting_on_truncate(self):
        self.adapter.truncate(schema='test_schema', table='test_table')
        self.mock_execute.assert_has_calls(
            [mock.call('truncate table "test_schema".test_table', None)])

    def test_quoting_on_rename(self):
        self.adapter.rename(schema='test_schema',
                            from_name='table_a',
                            to_name='table_b')
        self.mock_execute.assert_has_calls([
            mock.call('alter table "test_schema".table_a rename to table_b',
                      None)
        ])

    def test_client_session_keep_alive_false_by_default(self):
        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=False,
                      database='test_databse',
                      password='******',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse')
        ])

    def test_client_session_keep_alive_true(self):
        self.config.credentials = self.config.credentials.incorporate(
            client_session_keep_alive=True)
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter.get_connection(name='new_connection_with_new_config')

        self.snowflake.assert_has_calls([
            mock.call(account='test_account',
                      autocommit=False,
                      client_session_keep_alive=True,
                      database='test_databse',
                      password='******',
                      role=None,
                      schema='public',
                      user='******',
                      warehouse='test_warehouse')
        ])