示例#1
0
def _get_connection(sid, did, trans_id):
    """
    Get the connection object of ERD.
    :param sid:
    :param did:
    :param trans_id:
    :return:
    """
    manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid)
    try:
        conn = manager.connection(did=did, conn_id=trans_id,
                                  auto_reconnect=True,
                                  use_binary_placeholder=True)
        status, msg = conn.connect()
        if not status:
            app.logger.error(msg)
            raise ConnectionLost(sid, conn.db, trans_id)

        return conn
    except Exception as e:
        app.logger.error(e)
        raise
示例#2
0
    def connection(self,
                   database=None,
                   conn_id=None,
                   auto_reconnect=True,
                   did=None,
                   async_=None,
                   use_binary_placeholder=False,
                   array_to_string=False):
        if database is not None:
            if hasattr(str, 'decode') and \
                    not isinstance(database, unicode):
                database = database.decode('utf-8')
            if did is not None:
                if did in self.db_info:
                    self.db_info[did]['datname'] = database
        else:
            if did is None:
                database = self.db
            elif did in self.db_info:
                database = self.db_info[did]['datname']
            else:
                maintenance_db_id = u'DB:{0}'.format(self.db)
                if maintenance_db_id in self.connections:
                    conn = self.connections[maintenance_db_id]
                    # try to connect maintenance db if not connected
                    if not conn.connected():
                        conn.connect()

                    if conn.connected():
                        status, res = conn.execute_dict(u"""
SELECT
    db.oid as did, db.datname, db.datallowconn,
    pg_encoding_to_char(db.encoding) AS serverencoding,
    has_database_privilege(db.oid, 'CREATE') as cancreate, datlastsysoid
FROM
    pg_database db
WHERE db.oid = {0}""".format(did))

                        if status and len(res['rows']) > 0:
                            for row in res['rows']:
                                self.db_info[did] = row
                                database = self.db_info[did]['datname']

                        if did not in self.db_info:
                            raise Exception(
                                gettext(
                                    "Could not find the specified database."))

        if not get_crypt_key()[0]:
            # the reason its not connected might be missing key
            raise CryptKeyMissing()

        if database is None:
            # Check SSH Tunnel is alive or not.
            if self.use_ssh_tunnel == 1:
                self.check_ssh_tunnel_alive()
            else:
                raise ConnectionLost(self.sid, None, None)

        my_id = (u'CONN:{0}'.format(conn_id)) if conn_id is not None else \
            (u'DB:{0}'.format(database))

        self.pinged = datetime.datetime.now()

        if my_id in self.connections:
            return self.connections[my_id]
        else:
            if async_ is None:
                async_ = 1 if conn_id is not None else 0
            else:
                async_ = 1 if async_ is True else 0
            self.connections[my_id] = Connection(
                self,
                my_id,
                database,
                auto_reconnect,
                async_,
                use_binary_placeholder=use_binary_placeholder,
                array_to_string=array_to_string)

            return self.connections[my_id]
示例#3
0
FROM
    pg_database db
