Esempio n. 1
0
 def __init__(self, value):
     self.value = value
     if isinstance(self.value, basestring):
         self.type = get_module_registry().get_descriptor_by_name(
             'org.vistrails.vistrails.basic',
             'String')
     else:  # isinstance(value, float):
         self.type = get_module_registry().get_descriptor_by_name(
             'org.vistrails.vistrails.basic',
             'Float')
Esempio n. 2
0
    def make_port_specs(self):
        self._port_specs = {}
        self._input_port_specs = []
        self._output_port_specs = []
        self._input_remap = {}
        self._output_remap = {}
        if self.pipeline is None:
            return

        registry = get_module_registry()
        for module in self.pipeline.module_list:
            if module.name == 'OutputPort' and module.package == basic_pkg:
                (port_name, sigstring, optional, depth, _) = \
                    self.get_port_spec_info(module)
                port_spec = registry.create_port_spec(port_name, 'output',
                                                      None, sigstring,
                                                      optional, depth=depth)
                self._port_specs[(port_name, 'output')] = port_spec
                self._output_port_specs.append(port_spec)
                self._output_remap[port_name] = module
            elif module.name == 'InputPort' and module.package == basic_pkg:
                (port_name, sigstring, optional, depth, _) = \
                    self.get_port_spec_info(module)
                port_spec = registry.create_port_spec(port_name, 'input',
                                                      None, sigstring,
                                                      optional, depth=depth)
                self._port_specs[(port_name, 'input')] = port_spec
                self._input_port_specs.append(port_spec)
                self._input_remap[port_name] = module
Esempio n. 3
0
 def _get_module_descriptor(self):
     if self._module_descriptor is None or \
             self._module_descriptor() is None:
         reg = get_module_registry()
         self._module_descriptor = \
             weakref.ref(reg.get_descriptor_by_name(*self.descriptor_info))
     return self._module_descriptor()
Esempio n. 4
0
 def updateModule(self, module):
     if self.updateLocked: return
     self.check_need_save_changes()
     self.module = module
     self.confWidget.setUpdatesEnabled(False)    
     self.confWidget.setVisible(False)
     self.confWidget.clear()
     if module and self.controller:
         registry = get_module_registry()
         getter = registry.get_configuration_widget
         widgetType = None
         try:
             widgetType = \
                 getter(module.package, module.name, module.namespace)
         except ModuleRegistryException:
             pass
         if not widgetType:
             widgetType = DefaultModuleConfigurationWidget
         widget = widgetType(module, self.controller)
     
         self.confWidget.setUpWidget(widget)
         self.connect(widget, QtCore.SIGNAL("doneConfigure"),
                      self.configureDone)
         self.connect(widget, QtCore.SIGNAL("stateChanged"),
                      self.stateChanged)
     self.confWidget.setUpdatesEnabled(True)
     self.confWidget.setVisible(True)
     self.hasChanges = False
     # we need to reset the title in case there were changes
     self.setWindowTitle("Module Configuration")
Esempio n. 5
0
    def inspect_spreadsheet_cells(self, pipeline):
        """ inspect_spreadsheet_cells(pipeline: Pipeline) -> None
        Inspect the pipeline to see how many cells is needed
        
        """
        registry = get_module_registry()
        self.spreadsheet_cells = []
        if not pipeline: return

        def find_spreadsheet_cells(pipeline, root_id=None):
            if root_id is None:
                root_id = []
            # Sometimes we run without the spreadsheet!
            spreadsheet_pkg = \
                '%s.spreadsheet' % get_vistrails_default_pkg_prefix()
            if registry.has_module(spreadsheet_pkg, 'SpreadsheetCell'):
                # First pass to check cells types
                cellType = \
                    registry.get_descriptor_by_name(spreadsheet_pkg,
                                                    'SpreadsheetCell').module
                for mId, module in pipeline.modules.iteritems():
                    desc = registry.get_descriptor_by_name(module.package, 
                                                           module.name, 
                                                           module.namespace)
                    if issubclass(desc.module, cellType):
                        self.spreadsheet_cells.append(root_id + [mId])

            for subworkflow_id in self.find_subworkflows(pipeline):
                subworkflow = pipeline.modules[subworkflow_id]
                if subworkflow.pipeline is not None:
                    find_spreadsheet_cells(subworkflow.pipeline, 
                                           root_id + [subworkflow_id])

        find_spreadsheet_cells(pipeline)
Esempio n. 6
0
    def can_convert(cls, sub_descs, super_descs):
        from vistrails.core.modules.module_registry import get_module_registry
        from vistrails.core.system import get_vistrails_basic_pkg_id
        reg = get_module_registry()
        basic_pkg = get_vistrails_basic_pkg_id()
        variant_desc = reg.get_descriptor_by_name(basic_pkg, 'Variant')
        desc = reg.get_descriptor(cls)

        def check_types(sub_descs, super_descs):
            for (sub_desc, super_desc) in izip(sub_descs, super_descs):
                if (sub_desc == variant_desc or super_desc == variant_desc):
                    continue
                if not reg.is_descriptor_subclass(sub_desc, super_desc):
                    return False
            return True

        in_port = reg.get_port_spec_from_descriptor(
                desc,
                'in_value', 'input')
        if (len(sub_descs) != len(in_port.descriptors()) or
                not check_types(sub_descs, in_port.descriptors())):
            return False
        out_port = reg.get_port_spec_from_descriptor(
                desc,
                'out_value', 'output')
        if (len(out_port.descriptors()) != len(super_descs)
                or not check_types(out_port.descriptors(), super_descs)):
            return False

        return True
    def register_self(cls, **kwargs):
        registry = get_module_registry()
        def resolve_type(t):
            if isinstance(t, tuple):
                return registry.get_descriptor_by_name(*t).module
            elif isinstance(t, type):
                return t
            else:
                assert False, ("Unknown type " + str(type(t)))

        registry.add_module(cls, **kwargs)
        try:
            ips = cls.input_ports
        except AttributeError:
            pass
        else:
            for (port_name, types) in ips:
                registry.add_input_port(cls,
                                        port_name,
                                        list(resolve_type(t) for t in types))

        try:
            ops = cls.output_ports
        except AttributeError:
            pass
        else:
            for (port_name, types) in ops:
                registry.add_output_port(cls,
                                         port_name,
                                         list(resolve_type(t) for t in types))
Esempio n. 8
0
    def find_descriptor(controller, pipeline, module_id, desired_version=''):
        reg = get_module_registry()

        get_descriptor = reg.get_descriptor_by_name
        pm = get_package_manager()
        invalid_module = pipeline.modules[module_id]
        mpkg, mname, mnamespace, mid = (invalid_module.package,
                                        invalid_module.name,
                                        invalid_module.namespace,
                                        invalid_module.id)
        pkg = pm.get_package(mpkg)
        desired_version = ''
        d = None
        # don't check for abstraction/subworkflow since the old module
        # could be a subworkflow
        if reg.has_abs_upgrade(*invalid_module.descriptor_info):
            return reg.get_abs_upgrade(*invalid_module.descriptor_info)

        try:
            try:
                d = get_descriptor(mpkg, mname, mnamespace, '', desired_version)
            except MissingModule, e:
                r = None
                if pkg.can_handle_missing_modules():
                    r = pkg.handle_missing_module(controller, module_id, 
                                                  pipeline)
                    d = get_descriptor(mpkg, mname, mnamespace, '', 
                                       desired_version)
                if not r:
                    raise e
        except MissingModule, e:
            return None
Esempio n. 9
0
    def check_port_spec(module, port_name, port_type, descriptor=None, 
                        sigstring=None):
        basic_pkg = get_vistrails_basic_pkg_id()

        reg = get_module_registry()
        found = False
        try:
            if descriptor is not None:
                s = reg.get_port_spec_from_descriptor(descriptor, port_name,
                                                      port_type)
                found = True

                spec_tuples = parse_port_spec_string(sigstring, basic_pkg)
                for i in xrange(len(spec_tuples)):
                    spec_tuple = spec_tuples[i]
                    port_pkg = reg.get_package_by_name(spec_tuple[0])
                    if port_pkg.identifier != spec_tuple[0]:
                        # we have an old identifier
                        spec_tuples[i] = (port_pkg.identifier,) + spec_tuple[1:]
                sigstring = create_port_spec_string(spec_tuples)
                # sigstring = expand_port_spec_string(sigstring, basic_pkg)
                if s.sigstring != sigstring:
                    msg = ('%s port "%s" of module "%s" exists, but '
                           'signatures differ "%s" != "%s"') % \
                           (port_type.capitalize(), port_name, module.name,
                            s.sigstring, sigstring)
                    raise UpgradeWorkflowError(msg, module, port_name, port_type)
        except MissingPort:
            pass

        if not found and \
                not module.has_portSpec_with_name((port_name, port_type)):
            msg = '%s port "%s" of module "%s" does not exist.' % \
                (port_type.capitalize(), port_name, module.name)
            raise UpgradeWorkflowError(msg, module, port_name, port_type)
Esempio n. 10
0
 def loadWidget( self, pipeline):
     from PyQt4 import QtGui
     aliases = pipeline.aliases
     widget = QtGui.QWidget()
     layout = QtGui.QVBoxLayout()
     hidden_aliases = self.plot.computeHiddenAliases()
     for name, (type, oId, parentType, parentId, mId) in aliases.iteritems():
         if name not in hidden_aliases:
             p = pipeline.db_get_object(type, oId)
             if p.identifier == '':
                 idn = 'edu.utah.sci.vistrails.basic'
             else:
                 idn = p.identifier
             reg = get_module_registry()
             p_module = reg.get_module_by_name(idn, p.type, p.namespace)
             if p_module is not None:
                 widget_type = get_widget_class(p_module)
             else:
                 widget_type = StandardConstantWidget
             p_widget = widget_type(p, None)
             a_layout = QtGui.QHBoxLayout()
             label = QtGui.QLabel(name)
             a_layout.addWidget(label)
             a_layout.addWidget(p_widget)
             
             layout.addLayout(a_layout)
             self.alias_widgets[name] = p_widget
             
     widget.setLayout(layout)
     return widget
Esempio n. 11
0
 def compute(self):
     reg = get_module_registry()
     tf = self.getInputFromPort('TransferFunction')
     new_tf = copy.copy(tf)
     if self.hasInputFromPort('Input'):
         port = self.getInputFromPort('Input')
         algo = port.vtkInstance.GetProducer()
         output = algo.GetOutput(port.vtkInstance.GetIndex())
         (new_tf._min_range, new_tf._max_range) = output.GetScalarRange()
     elif self.hasInputFromPort('Dataset'):
         algo = self.getInputFromPort('Dataset').vtkInstance
         output = algo
         (new_tf._min_range, new_tf._max_range) = output.GetScalarRange()
     else:
         (new_tf._min_range, new_tf._max_range) = self.getInputFromPort('Range')
         
     self.setResult('TransferFunction', new_tf)
     (of,cf) = new_tf.get_vtk_transfer_functions()
     
     of_module = reg.get_descriptor_by_name(vtk_pkg_identifier, 
                                            'vtkPiecewiseFunction').module()
     of_module.vtkInstance  = of
     
     cf_module = reg.get_descriptor_by_name(vtk_pkg_identifier, 
                                            'vtkColorTransferFunction').module()
     cf_module.vtkInstance  = cf
     
     self.setResult('vtkPicewiseFunction', of_module)
     self.setResult('vtkColorTransferFunction', cf_module)
