예제 #1
0
def execute_script(connection, script, log):
    connection = get_connection(connection)

    ranges = grt.modules.MysqlSqlFacade.getSqlStatementRanges(script)
    for start, length in ranges:
        if grt.query_status():
            raise grt.UserInterrupt()
        statement = script[start:start + length]
        try:
            grt.send_info("Execute statement", statement)
            grt.log_debug3("DbMySQLFE", "Execute %s\n" % statement)
            connection.execute(statement)
        except db_utils.QueryError, exc:
            if log:
                entry = grt.classes.GrtLogEntry()
                entry.owner = log
                entry.name = str(exc)
                entry.entryType = 2
                log.entries.append(entry)
            grt.send_warning("%s" % exc)
            grt.log_error("DbMySQLFE",
                          "Exception executing '%s': %s\n" % (statement, exc))
            return False
        except Exception, exc:
            if log:
                entry = grt.classes.GrtLogEntry()
                entry.owner = log
                entry.name = "Exception: " + str(exc)
                entry.entryType = 2
                log.entries.append(entry)
            grt.send_warning("Exception caught: %s" % exc)
            grt.log_error("DbMySQLFE",
                          "Exception executing '%s': %s\n" % (statement, exc))
            return False
예제 #2
0
 def filter_warnings(mtype, text, detail):
     # filter out parser warnings about stub creation/reuse from the message stream, since
     # they're harmless
     if mtype == "WARNING" and (" stub " in text or "Stub " in text):
         grt.send_info(text)
         return True
     return False
예제 #3
0
 def filter_warnings(mtype, text, detail):
     # filter out parser warnings about stub creation/reuse from the message stream, since
     # they're harmless
     if mtype == "WARNING" and (" stub " in text or "Stub " in text):
         grt.send_info(text)
         return True
     return False
예제 #4
0
def connect(conn, password=''):
    """ Establish a connection to a database and return a Python DB API 2.0 connection object.
    
    :param conn:      An instance of :class:`db_mgmt_Connection` that contains the needed parameters
                      to set the connection up. You must ensure that this object has a :attr:`driver`
                      attribute with a :attr:`driverLibraryName` attribute that specifies a python module
                      name that will be imported and its :meth:`connect` method called to actually perform
                      the connection.

    :type conn: db_mgmt_Connection

    :param password:  A password to authenticate the user specified in :attr:`conn` with (optional).

    :type password: string

    :returns: A Python DB API 2.0 connection object that can be used to communicate to the target RDBMS.
    """

    connection_string = get_odbc_connection_string(conn, password)
    import re
    connection_string_fixed = re.sub("(.*PWD=)([^;]*)(.*)", r"\1XXXX\3", connection_string)
    connection_string_fixed = re.sub("(.*PASSWORD=)([^;]*)(.*)", r"\1XXXX\3", connection_string_fixed)
    grt.send_info('Opening ODBC connection to %s...' % connection_string_fixed)

    library = __import__(conn.driver.driverLibraryName, globals(), locals())
    connection = library.connect(connection_string, password=password)

    return connection 
def execute_script(connection, script, log):
    connection = get_connection(connection)

    ranges = grt.modules.MysqlSqlFacade.getSqlStatementRanges(script)
    for start, length in ranges:
        if grt.query_status():
            raise grt.UserInterrupt()
        statement = script[start:start+length]
        try:
            grt.send_info("Execute statement", statement)
            grt.log_debug3("DbMySQLFE", "Execute %s\n" % statement)
            connection.execute(statement)
        except db_utils.QueryError, exc:
            if log:
                entry = grt.classes.GrtLogEntry()
                entry.owner = log
                entry.name = str(exc)
                entry.entryType = 2
                log.entries.append(entry)
            grt.send_warning("%s" % exc)
            grt.log_error("DbMySQLFE", "Exception executing '%s': %s\n" % (statement, exc))
            return False
        except Exception, exc:
            if log:
                entry = grt.classes.GrtLogEntry()
                entry.owner = log
                entry.name = "Exception: " + str(exc)
                entry.entryType = 2
                log.entries.append(entry)
            grt.send_warning("Exception caught: %s" % exc)
            grt.log_error("DbMySQLFE", "Exception executing '%s': %s\n" % (statement, exc))
            return False
예제 #6
0
 def reverseEngineerViews(cls, connection, schema):
     for view_name in cls.getViewNames(connection, schema.owner.name,
                                       schema.name):
         grt.send_info(
             '%s reverseEngineerViews: Cannot reverse engineer view "%s"' %
             (cls.getTargetDBMSName(), view_name))
     return 0
예제 #7
0
def connect(conn, password=''):
    """ Establish a connection to a database and return a Python DB API 2.0 connection object.
    
    :param conn:      An instance of :class:`db_mgmt_Connection` that contains the needed parameters
                      to set the connection up. You must ensure that this object has a :attr:`driver`
                      attribute with a :attr:`driverLibraryName` attribute that specifies a python module
                      name that will be imported and its :meth:`connect` method called to actually perform
                      the connection.

    :type conn: db_mgmt_Connection

    :param password:  A password to authenticate the user specified in :attr:`conn` with (optional).

    :type password: string

    :returns: A Python DB API 2.0 connection object that can be used to communicate to the target RDBMS.
    """

    connection_string = get_odbc_connection_string(conn, password)
    import re
    connection_string_fixed = re.sub("(.*PWD=)([^;]*)(.*)", r"\1XXXX\3",
                                     connection_string)
    connection_string_fixed = re.sub("(.*PASSWORD=)([^;]*)(.*)", r"\1XXXX\3",
                                     connection_string_fixed)
    grt.send_info('Opening ODBC connection to %s...' % connection_string_fixed)

    library = __import__(conn.driver.driverLibraryName, globals(), locals())
    connection = library.connect(connection_string, password=password)

    return connection
예제 #8
0
 def reverseEngineerTriggers(cls, connection, schema):
     # Unfortunately it seems that there's no way to get the SQL definition of a trigger with ODBC
     for trigger_name in cls.getTriggerNames(connection, schema.owner.name,
                                             schema.name):
         grt.send_info(
             '%s reverseEngineerTriggers: Cannot reverse engineer trigger "%s"'
             % (cls.getTargetDBMSName(), trigger_name))
     return 0
예제 #9
0
 def reverseEngineerFunctions(cls, connection, schema):
     # Unfortunately it seems that there's no way to get the SQL definition of a store procedure/function with ODBC
     for function_name in cls.getFunctionNames(connection,
                                               schema.owner.name,
                                               schema.name):
         grt.send_info(
             '%s reverseEngineerFunctions: Cannot reverse engineer function "%s"'
             % (cls.getTargetDBMSName(), function_name))
     return 0
예제 #10
0
 def task_post_processing(self):
     selected_option = self.main.plan.state.applicationData.get("schemaMappingMethod")
     # nothing needs to be done for drop_catalog
     if selected_option == "drop_schema":
         grt.send_info("Merging reverse engineered schema objects into a single schema...")
         self._merge_schemata()
     elif selected_option == "merge_schema":
         grt.send_info("Merging and renaming reverse engineered schema objects into a single schema...")
         self._merge_schemata(prefix='schema_name')
예제 #11
0
    def reverseEngineer(self):
        """Perform reverse engineering of selected schemas into the migration.sourceCatalog node"""
        self.connect()

        grt.send_info(
            "Reverse engineering %s from %s" %
            (", ".join(self.selectedSchemataNames), self.selectedCatalogName))
        self.state.sourceCatalog = self._rev_eng_module.reverseEngineer(
            self.connection, self.selectedCatalogName,
            self.selectedSchemataNames, self.state.applicationData)
예제 #12
0
def connect(connection, password):
    '''Establishes a connection to the server and stores the connection object in the connections pool.

    It first looks for a connection with the given connection parameters in the connections pool to
    reuse existent connections. If such connection is found it queries the server to ensure that the
    connection is alive and reestablishes it if is dead. If no suitable connection is found in the
    connections pool, a new one is created and stored in the pool.

    Parameters:
    ===========
        connection:  an object of the class db_mgmt_Connection storing the parameters
                     for the connection.
        password:    a string with the password to use for the connection.
    '''
    con = None
    host_identifier = connection.hostIdentifier
    try:
        con = get_connection(connection)
        try:
            if not con.cursor().execute('SELECT 1'):
                raise Exception("connection error")
        except Exception, exc:
            grt.send_info("Connection to %s apparently lost, reconnecting..." %
                          connection.hostIdentifier)
            raise NotConnectedError("Connection error")
    except NotConnectedError, exc:
        grt.send_info("Connecting to %s..." % host_identifier)
        import pyodbc
        try:
            con = db_driver.connect(connection, password)
            # Sybase metadata query SPs use things that don't work inside transactions, so enable autocommit
            con.autocommit = True

            # Adds data type conversion functions for pyodbc


#            if connection.driver.driverLibraryName == 'pyodbc':
#                cursor = con.cursor()
#                version = con.execute("SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)").fetchone()[0]
#                majorVersion = int(version.split('.', 1)[0])
#                if majorVersion >= 9:
#                    con.add_output_converter(-150, lambda value: value if value is None else value.decode('utf-16'))
#                    con.add_output_converter(0, lambda value: value if value is None else value.decode('utf-16'))
#                else:
#                    con.add_output_converter(-150, lambda value: value if value is None else str(value))
#                    con.add_output_converter(0, lambda value: value if value is None else str(value))

        except pyodbc.Error, odbc_err:
            # 28000 is from native SQL Server driver... 42000 seems to be from FreeTDS
            # FIXME: This should be tuned for Sybase
            if len(odbc_err.args) == 2 and odbc_err.args[0] in (
                    '28000', '42000') and "(18456)" in odbc_err.args[1]:
                raise grt.DBLoginError(odbc_err.args[1])
