Exemple #1
0
 def test_append_comment(self):
     self.project_cfg.update(
         {'query-comment': {
             'comment': 'executed by dbt',
             'append': True
         }})
     config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
     query_header = MacroQueryStringSetter(config,
                                           mock.MagicMock(macros={}))
     sql = query_header.add(self.query)
     self.assertEqual(sql, f'{self.query[:-1]}\n/* executed by dbt */;')
Exemple #2
0
    def setUp(self):
        flags.STRICT_MODE = False

        profile_cfg = {
            'outputs': {
                'test': {
                    'type': 'snowflake',
                    'account': 'test_account',
                    'user': '******',
                    'database': 'test_database',
                    'warehouse': 'test_warehouse',
                    'schema': 'public',
                },
            },
            'target': 'test',
        }

        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            },
            'query-comment': 'dbt',
        }
        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
        self.assertEqual(self.config.query_comment.comment, 'dbt')
        self.assertEqual(self.config.query_comment.append, False)

        self.handle = mock.MagicMock(
            spec=snowflake_connector.SnowflakeConnection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch(
            'dbt.adapters.snowflake.connections.snowflake.connector.connect'
        )
        self.snowflake = self.patcher.start()

        self.load_patch = mock.patch('dbt.parser.manifest.make_parse_result')
        self.mock_parse_result = self.load_patch.start()
        self.mock_parse_result.return_value = ParseResult.rpc()

        self.snowflake.return_value = self.handle
        self.adapter = SnowflakeAdapter(self.config)

        self.adapter.connections.query_header = MacroQueryStringSetter(self.config, mock.MagicMock(macros={}))

        self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q)

        self.adapter.acquire_connection()
        inject_adapter(self.adapter)
