def __init__(self, value): self.value = value if isinstance(self.value, basestring): self.type = get_module_registry().get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'String') else: # isinstance(value, float): self.type = get_module_registry().get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'Float')
def make_port_specs(self): self._port_specs = {} self._input_port_specs = [] self._output_port_specs = [] self._input_remap = {} self._output_remap = {} if self.pipeline is None: return registry = get_module_registry() for module in self.pipeline.module_list: if module.name == 'OutputPort' and module.package == basic_pkg: (port_name, sigstring, optional, depth, _) = \ self.get_port_spec_info(module) port_spec = registry.create_port_spec(port_name, 'output', None, sigstring, optional, depth=depth) self._port_specs[(port_name, 'output')] = port_spec self._output_port_specs.append(port_spec) self._output_remap[port_name] = module elif module.name == 'InputPort' and module.package == basic_pkg: (port_name, sigstring, optional, depth, _) = \ self.get_port_spec_info(module) port_spec = registry.create_port_spec(port_name, 'input', None, sigstring, optional, depth=depth) self._port_specs[(port_name, 'input')] = port_spec self._input_port_specs.append(port_spec) self._input_remap[port_name] = module
def _get_module_descriptor(self): if self._module_descriptor is None or \ self._module_descriptor() is None: reg = get_module_registry() self._module_descriptor = \ weakref.ref(reg.get_descriptor_by_name(*self.descriptor_info)) return self._module_descriptor()
def updateModule(self, module): if self.updateLocked: return self.check_need_save_changes() self.module = module self.confWidget.setUpdatesEnabled(False) self.confWidget.setVisible(False) self.confWidget.clear() if module and self.controller: registry = get_module_registry() getter = registry.get_configuration_widget widgetType = None try: widgetType = \ getter(module.package, module.name, module.namespace) except ModuleRegistryException: pass if not widgetType: widgetType = DefaultModuleConfigurationWidget widget = widgetType(module, self.controller) self.confWidget.setUpWidget(widget) self.connect(widget, QtCore.SIGNAL("doneConfigure"), self.configureDone) self.connect(widget, QtCore.SIGNAL("stateChanged"), self.stateChanged) self.confWidget.setUpdatesEnabled(True) self.confWidget.setVisible(True) self.hasChanges = False # we need to reset the title in case there were changes self.setWindowTitle("Module Configuration")
def inspect_spreadsheet_cells(self, pipeline): """ inspect_spreadsheet_cells(pipeline: Pipeline) -> None Inspect the pipeline to see how many cells is needed """ registry = get_module_registry() self.spreadsheet_cells = [] if not pipeline: return def find_spreadsheet_cells(pipeline, root_id=None): if root_id is None: root_id = [] # Sometimes we run without the spreadsheet! spreadsheet_pkg = \ '%s.spreadsheet' % get_vistrails_default_pkg_prefix() if registry.has_module(spreadsheet_pkg, 'SpreadsheetCell'): # First pass to check cells types cellType = \ registry.get_descriptor_by_name(spreadsheet_pkg, 'SpreadsheetCell').module for mId, module in pipeline.modules.iteritems(): desc = registry.get_descriptor_by_name(module.package, module.name, module.namespace) if issubclass(desc.module, cellType): self.spreadsheet_cells.append(root_id + [mId]) for subworkflow_id in self.find_subworkflows(pipeline): subworkflow = pipeline.modules[subworkflow_id] if subworkflow.pipeline is not None: find_spreadsheet_cells(subworkflow.pipeline, root_id + [subworkflow_id]) find_spreadsheet_cells(pipeline)
def can_convert(cls, sub_descs, super_descs): from vistrails.core.modules.module_registry import get_module_registry from vistrails.core.system import get_vistrails_basic_pkg_id reg = get_module_registry() basic_pkg = get_vistrails_basic_pkg_id() variant_desc = reg.get_descriptor_by_name(basic_pkg, 'Variant') desc = reg.get_descriptor(cls) def check_types(sub_descs, super_descs): for (sub_desc, super_desc) in izip(sub_descs, super_descs): if (sub_desc == variant_desc or super_desc == variant_desc): continue if not reg.is_descriptor_subclass(sub_desc, super_desc): return False return True in_port = reg.get_port_spec_from_descriptor( desc, 'in_value', 'input') if (len(sub_descs) != len(in_port.descriptors()) or not check_types(sub_descs, in_port.descriptors())): return False out_port = reg.get_port_spec_from_descriptor( desc, 'out_value', 'output') if (len(out_port.descriptors()) != len(super_descs) or not check_types(out_port.descriptors(), super_descs)): return False return True
def register_self(cls, **kwargs): registry = get_module_registry() def resolve_type(t): if isinstance(t, tuple): return registry.get_descriptor_by_name(*t).module elif isinstance(t, type): return t else: assert False, ("Unknown type " + str(type(t))) registry.add_module(cls, **kwargs) try: ips = cls.input_ports except AttributeError: pass else: for (port_name, types) in ips: registry.add_input_port(cls, port_name, list(resolve_type(t) for t in types)) try: ops = cls.output_ports except AttributeError: pass else: for (port_name, types) in ops: registry.add_output_port(cls, port_name, list(resolve_type(t) for t in types))
def find_descriptor(controller, pipeline, module_id, desired_version=''): reg = get_module_registry() get_descriptor = reg.get_descriptor_by_name pm = get_package_manager() invalid_module = pipeline.modules[module_id] mpkg, mname, mnamespace, mid = (invalid_module.package, invalid_module.name, invalid_module.namespace, invalid_module.id) pkg = pm.get_package(mpkg) desired_version = '' d = None # don't check for abstraction/subworkflow since the old module # could be a subworkflow if reg.has_abs_upgrade(*invalid_module.descriptor_info): return reg.get_abs_upgrade(*invalid_module.descriptor_info) try: try: d = get_descriptor(mpkg, mname, mnamespace, '', desired_version) except MissingModule, e: r = None if pkg.can_handle_missing_modules(): r = pkg.handle_missing_module(controller, module_id, pipeline) d = get_descriptor(mpkg, mname, mnamespace, '', desired_version) if not r: raise e except MissingModule, e: return None
def check_port_spec(module, port_name, port_type, descriptor=None, sigstring=None): basic_pkg = get_vistrails_basic_pkg_id() reg = get_module_registry() found = False try: if descriptor is not None: s = reg.get_port_spec_from_descriptor(descriptor, port_name, port_type) found = True spec_tuples = parse_port_spec_string(sigstring, basic_pkg) for i in xrange(len(spec_tuples)): spec_tuple = spec_tuples[i] port_pkg = reg.get_package_by_name(spec_tuple[0]) if port_pkg.identifier != spec_tuple[0]: # we have an old identifier spec_tuples[i] = (port_pkg.identifier,) + spec_tuple[1:] sigstring = create_port_spec_string(spec_tuples) # sigstring = expand_port_spec_string(sigstring, basic_pkg) if s.sigstring != sigstring: msg = ('%s port "%s" of module "%s" exists, but ' 'signatures differ "%s" != "%s"') % \ (port_type.capitalize(), port_name, module.name, s.sigstring, sigstring) raise UpgradeWorkflowError(msg, module, port_name, port_type) except MissingPort: pass if not found and \ not module.has_portSpec_with_name((port_name, port_type)): msg = '%s port "%s" of module "%s" does not exist.' % \ (port_type.capitalize(), port_name, module.name) raise UpgradeWorkflowError(msg, module, port_name, port_type)
def loadWidget( self, pipeline): from PyQt4 import QtGui aliases = pipeline.aliases widget = QtGui.QWidget() layout = QtGui.QVBoxLayout() hidden_aliases = self.plot.computeHiddenAliases() for name, (type, oId, parentType, parentId, mId) in aliases.iteritems(): if name not in hidden_aliases: p = pipeline.db_get_object(type, oId) if p.identifier == '': idn = 'edu.utah.sci.vistrails.basic' else: idn = p.identifier reg = get_module_registry() p_module = reg.get_module_by_name(idn, p.type, p.namespace) if p_module is not None: widget_type = get_widget_class(p_module) else: widget_type = StandardConstantWidget p_widget = widget_type(p, None) a_layout = QtGui.QHBoxLayout() label = QtGui.QLabel(name) a_layout.addWidget(label) a_layout.addWidget(p_widget) layout.addLayout(a_layout) self.alias_widgets[name] = p_widget widget.setLayout(layout) return widget
def compute(self): reg = get_module_registry() tf = self.getInputFromPort('TransferFunction') new_tf = copy.copy(tf) if self.hasInputFromPort('Input'): port = self.getInputFromPort('Input') algo = port.vtkInstance.GetProducer() output = algo.GetOutput(port.vtkInstance.GetIndex()) (new_tf._min_range, new_tf._max_range) = output.GetScalarRange() elif self.hasInputFromPort('Dataset'): algo = self.getInputFromPort('Dataset').vtkInstance output = algo (new_tf._min_range, new_tf._max_range) = output.GetScalarRange() else: (new_tf._min_range, new_tf._max_range) = self.getInputFromPort('Range') self.setResult('TransferFunction', new_tf) (of,cf) = new_tf.get_vtk_transfer_functions() of_module = reg.get_descriptor_by_name(vtk_pkg_identifier, 'vtkPiecewiseFunction').module() of_module.vtkInstance = of cf_module = reg.get_descriptor_by_name(vtk_pkg_identifier, 'vtkColorTransferFunction').module() cf_module.vtkInstance = cf self.setResult('vtkPicewiseFunction', of_module) self.setResult('vtkColorTransferFunction', cf_module)
def updateController(self, controller): """ updateController(controller: VistrailController) -> None Construct input forms for a controller's variables """ # we shouldn't do this whenver the controller changes... if self.controller != controller: self.controller = controller if self.updateLocked: return self.vWidget.clear() if controller: reg = module_registry.get_module_registry() for var in [v for v in controller.vistrail.vistrail_vars]: try: descriptor = reg.get_descriptor_by_name(var.package, var.module, var.namespace) except module_registry.ModuleRegistryException: debug.critical("Missing Module Descriptor for vistrail" " variable %s\nPackage: %s\nType: %s" "\nNamespace: %s" % \ (var.name, var.package, var.module, var.namespace)) continue self.vWidget.addVariable(var.uuid, var.name, descriptor, var.value) self.vWidget.showPromptByChildren() else: self.vWidget.showPrompt(False)
def compute(self): params = self.readInputs() signature = self.getId(params) jm = JobMonitor.getInstance() # use cached job if it exist cache = jm.getCache(signature) if cache: self.setResults(cache.parameters) return # check if job is running job = jm.getJob(signature) if job: params = job.parameters else: # start job params = self.startJob(params) # set visible name # check custom name m = self.interpreter._persistent_pipeline.modules[self.id] if '__desc__' in m.db_annotations_key_index: name = m.get_annotation_by_key('__desc__').value.strip() else: reg = get_module_registry() name = reg.get_descriptor(self.__class__).name jm.addJob(signature, params, name) # call method to check job jm.checkJob(self, signature, self.getMonitor(params)) # job is finished, set outputs params = self.finishJob(params) self.setResults(params) cache = jm.setCache(signature, params)
def handle_module_upgrade_request(controller, module_id, pipeline): old_module = pipeline.modules[module_id] if old_module.name == "JSONFile" and old_module.version != "0.1.5" and old_module.namespace == "read": from vistrails.core.db.action import create_action from vistrails.core.modules.module_registry import get_module_registry from .read.read_json import JSONObject reg = get_module_registry() new_desc = reg.get_descriptor(JSONObject) new_module = controller.create_module_from_descriptor(new_desc, old_module.location.x, old_module.location.y) actions = UpgradeWorkflowHandler.replace_generic(controller, pipeline, old_module, new_module) new_function = controller.create_function(new_module, "key_name", ["_key"]) actions.append(create_action([("add", new_function, "module", new_module.id)])) return actions module_remap = { "read|csv|CSVFile": [(None, "0.1.1", "read|CSVFile", {"src_port_remap": {"self": "value"}})], "read|numpy|NumPyArray": [(None, "0.1.1", "read|NumPyArray", {"src_port_remap": {"self": "value"}})], "read|CSVFile": [("0.1.1", "0.1.2", None, {"src_port_remap": {"self": "value"}}), ("0.1.3", "0.1.5", None, {})], "read|NumPyArray": [("0.1.1", "0.1.2", None, {"src_port_remap": {"self": "value"}})], "read|ExcelSpreadsheet": [ ("0.1.1", "0.1.2", None, {"src_port_remap": {"self": "value"}}), ("0.1.3", "0.1.4", None, {}), ], "read|JSONFile": [(None, "0.1.5", "read|JSONObject")], } return UpgradeWorkflowHandler.remap_module(controller, module_id, pipeline, module_remap)
def __subclasscheck__(self, other): if not issubclass(other, type): raise TypeError if not isinstance(other, ModuleClass): return False reg = get_module_registry() return reg.is_descriptor_subclass(self.descriptor, other.descriptor)
def _get_base_descriptor(self): if self._base_descriptor is None and self.base_descriptor_id >= 0: from vistrails.core.modules.module_registry import get_module_registry reg = get_module_registry() self._base_descriptor = \ reg.descriptors_by_id[self.base_descriptor_id] return self._base_descriptor
def collectParameterActions(self): """ collectParameterActions() -> list Return a list of action lists corresponding to each dimension """ if not self.pipeline: return None reg = get_module_registry() parameterValues = [[], [], [], []] counts = self.label.getCounts() for i in xrange(self.layout().count()): pEditor = self.layout().itemAt(i).widget() if pEditor and isinstance(pEditor, QParameterSetEditor): for paramWidget in pEditor.paramWidgets: editor = paramWidget.editor interpolator = editor.stackedEditors.currentWidget() paramInfo = paramWidget.param dim = paramWidget.getDimension() if dim in [0, 1, 2, 3]: count = counts[dim] values = interpolator.get_values(count) if not values: return None pId = paramInfo.id pType = paramInfo.dbtype parentType = paramInfo.parent_dbtype parentId = paramInfo.parent_id function = self.pipeline.db_get_object(parentType, parentId) old_param = self.pipeline.db_get_object(pType,pId) pName = old_param.name pAlias = old_param.alias pIdentifier = old_param.identifier actions = [] tmp_id = -1L for v in values: getter = reg.get_descriptor_by_name desc = getter(paramInfo.identifier, paramInfo.type, paramInfo.namespace) if not isinstance(v, str): str_value = desc.module.translate_to_string(v) else: str_value = v new_param = ModuleParam(id=tmp_id, pos=old_param.pos, name=pName, alias=pAlias, val=str_value, type=paramInfo.type, identifier=pIdentifier ) action_spec = ('change', old_param, new_param, parentType, function.real_id) action = vistrails.core.db.action.create_action([action_spec]) actions.append(action) parameterValues[dim].append(actions) tmp_id -= 1 return [zip(*p) for p in parameterValues]
def get_module_registry(): global _registry if _registry is None: from vistrails.core.modules.module_registry import get_module_registry _registry = get_module_registry() return _registry
def init(self): """Initial setup of the Manager. Discovers plots and variable loaders from packages and registers notifications for packages loaded in the future. """ app = get_vistrails_application() # dat_new_plot(plot: Plot) app.create_notification('dat_new_plot') # dat_removed_plot(plot: Plot) app.create_notification('dat_removed_plot') # dat_new_loader(loader: BaseVariableLoader) app.create_notification('dat_new_loader') # dat_removed_loader(loader: BaseVariableLoader) app.create_notification('dat_removed_loader') # dat_new_operation(loader: VariableOperation) app.create_notification('dat_new_operation') # dat_removed_operation(loader: VariableOperation) app.create_notification('dat_removed_operation') app.register_notification("reg_new_package", self.new_package) app.register_notification("reg_deleted_package", self.deleted_package) # Load the Plots and VariableLoaders from the packages registry = get_module_registry() for package in registry.package_list: self.new_package(package.identifier)
def createEditor(self, parent, option, index): registry = get_module_registry() if index.column()==2: #Depth type spinbox = QtGui.QSpinBox(parent) spinbox.setValue(0) return spinbox elif index.column()==1: #Port type combo = CompletingComboBox(parent) # FIXME just use descriptors here!! variant_desc = registry.get_descriptor_by_name( get_vistrails_basic_pkg_id(), 'Variant') for _, pkg in sorted(registry.packages.iteritems()): pkg_item = QtGui.QStandardItem("----- %s -----" % pkg.name) pkg_item.setData('', QtCore.Qt.UserRole) pkg_item.setFlags(pkg_item.flags() & ~( QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable)) font = pkg_item.font() font.setBold(True) pkg_item.setFont(font) combo.model().appendRow(pkg_item) for _, descriptor in sorted(pkg.descriptors.iteritems()): if descriptor is variant_desc: variant_index = combo.count() combo.addItem("%s (%s)" % (descriptor.name, descriptor.identifier), descriptor.sigstring) combo.select_default_item(variant_index) return combo else: return QtGui.QItemDelegate.createEditor(self, parent, option, index)
def inspect_input_output_ports(self, pipeline): """ inspect_input_output_ports(pipeline: Pipeline) -> None Inspect the pipeline input/output ports, useful for submodule """ registry = get_module_registry() self.input_ports = {} self.input_port_by_name = {} self.output_ports = {} self.output_port_by_name = {} if not pipeline: return for cId, conn in pipeline.connections.iteritems(): src_module = pipeline.modules[conn.source.moduleId] dst_module = pipeline.modules[conn.destination.moduleId] if src_module.name=='InputPort': spec = registry.getInputPortSpec(dst_module, conn.destination.name) name = self.get_port_name(src_module) if name=='': name = conn.destination.name self.input_ports[src_module.id] = (name, spec[0]) self.input_port_by_name[name] = src_module.id if dst_module.name=='OutputPort': spec = registry.getOutputPortSpec(src_module, conn.source.name) name = self.get_port_name(dst_module) if name=='': name = conn.source.name self.output_ports[dst_module.id] = (name, spec[0]) self.output_port_by_name[name] = dst_module.id
def __init__(self, param, size, parent=None): """ QParameterWidget(param: ParameterInfo, size: int, parent: QWidget) -> QParameterWidget """ QtGui.QWidget.__init__(self, parent) self.param = param self.prevWidget = 0 hLayout = QtGui.QHBoxLayout(self) hLayout.setMargin(0) hLayout.setSpacing(0) self.setLayout(hLayout) hLayout.addSpacing(5+16+5) self.label = QtGui.QLabel(param.spec.module) self.label.setFixedWidth(50) hLayout.addWidget(self.label) registry = get_module_registry() module = param.spec.descriptor.module assert issubclass(module, Constant) self.editor = QParameterEditor(param, size) hLayout.addWidget(self.editor) self.selector = QDimensionSelector() self.connect(self.selector.radioButtons[4], QtCore.SIGNAL('toggled(bool)'), self.disableParameter) hLayout.addWidget(self.selector)
def update_output_modules(self, *args, **kwargs): # need to find all currently loaded output modes (need to # check after modules are loaded and spin through registery) # and display them here reg = get_module_registry() output_d = reg.get_descriptor_by_name(get_vistrails_basic_pkg_id(), "OutputModule") sublist = reg.get_descriptor_subclasses(output_d) modes = {} for d in sublist: if hasattr(d.module, '_output_modes_dict'): for mode_type, (mode, _) in (d.module._output_modes_dict .iteritems()): modes[mode_type] = mode found_modes = set() for mode_type, mode in modes.iteritems(): found_modes.add(mode_type) if mode_type not in self.mode_widgets: mode_config = None output_settings = self.persistent_config.outputDefaultSettings if output_settings.has(mode_type): mode_config = getattr(output_settings, mode_type) widget = OutputModeConfigurationWidget(mode, mode_config) widget.fieldChanged.connect(self.field_was_changed) self.inner_layout.addWidget(widget) self.mode_widgets[mode_type] = widget for mode_type, widget in self.mode_widgets.items(): if mode_type not in found_modes: self.inner_layout.removeWidget(self.mode_widgets[mode_type]) del self.mode_widgets[mode_type]
def theFunction(src, dst): global Variant_desc, InputPort_desc if Variant_desc is None: reg = get_module_registry() Variant_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'Variant') InputPort_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'InputPort') iport = conn.destination.name oport = conn.source.name src.enableOutputPort(oport) conf = get_vistrails_configuration() error_on_others = getattr(conf, 'errorOnConnectionTypeerror') error_on_variant = (error_on_others or getattr(conf, 'errorOnVariantTypeerror')) errors = [error_on_others, error_on_variant] if isinstance(src, InputPort_desc.module): typecheck = [False] else: typecheck = [errors[desc is Variant_desc] for desc in conn.source.spec.descriptors()] dst.set_input_port( iport, ModuleConnector(src, oport, conn.destination.spec, typecheck))
def begin_compute(self, obj): i = self.remap_id(obj.id) self.view.set_module_computing(i) reg = get_module_registry() module_name = reg.get_descriptor(obj.__class__).name self.log.start_execution(obj, i, module_name)
def __getattr__(self, attr_name): reg = get_module_registry() d = reg.get_descriptor_by_name(self._package.identifier, attr_name, '', self._package.version) vt_api = get_api() module = vt_api.add_module_from_descriptor(d) return module
def registerControl(module, **kwargs): """This function is used to register the control modules. In this way, all of them will have the same style and shape.""" reg = get_module_registry() reg.add_module(module, moduleRightFringe=[(0.0,0.0),(0.25,0.5),(0.0,1.0)], moduleLeftFringe=[(0.0,0.0),(0.0,1.0)], **kwargs)
def summon(self): result = self.module_descriptor.module() result.transfer_attrs(self) # FIXME this may not be quite right because we don't have self.registry # anymore. That said, I'm not sure how self.registry would have # worked for hybrids... result.registry = get_module_registry() return result
def register_self(): registry = get_module_registry() r = registry.get_descriptor_by_name(vtk_pkg_identifier, 'vtkRenderer').module registry.add_module(VTKRenderOffscreen) registry.add_input_port(VTKRenderOffscreen, 'renderer', r) registry.add_input_port(VTKRenderOffscreen, 'width', Integer) registry.add_input_port(VTKRenderOffscreen, 'height', Integer) registry.add_output_port(VTKRenderOffscreen, 'image', File)
def initialize(*args,**keywords): reg = get_module_registry() reg.add_module(Map) reg.add_input_port(Map, 'FunctionPort', (Module, '')) reg.add_input_port(Map, 'InputList', (List, '')) reg.add_input_port(Map, 'InputPort', (List, '')) reg.add_input_port(Map, 'OutputPort', (String, '')) reg.add_output_port(Map, 'Result', (List, ''))
def matchQueryParam(self, template, target): """ matchQueryParam(template: Param, target: Param) -> bool Check to see if target can match with a query template """ if (template.type != target.type or template.identifier != target.identifier or template.namespace != target.namespace): return False reg = get_module_registry() desc = reg.get_descriptor_by_name(template.identifier, template.type, template.namespace) return desc.module.query_compute(target.strValue, template.strValue, template.queryMethod)
def testSummonModule(self): """Check that summon creates a correct module""" from vistrails.core.modules.basic_modules import identifier as basic_pkg x = Module() x.name = "String" x.package = basic_pkg try: registry = get_module_registry() c = x.summon() m = registry.get_descriptor_by_name(basic_pkg, 'String').module assert isinstance(c, m) except NoSummon: msg = "Expected to get a String object, got a NoSummon exception" self.fail(msg)
def setup(self, cell, plot): self.cell = cell self.plot = plot # Get pipeline of the cell mngr = VistrailManager(cell._controller) pipelineInfo = mngr.get_pipeline(cell.cellInfo) # Clear old tabs self.tabWidget.clear() # Get all of the plot modules in the pipeline plot_modules = get_plot_modules( pipelineInfo, cell._controller.current_pipeline) registry = get_module_registry() getter = registry.get_configuration_widget for module in plot_modules: widgetType = None widget = None # Try to get custom config widget for the module try: widgetType = \ getter(module.package, module.name, module.namespace) except ModuleRegistryException: pass if widgetType: # Use custom widget widget = widgetType(module, cell._controller) self.connect(widget, QtCore.SIGNAL("doneConfigure"), self.configureDone) self.connect(widget, QtCore.SIGNAL("stateChanged"), self.stateChanged) else: # Use PortsList widget, only if module has ports widget = DATPortsList(self) widget.update_module(module) if len(widget.port_spec_items) > 0: widget.set_controller(cell._controller) else: widget = None # Add widget in new tab if widget: self.tabWidget.addTab(widget, module.name)
def read_type(pipeline): """Read the type of a Variable from its pipeline. The type is obtained from the 'spec' input port of the 'OutputPort' module. """ reg = get_module_registry() OutputPort = reg.get_module_by_name( 'org.vistrails.vistrails.basic', 'OutputPort') outputs = find_modules_by_type(pipeline, [OutputPort]) if len(outputs) == 1: output = outputs[0] if get_function(output, 'name') == 'value': spec = get_function(output, 'spec') return resolve_descriptor(spec) return None
def newPackage(self, package_identifier, prepend=False): # prepend places at the front of the list of packages, # by default adds to the end of the list of packages # Right now the list is sorted so prepend has no effect if package_identifier in self.packages: return self.packages[package_identifier] registry = get_module_registry() package_name = registry.packages[package_identifier].name package_item = QPackageTreeWidgetItem(None, package_name, package_identifier) self.packages[package_identifier] = package_item if prepend: self.treeWidget.insertTopLevelItem(0, package_item) else: self.treeWidget.addTopLevelItem(package_item) return package_item
def get_widget_class(descriptor, widget_type=None, widget_use=None, return_default=True): reg = get_module_registry() cls = reg.get_constant_config_widget(descriptor, widget_type, widget_use) prefix = get_prefix(reg, descriptor) if cls is None and return_default: if descriptor.module is not None and \ hasattr(descriptor.module, 'get_widget_class'): cls = descriptor.module.get_widget_class() if cls is None: if widget_type == "enum": return StandardConstantEnumWidget else: return StandardConstantWidget return load_cls(cls, prefix)
def get_values(self): from vistrails.core.modules.module_registry import get_module_registry reg = get_module_registry() PersistentRef = \ reg.get_descriptor_by_name(persistence_pkg, 'PersistentRef').module functions = [] if self.new_file and self.new_file.get_path() and \ self.managed_new.isChecked(): # if self.new_file and self.new_file.get_path(): functions.append(('value', [self.new_file.get_path()])) else: functions.append(('value', None)) pass ref = PersistentRef() if self.managed_new.isChecked(): if self.existing_ref and not self.existing_ref._exists: ref.id = self.existing_ref.id ref.version = self.existing_ref.version else: ref.id = str(uuid.uuid1()) ref.version = None # endif ref.name = str(self.name_edit.text()) ref.tags = str(self.tags_edit.text()) elif self.managed_existing.isChecked(): (ref.id, ref.version, ref.name, ref.tags) = \ self.ref_widget.get_info() if self.keep_local.isChecked(): functions.append(('localPath', [self.local_path.get_path()])) functions.append( ('readLocal', [str(self.r_priority_local.isChecked())])) functions.append( ('writeLocal', [str(self.write_managed_checkbox.isChecked())])) # ref.local_path = self.local_path.get_path() # ref.local_read = self.r_priority_local.isChecked() # ref.local_writeback = self.write_managed_checkbox.isChecked() else: ref.local_path = None # functions.append(('localPath', None)) functions.append(('readLocal', None)) functions.append(('writeLocal', None)) pass functions.append(('ref', [PersistentRef.translate_to_string(ref)])) self.controller.update_functions(self.module, functions)
def handle_module_upgrade_request(controller, module_id, pipeline): # Before 0.0.3, SQLSource's resultSet output was type ListOfElements (which # doesn't exist anymore) # In 0.0.3, SQLSource's resultSet output was type List # In 0.1.0, SQLSource's output was renamed to result and is now a Table; # this is totally incompatible and no upgrade code is possible # the resultSet is kept for now for compatibility # Up to 0.0.4, DBConnection would ask for a password if one was necessary; # this behavior has not been kept. There is now a password input port, to # which you can connect a PasswordDialog from package dialogs if needed old_module = pipeline.modules[module_id] # DBConnection module from before 0.1.0: automatically add the password # prompt module if (old_module.name == 'DBConnection' and versions_increasing(old_module.version, '0.1.0')): reg = get_module_registry() # Creates the new module new_module = controller.create_module_from_descriptor( reg.get_descriptor(DBConnection)) # Create the password module mod_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.dialogs', 'PasswordDialog') mod = controller.create_module_from_descriptor(mod_desc) # Adds a 'label' function to the password module ops = [('add', mod)] ops.extend( controller.update_function_ops(mod, 'label', ['Server password'])) # Connects the password module to the new module conn = controller.create_connection(mod, 'result', new_module, 'password') ops.append(('add', conn)) # Replaces the old module with the new one upgrade_actions = UpgradeWorkflowHandler.replace_generic( controller, pipeline, old_module, new_module, src_port_remap={'self': 'connection'}) password_fix_action = create_action(ops) return upgrade_actions + [password_fix_action] return UpgradeWorkflowHandler.attempt_automatic_upgrade( controller, pipeline, module_id)
def get_query_widget_class(descriptor, widget_type=None): cls = get_widget_class(descriptor, widget_type, "query", False) if cls is None: if descriptor.module is not None and \ hasattr(descriptor.module, 'get_query_widget_class'): cls = descriptor.module.get_query_widget_class() if cls is None: class DefaultQueryWidget(BaseQueryWidget): def __init__(self, param, parent=None): BaseQueryWidget.__init__(self, get_widget_class(descriptor), ["==", "!="], param, parent) return DefaultQueryWidget reg = get_module_registry() prefix = get_prefix(reg, descriptor) return load_cls(cls, prefix) return cls
def contextMenuEvent(self, event): """Just dispatches the menu event to the widget item""" item = self.itemAt(event.pos()) if item: # find top level p = item while p.parent(): p = p.parent() # get package identifier assert isinstance(p, QPackageTreeWidgetItem) identifier = p.identifier registry = get_module_registry() package = registry.packages[identifier] try: if package.has_context_menu(): if isinstance(item, QPackageTreeWidgetItem): text = None elif isinstance(item, QNamespaceTreeWidgetItem): return # no context menu for namespaces elif isinstance(item, QModuleTreeWidgetItem): text = item.descriptor.name if item.descriptor.namespace: text = '%s|%s' % (item.descriptor.namespace, text) else: assert False, "fell through" menu_items = package.context_menu(text) if menu_items: menu = QtGui.QMenu(self) for text, callback in menu_items: act = QtGui.QAction(text, self) act.setStatusTip(text) QtCore.QObject.connect( act, QtCore.SIGNAL("triggered()"), callback) menu.addAction(act) menu.exec_(event.globalPos()) return except Exception, e: debug.unexpected_exception(e) debug.warning("Got exception trying to display %s's " "context menu in the palette: %s\n%s" % (package.name, debug.format_exception(e), traceback.format_exc())) item.contextMenuEvent(event, self)
def _get_variables_root(controller=None): """Create or get the version tagged 'dat-vars' This is the base version of all DAT variables. It consists of a single OutputPort module with name 'value'. """ if controller is None: controller = get_vistrails_application().get_controller() assert controller is not None if controller.vistrail.has_tag_str('dat-vars'): root_version = controller.vistrail.get_version_number('dat-vars') else: # Create the 'dat-vars' version controller.change_selected_version(0) reg = get_module_registry() operations = [] # Add an OutputPort module descriptor = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'OutputPort') out_mod = controller.create_module_from_descriptor(descriptor) operations.append(('add', out_mod)) # Add a function to this module operations.extend( controller.update_function_ops( out_mod, 'name', ['value'])) # Perform the operations action = create_action(operations) controller.add_new_action(action) root_version = controller.perform_action(action) # Tag as 'dat-vars' controller.vistrail.set_tag(root_version, 'dat-vars') controller.change_selected_version(root_version) pipeline = controller.current_pipeline outmod_id = pipeline.modules.keys() assert len(outmod_id) == 1 outmod_id = outmod_id[0] return controller, root_version, outmod_id
def addWidget(packagePath): """ addWidget(packagePath: str) -> package Add a new widget type to the spreadsheet registry supplying a basic set of spreadsheet widgets """ try: registry = get_module_registry() widget = importReturnLast(packagePath) if hasattr(widget, 'widgetName'): widgetName = widget.widgetName() else: widgetName = packagePath widget.registerWidget(registry, basic_modules, basicWidgets) spreadsheetRegistry.registerPackage(widget, packagePath) debug.log(' ==> Successfully import <%s>' % widgetName) except Exception, e: debug.log(' ==> Ignored package <%s>' % packagePath, e) widget = None
def contextMenuEvent(self, event): # Just dispatches the menu event to the widget item item = self.itemAt(event.pos()) if item: # find top level p = item while p.parent(): p = p.parent() # get package identifier identifiers = [ i for i, j in self.parent().packages.iteritems() if j == weakref.ref(p) ] if identifiers: identifier = identifiers[0] registry = get_module_registry() package = registry.packages[identifier] try: if package.has_contextMenuName(): name = package.contextMenuName(str(item.text(0))) if name: act = QtGui.QAction(name, self) act.setStatusTip(name) def callMenu(): if package.has_callContextMenu(): name = package.callContextMenu( str(item.text(0))) QtCore.QObject.connect( act, QtCore.SIGNAL("triggered()"), callMenu) menu = QtGui.QMenu(self) menu.addAction(act) menu.exec_(event.globalPos()) return except Exception, e: debug.warning( "Got exception trying to display %s's " "context menu in the palette: %s: %s" % (package.name, type(e).__name__, ', '.join(e.args))) item.contextMenuEvent(event, self)
def inspect_spreadsheet_cells(self, pipeline): """ inspect_spreadsheet_cells(pipeline: Pipeline) -> None Inspect the pipeline to see how many cells is needed """ self.spreadsheet_cells = [] if not pipeline: return registry = get_module_registry() # Sometimes we run without the spreadsheet! if not registry.has_module('org.vistrails.vistrails.spreadsheet', 'SpreadsheetCell'): return cell_desc = registry.get_descriptor_by_name( 'org.vistrails.vistrails.spreadsheet', 'SpreadsheetCell') output_desc = registry.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'OutputModule') def find_spreadsheet_cells(pipeline, root_id=None): if root_id is None: root_id = [] for mId, module in pipeline.modules.iteritems(): desc = registry.get_descriptor_by_name(module.package, module.name, module.namespace) # SpreadsheetCell subclasses if registry.is_descriptor_subclass(desc, cell_desc): self.spreadsheet_cells.append(root_id + [mId]) # Output modules with a 'spreadsheet' mode elif registry.is_descriptor_subclass(desc, output_desc): if desc.module.get_mode_class('spreadsheet') is not None: self.spreadsheet_cells.append(root_id + [mId]) for subworkflow_id in self.find_subworkflows(pipeline): subworkflow = pipeline.modules[subworkflow_id] if subworkflow.pipeline is not None: find_spreadsheet_cells(subworkflow.pipeline, root_id + [subworkflow_id]) find_spreadsheet_cells(pipeline)
def test_operation_resolution(self): import dat.tests.pkg_test_operations.init as pkg from vistrails.core.modules.basic_modules import Integer, String from vistrails.packages.URL.init import DownloadFile reg = get_module_registry() gd = reg.get_descriptor self.assertIs( find_operation( 'overload_std', [gd(String), gd(Integer)]), pkg.overload_std_3) with self.assertRaises(InvalidOperation) as cm: find_operation( 'overload_std', [gd(String), gd(DownloadFile)]) self.assertIn("Found no match", cm.exception.message) with self.assertRaises(InvalidOperation) as cm: find_operation('nonexistent', []) self.assertIn("There is no ", cm.exception.message) self.assertIs( find_operation( 'overload_custom', [gd(pkg.ModC), gd(pkg.ModD)]), pkg.overload_custom_2) with self.assertRaises(InvalidOperation) as cm: find_operation( 'overload_custom', [gd(pkg.ModE), gd(pkg.ModE)]) self.assertIs( find_operation( 'overload_custom', [gd(pkg.ModD), gd(pkg.ModD)]), pkg.overload_custom_1)
def compute(self): reg = get_module_registry() tf = self.get_input('TransferFunction') new_tf = copy.copy(tf) if self.has_input('Input'): port = self.get_input('Input') algo = port.GetProducer() output = algo.GetOutput(port.GetIndex()) (new_tf._min_range, new_tf._max_range) = output.GetScalarRange() elif self.has_input('Dataset'): algo = self.get_input('Dataset') output = algo (new_tf._min_range, new_tf._max_range) = output.GetScalarRange() else: (new_tf._min_range, new_tf._max_range) = self.get_input('Range') self.set_output('TransferFunction', new_tf) (of, cf) = new_tf.get_vtk_transfer_functions() self.set_output('vtkPicewiseFunction', of) self.set_output('vtkColorTransferFunction', cf)
def check_port_spec(module, port_name, port_type, descriptor=None, sigstring=None): basic_pkg = get_vistrails_basic_pkg_id() reg = get_module_registry() found = False try: if descriptor is not None: s = reg.get_port_spec_from_descriptor(descriptor, port_name, port_type) found = True spec_tuples = parse_port_spec_string(sigstring, basic_pkg) for i in xrange(len(spec_tuples)): spec_tuple = spec_tuples[i] port_pkg = reg.get_package_by_name(spec_tuple[0]) if port_pkg.identifier != spec_tuple[0]: # we have an old identifier spec_tuples[i] = ( port_pkg.identifier, ) + spec_tuple[1:] sigstring = create_port_spec_string(spec_tuples) # sigstring = expand_port_spec_string(sigstring, basic_pkg) if s.sigstring != sigstring: msg = ('%s port "%s" of module "%s" exists, but ' 'signatures differ "%s" != "%s"') % \ (port_type.capitalize(), port_name, module.name, s.sigstring, sigstring) raise UpgradeWorkflowError(msg, module, port_name, port_type) except MissingPort: pass if not found and \ not module.has_portSpec_with_name((port_name, port_type)): msg = '%s port "%s" of module "%s" does not exist.' % \ (port_type.capitalize(), port_name, module.name) raise UpgradeWorkflowError(msg, module, port_name, port_type)
def connect_var(self, vt_var, dest_module, dest_portname): self._ensure_version() var_type_desc = get_module_registry().get_descriptor_by_name( vt_var.package, vt_var.module, vt_var.namespace) x = dest_module.location.x y = dest_module.location.y # Adapted from VistrailController#connect_vistrail_var() var_module = self.controller.find_vistrail_var_module(vt_var.uuid) if var_module is None: var_module = self.controller.create_vistrail_var_module( var_type_desc, x, y, vt_var.uuid) self.operations.append(('add', var_module)) elif self.controller.check_vistrail_var_connected(var_module, dest_module, dest_portname): return connection = self.controller.create_connection( var_module, 'value', dest_module, dest_portname) self.operations.append(('add', connection))
def registerSelf(): """ registerSelf() -> None Registry module with the registry """ from base_module import vtkRendererOutput vtkRendererOutput.register_output_mode(vtkRendererToSpreadsheet) registry = get_module_registry() registry.add_module(VTKCell) registry.add_input_port(VTKCell, "Location", CellLocation) from vistrails.core import debug for (port, module) in [("AddRenderer", 'vtkRenderer'), ("SetRenderView", 'vtkRenderView'), ("InteractionHandler", 'vtkInteractionHandler'), ("InteractorStyle", 'vtkInteractorStyle'), ("AddPicker", 'vtkAbstractPicker')]: try: registry.add_input_port(VTKCell, port, '(%s:%s)' % (vtk_pkg_identifier, module)) except Exception, e: debug.warning( "Got an exception adding VTKCell's %s input " "port" % port, e)
def initialize(): from tensorflow.python.ops import standard_ops reg = get_module_registry() for module in base_modules: reg.add_module(module) # Optimizers reg.add_module(AutoOptimizer) optimizers = set(['Optimizer']) optimizers.update(register_optimizers(reg)) # Operations reg.add_module(AutoOperation) ops = set(wrapped) ops.update(register_operations(reg, standard_ops, '', ops)) ops.update(register_operations(reg, tensorflow.train, 'train', ops | optimizers)) ops.update(register_operations(reg, tensorflow.nn, 'nn', ops)) ops.update(register_operations(reg, tensorflow.image, 'image', ops))
def __init__(self, identifier, version=''): self._package = None # namespace_dict : {namespace : (namespace_dict, modules)} self._namespaces = ({}, []) reg = get_module_registry() self._package = reg.get_package_by_name(identifier, version) for desc in self._package.descriptor_list: if desc.namespace: namespaces = desc.namespace.split('|') else: namespaces = [] cur_namespace = self._namespaces[0] cur_modules = self._namespaces[1] for namespace in namespaces: if namespace not in cur_namespace: cur_namespace[namespace] = ({}, []) cur_modules = cur_namespace[namespace][1] cur_namespace = cur_namespace[namespace][0] cur_modules.append(desc) iteritems = [self._namespaces] for (namespaces, modules) in iteritems: modules.sort(key=lambda d: d.name) iteritems = itertools.chain(iteritems, namespaces.itervalues())
def setup_pipeline(self, pipeline, **kwargs): """setup_pipeline(controller, pipeline, locator, currentVersion, view, aliases, **kwargs) Matches a pipeline with the persistent pipeline and creates instances of modules that aren't in the cache. """ def fetch(name, default): return kwargs.pop(name, default) controller = fetch('controller', None) locator = fetch('locator', None) current_version = fetch('current_version', None) view = fetch('view', DummyView()) vistrail_variables = fetch('vistrail_variables', None) aliases = fetch('aliases', None) params = fetch('params', None) extra_info = fetch('extra_info', None) logger = fetch('logger', DummyLogController) sinks = fetch('sinks', None) reason = fetch('reason', None) actions = fetch('actions', None) done_summon_hooks = fetch('done_summon_hooks', []) module_executed_hook = fetch('module_executed_hook', []) stop_on_error = fetch('stop_on_error', True) parent_exec = fetch('parent_exec', None) job_monitor = fetch('job_monitor', None) reg = get_module_registry() if len(kwargs) > 0: raise VistrailsInternalError('Wrong parameters passed ' 'to setup_pipeline: %s' % kwargs) def create_null(): """Creates a Null value""" getter = reg.get_descriptor_by_name descriptor = getter(basic_pkg, 'Null') return descriptor.module() def create_constant(param, module): """Creates a Constant from a parameter spec""" getter = reg.get_descriptor_by_name desc = getter(param.identifier, param.type, param.namespace) constant = desc.module() constant.id = module.id # if param.evaluatedStrValue: # constant.setValue(param.evaluatedStrValue) if param.strValue != '': constant.setValue(param.strValue) else: constant.setValue( constant.translate_to_string(constant.default_value)) return constant ### BEGIN METHOD ### # if self.debugger: # self.debugger.update() to_delete = [] errors = {} if controller is not None: # Controller is none for sub_modules controller.validate(pipeline) else: pipeline.validate() self.resolve_aliases(pipeline, aliases) if vistrail_variables: self.resolve_variables(vistrail_variables, pipeline) self.update_params(pipeline, params) (tmp_to_persistent_module_map, conn_map, module_added_set, conn_added_set) = self.add_to_persistent_pipeline(pipeline) # Create the new objects for i in module_added_set: persistent_id = tmp_to_persistent_module_map[i] module = self._persistent_pipeline.modules[persistent_id] obj = self._objects[persistent_id] = module.summon() obj.interpreter = self obj.id = persistent_id obj.signature = module._signature # Checking if output should be stored if module.has_annotation_with_key('annotate_output'): annotate_output = module.get_annotation_by_key( 'annotate_output') #print annotate_output if annotate_output: obj.annotate_output = True for f in module.functions: connector = None if len(f.params) == 0: connector = ModuleConnector(create_null(), 'value', f.get_spec('output')) elif len(f.params) == 1: p = f.params[0] try: constant = create_constant(p, module) connector = ModuleConnector(constant, 'value', f.get_spec('output')) except Exception, e: debug.unexpected_exception(e) err = ModuleError( module, "Uncaught exception creating Constant from " "%r: %s" % (p.strValue, debug.format_exception(e))) errors[i] = err to_delete.append(obj.id) else: tupleModule = vistrails.core.interpreter.base.InternalTuple( ) tupleModule.length = len(f.params) for (j, p) in enumerate(f.params): try: constant = create_constant(p, module) constant.update() connector = ModuleConnector( constant, 'value', f.get_spec('output')) tupleModule.set_input_port(j, connector) except Exception, e: debug.unexpected_exception(e) err = ModuleError( module, "Uncaught exception creating Constant " "from %r: %s" % (p.strValue, debug.format_exception(e))) errors[i] = err to_delete.append(obj.id) connector = ModuleConnector(tupleModule, 'value', f.get_spec('output')) if connector: obj.set_input_port(f.name, connector, is_method=True)
def execute(self, *args, **kwargs): """Execute the pipeline. Positional arguments are either input values (created from ``module == value``, where `module` is a Module from the pipeline and `value` is some value or Function instance) for the pipeline's InputPorts, or Module instances (to select sink modules). Keyword arguments are also used to set InputPort by looking up inputs by name. Example:: input_bound = pipeline.get_input('higher_bound') input_url = pipeline.get_input('url') sinkmodule = pipeline.get_module(32) pipeline.execute(sinkmodule, input_bound == vt.Function(Integer, 10), input_url == 'http://www.vistrails.org/', resolution=15) # kwarg: only one equal sign """ sinks = set() inputs = {} reg = get_module_registry() InputPort_desc = reg.get_descriptor_by_name( get_vistrails_basic_pkg_id(), 'InputPort') # Read args for arg in args: if isinstance(arg, ModuleValuePair): if arg.module.id in inputs: raise ValueError("Multiple values set for InputPort %r" % get_inputoutput_name(arg.module)) if not reg.is_descriptor_subclass(arg.module.module_descriptor, InputPort_desc): raise ValueError("Module %d is not an InputPort" % arg.module.id) inputs[arg.module.id] = arg.value elif isinstance(arg, Module): sinks.add(arg.module_id) # Read kwargs for key, value in kwargs.iteritems(): key = self.get_input(key) # Might raise KeyError if key.module_id in inputs: raise ValueError("Multiple values set for InputPort %r" % get_inputoutput_name(key.module)) inputs[key.module_id] = value reason = "API pipeline execution" sinks = sinks or None # Use controller only if no inputs were passed in if (not inputs and self.vistrail is not None and self.vistrail.current_version == self.version): controller = self.vistrail.controller results, changed = controller.execute_workflow_list([[ controller.locator, # locator self.version, # version self.pipeline, # pipeline DummyView(), # view None, # custom_aliases None, # custom_params reason, # reason sinks, # sinks None, # extra_info ]]) result, = results else: pipeline = self.pipeline if inputs: id_scope = IdScope(1) pipeline = pipeline.do_copy(False, id_scope) # A hach to get ids from id_scope that we know won't collide: # make them negative id_scope.getNewId = lambda t, g=id_scope.getNewId: -g(t) create_module = \ VistrailController.create_module_from_descriptor_static create_function = VistrailController.create_function_static create_connection = VistrailController.create_connection_static # Fills in the ExternalPipe ports for module_id, values in inputs.iteritems(): module = pipeline.modules[module_id] if not isinstance(values, (list, tuple)): values = [values] # Guess the type of the InputPort _, sigstrings, _, _, _ = get_port_spec_info( pipeline, module) sigstrings = parse_port_spec_string(sigstrings) # Convert whatever we got to a list of strings, for the # pipeline values = [ reg.convert_port_val(val, sigstring, None) for val, sigstring in izip(values, sigstrings) ] if len(values) == 1: # Create the constant module constant_desc = reg.get_descriptor_by_name( *sigstrings[0]) constant_mod = create_module(id_scope, constant_desc) func = create_function(id_scope, constant_mod, 'value', values) constant_mod.add_function(func) pipeline.add_module(constant_mod) # Connect it to the ExternalPipe port conn = create_connection(id_scope, constant_mod, 'value', module, 'ExternalPipe') pipeline.db_add_connection(conn) else: raise RuntimeError("TODO : create tuple") interpreter = get_default_interpreter() result = interpreter.execute(pipeline, reason=reason, sinks=sinks) if result.errors: raise ExecutionErrors(self, result) else: return ExecutionResults(self, result)
def performParameterExploration(self): """ performParameterExploration() -> None Perform the exploration by collecting a list of actions corresponding to each dimension """ registry = get_module_registry() actions = self.peWidget.table.collectParameterActions() spreadsheet_pkg = 'org.vistrails.vistrails.spreadsheet' # Set the annotation to persist the parameter exploration # TODO: For now, we just replace the existing exploration - Later we should append them. xmlString = "<paramexps>\n" + self.getParameterExploration() + "\n</paramexps>" self.controller.vistrail.set_paramexp(self.currentVersion, xmlString) self.controller.set_changed(True) if self.controller.current_pipeline and actions: explorer = ActionBasedParameterExploration() (pipelines, performedActions) = explorer.explore( self.controller.current_pipeline, actions) dim = [max(1, len(a)) for a in actions] if (registry.has_module(spreadsheet_pkg, 'CellLocation') and registry.has_module(spreadsheet_pkg, 'SheetReference')): modifiedPipelines = self.virtualCell.positionPipelines( 'PE#%d %s' % (QParameterExplorationTab.explorationId, self.controller.name), dim[2], dim[1], dim[0], pipelines, self.controller) else: modifiedPipelines = pipelines mCount = [] for p in modifiedPipelines: if len(mCount)==0: mCount.append(0) else: mCount.append(len(p.modules)+mCount[len(mCount)-1]) # Now execute the pipelines totalProgress = sum([len(p.modules) for p in modifiedPipelines]) progress = QtGui.QProgressDialog('Performing Parameter ' 'Exploration...', '&Cancel', 0, totalProgress) progress.setWindowTitle('Parameter Exploration') progress.setWindowModality(QtCore.Qt.WindowModal) progress.show() QParameterExplorationTab.explorationId += 1 interpreter = get_default_interpreter() for pi in xrange(len(modifiedPipelines)): progress.setValue(mCount[pi]) QtCore.QCoreApplication.processEvents() if progress.wasCanceled(): break def moduleExecuted(objId): if not progress.wasCanceled(): #progress.setValue(progress.value()+1) #the call above was crashing when used by multithreaded #code, replacing with the call below (thanks to Terence #for submitting this fix). QtCore.QMetaObject.invokeMethod(progress, "setValue", QtCore.Q_ARG(int,progress.value()+1)) QtCore.QCoreApplication.processEvents() kwargs = {'locator': self.controller.locator, 'current_version': self.controller.current_version, 'view': self.controller.current_pipeline_scene, 'module_executed_hook': [moduleExecuted], 'reason': 'Parameter Exploration', 'actions': performedActions[pi], } interpreter.execute(modifiedPipelines[pi], **kwargs) progress.setValue(totalProgress)
def get_module_registry(): from vistrails.core.modules.module_registry import get_module_registry return get_module_registry()
def updateFromPipeline(self, pipeline, controller): """ updateFromPipeline(pipeline: Pipeline) -> None Read the list of aliases and parameters from the pipeline """ self.clear() if not pipeline: return # Update the aliases if len(pipeline.aliases) > 0: aliasRoot = QParameterTreeWidgetItem(None, self, ['Aliases']) aliasRoot.setFlags(QtCore.Qt.ItemIsEnabled) for (alias, info) in pipeline.aliases.iteritems(): ptype, pId, parentType, parentId, mId = info parameter = pipeline.db_get_object(ptype, pId) function = pipeline.db_get_object(parentType, parentId) v = parameter.strValue port_spec = function.get_spec('input') port_spec_item = port_spec.port_spec_items[parameter.pos] label = ['%s = %s' % (alias, v)] pInfo = ParameterInfo(module_id=mId, name=function.name, pos=parameter.pos, value=v, spec=port_spec_item, is_alias=True) QParameterTreeWidgetItem((alias, [pInfo]), aliasRoot, label) aliasRoot.setExpanded(True) vistrailVarsRoot = QParameterTreeWidgetItem(None, self, ['Vistrail Variables']) vistrailVarsRoot.setHidden(True) # Now go through all modules and functions inspector = PipelineInspector() inspector.inspect_ambiguous_modules(pipeline) sortedModules = sorted(pipeline.modules.iteritems(), key=lambda item: item[1].name) reg = get_module_registry() for mId, module in sortedModules: if module.is_vistrail_var(): vistrailVarsRoot.setHidden(False) vistrailVarsRoot.setExpanded(True) port_spec = module.get_port_spec('value', 'input') if not port_spec: debug.critical("Not port_spec for value in module %s" % module) continue if not controller.has_vistrail_variable_with_uuid( module.get_vistrail_var()): continue vv = controller.get_vistrail_variable_by_uuid( module.get_vistrail_var()) label = ['%s = %s' % (vv.name, vv.value)] pList = [ ParameterInfo(module_id=mId, name=port_spec.name, pos=port_spec.port_spec_items[pId].pos, value="", spec=port_spec.port_spec_items[pId], is_alias=False) for pId in xrange(len(port_spec.port_spec_items)) ] QParameterTreeWidgetItem((vv.name, pList), vistrailVarsRoot, label) continue function_names = {} # Add existing parameters mLabel = [module.name] moduleItem = None if len(module.functions) > 0: for fId in xrange(len(module.functions)): function = module.functions[fId] function_names[function.name] = function if len(function.params) == 0: continue if moduleItem == None: if inspector.annotated_modules.has_key(mId): annotatedId = inspector.annotated_modules[mId] moduleItem = QParameterTreeWidgetItem( annotatedId, self, mLabel) else: moduleItem = QParameterTreeWidgetItem( None, self, mLabel) v = ', '.join([p.strValue for p in function.params]) label = ['%s(%s)' % (function.name, v)] try: port_spec = function.get_spec('input') except Exception: debug.critical("get_spec failed: %s %s %s" % \ (module, function, function.sigstring)) continue port_spec_items = port_spec.port_spec_items pList = [ ParameterInfo(module_id=mId, name=function.name, pos=function.params[pId].pos, value=function.params[pId].strValue, spec=port_spec_items[pId], is_alias=False) for pId in xrange(len(function.params)) ] mName = module.name if moduleItem.parameter is not None: mName += '(%d)' % moduleItem.parameter fName = '%s :: %s' % (mName, function.name) QParameterTreeWidgetItem((fName, pList), moduleItem, label) # Add available parameters if module.is_valid: for port_spec in module.destinationPorts(): if (port_spec.name in function_names or not port_spec.is_valid or not len(port_spec.port_spec_items) or not reg.is_constant(port_spec)): # The function already exists or is empty # or contains non-constant modules continue if moduleItem == None: if inspector.annotated_modules.has_key(mId): annotatedId = inspector.annotated_modules[mId] moduleItem = QParameterTreeWidgetItem( annotatedId, self, mLabel, False) else: moduleItem = QParameterTreeWidgetItem( None, self, mLabel, False) v = ', '.join( [p.module for p in port_spec.port_spec_items]) label = ['%s(%s)' % (port_spec.name, v)] pList = [ ParameterInfo(module_id=mId, name=port_spec.name, pos=port_spec.port_spec_items[pId].pos, value="", spec=port_spec.port_spec_items[pId], is_alias=False) for pId in xrange(len(port_spec.port_spec_items)) ] mName = module.name if moduleItem.parameter is not None: mName += '(%d)' % moduleItem.parameter fName = '%s :: %s' % (mName, port_spec.name) QParameterTreeWidgetItem((fName, pList), moduleItem, label, False) if moduleItem: moduleItem.setExpanded(True) self.toggleUnsetParameters(self.showUnsetParameters)
def convert_data(child, parent_type, parent_id): from vistrails.core.vistrail.port_spec import PortSpec from vistrails.core.modules.module_registry import get_module_registry global module_map, translate_ports registry = get_module_registry() if child.vtType == 'module': descriptor = registry.get_descriptor_from_name_only(child.db_name) package = descriptor.identifier module_map[child.db_id] = (child.db_name, package) return DBModule(id=child.db_id, cache=child.db_cache, abstraction=0, name=child.db_name, package=package) elif child.vtType == 'connection': return DBConnection(id=child.db_id) elif child.vtType == 'portSpec': return DBPortSpec(id=child.db_id, name=child.db_name, type=child.db_type, spec=child.db_spec) elif child.vtType == 'function': if parent_type == 'module': name = translate_vtk(parent_id, child.db_name) else: name = child.db_name return DBFunction(id=child.db_id, pos=child.db_pos, name=name) elif child.vtType == 'parameter': return DBParameter(id=child.db_id, pos=child.db_pos, name=child.db_name, type=child.db_type, val=child.db_val, alias=child.db_alias) elif child.vtType == 'location': return DBLocation(id=child.db_id, x=child.db_x, y=child.db_y) elif child.vtType == 'annotation': return DBAnnotation(id=child.db_id, key=child.db_key, value=child.db_value) elif child.vtType == 'port': sig = child.db_sig if '(' in sig and ')' in sig: name = sig[:sig.find('(')] specs = sig[sig.find('(') + 1:sig.find(')')] name = translate_vtk(child.db_moduleId, name, specs) new_specs = [] for spec_name in specs.split(','): descriptor = registry.get_descriptor_from_name_only(spec_name) spec_str = descriptor.identifier + ':' + spec_name new_specs.append(spec_str) spec = '(' + ','.join(new_specs) + ')' else: name = sig spec = '' return DBPort(id=child.db_id, type=child.db_type, moduleId=child.db_moduleId, moduleName=child.db_moduleName, name=name, spec=spec)
def write_workflow_to_python(pipeline): """Writes a pipeline to a Python source file. :returns: An iterable over the lines (as unicode) of the generated script. """ # The set of all currently bound symbols in the resulting script's global # scope # These are either variables from translated modules (internal or output # ports) or imported names all_vars = set(reserved) # Should we import izip or product from itertools? import_izip = False import_product = False # The parts of the final generated script text = [] # The modules that have been translated, maps from the module's id to a # Script object modules = dict() # The preludes that have been collected # A "prelude" is a piece of code that is supposed to go at the top of the # file, and that shouldn't be repeated; things like import statements, # function/class definition, and constants preludes = [] reg = get_module_registry() # set port specs and module depth pipeline.validate(False) # ######################################## # Walk through the pipeline to get all the codes # for module_id in pipeline.graph.vertices_topological_sort(): module = pipeline.modules[module_id] print("Processing module %s %d" % (module.name, module_id)) desc = module.module_descriptor module_class = desc.module # Gets the code code_preludes = [] if not hasattr(module_class, 'to_python_script'): # Use vistrails API to execute module code, code_preludes = generate_api_code(module) elif module_class.to_python_script is None: debug.critical("Module %s cannot be converted to Python" % module.name) code = Script( "# <Missing code>\n" "# %s has empty function to_python_script()\n" "# VisTrails cannot currently export such modules" % module.name, 'variables', 'variables') else: # Call the module to get the base code code = module_class.to_python_script(module) if isinstance(code, tuple): code, code_preludes = code print("Got code:\n%r" % (code, )) assert isinstance(code, Script) modules[module_id] = code preludes.extend(code_preludes) # ######################################## # Processes the preludes and writes the beginning of the file # print("Writing preludes") # Adds all imported modules to the list of symbols for prelude in preludes: all_vars.update(prelude.imported_pkgs) # Removes collisions prelude_renames = {} for prelude in preludes: prelude_renames.update(prelude.avoid_collisions(all_vars)) # remove duplicates final_preludes = [] final_prelude_set = set() for prelude in [unicode(p) for p in preludes]: if prelude not in final_prelude_set: final_prelude_set.add(prelude) final_preludes.append(prelude) # Writes the preludes for prelude in final_preludes: text.append(prelude) #if preludes: # text.append('# PRELUDE ENDS -- pipeline code follows\n\n') text.append('') # ######################################## # Walk through the pipeline a second time to generate the full script # first = True # outer name of output is different from name in code if loop is used # so we keep this mapping: {(module_id, oport_name): outer_name} output_loop_map = {} for module_id in pipeline.graph.vertices_topological_sort(): module = pipeline.modules[module_id] desc = module.module_descriptor print("Writing module %s %d" % (module.name, module_id)) if not first: text.append('\n') else: first = False # Annotation, used to rebuild the pipeline text.append("# MODULE %d %s" % (module_id, desc.sigstring)) code = modules[module_id] # Gets all the module's input and output port names input_ports = set( reg.module_destination_ports_from_descriptor(False, desc)) input_ports.update(module.input_port_specs) input_port_names = set(p.name for p in input_ports) iports = dict((p.name, p) for p in input_ports) connected_inputs = set() for _, conn_id in pipeline.graph.edges_to(module_id): conn = pipeline.connections[conn_id] connected_inputs.add(utf8(conn.destination.name)) for function in module.functions: connected_inputs.add(utf8(function.name)) output_ports = set(reg.module_source_ports_from_descriptor( False, desc)) output_ports.update(module.output_port_specs) output_port_names = set(p.name for p in output_ports) connected_outputs = set() for _, conn_id in pipeline.graph.edges_from(module_id): conn = pipeline.connections[conn_id] connected_outputs.add(utf8(conn.source.name)) # Changes symbol names in this piece of code to prevent collisions # with already-encountered code old_all_vars = set(all_vars) code.normalize(input_port_names, output_port_names, all_vars) # Now, code knows what its inputs and outputs are print("Normalized code:\n%r" % (code, )) print("New vars in all_vars: %r" % (all_vars - old_all_vars, )) print("used_inputs: %r" % (code.used_inputs, )) # build final inputs as result of merging connections # {dest_name: {source_name, depth_diff}} combined_inputs = {} # Adds functions for function in module.functions: port = utf8(function.name) code.unset_inputs.discard(port) if code.skip_functions: continue if port not in code.used_inputs: print("NOT adding function %s (not used in script)" % port) continue ## Creates a variable with the value name = make_unique(port, all_vars) if len(function.params) == 1: value = function.params[0].value() else: value = [p.value() for p in function.params] depth = -iports[port].depth print("Function %s: var %s, value %r" % (port, name, value)) text.append("# FUNCTION %s %s" % (port, name)) text.append('%s = %r' % (name, value)) if port not in combined_inputs: combined_inputs[port] = [] combined_inputs[port].append((name, depth)) # Sets input connections conn_ids = sorted( [conn_id for _, conn_id in pipeline.graph.edges_to(module_id)]) for conn_id in conn_ids: conn = pipeline.connections[conn_id] dst = conn.destination port = utf8(dst.name) if port not in code.used_inputs: print("NOT connecting port %s (not used in script)" % dst.name) continue src = conn.source # Tells the code what the variable was src_mod = modules[src.moduleId] name = output_loop_map.get((src.moduleId, src.name), src_mod.get_output(src.name)) # get depth difference to destination port depth = pipeline._connection_depths.get(conn_id, 0) # account for module looping depth += pipeline.modules[src.moduleId].list_depth print("Input %s: var %s" % (port, name)) text.append("# CONNECTION %s %s" % (dst.name, name)) #code.set_input(utf8(dst.name), name) code.unset_inputs.discard(utf8(dst.name)) if port not in combined_inputs: combined_inputs[port] = [] combined_inputs[port].append((name, depth)) # Sets default values if code.unset_inputs: print("unset_inputs: %r" % (code.unset_inputs, )) for port in set(code.unset_inputs): if port not in code.used_inputs: continue # Creates a variable with the value name = make_unique(port, all_vars) default = iports[port].defaults if len(default) == 1: default = default[0] print("Default: %s: var %s, value %r" % (port, name, default)) text.append("# DEFAULT %s %s" % (port, name)) text.append('%s = %s' % (name, default)) code.set_input(port, name) # merge connections # [(port, [names], depth)] merged_inputs = [] # keep input port order for iport in sorted(input_ports, key=lambda x: (x.sort_key, x.id)): port = iport.name if port not in combined_inputs: continue conns = combined_inputs[port] descs = iports[port].descriptors() can_merge = ((len(descs) == 1 and isinstance(descs[0], List)) or iports[port].depth > 0) if len(conns) == 1 or not can_merge: # no merge needed name = conns[0][0] depth = conns[0][1] if depth >= 0: # add directly new_names = [name] for i in xrange(depth): name = make_unique(new_names[-1] + '_item', all_vars) new_names.append(name) code.set_input(port, name) merged_inputs.append((port, new_names, depth)) else: # wrap to correct depth as new variable while depth: depth += 1 name = '[%s]' % name new_name = make_unique(code.inputs[port], all_vars) text.append('%s = %s' % (new_name, name)) code.set_input(port, new_name) merged_inputs.append((port, [new_name], 0)) else: # merge connections items = [] for name, depth in conns: # wrap to max depth while depth < 0: depth += 1 name = '[%s]' % name items.append(name) # assign names to list items in loop # like "xItem", "xItemItem", etc. new_names = [make_unique(code.inputs[port], all_vars)] max_depth = max([c[1] for c in conns]) if max_depth > 0: for i in xrange(max_depth): new_name = make_unique(new_names[-1] + '_item', all_vars) new_names.append(new_name) text.append('%s = %s' % (new_names[0], ' + '.join(items))) code.set_input(port, new_names[-1]) merged_inputs.append((port, new_names, max_depth)) #### Prepare looping ################################################# max_level = module.list_depth # names for the looped output ports output_sub_ports = dict([(port, [code.get_output(port)]) for port in connected_outputs]) # add output names to global output name list for port, names in output_sub_ports.iteritems(): output_loop_map[(module_id, port)] = names[0] offset_levels = [] loop_args = [] for level in xrange(1, max_level + 1): # Output values are collected from the inside and out for port in connected_outputs: sub_ports = output_sub_ports[port] new_name = make_unique(sub_ports[-1] + '_item', all_vars) code.set_output(port, new_name) sub_ports.append(new_name) # construct the loop code combine_type = 'cartesian' if max_level == 1: # first level may use a complex port combination cps = {} for cp in module.control_parameters: cps[cp.name] = cp.value if ModuleControlParam.LOOP_KEY in cps: combine_type = cps[ModuleControlParam.LOOP_KEY] if combine_type not in ['cartesian', 'pairwise']: combine_type = ast.literal_eval(combine_type) if combine_type == 'cartesian': # make nested for loops for all iterated ports items = [] for port, names, depth in merged_inputs: if depth >= level: # map parent level name to this level items.append((names[level], names[level - 1])) offset = 0 loop_arg = [] for item in items: loop_arg.append(' ' * offset + 'for %s in %s:' % item) offset += 1 loop_args.append('\n'.join(loop_arg)) offset_levels.append(offset) elif combine_type == 'pairwise': # zip all iterated ports prev_items, cur_items = [], [] for port, names, depth in merged_inputs: if depth >= level: # map parent level name to this level prev_items.append(names[level - 1]) cur_items.append(names[level]) loop_arg = 'for %s in izip(%s):' % (', '.join(cur_items), ', '.join(prev_items)) import_izip = True loop_args.append(loop_arg) offset_levels.append(1) else: # construct custom combination name_map = {} for port, names, depth in merged_inputs: if depth >= level: # map parent level name to this level name_map[port] = (names[level - 1], names[level]) loop_arg = "for %s in %s:" % combine(combine_type, name_map) if 'izip' in loop_arg: import_izip = True if 'product(' in loop_arg: import_product = True loop_args.append(loop_arg) offset_levels.append(1) # TODO: handle while loop #### Write loop start ################################################ for level in xrange(0, max_level): offset = sum(offset_levels[:level]) for port in connected_outputs: sub_ports = output_sub_ports[port] text.append(' ' * offset + '%s = []' % sub_ports[level]) sub_ports.append(new_name) for_loop = '\n'.join( [' ' * offset + f for f in loop_args[level].split('\n')]) text.append(for_loop) #### Write module code ############################################### code_text = unicode(code) if offset_levels: code_text = code_text.split('\n') code_text = '\n'.join( [' ' * sum(offset_levels) + t for t in code_text]) # node.increase_indent(sum(offset_levels)*4) # Ok, add the module's code print("Rendering code") text.append(code_text) print("Total new vars: %r" % (all_vars - old_all_vars, )) #### Write loop end ################################################## for level in reversed(xrange(0, max_level)): offset = sum(offset_levels[:level + 1]) for port in connected_outputs: sub_ports = output_sub_ports[port] text.append(' ' * offset + '%s.append(%s)' % (sub_ports[level], sub_ports[level - 1])) if import_izip or import_product: if not import_product: text.insert(0, 'from itertools import izip\n') elif not import_product: text.insert(0, 'from itertools import product\n') else: text.insert(0, 'from itertools import izip, product\n') return text
def assignPipelineCellLocations(pipeline, sheetName, row, col, cellIds=None, minRowCount=None, minColCount=None): reg = get_module_registry() spreadsheet_cell_desc = reg.get_descriptor_by_name(spreadsheet_pkg, 'SpreadsheetCell') output_module_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'OutputModule') create_module = VistrailController.create_module_static create_function = VistrailController.create_function_static create_connection = VistrailController.create_connection_static pipeline = copy.copy(pipeline) root_pipeline = pipeline if cellIds is None: inspector = PipelineInspector() inspector.inspect_spreadsheet_cells(pipeline) inspector.inspect_ambiguous_modules(pipeline) cellIds = inspector.spreadsheet_cells def fix_cell_module(pipeline, mId): # Delete connections to 'Location' input port conns_to_delete = [] for c in pipeline.connection_list: if c.destinationId == mId and c.destination.name == 'Location': conns_to_delete.append(c.id) for c_id in conns_to_delete: pipeline.delete_connection(c_id) # a hack to first get the id_scope to the local pipeline scope # then make them negative by hacking the getNewId method # all of this is reset at the end of this block id_scope = pipeline.tmp_id orig_getNewId = pipeline.tmp_id.__class__.getNewId def getNewId(self, objType): return -orig_getNewId(self, objType) pipeline.tmp_id.__class__.getNewId = getNewId # Add a sheet reference with a specific name sheetReference = create_module(id_scope, spreadsheet_pkg, "SheetReference") sheetNameFunction = create_function(id_scope, sheetReference, "SheetName", [str(sheetName)]) # ["%s %d" % (sheetPrefix, sheet)]) sheetReference.add_function(sheetNameFunction) if minRowCount is not None: minRowFunction = create_function(id_scope, sheetReference, "MinRowCount", [str(minRowCount)]) # [str(rowCount*vRCount)]) sheetReference.add_function(minRowFunction) if minColCount is not None: minColFunction = create_function(id_scope, sheetReference, "MinColumnCount", [str(minColCount)]) # [str(colCount*vCCount)]) sheetReference.add_function(minColFunction) # Add a cell location module with a specific row and column cellLocation = create_module(id_scope, spreadsheet_pkg, "CellLocation") rowFunction = create_function(id_scope, cellLocation, "Row", [str(row)]) # [str(row*vRCount+vRow+1)]) colFunction = create_function(id_scope, cellLocation, "Column", [str(col)]) # [str(col*vCCount+vCol+1)]) cellLocation.add_function(rowFunction) cellLocation.add_function(colFunction) # Then connect the SheetReference to the CellLocation sheet_conn = create_connection(id_scope, sheetReference, "value", cellLocation, "SheetReference") # Then connect the CellLocation to the spreadsheet cell cell_module = pipeline.get_module_by_id(mId) cell_conn = create_connection(id_scope, cellLocation, "value", cell_module, "Location") pipeline.add_module(sheetReference) pipeline.add_module(cellLocation) pipeline.add_connection(sheet_conn) pipeline.add_connection(cell_conn) # replace the getNewId method pipeline.tmp_id.__class__.getNewId = orig_getNewId def fix_output_module(pipeline, mId): # Remove all connections to 'configuration' input port conns_to_delete = [] for c in pipeline.connection_list: if (c.destinationId == mId and c.destination.name == 'configuration'): conns_to_delete.append(c.id) for c_id in conns_to_delete: pipeline.delete_connection(c_id) m = pipeline.modules[mId] # Remove all functions on 'configuration' input port funcs_to_delete = [] for f in m.functions: if f.name == 'configuration': funcs_to_delete.append(f.real_id) for f_id in funcs_to_delete: m.delete_function_by_real_id(f_id) # a hack to first get the id_scope to the local pipeline scope # then make them negative by hacking the getNewId method # all of this is reset at the end of this block id_scope = pipeline.tmp_id orig_getNewId = pipeline.tmp_id.__class__.getNewId def getNewId(self, objType): return -orig_getNewId(self, objType) pipeline.tmp_id.__class__.getNewId = getNewId config = {'row': row - 1, 'col': col - 1} if minRowCount is not None: config['sheetRowCount'] = minRowCount if minColCount is not None: config['sheetColCount'] = minColCount if sheetName is not None: config['sheetName']= sheetName config = {'spreadsheet': config} config_function = create_function(id_scope, m, 'configuration', [repr(config)]) m.add_function(config_function) # replace the getNewId method pipeline.tmp_id.__class__.getNewId = orig_getNewId for id_list in cellIds: cell_pipeline = pipeline # find at which depth we need to be working if isinstance(id_list, (int, long)): mId = id_list m = cell_pipeline.modules[mId] else: id_iter = iter(id_list) mId = next(id_iter) m = cell_pipeline.modules[mId] for mId in id_iter: cell_pipeline = m.pipeline m = cell_pipeline.modules[mId] if reg.is_descriptor_subclass(m.module_descriptor, spreadsheet_cell_desc): fix_cell_module(cell_pipeline, mId) elif reg.is_descriptor_subclass(m.module_descriptor, output_module_desc): fix_output_module(cell_pipeline, mId) return root_pipeline