예제 #13
0
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection (ignored for SQLite).
        '''
        con = None
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception('connection error')
            except Exception, exc:
                grt.send_info(
                    'Connection to %s apparently lost, reconnecting...' %
                    connection.hostIdentifier)
                raise NotConnectedError('Connection error')
        except NotConnectedError, exc:
            grt.send_info('Connecting to %s...' % connection.hostIdentifier)
            if connection.driver.driverLibraryName == 'sqlanydb':
                import sqlanydbwrapper as sqlanydb  # Replace this to a direct sqlanydb import when it complies with PEP 249
                connstr = replace_string_parameters(
                    connection.driver.connectionStringTemplate,
                    dict(connection.parameterValues))
                import ast
                try:
                    all_params_dict = ast.literal_eval(connstr)
                except Exception, exc:
                    grt.send_error(
                        'The given connection string is not a valid python dict: %s'
                        % connstr)
                    raise
                # Remove unreplaced parameters:
                params = dict(
                    (key, value) for key, value in all_params_dict.iteritems()
                    if not (value.startswith('%') and value.endswith('%')))
                params['password'] = password
                conn_params = dict(params)
                conn_params['password'] = '******'
                connection.parameterValues[
                    'wbcopytables_connection_string'] = repr(conn_params)

                con = sqlanydb.connect(**params)
    def connect(cls, connection, password):
        """Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection (ignored for SQLite).
        """
        con = None
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute("SELECT 1"):
                    raise Exception("connection error")
            except Exception, exc:
                grt.send_info("Connection to %s apparently lost, reconnecting..." % connection.hostIdentifier)
                raise NotConnectedError("Connection error")
        except NotConnectedError, exc:
            grt.send_info("Connecting to %s..." % connection.hostIdentifier)
            if connection.driver.driverLibraryName == "sqlanydb":
                import sqlanydbwrapper as sqlanydb  # Replace this to a direct sqlanydb import when it complies with PEP 249

                connstr = replace_string_parameters(
                    connection.driver.connectionStringTemplate, dict(connection.parameterValues)
                )
                import ast

                try:
                    all_params_dict = ast.literal_eval(connstr)
                except Exception, exc:
                    grt.send_error("The given connection string is not a valid python dict: %s" % connstr)
                    raise
                # Remove unreplaced parameters:
                params = dict(
                    (key, value)
                    for key, value in all_params_dict.iteritems()
                    if not (value.startswith("%") and value.endswith("%"))
                )
                params["password"] = password
                conn_params = dict(params)
                conn_params["password"] = "******"
                connection.parameterValues["wbcopytables_connection_string"] = repr(conn_params)

                con = sqlanydb.connect(**params)
예제 #15
0
def connect(connection, password):
    '''Establishes a connection to the server and stores the connection object in the connections pool.

    It first looks for a connection with the given connection parameters in the connections pool to
    reuse existent connections. If such connection is found it queries the server to ensure that the
    connection is alive and reestablishes it if is dead. If no suitable connection is found in the
    connections pool, a new one is created and stored in the pool.

    Parameters:
    ===========
        connection:  an object of the class db_mgmt_Connection storing the parameters
                     for the connection.
        password:    a string with the password to use for the connection.
    '''
    con = None
    host_identifier = connection.hostIdentifier
    try:
        con = get_connection(connection)
        try:
            if not con.cursor().execute('SELECT 1'):
                raise Exception("connection error")
        except Exception, exc:
            grt.send_info("Connection to %s apparently lost, reconnecting..." % connection.hostIdentifier)
            raise NotConnectedError("Connection error")
    except NotConnectedError, exc:
        grt.send_info("Connecting to %s..." % host_identifier)
        import pyodbc
        try:
            con = db_driver.connect(connection, password)
            # Sybase metadata query SPs use things that don't work inside transactions, so enable autocommit
            con.autocommit = True

            # Adds data type conversion functions for pyodbc
#            if connection.driver.driverLibraryName == 'pyodbc':
#                cursor = con.cursor()
#                version = con.execute("SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)").fetchone()[0]
#                majorVersion = int(version.split('.', 1)[0])
#                if majorVersion >= 9:
#                    con.add_output_converter(-150, lambda value: value if value is None else value.decode('utf-16'))
#                    con.add_output_converter(0, lambda value: value if value is None else value.decode('utf-16'))
#                else:
#                    con.add_output_converter(-150, lambda value: value if value is None else str(value))
#                    con.add_output_converter(0, lambda value: value if value is None else str(value))

        except pyodbc.Error, odbc_err:
            # 28000 is from native SQL Server driver... 42000 seems to be from FreeTDS
            # FIXME: This should be tuned for Sybase
            if len(odbc_err.args) == 2 and odbc_err.args[0] in ('28000', '42000') and "(18456)" in odbc_err.args[1]:
                raise grt.DBLoginError(odbc_err.args[1])
예제 #16
0
def testInstanceSettingByName(what, server_instance):
    global test_ssh_connection
    profile = ServerProfile(server_instance)

    if what == "connect_to_host":
        if test_ssh_connection:
            test_ssh_connection = None

        print "Connecting to %s" % profile.ssh_hostname

        try:
            test_ssh_connection = wb_admin_control.WbAdminControl(
                profile, connect_sql=False)
            test_ssh_connection.init()
            grt.send_info("connected.")
        except Exception, exc:
            import traceback
            traceback.print_exc()
            return "ERROR " + str(exc)
        except:
예제 #17
0
def testInstanceSettingByName(what, server_instance):
    global test_ssh_connection
    print "What", what
    profile = ServerProfile(server_instance)

    if what == "connect_to_host":
        if test_ssh_connection:
            test_ssh_connection = None

        print "Connecting to %s" % profile.ssh_hostname

        try:
            test_ssh_connection = wb_admin_control.WbAdminControl(profile, connect_sql=False)
            test_ssh_connection.init()
            grt.send_info("connected.")
        except Exception, exc:
            import traceback

            traceback.print_exc()
            return "ERROR " + str(exc)
        except:
예제 #18
0
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such a connection is found, it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection.
        '''
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception("connection error")
            except Exception as exc:
                grt.send_info("Connection to %s apparently lost, reconnecting..." % connection.hostIdentifier)
                raise NotConnectedError("Connection error")
        except NotConnectedError as exc:
            grt.send_info("Connecting to %s..." % connection.hostIdentifier)
            con = db_driver.connect(connection, password)
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            grt.send_info("Connected")
            cls._connections[connection.__id__] = {"connection": con}
        return 1
예제 #19
0
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection (ignored for SQLite).
        '''
        con = None
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception('connection error')
            except Exception, exc:
                grt.send_info('Connection to %s apparently lost, reconnecting...' % connection.hostIdentifier)
                raise NotConnectedError('Connection error')
        except NotConnectedError, exc:
            grt.send_info('Connecting to %s...' % connection.hostIdentifier)
            con = sqlite3.connect(connection.parameterValues['dbfile'])
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            connection.parameterValues['wbcopytables_connection_string'] = "'" + connection.parameterValues['dbfile'] + "'"
            grt.send_info('Connected')
            cls._connections[connection.__id__] = {'connection': con}
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection.
        '''
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception("connection error")
            except Exception, exc:
                grt.send_info("Connection to %s apparently lost, reconnecting..." % connection.hostIdentifier)
                raise NotConnectedError("Connection error")
        except NotConnectedError, exc:
            grt.send_info("Connecting to %s..." % connection.hostIdentifier)
            con = db_driver.connect(connection, password)
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            grt.send_info("Connected")
            cls._connections[connection.__id__] = {"connection": con}
예제 #21
0
def testInstanceSettingByName(what, connection, server_instance):
    global test_ssh_connection
    log_debug("Test %s in %s\n" % (what, connection.name))

    profile = ServerProfile(connection, server_instance)
    if what == "connect_to_host":
        if test_ssh_connection:
            test_ssh_connection = None

        log_info("Instance test: Connecting to %s\n" % profile.ssh_hostname)

        try:
            test_ssh_connection = wb_admin_control.WbAdminControl(profile, None, connect_sql=False, test_only = True)
            test_ssh_connection.init()
                
            grt.send_info("connected.")
        except Exception, exc:
            log_error("Exception: %s" % exc.message)
            import traceback
            log_debug2("Backtrace was: " % traceback.format_stack())
            return "ERROR "+str(exc)
        except:
예제 #22
0
def testInstanceSettingByName(what, connection, server_instance):
    global test_ssh_connection
    log_debug("Test %s in %s\n" % (what, connection.name))

    profile = ServerProfile(connection, server_instance)
    if what == "connect_to_host":
        if test_ssh_connection:
            test_ssh_connection = None

        log_info("Instance test: Connecting to %s\n" % profile.ssh_hostname)

        try:
            test_ssh_connection = wb_admin_control.WbAdminControl(profile, None, connect_sql=False, test_only=True)
            test_ssh_connection.init()

            grt.send_info("connected.")
        except Exception, exc:
            log_error("Exception: %s\n" % exc.message)
            import traceback
            log_debug2("Backtrace was: ", traceback.format_stack())
            return "ERROR "+str(exc)
        except:
예제 #23
0
def testInstanceSettingByName(what, connection, server_instance):
    global test_ssh_connection

    log_debug(_this_file, "Test %s in %s\n" % (what, connection.name))

    profile = ServerProfile(connection, server_instance)

    if what == "connect_to_host":
        if test_ssh_connection:
            test_ssh_connection = None

        log_info(_this_file, "Instance test: Connecting to %s\n" % profile.ssh_hostname)

        try:
            test_ssh_connection = wb_admin_control.WbAdminControl(profile, connect_sql=False)
            test_ssh_connection.init()
            grt.send_info("connected.")
        except Exception, exc:
            import traceback
            traceback.print_exc()
            return "ERROR "+str(exc)
        except:
예제 #24
0
def connect(connection, password):
    try:
        con = get_connection(connection)
        try:
            con.ping()
        except Exception:
            grt.send_info("Reconnecting to %s..." % connection.hostIdentifier)
            con.disconnect()
            con.connect()
            grt.send_info("Connection restablished")
    except NotConnectedError:
        con = db_utils.MySQLConnection(connection, password=password)
        grt.send_info("Connecting to %s..." % connection.hostIdentifier)
        con.connect()
        grt.send_info("Connected")
        _connections[connection.__id__] = con
    return 1
예제 #25
0
def connect(connection, password):
    try:
        con = get_connection(connection)
        try:
            con.ping()
        except Exception:
            grt.send_info("Reconnecting to %s..." % connection.hostIdentifier)
            con.disconnect()
            con.connect()
            grt.send_info("Connection restablished")
    except NotConnectedError:
        con = db_utils.MySQLConnection(connection, password=password)
        grt.send_info("Connecting to %s..." % connection.hostIdentifier)
        con.connect()
        grt.send_info("Connected")
        _connections[connection.__id__] = con
    return 1
예제 #26
0
def connect(connection, password):
    try:
        con = get_connection(connection)
        try:
            con.ping()
        except Exception:
            grt.send_info("Reconnecting to %s..." % connection.hostIdentifier)
            con.disconnect()
            con.connect()
            grt.send_info("Connection restablished")
    except NotConnectedError:
        con = MySQLConnection(connection, password = password)
        host_identifier = connection.hostIdentifier
        grt.send_info("Connecting to %s..." % host_identifier)
        con.connect()        
        _connections[connection.__id__] = con
        version = "Unknown version"
        result = execute_query(connection, "SHOW VARIABLES LIKE 'version'")
        if result and result.nextRow():
            version = result.stringByIndex(2)
        grt.send_info("Connected to %s, %s" % (host_identifier, version))
    return 1
def connect(connection, password):
    try:
        con = get_connection(connection)
        try:
            con.ping()
        except Exception:
            grt.send_info("Reconnecting to %s..." % connection.hostIdentifier)
            con.disconnect()
            con.connect()
            grt.send_info("Connection restablished")
    except NotConnectedError:
        con = MySQLConnection(connection, password=password)
        host_identifier = connection.hostIdentifier
        grt.send_info("Connecting to %s..." % host_identifier)
        con.connect()
        _connections[connection.__id__] = con
        version = "Unknown version"
        result = execute_query(connection, "SHOW VARIABLES LIKE 'version'")
        if result and result.nextRow():
            version = result.stringByIndex(2)
        grt.send_info("Connected to %s, %s" % (host_identifier, version))
    return 1
예제 #28
0
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection (ignored for SQLite).
        '''
        con = None
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception('connection error')
            except Exception as exc:
                grt.send_info(
                    'Connection to %s apparently lost, reconnecting...' %
                    connection.hostIdentifier)
                raise NotConnectedError('Connection error')
        except NotConnectedError as exc:
            grt.send_info('Connecting to %s...' % connection.hostIdentifier)
            con = sqlite3.connect(connection.parameterValues['dbfile'])
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            connection.parameterValues[
                'wbcopytables_connection_string'] = "'" + connection.parameterValues[
                    'dbfile'] + "'"
            grt.send_info('Connected')
            cls._connections[connection.__id__] = {'connection': con}
        if con:
            ver = cls.execute_query(connection,
                                    "SELECT sqlite_version()").fetchone()[0]
            grt.log_info('SQLite RE',
                         'Connected to %s, %s' % (connection.name, ver))
            ver_parts = server_version_str2tuple(ver) + (0, 0, 0, 0)
            version = grt.classes.GrtVersion()
            version.majorNumber, version.minorNumber, version.releaseNumber, version.buildNumber = ver_parts[:
                                                                                                             4]
            cls._connections[connection.__id__]['version'] = version
        return 1
예제 #29
0
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection (ignored for SQLite).
        '''
        con = None
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception('connection error')
            except Exception, exc:
                grt.send_info(
                    'Connection to %s apparently lost, reconnecting...' %
                    connection.hostIdentifier)
                raise NotConnectedError('Connection error')
        except NotConnectedError, exc:
            grt.send_info('Connecting to %s...' % connection.hostIdentifier)
            con = sqlite3.connect(connection.parameterValues['dbfile'])
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            connection.parameterValues[
                'wbcopytables_connection_string'] = "'" + connection.parameterValues[
                    'dbfile'] + "'"
            grt.send_info('Connected')
            cls._connections[connection.__id__] = {'connection': con}
