Пример #1
0
    def init_registry(self, registry_filename=None):
        if registry_filename is not None:
            self._registry = core.db.io.open_registry(registry_filename)
            self._registry.set_global()
        else:
            self._registry = ModuleRegistry()
            self._registry.set_global()

            def setup_basic_package():
                # setup basic package
                basic_package = self.add_package('basic_modules')
                self._registry._default_package = basic_package
                prefix_dictionary = {'basic_modules': 'core.modules.'}
                self.initialize_packages(prefix_dictionary)
            setup_basic_package()

            self._abstraction_pkg = self.add_package('abstraction', False)
            # FIXME need to get this info from the package, but cannot
            # do this since controller isn't imported yet
            self._abstraction_pkg.identifier = 'local.abstractions'
            self._abstraction_pkg.name = 'My SubWorkflows'
            self._abstraction_pkg.version = '1.6'
            self._registry.add_package(self._abstraction_pkg)
 def make_registry(self):
     reg_module = \
         registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', 
                                         self.name).module
     self._registry = ModuleRegistry()
     self._registry.add_hierarchy(registry, self)
     for module in self.pipeline.module_list:
         if module.name == 'OutputPort':
             port = self.make_port_from_module(module, 'source')
             self._registry.add_port(reg_module, PortEndPoint.Source, port)
         elif module.name == 'InputPort':
             port = self.make_port_from_module(module, 'destination')
             self._registry.add_port(reg_module, PortEndPoint.Destination, 
                                     port)
Пример #3
0
    def add_port_to_registry(self, port_spec):
        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        if self.registry is None:
            self.registry = ModuleRegistry()
            self.registry.add_hierarchy(registry, self)

        if port_spec.type == 'input':
            endpoint = PortEndPoint.Destination
        else:
            endpoint = PortEndPoint.Source
        portSpecs = port_spec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port()
        port.name = port_spec.name
        port.spec = core.modules.module_registry.PortSpec(signature)
        self.registry.add_port(module, endpoint, port)        
    def init_registry(self, registry_filename=None):
        if registry_filename is not None:
            self._registry = core.db.io.open_registry(registry_filename)
            self._registry.set_global()
        else:
            self._registry = ModuleRegistry()
            self._registry.set_global()

            def setup_basic_package():
                # setup basic package
                basic_package = self.add_package('basic_modules')
                self._registry._default_package = basic_package
                prefix_dictionary = {'basic_modules': 'core.modules.'}
                self.initialize_packages(prefix_dictionary)
            setup_basic_package()

            self._abstraction_pkg = self.add_package('abstraction', False)
            # FIXME need to get this info from the package, but cannot
            # do this since controller isn't imported yet
            self._abstraction_pkg.identifier = 'local.abstractions'
            self._abstraction_pkg.name = 'My SubWorkflows'
            self._abstraction_pkg.version = '1.6'
            self._registry.add_package(self._abstraction_pkg)