Esempio n. 12
0
 def updateController(self, controller):
     """ updateController(controller: VistrailController) -> None
     Construct input forms for a controller's variables
     
     """
     # we shouldn't do this whenver the controller changes...
     if self.controller != controller:
         self.controller = controller
         if self.updateLocked: return
         self.vWidget.clear()
         if controller:
             reg = module_registry.get_module_registry()
             for var in [v for v in controller.vistrail.vistrail_vars]:
                 try:
                     descriptor = reg.get_descriptor_by_name(var.package,
                                                             var.module, 
                                                             var.namespace)
                 except module_registry.ModuleRegistryException:
                     debug.critical("Missing Module Descriptor for vistrail"
                                    " variable %s\nPackage: %s\nType: %s"
                                    "\nNamespace: %s" % \
                                        (var.name, var.package, var.module, 
                                         var.namespace))
                     continue
                 self.vWidget.addVariable(var.uuid, var.name, descriptor, 
                                          var.value)
             self.vWidget.showPromptByChildren()
         else:
             self.vWidget.showPrompt(False)
Esempio n. 13
0
 def compute(self):
     params = self.readInputs()
     signature = self.getId(params)
     jm = JobMonitor.getInstance()
     # use cached job if it exist
     cache = jm.getCache(signature)
     if cache:
         self.setResults(cache.parameters)
         return
     # check if job is running
     job = jm.getJob(signature)
     if job:
         params = job.parameters
     else:
         # start job
         params = self.startJob(params)
         # set visible name
         # check custom name
         m = self.interpreter._persistent_pipeline.modules[self.id]
         if '__desc__' in m.db_annotations_key_index:
             name = m.get_annotation_by_key('__desc__').value.strip()
         else:
             reg = get_module_registry()
             name = reg.get_descriptor(self.__class__).name
         jm.addJob(signature, params, name)
     # call method to check job
     jm.checkJob(self, signature, self.getMonitor(params))
     # job is finished, set outputs
     params = self.finishJob(params)
     self.setResults(params)
     cache = jm.setCache(signature, params)
Esempio n. 14
0
def handle_module_upgrade_request(controller, module_id, pipeline):
    old_module = pipeline.modules[module_id]
    if old_module.name == "JSONFile" and old_module.version != "0.1.5" and old_module.namespace == "read":
        from vistrails.core.db.action import create_action
        from vistrails.core.modules.module_registry import get_module_registry
        from .read.read_json import JSONObject

        reg = get_module_registry()
        new_desc = reg.get_descriptor(JSONObject)
        new_module = controller.create_module_from_descriptor(new_desc, old_module.location.x, old_module.location.y)
        actions = UpgradeWorkflowHandler.replace_generic(controller, pipeline, old_module, new_module)

        new_function = controller.create_function(new_module, "key_name", ["_key"])
        actions.append(create_action([("add", new_function, "module", new_module.id)]))
        return actions

    module_remap = {
        "read|csv|CSVFile": [(None, "0.1.1", "read|CSVFile", {"src_port_remap": {"self": "value"}})],
        "read|numpy|NumPyArray": [(None, "0.1.1", "read|NumPyArray", {"src_port_remap": {"self": "value"}})],
        "read|CSVFile": [("0.1.1", "0.1.2", None, {"src_port_remap": {"self": "value"}}), ("0.1.3", "0.1.5", None, {})],
        "read|NumPyArray": [("0.1.1", "0.1.2", None, {"src_port_remap": {"self": "value"}})],
        "read|ExcelSpreadsheet": [
            ("0.1.1", "0.1.2", None, {"src_port_remap": {"self": "value"}}),
            ("0.1.3", "0.1.4", None, {}),
        ],
        "read|JSONFile": [(None, "0.1.5", "read|JSONObject")],
    }

    return UpgradeWorkflowHandler.remap_module(controller, module_id, pipeline, module_remap)
Esempio n. 15
0
 def __subclasscheck__(self, other):
     if not issubclass(other, type):
         raise TypeError
     if not isinstance(other, ModuleClass):
         return False
     reg = get_module_registry()
     return reg.is_descriptor_subclass(self.descriptor, other.descriptor)
Esempio n. 16
0
 def _get_base_descriptor(self):
     if self._base_descriptor is None and self.base_descriptor_id >= 0:
         from vistrails.core.modules.module_registry import get_module_registry
         reg = get_module_registry()
         self._base_descriptor = \
             reg.descriptors_by_id[self.base_descriptor_id]
     return self._base_descriptor
Esempio n. 17
0
    def collectParameterActions(self):
        """ collectParameterActions() -> list
        Return a list of action lists corresponding to each dimension
        
        """
        if not self.pipeline:
            return None

        reg = get_module_registry()
        parameterValues = [[], [], [], []]
        counts = self.label.getCounts()
        for i in xrange(self.layout().count()):
            pEditor = self.layout().itemAt(i).widget()
            if pEditor and isinstance(pEditor, QParameterSetEditor):
                for paramWidget in pEditor.paramWidgets:
                    editor = paramWidget.editor
                    interpolator = editor.stackedEditors.currentWidget()
                    paramInfo = paramWidget.param
                    dim = paramWidget.getDimension()
                    if dim in [0, 1, 2, 3]:
                        count = counts[dim]
                        values = interpolator.get_values(count)
                        if not values:
                            return None
                        pId = paramInfo.id
                        pType = paramInfo.dbtype
                        parentType = paramInfo.parent_dbtype
                        parentId = paramInfo.parent_id
                        function = self.pipeline.db_get_object(parentType,
                                                               parentId)
                        old_param = self.pipeline.db_get_object(pType,pId)
                        pName = old_param.name
                        pAlias = old_param.alias
                        pIdentifier = old_param.identifier
                        actions = []
                        tmp_id = -1L
                        for v in values:
                            getter = reg.get_descriptor_by_name
                            desc = getter(paramInfo.identifier,
                                          paramInfo.type,
                                          paramInfo.namespace)
                            if not isinstance(v, str):
                                str_value = desc.module.translate_to_string(v)
                            else:
                                str_value = v
                            new_param = ModuleParam(id=tmp_id,
                                                    pos=old_param.pos,
                                                    name=pName,
                                                    alias=pAlias,
                                                    val=str_value,
                                                    type=paramInfo.type,
                                                    identifier=pIdentifier
                                                    )
                            action_spec = ('change', old_param, new_param,
                                           parentType, function.real_id)
                            action = vistrails.core.db.action.create_action([action_spec])
                            actions.append(action)
                        parameterValues[dim].append(actions)
                        tmp_id -= 1
        return [zip(*p) for p in parameterValues]
Esempio n. 18
0
def get_module_registry():
    global _registry
    if _registry is None:
        from vistrails.core.modules.module_registry import get_module_registry

        _registry = get_module_registry()
    return _registry
Esempio n. 19
0
    def init(self):
        """Initial setup of the Manager.

        Discovers plots and variable loaders from packages and registers
        notifications for packages loaded in the future.
        """
        app = get_vistrails_application()

        # dat_new_plot(plot: Plot)
        app.create_notification('dat_new_plot')
        # dat_removed_plot(plot: Plot)
        app.create_notification('dat_removed_plot')
        # dat_new_loader(loader: BaseVariableLoader)
        app.create_notification('dat_new_loader')
        # dat_removed_loader(loader: BaseVariableLoader)
        app.create_notification('dat_removed_loader')
        # dat_new_operation(loader: VariableOperation)
        app.create_notification('dat_new_operation')
        # dat_removed_operation(loader: VariableOperation)
        app.create_notification('dat_removed_operation')

        app.register_notification("reg_new_package", self.new_package)
        app.register_notification("reg_deleted_package", self.deleted_package)

        # Load the Plots and VariableLoaders from the packages
        registry = get_module_registry()
        for package in registry.package_list:
            self.new_package(package.identifier)
Esempio n. 20
0
    def createEditor(self, parent, option, index):
        registry = get_module_registry()
        if index.column()==2: #Depth type
            spinbox = QtGui.QSpinBox(parent)
            spinbox.setValue(0)
            return spinbox
        elif index.column()==1: #Port type
            combo = CompletingComboBox(parent)
            # FIXME just use descriptors here!!
            variant_desc = registry.get_descriptor_by_name(
                get_vistrails_basic_pkg_id(), 'Variant')
            for _, pkg in sorted(registry.packages.iteritems()):
                pkg_item = QtGui.QStandardItem("----- %s -----" % pkg.name)
                pkg_item.setData('', QtCore.Qt.UserRole)
                pkg_item.setFlags(pkg_item.flags() & ~(
                        QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable))
                font = pkg_item.font()
                font.setBold(True)
                pkg_item.setFont(font)
                combo.model().appendRow(pkg_item)
                for _, descriptor in sorted(pkg.descriptors.iteritems()):
                    if descriptor is variant_desc:
                        variant_index = combo.count()
                    combo.addItem("%s (%s)" % (descriptor.name,
                                               descriptor.identifier),
                                  descriptor.sigstring)

            combo.select_default_item(variant_index)
            return combo
        else:
            return QtGui.QItemDelegate.createEditor(self, parent, option, index)
Esempio n. 21
0
 def inspect_input_output_ports(self, pipeline):
     """ inspect_input_output_ports(pipeline: Pipeline) -> None
     Inspect the pipeline input/output ports, useful for submodule
     
     """
     registry = get_module_registry()
     self.input_ports = {}
     self.input_port_by_name = {}
     self.output_ports = {}
     self.output_port_by_name = {}
     if not pipeline: return        
     for cId, conn in pipeline.connections.iteritems():
         src_module = pipeline.modules[conn.source.moduleId]
         dst_module = pipeline.modules[conn.destination.moduleId]
         if src_module.name=='InputPort':
             spec = registry.getInputPortSpec(dst_module,
                                              conn.destination.name)
             name = self.get_port_name(src_module)
             if name=='':
                 name = conn.destination.name
             self.input_ports[src_module.id] = (name,
                                              spec[0])
             self.input_port_by_name[name] = src_module.id
         if dst_module.name=='OutputPort':
             spec = registry.getOutputPortSpec(src_module,
                                              conn.source.name)
             name = self.get_port_name(dst_module)
             if name=='':
                 name = conn.source.name
             self.output_ports[dst_module.id] = (name,
                                               spec[0])
             self.output_port_by_name[name] = dst_module.id