예제 #30
0
def testInstanceSettingByName(what, connection, server_instance):
    global test_ssh_connection
    log_debug("Test %s in %s\n" % (what, connection.name))

    profile = ServerProfile(connection, server_instance)
    if what == "connect_to_host":
        if test_ssh_connection:
            test_ssh_connection = None

        log_info("Instance test: Connecting to %s\n" % profile.ssh_hostname)

        try:
            test_ssh_connection = wb_admin_control.WbAdminControl(
                profile, None, connect_sql=False, test_only=True)
            test_ssh_connection.init()

            grt.send_info("connected.")
        except Exception as exc:
            log_error("Exception: %s\n" % str(exc))
            import traceback
            log_debug2("Backtrace was: ", traceback.format_stack())
            return "ERROR " + str(exc)
        except:
            return "ERROR"

        try:
            test_ssh_connection.acquire_admin_access()
        except Exception as exc:

            log_error("Exception: %s\n" % str(exc))
            import traceback
            log_debug2("Backtrace was: " % traceback.format_stack())
            return "ERROR " + str(exc)

        os_info = test_ssh_connection.detect_operating_system_version()
        if os_info:
            os_type, os_name, os_variant, os_version = os_info
            log_info("Instance test: detected remote OS: %s (%s), %s, %s\n" %
                     (os_info))

            # check if the admin access error was because of wrong OS set
            if os_type != profile.target_os:
                return "ERROR Wrong Remote OS configured for connection. Set to %s, but was detected as %s" % (
                    profile.target_os, os_type)
        else:
            log_warning(
                "Instance test: could not determine OS version information\n")

            return "ERROR Could not determine remote OS details"

        return "OK"

    elif what == "disconnect":
        if test_ssh_connection:
            test_ssh_connection = None
        return "OK"

    elif what == "check_privileges":
        return "ERROR"

    elif what in ("find_config_file", "check_config_path",
                  "check_config_section"):
        config_file = profile.config_file_path
        print("Check if %s exists in remote host" % config_file)
        try:
            if not test_ssh_connection.ssh.fileExists(config_file):
                return "ERROR File %s doesn't exist" % config_file
            else:
                print("File was found in expected location")
        except IOError:
            return 'ERROR Could not verify the existence of the file %s' % config_file

        if what == "check_config_path":
            return "OK"

        section = profile.config_file_section
        cfg_file_content = ""
        print("Check if %s section exists in %s" % (section, config_file))
        try:
            #local_file = test_ssh_connection.fetch_file(config_file)
            cfg_file_content = test_ssh_connection.server_helper.get_file_content(
                path=config_file)
        except Exception as exc:
            import traceback
            traceback.print_exc()
            return "ERROR " + str(exc)

        if ("[" + section + "]") in cfg_file_content:
            return "OK"
        return "ERROR Couldn't find section %s in the remote config file %s" % (
            section, config_file)

    elif what in ("find_config_file/local", "check_config_path/local",
                  "check_config_section/local"):
        config_file = profile.config_file_path
        config_file = wb_admin_control.WbAdminControl(
            profile, None,
            connect_sql=False).expand_path_variables(config_file)
        print("Check if %s can be accessed" % config_file)
        if os.path.exists(config_file):
            print("File was found at the expected location")
        else:
            return "ERROR File %s doesn't exist" % config_file

        if what == "check_config_path/local":
            return "OK"

        section = profile.config_file_section
        print("Check if section for instance %s exists in %s" %
              (section, config_file))
        if check_if_config_file_has_section(open(config_file, "r"), section):
            print("[%s] section found in configuration file" % section)
            return "OK"
        return "ERROR Couldn't find section [%s] in the config file %s" % (
            section, config_file)

    elif what == "find_error_files":
        return "ERROR"

    elif what == "check_admin_commands":
        path = profile.start_server_cmd
        cmd_start = None
        if path.startswith("/"):
            cmd_start = path.split()[0]
            if not test_ssh_connection.ssh.fileExists(cmd_start):
                return "ERROR %s is invalid" % path

        path = profile.stop_server_cmd
        if path.startswith("/"):
            cmd = path.split()[0]
            if cmd != cmd_start and not test_ssh_connection.ssh.fileExists(
                    cmd):
                return "ERROR %s is invalid" % path

        return "OK"

    elif what == "check_admin_commands/local":
        path = profile.start_server_cmd
        cmd_start = None
        if path.startswith("/"):
            cmd_start = path.split()[0]
            if not os.path.exists(cmd_start):
                return "ERROR %s is invalid" % path

        path = profile.stop_server_cmd
        if path.startswith("/"):
            cmd = path.split()[0]
            if cmd != cmd_start and not os.path.exists(cmd):
                return "ERROR %s is invalid" % path

        return "OK"

    return "ERROR bad command"
예제 #31
0
        except pyodbc.Error, odbc_err:
            # 28000 is from native SQL Server driver... 42000 seems to be from FreeTDS
            # FIXME: This should be tuned for Sybase
            if len(odbc_err.args) == 2 and odbc_err.args[0] in (
                    '28000', '42000') and "(18456)" in odbc_err.args[1]:
                raise grt.DBLoginError(odbc_err.args[1])

        if not con:
            grt.send_error('Connection failed', str(exc))
            raise

        _connections[connection.__id__] = {"connection": con}
        _connections[connection.__id__]["version"] = getServerVersion(
            connection)
        version = execute_query(connection, "SELECT @@version").fetchone()[0]
        grt.send_info("Connected to %s, %s", (host_identifier, version))
    return 1


@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def disconnect(connection):
    if connection.__id__ in _connections:
        del _connections[
            connection.
            __id__]  # pyodbc cursors are automatically closed when deleted
    return 0


@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def isConnected(connection):
    return 1 if connection.__id__ in _connections else 0
    def reverseEngineer(cls, connection, catalog_name, schemata_list, context):
        grt.send_progress(0, "Reverse engineering catalog information")
        cls.check_interruption()
        catalog = cls.reverseEngineerCatalog(connection, catalog_name)

        # calculate total workload 1st
        grt.send_progress(0.1, 'Preparing...')
        table_count_per_schema = {}
        view_count_per_schema = {}
        routine_count_per_schema = {}
        trigger_count_per_schema = {}
        total_count_per_schema = {}

        get_tables = context.get("reverseEngineerTables", True)
        get_triggers = context.get("reverseEngineerTriggers", True)
        get_views = context.get("reverseEngineerViews", True)
        get_routines = context.get("reverseEngineerRoutines", True)

        # 10% of the progress is for preparation
        total = 1e-10  # total should not be zero to avoid DivisionByZero exceptions
        i = 0.0
        accumulated_progress = 0.1
        for schema_name in schemata_list:
            cls.check_interruption()
            table_count_per_schema[schema_name] = len(cls.getTableNames(connection, catalog_name, schema_name)) if get_tables else 0
            view_count_per_schema[schema_name] = len(cls.getViewNames(connection, catalog_name, schema_name)) if get_views else 0
            cls.check_interruption()
            routine_count_per_schema[schema_name] = len(cls.getProcedureNames(connection, catalog_name, schema_name)) + len(cls.getFunctionNames(connection, catalog_name, schema_name)) if get_routines else 0
            trigger_count_per_schema[schema_name] = len(cls.getTriggerNames(connection, catalog_name, schema_name)) if get_triggers else 0

            total_count_per_schema[schema_name] = (table_count_per_schema[schema_name] + view_count_per_schema[schema_name] +
                                                   routine_count_per_schema[schema_name] + trigger_count_per_schema[schema_name] + 1e-10)
            total += total_count_per_schema[schema_name]

            grt.send_progress(accumulated_progress + 0.1 * (i / (len(schemata_list) + 1e-10) ), "Gathered stats for %s" % schema_name)
            i += 1.0

        # Now take 60% in the first pass of reverse engineering:
        accumulated_progress = 0.2
        for schema_name in schemata_list:
            schema_progress_share = 0.6 * (total_count_per_schema.get(schema_name, 0.0) / total)
            schema = find_object_with_name(catalog.schemata, schema_name) 

            if schema:
                # Reverse engineer tables:
                step_progress_share = schema_progress_share * (table_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_tables:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering tables from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    # Remove previous first pass marks that may exist if the user goes back and attempt rev eng again:
                    progress_flags = cls._connections[connection.__id__].setdefault('_rev_eng_progress_flags', set())
                    progress_flags.discard('%s_tables_first_pass' % schema_name)
                    cls.reverseEngineerTables(connection, schema)
                    grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'First pass of table reverse engineering for schema %s completed!' % schema_name)
        
                # Reverse engineer views:
                step_progress_share = schema_progress_share * (view_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_views:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering views from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    cls.reverseEngineerViews(connection, schema)
                    grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'Reverse engineering of views for schema %s completed!' % schema_name)
        
                # Reverse engineer routines:
                step_progress_share = schema_progress_share * (routine_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_routines:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering routines from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    grt.begin_progress_step(0.0, 0.5)
                    cls.reverseEngineerProcedures(connection, schema)
                    cls.check_interruption()
                    grt.end_progress_step()
                    grt.begin_progress_step(0.5, 1.0)
                    reverseEngineerFunctions(connection, schema)
                    grt.end_progress_step()
                    grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'Reverse engineering of routines for schema %s completed!' % schema_name)
        
                # Reverse engineer triggers:
                step_progress_share = schema_progress_share * (trigger_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_triggers:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering triggers from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    cls.reverseEngineerTriggers(connection, schema)
                    grt.end_progress_step()
        
                accumulated_progress = 0.8
                grt.send_progress(accumulated_progress, 'Reverse engineering of triggers for schema %s completed!' % schema_name)
            else:  # No schema with the given name was found
                grt.send_warning('The schema %s was not found in the catalog %s. Skipping it.' % (schema_name, catalog_name) )
                
        # Now the second pass for reverse engineering tables:
        if get_tables:
            total_tables = sum(table_count_per_schema[schema.name] for schema in catalog.schemata if schema.name in schemata_list)
            for schema in catalog.schemata:
                if schema.name not in schemata_list:
                    continue
                cls.check_interruption()
                step_progress_share = 0.2 * (table_count_per_schema[schema.name] / (total_tables + 1e-10))
                grt.send_info('Reverse engineering foreign keys for tables in schema %s' % schema.name)
                grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                cls.reverseEngineerTables(connection, schema)
                grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'Second pass of table reverse engineering for schema %s completed!' % schema_name)
            

        grt.send_progress(1.0, 'Reverse engineering completed!')
        return catalog