Пример #5
0
class PackageManager(object):
    # # add_package_menu_signal is emitted with a tuple containing the package
    # # identifier, package name and the menu item
    # add_package_menu_signal = QtCore.SIGNAL("add_package_menu")
    # # remove_package_menu_signal is emitted with the package identifier
    # remove_package_menu_signal = QtCore.SIGNAL("remove_package_menu")
    # # package_error_message_signal is emitted with the package identifier,
    # # package name and the error message
    # package_error_message_signal = QtCore.SIGNAL("package_error_message_signal")
    # # reloading_package_signal is emitted when a package reload has disabled
    # # the packages, but has not yet enabled them
    # reloading_package_signal = QtCore.SIGNAL("reloading_package_signal")

    class DependencyCycle(Exception):
        def __init__(self, p1, p2):
            self._package_1 = p1
            self._package_2 = p2

        def __str__(self):
            return ("Packages '%s' and '%s' have cyclic dependencies" %
                    (self._package_1, self._package_2))

    class MissingPackage(Exception):
        def __init__(self, n):
            self._package_name = n

        def __str__(self):
            return "Package '%s' is missing." % self._package_name

    class PackageInternalError(Exception):
        def __init__(self, n, d):
            self._package_name = n
            self._description = d

        def __str__(self):
            return "Package '%s' has a bug: %s" % (self._package_name,
                                                   self._description)

    def import_packages_module(self):
        """Imports the packages module using path trickery to find it
        in the right place.

        """
        if self._packages is not None:
            return self._packages
        # Imports standard packages directory
        conf = self._configuration
        old_sys_path = copy.copy(sys.path)
        if conf.check('packageDirectory'):
            sys.path.insert(0, conf.packageDirectory)
        try:
            import packages
        except ImportError:
            debug.critical('ImportError: "packages" sys.path: %s' % sys.path)
            raise
        finally:
            sys.path = old_sys_path
        self._packages = packages
        return packages

    def import_user_packages_module(self):
        """Imports the packages module using path trickery to find it
        in the right place.

        """
        if self._userpackages is not None:
            return self._userpackages
        # Imports user packages directory
        conf = self._configuration
        old_sys_path = copy.copy(sys.path)
        if conf.check('userPackageDirectory'):
            sys.path.insert(
                0, os.path.join(conf.userPackageDirectory, os.path.pardir))
        try:
            import userpackages
        except ImportError:
            debug.critical('ImportError: "userpackages" sys.path: %s' %
                           sys.path)
            raise
        finally:
            sys.path = old_sys_path
        self._userpackages = userpackages
        return userpackages

    def __init__(self, configuration):
        """__init__(configuration: ConfigurationObject) -> PackageManager
        configuration is the persistent configuration object of the application.
        
        """
        global _package_manager
        if _package_manager:
            m = "Package manager can only be constructed once."
            raise VistrailsInternalError(m)
        _package_manager = self
        self._configuration = configuration
        self._package_list = {}
        self._package_versions = {}
        self._dependency_graph = core.data_structures.graph.Graph()
        self._registry = None
        self._userpackages = None
        self._packages = None
        self._abstraction_pkg = None

    def init_registry(self, registry_filename=None):
        if registry_filename is not None:
            self._registry = core.db.io.open_registry(registry_filename)
            self._registry.set_global()
        else:
            self._registry = ModuleRegistry()
            self._registry.set_global()

            def setup_basic_package():
                # setup basic package
                basic_package = self.add_package('basic_modules')
                self._registry._default_package = basic_package
                prefix_dictionary = {'basic_modules': 'core.modules.'}
                self.initialize_packages(prefix_dictionary)

            setup_basic_package()

            self._abstraction_pkg = self.add_package('abstraction', False)
            # FIXME need to get this info from the package, but cannot
            # do this since controller isn't imported yet
            self._abstraction_pkg.identifier = 'local.abstractions'
            self._abstraction_pkg.name = 'My SubWorkflows'
            self._abstraction_pkg.version = '1.6'
            self._registry.add_package(self._abstraction_pkg)

    def finalize_packages(self):
        """Finalizes all installed packages. Call this only prior to
exiting VisTrails."""
        for package in self._package_list.itervalues():
            package.finalize()
        self._package_list = {}
        self._package_versions = {}
        global _package_manager
        _package_manager = None

    def add_package(self, codepath, add_to_package_list=True):
        """Adds a new package to the manager. This does not initialize it.
To do so, call initialize_packages()"""
        package = self._registry.create_package(codepath)
        if add_to_package_list:
            self.add_to_package_list(codepath, package)
        return package

    def add_to_package_list(self, codepath, package):
        self._package_list[codepath] = package

    def initialize_abstraction_pkg(self, prefix_dictionary):
        if self._abstraction_pkg is None:
            raise Exception("Subworkflows packages is None")
        self.add_to_package_list(self._abstraction_pkg.codepath,
                                 self._abstraction_pkg)
        self.late_enable_package(self._abstraction_pkg.codepath,
                                 prefix_dictionary, False)

    def remove_package(self, codepath):
        """remove_package(name): Removes a package from the system."""
        pkg = self._package_list[codepath]
        self._dependency_graph.delete_vertex(pkg.identifier)
        del self._package_versions[pkg.identifier][pkg.version]
        if len(self._package_versions[pkg.identifier]) == 0:
            del self._package_versions[pkg.identifier]
        self.remove_menu_items(pkg)
        pkg.finalize()
        del self._package_list[codepath]
        self._registry.remove_package(pkg)
        app = get_vistrails_application()
        app.send_notification("package_removed", codepath)

    def has_package(self, identifier, version=None):
        """has_package(identifer: string) -> Boolean.
Returns true if given package identifier is present."""
        if identifier in self._package_versions:
            return (version is None
                    or version in self._package_versions[identifier])
        return False

    def look_at_available_package(self, codepath):
        """look_at_available_package(codepath: string) -> Package

        Returns a Package object for an uninstalled package. This does
        NOT install a package.
        """
        return self._registry.create_package(codepath, False)

    def get_package(self, identifier, version=None):
        package_versions = self._package_versions[identifier]
        if version is not None:
            return package_versions[version]

        max_version = '0'
        max_pkg = None
        for version, pkg in package_versions.iteritems():
            if versions_increasing(max_version, version):
                max_version = version
                max_pkg = pkg
        return max_pkg

    def get_package_by_codepath(self, codepath):
        """get_package_by_codepath(codepath: string) -> Package.
        Returns a package with given codepath if it is enabled,
        otherwise throws exception
        """
        if codepath not in self._package_list:
            raise self.MissingPackage(codepath)
        else:
            return self._package_list[codepath]

    def get_package_by_identifier(self, identifier):
        """get_package_by_identifier(identifier: string) -> Package.
        Returns a package with given identifier if it is enabled,
        otherwise throws exception
        """
        if identifier not in self._registry.packages:
            raise self.MissingPackage(identifier)
        return self._registry.packages[identifier]

    def get_package_configuration(self, codepath):
        """get_package_configuration(codepath: string) ->
        ConfigurationObject or None

        Returns the configuration object for the package, if existing,
        or None. Throws MissingPackage if package doesn't exist.
        """

        pkg = self.get_package_by_codepath(codepath)

        if not hasattr(pkg.module, 'configuration'):
            return None
        else:
            c = pkg.module.configuration
            if not isinstance(c, ConfigurationObject):
                d = "'configuration' attribute should be a ConfigurationObject"
                raise self.PackageInternalError(codepath, d)
            return c

    def check_dependencies(self, package, deps):
        # want to check that necessary version also exists, if specified
        missing_deps = []
        for dep in deps:
            min_version = None
            max_version = None
            if type(dep) == tuple:
                identifier = dep[0]
                if len(dep) > 1:
                    min_version = dep[1]
                    if len(dep) > 2:
                        max_version = dep[2]
            else:
                identifier = dep

            if identifier not in self._package_versions:
                missing_deps.append((identifier, None, None))
            else:
                if min_version is None and max_version is None:
                    continue
                found_version = False
                for version, pkg in \
                        self._package_versions[identifier].iteritems():
                    if ((min_version is None
                         or versions_increasing(min_version, version)) and
                        (max_version is None
                         or versions_increasing(version, max_version))):
                        found_version = True
                if not found_version:
                    missing_deps.append((identifier, min_version, max_version))

        if len(missing_deps) > 0:
            raise Package.MissingDependency(package, missing_deps)
        return True

    def add_dependencies(self, package):
        """add_dependencies(package) -> None.  Register all
        dependencies a package contains by calling the appropriate
        callback.

        Does not add multiple dependencies - if a dependency is already there,
        add_dependencies ignores it.
        """
        deps = package.dependencies()
        # FIXME don't hardcode this
        from core.modules.basic_modules import identifier as basic_pkg
        if package.identifier != basic_pkg:
            deps.append(basic_pkg)

        self.check_dependencies(package, deps)

        for dep in deps:
            if type(dep) == tuple:
                dep_name = dep[0]
            else:
                dep_name = dep

            if not self._dependency_graph.has_edge(package.identifier,
                                                   dep_name):
                self._dependency_graph.add_edge(package.identifier, dep_name)

    def late_enable_package(self,
                            package_codepath,
                            prefix_dictionary={},
                            needs_add=True):
        """late_enable_package enables a package 'late', that is,
        after VisTrails initialization. All dependencies need to be
        already enabled.
        """
        if needs_add:
            if package_codepath in self._package_list:
                msg = 'duplicate package identifier: %s' % package_codepath
                raise VistrailsInternalError(msg)
            self.add_package(package_codepath)
        pkg = self.get_package_by_codepath(package_codepath)
        try:
            pkg.load(prefix_dictionary.get(pkg.codepath, None))
        except Exception, e:
            # invert self.add_package
            del self._package_list[package_codepath]
            raise
        self._dependency_graph.add_vertex(pkg.identifier)
        if pkg.identifier not in self._package_versions:
            self._package_versions[pkg.identifier] = {}
        self._package_versions[pkg.identifier][pkg.version] = pkg
        try:
            self.add_dependencies(pkg)
            #check_requirements is now called in pkg.initialize()
            #pkg.check_requirements()
            self._registry.initialize_package(pkg)
            self._registry.signals.emit_new_package(pkg.identifier, True)
            app = get_vistrails_application()
            app.send_notification("package_added", package_codepath)
            self.add_menu_items(pkg)
        except Exception, e:
            del self._package_versions[pkg.identifier][pkg.version]
            if len(self._package_versions[pkg.identifier]) == 0:
                del self._package_versions[pkg.identifier]
            self._dependency_graph.delete_vertex(pkg.identifier)
            # invert self.add_package
            del self._package_list[package_codepath]
            # if we adding the package to the registry, make sure we
            # remove it if initialization fails
            try:
                self._registry.remove_package(pkg)
            except MissingPackage:
                pass
            raise e