Esempio n. 22
0
    def __init__(self, param, size, parent=None):
        """ QParameterWidget(param: ParameterInfo, size: int, parent: QWidget)
                             -> QParameterWidget
        """
        QtGui.QWidget.__init__(self, parent)
        self.param = param
        self.prevWidget = 0
        
        hLayout = QtGui.QHBoxLayout(self)
        hLayout.setMargin(0)
        hLayout.setSpacing(0)        
        self.setLayout(hLayout)

        hLayout.addSpacing(5+16+5)

        self.label = QtGui.QLabel(param.spec.module)
        self.label.setFixedWidth(50)
        hLayout.addWidget(self.label)

        registry = get_module_registry()
        module = param.spec.descriptor.module
        assert issubclass(module, Constant)

        self.editor = QParameterEditor(param, size)
        hLayout.addWidget(self.editor)

        self.selector = QDimensionSelector()
        self.connect(self.selector.radioButtons[4],
                     QtCore.SIGNAL('toggled(bool)'),
                     self.disableParameter)
        hLayout.addWidget(self.selector)
Esempio n. 23
0
    def update_output_modules(self, *args, **kwargs):
        # need to find all currently loaded output modes (need to
        # check after modules are loaded and spin through registery)
        # and display them here
        reg = get_module_registry()
        output_d = reg.get_descriptor_by_name(get_vistrails_basic_pkg_id(),
                                              "OutputModule")
        sublist = reg.get_descriptor_subclasses(output_d)
        modes = {}
        for d in sublist:
            if hasattr(d.module, '_output_modes_dict'):
                for mode_type, (mode, _) in (d.module._output_modes_dict
                                                     .iteritems()):
                    modes[mode_type] = mode

        found_modes = set()
        for mode_type, mode in modes.iteritems():
            found_modes.add(mode_type)
            if mode_type not in self.mode_widgets:
                mode_config = None
                output_settings = self.persistent_config.outputDefaultSettings
                if output_settings.has(mode_type):
                    mode_config = getattr(output_settings, mode_type)
                widget = OutputModeConfigurationWidget(mode, mode_config)
                widget.fieldChanged.connect(self.field_was_changed)
                self.inner_layout.addWidget(widget)
                self.mode_widgets[mode_type] = widget
        
        for mode_type, widget in self.mode_widgets.items():
            if mode_type not in found_modes:
                self.inner_layout.removeWidget(self.mode_widgets[mode_type])
                del self.mode_widgets[mode_type]
    def theFunction(src, dst):
        global Variant_desc, InputPort_desc
        if Variant_desc is None:
            reg = get_module_registry()
            Variant_desc = reg.get_descriptor_by_name(
                    'org.vistrails.vistrails.basic', 'Variant')
            InputPort_desc = reg.get_descriptor_by_name(
                    'org.vistrails.vistrails.basic', 'InputPort')

        iport = conn.destination.name
        oport = conn.source.name
        src.enableOutputPort(oport)
        conf = get_vistrails_configuration()
        error_on_others = getattr(conf, 'errorOnConnectionTypeerror')
        error_on_variant = (error_on_others or
                            getattr(conf, 'errorOnVariantTypeerror'))
        errors = [error_on_others, error_on_variant]
        if isinstance(src, InputPort_desc.module):
            typecheck = [False]
        else:
            typecheck = [errors[desc is Variant_desc]
                         for desc in conn.source.spec.descriptors()]
        dst.set_input_port(
                iport,
                ModuleConnector(src, oport, conn.destination.spec, typecheck))
Esempio n. 25
0
    def begin_compute(self, obj):
        i = self.remap_id(obj.id)
        self.view.set_module_computing(i)

        reg = get_module_registry()
        module_name = reg.get_descriptor(obj.__class__).name

        self.log.start_execution(obj, i, module_name)
Esempio n. 26
0
 def __getattr__(self, attr_name):
     reg = get_module_registry()
     d = reg.get_descriptor_by_name(self._package.identifier,
                                    attr_name, '', 
                                    self._package.version)
     vt_api = get_api()
     module = vt_api.add_module_from_descriptor(d)
     return module
Esempio n. 27
0
def registerControl(module, **kwargs):
    """This function is used to register the control modules. In this way, all of
    them will have the same style and shape."""

    reg = get_module_registry()
    reg.add_module(module,
                   moduleRightFringe=[(0.0,0.0),(0.25,0.5),(0.0,1.0)],
                   moduleLeftFringe=[(0.0,0.0),(0.0,1.0)],
                   **kwargs)
Esempio n. 28
0
    def summon(self):
        result = self.module_descriptor.module()
        result.transfer_attrs(self)

        # FIXME this may not be quite right because we don't have self.registry
        # anymore.  That said, I'm not sure how self.registry would have
        # worked for hybrids...
        result.registry = get_module_registry()
        return result
Esempio n. 29
0
def register_self():
    registry = get_module_registry()
    r = registry.get_descriptor_by_name(vtk_pkg_identifier, 
                                        'vtkRenderer').module
    registry.add_module(VTKRenderOffscreen)
    registry.add_input_port(VTKRenderOffscreen, 'renderer', r)
    registry.add_input_port(VTKRenderOffscreen, 'width', Integer)
    registry.add_input_port(VTKRenderOffscreen, 'height', Integer)
    registry.add_output_port(VTKRenderOffscreen, 'image', File)
Esempio n. 30
0
def initialize(*args,**keywords):
    reg = get_module_registry()

    reg.add_module(Map)
    reg.add_input_port(Map, 'FunctionPort', (Module, ''))
    reg.add_input_port(Map, 'InputList', (List, ''))
    reg.add_input_port(Map, 'InputPort', (List, ''))
    reg.add_input_port(Map, 'OutputPort', (String, ''))
    reg.add_output_port(Map, 'Result', (List, ''))
Esempio n. 31
0
    def matchQueryParam(self, template, target):
        """ matchQueryParam(template: Param, target: Param) -> bool
        Check to see if target can match with a query template
        
        """
        if (template.type != target.type
                or template.identifier != target.identifier
                or template.namespace != target.namespace):
            return False

        reg = get_module_registry()
        desc = reg.get_descriptor_by_name(template.identifier, template.type,
                                          template.namespace)
        return desc.module.query_compute(target.strValue, template.strValue,
                                         template.queryMethod)
Esempio n. 32
0
    def testSummonModule(self):
        """Check that summon creates a correct module"""
        from vistrails.core.modules.basic_modules import identifier as basic_pkg

        x = Module()
        x.name = "String"
        x.package = basic_pkg
        try:
            registry = get_module_registry()
            c = x.summon()
            m = registry.get_descriptor_by_name(basic_pkg, 'String').module
            assert isinstance(c, m)
        except NoSummon:
            msg = "Expected to get a String object, got a NoSummon exception"
            self.fail(msg)
Esempio n. 33
0
    def setup(self, cell, plot):
        self.cell = cell
        self.plot = plot

        # Get pipeline of the cell
        mngr = VistrailManager(cell._controller)
        pipelineInfo = mngr.get_pipeline(cell.cellInfo)

        # Clear old tabs
        self.tabWidget.clear()

        # Get all of the plot modules in the pipeline
        plot_modules = get_plot_modules(
            pipelineInfo,
            cell._controller.current_pipeline)

        registry = get_module_registry()
        getter = registry.get_configuration_widget
        for module in plot_modules:
            widgetType = None
            widget = None

            # Try to get custom config widget for the module
            try:
                widgetType = \
                    getter(module.package, module.name, module.namespace)
            except ModuleRegistryException:
                pass

            if widgetType:
                # Use custom widget
                widget = widgetType(module, cell._controller)
                self.connect(widget, QtCore.SIGNAL("doneConfigure"),
                             self.configureDone)
                self.connect(widget, QtCore.SIGNAL("stateChanged"),
                             self.stateChanged)
            else:
                # Use PortsList widget, only if module has ports
                widget = DATPortsList(self)
                widget.update_module(module)
                if len(widget.port_spec_items) > 0:
                    widget.set_controller(cell._controller)
                else:
                    widget = None

            # Add widget in new tab
            if widget:
                self.tabWidget.addTab(widget, module.name)
Esempio n. 34
0
    def read_type(pipeline):
        """Read the type of a Variable from its pipeline.

        The type is obtained from the 'spec' input port of the 'OutputPort'
        module.
        """
        reg = get_module_registry()
        OutputPort = reg.get_module_by_name(
            'org.vistrails.vistrails.basic', 'OutputPort')
        outputs = find_modules_by_type(pipeline, [OutputPort])
        if len(outputs) == 1:
            output = outputs[0]
            if get_function(output, 'name') == 'value':
                spec = get_function(output, 'spec')
                return resolve_descriptor(spec)
        return None
Esempio n. 35
0
 def newPackage(self, package_identifier, prepend=False):
     # prepend places at the front of the list of packages,
     # by default adds to the end of the list of packages
     # Right now the list is sorted so prepend has no effect
     if package_identifier in self.packages:
         return self.packages[package_identifier]
     registry = get_module_registry()
     package_name = registry.packages[package_identifier].name
     package_item = QPackageTreeWidgetItem(None, package_name,
                                           package_identifier)
     self.packages[package_identifier] = package_item
     if prepend:
         self.treeWidget.insertTopLevelItem(0, package_item)
     else:
         self.treeWidget.addTopLevelItem(package_item)
     return package_item
Esempio n. 36
0
def get_widget_class(descriptor, widget_type=None, widget_use=None,
                     return_default=True):        
    reg = get_module_registry()
    cls = reg.get_constant_config_widget(descriptor, widget_type, 
                                         widget_use)
    prefix = get_prefix(reg, descriptor)
    if cls is None and return_default:
        if descriptor.module is not None and \
           hasattr(descriptor.module, 'get_widget_class'):
            cls = descriptor.module.get_widget_class()
        if cls is None:
            if widget_type == "enum":
                return StandardConstantEnumWidget
            else:
                return StandardConstantWidget
    return load_cls(cls, prefix)
Esempio n. 37
0
    def get_values(self):
        from vistrails.core.modules.module_registry import get_module_registry
        reg = get_module_registry()
        PersistentRef = \
            reg.get_descriptor_by_name(persistence_pkg, 'PersistentRef').module

        functions = []
        if self.new_file and self.new_file.get_path() and \
                self.managed_new.isChecked():
            # if self.new_file and self.new_file.get_path():
            functions.append(('value', [self.new_file.get_path()]))
        else:
            functions.append(('value', None))
            pass
        ref = PersistentRef()
        if self.managed_new.isChecked():
            if self.existing_ref and not self.existing_ref._exists:
                ref.id = self.existing_ref.id
                ref.version = self.existing_ref.version
            else:
                ref.id = str(uuid.uuid1())
                ref.version = None
            # endif
            ref.name = str(self.name_edit.text())
            ref.tags = str(self.tags_edit.text())
        elif self.managed_existing.isChecked():
            (ref.id, ref.version, ref.name, ref.tags) = \
                self.ref_widget.get_info()
        if self.keep_local.isChecked():
            functions.append(('localPath', [self.local_path.get_path()]))
            functions.append(
                ('readLocal', [str(self.r_priority_local.isChecked())]))
            functions.append(
                ('writeLocal', [str(self.write_managed_checkbox.isChecked())]))


#             ref.local_path = self.local_path.get_path()
#             ref.local_read = self.r_priority_local.isChecked()
#             ref.local_writeback = self.write_managed_checkbox.isChecked()
        else:
            ref.local_path = None
            # functions.append(('localPath', None))
            functions.append(('readLocal', None))
            functions.append(('writeLocal', None))
            pass
        functions.append(('ref', [PersistentRef.translate_to_string(ref)]))
        self.controller.update_functions(self.module, functions)