예제 #33
0
#                    con.add_output_converter(0, lambda value: value if value is None else str(value))

        except pyodbc.Error, odbc_err:
            # 28000 is from native SQL Server driver... 42000 seems to be from FreeTDS
            # FIXME: This should be tuned for Sybase
            if len(odbc_err.args) == 2 and odbc_err.args[0] in ('28000', '42000') and "(18456)" in odbc_err.args[1]:
                raise grt.DBLoginError(odbc_err.args[1])

        if not con:
            grt.send_error('Connection failed', str(exc))
            raise
        
        _connections[connection.__id__] = {"connection" : con }
        _connections[connection.__id__]["version"] = getServerVersion(connection)
        version  = execute_query(connection, "SELECT @@version").fetchone()[0]
        grt.send_info("Connected to %s, %s", (host_identifier, version))
    return 1


@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def disconnect(connection):
    if connection.__id__ in _connections:
        del _connections[connection.__id__]  # pyodbc cursors are automatically closed when deleted
    return 0


@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def isConnected(connection):
    return 1 if connection.__id__ in _connections else 0

예제 #34
0
    def connect(cls, connection, password):
        '''Establishes a connection to the server and stores the connection object in the connections pool.

        It first looks for a connection with the given connection parameters in the connections pool to
        reuse existent connections. If such connection is found it queries the server to ensure that the
        connection is alive and reestablishes it if is dead. If no suitable connection is found in the
        connections pool, a new one is created and stored in the pool.

        Parameters:
        ===========
            connection:  an object of the class db_mgmt_Connection storing the parameters
                         for the connection.
            password:    a string with the password to use for the connection (ignored for SQLite).
        '''
        con = None
        try:
            con = cls.get_connection(connection)
            try:
                if not con.cursor().execute('SELECT 1'):
                    raise Exception('connection error')
            except Exception as exc:
                grt.send_info(
                    'Connection to %s apparently lost, reconnecting...' %
                    connection.hostIdentifier)
                raise NotConnectedError('Connection error')
        except NotConnectedError as exc:
            grt.send_info('Connecting to %s...' % connection.hostIdentifier)
            if connection.driver.driverLibraryName == 'sqlanydb':
                import sqlanydbwrapper as sqlanydb  # Replace this to a direct sqlanydb import when it complies with PEP 249
                connstr = replace_string_parameters(
                    connection.driver.connectionStringTemplate,
                    dict(connection.parameterValues))
                import ast
                try:
                    all_params_dict = ast.literal_eval(connstr)
                except Exception as exc:
                    grt.send_error(
                        'The given connection string is not a valid python dict: %s'
                        % connstr)
                    raise
                # Remove unreplaced parameters:
                params = dict(
                    (key, value)
                    for key, value in list(all_params_dict.items())
                    if not (value.startswith('%') and value.endswith('%')))
                params['password'] = password
                conn_params = dict(params)
                conn_params['password'] = '******'
                connection.parameterValues[
                    'wbcopytables_connection_string'] = repr(conn_params)

                con = sqlanydb.connect(**params)
            else:
                con = db_driver.connect(connection, password)
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            grt.send_info('Connected')
            cls._connections[connection.__id__] = {'connection': con}
        if con:
            ver = cls.execute_query(connection,
                                    "SELECT @@version").fetchone()[0]
            grt.log_info("SQLAnywhere RE",
                         "Connected to %s, %s\n" % (connection.name, ver))
            ver_parts = server_version_str2tuple(ver) + (0, 0, 0, 0)
            version = grt.classes.GrtVersion()
            version.majorNumber, version.minorNumber, version.releaseNumber, version.buildNumber = ver_parts[:
                                                                                                             4]
            cls._connections[connection.__id__]["version"] = version

        return 1
예제 #35
0
    try:
        import time
        f = urllib2.urlopen(url)
    except Exception, exc:
        raise
    try:
        file_size = int(f.info().getheaders("Content-Length")[0])
    except:
       file_size = -1
    try:
        outf = open(destPath, "w+b")
    except Exception, exc:
        raise Exception("Can't create file "+destPath+":"+str(exc))

    fetched = 0.0
    grt.send_info("0:%i:Downloading..." % file_size)
    while True:
        buf = f.read(8192)
        if not buf:
            break
        fetched += len(buf)
        outf.write(buf)
        grt.send_info("%i:%i:Downloading..."%(fetched,file_size))
    outf.close()
    grt.send_info("%i:%i:Finished"%(fetched,file_size))
    return destPath


@ModuleInfo.export(grt.INT, grt.STRING)
def checkUpdate(currentVersion):
    try:
예제 #36
0
def reverseEngineer(connection, catalog_name, schemata_list, context):
    catalog = grt.classes.db_mysql_Catalog()
    catalog.name = catalog_name
    catalog.simpleDatatypes.remove_all()
    catalog.simpleDatatypes.extend(connection.driver.owner.simpleDatatypes)
    
    table_names_per_schema = {}
    routine_names_per_schema = {}
    trigger_names_per_schema = {}
    
    def filter_warnings(mtype, text, detail):
        # filter out parser warnings about stub creation/reuse from the message stream, since
        # they're harmless
        if mtype == "WARNING" and (" stub " in text or "Stub " in text):
            grt.send_info(text)
            return True
        return False
    
    version = getServerVersion(connection)
    
    get_tables = context.get("reverseEngineerTables", True)
    get_triggers = context.get("reverseEngineerTriggers", True) and (version.majorNumber, version.minorNumber, version.releaseNumber) >= (5, 1, 21)
    get_views = context.get("reverseEngineerViews", True)
    get_routines = context.get("reverseEngineerRoutines", True)
    
    # calculate total workload 1st
    
    # 10% of the progress is for preparation
    
    grt.send_progress(0, "Preparing...")
    total = 0
    i = 0.0
    for schema_name in schemata_list:
        check_interruption()
        if get_tables and get_views:
            table_names = getAllTableNames(connection, catalog_name, schema_name)
        elif get_tables:
            table_names = getTableNames(connection, catalog_name, schema_name)
        elif get_views:
            table_names = getViewNames(connection, catalog_name, schema_name)
        else:
            table_name = []
        total += len(table_names)
        table_names_per_schema[schema_name] = table_names
        check_interruption()
        if get_routines:
            procedure_names = getProcedureNames(connection, catalog_name, schema_name)
            check_interruption()
            function_names = getFunctionNames(connection, catalog_name, schema_name)
            check_interruption()
            total += len(procedure_names)
            total += len(function_names)
            routine_names_per_schema[schema_name] = procedure_names, function_names
        else:
            routine_names_per_schema[schema_name] = [], []
        if get_triggers:
            trigger_names = getTriggerNames(connection, catalog_name, schema_name)
            total += len(trigger_names)
        else:
            trigger_names = []
        trigger_names_per_schema[schema_name] = trigger_names
        
        grt.send_progress(0.1 * (i/len(schemata_list)), "Preparing...")
        i += 1.0

    def wrap_sql(sql, schema):
        return "USE `%s`;\n%s"%(escape_sql_identifier(schema), sql)

    def wrap_routine_sql(sql):
        return "DELIMITER $$\n"+sql

    i = 0.0
    for schema_name in schemata_list:
        schema = grt.classes.db_mysql_Schema()
        schema.owner = catalog
        schema.name = schema_name
        catalog.schemata.append(schema)

        if get_tables or get_views:
            grt.send_info("Reverse engineering tables from %s" % schema_name)
            for table_name in table_names_per_schema[schema_name]:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving table %s.%s..." % (schema_name, table_name))
                result = execute_query(connection, "SHOW CREATE TABLE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(table_name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, table_name))
                if result and result.nextRow():
                    sql = result.stringByIndex(2)
                    grt.push_message_handler(filter_warnings)
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MysqlSqlFacade.parseSqlScriptString(catalog, wrap_sql(sql, schema_name))
                    grt.end_progress_step()
                    grt.pop_message_handler()
                    i += 0.5
                else:
                    raise Exception("Could not fetch table information for %s.%s" % (schema_name, table_name))

        if get_triggers:
            grt.send_info("Reverse engineering triggers from %s" % schema_name)
            for trigger_name in trigger_names_per_schema[schema_name]:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving trigger %s.%s..." % (schema_name, trigger_name))
                result = execute_query(connection, "SHOW CREATE TRIGGER `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(trigger_name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, trigger_name))
                if result and result.nextRow():
                    sql = result.stringByName("SQL Original Statement")
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MysqlSqlFacade.parseSqlScriptString(catalog, wrap_sql(wrap_routine_sql(sql), schema_name))
                    grt.end_progress_step()
                    i += 0.5
                else:
                    raise Exception("Could not fetch trigger information for %s.%s" % (schema_name, trigger_name))
        
        if get_routines:
            grt.send_info("Reverse engineering stored procedures from %s" % schema_name)
            procedure_names, function_names = routine_names_per_schema[schema_name]
            for name in procedure_names:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving stored procedure %s.%s..." % (schema_name, name))
                result = execute_query(connection, "SHOW CREATE PROCEDURE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, name))
                if result and result.nextRow():
                    sql = result.stringByName("Create Procedure")
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MysqlSqlFacade.parseSqlScriptString(catalog, wrap_sql(wrap_routine_sql(sql), schema_name))
                    grt.end_progress_step()
                    i += 0.5
                else:
                    raise Exception("Could not fetch procedure information for %s.%s" % (schema_name, name))

            grt.send_info("Reverse engineering functions from %s" % schema_name)
            for name in function_names:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving function %s.%s..." % (schema_name, name))
                result = execute_query(connection, "SHOW CREATE FUNCTION `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, name))
                if result and result.nextRow():
                    sql = result.stringByName("Create Function")
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MysqlSqlFacade.parseSqlScriptString(catalog, wrap_sql(wrap_routine_sql(sql), schema_name))
                    grt.end_progress_step()
                    i += 0.5
                else:
                    raise Exception("Could not fetch function information for %s.%s" % (schema_name, name))

    grt.send_progress(1.0, "Reverse engineered %i objects" % total)
    
    # check for any stub tables left
    empty_schemas = []
    for schema in catalog.schemata:
        schema_has_stub_tables = False
        for table in reversed(schema.tables):
            if table.isStub:
                grt.send_warning("Table %s was referenced from another table, but was not reverse engineered" % table.name)
                schema.tables.remove(table)
                schema_has_stub_tables = True
        if not schema.tables and not schema.views and not schema.routines and schema_has_stub_tables:
            empty_schemas.append(schema)
    for schema in empty_schemas:
        catalog.schemata.remove(schema)

    return catalog
예제 #37
0
                # Remove unreplaced parameters:
                params = dict( (key, value) for key, value in all_params_dict.iteritems()
                                            if not (value.startswith('%') and value.endswith('%'))
                             )
                params['password'] = password
                conn_params = dict(params)
                conn_params['password'] = '******'
                connection.parameterValues['wbcopytables_connection_string'] = repr(conn_params)
                
                con = sqlanydb.connect(**params)
            else:
                con = db_driver.connect(connection, password)
            if not con:
                grt.send_error('Connection failed', str(exc))
                raise
            grt.send_info('Connected')
            cls._connections[connection.__id__] = {'connection': con}
        if con:
            ver = cls.execute_query(connection, "SELECT @@version").fetchone()[0]
            grt.log_info("SQLAnywhere RE", "Connected to %s, %s\n" % (connection.name, ver))
            ver_parts = server_version_str2tuple(ver) + (0, 0, 0, 0)
            version = grt.classes.GrtVersion()
            version.majorNumber, version.minorNumber, version.releaseNumber, version.buildNumber = ver_parts[:4]
            cls._connections[connection.__id__]["version"] = version

        return 1

    @classmethod
    @release_cursors
    def getCatalogNames(cls, connection):
        """Returns a list of the available catalogs.
    def _merge_schemata(self, prefix=''):
        catalog = self.main.plan.migrationSource.catalog
        schema = catalog.schemata[0]
        # preserve the original name of the catalog
        schema.oldName = schema.name

        module_db = self.main.plan.migrationSource.module_db()

        # otypes is something like ['tables', 'views', 'routines']:
        otypes = [ suptype[0] for suptype in self.main.plan.migrationSource.supportedObjectTypes ]

        # Update names for the objects of this first schema:
        if prefix:
            actual_prefix = (schema.name if prefix == 'schema_name' else schema.__id__) + '_'
            for otype in otypes:
                for obj in getattr(schema, otype):
                    # this will be used later during datacopy to refer to the original object to copy from
                    obj.oldName = module_db.quoteIdentifier(schema.oldName)+"."+module_db.quoteIdentifier(obj.name)
                    oname = obj.name
                    obj.name = actual_prefix + obj.name
                    grt.send_info("Object %s was renamed to %s" % (oname, obj.name))
        else:
            for otype in otypes:
                for obj in getattr(schema, otype):
                    # this will be used later during datacopy to refer to the original object to copy from
                    obj.oldName = module_db.quoteIdentifier(schema.name)+"."+module_db.quoteIdentifier(obj.name)

        schema.name = catalog.name
        if not prefix:
            known_names = dict( (otype, set(obj.name for obj in getattr(schema, otype))) for otype in otypes)

        for other_schema in list(catalog.schemata)[1:]:
            if other_schema.defaultCharacterSetName != schema.defaultCharacterSetName:
                grt.send_warning('While merging schema %s into %s: Default charset for schemata differs (%s vs %s). Setting default charset to %s' % (other_schema.name, schema.name, other_schema.defaultCharacterSetName, schema.defaultCharacterSetName, schema.defaultCharacterSetName))
                self.main.plan.state.addMigrationLogEntry(0, other_schema, None,
                      'While merging schema %s into %s: Default charset for schemata differs (%s vs %s). Setting default charset to %s' % (other_schema.name, schema.name, other_schema.defaultCharacterSetName, schema.defaultCharacterSetName, schema.defaultCharacterSetName))

            if other_schema.defaultCollationName != schema.defaultCollationName:
                grt.send_warning('While merging schema %s into %s: Default collation for schemata differs (%s vs %s). Setting default collation to %s' % (other_schema.name, schema.name, other_schema.defaultCollationName, schema.defaultCollationName, schema.defaultCollationName))
                self.main.plan.state.addMigrationLogEntry(0, other_schema, None,
                      'While merging schema %s into %s: Default collation for schemata differs (%s vs %s). Setting default collation to %s' % (other_schema.name, schema.name, other_schema.defaultCollationName, schema.defaultCollationName, schema.defaultCollationName))

            for otype in otypes:
                other_objects = getattr(other_schema, otype)
                if not prefix:
                    repeated_object_names = known_names[otype].intersection(obj.name for obj in other_objects)
                    if repeated_object_names:
                        objects_dict = dict( (obj.name, obj) for obj in other_objects )
                        for repeated_object_name in repeated_object_names:
                            objects_dict[repeated_object_name].name += '_' + other_schema.name
                            grt.send_warning('The name of the %(otype)s "%(oname)s" conflicts with other %(otype)s names: renamed to "%(onewname)s"' % { 'otype':otype[:-1],
                                                                  'oname':repeated_object_name,
                                                                  'onewname':objects_dict[repeated_object_name].name })
        
                            self.main.plan.state.addMigrationLogEntry(0, other_schema, None,
                                  'The name of the %(otype)s "%(oname)s" conflicts with other %(otype)s names: renamed to "%(onewname)s"' % { 'otype':otype[:-1],
                                                                                                                                              'oname':repeated_object_name,
                                                                                                                                              'onewname':objects_dict[repeated_object_name].name }
                                                                      )
                        known_names[otype].update(other_objects)
                else:
                    actual_prefix = (other_schema.name if prefix == 'schema_name' else schema.__id__) + '_'

                getattr(schema, otype).extend(other_objects)
                for obj in other_objects:
                    # this will be used later during datacopy to refer to the original object to copy from
                    obj.oldName = module_db.quoteIdentifier(obj.owner.name)+"."+module_db.quoteIdentifier(obj.name)
                    
                    obj.owner = schema
                    if prefix:
                        oname = obj.name
                        obj.name = actual_prefix + obj.name
                        grt.send_info("Object %s was renamed to %s" % (oname, obj.name))

        # Keep only the merged schema:
        catalog.schemata.remove_all()
        catalog.schemata.append(schema)
예제 #39
0
    try:
        import time
        f = urllib2.urlopen(url)
    except Exception, exc:
        raise
    try:
        file_size = int(f.info().getheaders("Content-Length")[0])
    except:
        file_size = -1
    try:
        outf = open(destPath, "w+b")
    except Exception, exc:
        raise Exception("Can't create file " + destPath + ":" + str(exc))

    fetched = 0.0
    grt.send_info("0:%i:Downloading..." % file_size)
    while True:
        buf = f.read(8192)
        if not buf:
            break
        fetched += len(buf)
        outf.write(buf)
        grt.send_info("%i:%i:Downloading..." % (fetched, file_size))
    outf.close()
    grt.send_info("%i:%i:Finished" % (fetched, file_size))
    return destPath


@ModuleInfo.export(grt.INT, grt.STRING)
def checkUpdate(currentVersion):
    try:
예제 #40
0
def reverseEngineer(connection, catalog_name, schemata_list, context):
    catalog = grt.classes.db_mysql_Catalog()
    catalog.name = catalog_name
    catalog.simpleDatatypes.remove_all()
    catalog.simpleDatatypes.extend(connection.driver.owner.simpleDatatypes)
    
    table_names_per_schema = {}
    routine_names_per_schema = {}
    trigger_names_per_schema = {}
    
    def filter_warnings(mtype, text, detail):
        # filter out parser warnings about stub creation/reuse from the message stream, since
        # they're harmless
        if mtype == "WARNING" and (" stub " in text or "Stub " in text):
            grt.send_info(text)
            return True
        return False
    
    version = getServerVersion(connection)
    
    get_tables = context.get("reverseEngineerTables", True)
    get_triggers = context.get("reverseEngineerTriggers", True) and (version.majorNumber, version.minorNumber, version.releaseNumber) >= (5, 1, 21)
    get_views = context.get("reverseEngineerViews", True)
    get_routines = context.get("reverseEngineerRoutines", True)
    
    # calculate total workload 1st
    
    # 10% of the progress is for preparation
    
    grt.send_progress(0, "Preparing...")
    total = 0
    i = 0.0
    for schema_name in schemata_list:
        check_interruption()
        if get_tables and get_views:
            table_names = getAllTableNames(connection, catalog_name, schema_name)
        elif get_tables:
            table_names = getTableNames(connection, catalog_name, schema_name)
        elif get_views:
            table_names = getViewNames(connection, catalog_name, schema_name)
        else:
            table_name = []
        total += len(table_names)
        table_names_per_schema[schema_name] = table_names
        check_interruption()
        if get_routines:
            procedure_names = getProcedureNames(connection, catalog_name, schema_name)
            check_interruption()
            function_names = getFunctionNames(connection, catalog_name, schema_name)
            check_interruption()
            total += len(procedure_names)
            total += len(function_names)
            routine_names_per_schema[schema_name] = procedure_names, function_names
        else:
            routine_names_per_schema[schema_name] = [], []
        if get_triggers:
            trigger_names = getTriggerNames(connection, catalog_name, schema_name)
            total += len(trigger_names)
        else:
            trigger_names = []
        trigger_names_per_schema[schema_name] = trigger_names
        
        grt.send_progress(0.1 * (i/len(schemata_list)), "Preparing...")
        i += 1.0

    def wrap_sql(sql, schema):
        return "USE `%s`;\n%s"%(escape_sql_identifier(schema), sql)

    def wrap_routine_sql(sql):
        return "DELIMITER $$\n"+sql

    i = 0.0
    for schema_name in schemata_list:
        schema = grt.classes.db_mysql_Schema()
        schema.owner = catalog
        schema.name = schema_name
        catalog.schemata.append(schema)
        context = grt.modules.MySQLParserServices.createParserContext(catalog.characterSets, getServerVersion(connection), getServerMode(connection), 1)
        options = {}

        if get_tables or get_views:
            grt.send_info("Reverse engineering tables from %s" % schema_name)
            for table_name in table_names_per_schema[schema_name]:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving table %s.%s..." % (schema_name, table_name))
                result = execute_query(connection, "SHOW CREATE TABLE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(table_name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, table_name))
                if result and result.nextRow():
                    sql = result.stringByIndex(2)
                    grt.push_message_handler(filter_warnings)
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(sql, schema_name), options)
                    grt.end_progress_step()
                    grt.pop_message_handler()
                    i += 0.5
                else:
                    raise Exception("Could not fetch table information for %s.%s" % (schema_name, table_name))

        if get_triggers:
            grt.send_info("Reverse engineering triggers from %s" % schema_name)
            for trigger_name in trigger_names_per_schema[schema_name]:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving trigger %s.%s..." % (schema_name, trigger_name))
                result = execute_query(connection, "SHOW CREATE TRIGGER `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(trigger_name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, trigger_name))
                if result and result.nextRow():
                    sql = result.stringByName("SQL Original Statement")
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(wrap_routine_sql(sql), schema_name), options)
                    grt.end_progress_step()
                    i += 0.5
                else:
                    raise Exception("Could not fetch trigger information for %s.%s" % (schema_name, trigger_name))
        
        if get_routines:
            grt.send_info("Reverse engineering stored procedures from %s" % schema_name)
            procedure_names, function_names = routine_names_per_schema[schema_name]
            for name in procedure_names:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving stored procedure %s.%s..." % (schema_name, name))
                result = execute_query(connection, "SHOW CREATE PROCEDURE `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, name))
                if result and result.nextRow():
                    sql = result.stringByName("Create Procedure")
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(wrap_routine_sql(sql), schema_name), options)
                    grt.end_progress_step()
                    i += 0.5
                else:
                    raise Exception("Could not fetch procedure information for %s.%s" % (schema_name, name))

            grt.send_info("Reverse engineering functions from %s" % schema_name)
            for name in function_names:
                check_interruption()
                grt.send_progress(0.1 + 0.9 * (i / total), "Retrieving function %s.%s..." % (schema_name, name))
                result = execute_query(connection, "SHOW CREATE FUNCTION `%s`.`%s`" % (escape_sql_identifier(schema_name), escape_sql_identifier(name)))
                i += 0.5
                grt.send_progress(0.1 + 0.9 * (i / total), "Reverse engineering %s.%s..." % (schema_name, name))
                if result and result.nextRow():
                    sql = result.stringByName("Create Function")
                    grt.begin_progress_step(0.1 + 0.9 * (i / total), 0.1 + 0.9 * ((i+0.5) / total))
                    grt.modules.MySQLParserServices.parseSQLIntoCatalogSql(context, catalog, wrap_sql(wrap_routine_sql(sql), schema_name), options)
                    grt.end_progress_step()
                    i += 0.5
                else:
                    raise Exception("Could not fetch function information for %s.%s" % (schema_name, name))

    grt.send_progress(1.0, "Reverse engineered %i objects" % total)
    
    # check for any stub tables left
    empty_schemas = []
    for schema in catalog.schemata:
        schema_has_stub_tables = False
        for table in reversed(schema.tables):
            if table.isStub:
                grt.send_warning("Table %s was referenced from another table, but was not reverse engineered" % table.name)
                schema.tables.remove(table)
                schema_has_stub_tables = True
        if not schema.tables and not schema.views and not schema.routines and schema_has_stub_tables:
            empty_schemas.append(schema)
    for schema in empty_schemas:
        catalog.schemata.remove(schema)

    return catalog
예제 #41
0
#                    con.add_output_converter(0, lambda value: value if value is None else value.decode('utf-16'))
#                else:
#                    con.add_output_converter(-150, lambda value: value if value is None else str(value))
#                    con.add_output_converter(0, lambda value: value if value is None else str(value))

        except pyodbc.Error, odbc_err:
            # 28000 is from native SQL Server driver... 42000 seems to be from FreeTDS
            # FIXME: This should be tuned for Sybase
            if len(odbc_err.args) == 2 and odbc_err.args[0] in (
                    '28000', '42000') and "(18456)" in odbc_err.args[1]:
                raise grt.DBLoginError(odbc_err.args[1])

        if not con:
            grt.send_error('Connection failed', str(exc))
            raise
        grt.send_info("Connected")

        _connections[connection.__id__] = {"connection": con}
        _connections[connection.__id__]["version"] = getServerVersion(
            connection)
    return 1


@ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection)
def disconnect(connection):
    if connection.__id__ in _connections:
        del _connections[
            connection.
            __id__]  # pyodbc cursors are automatically closed when deleted
    return 0
예제 #42
0
 def reverseEngineer(self):
     """Perform reverse engineering of selected schemas into the migration.sourceCatalog node"""
     self.connect()
     
     grt.send_info("Reverse engineering %s from %s" % (", ".join(self.selectedSchemataNames), self.selectedCatalogName))
     self.state.sourceCatalog = self._rev_eng_module.reverseEngineer(self.connection, self.selectedCatalogName, self.selectedSchemataNames, self.state.applicationData)
예제 #43
0
def reverseEngineer(connection, catalog_name, schemata_list, options):
    """Reverse engineers a Sybase ASE database.

    This is the function that will be called by the Migration Wizard to reverse engineer
    a Sybase database. All the other reverseEngineer* functions are not actually required
    and should not be considered part of this module API even though they are currently
    being exposed. This function calls the other reverseEngineer* functions to complete
    the full reverse engineer process.
    """
    grt.send_progress(0, "Reverse engineering catalog information")
    catalog = grt.classes.db_sybase_Catalog()
    catalog.name = catalog_name
    catalog.simpleDatatypes.remove_all()
    catalog.simpleDatatypes.extend(connection.driver.owner.simpleDatatypes)
    catalog.defaultCollationName = ''  #   FIXME: Find out the right collation for the catalog

    grt.send_progress(0.05, "Reverse engineering User Data Types...")
    check_interruption()  #
    reverseEngineerUserDatatypes(connection, catalog)

    # calculate total workload 1st
    grt.send_progress(0.1, 'Preparing...')
    table_count_per_schema = {}
    view_count_per_schema = {}
    routine_count_per_schema = {}
    trigger_count_per_schema = {}
    total_count_per_schema = {}

    get_tables = options.get("reverseEngineerTables", True)
    get_triggers = options.get("reverseEngineerTriggers", True)
    get_views = options.get("reverseEngineerViews", True)
    get_routines = options.get("reverseEngineerRoutines", True)

    # 10% of the progress is for preparation
    total = 1e-10  # total should not be zero to avoid DivisionByZero exceptions
    i = 1.0
    accumulated_progress = 0.1
    for schema_name in schemata_list:
        check_interruption()
        table_count_per_schema[schema_name] = len(
            getTableNames(connection, catalog_name,
                          schema_name)) if get_tables else 0
        view_count_per_schema[schema_name] = len(
            getViewNames(connection, catalog_name,
                         schema_name)) if get_views else 0
        check_interruption()
        routine_count_per_schema[schema_name] = len(
            getProcedureNames(connection, catalog_name, schema_name)) + len(
                getFunctionNames(connection, catalog_name,
                                 schema_name)) if get_routines else 0
        trigger_count_per_schema[schema_name] = len(
            getTriggerNames(connection, catalog_name,
                            schema_name)) if get_triggers else 0

        total_count_per_schema[schema_name] = (
            table_count_per_schema[schema_name] +
            view_count_per_schema[schema_name] +
            routine_count_per_schema[schema_name] +
            trigger_count_per_schema[schema_name] + 1e-10)
        total += total_count_per_schema[schema_name]

        grt.send_progress(
            accumulated_progress + 0.1 * (i / (len(schemata_list) + 1e-10)),
            "Gathered stats for %s" % schema_name)
        i += 1.0

    # Now take 60% in the first pass of reverse engineering:
    accumulated_progress = 0.2
    grt.reset_progress_steps()
    grt.begin_progress_step(accumulated_progress, accumulated_progress + 0.6)
    accumulated_schema_progress = 0.0
    for schema_name in schemata_list:
        schema_progress_share = total_count_per_schema.get(schema_name,
                                                           0.0) / total

        grt.begin_progress_step(
            accumulated_schema_progress,
            accumulated_schema_progress + schema_progress_share)

        this_schema_progress = 0.0

        schema = grt.classes.db_sybase_Schema()
        schema.owner = catalog
        schema.name = schema_name
        schema.defaultCollationName = catalog.defaultCollationName
        catalog.schemata.append(schema)

        # Reverse engineer tables:
        step_progress_share = table_count_per_schema[schema_name] / (
            total_count_per_schema[schema_name] + 1e-10)
        if get_tables:
            check_interruption()
            grt.send_info('Reverse engineering %i tables from %s' %
                          (table_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress,
                                    this_schema_progress + step_progress_share)
            # Remove previous first pass marks that may exist if the user goes back and attempt rev eng again:
            progress_flags = _connections[connection.__id__].setdefault(
                '_rev_eng_progress_flags', set())
            progress_flags.discard('%s_tables_first_pass' % schema_name)
            reverseEngineerTables(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(
            this_schema_progress,
            'First pass of table reverse engineering for schema %s completed!'
            % schema_name)

        # Reverse engineer views:
        step_progress_share = view_count_per_schema[schema_name] / (
            total_count_per_schema[schema_name] + 1e-10)
        if get_views:
            check_interruption()
            grt.send_info('Reverse engineering %i views from %s' %
                          (view_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress,
                                    this_schema_progress + step_progress_share)
            reverseEngineerViews(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(
            this_schema_progress,
            'Reverse engineering of views for schema %s completed!' %
            schema_name)

        # Reverse engineer routines:
        step_progress_share = routine_count_per_schema[schema_name] / (
            total_count_per_schema[schema_name] + 1e-10)
        if get_routines:
            check_interruption()
            grt.send_info('Reverse engineering %i routines from %s' %
                          (routine_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(
                this_schema_progress,
                this_schema_progress + step_progress_share / 2)
            schema.routines.remove_all()
            reverseEngineerProcedures(connection, schema)
            grt.end_progress_step()
            check_interruption()
            grt.begin_progress_step(
                this_schema_progress + step_progress_share / 2,
                this_schema_progress + step_progress_share)
            reverseEngineerFunctions(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(
            this_schema_progress,
            'Reverse engineering of routines for schema %s completed!' %
            schema_name)

        # Reverse engineer triggers:
        step_progress_share = trigger_count_per_schema[schema_name] / (
            total_count_per_schema[schema_name] + 1e-10)
        if get_triggers:
            check_interruption()
            grt.send_info('Reverse engineering %i triggers from %s' %
                          (trigger_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress,
                                    this_schema_progress + step_progress_share)
            reverseEngineerTriggers(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(
            this_schema_progress,
            'Reverse engineering of triggers for schema %s completed!' %
            schema_name)

        accumulated_schema_progress += schema_progress_share
        grt.end_progress_step()

    grt.end_progress_step()

    # Now the second pass for reverse engineering tables:
    accumulated_progress = 0.8
    if get_tables:
        total_tables = sum(table_count_per_schema[schema.name]
                           for schema in catalog.schemata)
        for schema in catalog.schemata:
            check_interruption()
            step_progress_share = 0.2 * (table_count_per_schema[schema.name] /
                                         (total_tables + 1e-10))
            grt.send_info(
                'Reverse engineering foreign keys for tables in schema %s' %
                schema.name)
            grt.begin_progress_step(accumulated_progress,
                                    accumulated_progress + step_progress_share)
            reverseEngineerTables(connection, schema)
            grt.end_progress_step()

            accumulated_progress += step_progress_share
            grt.send_progress(
                accumulated_progress,
                'Second pass of table reverse engineering for schema %s completed!'
                % schema_name)

    grt.send_progress(1.0, 'Reverse engineering completed!')
    return catalog
예제 #44
0
    def reverseEngineer(cls, connection, catalog_name, schemata_list, context):
        grt.send_progress(0, "Reverse engineering catalog information")
        cls.check_interruption()
        catalog = cls.reverseEngineerCatalog(connection, catalog_name)

        # calculate total workload 1st
        grt.send_progress(0.1, 'Preparing...')
        table_count_per_schema = {}
        view_count_per_schema = {}
        routine_count_per_schema = {}
        trigger_count_per_schema = {}
        total_count_per_schema = {}

        get_tables = context.get("reverseEngineerTables", True)
        get_triggers = context.get("reverseEngineerTriggers", True)
        get_views = context.get("reverseEngineerViews", True)
        get_routines = context.get("reverseEngineerRoutines", True)

        # 10% of the progress is for preparation
        total = 1e-10  # total should not be zero to avoid DivisionByZero exceptions
        i = 0.0
        accumulated_progress = 0.1
        for schema_name in schemata_list:
            cls.check_interruption()
            table_count_per_schema[schema_name] = len(cls.getTableNames(connection, catalog_name, schema_name)) if get_tables else 0
            view_count_per_schema[schema_name] = len(cls.getViewNames(connection, catalog_name, schema_name)) if get_views else 0
            cls.check_interruption()
            routine_count_per_schema[schema_name] = len(cls.getProcedureNames(connection, catalog_name, schema_name)) + len(cls.getFunctionNames(connection, catalog_name, schema_name)) if get_routines else 0
            trigger_count_per_schema[schema_name] = len(cls.getTriggerNames(connection, catalog_name, schema_name)) if get_triggers else 0

            total_count_per_schema[schema_name] = (table_count_per_schema[schema_name] + view_count_per_schema[schema_name] +
                                                   routine_count_per_schema[schema_name] + trigger_count_per_schema[schema_name] + 1e-10)
            total += total_count_per_schema[schema_name]

            grt.send_progress(accumulated_progress + 0.1 * (i / (len(schemata_list) + 1e-10) ), "Gathered stats for %s" % schema_name)
            i += 1.0

        # Now take 60% in the first pass of reverse engineering:
        accumulated_progress = 0.2
        for schema_name in schemata_list:
            schema_progress_share = 0.6 * (total_count_per_schema.get(schema_name, 0.0) / total)
            schema = find_object_with_name(catalog.schemata, schema_name) 

            if schema:
                # Reverse engineer tables:
                step_progress_share = schema_progress_share * (table_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_tables:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering tables from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    # Remove previous first pass marks that may exist if the user goes back and attempt rev eng again:
                    progress_flags = cls._connections[connection.__id__].setdefault('_rev_eng_progress_flags', set())
                    progress_flags.discard('%s_tables_first_pass' % schema_name)
                    cls.reverseEngineerTables(connection, schema)
                    grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'First pass of table reverse engineering for schema %s completed!' % schema_name)
        
                # Reverse engineer views:
                step_progress_share = schema_progress_share * (view_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_views:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering views from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    cls.reverseEngineerViews(connection, schema)
                    grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'Reverse engineering of views for schema %s completed!' % schema_name)
        
                # Reverse engineer routines:
                step_progress_share = schema_progress_share * (routine_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_routines:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering routines from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    grt.begin_progress_step(0.0, 0.5)
                    cls.reverseEngineerProcedures(connection, schema)
                    cls.check_interruption()
                    grt.end_progress_step()
                    grt.begin_progress_step(0.5, 1.0)
                    reverseEngineerFunctions(connection, schema)
                    grt.end_progress_step()
                    grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'Reverse engineering of routines for schema %s completed!' % schema_name)
        
                # Reverse engineer triggers:
                step_progress_share = schema_progress_share * (trigger_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10))
                if get_triggers:
                    cls.check_interruption()
                    grt.send_info('Reverse engineering triggers from %s' % schema_name)
                    grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                    cls.reverseEngineerTriggers(connection, schema)
                    grt.end_progress_step()
        
                accumulated_progress = 0.8
                grt.send_progress(accumulated_progress, 'Reverse engineering of triggers for schema %s completed!' % schema_name)
            else:  # No schema with the given name was found
                grt.send_warning('The schema %s was not found in the catalog %s. Skipping it.' % (schema_name, catalog_name) )
                
        # Now the second pass for reverse engineering tables:
        if get_tables:
            total_tables = sum(table_count_per_schema[schema.name] for schema in catalog.schemata if schema.name in schemata_list)
            for schema in catalog.schemata:
                if schema.name not in schemata_list:
                    continue
                cls.check_interruption()
                step_progress_share = 0.2 * (table_count_per_schema[schema.name] / (total_tables + 1e-10))
                grt.send_info('Reverse engineering foreign keys for tables in schema %s' % schema.name)
                grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
                cls.reverseEngineerTables(connection, schema)
                grt.end_progress_step()
        
                accumulated_progress += step_progress_share
                grt.send_progress(accumulated_progress, 'Second pass of table reverse engineering for schema %s completed!' % schema_name)
            

        grt.send_progress(1.0, 'Reverse engineering completed!')
        return catalog
예제 #45
0
def createCatalogObjects(connection, catalog, objectCreationParams, creationLog):
    """Create catalog objects in the server for the specified connection. The catalog must have been 
    previously processed with generateSQLCreateStatements(), so that the objects have their temp_sql 
    attributes set with their respective SQL CREATE statements.
    """

    def makeLogObject(obj):
        if creationLog is not None:
            log = grt.classes.GrtLogObject()
            log.logObject = obj
            creationLog.append(log)
            return log
        else:
            return None
    
    try:
        grt.send_progress(0.0, "Creating schema in target MySQL server at %s..." % connection.hostIdentifier)
        
        preamble = catalog.customData["migration:preamble"]
        grt.send_progress(0.0, "Executing preamble script...")
        execute_script(connection, preamble.temp_sql, makeLogObject(preamble))

        i = 0.0
        for schema in catalog.schemata:
            grt.begin_progress_step(i, i + 1.0 / len(catalog.schemata))
            i += 1.0 / len(catalog.schemata)

            if schema.commentedOut:
                grt.send_progress(1.0, "Skipping schema %s... " % schema.name)
                grt.end_progress_step()
                continue

            total = len(schema.tables) + len(schema.views) + len(schema.routines) + sum([len(table.triggers) for table in schema.tables])

            grt.send_progress(0.0, "Creating schema %s..." % schema.name)
            execute_script(connection, schema.temp_sql, makeLogObject(schema))

            tcount = 0
            vcount = 0
            rcount = 0
            trcount = 0
            o = 0
            for table in schema.tables:
                if table.commentedOut:
                    grt.send_progress(float(o) / total, "Skipping table %s.%s" % (schema.name, table.name))
                else:
                    grt.send_progress(float(o) / total, "Creating table %s.%s" % (schema.name, table.name))
                o += 1
                if not table.commentedOut and execute_script(connection, table.temp_sql, makeLogObject(table)):
                    tcount += 1

            for view in schema.views:
                if view.commentedOut:
                    grt.send_progress(float(o) / total, "Skipping view %s.%s" % (schema.name, view.name))
                else:
                    grt.send_progress(float(o) / total, "Creating view %s.%s" % (schema.name, view.name))
                o += 1
                if not view.commentedOut and execute_script(connection, view.temp_sql, makeLogObject(view)):
                    vcount += 1

            for routine in schema.routines:
                if routine.commentedOut:
                    grt.send_progress(float(o) / total, "Skipping routine %s.%s" % (schema.name, routine.name))
                else:
                    grt.send_progress(float(o) / total, "Creating routine %s.%s" % (schema.name, routine.name))
                o += 1
                if not routine.commentedOut and execute_script(connection, routine.temp_sql, makeLogObject(routine)):
                    rcount += 1

            for table in schema.tables:
                for trigger in table.triggers:
                    if trigger.commentedOut:
                        grt.send_progress(float(o) / total, "Skipping trigger %s.%s.%s" % (schema.name, table.name, trigger.name))
                    else:
                        grt.send_progress(float(o) / total, "Creating trigger %s.%s.%s" % (schema.name, table.name, trigger.name))
                    o += 1
                    if not trigger.commentedOut and execute_script(connection, trigger.temp_sql, makeLogObject(trigger)):
                        trcount += 1

            grt.send_info("Scripts for %i tables, %i views and %i routines were executed for schema %s" % (tcount, vcount, rcount, schema.name))
            grt.end_progress_step()

        postamble = catalog.customData["migration:postamble"]
        grt.send_progress(1.0, "Executing postamble script...")
        execute_script(connection, postamble.temp_sql, makeLogObject(postamble))

        grt.send_progress(1.0, "Schema created")
    except grt.UserInterrupt:
        grt.send_info("Cancellation request detected, interrupting schema creation.")
        raise
    
    return 1
예제 #46
0
    def _merge_schemata(self, prefix=''):
        catalog = self.main.plan.migrationSource.catalog
        schema = catalog.schemata[0]
        # preserve the original name of the catalog
        schema.oldName = schema.name

        module_db = self.main.plan.migrationSource.module_db()

        # otypes is something like ['tables', 'views', 'routines']:
        otypes = [
            suptype[0]
            for suptype in self.main.plan.migrationSource.supportedObjectTypes
        ]

        # Update names for the objects of this first schema:
        if prefix:
            actual_prefix = (schema.name if prefix == 'schema_name' else
                             schema.__id__) + '_'
            for otype in otypes:
                for obj in getattr(schema, otype):
                    # this will be used later during datacopy to refer to the original object to copy from
                    obj.oldName = module_db.quoteIdentifier(
                        schema.oldName) + "." + module_db.quoteIdentifier(
                            obj.name)
                    oname = obj.name
                    obj.name = actual_prefix + obj.name
                    grt.send_info("Object %s was renamed to %s" %
                                  (oname, obj.name))
        else:
            for otype in otypes:
                for obj in getattr(schema, otype):
                    # this will be used later during datacopy to refer to the original object to copy from
                    obj.oldName = module_db.quoteIdentifier(
                        schema.name) + "." + module_db.quoteIdentifier(
                            obj.name)

        schema.name = catalog.name
        if not prefix:
            known_names = dict(
                (otype, set(obj.name for obj in getattr(schema, otype)))
                for otype in otypes)

        for other_schema in list(catalog.schemata)[1:]:
            if other_schema.defaultCharacterSetName != schema.defaultCharacterSetName:
                grt.send_warning(
                    'While merging schema %s into %s: Default charset for schemas differs (%s vs %s). Setting default charset to %s'
                    % (other_schema.name, schema.name,
                       other_schema.defaultCharacterSetName,
                       schema.defaultCharacterSetName,
                       schema.defaultCharacterSetName))
                self.main.plan.state.addMigrationLogEntry(
                    0, other_schema, None,
                    'While merging schema %s into %s: Default charset for schemas differs (%s vs %s). Setting default charset to %s'
                    % (other_schema.name, schema.name,
                       other_schema.defaultCharacterSetName,
                       schema.defaultCharacterSetName,
                       schema.defaultCharacterSetName))

            if other_schema.defaultCollationName != schema.defaultCollationName:
                grt.send_warning(
                    'While merging schema %s into %s: Default collation for schemas differs (%s vs %s). Setting default collation to %s'
                    %
                    (other_schema.name, schema.name,
                     other_schema.defaultCollationName,
                     schema.defaultCollationName, schema.defaultCollationName))
                self.main.plan.state.addMigrationLogEntry(
                    0, other_schema, None,
                    'While merging schema %s into %s: Default collation for schemas differs (%s vs %s). Setting default collation to %s'
                    %
                    (other_schema.name, schema.name,
                     other_schema.defaultCollationName,
                     schema.defaultCollationName, schema.defaultCollationName))

            for otype in otypes:
                other_objects = getattr(other_schema, otype)
                if not prefix:
                    repeated_object_names = known_names[otype].intersection(
                        obj.name for obj in other_objects)
                    if repeated_object_names:
                        objects_dict = dict(
                            (obj.name, obj) for obj in other_objects)
                        for repeated_object_name in repeated_object_names:
                            objects_dict[
                                repeated_object_name].name += '_' + other_schema.name
                            grt.send_warning(
                                'The name of the %(otype)s "%(oname)s" conflicts with other %(otype)s names: renamed to "%(onewname)s"'
                                % {
                                    'otype':
                                    otype[:-1],
                                    'oname':
                                    repeated_object_name,
                                    'onewname':
                                    objects_dict[repeated_object_name].name
                                })

                            self.main.plan.state.addMigrationLogEntry(
                                0, other_schema, None,
                                'The name of the %(otype)s "%(oname)s" conflicts with other %(otype)s names: renamed to "%(onewname)s"'
                                % {
                                    'otype':
                                    otype[:-1],
                                    'oname':
                                    repeated_object_name,
                                    'onewname':
                                    objects_dict[repeated_object_name].name
                                })
                        known_names[otype].update(other_objects)
                else:
                    actual_prefix = (other_schema.name if prefix
                                     == 'schema_name' else schema.__id__) + '_'

                getattr(schema, otype).extend(other_objects)
                for obj in other_objects:
                    # this will be used later during datacopy to refer to the original object to copy from
                    obj.oldName = module_db.quoteIdentifier(
                        obj.owner.name) + "." + module_db.quoteIdentifier(
                            obj.name)

                    obj.owner = schema
                    if prefix:
                        oname = obj.name
                        obj.name = actual_prefix + obj.name
                        grt.send_info("Object %s was renamed to %s" %
                                      (oname, obj.name))

        # Keep only the merged schema:
        catalog.schemata.remove_all()
        catalog.schemata.append(schema)
 def reverseEngineerViews(cls, connection, schema):
     for view_name in cls.getViewNames(connection, schema.owner.name, schema.name):
         grt.send_info('%s reverseEngineerViews: Cannot reverse engineer view "%s"' % (cls.getTargetDBMSName(), view_name))
     return 0
예제 #48
0
def reverseEngineer(connection, catalog_name, schemata_list, options):
    """Reverse engineers a Sybase ASE database.

    This is the function that will be called by the Migration Wizard to reverse engineer
    a Sybase database. All the other reverseEngineer* functions are not actually required
    and should not be considered part of this module API even though they are currently
    being exposed. This function calls the other reverseEngineer* functions to complete
    the full reverse engineer process.
    """
    grt.send_progress(0, "Reverse engineering catalog information")
    catalog = grt.classes.db_sybase_Catalog()
    catalog.name = catalog_name
    catalog.simpleDatatypes.remove_all()
    catalog.simpleDatatypes.extend(connection.driver.owner.simpleDatatypes)
    catalog.defaultCollationName = '' #   FIXME: Find out the right collation for the catalog
    
    grt.send_progress(0.05, "Reverse engineering User Data Types...")
    check_interruption()  #
    reverseEngineerUserDatatypes(connection, catalog)

    # calculate total workload 1st
    grt.send_progress(0.1, 'Preparing...')
    table_count_per_schema = {}
    view_count_per_schema = {}
    routine_count_per_schema = {}
    trigger_count_per_schema = {}
    total_count_per_schema = {}

    get_tables = options.get("reverseEngineerTables", True)
    get_triggers = options.get("reverseEngineerTriggers", True)
    get_views = options.get("reverseEngineerViews", True)
    get_routines = options.get("reverseEngineerRoutines", True)

    # 10% of the progress is for preparation
    total = 1e-10  # total should not be zero to avoid DivisionByZero exceptions
    i = 1.0
    accumulated_progress = 0.1
    for schema_name in schemata_list:
        check_interruption()
        table_count_per_schema[schema_name] = len(getTableNames(connection, catalog_name, schema_name)) if get_tables else 0
        view_count_per_schema[schema_name] = len(getViewNames(connection, catalog_name, schema_name)) if get_views else 0
        check_interruption()
        routine_count_per_schema[schema_name] = len(getProcedureNames(connection, catalog_name, schema_name)) + len(getFunctionNames(connection, catalog_name, schema_name)) if get_routines else 0
        trigger_count_per_schema[schema_name] = len(getTriggerNames(connection, catalog_name, schema_name)) if get_triggers else 0

        total_count_per_schema[schema_name] = (table_count_per_schema[schema_name] + view_count_per_schema[schema_name] +
                                               routine_count_per_schema[schema_name] + trigger_count_per_schema[schema_name] + 1e-10)
        total += total_count_per_schema[schema_name]

        grt.send_progress(accumulated_progress + 0.1 * (i / (len(schemata_list) + 1e-10) ), "Gathered stats for %s" % schema_name)
        i += 1.0

    # Now take 60% in the first pass of reverse engineering:
    accumulated_progress = 0.2
    grt.reset_progress_steps()
    grt.begin_progress_step(accumulated_progress, accumulated_progress + 0.6)
    accumulated_schema_progress = 0.0
    for schema_name in schemata_list:
        schema_progress_share = total_count_per_schema.get(schema_name, 0.0) / total

        grt.begin_progress_step(accumulated_schema_progress, accumulated_schema_progress + schema_progress_share)
        
        this_schema_progress = 0.0

        schema = grt.classes.db_sybase_Schema()
        schema.owner = catalog
        schema.name = schema_name
        schema.defaultCollationName = catalog.defaultCollationName
        catalog.schemata.append(schema)

        # Reverse engineer tables:
        step_progress_share = table_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10)
        if get_tables:
            check_interruption()
            grt.send_info('Reverse engineering %i tables from %s' % (table_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress, this_schema_progress + step_progress_share)
            # Remove previous first pass marks that may exist if the user goes back and attempt rev eng again:
            progress_flags = _connections[connection.__id__].setdefault('_rev_eng_progress_flags', set())
            progress_flags.discard('%s_tables_first_pass' % schema_name)
            reverseEngineerTables(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(this_schema_progress, 'First pass of table reverse engineering for schema %s completed!' % schema_name)

        # Reverse engineer views:
        step_progress_share = view_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10)
        if get_views:
            check_interruption()
            grt.send_info('Reverse engineering %i views from %s' % (view_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress, this_schema_progress + step_progress_share)
            reverseEngineerViews(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(this_schema_progress, 'Reverse engineering of views for schema %s completed!' % schema_name)

        # Reverse engineer routines:
        step_progress_share = routine_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10)
        if get_routines:
            check_interruption()
            grt.send_info('Reverse engineering %i routines from %s' % (routine_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress, this_schema_progress + step_progress_share/2)
            schema.routines.remove_all()
            reverseEngineerProcedures(connection, schema)
            grt.end_progress_step()
            check_interruption()
            grt.begin_progress_step(this_schema_progress + step_progress_share/2, this_schema_progress + step_progress_share)
            reverseEngineerFunctions(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(this_schema_progress, 'Reverse engineering of routines for schema %s completed!' % schema_name)

        # Reverse engineer triggers:
        step_progress_share = trigger_count_per_schema[schema_name] / (total_count_per_schema[schema_name] + 1e-10)
        if get_triggers:
            check_interruption()
            grt.send_info('Reverse engineering %i triggers from %s' % (trigger_count_per_schema[schema_name], schema_name))
            grt.begin_progress_step(this_schema_progress, this_schema_progress + step_progress_share)
            reverseEngineerTriggers(connection, schema)
            grt.end_progress_step()

        this_schema_progress += step_progress_share
        grt.send_progress(this_schema_progress, 'Reverse engineering of triggers for schema %s completed!' % schema_name)
    
        accumulated_schema_progress += schema_progress_share
        grt.end_progress_step()

    grt.end_progress_step()

    # Now the second pass for reverse engineering tables:
    accumulated_progress = 0.8
    if get_tables:
        total_tables = sum(table_count_per_schema[schema.name] for schema in catalog.schemata)
        for schema in catalog.schemata:
            check_interruption()
            step_progress_share = 0.2 * (table_count_per_schema[schema.name] / (total_tables + 1e-10))
            grt.send_info('Reverse engineering foreign keys for tables in schema %s' % schema.name)
            grt.begin_progress_step(accumulated_progress, accumulated_progress + step_progress_share)
            reverseEngineerTables(connection, schema)
            grt.end_progress_step()
    
            accumulated_progress += step_progress_share
            grt.send_progress(accumulated_progress, 'Second pass of table reverse engineering for schema %s completed!' % schema_name)
        

    grt.send_progress(1.0, 'Reverse engineering completed!')
    return catalog
 def reverseEngineerFunctions(cls, connection, schema):
     # Unfortunately it seems that there's no way to get the SQL definition of a store procedure/function with ODBC
     for function_name in cls.getFunctionNames(connection, schema.owner.name, schema.name):
         grt.send_info('%s reverseEngineerFunctions: Cannot reverse engineer function "%s"' % (cls.getTargetDBMSName(), function_name))
     return 0
 def reverseEngineerTriggers(cls, connection, schema):
     # Unfortunately it seems that there's no way to get the SQL definition of a trigger with ODBC
     for trigger_name in cls.getTriggerNames(connection, schema.owner.name, schema.name):
         grt.send_info('%s reverseEngineerTriggers: Cannot reverse engineer trigger "%s"' % (cls.getTargetDBMSName(), trigger_name))
     return 0
예제 #51
0
def createCatalogObjects(connection, catalog, objectCreationParams,
                         creationLog):
    """Create catalog objects in the server for the specified connection. The catalog must have been 
    previously processed with generateSQLCreateStatements(), so that the objects have their temp_sql 
    attributes set with their respective SQL CREATE statements.
    """
    def makeLogObject(obj):
        if creationLog is not None:
            log = grt.classes.GrtLogObject()
            log.logObject = obj
            creationLog.append(log)
            return log
        else:
            return None

    try:
        grt.send_progress(
            0.0, "Creating schema in target MySQL server at %s..." %
            connection.hostIdentifier)

        preamble = catalog.customData["migration:preamble"]
        grt.send_progress(0.0, "Executing preamble script...")
        execute_script(connection, preamble.temp_sql, makeLogObject(preamble))

        i = 0.0
        for schema in catalog.schemata:
            grt.begin_progress_step(i, i + 1.0 / len(catalog.schemata))
            i += 1.0 / len(catalog.schemata)

            if schema.commentedOut:
                grt.send_progress(1.0, "Skipping schema %s... " % schema.name)
                grt.end_progress_step()
                continue

            total = len(schema.tables) + len(schema.views) + len(
                schema.routines) + sum(
                    [len(table.triggers) for table in schema.tables])

            grt.send_progress(0.0, "Creating schema %s..." % schema.name)
            execute_script(connection, schema.temp_sql, makeLogObject(schema))

            tcount = 0
            vcount = 0
            rcount = 0
            trcount = 0
            o = 0
            for table in schema.tables:
                if table.commentedOut:
                    grt.send_progress(
                        float(o) / total,
                        "Skipping table %s.%s" % (schema.name, table.name))
                else:
                    grt.send_progress(
                        float(o) / total,
                        "Creating table %s.%s" % (schema.name, table.name))
                o += 1
                if not table.commentedOut and execute_script(
                        connection, table.temp_sql, makeLogObject(table)):
                    tcount += 1

            for view in schema.views:
                if view.commentedOut:
                    grt.send_progress(
                        float(o) / total,
                        "Skipping view %s.%s" % (schema.name, view.name))
                else:
                    grt.send_progress(
                        float(o) / total,
                        "Creating view %s.%s" % (schema.name, view.name))
                o += 1
                if not view.commentedOut and execute_script(
                        connection, view.temp_sql, makeLogObject(view)):
                    vcount += 1

            for routine in schema.routines:
                if routine.commentedOut:
                    grt.send_progress(
                        float(o) / total,
                        "Skipping routine %s.%s" % (schema.name, routine.name))
                else:
                    grt.send_progress(
                        float(o) / total,
                        "Creating routine %s.%s" % (schema.name, routine.name))
                o += 1
                if not routine.commentedOut and execute_script(
                        connection, routine.temp_sql, makeLogObject(routine)):
                    rcount += 1

            for table in schema.tables:
                for trigger in table.triggers:
                    if trigger.commentedOut:
                        grt.send_progress(
                            float(o) / total, "Skipping trigger %s.%s.%s" %
                            (schema.name, table.name, trigger.name))
                    else:
                        grt.send_progress(
                            float(o) / total, "Creating trigger %s.%s.%s" %
                            (schema.name, table.name, trigger.name))
                    o += 1
                    if not trigger.commentedOut and execute_script(
                            connection, trigger.temp_sql,
                            makeLogObject(trigger)):
                        trcount += 1

            grt.send_info(
                "Scripts for %i tables, %i views and %i routines were executed for schema %s"
                % (tcount, vcount, rcount, schema.name))
            grt.end_progress_step()

        postamble = catalog.customData["migration:postamble"]
        grt.send_progress(1.0, "Executing postamble script...")
        execute_script(connection, postamble.temp_sql,
                       makeLogObject(postamble))

        grt.send_progress(1.0, "Schema created")
    except grt.UserInterrupt:
        grt.send_info(
            "Cancellation request detected, interrupting schema creation.")
        raise

    return 1
                    (key, value)
                    for key, value in all_params_dict.iteritems()
                    if not (value.startswith("%") and value.endswith("%"))
                )
                params["password"] = password
                conn_params = dict(params)
                conn_params["password"] = "******"
                connection.parameterValues["wbcopytables_connection_string"] = repr(conn_params)

                con = sqlanydb.connect(**params)
            else:
                con = db_driver.connect(connection, password)
            if not con:
                grt.send_error("Connection failed", str(exc))
                raise
            grt.send_info("Connected")
            cls._connections[connection.__id__] = {"connection": con}
        if con:
            ver = cls.execute_query(connection, "SELECT @@version").fetchone()[0]
            grt.log_info("SQLAnywhere RE", "Connected to %s, %s\n" % (connection.name, ver))
            ver_parts = server_version_str2tuple(ver) + (0, 0, 0, 0)
            version = grt.classes.GrtVersion()
            version.majorNumber, version.minorNumber, version.releaseNumber, version.buildNumber = ver_parts[:4]
            cls._connections[connection.__id__]["version"] = version

        return 1

    @classmethod
    @release_cursors
    def getCatalogNames(cls, connection):
        """Returns a list of the available catalogs.