class Group(DBGroup, Module): ########################################################################## # Constructors and copy def __init__(self, *args, **kwargs): if 'pipeline' in kwargs: kwargs['workflow'] = kwargs['pipeline'] del kwargs['pipeline'] DBGroup.__init__(self, *args, **kwargs) if self.cache is None: self.cache = 1 if self.id is None: self.id = -1 if self.location is None: self.location = Location(x=-1.0, y=-1.0) if self.name is None: self.name = '' if self.package is None: self.package = '' if self.version is None: self.version = '' self.portVisible = set() self._registry = None def __copy__(self): return Group.do_copy(self) def do_copy(self, new_ids=False, id_scope=None, id_remap=None): cp = DBGroup.do_copy(self, new_ids, id_scope, id_remap) cp.__class__ = Group cp._registry = None # for port_spec in cp.db_portSpecs: # cp.add_port_to_registry(port_spec) cp.portVisible = copy.copy(self.portVisible) return cp @staticmethod def convert(_group): if _group.__class__ == Group: return _group.__class__ = Group _group._registry = None _group.portVisible = set() if _group.db_location: Location.convert(_group.db_location) if _group.db_workflow: from core.vistrail.pipeline import Pipeline Pipeline.convert(_group.db_workflow) for _function in _group.db_functions: ModuleFunction.convert(_function) for _annotation in _group.db_get_annotations(): Annotation.convert(_annotation) ########################################################################## # Properties # We need to repeat these here because Module uses DBModule. ... id = DBGroup.db_id cache = DBGroup.db_cache annotations = DBGroup.db_annotations location = DBGroup.db_location center = DBGroup.db_location name = DBGroup.db_name label = DBGroup.db_name namespace = DBGroup.db_namespace package = DBGroup.db_package tag = DBGroup.db_tag version = DBGroup.db_version def summon(self): # define this so that pipeline is copied over... pass def is_group(self): return True pipeline = DBGroup.db_workflow def _get_registry(self): if not self._registry: # print 'making registry' self.make_registry() return self._registry registry = property(_get_registry) # override these from the Module class with defaults def _get_port_specs(self): return dict() port_specs = property(_get_port_specs) def has_portSpec_with_name(self, name): return False def get_portSpec_by_name(self, name): return None @staticmethod def make_port_from_module(module, port_type): for function in module.functions: if function.name == 'name': port_name = function.params[0].strValue print ' port_name:', port_name if function.name == 'spec': port_spec = function.params[0].strValue #print ' port_spec:', port_spec port = Port(id=-1, name=port_name, type=port_type) portSpecs = port_spec[1:-1].split(',') signature = [] for s in portSpecs: spec = s.split(':', 2) signature.append(registry.get_descriptor_by_name(*spec).module) port.spec = core.modules.module_registry.PortSpec(signature) return port def make_registry(self): reg_module = \ registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', self.name).module self._registry = ModuleRegistry() self._registry.add_hierarchy(registry, self) for module in self.pipeline.module_list: print 'module:', module.name if module.name == 'OutputPort': port = self.make_port_from_module(module, 'source') self._registry.add_port(reg_module, PortEndPoint.Source, port) elif module.name == 'InputPort': port = self.make_port_from_module(module, 'destination') self._registry.add_port(reg_module, PortEndPoint.Destination, port) def sourcePorts(self): return self.registry.module_source_ports(False, self.package, self.name, self.namespace) def destinationPorts(self): return self.registry.module_destination_ports(False, self.package, self.name, self.namespace) ########################################################################## # Operators def __str__(self): """__str__() -> str - Returns a string representation of an GroupModule object. """ rep = '<group id="%s" abstraction_id="%s" verion="%s">' rep += str(self.location) rep += str(self.functions) rep += str(self.annotations) rep += '</group>' return rep % (str(self.id), str(self.abstraction_id), str(self.version)) def __eq__(self, other): """ __eq__(other: GroupModule) -> boolean Returns True if self and other have the same attributes. Used by == operator. """ if type(other) != type(self): return False if self.location != other.location: return False if len(self.functions) != len(other.functions): return False if len(self.annotations) != len(other.annotations): return False for f,g in izip(self.functions, other.functions): if f != g: return False for f,g in izip(self.annotations, other.annotations): if f != g: return False return True def __ne__(self, other): return not self.__eq__(other)
class Module(DBModule): """ Represents a module from a Pipeline """ ########################################################################## # Constructor and copy def __init__(self, *args, **kwargs): DBModule.__init__(self, *args, **kwargs) if self.cache is None: self.cache = 1 if self.id is None: self.id = -1 if self.location is None: self.location = Location(x=-1.0, y=-1.0) if self.name is None: self.name = '' if self.package is None: self.package = '' if self.version is None: self.version = '' self.portVisible = set() self.registry = None def __copy__(self): """__copy__() -> Module - Returns a clone of itself""" return Module.do_copy(self) def do_copy(self, new_ids=False, id_scope=None, id_remap=None): cp = DBModule.do_copy(self, new_ids, id_scope, id_remap) cp.__class__ = Module # cp.registry = copy.copy(self.registry) cp.registry = None for port_spec in cp.db_portSpecs: cp.add_port_to_registry(port_spec) cp.portVisible = copy.copy(self.portVisible) return cp @staticmethod def convert(_module): _module.__class__ = Module _module.registry = None for _port_spec in _module.db_portSpecs: PortSpec.convert(_port_spec) _module.add_port_to_registry(_port_spec) if _module.db_location: Location.convert(_module.db_location) for _function in _module.db_functions: ModuleFunction.convert(_function) for _annotation in _module.db_get_annotations(): Annotation.convert(_annotation) _module.portVisible = set() ########################################################################## id = DBModule.db_id cache = DBModule.db_cache annotations = DBModule.db_annotations location = DBModule.db_location center = DBModule.db_location name = DBModule.db_name label = DBModule.db_name namespace = DBModule.db_namespace package = DBModule.db_package tag = DBModule.db_tag version = DBModule.db_version # type check this (list, hash) def _get_functions(self): self.db_functions.sort(key=lambda x: x.db_pos) return self.db_functions def _set_functions(self, functions): # want to convert functions to hash...? self.db_functions = functions functions = property(_get_functions, _set_functions) def add_function(self, function): self.db_add_function(function) def add_annotation(self, annotation): self.db_add_annotation(annotation) def delete_annotation(self, annotation): self.db_delete_annotation(annotation) def has_annotation_with_key(self, key): return self.db_has_annotation_with_key(key) def get_annotation_by_key(self, key): return self.db_get_annotation_by_key(key) def _get_port_specs(self): return self.db_portSpecs_id_index port_specs = property(_get_port_specs) def has_portSpec_with_name(self, name): return self.db_has_portSpec_with_name(name) def get_portSpec_by_name(self, name): return self.db_get_portSpec_by_name(name) def summon(self): get = registry.get_descriptor_by_name result = get(self.package, self.name, self.namespace).module() if self.cache != 1: result.is_cacheable = lambda *args: False if hasattr(result, 'srcPortsOrder'): result.srcPortsOrder = [p.name for p in self.destinationPorts()] result.registry = self.registry or registry return result def getNumFunctions(self): """getNumFunctions() -> int - Returns the number of functions """ return len(self.functions) def sourcePorts(self): """sourcePorts() -> list of Port Returns list of source (output) ports module supports. """ ports = registry.module_source_ports(True, self.package, self.name, self.namespace) if self.registry: ports.extend(self.registry.module_source_ports(False, self.package, self.name, self.namespace)) return ports def destinationPorts(self): """destinationPorts() -> list of Port Returns list of destination (input) ports module supports """ ports = registry.module_destination_ports(True, self.package, self.name, self.namespace) if self.registry: ports.extend(self.registry.module_destination_ports(False, self.package, self.name, self.namespace)) return ports def add_port_to_registry(self, port_spec): module = \ registry.get_descriptor_by_name(self.package, self.name, self.namespace).module if self.registry is None: self.registry = ModuleRegistry() self.registry.add_hierarchy(registry, self) if port_spec.type == 'input': endpoint = PortEndPoint.Destination else: endpoint = PortEndPoint.Source portSpecs = port_spec.spec[1:-1].split(',') signature = [registry.get_descriptor_from_name_only(spec).module for spec in portSpecs] port = Port() port.name = port_spec.name port.spec = core.modules.module_registry.PortSpec(signature) self.registry.add_port(module, endpoint, port) def delete_port_from_registry(self, id): if not id in self.port_specs: raise VistrailsInternalError("id missing in port_specs") portSpec = self.port_specs[id] portSpecs = portSpec.spec[1:-1].split(',') signature = [registry.get_descriptor_from_name_only(spec).module for spec in portSpecs] port = Port(signature) port.name = portSpec.name port.spec = core.modules.module_registry.PortSpec(signature) module = \ registry.get_descriptor_by_name(self.package, self.name, self.namespace).module assert isinstance(self.registry, ModuleRegistry) if portSpec.type == 'input': self.registry.delete_input_port(module, port.name) else: self.registry.delete_output_port(module, port.name) ########################################################################## # Debugging def show_comparison(self, other): if type(other) != type(self): print "Type mismatch" print type(self), type(other) elif self.id != other.id: print "id mismatch" print self.id, other.id elif self.name != other.name: print "name mismatch" print self.name, other.name elif self.cache != other.cache: print "cache mismatch" print self.cache, other.cache elif self.location != other.location: print "location mismatch" # FIXME Location has no show_comparison # self.location.show_comparison(other.location) elif len(self.functions) != len(other.functions): print "function length mismatch" print len(self.functions), len(other.functions) else: for f, g in izip(self.functions, other.functions): if f != g: print "function mismatch" f.show_comparison(g) return print "No difference found" assert self == other ########################################################################## # Operators def __str__(self): """__str__() -> str Returns a string representation of itself. """ return ("(Module '%s' id=%s functions:%s port_specs:%s)@%X" % (self.name, self.id, [str(f) for f in self.functions], [str(port_spec) for port_spec in self.db_portSpecs], id(self))) def __eq__(self, other): """ __eq__(other: Module) -> boolean Returns True if self and other have the same attributes. Used by == operator. """ if type(other) != type(self): return False if self.name != other.name: return False if self.cache != other.cache: return False if self.location != other.location: return False if len(self.functions) != len(other.functions): return False if len(self.annotations) != len(other.annotations): return False for f, g in izip(self.functions, other.functions): if f != g: return False for f, g in izip(self.annotations, other.annotations): if f != g: return False return True def __ne__(self, other): return not self.__eq__(other)
class AbstractionModule(DBAbstractionRef): ########################################################################## # Constructors and copy def __init__(self, *args, **kwargs): DBAbstractionRef.__init__(self, *args, **kwargs) if self.id is None: self.id = -1 self.portVisible = set() self._registry = None self.abstraction = None # FIXME should we have a registry for an abstraction module? def __copy__(self): return AbstractionModule.do_copy(self) def do_copy(self, new_ids=False, id_scope=None, id_remap=None): cp = DBAbstractionRef.do_copy(self, new_ids, id_scope, id_remap) cp.__class__ = AbstractionModule cp.portVisible = copy.copy(self.portVisible) cp._registry = self._registry cp.abstraction = self.abstraction return cp @staticmethod def convert(_abstraction_module): if _abstraction_module.__class__ == AbstractionModule: return _abstraction_module.__class__ = AbstractionModule if _abstraction_module.db_location: Location.convert(_abstraction_module.db_location) for _function in _abstraction_module.db_functions: ModuleFunction.convert(_function) for _annotation in _abstraction_module.db_get_annotations(): Annotation.convert(_annotation) _abstraction_module.portVisible = set() _abstraction_module._registry = None _abstraction_module.abstraction = None ########################################################################## # Properties id = DBAbstractionRef.db_id cache = DBAbstractionRef.db_cache abstraction_id = DBAbstractionRef.db_abstraction_id location = DBAbstractionRef.db_location center = DBAbstractionRef.db_location version = DBAbstractionRef.db_version tag = DBAbstractionRef.db_name label = DBAbstractionRef.db_name name = 'Abstraction' package = 'edu.utah.sci.vistrails.basic' namespace = None annotations = DBAbstractionRef.db_annotations def _get_functions(self): self.db_functions.sort(key=lambda x: x.db_pos) return self.db_functions def _set_functions(self, functions): # want to convert functions to hash...? self.db_functions = functions functions = property(_get_functions, _set_functions) def _get_pipeline(self): from core.vistrail.pipeline import Pipeline import db.services.vistrail workflow = db.services.vistrail.materializeWorkflow(self.abstraction, self.version) Pipeline.convert(workflow) return workflow pipeline = property(_get_pipeline) def _get_registry(self): if not self._registry: self.make_registry() return self._registry registry = property(_get_registry) def add_annotation(self, annotation): self.db_add_annotation(annotation) def delete_annotation(self, annotation): self.db_delete_annotation(annotation) def has_annotation_with_key(self, key): return self.db_has_annotation_with_key(key) def get_annotation_by_key(self, key): return self.db_get_annotation_by_key(key) def getNumFunctions(self): """getNumFunctions() -> int - Returns the number of functions """ return len(self.functions) def summon(self): # we shouldn't ever call this since we're expanding abstractions return None @staticmethod def make_port_from_module(module, port_type): for function in module.functions: if function.name == 'name': port_name = function.params[0].strValue if function.name == 'spec': port_spec = function.params[0].strValue port = Port(id=-1, name=port_name, type=port_type) portSpecs = port_spec[1:-1].split(',') signature = [] for s in portSpecs: spec = s.split(':', 2) signature.append(registry.get_descriptor_by_name(*spec).module) port.spec = core.modules.module_registry.PortSpec(signature) return port def make_registry(self): reg_module = \ registry.get_descriptor_by_name('edu.utah.sci.vistrails.basic', self.name).module self._registry = ModuleRegistry() self._registry.add_hierarchy(registry, self) for module in self.pipeline.module_list: if module.name == 'OutputPort': port = self.make_port_from_module(module, 'source') self._registry.add_port(reg_module, PortEndPoint.Source, port) elif module.name == 'InputPort': port = self.make_port_from_module(module, 'destination') self._registry.add_port(reg_module, PortEndPoint.Destination, port) def sourcePorts(self): ports = [] for module in self.pipeline.module_list: if module.name == 'OutputPort': ports.append(self.make_port_from_module(module, 'source')) return ports def destinationPorts(self): ports = [] for module in self.pipeline.module_list: if module.name == 'InputPort': ports.append(self.make_port_from_module(module, 'destination')) return ports ########################################################################## # Operators def __str__(self): """__str__() -> str - Returns a string representation of an AbstractionModule object. """ rep = '<abstraction_module id="%s" abstraction_id="%s" verion="%s">' rep += str(self.location) rep += str(self.functions) rep += str(self.annotations) rep += '</abstraction_module>' return rep % (str(self.id), str(self.abstraction_id), str(self.version)) def __eq__(self, other): """ __eq__(other: AbstractionModule) -> boolean Returns True if self and other have the same attributes. Used by == operator. """ if type(other) != type(self): return False if self.location != other.location: return False if len(self.functions) != len(other.functions): return False if len(self.annotations) != len(other.annotations): return False for f,g in izip(self.functions, other.functions): if f != g: return False for f,g in izip(self.annotations, other.annotations): if f != g: return False return True def __ne__(self, other): return not self.__eq__(other)