class PackageManager(object):
    # # add_package_menu_signal is emitted with a tuple containing the package
    # # identifier, package name and the menu item
    # add_package_menu_signal = QtCore.SIGNAL("add_package_menu")
    # # remove_package_menu_signal is emitted with the package identifier
    # remove_package_menu_signal = QtCore.SIGNAL("remove_package_menu")
    # # package_error_message_signal is emitted with the package identifier,
    # # package name and the error message
    # package_error_message_signal = QtCore.SIGNAL("package_error_message_signal")
    # # reloading_package_signal is emitted when a package reload has disabled
    # # the packages, but has not yet enabled them
    # reloading_package_signal = QtCore.SIGNAL("reloading_package_signal")

    class DependencyCycle(Exception):
        def __init__(self, p1, p2):
            self._package_1 = p1
            self._package_2 = p2
        def __str__(self):
            return ("Packages '%s' and '%s' have cyclic dependencies" %
                    (self._package_1,
                     self._package_2))

    class MissingPackage(Exception):
        def __init__(self, n):
            self._package_name = n
        def __str__(self):
            return "Package '%s' is missing." % self._package_name

    class PackageInternalError(Exception):
        def __init__(self, n, d):
            self._package_name = n
            self._description = d
        def __str__(self):
            return "Package '%s' has a bug: %s" % (self._package_name,
                                                   self._description)

    def import_packages_module(self):
        """Imports the packages module using path trickery to find it
        in the right place.

        """
        if self._packages is not None:
            return self._packages
        # Imports standard packages directory
        conf = self._configuration
        old_sys_path = copy.copy(sys.path)
        if conf.check('packageDirectory'):
            sys.path.insert(0, conf.packageDirectory)
        try:
            import packages
        except ImportError:
            debug.critical('ImportError: "packages" sys.path: %s' % sys.path)
            raise
        finally:
            sys.path = old_sys_path
        self._packages = packages
        return packages

    def import_user_packages_module(self):
        """Imports the packages module using path trickery to find it
        in the right place.

        """
        if self._userpackages is not None:
            return self._userpackages
        # Imports user packages directory
        conf = self._configuration
        old_sys_path = copy.copy(sys.path)
        if conf.check('userPackageDirectory'):
            sys.path.insert(0, os.path.join(conf.userPackageDirectory,
                                            os.path.pardir))
        try:
            import userpackages
        except ImportError:
            debug.critical('ImportError: "userpackages" sys.path: %s' % sys.path)
            raise
        finally:
            sys.path = old_sys_path
        self._userpackages = userpackages
        return userpackages

    def __init__(self, configuration):
        """__init__(configuration: ConfigurationObject) -> PackageManager
        configuration is the persistent configuration object of the application.
        
        """
        global _package_manager
        if _package_manager:
            m = "Package manager can only be constructed once."
            raise VistrailsInternalError(m)
        _package_manager = self
        self._configuration = configuration
        self._package_list = {}
        self._package_versions = {}
        self._dependency_graph = core.data_structures.graph.Graph()
        self._registry = None
        self._userpackages = None
        self._packages = None
        self._abstraction_pkg = None

    def init_registry(self, registry_filename=None):
        if registry_filename is not None:
            self._registry = core.db.io.open_registry(registry_filename)
            self._registry.set_global()
        else:
            self._registry = ModuleRegistry()
            self._registry.set_global()

            def setup_basic_package():
                # setup basic package
                basic_package = self.add_package('basic_modules')
                self._registry._default_package = basic_package
                prefix_dictionary = {'basic_modules': 'core.modules.'}
                self.initialize_packages(prefix_dictionary)
            setup_basic_package()

            self._abstraction_pkg = self.add_package('abstraction', False)
            # FIXME need to get this info from the package, but cannot
            # do this since controller isn't imported yet
            self._abstraction_pkg.identifier = 'local.abstractions'
            self._abstraction_pkg.name = 'My SubWorkflows'
            self._abstraction_pkg.version = '1.6'
            self._registry.add_package(self._abstraction_pkg)

    def finalize_packages(self):
        """Finalizes all installed packages. Call this only prior to
exiting VisTrails."""
        for package in self._package_list.itervalues():
            package.finalize()
        self._package_list = {}
        self._package_versions = {}
        global _package_manager
        _package_manager = None

    def add_package(self, codepath, add_to_package_list=True):
        """Adds a new package to the manager. This does not initialize it.
To do so, call initialize_packages()"""
        package = self._registry.create_package(codepath)
        if add_to_package_list:
            self.add_to_package_list(codepath, package)
        app = get_vistrails_application()
        app.send_notification("package_added", codepath)
        return package

    def add_to_package_list(self, codepath, package):
        self._package_list[codepath] = package

    def initialize_abstraction_pkg(self, prefix_dictionary):
        if self._abstraction_pkg is None:
            raise Exception("Subworkflows packages is None")
        self.add_to_package_list(self._abstraction_pkg.codepath,
                                 self._abstraction_pkg)
        self.late_enable_package(self._abstraction_pkg.codepath, 
                                 prefix_dictionary, False)

    def remove_package(self, codepath):
        """remove_package(name): Removes a package from the system."""
        pkg = self._package_list[codepath]
        self._dependency_graph.delete_vertex(pkg.identifier)
        del self._package_versions[pkg.identifier][pkg.version]
        if len(self._package_versions[pkg.identifier]) == 0:
            del self._package_versions[pkg.identifier]
        self.remove_menu_items(pkg)
        pkg.finalize()
        del self._package_list[codepath]
        self._registry.remove_package(pkg)
        app = get_vistrails_application()
        app.send_notification("package_removed", codepath)

    def has_package(self, identifier, version=None):
        """has_package(identifer: string) -> Boolean.
Returns true if given package identifier is present."""
        if identifier in self._package_versions:
            return (version is None or 
                    version in self._package_versions[identifier])
        return False

    def look_at_available_package(self, codepath):
        """look_at_available_package(codepath: string) -> Package

        Returns a Package object for an uninstalled package. This does
        NOT install a package.
        """
        return self._registry.create_package(codepath, False)

    def get_package(self, identifier, version=None):
        package_versions = self._package_versions[identifier]
        if version is not None:
            return package_versions[version]

        max_version = '0'
        max_pkg = None
        for version, pkg in package_versions.iteritems():
            if versions_increasing(max_version, version):
                max_version = version
                max_pkg = pkg
        return max_pkg

    def get_package_by_codepath(self, codepath):
        """get_package_by_codepath(codepath: string) -> Package.
        Returns a package with given codepath if it is enabled,
        otherwise throws exception
        """
        if codepath not in self._package_list:
            raise self.MissingPackage(codepath)
        else:
            return self._package_list[codepath]

    def get_package_by_identifier(self, identifier):
        """get_package_by_identifier(identifier: string) -> Package.
        Returns a package with given identifier if it is enabled,
        otherwise throws exception
        """
        if identifier not in self._registry.packages:
            raise self.MissingPackage(identifier)
        return self._registry.packages[identifier]

    def get_package_configuration(self, codepath):
        """get_package_configuration(codepath: string) ->
        ConfigurationObject or None

        Returns the configuration object for the package, if existing,
        or None. Throws MissingPackage if package doesn't exist.
        """

        pkg = self.get_package_by_codepath(codepath)

        if not hasattr(pkg.module, 'configuration'):
            return None
        else:
            c = pkg.module.configuration
            if not isinstance(c, ConfigurationObject):
                d = "'configuration' attribute should be a ConfigurationObject"
                raise self.PackageInternalError(codepath, d)
            return c

    def check_dependencies(self, package, deps):
        # want to check that necessary version also exists, if specified
        missing_deps = []
        for dep in deps:
            min_version = None
            max_version = None
            if type(dep) == tuple:
                identifier = dep[0]
                if len(dep) > 1:
                    min_version = dep[1]
                    if len(dep) > 2:
                        max_version = dep[2]
            else:
                identifier = dep

            if identifier not in self._package_versions:
                missing_deps.append((identifier, None, None))
            else:
                if min_version is None and max_version is None:
                    continue
                found_version = False
                for version, pkg in \
                        self._package_versions[identifier].iteritems():
                    if ((min_version is None or
                         versions_increasing(min_version, version)) and
                        (max_version is None or
                         versions_increasing(version, max_version))):
                        found_version = True
                if not found_version:
                    missing_deps.append((identifier, min_version, max_version))

        if len(missing_deps) > 0:
            raise Package.MissingDependency(package, missing_deps)
        return True

    def add_dependencies(self, package):
        """add_dependencies(package) -> None.  Register all
        dependencies a package contains by calling the appropriate
        callback.

        Does not add multiple dependencies - if a dependency is already there,
        add_dependencies ignores it.
        """
        deps = package.dependencies()
        # FIXME don't hardcode this
        from core.modules.basic_modules import identifier as basic_pkg
        if package.identifier != basic_pkg:
            deps.append(basic_pkg)

        self.check_dependencies(package, deps)

        for dep in deps:
            if type(dep) == tuple:
                dep_name = dep[0]
            else:
                dep_name = dep

            if not self._dependency_graph.has_edge(package.identifier,
                                                   dep_name):
                self._dependency_graph.add_edge(package.identifier, dep_name)

    def late_enable_package(self, package_codepath, prefix_dictionary={}, 
                            needs_add=True):
        """late_enable_package enables a package 'late', that is,
        after VisTrails initialization. All dependencies need to be
        already enabled.
        """
        if needs_add:
            if package_codepath in self._package_list:
                msg = 'duplicate package identifier: %s' % package_codepath
                raise VistrailsInternalError(msg)
            self.add_package(package_codepath)
        pkg = self.get_package_by_codepath(package_codepath)
        try:
            pkg.load(prefix_dictionary.get(pkg.codepath, None))
        except Exception, e:
            # invert self.add_package
            del self._package_list[package_codepath]
            raise
        self._dependency_graph.add_vertex(pkg.identifier)
        if pkg.identifier not in self._package_versions:
            self._package_versions[pkg.identifier] = {}
        self._package_versions[pkg.identifier][pkg.version] = pkg
        try:
            self.add_dependencies(pkg)
            #check_requirements is now called in pkg.initialize()
            #pkg.check_requirements()
            self._registry.initialize_package(pkg)
            # FIXME Empty packages still need to be added, but currently they are not
            # because newPackage is typically only called for the first module inside
            # a package.
            from core.modules.abstraction import identifier as abstraction_identifier
            if pkg.identifier == abstraction_identifier:
                self._registry.signals.emit_new_package(abstraction_identifier, True)
        except Exception, e:
            del self._package_versions[pkg.identifier][pkg.version]
            if len(self._package_versions[pkg.identifier]) == 0:
                del self._package_versions[pkg.identifier]
            self._dependency_graph.delete_vertex(pkg.identifier)
            # invert self.add_package
            del self._package_list[package_codepath]
            # if we adding the package to the registry, make sure we
            # remove it if initialization fails
            try:
                self._registry.remove_package(pkg)
            except MissingPackage:
                pass
            raise e