Exemple #3
0
    def setUp(self):
        flags.STRICT_MODE = False

        self.target_dict = {
            'type': 'postgres',
            'dbname': 'postgres',
            'user': '******',
            'host': 'thishostshouldnotexist',
            'pass': '******',
            'port': 5432,
            'schema': 'public'
        }

        profile_cfg = {
            'outputs': {
                'test': self.target_dict,
            },
            'target': 'test'
        }
        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            },
        }

        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

        self.handle = mock.MagicMock(spec=psycopg2_extensions.connection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch('dbt.adapters.postgres.connections.psycopg2')
        self.psycopg2 = self.patcher.start()

        self.psycopg2.connect.return_value = self.handle
        self.adapter = PostgresAdapter(self.config)
        self.adapter.connections.query_header = MacroQueryStringSetter(
            self.config, mock.MagicMock(macros={}))

        self.qh_patch = mock.patch.object(
            self.adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(
            q)
        self.adapter.acquire_connection()
        inject_adapter(self.adapter)

        self.load_patch = mock.patch('dbt.parser.manifest.make_parse_result')
        self.mock_parse_result = self.load_patch.start()
        self.mock_parse_result.return_value = ParseResult.rpc()
Exemple #4
0
    def setUp(self):
        self.target_dict = {
            'type': 'oracle',
            'service': 'test',
            'username': '******',
            'host': 'thishostshouldnotexist',
            'password': '******',
            'port': 1521,
            'schema': 'public'
        }

        profile_cfg = {
            'outputs': {
                'test': self.target_dict,
            },
            'target': 'test'
        }
        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            },
            'config-version': 2,
        }

        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

        self.handle = mock.MagicMock(spec=cx_Oracle.Connection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch('dbt.adapters.oracle.connections.oracle')
        self.oracle = self.patcher.start()

        self.load_patch = mock.patch('dbt.parser.manifest.make_parse_result')
        self.mock_parse_result = self.load_patch.start()
        self.mock_parse_result.return_value = ParseResult.rpc()

        self.oracle.connect.return_value = self.handle
        self.adapter = OracleAdapter(self.config)
        self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
        self.adapter.connections.query_header = MacroQueryStringSetter(self.config, self.adapter._macro_manifest_lazy)

        self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q)
        self.adapter.acquire_connection()
        inject_adapter(self.adapter, OraclePlugin)
    def get_adapter(self, target):
        project = self.project_cfg.copy()
        profile = self.raw_profile.copy()
        profile['target'] = target

        config = config_from_parts_or_dicts(
            project=project,
            profile=profile,
        )
        adapter = BigQueryAdapter(config)

        adapter.connections.query_header = MacroQueryStringSetter(
            config, MagicMock(macros={}))

        self.qh_patch = patch.object(adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(
            q)

        inject_adapter(adapter, BigQueryPlugin)
        return adapter
Exemple #6
0
 def set_query_header(self, manifest: Manifest) -> None:
     self.query_header = MacroQueryStringSetter(self.profile, manifest)
Exemple #7
0
class BaseConnectionManager(metaclass=abc.ABCMeta):
    """Methods to implement:
        - exception_handler
        - cancel_open
        - open
        - begin
        - commit
        - clear_transaction
        - execute

    You must also set the 'TYPE' class attribute with a class-unique constant
    string.
    """
    TYPE: str = NotImplemented

    def __init__(self, profile: AdapterRequiredConfig):
        self.profile = profile
        self.thread_connections: Dict[Hashable, Connection] = {}
        self.lock: RLock = dbt.flags.MP_CONTEXT.RLock()
        self.query_header: Optional[MacroQueryStringSetter] = None

    def set_query_header(self, manifest: Manifest) -> None:
        self.query_header = MacroQueryStringSetter(self.profile, manifest)

    @staticmethod
    def get_thread_identifier() -> Hashable:
        # note that get_ident() may be re-used, but we should never experience
        # that within a single process
        return (os.getpid(), get_ident())

    def get_thread_connection(self) -> Connection:
        key = self.get_thread_identifier()
        with self.lock:
            if key not in self.thread_connections:
                raise dbt.exceptions.InvalidConnectionException(
                    key, list(self.thread_connections))
            return self.thread_connections[key]

    def set_thread_connection(self, conn):
        key = self.get_thread_identifier()
        if key in self.thread_connections:
            raise dbt.exceptions.InternalException(
                'In set_thread_connection, existing connection exists for {}')
        self.thread_connections[key] = conn

    def get_if_exists(self) -> Optional[Connection]:
        key = self.get_thread_identifier()
        with self.lock:
            return self.thread_connections.get(key)

    def clear_thread_connection(self) -> None:
        key = self.get_thread_identifier()
        with self.lock:
            if key in self.thread_connections:
                del self.thread_connections[key]

    def clear_transaction(self) -> None:
        """Clear any existing transactions."""
        conn = self.get_thread_connection()
        if conn is not None:
            if conn.transaction_open:
                self._rollback(conn)
            self.begin()
            self.commit()

    @abc.abstractmethod
    def exception_handler(self, sql: str) -> ContextManager:
        """Create a context manager that handles exceptions caused by database
        interactions.

        :param str sql: The SQL string that the block inside the context
            manager is executing.
        :return: A context manager that handles exceptions raised by the
            underlying database.
        """
        raise dbt.exceptions.NotImplementedException(
            '`exception_handler` is not implemented for this adapter!')

    def set_connection_name(self, name: Optional[str] = None) -> Connection:
        conn_name: str
        if name is None:
            # if a name isn't specified, we'll re-use a single handle
            # named 'master'
            conn_name = 'master'
        else:
            if not isinstance(name, str):
                raise dbt.exceptions.CompilerException(
                    f'For connection name, got {name} - not a string!')
            assert isinstance(name, str)
            conn_name = name

        conn = self.get_if_exists()
        if conn is None:
            conn = Connection(type=Identifier(self.TYPE),
                              name=None,
                              state=ConnectionState.INIT,
                              transaction_open=False,
                              handle=None,
                              credentials=self.profile.credentials)
            self.set_thread_connection(conn)

        if conn.name == conn_name and conn.state == 'open':
            return conn

        logger.debug('Acquiring new {} connection "{}".'.format(
            self.TYPE, conn_name))

        if conn.state == 'open':
            logger.debug(
                'Re-using an available connection from the pool (formerly {}).'
                .format(conn.name))
        else:
            conn.handle = LazyHandle(self.open)

        conn.name = conn_name
        return conn

    @abc.abstractmethod
    def cancel_open(self) -> Optional[List[str]]:
        """Cancel all open connections on the adapter. (passable)"""
        raise dbt.exceptions.NotImplementedException(
            '`cancel_open` is not implemented for this adapter!')

    @abc.abstractclassmethod
    def open(cls, connection: Connection) -> Connection:
        """Open the given connection on the adapter and return it.

        This may mutate the given connection (in particular, its state and its
        handle).

        This should be thread-safe, or hold the lock if necessary. The given
        connection should not be in either in_use or available.
        """
        raise dbt.exceptions.NotImplementedException(
            '`open` is not implemented for this adapter!')

    def release(self) -> None:
        with self.lock:
            conn = self.get_if_exists()
            if conn is None:
                return

        try:
            if conn.state == 'open':
                if conn.transaction_open is True:
                    self._rollback(conn)
            else:
                self.close(conn)
        except Exception:
            # if rollback or close failed, remove our busted connection
            self.clear_thread_connection()
            raise

    def cleanup_all(self) -> None:
        with self.lock:
            for connection in self.thread_connections.values():
                if connection.state not in {'closed', 'init'}:
                    logger.debug("Connection '{}' was left open.".format(
                        connection.name))
                else:
                    logger.debug("Connection '{}' was properly closed.".format(
                        connection.name))
                self.close(connection)

            # garbage collect these connections
            self.thread_connections.clear()

    @abc.abstractmethod
    def begin(self) -> None:
        """Begin a transaction. (passable)"""
        raise dbt.exceptions.NotImplementedException(
            '`begin` is not implemented for this adapter!')

    @abc.abstractmethod
    def commit(self) -> None:
        """Commit a transaction. (passable)"""
        raise dbt.exceptions.NotImplementedException(
            '`commit` is not implemented for this adapter!')

    @classmethod
    def _rollback_handle(cls, connection: Connection) -> None:
        """Perform the actual rollback operation."""
        try:
            connection.handle.rollback()
        except Exception:
            logger.debug('Failed to rollback {}'.format(connection.name),
                         exc_info=True)

    @classmethod
    def _close_handle(cls, connection: Connection) -> None:
        """Perform the actual close operation."""
        # On windows, sometimes connection handles don't have a close() attr.
        if hasattr(connection.handle, 'close'):
            logger.debug('On {}: Close'.format(connection.name))
            connection.handle.close()
        else:
            logger.debug('On {}: No close available on handle'.format(
                connection.name))

    @classmethod
    def _rollback(cls, connection: Connection) -> None:
        """Roll back the given connection."""
        if dbt.flags.STRICT_MODE:
            if not isinstance(connection, Connection):
                raise dbt.exceptions.CompilerException(
                    f'In _rollback, got {connection} - not a Connection!')

        if connection.transaction_open is False:
            raise dbt.exceptions.InternalException(
                'Tried to rollback transaction on connection "{}", but '
                'it does not have one open!'.format(connection.name))

        logger.debug('On {}: ROLLBACK'.format(connection.name))
        cls._rollback_handle(connection)

        connection.transaction_open = False

    @classmethod
    def close(cls, connection: Connection) -> Connection:
        if dbt.flags.STRICT_MODE:
            if not isinstance(connection, Connection):
                raise dbt.exceptions.CompilerException(
                    f'In close, got {connection} - not a Connection!')

        # if the connection is in closed or init, there's nothing to do
        if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}:
            return connection

        if connection.transaction_open and connection.handle:
            cls._rollback_handle(connection)
        connection.transaction_open = False

        cls._close_handle(connection)
        connection.state = ConnectionState.CLOSED

        return connection

    def commit_if_has_connection(self) -> None:
        """If the named connection exists, commit the current transaction."""
        connection = self.get_if_exists()
        if connection:
            self.commit()

    def _add_query_comment(self, sql: str) -> str:
        if self.query_header is None:
            return sql
        return self.query_header.add(sql)

    @abc.abstractmethod
    def execute(self,
                sql: str,
                auto_begin: bool = False,
                fetch: bool = False) -> Tuple[str, agate.Table]:
        """Execute the given SQL.

        :param str sql: The sql to execute.
        :param bool auto_begin: If set, and dbt is not currently inside a
            transaction, automatically begin one.
        :param bool fetch: If set, fetch results.
        :return: A tuple of the status and the results (empty if fetch=False).
        :rtype: Tuple[str, agate.Table]
        """
        raise dbt.exceptions.NotImplementedException(
            '`execute` is not implemented for this adapter!')