Esempio n. 38
0
def handle_module_upgrade_request(controller, module_id, pipeline):
    # Before 0.0.3, SQLSource's resultSet output was type ListOfElements (which
    #   doesn't exist anymore)
    # In 0.0.3, SQLSource's resultSet output was type List
    # In 0.1.0, SQLSource's output was renamed to result and is now a Table;
    #   this is totally incompatible and no upgrade code is possible
    #   the resultSet is kept for now for compatibility

    # Up to 0.0.4, DBConnection would ask for a password if one was necessary;
    #   this behavior has not been kept. There is now a password input port, to
    #   which you can connect a PasswordDialog from package dialogs if needed

    old_module = pipeline.modules[module_id]
    # DBConnection module from before 0.1.0: automatically add the password
    # prompt module
    if (old_module.name == 'DBConnection'
            and versions_increasing(old_module.version, '0.1.0')):
        reg = get_module_registry()
        # Creates the new module
        new_module = controller.create_module_from_descriptor(
            reg.get_descriptor(DBConnection))
        # Create the password module
        mod_desc = reg.get_descriptor_by_name(
            'org.vistrails.vistrails.dialogs', 'PasswordDialog')
        mod = controller.create_module_from_descriptor(mod_desc)
        # Adds a 'label' function to the password module
        ops = [('add', mod)]
        ops.extend(
            controller.update_function_ops(mod, 'label', ['Server password']))
        # Connects the password module to the new module
        conn = controller.create_connection(mod, 'result', new_module,
                                            'password')
        ops.append(('add', conn))
        # Replaces the old module with the new one
        upgrade_actions = UpgradeWorkflowHandler.replace_generic(
            controller,
            pipeline,
            old_module,
            new_module,
            src_port_remap={'self': 'connection'})
        password_fix_action = create_action(ops)
        return upgrade_actions + [password_fix_action]

    return UpgradeWorkflowHandler.attempt_automatic_upgrade(
        controller, pipeline, module_id)
Esempio n. 39
0
def get_query_widget_class(descriptor, widget_type=None):
    cls = get_widget_class(descriptor, widget_type, "query", False)
    if cls is None:
        if descriptor.module is not None and \
           hasattr(descriptor.module, 'get_query_widget_class'):
            cls = descriptor.module.get_query_widget_class()
        if cls is None:
            class DefaultQueryWidget(BaseQueryWidget):
                def __init__(self, param, parent=None):
                    BaseQueryWidget.__init__(self, 
                                             get_widget_class(descriptor), 
                                             ["==", "!="],
                                             param, parent)
            return DefaultQueryWidget
        reg = get_module_registry()
        prefix = get_prefix(reg, descriptor)
        return load_cls(cls, prefix)
    return cls
Esempio n. 40
0
    def contextMenuEvent(self, event):
        """Just dispatches the menu event to the widget item"""
        item = self.itemAt(event.pos())
        if item:
            # find top level
            p = item
            while p.parent():
                p = p.parent()
            # get package identifier
            assert isinstance(p, QPackageTreeWidgetItem)
            identifier = p.identifier
            registry = get_module_registry()
            package = registry.packages[identifier]
            try:
                if package.has_context_menu():
                    if isinstance(item, QPackageTreeWidgetItem):
                        text = None
                    elif isinstance(item, QNamespaceTreeWidgetItem):
                        return  # no context menu for namespaces
                    elif isinstance(item, QModuleTreeWidgetItem):
                        text = item.descriptor.name
                        if item.descriptor.namespace:
                            text = '%s|%s' % (item.descriptor.namespace, text)
                    else:
                        assert False, "fell through"
                    menu_items = package.context_menu(text)
                    if menu_items:
                        menu = QtGui.QMenu(self)
                        for text, callback in menu_items:
                            act = QtGui.QAction(text, self)
                            act.setStatusTip(text)
                            QtCore.QObject.connect(
                                act, QtCore.SIGNAL("triggered()"), callback)
                            menu.addAction(act)
                        menu.exec_(event.globalPos())
                    return
            except Exception, e:
                debug.unexpected_exception(e)
                debug.warning("Got exception trying to display %s's "
                              "context menu in the palette: %s\n%s" %
                              (package.name, debug.format_exception(e),
                               traceback.format_exc()))

            item.contextMenuEvent(event, self)
Esempio n. 41
0
    def _get_variables_root(controller=None):
        """Create or get the version tagged 'dat-vars'

        This is the base version of all DAT variables. It consists of a single
        OutputPort module with name 'value'.
        """
        if controller is None:
            controller = get_vistrails_application().get_controller()
            assert controller is not None
        if controller.vistrail.has_tag_str('dat-vars'):
            root_version = controller.vistrail.get_version_number('dat-vars')
        else:
            # Create the 'dat-vars' version
            controller.change_selected_version(0)
            reg = get_module_registry()
            operations = []

            # Add an OutputPort module
            descriptor = reg.get_descriptor_by_name(
                'org.vistrails.vistrails.basic', 'OutputPort')
            out_mod = controller.create_module_from_descriptor(descriptor)
            operations.append(('add', out_mod))

            # Add a function to this module
            operations.extend(
                controller.update_function_ops(
                    out_mod,
                    'name',
                    ['value']))

            # Perform the operations
            action = create_action(operations)
            controller.add_new_action(action)
            root_version = controller.perform_action(action)
            # Tag as 'dat-vars'
            controller.vistrail.set_tag(root_version, 'dat-vars')

        controller.change_selected_version(root_version)
        pipeline = controller.current_pipeline
        outmod_id = pipeline.modules.keys()
        assert len(outmod_id) == 1
        outmod_id = outmod_id[0]
        return controller, root_version, outmod_id
Esempio n. 42
0
def addWidget(packagePath):
    """ addWidget(packagePath: str) -> package
    Add a new widget type to the spreadsheet registry supplying a
    basic set of spreadsheet widgets

    """
    try:
        registry = get_module_registry()
        widget = importReturnLast(packagePath)
        if hasattr(widget, 'widgetName'):
            widgetName = widget.widgetName()
        else:
            widgetName = packagePath
        widget.registerWidget(registry, basic_modules, basicWidgets)
        spreadsheetRegistry.registerPackage(widget, packagePath)
        debug.log('  ==> Successfully import <%s>' % widgetName)
    except Exception, e:
        debug.log('  ==> Ignored package <%s>' % packagePath, e)
        widget = None
Esempio n. 43
0
    def contextMenuEvent(self, event):
        # Just dispatches the menu event to the widget item
        item = self.itemAt(event.pos())
        if item:
            # find top level
            p = item
            while p.parent():
                p = p.parent()
            # get package identifier
            identifiers = [
                i for i, j in self.parent().packages.iteritems()
                if j == weakref.ref(p)
            ]
            if identifiers:
                identifier = identifiers[0]
                registry = get_module_registry()
                package = registry.packages[identifier]
                try:
                    if package.has_contextMenuName():
                        name = package.contextMenuName(str(item.text(0)))
                        if name:
                            act = QtGui.QAction(name, self)
                            act.setStatusTip(name)

                            def callMenu():
                                if package.has_callContextMenu():
                                    name = package.callContextMenu(
                                        str(item.text(0)))

                            QtCore.QObject.connect(
                                act, QtCore.SIGNAL("triggered()"), callMenu)
                            menu = QtGui.QMenu(self)
                            menu.addAction(act)
                            menu.exec_(event.globalPos())
                        return
                except Exception, e:
                    debug.warning(
                        "Got exception trying to display %s's "
                        "context menu in the palette: %s: %s" %
                        (package.name, type(e).__name__, ', '.join(e.args)))

            item.contextMenuEvent(event, self)
Esempio n. 44
0
    def inspect_spreadsheet_cells(self, pipeline):
        """ inspect_spreadsheet_cells(pipeline: Pipeline) -> None
        Inspect the pipeline to see how many cells is needed
        
        """
        self.spreadsheet_cells = []
        if not pipeline:
            return

        registry = get_module_registry()
        # Sometimes we run without the spreadsheet!
        if not registry.has_module('org.vistrails.vistrails.spreadsheet',
                                   'SpreadsheetCell'):
            return
        cell_desc = registry.get_descriptor_by_name(
            'org.vistrails.vistrails.spreadsheet', 'SpreadsheetCell')
        output_desc = registry.get_descriptor_by_name(
            'org.vistrails.vistrails.basic', 'OutputModule')

        def find_spreadsheet_cells(pipeline, root_id=None):
            if root_id is None:
                root_id = []
            for mId, module in pipeline.modules.iteritems():
                desc = registry.get_descriptor_by_name(module.package,
                                                       module.name,
                                                       module.namespace)
                # SpreadsheetCell subclasses
                if registry.is_descriptor_subclass(desc, cell_desc):
                    self.spreadsheet_cells.append(root_id + [mId])
                # Output modules with a 'spreadsheet' mode
                elif registry.is_descriptor_subclass(desc, output_desc):
                    if desc.module.get_mode_class('spreadsheet') is not None:
                        self.spreadsheet_cells.append(root_id + [mId])

            for subworkflow_id in self.find_subworkflows(pipeline):
                subworkflow = pipeline.modules[subworkflow_id]
                if subworkflow.pipeline is not None:
                    find_spreadsheet_cells(subworkflow.pipeline,
                                           root_id + [subworkflow_id])

        find_spreadsheet_cells(pipeline)
Esempio n. 45
0
    def test_operation_resolution(self):
        import dat.tests.pkg_test_operations.init as pkg

        from vistrails.core.modules.basic_modules import Integer, String
        from vistrails.packages.URL.init import DownloadFile

        reg = get_module_registry()
        gd = reg.get_descriptor

        self.assertIs(
            find_operation(
                'overload_std',
                [gd(String), gd(Integer)]),
            pkg.overload_std_3)

        with self.assertRaises(InvalidOperation) as cm:
            find_operation(
                'overload_std',
                [gd(String), gd(DownloadFile)])
        self.assertIn("Found no match", cm.exception.message)

        with self.assertRaises(InvalidOperation) as cm:
            find_operation('nonexistent', [])
        self.assertIn("There is no ", cm.exception.message)

        self.assertIs(
            find_operation(
                'overload_custom',
                [gd(pkg.ModC), gd(pkg.ModD)]),
            pkg.overload_custom_2)

        with self.assertRaises(InvalidOperation) as cm:
            find_operation(
                'overload_custom',
                [gd(pkg.ModE), gd(pkg.ModE)])

        self.assertIs(
            find_operation(
                'overload_custom',
                [gd(pkg.ModD), gd(pkg.ModD)]),
            pkg.overload_custom_1)
Esempio n. 46
0
    def compute(self):
        reg = get_module_registry()
        tf = self.get_input('TransferFunction')
        new_tf = copy.copy(tf)
        if self.has_input('Input'):
            port = self.get_input('Input')
            algo = port.GetProducer()
            output = algo.GetOutput(port.GetIndex())
            (new_tf._min_range, new_tf._max_range) = output.GetScalarRange()
        elif self.has_input('Dataset'):
            algo = self.get_input('Dataset')
            output = algo
            (new_tf._min_range, new_tf._max_range) = output.GetScalarRange()
        else:
            (new_tf._min_range, new_tf._max_range) = self.get_input('Range')

        self.set_output('TransferFunction', new_tf)
        (of, cf) = new_tf.get_vtk_transfer_functions()

        self.set_output('vtkPicewiseFunction', of)
        self.set_output('vtkColorTransferFunction', cf)