Пример #7
0
class Module(DBModule):
    """ Represents a module from a Pipeline """

    ##########################################################################
    # Constructor and copy

    def __init__(self, *args, **kwargs):
        DBModule.__init__(self, *args, **kwargs)
        if self.cache is None:
            self.cache = 1
        if self.id is None:
            self.id = -1
        if self.location is None:
            self.location = Location(x=-1.0, y=-1.0)
        if self.name is None:
            self.name = ''
        if self.package is None:
            self.package = ''
        if self.version is None:
            self.version = ''
        self.portVisible = set()
        self.registry = None

    def __copy__(self):
        """__copy__() -> Module - Returns a clone of itself"""
        return Module.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBModule.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = Module
        # cp.registry = copy.copy(self.registry)
        cp.registry = None
        for port_spec in cp.db_portSpecs:
            cp.add_port_to_registry(port_spec)
        cp.portVisible = copy.copy(self.portVisible)
        return cp

    @staticmethod
    def convert(_module):
	_module.__class__ = Module
	_module.registry = None
        for _port_spec in _module.db_portSpecs:
            PortSpec.convert(_port_spec)
            _module.add_port_to_registry(_port_spec)
        if _module.db_location:
            Location.convert(_module.db_location)
	for _function in _module.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _module.db_get_annotations():
            Annotation.convert(_annotation)

        _module.portVisible = set()

    ##########################################################################

    id = DBModule.db_id
    cache = DBModule.db_cache
    annotations = DBModule.db_annotations
    location = DBModule.db_location
    center = DBModule.db_location
    name = DBModule.db_name
    label = DBModule.db_name
    namespace = DBModule.db_namespace
    package = DBModule.db_package
    tag = DBModule.db_tag
    version = DBModule.db_version

    # type check this (list, hash)
    def _get_functions(self):
        self.db_functions.sort(key=lambda x: x.db_pos)
        return self.db_functions
    def _set_functions(self, functions):
	# want to convert functions to hash...?
        self.db_functions = functions
    functions = property(_get_functions, _set_functions)
    def add_function(self, function):
        self.db_add_function(function)

    def add_annotation(self, annotation):
        self.db_add_annotation(annotation)
    def delete_annotation(self, annotation):
        self.db_delete_annotation(annotation)
    def has_annotation_with_key(self, key):
        return self.db_has_annotation_with_key(key)
    def get_annotation_by_key(self, key):
        return self.db_get_annotation_by_key(key)        

    def _get_port_specs(self):
        return self.db_portSpecs_id_index
    port_specs = property(_get_port_specs)
    def has_portSpec_with_name(self, name):
        return self.db_has_portSpec_with_name(name)
    def get_portSpec_by_name(self, name):
        return self.db_get_portSpec_by_name(name)

    def summon(self):
        get = registry.get_descriptor_by_name
        result = get(self.package, self.name, self.namespace).module()
        if self.cache != 1:
            result.is_cacheable = lambda *args: False
        if hasattr(result, 'srcPortsOrder'):
            result.srcPortsOrder = [p.name for p in self.destinationPorts()]
        result.registry = self.registry or registry
        return result

    def getNumFunctions(self):
        """getNumFunctions() -> int - Returns the number of functions """
        return len(self.functions)


    def sourcePorts(self):
        """sourcePorts() -> list of Port 
        Returns list of source (output) ports module supports.

        """

        ports = registry.module_source_ports(True, self.package, self.name, self.namespace)
        if self.registry:
            ports.extend(self.registry.module_source_ports(False, self.package, self.name, self.namespace))
        return ports

    def destinationPorts(self):
        """destinationPorts() -> list of Port 
        Returns list of destination (input) ports module supports

        """
        ports = registry.module_destination_ports(True, self.package, self.name, self.namespace)
        if self.registry:
            ports.extend(self.registry.module_destination_ports(False, self.package, self.name, self.namespace))
        return ports

    def add_port_to_registry(self, port_spec):
        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        if self.registry is None:
            self.registry = ModuleRegistry()
            self.registry.add_hierarchy(registry, self)

        if port_spec.type == 'input':
            endpoint = PortEndPoint.Destination
        else:
            endpoint = PortEndPoint.Source
        portSpecs = port_spec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port()
        port.name = port_spec.name
        port.spec = core.modules.module_registry.PortSpec(signature)
        self.registry.add_port(module, endpoint, port)        

    def delete_port_from_registry(self, id):
        if not id in self.port_specs:
            raise VistrailsInternalError("id missing in port_specs")
        portSpec = self.port_specs[id]
        portSpecs = portSpec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port(signature)
        port.name = portSpec.name
        port.spec = core.modules.module_registry.PortSpec(signature)

        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        assert isinstance(self.registry, ModuleRegistry)

        if portSpec.type == 'input':
            self.registry.delete_input_port(module, port.name)
        else:
            self.registry.delete_output_port(module, port.name)

    ##########################################################################
    # Debugging

    def show_comparison(self, other):
        if type(other) != type(self):
            print "Type mismatch"
            print type(self), type(other)
        elif self.id != other.id:
            print "id mismatch"
            print self.id, other.id
        elif self.name != other.name:
            print "name mismatch"
            print self.name, other.name
        elif self.cache != other.cache:
            print "cache mismatch"
            print self.cache, other.cache
        elif self.location != other.location:
            print "location mismatch"
            # FIXME Location has no show_comparison
            # self.location.show_comparison(other.location)
        elif len(self.functions) != len(other.functions):
            print "function length mismatch"
            print len(self.functions), len(other.functions)
        else:
            for f, g in izip(self.functions, other.functions):
                if f != g:
                    print "function mismatch"
                    f.show_comparison(g)
                    return
            print "No difference found"
            assert self == other

    ##########################################################################
    # Operators

    def __str__(self):
        """__str__() -> str Returns a string representation of itself. """
        return ("(Module '%s' id=%s functions:%s port_specs:%s)@%X" %
                (self.name,
                 self.id,
                 [str(f) for f in self.functions],
                 [str(port_spec) for port_spec in self.db_portSpecs],
                 id(self)))

    def __eq__(self, other):
        """ __eq__(other: Module) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.name != other.name:
            return False
        if self.cache != other.cache:
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f, g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f, g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)
Пример #8
0
def open_registry(filename):
    from core.modules.module_registry import ModuleRegistry
    registry = db.services.io.open_registry_from_xml(filename)
    ModuleRegistry.convert(registry)
    return registry
Пример #9
0
class Group(DBGroup, Module):

    ##########################################################################
    # Constructors and copy

    def __init__(self, *args, **kwargs):
        if 'pipeline' in kwargs:
            kwargs['workflow'] = kwargs['pipeline']
            del kwargs['pipeline']
        DBGroup.__init__(self, *args, **kwargs)
        if self.cache is None:
            self.cache = 1
        if self.id is None:
            self.id = -1
        if self.location is None:
            self.location = Location(x=-1.0, y=-1.0)
        if self.name is None:
            self.name = ''
        if self.package is None:
            self.package = ''
        if self.version is None:
            self.version = ''
        self.portVisible = set()
        self._registry = None

    def __copy__(self):
        return Group.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBGroup.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = Group
        cp._registry = None
#         for port_spec in cp.db_portSpecs:
#             cp.add_port_to_registry(port_spec)
        cp.portVisible = copy.copy(self.portVisible)
        return cp

    @staticmethod
    def convert(_group):
        if _group.__class__ == Group:
            return
        _group.__class__ = Group
        _group._registry = None
        _group.portVisible = set()
        if _group.db_location:
            Location.convert(_group.db_location)
        if _group.db_workflow:
            from core.vistrail.pipeline import Pipeline
            Pipeline.convert(_group.db_workflow)
	for _function in _group.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _group.db_get_annotations():
            Annotation.convert(_annotation)


    ##########################################################################
    # Properties

    # We need to repeat these here because Module uses DBModule. ...
    id = DBGroup.db_id
    cache = DBGroup.db_cache
    annotations = DBGroup.db_annotations
    location = DBGroup.db_location
    center = DBGroup.db_location
    name = DBGroup.db_name
    label = DBGroup.db_name
    namespace = DBGroup.db_namespace
    package = DBGroup.db_package
    tag = DBGroup.db_tag
    version = DBGroup.db_version

    def summon(self):
        # define this so that pipeline is copied over...
        pass

    def is_group(self):
        return True

    pipeline = DBGroup.db_workflow
    
    def _get_registry(self):
        if not self._registry:
            # print 'making registry'
            self.make_registry()
        return self._registry
    registry = property(_get_registry)

    # override these from the Module class with defaults
    def _get_port_specs(self):
        return dict()
    port_specs = property(_get_port_specs)
    def has_portSpec_with_name(self, name):
        return False
    def get_portSpec_by_name(self, name):
        return None

    @staticmethod
    def make_port_from_module(module, port_type):
        for function in module.functions:
            if function.name == 'name':
                port_name = function.params[0].strValue
                print '  port_name:', port_name
            if function.name == 'spec':
                port_spec = function.params[0].strValue
                #print '  port_spec:',  port_spec
        port = Port(id=-1,
                    name=port_name,
                    type=port_type)
        portSpecs = port_spec[1:-1].split(',')
        signature = []
        for s in portSpecs:
            spec = s.split(':', 2)
            signature.append(registry.get_descriptor_by_name(*spec).module)
        port.spec = core.modules.module_registry.PortSpec(signature)
        return port

    def make_registry(self):
        reg_module = \
            registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', 
                                            self.name).module
        self._registry = ModuleRegistry()
        self._registry.add_hierarchy(registry, self)
        for module in self.pipeline.module_list:
            print 'module:', module.name
            if module.name == 'OutputPort':
                port = self.make_port_from_module(module, 'source')
                self._registry.add_port(reg_module, PortEndPoint.Source, port)
            elif module.name == 'InputPort':
                port = self.make_port_from_module(module, 'destination')
                self._registry.add_port(reg_module, PortEndPoint.Destination, 
                                        port)

    def sourcePorts(self):
        return self.registry.module_source_ports(False, self.package,
                                                 self.name, self.namespace)

    def destinationPorts(self):
        return self.registry.module_destination_ports(False, self.package, 
                                                      self.name, self.namespace)

    ##########################################################################
    # Operators
    
    def __str__(self):
        """__str__() -> str - Returns a string representation of an 
        GroupModule object. 

        """
        rep = '<group id="%s" abstraction_id="%s" verion="%s">'
        rep += str(self.location)
        rep += str(self.functions)
        rep += str(self.annotations)
        rep += '</group>'
        return  rep % (str(self.id), str(self.abstraction_id), 
                       str(self.version))

    def __eq__(self, other):
        """ __eq__(other: GroupModule) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f,g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f,g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)