Exemple #8
0
    def setUp(self):
        flags.STRICT_MODE = False

        profile_cfg = {
            'outputs': {
                'test': {
                    'type': 'snowflake',
                    'account': 'test_account',
                    'user': '******',
                    'database': 'test_database',
                    'warehouse': 'test_warehouse',
                    'schema': 'public',
                },
            },
            'target': 'test',
        }

        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            },
            'query-comment': 'dbt',
            'config-version': 2,
        }
        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
        self.assertEqual(self.config.query_comment.comment, 'dbt')
        self.assertEqual(self.config.query_comment.append, False)

        self.handle = mock.MagicMock(
            spec=snowflake_connector.SnowflakeConnection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch(
            'dbt.adapters.snowflake.connections.snowflake.connector.connect'
        )
        self.snowflake = self.patcher.start()

        # Create the Manifest.state_check patcher
        @mock.patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
        def _mock_state_check(self):
            config = self.root_project
            all_projects = self.all_projects
            return ManifestStateCheck(
                vars_hash=FileHash.from_contents('vars'),
                project_hashes={name: FileHash.from_contents(name) for name in all_projects},
                profile_hash=FileHash.from_contents('profile'),
            )
        self.load_state_check = mock.patch('dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
        self.mock_state_check = self.load_state_check.start()
        self.mock_state_check.side_effect = _mock_state_check

        self.snowflake.return_value = self.handle
        self.adapter = SnowflakeAdapter(self.config)
        self.adapter._macro_manifest_lazy = load_internal_manifest_macros(self.config)
        self.adapter.connections.query_header = MacroQueryStringSetter(self.config, self.adapter._macro_manifest_lazy)

        self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q)

        self.adapter.acquire_connection()
        inject_adapter(self.adapter, SnowflakePlugin)
Exemple #9
0
 def test_disable_query_comment(self):
     self.project_cfg.update({'query-comment': ''})
     config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
     query_header = MacroQueryStringSetter(config,
                                           mock.MagicMock(macros={}))
     self.assertEqual(query_header.add(self.query), self.query)
Exemple #10
0
 def test_comment_should_prepend_query_by_default(self):
     config = config_from_parts_or_dicts(self.project_cfg, self.profile_cfg)
     query_header = MacroQueryStringSetter(config,
                                           mock.MagicMock(macros={}))
     sql = query_header.add(self.query)
     self.assertTrue(re.match(f'^\/\*.*\*\/\n{self.query}$', sql))
Exemple #11
0
 def set_query_header(self, manifest=None) -> None:
     if manifest is not None:
         self.query_header = MacroQueryStringSetter(self.profile, manifest)
     else:
         self.query_header = QueryStringSetter(self.profile)
    def setUp(self):
        flags.STRICT_MODE = False

        self.target_dict = {
            'type': 'postgres',
            'dbname': 'postgres',
            'user': '******',
            'host': 'thishostshouldnotexist',
            'pass': '******',
            'port': 5432,
            'schema': 'public'
        }

        profile_cfg = {
            'outputs': {
                'test': self.target_dict,
            },
            'target': 'test'
        }
        project_cfg = {
            'name': 'X',
            'version': '0.1',
            'profile': 'test',
            'project-root': '/tmp/dbt/does-not-exist',
            'quoting': {
                'identifier': False,
                'schema': True,
            },
            'config-version': 2,
        }

        self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)

        self.handle = mock.MagicMock(spec=psycopg2_extensions.connection)
        self.cursor = self.handle.cursor.return_value
        self.mock_execute = self.cursor.execute
        self.patcher = mock.patch('dbt.adapters.postgres.connections.psycopg2')
        self.psycopg2 = self.patcher.start()

        # Create the Manifest.state_check patcher
        @mock.patch(
            'dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
        def _mock_state_check(self):
            config = self.root_project
            all_projects = self.all_projects
            return ManifestStateCheck(
                vars_hash=FileHash.from_contents('vars'),
                project_hashes={
                    name: FileHash.from_contents(name)
                    for name in all_projects
                },
                profile_hash=FileHash.from_contents('profile'),
            )

        self.load_state_check = mock.patch(
            'dbt.parser.manifest.ManifestLoader.build_manifest_state_check')
        self.mock_state_check = self.load_state_check.start()
        self.mock_state_check.side_effect = _mock_state_check

        self.psycopg2.connect.return_value = self.handle
        self.adapter = PostgresAdapter(self.config)
        self.adapter._macro_manifest_lazy = load_internal_manifest_macros(
            self.config)
        self.adapter.connections.query_header = MacroQueryStringSetter(
            self.config, self.adapter._macro_manifest_lazy)

        self.qh_patch = mock.patch.object(
            self.adapter.connections.query_header, 'add')
        self.mock_query_header_add = self.qh_patch.start()
        self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(
            q)
        self.adapter.acquire_connection()
        inject_adapter(self.adapter, PostgresPlugin)