Esempio n. 47
0
    def check_port_spec(module,
                        port_name,
                        port_type,
                        descriptor=None,
                        sigstring=None):
        basic_pkg = get_vistrails_basic_pkg_id()

        reg = get_module_registry()
        found = False
        try:
            if descriptor is not None:
                s = reg.get_port_spec_from_descriptor(descriptor, port_name,
                                                      port_type)
                found = True

                spec_tuples = parse_port_spec_string(sigstring, basic_pkg)
                for i in xrange(len(spec_tuples)):
                    spec_tuple = spec_tuples[i]
                    port_pkg = reg.get_package_by_name(spec_tuple[0])
                    if port_pkg.identifier != spec_tuple[0]:
                        # we have an old identifier
                        spec_tuples[i] = (
                            port_pkg.identifier, ) + spec_tuple[1:]
                sigstring = create_port_spec_string(spec_tuples)
                # sigstring = expand_port_spec_string(sigstring, basic_pkg)
                if s.sigstring != sigstring:
                    msg = ('%s port "%s" of module "%s" exists, but '
                           'signatures differ "%s" != "%s"') % \
                           (port_type.capitalize(), port_name, module.name,
                            s.sigstring, sigstring)
                    raise UpgradeWorkflowError(msg, module, port_name,
                                               port_type)
        except MissingPort:
            pass

        if not found and \
                not module.has_portSpec_with_name((port_name, port_type)):
            msg = '%s port "%s" of module "%s" does not exist.' % \
                (port_type.capitalize(), port_name, module.name)
            raise UpgradeWorkflowError(msg, module, port_name, port_type)
Esempio n. 48
0
    def connect_var(self, vt_var, dest_module, dest_portname):
        self._ensure_version()
        var_type_desc = get_module_registry().get_descriptor_by_name(
            vt_var.package, vt_var.module, vt_var.namespace)
        x = dest_module.location.x
        y = dest_module.location.y

        # Adapted from VistrailController#connect_vistrail_var()
        var_module = self.controller.find_vistrail_var_module(vt_var.uuid)
        if var_module is None:
            var_module = self.controller.create_vistrail_var_module(
                var_type_desc,
                x, y,
                vt_var.uuid)
            self.operations.append(('add', var_module))
        elif self.controller.check_vistrail_var_connected(var_module,
                                                          dest_module,
                                                          dest_portname):
            return
        connection = self.controller.create_connection(
            var_module, 'value', dest_module, dest_portname)
        self.operations.append(('add', connection))
Esempio n. 49
0
def registerSelf():
    """ registerSelf() -> None
    Registry module with the registry
    """
    from base_module import vtkRendererOutput
    vtkRendererOutput.register_output_mode(vtkRendererToSpreadsheet)

    registry = get_module_registry()
    registry.add_module(VTKCell)
    registry.add_input_port(VTKCell, "Location", CellLocation)
    from vistrails.core import debug
    for (port, module) in [("AddRenderer", 'vtkRenderer'),
                           ("SetRenderView", 'vtkRenderView'),
                           ("InteractionHandler", 'vtkInteractionHandler'),
                           ("InteractorStyle", 'vtkInteractorStyle'),
                           ("AddPicker", 'vtkAbstractPicker')]:
        try:
            registry.add_input_port(VTKCell, port,
                                    '(%s:%s)' % (vtk_pkg_identifier, module))
        except Exception, e:
            debug.warning(
                "Got an exception adding VTKCell's %s input "
                "port" % port, e)
Esempio n. 50
0
def initialize():
    from tensorflow.python.ops import standard_ops

    reg = get_module_registry()

    for module in base_modules:
        reg.add_module(module)

    # Optimizers
    reg.add_module(AutoOptimizer)

    optimizers = set(['Optimizer'])
    optimizers.update(register_optimizers(reg))

    # Operations
    reg.add_module(AutoOperation)

    ops = set(wrapped)
    ops.update(register_operations(reg, standard_ops, '', ops))
    ops.update(register_operations(reg, tensorflow.train, 'train',
                                   ops | optimizers))
    ops.update(register_operations(reg, tensorflow.nn, 'nn', ops))
    ops.update(register_operations(reg, tensorflow.image, 'image', ops))
Esempio n. 51
0
    def __init__(self, identifier, version=''):
        self._package = None
        # namespace_dict : {namespace : (namespace_dict, modules)}
        self._namespaces = ({}, [])
        reg = get_module_registry()
        self._package = reg.get_package_by_name(identifier, version)
        for desc in self._package.descriptor_list:
            if desc.namespace:
                namespaces = desc.namespace.split('|')
            else:
                namespaces = []
            cur_namespace = self._namespaces[0]
            cur_modules = self._namespaces[1]
            for namespace in namespaces:
                if namespace not in cur_namespace:
                    cur_namespace[namespace] = ({}, [])
                cur_modules = cur_namespace[namespace][1]
                cur_namespace = cur_namespace[namespace][0]
            cur_modules.append(desc)

        iteritems = [self._namespaces]
        for (namespaces, modules) in iteritems:
            modules.sort(key=lambda d: d.name)
            iteritems = itertools.chain(iteritems, namespaces.itervalues())
Esempio n. 52
0
    def setup_pipeline(self, pipeline, **kwargs):
        """setup_pipeline(controller, pipeline, locator, currentVersion,
                          view, aliases, **kwargs)
        Matches a pipeline with the persistent pipeline and creates
        instances of modules that aren't in the cache.
        """
        def fetch(name, default):
            return kwargs.pop(name, default)

        controller = fetch('controller', None)
        locator = fetch('locator', None)
        current_version = fetch('current_version', None)
        view = fetch('view', DummyView())
        vistrail_variables = fetch('vistrail_variables', None)
        aliases = fetch('aliases', None)
        params = fetch('params', None)
        extra_info = fetch('extra_info', None)
        logger = fetch('logger', DummyLogController)
        sinks = fetch('sinks', None)
        reason = fetch('reason', None)
        actions = fetch('actions', None)
        done_summon_hooks = fetch('done_summon_hooks', [])
        module_executed_hook = fetch('module_executed_hook', [])
        stop_on_error = fetch('stop_on_error', True)
        parent_exec = fetch('parent_exec', None)
        job_monitor = fetch('job_monitor', None)

        reg = get_module_registry()

        if len(kwargs) > 0:
            raise VistrailsInternalError('Wrong parameters passed '
                                         'to setup_pipeline: %s' % kwargs)

        def create_null():
            """Creates a Null value"""
            getter = reg.get_descriptor_by_name
            descriptor = getter(basic_pkg, 'Null')
            return descriptor.module()

        def create_constant(param, module):
            """Creates a Constant from a parameter spec"""
            getter = reg.get_descriptor_by_name
            desc = getter(param.identifier, param.type, param.namespace)
            constant = desc.module()
            constant.id = module.id
            #             if param.evaluatedStrValue:
            #                 constant.setValue(param.evaluatedStrValue)
            if param.strValue != '':
                constant.setValue(param.strValue)
            else:
                constant.setValue(
                    constant.translate_to_string(constant.default_value))
            return constant

        ### BEGIN METHOD ###

#         if self.debugger:
#             self.debugger.update()
        to_delete = []
        errors = {}

        if controller is not None:
            # Controller is none for sub_modules
            controller.validate(pipeline)
        else:
            pipeline.validate()

        self.resolve_aliases(pipeline, aliases)
        if vistrail_variables:
            self.resolve_variables(vistrail_variables, pipeline)

        self.update_params(pipeline, params)

        (tmp_to_persistent_module_map, conn_map, module_added_set,
         conn_added_set) = self.add_to_persistent_pipeline(pipeline)

        # Create the new objects
        for i in module_added_set:
            persistent_id = tmp_to_persistent_module_map[i]
            module = self._persistent_pipeline.modules[persistent_id]
            obj = self._objects[persistent_id] = module.summon()
            obj.interpreter = self
            obj.id = persistent_id
            obj.signature = module._signature

            # Checking if output should be stored
            if module.has_annotation_with_key('annotate_output'):
                annotate_output = module.get_annotation_by_key(
                    'annotate_output')
                #print annotate_output
                if annotate_output:
                    obj.annotate_output = True

            for f in module.functions:
                connector = None
                if len(f.params) == 0:
                    connector = ModuleConnector(create_null(), 'value',
                                                f.get_spec('output'))
                elif len(f.params) == 1:
                    p = f.params[0]
                    try:
                        constant = create_constant(p, module)
                        connector = ModuleConnector(constant, 'value',
                                                    f.get_spec('output'))
                    except Exception, e:
                        debug.unexpected_exception(e)
                        err = ModuleError(
                            module,
                            "Uncaught exception creating Constant from "
                            "%r: %s" % (p.strValue, debug.format_exception(e)))
                        errors[i] = err
                        to_delete.append(obj.id)
                else:
                    tupleModule = vistrails.core.interpreter.base.InternalTuple(
                    )
                    tupleModule.length = len(f.params)
                    for (j, p) in enumerate(f.params):
                        try:
                            constant = create_constant(p, module)
                            constant.update()
                            connector = ModuleConnector(
                                constant, 'value', f.get_spec('output'))
                            tupleModule.set_input_port(j, connector)
                        except Exception, e:
                            debug.unexpected_exception(e)
                            err = ModuleError(
                                module, "Uncaught exception creating Constant "
                                "from %r: %s" %
                                (p.strValue, debug.format_exception(e)))
                            errors[i] = err
                            to_delete.append(obj.id)
                    connector = ModuleConnector(tupleModule, 'value',
                                                f.get_spec('output'))
                if connector:
                    obj.set_input_port(f.name, connector, is_method=True)