class AbstractionModule(DBAbstractionRef):

    ##########################################################################
    # Constructors and copy

    def __init__(self, *args, **kwargs):
        DBAbstractionRef.__init__(self, *args, **kwargs)
        if self.id is None:
            self.id = -1
        self.portVisible = set()
        self._registry = None
        self.abstraction = None
        # FIXME should we have a registry for an abstraction module?

    def __copy__(self):
        return AbstractionModule.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBAbstractionRef.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = AbstractionModule
        cp.portVisible = copy.copy(self.portVisible)
        cp._registry = self._registry
        cp.abstraction = self.abstraction
        return cp

    @staticmethod
    def convert(_abstraction_module):
        if _abstraction_module.__class__ == AbstractionModule:
            return
        _abstraction_module.__class__ = AbstractionModule
        if _abstraction_module.db_location:
            Location.convert(_abstraction_module.db_location)
	for _function in _abstraction_module.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _abstraction_module.db_get_annotations():
            Annotation.convert(_annotation)
        _abstraction_module.portVisible = set()
        _abstraction_module._registry = None
        _abstraction_module.abstraction = None


    ##########################################################################
    # Properties

    id = DBAbstractionRef.db_id
    cache = DBAbstractionRef.db_cache
    abstraction_id = DBAbstractionRef.db_abstraction_id
    location = DBAbstractionRef.db_location
    center = DBAbstractionRef.db_location
    version = DBAbstractionRef.db_version
    tag = DBAbstractionRef.db_name
    label = DBAbstractionRef.db_name
    name = 'Abstraction'
    package = 'edu.utah.sci.vistrails.basic'
    namespace = None
    annotations = DBAbstractionRef.db_annotations
    
    def _get_functions(self):
        self.db_functions.sort(key=lambda x: x.db_pos)
        return self.db_functions
    def _set_functions(self, functions):
	# want to convert functions to hash...?
        self.db_functions = functions
    functions = property(_get_functions, _set_functions)

    def _get_pipeline(self):
        from core.vistrail.pipeline import Pipeline
        import db.services.vistrail
        workflow = db.services.vistrail.materializeWorkflow(self.abstraction, 
                                                            self.version)
        Pipeline.convert(workflow)
        return workflow
    pipeline = property(_get_pipeline)

    def _get_registry(self):
        if not self._registry:
            self.make_registry()
        return self._registry
    registry = property(_get_registry)

    def add_annotation(self, annotation):
        self.db_add_annotation(annotation)
    def delete_annotation(self, annotation):
        self.db_delete_annotation(annotation)
    def has_annotation_with_key(self, key):
        return self.db_has_annotation_with_key(key)
    def get_annotation_by_key(self, key):
        return self.db_get_annotation_by_key(key)        

    def getNumFunctions(self):
        """getNumFunctions() -> int - Returns the number of functions """
        return len(self.functions)

    def summon(self):
        # we shouldn't ever call this since we're expanding abstractions
        return None

    @staticmethod
    def make_port_from_module(module, port_type):
        for function in module.functions:
            if function.name == 'name':
                port_name = function.params[0].strValue
            if function.name == 'spec':
                port_spec = function.params[0].strValue
        port = Port(id=-1,
                    name=port_name,
                    type=port_type)
        portSpecs = port_spec[1:-1].split(',')
        signature = []
        for s in portSpecs:
            spec = s.split(':', 2)
            signature.append(registry.get_descriptor_by_name(*spec).module)
        port.spec = core.modules.module_registry.PortSpec(signature)
        return port

    def make_registry(self):
        reg_module = \
            registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', 
                                            self.name).module
        self._registry = ModuleRegistry()
        self._registry.add_hierarchy(registry, self)
        for module in self.pipeline.module_list:
            if module.name == 'OutputPort':
                port = self.make_port_from_module(module, 'source')
                self._registry.add_port(reg_module, PortEndPoint.Source, port)
            elif module.name == 'InputPort':
                port = self.make_port_from_module(module, 'destination')
                self._registry.add_port(reg_module, PortEndPoint.Destination, 
                                        port)

    def sourcePorts(self):
        ports = []
        for module in self.pipeline.module_list:
            if module.name == 'OutputPort':
                ports.append(self.make_port_from_module(module, 'source'))
        return ports

    def destinationPorts(self):
        ports = []
        for module in self.pipeline.module_list:
            if module.name == 'InputPort':
                ports.append(self.make_port_from_module(module, 'destination'))
        return ports

    ##########################################################################
    # Operators
    
    def __str__(self):
        """__str__() -> str - Returns a string representation of an 
        AbstractionModule object. 

        """
        rep = '<abstraction_module id="%s" abstraction_id="%s" verion="%s">'
        rep += str(self.location)
        rep += str(self.functions)
        rep += str(self.annotations)
        rep += '</abstraction_module>'
        return  rep % (str(self.id), str(self.abstraction_id), 
                       str(self.version))

    def __eq__(self, other):
        """ __eq__(other: AbstractionModule) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f,g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f,g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)