Пример #1
0
class Module(DBModule):
    """ Represents a module from a Pipeline """

    ##########################################################################
    # Constructor and copy

    def __init__(self, *args, **kwargs):
        DBModule.__init__(self, *args, **kwargs)
        if self.cache is None:
            self.cache = 1
        if self.id is None:
            self.id = -1
        if self.location is None:
            self.location = Location(x=-1.0, y=-1.0)
        if self.name is None:
            self.name = ''
        if self.package is None:
            self.package = ''
        if self.version is None:
            self.version = ''
        self.portVisible = set()
        self.registry = None

    def __copy__(self):
        """__copy__() -> Module - Returns a clone of itself"""
        return Module.do_copy(self)

    def do_copy(self, new_ids=False, id_scope=None, id_remap=None):
        cp = DBModule.do_copy(self, new_ids, id_scope, id_remap)
        cp.__class__ = Module
        # cp.registry = copy.copy(self.registry)
        cp.registry = None
        for port_spec in cp.db_portSpecs:
            cp.add_port_to_registry(port_spec)
        cp.portVisible = copy.copy(self.portVisible)
        return cp

    @staticmethod
    def convert(_module):
	_module.__class__ = Module
	_module.registry = None
        for _port_spec in _module.db_portSpecs:
            PortSpec.convert(_port_spec)
            _module.add_port_to_registry(_port_spec)
        if _module.db_location:
            Location.convert(_module.db_location)
	for _function in _module.db_functions:
	    ModuleFunction.convert(_function)
        for _annotation in _module.db_get_annotations():
            Annotation.convert(_annotation)

        _module.portVisible = set()

    ##########################################################################

    id = DBModule.db_id
    cache = DBModule.db_cache
    annotations = DBModule.db_annotations
    location = DBModule.db_location
    center = DBModule.db_location
    name = DBModule.db_name
    label = DBModule.db_name
    namespace = DBModule.db_namespace
    package = DBModule.db_package
    tag = DBModule.db_tag
    version = DBModule.db_version

    # type check this (list, hash)
    def _get_functions(self):
        self.db_functions.sort(key=lambda x: x.db_pos)
        return self.db_functions
    def _set_functions(self, functions):
	# want to convert functions to hash...?
        self.db_functions = functions
    functions = property(_get_functions, _set_functions)
    def add_function(self, function):
        self.db_add_function(function)

    def add_annotation(self, annotation):
        self.db_add_annotation(annotation)
    def delete_annotation(self, annotation):
        self.db_delete_annotation(annotation)
    def has_annotation_with_key(self, key):
        return self.db_has_annotation_with_key(key)
    def get_annotation_by_key(self, key):
        return self.db_get_annotation_by_key(key)        

    def _get_port_specs(self):
        return self.db_portSpecs_id_index
    port_specs = property(_get_port_specs)
    def has_portSpec_with_name(self, name):
        return self.db_has_portSpec_with_name(name)
    def get_portSpec_by_name(self, name):
        return self.db_get_portSpec_by_name(name)

    def summon(self):
        get = registry.get_descriptor_by_name
        result = get(self.package, self.name, self.namespace).module()
        if self.cache != 1:
            result.is_cacheable = lambda *args: False
        if hasattr(result, 'srcPortsOrder'):
            result.srcPortsOrder = [p.name for p in self.destinationPorts()]
        result.registry = self.registry or registry
        return result

    def getNumFunctions(self):
        """getNumFunctions() -> int - Returns the number of functions """
        return len(self.functions)


    def sourcePorts(self):
        """sourcePorts() -> list of Port 
        Returns list of source (output) ports module supports.

        """

        ports = registry.module_source_ports(True, self.package, self.name, self.namespace)
        if self.registry:
            ports.extend(self.registry.module_source_ports(False, self.package, self.name, self.namespace))
        return ports

    def destinationPorts(self):
        """destinationPorts() -> list of Port 
        Returns list of destination (input) ports module supports

        """
        ports = registry.module_destination_ports(True, self.package, self.name, self.namespace)
        if self.registry:
            ports.extend(self.registry.module_destination_ports(False, self.package, self.name, self.namespace))
        return ports

    def add_port_to_registry(self, port_spec):
        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        if self.registry is None:
            self.registry = ModuleRegistry()
            self.registry.add_hierarchy(registry, self)

        if port_spec.type == 'input':
            endpoint = PortEndPoint.Destination
        else:
            endpoint = PortEndPoint.Source
        portSpecs = port_spec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port()
        port.name = port_spec.name
        port.spec = core.modules.module_registry.PortSpec(signature)
        self.registry.add_port(module, endpoint, port)        

    def delete_port_from_registry(self, id):
        if not id in self.port_specs:
            raise VistrailsInternalError("id missing in port_specs")
        portSpec = self.port_specs[id]
        portSpecs = portSpec.spec[1:-1].split(',')
        signature = [registry.get_descriptor_from_name_only(spec).module
                     for spec in portSpecs]
        port = Port(signature)
        port.name = portSpec.name
        port.spec = core.modules.module_registry.PortSpec(signature)

        module = \
            registry.get_descriptor_by_name(self.package, self.name, self.namespace).module
        assert isinstance(self.registry, ModuleRegistry)

        if portSpec.type == 'input':
            self.registry.delete_input_port(module, port.name)
        else:
            self.registry.delete_output_port(module, port.name)

    ##########################################################################
    # Debugging

    def show_comparison(self, other):
        if type(other) != type(self):
            print "Type mismatch"
            print type(self), type(other)
        elif self.id != other.id:
            print "id mismatch"
            print self.id, other.id
        elif self.name != other.name:
            print "name mismatch"
            print self.name, other.name
        elif self.cache != other.cache:
            print "cache mismatch"
            print self.cache, other.cache
        elif self.location != other.location:
            print "location mismatch"
            # FIXME Location has no show_comparison
            # self.location.show_comparison(other.location)
        elif len(self.functions) != len(other.functions):
            print "function length mismatch"
            print len(self.functions), len(other.functions)
        else:
            for f, g in izip(self.functions, other.functions):
                if f != g:
                    print "function mismatch"
                    f.show_comparison(g)
                    return
            print "No difference found"
            assert self == other

    ##########################################################################
    # Operators

    def __str__(self):
        """__str__() -> str Returns a string representation of itself. """
        return ("(Module '%s' id=%s functions:%s port_specs:%s)@%X" %
                (self.name,
                 self.id,
                 [str(f) for f in self.functions],
                 [str(port_spec) for port_spec in self.db_portSpecs],
                 id(self)))

    def __eq__(self, other):
        """ __eq__(other: Module) -> boolean
        Returns True if self and other have the same attributes. Used by == 
        operator. 
        
        """
        if type(other) != type(self):
            return False
        if self.name != other.name:
            return False
        if self.cache != other.cache:
            return False
        if self.location != other.location:
            return False
        if len(self.functions) != len(other.functions):
            return False
        if len(self.annotations) != len(other.annotations):
            return False
        for f, g in izip(self.functions, other.functions):
            if f != g:
                return False
        for f, g in izip(self.annotations, other.annotations):
            if f != g:
                return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)