Esempio n. 53
0
    def execute(self, *args, **kwargs):
        """Execute the pipeline.

        Positional arguments are either input values (created from
        ``module == value``, where `module` is a Module from the pipeline and
        `value` is some value or Function instance) for the pipeline's
        InputPorts, or Module instances (to select sink modules).

        Keyword arguments are also used to set InputPort by looking up inputs
        by name.

        Example::

           input_bound = pipeline.get_input('higher_bound')
           input_url = pipeline.get_input('url')
           sinkmodule = pipeline.get_module(32)
           pipeline.execute(sinkmodule,
                            input_bound == vt.Function(Integer, 10),
                            input_url == 'http://www.vistrails.org/',
                            resolution=15)  # kwarg: only one equal sign
        """
        sinks = set()
        inputs = {}

        reg = get_module_registry()
        InputPort_desc = reg.get_descriptor_by_name(
            get_vistrails_basic_pkg_id(), 'InputPort')

        # Read args
        for arg in args:
            if isinstance(arg, ModuleValuePair):
                if arg.module.id in inputs:
                    raise ValueError("Multiple values set for InputPort %r" %
                                     get_inputoutput_name(arg.module))
                if not reg.is_descriptor_subclass(arg.module.module_descriptor,
                                                  InputPort_desc):
                    raise ValueError("Module %d is not an InputPort" %
                                     arg.module.id)
                inputs[arg.module.id] = arg.value
            elif isinstance(arg, Module):
                sinks.add(arg.module_id)

        # Read kwargs
        for key, value in kwargs.iteritems():
            key = self.get_input(key)  # Might raise KeyError
            if key.module_id in inputs:
                raise ValueError("Multiple values set for InputPort %r" %
                                 get_inputoutput_name(key.module))
            inputs[key.module_id] = value

        reason = "API pipeline execution"
        sinks = sinks or None

        # Use controller only if no inputs were passed in
        if (not inputs and self.vistrail is not None
                and self.vistrail.current_version == self.version):
            controller = self.vistrail.controller
            results, changed = controller.execute_workflow_list([[
                controller.locator,  # locator
                self.version,  # version
                self.pipeline,  # pipeline
                DummyView(),  # view
                None,  # custom_aliases
                None,  # custom_params
                reason,  # reason
                sinks,  # sinks
                None,  # extra_info
            ]])
            result, = results
        else:
            pipeline = self.pipeline
            if inputs:
                id_scope = IdScope(1)
                pipeline = pipeline.do_copy(False, id_scope)

                # A hach to get ids from id_scope that we know won't collide:
                # make them negative
                id_scope.getNewId = lambda t, g=id_scope.getNewId: -g(t)

                create_module = \
                        VistrailController.create_module_from_descriptor_static
                create_function = VistrailController.create_function_static
                create_connection = VistrailController.create_connection_static
                # Fills in the ExternalPipe ports
                for module_id, values in inputs.iteritems():
                    module = pipeline.modules[module_id]
                    if not isinstance(values, (list, tuple)):
                        values = [values]

                    # Guess the type of the InputPort
                    _, sigstrings, _, _, _ = get_port_spec_info(
                        pipeline, module)
                    sigstrings = parse_port_spec_string(sigstrings)

                    # Convert whatever we got to a list of strings, for the
                    # pipeline
                    values = [
                        reg.convert_port_val(val, sigstring, None)
                        for val, sigstring in izip(values, sigstrings)
                    ]

                    if len(values) == 1:
                        # Create the constant module
                        constant_desc = reg.get_descriptor_by_name(
                            *sigstrings[0])
                        constant_mod = create_module(id_scope, constant_desc)
                        func = create_function(id_scope, constant_mod, 'value',
                                               values)
                        constant_mod.add_function(func)
                        pipeline.add_module(constant_mod)

                        # Connect it to the ExternalPipe port
                        conn = create_connection(id_scope, constant_mod,
                                                 'value', module,
                                                 'ExternalPipe')
                        pipeline.db_add_connection(conn)
                    else:
                        raise RuntimeError("TODO : create tuple")

            interpreter = get_default_interpreter()
            result = interpreter.execute(pipeline, reason=reason, sinks=sinks)

        if result.errors:
            raise ExecutionErrors(self, result)
        else:
            return ExecutionResults(self, result)
Esempio n. 54
0
def get_module_registry():
    global _registry
    if _registry is None:
        from vistrails.core.modules.module_registry import get_module_registry
        _registry = get_module_registry()
    return _registry
Esempio n. 55
0
    def performParameterExploration(self):
        """ performParameterExploration() -> None        
        Perform the exploration by collecting a list of actions
        corresponding to each dimension
        
        """
        registry = get_module_registry()
        actions = self.peWidget.table.collectParameterActions()
        spreadsheet_pkg = 'org.vistrails.vistrails.spreadsheet'
        # Set the annotation to persist the parameter exploration
        # TODO: For now, we just replace the existing exploration - Later we should append them.
        xmlString = "<paramexps>\n" + self.getParameterExploration() + "\n</paramexps>"
        self.controller.vistrail.set_paramexp(self.currentVersion, xmlString)
        self.controller.set_changed(True)

        if self.controller.current_pipeline and actions:
            explorer = ActionBasedParameterExploration()
            (pipelines, performedActions) = explorer.explore(
                self.controller.current_pipeline, actions)
            
            dim = [max(1, len(a)) for a in actions]
            if (registry.has_module(spreadsheet_pkg, 'CellLocation') and
                registry.has_module(spreadsheet_pkg, 'SheetReference')):
                modifiedPipelines = self.virtualCell.positionPipelines(
                    'PE#%d %s' % (QParameterExplorationTab.explorationId,
                                  self.controller.name),
                    dim[2], dim[1], dim[0], pipelines, self.controller)
            else:
                modifiedPipelines = pipelines

            mCount = []
            for p in modifiedPipelines:
                if len(mCount)==0:
                    mCount.append(0)
                else:
                    mCount.append(len(p.modules)+mCount[len(mCount)-1])
                
            # Now execute the pipelines
            totalProgress = sum([len(p.modules) for p in modifiedPipelines])
            progress = QtGui.QProgressDialog('Performing Parameter '
                                             'Exploration...',
                                             '&Cancel',
                                             0, totalProgress)
            progress.setWindowTitle('Parameter Exploration')
            progress.setWindowModality(QtCore.Qt.WindowModal)
            progress.show()

            QParameterExplorationTab.explorationId += 1
            interpreter = get_default_interpreter()
            for pi in xrange(len(modifiedPipelines)):
                progress.setValue(mCount[pi])
                QtCore.QCoreApplication.processEvents()
                if progress.wasCanceled():
                    break
                def moduleExecuted(objId):
                    if not progress.wasCanceled():
                        #progress.setValue(progress.value()+1)
                        #the call above was crashing when used by multithreaded
                        #code, replacing with the call below (thanks to Terence
                        #for submitting this fix). 
                        QtCore.QMetaObject.invokeMethod(progress, "setValue", 
                                        QtCore.Q_ARG(int,progress.value()+1))
                        QtCore.QCoreApplication.processEvents()
                kwargs = {'locator': self.controller.locator,
                          'current_version': self.controller.current_version,
                          'view': self.controller.current_pipeline_scene,
                          'module_executed_hook': [moduleExecuted],
                          'reason': 'Parameter Exploration',
                          'actions': performedActions[pi],
                          }
                interpreter.execute(modifiedPipelines[pi], **kwargs)
            progress.setValue(totalProgress)
Esempio n. 56
0
def get_module_registry():
    from vistrails.core.modules.module_registry import get_module_registry
    return get_module_registry()
Esempio n. 57
0
    def updateFromPipeline(self, pipeline, controller):
        """ updateFromPipeline(pipeline: Pipeline) -> None
        Read the list of aliases and parameters from the pipeline
        
        """
        self.clear()
        if not pipeline:
            return

        # Update the aliases
        if len(pipeline.aliases) > 0:
            aliasRoot = QParameterTreeWidgetItem(None, self, ['Aliases'])
            aliasRoot.setFlags(QtCore.Qt.ItemIsEnabled)
            for (alias, info) in pipeline.aliases.iteritems():
                ptype, pId, parentType, parentId, mId = info
                parameter = pipeline.db_get_object(ptype, pId)
                function = pipeline.db_get_object(parentType, parentId)
                v = parameter.strValue
                port_spec = function.get_spec('input')
                port_spec_item = port_spec.port_spec_items[parameter.pos]
                label = ['%s = %s' % (alias, v)]
                pInfo = ParameterInfo(module_id=mId,
                                      name=function.name,
                                      pos=parameter.pos,
                                      value=v,
                                      spec=port_spec_item,
                                      is_alias=True)
                QParameterTreeWidgetItem((alias, [pInfo]), aliasRoot, label)
            aliasRoot.setExpanded(True)

        vistrailVarsRoot = QParameterTreeWidgetItem(None, self,
                                                    ['Vistrail Variables'])
        vistrailVarsRoot.setHidden(True)

        # Now go through all modules and functions

        inspector = PipelineInspector()
        inspector.inspect_ambiguous_modules(pipeline)
        sortedModules = sorted(pipeline.modules.iteritems(),
                               key=lambda item: item[1].name)

        reg = get_module_registry()

        for mId, module in sortedModules:
            if module.is_vistrail_var():
                vistrailVarsRoot.setHidden(False)
                vistrailVarsRoot.setExpanded(True)
                port_spec = module.get_port_spec('value', 'input')
                if not port_spec:
                    debug.critical("Not port_spec for value in module %s" %
                                   module)
                    continue

                if not controller.has_vistrail_variable_with_uuid(
                        module.get_vistrail_var()):
                    continue
                vv = controller.get_vistrail_variable_by_uuid(
                    module.get_vistrail_var())

                label = ['%s = %s' % (vv.name, vv.value)]
                pList = [
                    ParameterInfo(module_id=mId,
                                  name=port_spec.name,
                                  pos=port_spec.port_spec_items[pId].pos,
                                  value="",
                                  spec=port_spec.port_spec_items[pId],
                                  is_alias=False)
                    for pId in xrange(len(port_spec.port_spec_items))
                ]
                QParameterTreeWidgetItem((vv.name, pList), vistrailVarsRoot,
                                         label)
                continue

            function_names = {}
            # Add existing parameters
            mLabel = [module.name]
            moduleItem = None
            if len(module.functions) > 0:
                for fId in xrange(len(module.functions)):
                    function = module.functions[fId]
                    function_names[function.name] = function
                    if len(function.params) == 0: continue
                    if moduleItem == None:
                        if inspector.annotated_modules.has_key(mId):
                            annotatedId = inspector.annotated_modules[mId]
                            moduleItem = QParameterTreeWidgetItem(
                                annotatedId, self, mLabel)
                        else:
                            moduleItem = QParameterTreeWidgetItem(
                                None, self, mLabel)
                    v = ', '.join([p.strValue for p in function.params])
                    label = ['%s(%s)' % (function.name, v)]

                    try:
                        port_spec = function.get_spec('input')
                    except Exception:
                        debug.critical("get_spec failed: %s %s %s" % \
                                       (module, function, function.sigstring))
                        continue
                    port_spec_items = port_spec.port_spec_items
                    pList = [
                        ParameterInfo(module_id=mId,
                                      name=function.name,
                                      pos=function.params[pId].pos,
                                      value=function.params[pId].strValue,
                                      spec=port_spec_items[pId],
                                      is_alias=False)
                        for pId in xrange(len(function.params))
                    ]
                    mName = module.name
                    if moduleItem.parameter is not None:
                        mName += '(%d)' % moduleItem.parameter
                    fName = '%s :: %s' % (mName, function.name)
                    QParameterTreeWidgetItem((fName, pList), moduleItem, label)
            # Add available parameters
            if module.is_valid:
                for port_spec in module.destinationPorts():
                    if (port_spec.name in function_names
                            or not port_spec.is_valid
                            or not len(port_spec.port_spec_items)
                            or not reg.is_constant(port_spec)):
                        # The function already exists or is empty
                        # or contains non-constant modules
                        continue
                    if moduleItem == None:
                        if inspector.annotated_modules.has_key(mId):
                            annotatedId = inspector.annotated_modules[mId]
                            moduleItem = QParameterTreeWidgetItem(
                                annotatedId, self, mLabel, False)
                        else:
                            moduleItem = QParameterTreeWidgetItem(
                                None, self, mLabel, False)
                    v = ', '.join(
                        [p.module for p in port_spec.port_spec_items])
                    label = ['%s(%s)' % (port_spec.name, v)]
                    pList = [
                        ParameterInfo(module_id=mId,
                                      name=port_spec.name,
                                      pos=port_spec.port_spec_items[pId].pos,
                                      value="",
                                      spec=port_spec.port_spec_items[pId],
                                      is_alias=False)
                        for pId in xrange(len(port_spec.port_spec_items))
                    ]
                    mName = module.name
                    if moduleItem.parameter is not None:
                        mName += '(%d)' % moduleItem.parameter
                    fName = '%s :: %s' % (mName, port_spec.name)
                    QParameterTreeWidgetItem((fName, pList), moduleItem, label,
                                             False)
            if moduleItem:
                moduleItem.setExpanded(True)
        self.toggleUnsetParameters(self.showUnsetParameters)