WHERE db.oid = {0}""".format(did))

                        if status and len(res['rows']) > 0:
                            for row in res['rows']:
                                self.db_info[did] = row
                                database = self.db_info[did]['datname']

                        if did not in self.db_info:
                            raise Exception(gettext(
                                "Could not find the specified database."
                            ))

        if database is None:
            raise ConnectionLost(self.sid, None, None)

        my_id = (u'CONN:{0}'.format(conn_id)) if conn_id is not None else \
            (u'DB:{0}'.format(database))

        self.pinged = datetime.datetime.now()

        if my_id in self.connections:
            return self.connections[my_id]
        else:
            if async is None:
                async = 1 if conn_id is not None else 0
            else:
                async = 1 if async is True else 0
            self.connections[my_id] = Connection(
                self, my_id, database, auto_reconnect, async,
示例#4
0
class StartRunningQueryTest(BaseTestGenerator):
    """
    Check that the start_running_query method works as intended
    """

    scenarios = [
        ('When gridData is not present in the session, it returns an error',
         dict(
             function_parameters=dict(sql=dict(sql='some sql',
                                               explain_plan=None),
                                      trans_id=123,
                                      http_session=dict()),
             pickle_load_return=None,
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=False,
             connection_connect_return=None,
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=dict(
                 success=0,
                 errormsg='Transaction ID not found in the session.',
                 info='DATAGRID_TRANSACTION_REQUIRED',
                 status=404,
             ),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When transactionId is not present in the gridData, '
         'it returns an error',
         dict(
             function_parameters=dict(sql=dict(sql='some sql',
                                               explain_plan=None),
                                      trans_id=123,
                                      http_session=dict(gridData=dict())),
             pickle_load_return=None,
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=False,
             connection_connect_return=None,
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=dict(
                 success=0,
                 errormsg='Transaction ID not found in the session.',
                 info='DATAGRID_TRANSACTION_REQUIRED',
                 status=404,
             ),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When the command information for the transaction '
         'cannot be retrieved, '
         'it returns an error',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=None,
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=False,
             connection_connect_return=None,
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=dict(data=dict(
                 status=False,
                 result='Either transaction object or session object '
                 'not found.',
                 can_edit=False,
                 can_filter=False,
                 info_notifier_timeout=5000,
                 notifies=None,
                 transaction_status=None)),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When exception happens while retrieving the database driver, '
         'it returns an error',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock()),
             get_driver_exception=True,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=False,
             connection_connect_return=None,
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=None,
             expect_internal_server_error_called_with=dict(
                 errormsg='get_driver exception'),
             expected_logger_error=get_driver_exception,
             expect_execute_void_called_with='some sql',
         )),
        ('When ConnectionLost happens while retrieving the '
         'database connection, '
         'it returns an error',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock()),
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=ConnectionLost('1', '2', '3'),
             is_connected_to_server=False,
             connection_connect_return=None,
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=None,
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When SSHTunnelConnectionLost happens while retrieving the '
         'database connection, '
         'it returns an error',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock()),
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=SSHTunnelConnectionLost('1.1.1.1'),
             is_connected_to_server=False,
             connection_connect_return=None,
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=None,
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When is not connected to the server and fails to connect, '
         'it returns an error',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock()),
             get_driver_exception=False,
             get_connection_lost_exception=True,
             manager_connection_exception=None,
             is_connected_to_server=False,
             connection_connect_return=[False, 'Unable to connect to server'],
             execute_async_return_value=None,
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=None,
             expect_internal_server_error_called_with=dict(
                 errormsg='Unable to connect to server'),
             expected_logger_error=get_connection_lost_exception,
             expect_execute_void_called_with='some sql',
         )),
        ('When server is connected and start query async request, '
         'it returns an success message',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock(),
                                          set_connection_id=MagicMock(),
                                          auto_commit=True,
                                          auto_rollback=False,
                                          can_edit=lambda: True,
                                          can_filter=lambda: True),
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=True,
             connection_connect_return=None,
             execute_async_return_value=[True, 'async function result output'],
             is_begin_required=False,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=dict(
                 data=dict(status=True,
                           result='async function result output',
                           can_edit=True,
                           can_filter=True,
                           info_notifier_timeout=5000,
                           notifies=None,
                           transaction_status=None)),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When server is connected and start query async request and '
         'begin is required, '
         'it returns an success message',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock(),
                                          set_connection_id=MagicMock(),
                                          auto_commit=True,
                                          auto_rollback=False,
                                          can_edit=lambda: True,
                                          can_filter=lambda: True),
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=True,
             connection_connect_return=None,
             execute_async_return_value=[True, 'async function result output'],
             is_begin_required=True,
             is_rollback_required=False,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=dict(
                 data=dict(status=True,
                           result='async function result output',
                           can_edit=True,
                           can_filter=True,
                           info_notifier_timeout=5000,
                           notifies=None,
                           transaction_status=None)),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When server is connected and start query async request and '
         'rollback is required, '
         'it returns an success message',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock(),
                                          set_connection_id=MagicMock(),
                                          auto_commit=True,
                                          auto_rollback=False,
                                          can_edit=lambda: True,
                                          can_filter=lambda: True),
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=True,
             connection_connect_return=None,
             execute_async_return_value=[True, 'async function result output'],
             is_begin_required=False,
             is_rollback_required=True,
             apply_explain_plan_wrapper_if_needed_return_value='some sql',
             expect_make_json_response_to_have_been_called_with=dict(
                 data=dict(status=True,
                           result='async function result output',
                           can_edit=True,
                           can_filter=True,
                           info_notifier_timeout=5000,
                           notifies=None,
                           transaction_status=None)),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='some sql',
         )),
        ('When server is connected and start query async request with '
         'explain plan wrapper, '
         'it returns an success message',
         dict(
             function_parameters=dict(
                 sql=dict(sql='some sql', explain_plan=None),
                 trans_id=123,
                 http_session=dict(gridData={'123': dict(command_obj='')})),
             pickle_load_return=MagicMock(conn_id=1,
                                          update_fetched_row_cnt=MagicMock(),
                                          set_connection_id=MagicMock(),
                                          auto_commit=True,
                                          auto_rollback=False,
                                          can_edit=lambda: True,
                                          can_filter=lambda: True),
             get_driver_exception=False,
             get_connection_lost_exception=False,
             manager_connection_exception=None,
             is_connected_to_server=True,
             connection_connect_return=None,
             execute_async_return_value=[True, 'async function result output'],
             is_begin_required=False,
             is_rollback_required=True,
             apply_explain_plan_wrapper_if_needed_return_value='EXPLAIN '
             'PLAN some sql',
             expect_make_json_response_to_have_been_called_with=dict(
                 data=dict(status=True,
                           result='async function result output',
                           can_edit=True,
                           can_filter=True,
                           info_notifier_timeout=5000,
                           notifies=None,
                           transaction_status=None)),
             expect_internal_server_error_called_with=None,
             expected_logger_error=None,
             expect_execute_void_called_with='EXPLAIN PLAN some sql',
         )),
    ]

    @patch('pgadmin.tools.sqleditor.utils.start_running_query'
           '.apply_explain_plan_wrapper_if_needed')
    @patch('pgadmin.tools.sqleditor.utils.start_running_query'
           '.make_json_response')
    @patch('pgadmin.tools.sqleditor.utils.start_running_query.pickle')
    @patch('pgadmin.tools.sqleditor.utils.start_running_query.get_driver')
    @patch('pgadmin.tools.sqleditor.utils.start_running_query'
           '.internal_server_error')
    @patch('pgadmin.tools.sqleditor.utils.start_running_query'
           '.update_session_grid_transaction')
    def runTest(self, update_session_grid_transaction_mock,
                internal_server_error_mock, get_driver_mock, pickle_mock,
                make_json_response_mock,
                apply_explain_plan_wrapper_if_needed_mock):
        """Check correct function is called to handle to run query."""
        self.connection = None

        self.loggerMock = MagicMock(error=MagicMock())
        expected_response = Response(
            response=json.dumps({'errormsg': 'some value'}))
        make_json_response_mock.return_value = expected_response
        if self.expect_internal_server_error_called_with is not None:
            internal_server_error_mock.return_value = expected_response
        pickle_mock.loads.return_value = self.pickle_load_return
        blueprint_mock = MagicMock(info_notifier_timeout=MagicMock(
            get=lambda: 5))

        # Save value for the later use
        self.is_begin_required_for_sql_query = \
            StartRunningQuery.is_begin_required_for_sql_query
        self.is_rollback_statement_required = \
            StartRunningQuery.is_rollback_statement_required

        if self.is_begin_required:
            StartRunningQuery.is_begin_required_for_sql_query = MagicMock(
                return_value=True)
        else:
            StartRunningQuery.is_begin_required_for_sql_query = MagicMock(
                return_value=False)
        if self.is_rollback_required:
            StartRunningQuery.is_rollback_statement_required = MagicMock(
                return_value=True)
        else:
            StartRunningQuery.is_rollback_statement_required = MagicMock(
                return_value=False)

        apply_explain_plan_wrapper_if_needed_mock.return_value = \
            self.apply_explain_plan_wrapper_if_needed_return_value

        manager = self.__create_manager()
        if self.get_driver_exception:
            get_driver_mock.side_effect = get_driver_exception
        elif self.get_connection_lost_exception:
            get_driver_mock.side_effect = get_connection_lost_exception
        else:
            get_driver_mock.return_value = MagicMock(
                connection_manager=lambda session_id: manager)

        try:
            result = StartRunningQuery(
                blueprint_mock,
                self.loggerMock).execute(**self.function_parameters)
            if self.manager_connection_exception is not None:
                self.fail('Exception: "' +
                          str(self.manager_connection_exception) +
                          '" excepted but not raised')

            self.assertEqual(result, expected_response)

        except AssertionError:
            raise
        except Exception as exception:
            self.assertEqual(self.manager_connection_exception, exception)

        self.__mock_assertions(internal_server_error_mock,
                               make_json_response_mock)
        if self.is_connected_to_server:
            apply_explain_plan_wrapper_if_needed_mock.assert_called_with(
                manager, self.function_parameters['sql'])

    def __create_manager(self):
        self.connection = MagicMock(
            connected=lambda: self.is_connected_to_server,
            connect=MagicMock(),
            execute_async=MagicMock(),
            execute_void=MagicMock(),
            get_notifies=MagicMock(),
            transaction_status=MagicMock(),
        )
        self.connection.connect.return_value = self.connection_connect_return
        self.connection.get_notifies.return_value = None
        self.connection.transaction_status.return_value = None
        self.connection.execute_async.return_value = \
            self.execute_async_return_value
        if self.manager_connection_exception is None:

            def connection_function(did, conn_id, use_binary_placeholder,
                                    array_to_string, auto_reconnect):
                return self.connection

            manager = MagicMock(connection=connection_function)

        else:
            manager = MagicMock()
            manager.connection.side_effect = self.manager_connection_exception

        return manager

    def __mock_assertions(self, internal_server_error_mock,
                          make_json_response_mock):
        if self.expect_make_json_response_to_have_been_called_with is not None:
            make_json_response_mock.assert_called_with(
                **self.expect_make_json_response_to_have_been_called_with)
        else:
            make_json_response_mock.assert_not_called()
        if self.expect_internal_server_error_called_with is not None:
            internal_server_error_mock.assert_called_with(
                **self.expect_internal_server_error_called_with)
        else:
            internal_server_error_mock.assert_not_called()
        if self.execute_async_return_value is not None:
            self.connection.execute_async.assert_called_with(
                self.expect_execute_void_called_with)
        else:
            self.connection.execute_async.assert_not_called()

        if self.expected_logger_error is not None:
            self.loggerMock.error.assert_called_with(
                self.expected_logger_error)
        else:
            self.loggerMock.error.assert_not_called()

        if self.is_begin_required:
            self.connection.execute_void.assert_called_with('BEGIN;')
        elif not self.is_rollback_required:
            self.connection.execute_void.assert_not_called()
        if self.is_rollback_required:
            self.connection.execute_void.assert_called_with('ROLLBACK;')
        elif not self.is_begin_required:
            self.connection.execute_void.assert_not_called()

    def tearDown(self):
        #  Reset methods to the original state
        StartRunningQuery.is_begin_required_for_sql_query = \
            staticmethod(self.is_begin_required_for_sql_query)
        StartRunningQuery.is_rollback_statement_required = \
            staticmethod(self.is_rollback_statement_required)
示例#5
0
    def connection(self, **kwargs):
        database = kwargs.get('database', None)
        conn_id = kwargs.get('conn_id', None)
        auto_reconnect = kwargs.get('auto_reconnect', True)
        did = kwargs.get('did', None)
        async_ = kwargs.get('async_', None)
        use_binary_placeholder = kwargs.get('use_binary_placeholder', False)
        array_to_string = kwargs.get('array_to_string', False)

        if database is not None:
            if did is not None and did in self.db_info:
                self.db_info[did]['datname'] = database
        else:
            if did is None:
                database = self.db
            elif did in self.db_info:
                database = self.db_info[did]['datname']
            else:
                maintenance_db_id = 'DB:{0}'.format(self.db)
                if maintenance_db_id in self.connections:
                    conn = self.connections[maintenance_db_id]
                    # try to connect maintenance db if not connected
                    if not conn.connected():
                        conn.connect()

                    if conn.connected():
                        status, res = conn.execute_dict("""
