Example #1
0
class Group(DBGroup, Module):

    ##########################################################################
    # Constructors and copy

    def __init__(self, *args, **kwargs):
        if 'pipeline' in kwargs:
            kwargs['workflow'] = kwargs['pipeline']
            del kwargs['pipeline']
        DBGroup.__init__(self, *args, **kwargs)
        if self.cache is None:
            self.cache = 1
        if self.id is None:
            self.id = -1
        if self.location is None:
            self.location = Location(x=-1.0, y=-1.0)
        if self.name is None:
            self.name = ''
        if self.package is None:
            self.package = ''
        if self.version is None:
            self.version = ''
        self.portVisible = set()
        self._registry = None

    def __copy__(self):
        return Group.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBGroup.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = Group
        cp._registry = None
#         for port_spec in cp.db_portSpecs:
#             cp.add_port_to_registry(port_spec)
        cp.portVisible = copy.copy(self.portVisible)
        return cp

    @staticmethod
    def convert(_group):
        if _group.__class__ == Group:
            return
        _group.__class__ = Group
        _group._registry = None
        _group.portVisible = set()
        if _group.db_location:
            Location.convert(_group.db_location)
        if _group.db_workflow:
            from core.vistrail.pipeline import Pipeline
            Pipeline.convert(_group.db_workflow)
	for _function in _group.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _group.db_get_annotations():
            Annotation.convert(_annotation)


    ##########################################################################
    # Properties

    # We need to repeat these here because Module uses DBModule. ...
    id = DBGroup.db_id
    cache = DBGroup.db_cache
    annotations = DBGroup.db_annotations
    location = DBGroup.db_location
    center = DBGroup.db_location
    name = DBGroup.db_name
    label = DBGroup.db_name
    namespace = DBGroup.db_namespace
    package = DBGroup.db_package
    tag = DBGroup.db_tag
    version = DBGroup.db_version

    def summon(self):
        # define this so that pipeline is copied over...
        pass

    def is_group(self):
        return True

    pipeline = DBGroup.db_workflow
    
    def _get_registry(self):
        if not self._registry:
            # print 'making registry'
            self.make_registry()
        return self._registry
    registry = property(_get_registry)

    # override these from the Module class with defaults
    def _get_port_specs(self):
        return dict()
    port_specs = property(_get_port_specs)
    def has_portSpec_with_name(self, name):
        return False
    def get_portSpec_by_name(self, name):
        return None

    @staticmethod
    def make_port_from_module(module, port_type):
        for function in module.functions:
            if function.name == 'name':
                port_name = function.params[0].strValue
                print '  port_name:', port_name
            if function.name == 'spec':
                port_spec = function.params[0].strValue
                #print '  port_spec:',  port_spec
        port = Port(id=-1,
                    name=port_name,
                    type=port_type)
        portSpecs = port_spec[1:-1].split(',')
        signature = []
        for s in portSpecs:
            spec = s.split(':', 2)
            signature.append(registry.get_descriptor_by_name(*spec).module)
        port.spec = core.modules.module_registry.PortSpec(signature)
        return port

    def make_registry(self):
        reg_module = \
            registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', 
                                            self.name).module
        self._registry = ModuleRegistry()
        self._registry.add_hierarchy(registry, self)
        for module in self.pipeline.module_list:
            print 'module:', module.name
            if module.name == 'OutputPort':
                port = self.make_port_from_module(module, 'source')
                self._registry.add_port(reg_module, PortEndPoint.Source, port)
            elif module.name == 'InputPort':
                port = self.make_port_from_module(module, 'destination')
                self._registry.add_port(reg_module, PortEndPoint.Destination, 
                                        port)

    def sourcePorts(self):
        return self.registry.module_source_ports(False, self.package,
                                                 self.name, self.namespace)

    def destinationPorts(self):
        return self.registry.module_destination_ports(False, self.package, 
                                                      self.name, self.namespace)

    ##########################################################################
    # Operators
    
    def __str__(self):
        """__str__() -> str - Returns a string representation of an 
        GroupModule object. 

        """
        rep = '<group id="%s" abstraction_id="%s" verion="%s">'
        rep += str(self.location)
        rep += str(self.functions)
        rep += str(self.annotations)
        rep += '</group>'
        return  rep % (str(self.id), str(self.abstraction_id), 
                       str(self.version))

    def __eq__(self, other):
        """ __eq__(other: GroupModule) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f,g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f,g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)
Example #2
0
class Module(DBModule):
    """ Represents a module from a Pipeline """

    ##########################################################################
    # Constructor and copy

    def __init__(self, *args, **kwargs):
        DBModule.__init__(self, *args, **kwargs)
        if self.cache is None:
            self.cache = 1
        if self.id is None:
            self.id = -1
        if self.location is None:
            self.location = Location(x=-1.0, y=-1.0)
        if self.name is None:
            self.name = ''
        if self.package is None:
            self.package = ''
        if self.version is None:
            self.version = ''
        self.portVisible = set()
        self.registry = None

    def __copy__(self):
        """__copy__() -> Module - Returns a clone of itself"""
        return Module.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBModule.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = Module
        # cp.registry = copy.copy(self.registry)
        cp.registry = None
        for port_spec in cp.db_portSpecs:
            cp.add_port_to_registry(port_spec)
        cp.portVisible = copy.copy(self.portVisible)
        return cp

    @staticmethod
    def convert(_module):
	_module.__class__ = Module
	_module.registry = None
        for _port_spec in _module.db_portSpecs:
            PortSpec.convert(_port_spec)
            _module.add_port_to_registry(_port_spec)
        if _module.db_location:
            Location.convert(_module.db_location)
	for _function in _module.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _module.db_get_annotations():
            Annotation.convert(_annotation)

        _module.portVisible = set()

    ##########################################################################

    id = DBModule.db_id
    cache = DBModule.db_cache
    annotations = DBModule.db_annotations
    location = DBModule.db_location
    center = DBModule.db_location
    name = DBModule.db_name
    label = DBModule.db_name
    namespace = DBModule.db_namespace
    package = DBModule.db_package
    tag = DBModule.db_tag
    version = DBModule.db_version

    # type check this (list, hash)
    def _get_functions(self):
        self.db_functions.sort(key=lambda x: x.db_pos)
        return self.db_functions
    def _set_functions(self, functions):
	# want to convert functions to hash...?
        self.db_functions = functions
    functions = property(_get_functions, _set_functions)
    def add_function(self, function):
        self.db_add_function(function)

    def add_annotation(self, annotation):
        self.db_add_annotation(annotation)
    def delete_annotation(self, annotation):
        self.db_delete_annotation(annotation)
    def has_annotation_with_key(self, key):
        return self.db_has_annotation_with_key(key)
    def get_annotation_by_key(self, key):
        return self.db_get_annotation_by_key(key)        

    def _get_port_specs(self):
        return self.db_portSpecs_id_index
    port_specs = property(_get_port_specs)
    def has_portSpec_with_name(self, name):
        return self.db_has_portSpec_with_name(name)
    def get_portSpec_by_name(self, name):
        return self.db_get_portSpec_by_name(name)

    def summon(self):
        get = registry.get_descriptor_by_name
        result = get(self.package, self.name, self.namespace).module()
        if self.cache != 1:
            result.is_cacheable = lambda *args: False
        if hasattr(result, 'srcPortsOrder'):
            result.srcPortsOrder = [p.name for p in self.destinationPorts()]
        result.registry = self.registry or registry
        return result

    def getNumFunctions(self):
        """getNumFunctions() -> int - Returns the number of functions """
        return len(self.functions)


    def sourcePorts(self):
        """sourcePorts() -> list of Port 
        Returns list of source (output) ports module supports.

        """

        ports = registry.module_source_ports(True, self.package, self.name, self.namespace)
        if self.registry:
            ports.extend(self.registry.module_source_ports(False, self.package, self.name, self.namespace))
        return ports

    def destinationPorts(self):
        """destinationPorts() -> list of Port 
        Returns list of destination (input) ports module supports

        """
        ports = registry.module_destination_ports(True, self.package, self.name, self.namespace)
        if self.registry:
            ports.extend(self.registry.module_destination_ports(False, self.package, self.name, self.namespace))
        return ports

    def add_port_to_registry(self, port_spec):
        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        if self.registry is None:
            self.registry = ModuleRegistry()
            self.registry.add_hierarchy(registry, self)

        if port_spec.type == 'input':
            endpoint = PortEndPoint.Destination
        else:
            endpoint = PortEndPoint.Source
        portSpecs = port_spec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port()
        port.name = port_spec.name
        port.spec = core.modules.module_registry.PortSpec(signature)
        self.registry.add_port(module, endpoint, port)        

    def delete_port_from_registry(self, id):
        if not id in self.port_specs:
            raise VistrailsInternalError("id missing in port_specs")
        portSpec = self.port_specs[id]
        portSpecs = portSpec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port(signature)
        port.name = portSpec.name
        port.spec = core.modules.module_registry.PortSpec(signature)

        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        assert isinstance(self.registry, ModuleRegistry)

        if portSpec.type == 'input':
            self.registry.delete_input_port(module, port.name)
        else:
            self.registry.delete_output_port(module, port.name)

    ##########################################################################
    # Debugging

    def show_comparison(self, other):
        if type(other) != type(self):
            print "Type mismatch"
            print type(self), type(other)
        elif self.id != other.id:
            print "id mismatch"
            print self.id, other.id
        elif self.name != other.name:
            print "name mismatch"
            print self.name, other.name
        elif self.cache != other.cache:
            print "cache mismatch"
            print self.cache, other.cache
        elif self.location != other.location:
            print "location mismatch"
            # FIXME Location has no show_comparison
            # self.location.show_comparison(other.location)
        elif len(self.functions) != len(other.functions):
            print "function length mismatch"
            print len(self.functions), len(other.functions)
        else:
            for f, g in izip(self.functions, other.functions):
                if f != g:
                    print "function mismatch"
                    f.show_comparison(g)
                    return
            print "No difference found"
            assert self == other

    ##########################################################################
    # Operators

    def __str__(self):
        """__str__() -> str Returns a string representation of itself. """
        return ("(Module '%s' id=%s functions:%s port_specs:%s)@%X" %
                (self.name,
                 self.id,
                 [str(f) for f in self.functions],
                 [str(port_spec) for port_spec in self.db_portSpecs],
                 id(self)))

    def __eq__(self, other):
        """ __eq__(other: Module) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.name != other.name:
            return False
        if self.cache != other.cache:
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f, g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f, g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)
class AbstractionModule(DBAbstractionRef):

    ##########################################################################
    # Constructors and copy

    def __init__(self, *args, **kwargs):
        DBAbstractionRef.__init__(self, *args, **kwargs)
        if self.id is None:
            self.id = -1
        self.portVisible = set()
        self._registry = None
        self.abstraction = None
        # FIXME should we have a registry for an abstraction module?

    def __copy__(self):
        return AbstractionModule.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBAbstractionRef.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = AbstractionModule
        cp.portVisible = copy.copy(self.portVisible)
        cp._registry = self._registry
        cp.abstraction = self.abstraction
        return cp

    @staticmethod
    def convert(_abstraction_module):
        if _abstraction_module.__class__ == AbstractionModule:
            return
        _abstraction_module.__class__ = AbstractionModule
        if _abstraction_module.db_location:
            Location.convert(_abstraction_module.db_location)
	for _function in _abstraction_module.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _abstraction_module.db_get_annotations():
            Annotation.convert(_annotation)
        _abstraction_module.portVisible = set()
        _abstraction_module._registry = None
        _abstraction_module.abstraction = None


    ##########################################################################
    # Properties

    id = DBAbstractionRef.db_id
    cache = DBAbstractionRef.db_cache
    abstraction_id = DBAbstractionRef.db_abstraction_id
    location = DBAbstractionRef.db_location
    center = DBAbstractionRef.db_location
    version = DBAbstractionRef.db_version
    tag = DBAbstractionRef.db_name
    label = DBAbstractionRef.db_name
    name = 'Abstraction'
    package = 'edu.utah.sci.vistrails.basic'
    namespace = None
    annotations = DBAbstractionRef.db_annotations
    
    def _get_functions(self):
        self.db_functions.sort(key=lambda x: x.db_pos)
        return self.db_functions
    def _set_functions(self, functions):
	# want to convert functions to hash...?
        self.db_functions = functions
    functions = property(_get_functions, _set_functions)

    def _get_pipeline(self):
        from core.vistrail.pipeline import Pipeline
        import db.services.vistrail
        workflow = db.services.vistrail.materializeWorkflow(self.abstraction, 
                                                            self.version)
        Pipeline.convert(workflow)
        return workflow
    pipeline = property(_get_pipeline)

    def _get_registry(self):
        if not self._registry:
            self.make_registry()
        return self._registry
    registry = property(_get_registry)

    def add_annotation(self, annotation):
        self.db_add_annotation(annotation)
    def delete_annotation(self, annotation):
        self.db_delete_annotation(annotation)
    def has_annotation_with_key(self, key):
        return self.db_has_annotation_with_key(key)
    def get_annotation_by_key(self, key):
        return self.db_get_annotation_by_key(key)        

    def getNumFunctions(self):
        """getNumFunctions() -> int - Returns the number of functions """
        return len(self.functions)

    def summon(self):
        # we shouldn't ever call this since we're expanding abstractions
        return None

    @staticmethod
    def make_port_from_module(module, port_type):
        for function in module.functions:
            if function.name == 'name':
                port_name = function.params[0].strValue
            if function.name == 'spec':
                port_spec = function.params[0].strValue
        port = Port(id=-1,
                    name=port_name,
                    type=port_type)
        portSpecs = port_spec[1:-1].split(',')
        signature = []
        for s in portSpecs:
            spec = s.split(':', 2)
            signature.append(registry.get_descriptor_by_name(*spec).module)
        port.spec = core.modules.module_registry.PortSpec(signature)
        return port

    def make_registry(self):
        reg_module = \
            registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', 
                                            self.name).module
        self._registry = ModuleRegistry()
        self._registry.add_hierarchy(registry, self)
        for module in self.pipeline.module_list:
            if module.name == 'OutputPort':
                port = self.make_port_from_module(module, 'source')
                self._registry.add_port(reg_module, PortEndPoint.Source, port)
            elif module.name == 'InputPort':
                port = self.make_port_from_module(module, 'destination')
                self._registry.add_port(reg_module, PortEndPoint.Destination, 
                                        port)

    def sourcePorts(self):
        ports = []
        for module in self.pipeline.module_list:
            if module.name == 'OutputPort':
                ports.append(self.make_port_from_module(module, 'source'))
        return ports

    def destinationPorts(self):
        ports = []
        for module in self.pipeline.module_list:
            if module.name == 'InputPort':
                ports.append(self.make_port_from_module(module, 'destination'))
        return ports

    ##########################################################################
    # Operators
    
    def __str__(self):
        """__str__() -> str - Returns a string representation of an 
        AbstractionModule object. 

        """
        rep = '<abstraction_module id="%s" abstraction_id="%s" verion="%s">'
        rep += str(self.location)
        rep += str(self.functions)
        rep += str(self.annotations)
        rep += '</abstraction_module>'
        return  rep % (str(self.id), str(self.abstraction_id), 
                       str(self.version))

    def __eq__(self, other):
        """ __eq__(other: AbstractionModule) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f,g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f,g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)