Esempio n. 58
0
def convert_data(child, parent_type, parent_id):
    from vistrails.core.vistrail.port_spec import PortSpec
    from vistrails.core.modules.module_registry import get_module_registry
    global module_map, translate_ports

    registry = get_module_registry()
    if child.vtType == 'module':
        descriptor = registry.get_descriptor_from_name_only(child.db_name)
        package = descriptor.identifier
        module_map[child.db_id] = (child.db_name, package)
        return DBModule(id=child.db_id,
                        cache=child.db_cache,
                        abstraction=0,
                        name=child.db_name,
                        package=package)
    elif child.vtType == 'connection':
        return DBConnection(id=child.db_id)
    elif child.vtType == 'portSpec':
        return DBPortSpec(id=child.db_id,
                          name=child.db_name,
                          type=child.db_type,
                          spec=child.db_spec)
    elif child.vtType == 'function':
        if parent_type == 'module':
            name = translate_vtk(parent_id, child.db_name)
        else:
            name = child.db_name
        return DBFunction(id=child.db_id, pos=child.db_pos, name=name)
    elif child.vtType == 'parameter':
        return DBParameter(id=child.db_id,
                           pos=child.db_pos,
                           name=child.db_name,
                           type=child.db_type,
                           val=child.db_val,
                           alias=child.db_alias)
    elif child.vtType == 'location':
        return DBLocation(id=child.db_id, x=child.db_x, y=child.db_y)
    elif child.vtType == 'annotation':
        return DBAnnotation(id=child.db_id,
                            key=child.db_key,
                            value=child.db_value)
    elif child.vtType == 'port':
        sig = child.db_sig
        if '(' in sig and ')' in sig:
            name = sig[:sig.find('(')]
            specs = sig[sig.find('(') + 1:sig.find(')')]
            name = translate_vtk(child.db_moduleId, name, specs)

            new_specs = []
            for spec_name in specs.split(','):
                descriptor = registry.get_descriptor_from_name_only(spec_name)
                spec_str = descriptor.identifier + ':' + spec_name
                new_specs.append(spec_str)
            spec = '(' + ','.join(new_specs) + ')'
        else:
            name = sig
            spec = ''
        return DBPort(id=child.db_id,
                      type=child.db_type,
                      moduleId=child.db_moduleId,
                      moduleName=child.db_moduleName,
                      name=name,
                      spec=spec)
Esempio n. 59
0
def write_workflow_to_python(pipeline):
    """Writes a pipeline to a Python source file.

    :returns: An iterable over the lines (as unicode) of the generated script.
    """
    # The set of all currently bound symbols in the resulting script's global
    # scope
    # These are either variables from translated modules (internal or output
    # ports) or imported names
    all_vars = set(reserved)

    # Should we import izip or product from itertools?
    import_izip = False
    import_product = False

    # The parts of the final generated script
    text = []

    # The modules that have been translated, maps from the module's id to a
    # Script object
    modules = dict()

    # The preludes that have been collected
    # A "prelude" is a piece of code that is supposed to go at the top of the
    # file, and that shouldn't be repeated; things like import statements,
    # function/class definition, and constants
    preludes = []

    reg = get_module_registry()

    # set port specs and module depth
    pipeline.validate(False)

    # ########################################
    # Walk through the pipeline to get all the codes
    #
    for module_id in pipeline.graph.vertices_topological_sort():
        module = pipeline.modules[module_id]
        print("Processing module %s %d" % (module.name, module_id))

        desc = module.module_descriptor
        module_class = desc.module

        # Gets the code
        code_preludes = []
        if not hasattr(module_class, 'to_python_script'):
            # Use vistrails API to execute module
            code, code_preludes = generate_api_code(module)
        elif module_class.to_python_script is None:
            debug.critical("Module %s cannot be converted to Python" %
                           module.name)
            code = Script(
                "# <Missing code>\n"
                "# %s has empty function to_python_script()\n"
                "# VisTrails cannot currently export such modules" %
                module.name, 'variables', 'variables')
        else:
            # Call the module to get the base code
            code = module_class.to_python_script(module)
            if isinstance(code, tuple):
                code, code_preludes = code
            print("Got code:\n%r" % (code, ))
            assert isinstance(code, Script)

        modules[module_id] = code
        preludes.extend(code_preludes)

    # ########################################
    # Processes the preludes and writes the beginning of the file
    #
    print("Writing preludes")
    # Adds all imported modules to the list of symbols
    for prelude in preludes:
        all_vars.update(prelude.imported_pkgs)
    # Removes collisions
    prelude_renames = {}
    for prelude in preludes:
        prelude_renames.update(prelude.avoid_collisions(all_vars))
    # remove duplicates
    final_preludes = []
    final_prelude_set = set()
    for prelude in [unicode(p) for p in preludes]:
        if prelude not in final_prelude_set:
            final_prelude_set.add(prelude)
            final_preludes.append(prelude)

    # Writes the preludes
    for prelude in final_preludes:
        text.append(prelude)
    #if preludes:
    #    text.append('# PRELUDE ENDS -- pipeline code follows\n\n')
    text.append('')

    # ########################################
    # Walk through the pipeline a second time to generate the full script
    #
    first = True
    # outer name of output is different from name in code if loop is used
    # so we keep this mapping: {(module_id, oport_name): outer_name}
    output_loop_map = {}
    for module_id in pipeline.graph.vertices_topological_sort():
        module = pipeline.modules[module_id]
        desc = module.module_descriptor
        print("Writing module %s %d" % (module.name, module_id))

        if not first:
            text.append('\n')
        else:
            first = False

        # Annotation, used to rebuild the pipeline
        text.append("# MODULE %d %s" % (module_id, desc.sigstring))

        code = modules[module_id]

        # Gets all the module's input and output port names
        input_ports = set(
            reg.module_destination_ports_from_descriptor(False, desc))
        input_ports.update(module.input_port_specs)

        input_port_names = set(p.name for p in input_ports)
        iports = dict((p.name, p) for p in input_ports)
        connected_inputs = set()
        for _, conn_id in pipeline.graph.edges_to(module_id):
            conn = pipeline.connections[conn_id]
            connected_inputs.add(utf8(conn.destination.name))
        for function in module.functions:
            connected_inputs.add(utf8(function.name))
        output_ports = set(reg.module_source_ports_from_descriptor(
            False, desc))
        output_ports.update(module.output_port_specs)
        output_port_names = set(p.name for p in output_ports)
        connected_outputs = set()
        for _, conn_id in pipeline.graph.edges_from(module_id):
            conn = pipeline.connections[conn_id]
            connected_outputs.add(utf8(conn.source.name))

        # Changes symbol names in this piece of code to prevent collisions
        # with already-encountered code
        old_all_vars = set(all_vars)
        code.normalize(input_port_names, output_port_names, all_vars)
        # Now, code knows what its inputs and outputs are
        print("Normalized code:\n%r" % (code, ))
        print("New vars in all_vars: %r" % (all_vars - old_all_vars, ))
        print("used_inputs: %r" % (code.used_inputs, ))

        # build final inputs as result of merging connections
        # {dest_name: {source_name, depth_diff}}
        combined_inputs = {}
        # Adds functions
        for function in module.functions:
            port = utf8(function.name)
            code.unset_inputs.discard(port)
            if code.skip_functions:
                continue
            if port not in code.used_inputs:
                print("NOT adding function %s (not used in script)" % port)
                continue

            ## Creates a variable with the value
            name = make_unique(port, all_vars)
            if len(function.params) == 1:
                value = function.params[0].value()
            else:
                value = [p.value() for p in function.params]

            depth = -iports[port].depth
            print("Function %s: var %s, value %r" % (port, name, value))
            text.append("# FUNCTION %s %s" % (port, name))
            text.append('%s = %r' % (name, value))
            if port not in combined_inputs:
                combined_inputs[port] = []
            combined_inputs[port].append((name, depth))

        # Sets input connections
        conn_ids = sorted(
            [conn_id for _, conn_id in pipeline.graph.edges_to(module_id)])
        for conn_id in conn_ids:
            conn = pipeline.connections[conn_id]
            dst = conn.destination
            port = utf8(dst.name)
            if port not in code.used_inputs:
                print("NOT connecting port %s (not used in script)" % dst.name)
                continue
            src = conn.source

            # Tells the code what the variable was
            src_mod = modules[src.moduleId]
            name = output_loop_map.get((src.moduleId, src.name),
                                       src_mod.get_output(src.name))

            # get depth difference to destination port
            depth = pipeline._connection_depths.get(conn_id, 0)
            # account for module looping
            depth += pipeline.modules[src.moduleId].list_depth

            print("Input %s: var %s" % (port, name))
            text.append("# CONNECTION %s %s" % (dst.name, name))
            #code.set_input(utf8(dst.name), name)
            code.unset_inputs.discard(utf8(dst.name))
            if port not in combined_inputs:
                combined_inputs[port] = []
            combined_inputs[port].append((name, depth))

        # Sets default values
        if code.unset_inputs:
            print("unset_inputs: %r" % (code.unset_inputs, ))
            for port in set(code.unset_inputs):
                if port not in code.used_inputs:
                    continue
                # Creates a variable with the value
                name = make_unique(port, all_vars)
                default = iports[port].defaults
                if len(default) == 1:
                    default = default[0]
                print("Default: %s: var %s, value %r" % (port, name, default))
                text.append("# DEFAULT %s %s" % (port, name))
                text.append('%s = %s' % (name, default))
                code.set_input(port, name)

        # merge connections
        # [(port, [names], depth)]
        merged_inputs = []
        # keep input port order
        for iport in sorted(input_ports, key=lambda x: (x.sort_key, x.id)):
            port = iport.name
            if port not in combined_inputs:
                continue
            conns = combined_inputs[port]
            descs = iports[port].descriptors()
            can_merge = ((len(descs) == 1 and isinstance(descs[0], List))
                         or iports[port].depth > 0)
            if len(conns) == 1 or not can_merge:
                # no merge needed
                name = conns[0][0]
                depth = conns[0][1]
                if depth >= 0:
                    # add directly
                    new_names = [name]
                    for i in xrange(depth):
                        name = make_unique(new_names[-1] + '_item', all_vars)
                        new_names.append(name)
                    code.set_input(port, name)
                    merged_inputs.append((port, new_names, depth))
                else:
                    # wrap to correct depth as new variable
                    while depth:
                        depth += 1
                        name = '[%s]' % name
                    new_name = make_unique(code.inputs[port], all_vars)
                    text.append('%s = %s' % (new_name, name))
                    code.set_input(port, new_name)
                    merged_inputs.append((port, [new_name], 0))
            else:
                # merge connections
                items = []
                for name, depth in conns:
                    # wrap to max depth
                    while depth < 0:
                        depth += 1
                        name = '[%s]' % name
                    items.append(name)
                # assign names to list items in loop
                # like "xItem", "xItemItem", etc.
                new_names = [make_unique(code.inputs[port], all_vars)]
                max_depth = max([c[1] for c in conns])
                if max_depth > 0:
                    for i in xrange(max_depth):
                        new_name = make_unique(new_names[-1] + '_item',
                                               all_vars)
                        new_names.append(new_name)
                text.append('%s = %s' % (new_names[0], ' + '.join(items)))
                code.set_input(port, new_names[-1])
                merged_inputs.append((port, new_names, max_depth))

        #### Prepare looping #################################################
        max_level = module.list_depth
        # names for the looped output ports
        output_sub_ports = dict([(port, [code.get_output(port)])
                                 for port in connected_outputs])
        # add output names to global output name list
        for port, names in output_sub_ports.iteritems():
            output_loop_map[(module_id, port)] = names[0]
        offset_levels = []
        loop_args = []
        for level in xrange(1, max_level + 1):
            # Output values are collected from the inside and out
            for port in connected_outputs:
                sub_ports = output_sub_ports[port]
                new_name = make_unique(sub_ports[-1] + '_item', all_vars)
                code.set_output(port, new_name)
                sub_ports.append(new_name)

            # construct the loop code
            combine_type = 'cartesian'
            if max_level == 1:
                # first level may use a complex port combination
                cps = {}
                for cp in module.control_parameters:
                    cps[cp.name] = cp.value
                if ModuleControlParam.LOOP_KEY in cps:
                    combine_type = cps[ModuleControlParam.LOOP_KEY]
                    if combine_type not in ['cartesian', 'pairwise']:
                        combine_type = ast.literal_eval(combine_type)
            if combine_type == 'cartesian':
                # make nested for loops for all iterated ports
                items = []
                for port, names, depth in merged_inputs:
                    if depth >= level:
                        # map parent level name to this level
                        items.append((names[level], names[level - 1]))
                offset = 0
                loop_arg = []
                for item in items:
                    loop_arg.append('    ' * offset + 'for %s in %s:' % item)
                    offset += 1
                loop_args.append('\n'.join(loop_arg))
                offset_levels.append(offset)
            elif combine_type == 'pairwise':
                #  zip all iterated ports
                prev_items, cur_items = [], []
                for port, names, depth in merged_inputs:
                    if depth >= level:
                        # map parent level name to this level
                        prev_items.append(names[level - 1])
                        cur_items.append(names[level])
                loop_arg = 'for %s in izip(%s):' % (', '.join(cur_items),
                                                    ', '.join(prev_items))
                import_izip = True
                loop_args.append(loop_arg)
                offset_levels.append(1)
            else:
                #  construct custom combination
                name_map = {}
                for port, names, depth in merged_inputs:
                    if depth >= level:
                        # map parent level name to this level
                        name_map[port] = (names[level - 1], names[level])
                loop_arg = "for %s in %s:" % combine(combine_type, name_map)
                if 'izip' in loop_arg:
                    import_izip = True
                if 'product(' in loop_arg:
                    import_product = True
                loop_args.append(loop_arg)
                offset_levels.append(1)

        # TODO: handle while loop

        #### Write loop start ################################################
        for level in xrange(0, max_level):
            offset = sum(offset_levels[:level])

            for port in connected_outputs:
                sub_ports = output_sub_ports[port]
                text.append('   ' * offset + '%s = []' % sub_ports[level])
                sub_ports.append(new_name)

            for_loop = '\n'.join(
                ['    ' * offset + f for f in loop_args[level].split('\n')])
            text.append(for_loop)

        #### Write module code ###############################################
        code_text = unicode(code)
        if offset_levels:
            code_text = code_text.split('\n')
            code_text = '\n'.join(
                ['    ' * sum(offset_levels) + t for t in code_text])
            # node.increase_indent(sum(offset_levels)*4)
        # Ok, add the module's code
        print("Rendering code")
        text.append(code_text)
        print("Total new vars: %r" % (all_vars - old_all_vars, ))

        #### Write loop end ##################################################
        for level in reversed(xrange(0, max_level)):
            offset = sum(offset_levels[:level + 1])
            for port in connected_outputs:
                sub_ports = output_sub_ports[port]
                text.append('    ' * offset + '%s.append(%s)' %
                            (sub_ports[level], sub_ports[level - 1]))

    if import_izip or import_product:
        if not import_product:
            text.insert(0, 'from itertools import izip\n')
        elif not import_product:
            text.insert(0, 'from itertools import product\n')
        else:
            text.insert(0, 'from itertools import izip, product\n')
    return text