SELECT
    db.oid as did, db.datname, db.datallowconn,
    pg_catalog.pg_encoding_to_char(db.encoding) AS serverencoding,
    pg_catalog.has_database_privilege(db.oid, 'CREATE') as cancreate,
    datlastsysoid, datistemplate
FROM
    pg_catalog.pg_database db
WHERE db.oid = {0}""".format(did))

                        if status and len(res['rows']) > 0:
                            for row in res['rows']:
                                self.db_info[did] = row
                                database = self.db_info[did]['datname']

                        if did not in self.db_info:
                            raise ObjectGone(
                                gettext(
                                    "Could not find the specified database."))

        if not get_crypt_key()[0]:
            # the reason its not connected might be missing key
            raise CryptKeyMissing()

        if database is None:
            # Check SSH Tunnel is alive or not.
            if self.use_ssh_tunnel == 1:
                self.check_ssh_tunnel_alive()
            else:
                raise ConnectionLost(self.sid, None, None)

        my_id = ('CONN:{0}'.format(conn_id)) if conn_id is not None else \
            ('DB:{0}'.format(database))

        self.pinged = datetime.datetime.now()

        if my_id in self.connections:
            return self.connections[my_id]
        else:
            if async_ is None:
                async_ = 1 if conn_id is not None else 0
            else:
                async_ = 1 if async_ is True else 0
            self.connections[my_id] = Connection(
                self,
                my_id,
                database,
                auto_reconnect=auto_reconnect,
                async_=async_,
                use_binary_placeholder=use_binary_placeholder,
                array_to_string=array_to_string)

            return self.connections[my_id]