Esempio n. 60
0
def assignPipelineCellLocations(pipeline, sheetName,
                                row, col, cellIds=None,
                                minRowCount=None, minColCount=None):

    reg = get_module_registry()
    spreadsheet_cell_desc = reg.get_descriptor_by_name(spreadsheet_pkg,
                                                       'SpreadsheetCell')
    output_module_desc = reg.get_descriptor_by_name(
            'org.vistrails.vistrails.basic', 'OutputModule')

    create_module = VistrailController.create_module_static
    create_function = VistrailController.create_function_static
    create_connection = VistrailController.create_connection_static

    pipeline = copy.copy(pipeline)
    root_pipeline = pipeline
    if cellIds is None:
        inspector = PipelineInspector()
        inspector.inspect_spreadsheet_cells(pipeline)
        inspector.inspect_ambiguous_modules(pipeline)
        cellIds = inspector.spreadsheet_cells

    def fix_cell_module(pipeline, mId):
        # Delete connections to 'Location' input port
        conns_to_delete = []
        for c in pipeline.connection_list:
            if c.destinationId == mId and c.destination.name == 'Location':
                conns_to_delete.append(c.id)
        for c_id in conns_to_delete:
            pipeline.delete_connection(c_id)

        # a hack to first get the id_scope to the local pipeline scope
        # then make them negative by hacking the getNewId method
        # all of this is reset at the end of this block
        id_scope = pipeline.tmp_id
        orig_getNewId = pipeline.tmp_id.__class__.getNewId
        def getNewId(self, objType):
            return -orig_getNewId(self, objType)
        pipeline.tmp_id.__class__.getNewId = getNewId

        # Add a sheet reference with a specific name
        sheetReference = create_module(id_scope, spreadsheet_pkg,
                                       "SheetReference")
        sheetNameFunction = create_function(id_scope, sheetReference,
                                            "SheetName", [str(sheetName)])
            # ["%s %d" % (sheetPrefix, sheet)])

        sheetReference.add_function(sheetNameFunction)

        if minRowCount is not None:
            minRowFunction = create_function(id_scope, sheetReference,
                                             "MinRowCount", [str(minRowCount)])
                                                   # [str(rowCount*vRCount)])
            sheetReference.add_function(minRowFunction)
        if minColCount is not None:
            minColFunction = create_function(id_scope, sheetReference,
                                             "MinColumnCount",
                                             [str(minColCount)])
                                                   # [str(colCount*vCCount)])
            sheetReference.add_function(minColFunction)

        # Add a cell location module with a specific row and column
        cellLocation = create_module(id_scope, spreadsheet_pkg,
                                     "CellLocation")
        rowFunction = create_function(id_scope, cellLocation, "Row", [str(row)])
                                                 # [str(row*vRCount+vRow+1)])
        colFunction = create_function(id_scope, cellLocation, "Column",
                                      [str(col)])
                                                 # [str(col*vCCount+vCol+1)])

        cellLocation.add_function(rowFunction)
        cellLocation.add_function(colFunction)

        # Then connect the SheetReference to the CellLocation
        sheet_conn = create_connection(id_scope, sheetReference, "value",
                                       cellLocation, "SheetReference")

        # Then connect the CellLocation to the spreadsheet cell
        cell_module = pipeline.get_module_by_id(mId)
        cell_conn = create_connection(id_scope, cellLocation, "value",
                                      cell_module, "Location")

        pipeline.add_module(sheetReference)
        pipeline.add_module(cellLocation)
        pipeline.add_connection(sheet_conn)
        pipeline.add_connection(cell_conn)

        # replace the getNewId method
        pipeline.tmp_id.__class__.getNewId = orig_getNewId

    def fix_output_module(pipeline, mId):
        # Remove all connections to 'configuration' input port
        conns_to_delete = []
        for c in pipeline.connection_list:
            if (c.destinationId == mId and
                    c.destination.name == 'configuration'):
                conns_to_delete.append(c.id)
        for c_id in conns_to_delete:
            pipeline.delete_connection(c_id)

        m = pipeline.modules[mId]

        # Remove all functions on 'configuration' input port
        funcs_to_delete = []
        for f in m.functions:
            if f.name == 'configuration':
                funcs_to_delete.append(f.real_id)
        for f_id in funcs_to_delete:
            m.delete_function_by_real_id(f_id)

        # a hack to first get the id_scope to the local pipeline scope
        # then make them negative by hacking the getNewId method
        # all of this is reset at the end of this block
        id_scope = pipeline.tmp_id
        orig_getNewId = pipeline.tmp_id.__class__.getNewId
        def getNewId(self, objType):
            return -orig_getNewId(self, objType)
        pipeline.tmp_id.__class__.getNewId = getNewId

        config = {'row': row - 1, 'col': col - 1}
        if minRowCount is not None:
            config['sheetRowCount'] = minRowCount
        if minColCount is not None:
            config['sheetColCount'] = minColCount
        if sheetName is not None:
            config['sheetName']= sheetName
        config = {'spreadsheet': config}
        config_function = create_function(id_scope, m,
                                          'configuration', [repr(config)])
        m.add_function(config_function)

        # replace the getNewId method
        pipeline.tmp_id.__class__.getNewId = orig_getNewId

    for id_list in cellIds:
        cell_pipeline = pipeline

        # find at which depth we need to be working
        if isinstance(id_list, (int, long)):
            mId = id_list
            m = cell_pipeline.modules[mId]
        else:
            id_iter = iter(id_list)
            mId = next(id_iter)
            m = cell_pipeline.modules[mId]
            for mId in id_iter:
                cell_pipeline = m.pipeline
                m = cell_pipeline.modules[mId]

        if reg.is_descriptor_subclass(m.module_descriptor,
                                      spreadsheet_cell_desc):
            fix_cell_module(cell_pipeline, mId)
        elif reg.is_descriptor_subclass(m.module_descriptor,
                                        output_module_desc):
            fix_output_module(cell_pipeline, mId)

    return root_pipeline