def test_get_var_type(self): locator = XMLFileLocator( os.path.join(os.path.dirname(__file__), 'variables.xml')) vistrail = locator.load() desc_var1 = Variable.read_type( get_upgraded_pipeline(vistrail, 'dat-var-var1')) self.assertEqual(desc_var1.module, basic.Float) desc_var2 = Variable.read_type( get_upgraded_pipeline(vistrail, 'dat-var-var2')) self.assertEqual(desc_var2.module, basic.String)
def test1(self): """Exercises aliasing on modules""" import vistrails.core.system from vistrails.core.db.locator import XMLFileLocator v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() p1 = v.getPipeline('final') p2 = v.getPipeline('final') self.assertEquals(len(p1.modules), len(p2.modules)) for k in p1.modules.keys(): if p1.modules[k] is p2.modules[k]: self.fail("didn't expect aliases in two different pipelines")
def test1(self): """Exercises aliasing on modules""" import vistrails.core.vistrail from vistrails.core.db.locator import XMLFileLocator v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() p1 = v.getPipeline('final') p2 = v.getPipeline('final') self.assertEquals(len(p1.modules), len(p2.modules)) for k in p1.modules.keys(): if p1.modules[k] is p2.modules[k]: self.fail("didn't expect aliases in two different pipelines")
def test2(self): from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() #testing diff v1 = 17 v2 = 27 v3 = 22 v.get_pipeline_diff(v1,v2) v.get_pipeline_diff(v1,v3) v.get_pipeline_diff_with_connections(v1,v2) v.get_pipeline_diff_with_connections(v1,v3)
def test_ticket_73(self): # Tests serializing a custom-named module to disk locator = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/test_ticket_73.xml') v = locator.load() import tempfile (fd, filename) = tempfile.mkstemp() os.close(fd) locator = XMLFileLocator(filename) try: locator.save(v) finally: os.remove(filename)
def test_python_source_2(self): locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/pythonsource.xml') result = run_and_get_results([(locator, "test_simple_success")], update_vistrail=False)[0] self.assertEquals(len(result.executed), 1)
def test2(self): """Exercises aliasing on points""" from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() p1 = v.getPipeline('final') v.getPipeline('final') p2 = v.getPipeline('final') m1s = p1.modules.items() m2s = p2.modules.items() m1s.sort() m2s.sort() for ((i1,m1),(i2,m2)) in izip(m1s, m2s): self.assertEquals(m1.center.x, m2.center.x) self.assertEquals(m1.center.y, m2.center.y)
def test_change_parameter(self): locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/test_change_vistrail.xml') result = run([(locator, "v1")], update_vistrail=False) self.assertEqual(len(result), 0) result = run([(locator, "v2")], update_vistrail=False) self.assertEquals(len(result), 0)
def test1(self): from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() #testing nodes in different branches v1 = 36 v2 = 41 p1 = v.getFirstCommonVersion(v1,v2) p2 = v.getFirstCommonVersion(v2,v1) self.assertEquals(p1,p2) #testing nodes in the same branch v1 = 15 v2 = 36 p1 = v.getFirstCommonVersion(v1,v2) p2 = v.getFirstCommonVersion(v2,v1) self.assertEquals(p1,p2) if p1 == 0 or p2 == 0: self.fail("vistrails tree is not single rooted.")
def test1(self): from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() #testing nodes in different branches v1 = 36 v2 = 41 p1 = v.getFirstCommonVersion(v1, v2) p2 = v.getFirstCommonVersion(v2, v1) self.assertEquals(p1, p2) #testing nodes in the same branch v1 = 15 v2 = 36 p1 = v.getFirstCommonVersion(v1, v2) p2 = v.getFirstCommonVersion(v2, v1) self.assertEquals(p1, p2) if p1 == 0 or p2 == 0: self.fail("vistrails tree is not single rooted.")
def get_pipeline(self): """Gets the pipeline. This might mean materializing it from a callback or translating it from a user-friendly format. """ if self.subworkflow is not None: locator = XMLFileLocator(self.subworkflow) vistrail = locator.load() return get_upgraded_pipeline(vistrail) else: callback_ret = self.pipeline_arg if callable(callback_ret): callback_ret = callback_ret() if isinstance(callback_ret, Pipeline): return callback_ret elif callback_ret[0] == 'pipeline': pipeline, = callback_ret[1:] return pipeline elif callback_ret[0] == 'python_lists': return build_pipeline(*callback_ret[1:]) else: raise ValueError("Plot pipeline is invalid value %s" % abbrev(repr(callback_ret)))
def test_ticket_73(self): # Tests serializing a custom-named module to disk locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/test_ticket_73.xml') v = locator.load() import tempfile (fd, filename) = tempfile.mkstemp() os.close(fd) locator = XMLFileLocator(filename) try: locator.save(v) finally: os.remove(filename)
def test_cache(self): from vistrails.core.modules.basic_modules import StandardOutput old_compute = StandardOutput.compute StandardOutput.compute = lambda s: None try: from vistrails.core.db.locator import XMLFileLocator from vistrails.core.vistrail.controller import VistrailController from vistrails.core.db.io import load_vistrail """Test if basic caching is working.""" locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml') (v, abstractions, thumbnails, mashups) = load_vistrail(locator) # the controller will take care of upgrades controller = VistrailController(v, locator, abstractions, thumbnails, mashups) p1 = v.getPipeline('int chain') n = v.get_version_number('int chain') controller.change_selected_version(n) controller.flush_delayed_actions() p1 = controller.current_pipeline view = DummyView() interpreter = CachedInterpreter.get() result = interpreter.execute( p1, locator=v, current_version=n, view=view, ) # to force fresh params p2 = v.getPipeline('int chain') controller.change_selected_version(n) controller.flush_delayed_actions() p2 = controller.current_pipeline result = interpreter.execute( p2, locator=v, current_version=n, view=view, ) self.assertEqual(len(result.modules_added), 1) finally: StandardOutput.compute = old_compute
def test2(self): from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() #testing diff v1 = 17 v2 = 27 v3 = 22 v.get_pipeline_diff(v1, v2) v.get_pipeline_diff(v1, v3) v.get_pipeline_diff_with_connections(v1, v2) v.get_pipeline_diff_with_connections(v1, v3)
def test1(self): from vistrails.core.modules.basic_modules import StandardOutput values = [] def mycompute(s): v = s.get_input('value') values.append(v) orig_compute = StandardOutput.compute StandardOutput.compute = mycompute try: locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml') result = run([(locator, "int chain")], update_vistrail=False) self.assertEqual(len(result), 0) self.assertEqual(values, [2]) finally: StandardOutput.compute = orig_compute
def test_tuple(self): from vistrails.core.vistrail.module_param import ModuleParam from vistrails.core.vistrail.module_function import ModuleFunction from vistrails.core.utils import DummyView from vistrails.core.vistrail.module import Module import vistrails.db.domain id_scope = vistrails.db.domain.IdScope() interpreter = vistrails.core.interpreter.default.get_default_interpreter( ) v = DummyView() p = vistrails.core.vistrail.pipeline.Pipeline() params = [ ModuleParam( id=id_scope.getNewId(ModuleParam.vtType), pos=0, type='Float', val='2.0', ), ModuleParam( id=id_scope.getNewId(ModuleParam.vtType), pos=1, type='Float', val='2.0', ) ] function = ModuleFunction(id=id_scope.getNewId(ModuleFunction.vtType), name='input') function.add_parameters(params) module = Module(id=id_scope.getNewId(Module.vtType), name='TestTupleExecution', package='org.vistrails.vistrails.console_mode_test', version='0.9.1') module.add_function(function) p.add_module(module) interpreter.execute(p, locator=XMLFileLocator('foo'), current_version=1L, view=v)
def test2(self): """Exercises aliasing on points""" from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() p1 = v.getPipeline('final') v.getPipeline('final') p2 = v.getPipeline('final') m1s = p1.modules.items() m2s = p2.modules.items() m1s.sort() m2s.sort() for ((i1, m1), (i2, m2)) in izip(m1s, m2s): self.assertEquals(m1.center.x, m2.center.x) self.assertEquals(m1.center.y, m2.center.y)
def execute_wf(wf, output_port): # Save the workflow in a temporary file temp_wf_fd, temp_wf = tempfile.mkstemp() try: f = open(temp_wf, 'w') f.write(wf) f.close() os.close(temp_wf_fd) # Clean the cache interpreter = get_default_interpreter() interpreter.flush() # Load the Pipeline from the temporary file vistrail = Vistrail() locator = XMLFileLocator(temp_wf) workflow = locator.load(Pipeline) # Build a Vistrail from this single Pipeline action_list = [] for module in workflow.module_list: action_list.append(('add', module)) for connection in workflow.connection_list: action_list.append(('add', connection)) action = vistrails.core.db.action.create_action(action_list) vistrail.add_action(action, 0L) vistrail.update_id_scope() tag = 'parallel flow' vistrail.addTag(tag, action.id) # Build a controller and execute controller = VistrailController() controller.set_vistrail(vistrail, None) controller.change_selected_version(vistrail.get_version_number(tag)) execution = controller.execute_current_workflow( custom_aliases=None, custom_params=None, extra_info=None, reason='API Pipeline Execution') # Build a list of errors errors = [] pipeline = vistrail.getPipeline(tag) execution_errors = execution[0][0].errors if execution_errors: for key in execution_errors: module = pipeline.modules[key] msg = '%s: %s' % (module.name, execution_errors[key]) errors.append(msg) # Get the execution log from the controller try: module_log = controller.log.workflow_execs[0].item_execs[0] except IndexError: errors.append("Module log not found") return dict(errors=errors) else: machine = controller.log.workflow_execs[0].machines[ module_log.machine_id] xml_log = serialize(module_log) machine_log = serialize(machine) # Get the output value output = None serializable = None if not execution_errors: executed_module, = execution[0][0].executed executed_module = execution[0][0].objects[executed_module] try: output = executed_module.get_output(output_port) except ModuleError: errors.append("Output port not found: %s" % output_port) return dict(errors=errors) reg = vistrails.core.modules.module_registry.get_module_registry() base_classes = inspect.getmro(type(output)) if Module in base_classes: serializable = reg.get_descriptor(type(output)).sigstring output = output.serialize() # Return the dictionary, that will be sent back to the client return dict(errors=errors, output=output, serializable=serializable, xml_log=xml_log, machine_log=machine_log) finally: os.unlink(temp_wf)
def execute_wf(wf, output_port): # Save the workflow in a temporary file temp_wf_fd, temp_wf = tempfile.mkstemp() try: f = open(temp_wf, 'w') f.write(wf) f.close() os.close(temp_wf_fd) # Clean the cache interpreter = get_default_interpreter() interpreter.flush() # Load the Pipeline from the temporary file vistrail = Vistrail() locator = XMLFileLocator(temp_wf) workflow = locator.load(Pipeline) # Build a Vistrail from this single Pipeline action_list = [] for module in workflow.module_list: action_list.append(('add', module)) for connection in workflow.connection_list: action_list.append(('add', connection)) action = vistrails.core.db.action.create_action(action_list) vistrail.add_action(action, 0L) vistrail.update_id_scope() tag = 'parallel flow' vistrail.addTag(tag, action.id) # Build a controller and execute controller = VistrailController() controller.set_vistrail(vistrail, None) controller.change_selected_version(vistrail.get_version_number(tag)) execution = controller.execute_current_workflow( custom_aliases=None, custom_params=None, extra_info=None, reason='API Pipeline Execution') # Build a list of errors errors = [] pipeline = vistrail.getPipeline(tag) execution_errors = execution[0][0].errors if execution_errors: for key in execution_errors: module = pipeline.modules[key] msg = '%s: %s' %(module.name, execution_errors[key]) errors.append(msg) # Get the execution log from the controller try: module_log = controller.log.workflow_execs[0].item_execs[0] except IndexError: errors.append("Module log not found") return dict(errors=errors) else: machine = controller.log.workflow_execs[0].machines[ module_log.machine_id] xml_log = serialize(module_log) machine_log = serialize(machine) # Get the output value output = None serializable = None if not execution_errors: executed_module, = execution[0][0].executed executed_module = execution[0][0].objects[executed_module] try: output = executed_module.get_output(output_port) except ModuleError: errors.append("Output port not found: %s" % output_port) return dict(errors=errors) reg = vistrails.core.modules.module_registry.get_module_registry() base_classes = inspect.getmro(type(output)) if Module in base_classes: serializable = reg.get_descriptor(type(output)).sigstring output = output.serialize() # Return the dictionary, that will be sent back to the client return dict(errors=errors, output=output, serializable=serializable, xml_log=xml_log, machine_log=machine_log) finally: os.unlink(temp_wf)
def test_empty_action_chain_2(self): from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() assert v.actionChain(17, 17) == []
def execute(workflowJSON): ''' Execute a workflow from it's JSON representation ''' debug('convert json to xml') workflowXML = json2xml(workflowJSON) #temp_wf_fd, temp_wf = tempfile.mkstemp('.xml') debug('create temporary file') temp_wf_fd, temp_wf = tempfile.mkstemp() try: f = open(temp_wf, 'w') f.write(workflowXML) f.close() os.close(temp_wf_fd) #load workflow temp file into vistrails #vt.load_workflow(temp_wf) #execute workflow #execution = vt.execute() debug('Load the Pipeline from the temporary file') vistrail = Vistrail() locator = XMLFileLocator(temp_wf) workflow = locator.load(Pipeline) debug('Build a Vistrail from this single Pipeline') action_list = [] for module in workflow.module_list: action_list.append(('add', module)) for connection in workflow.connection_list: action_list.append(('add', connection)) action = vistrails.core.db.action.create_action(action_list) debug('add actions') vistrail.add_action(action, 0L) vistrail.update_id_scope() tag = 'climatepipes' vistrail.addTag(tag, action.id) debug('Build a controller and execute') controller = VistrailController() controller.set_vistrail(vistrail, None) controller.change_selected_version(vistrail.get_version_number(tag)) execution = controller.execute_current_workflow( custom_aliases=None, custom_params=None, extra_info=None, reason='API Pipeline Execution') debug('get result') execution_pipeline = execution[0][0] if len(execution_pipeline.errors) > 0: error("Executing workflow") for key in execution_pipeline.errors: error(execution_pipeline.errors[key]) print execution_pipeline.errors[key] return None modules = execution_pipeline.objects for id, module in modules.iteritems(): if isinstance(module, ToGeoJSON): return json.dumps({'result': module.JSON, 'error': None }) finally: os.unlink(temp_wf)
def test_version_graph(self): from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load() v.getVersionGraph()
def test_dynamic_module_error(self): locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/dynamic_module_error.xml') result = run([(locator, "test")], update_vistrail=False) self.assertNotEqual(len(result), 0)
def test_python_source(self): locator = XMLFileLocator( vistrails.core.system.vistrails_root_directory() + '/tests/resources/pythonsource.xml') result = run([(locator, "testPortsAndFail")], update_vistrail=False) self.assertEqual(len(result), 0)
def execute(modules, connections=[], add_port_specs=[], enable_pkg=True, full_results=False): """Build a pipeline and execute it. This is useful to simply build a pipeline in a test case, and run it. When doing that, intercept_result() can be used to check the results of each module. modules is a list of module tuples describing the modules to be created, with the following format: [('ModuleName', 'package.identifier', [ # Functions ('port_name', [ # Function parameters ('Signature', 'value-as-string'), ]), ])] connections is a list of tuples describing the connections to make, with the following format: [ (source_module_index, 'source_port_name', dest_module_index, 'dest_module_name'), ] add_port_specs is a list of specs to add to modules, with the following format: [ (mod_id, 'input'/'output', 'portname', '(port_sig)'), ] It is useful to test modules that can have custom ports through a configuration widget. The function returns the 'errors' dict it gets from the interpreter, so you should use a construct like self.assertFalse(execute(...)) if the execution is not supposed to fail. For example, this creates (and runs) an Integer module with its value set to 44, connected to a PythonCalc module, connected to a StandardOutput: self.assertFalse(execute([ ('Float', 'org.vistrails.vistrails.basic', [ ('value', [('Float', '44.0')]), ]), ('PythonCalc', 'org.vistrails.vistrails.pythoncalc', [ ('value2', [('Float', '2.0')]), ('op', [('String', '-')]), ]), ('StandardOutput', 'org.vistrails.vistrails.basic', []), ], [ (0, 'value', 1, 'value1'), (1, 'value', 2, 'value'), ])) """ from vistrails.core.db.locator import XMLFileLocator from vistrails.core.modules.module_registry import MissingPackage from vistrails.core.packagemanager import get_package_manager from vistrails.core.utils import DummyView from vistrails.core.vistrail.connection import Connection from vistrails.core.vistrail.module import Module from vistrails.core.vistrail.module_function import ModuleFunction from vistrails.core.vistrail.module_param import ModuleParam from vistrails.core.vistrail.pipeline import Pipeline from vistrails.core.vistrail.port import Port from vistrails.core.vistrail.port_spec import PortSpec from vistrails.core.interpreter.noncached import Interpreter pm = get_package_manager() port_spec_per_module = {} # mod_id -> [portspec: PortSpec] j = 0 for i, (mod_id, inout, name, sig) in enumerate(add_port_specs): mod_specs = port_spec_per_module.setdefault(mod_id, []) ps = PortSpec(id=i, name=name, type=inout, sigstring=sig, sort_key=-1) for psi in ps.port_spec_items: psi.id = j j += 1 mod_specs.append(ps) pipeline = Pipeline() module_list = [] for i, (name, identifier, functions) in enumerate(modules): function_list = [] try: pkg = pm.get_package(identifier) except MissingPackage: if not enable_pkg: raise dep_graph = pm.build_dependency_graph([identifier]) for pkg_id in pm.get_ordered_dependencies(dep_graph): pkg = pm.identifier_is_available(pkg_id) if pkg is None: raise pm.late_enable_package(pkg.codepath) pkg = pm.get_package(identifier) for func_name, params in functions: param_list = [] for j, (param_type, param_val) in enumerate(params): param_list.append( ModuleParam(pos=j, type=param_type, val=param_val)) function_list.append( ModuleFunction(name=func_name, parameters=param_list)) name = name.rsplit('|', 1) if len(name) == 2: namespace, name = name else: namespace = None name, = name module = Module(name=name, namespace=namespace, package=identifier, version=pkg.version, id=i, functions=function_list) for port_spec in port_spec_per_module.get(i, []): module.add_port_spec(port_spec) pipeline.add_module(module) module_list.append(module) for i, (sid, sport, did, dport) in enumerate(connections): s_sig = module_list[sid].get_port_spec(sport, 'output').sigstring d_sig = module_list[did].get_port_spec(dport, 'input').sigstring pipeline.add_connection( Connection(id=i, ports=[ Port(id=i * 2, type='source', moduleId=sid, name=sport, signature=s_sig), Port(id=i * 2 + 1, type='destination', moduleId=did, name=dport, signature=d_sig), ])) interpreter = Interpreter.get() result = interpreter.execute(pipeline, locator=XMLFileLocator('foo.xml'), current_version=1, view=DummyView()) if full_results: return result else: # Allows to do self.assertFalse(execute(...)) return result.errors
def apply_operation_subworkflow(controller, op, subworkflow, args): """Load an operation subworkflow from a file to build a new Variable. op is the requested operation. subworkflow is the filename of an XML file. args is a list of Variable that are the arguments of the operation; they will be connected in place of the operation subworkflow's InputPort modules. """ reg = get_module_registry() inputport_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'InputPort') outputport_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'OutputPort') generator = PipelineGenerator(controller) # Add the operation subworkflow locator = XMLFileLocator(subworkflow) vistrail = locator.load() operation_pipeline = get_upgraded_pipeline(vistrail) # Copy every module but the InputPorts and the OutputPort operation_modules_map = dict() # old module id -> new module for module in operation_pipeline.modules.itervalues(): if module.module_descriptor not in (inputport_desc, outputport_desc): operation_modules_map[module.id] = generator.copy_module(module) # Copy the connections and locate the input ports and the output port operation_params = dict() # param name -> [(module, input port name)] output = None # (module, port name) for connection in operation_pipeline.connection_list: src = operation_pipeline.modules[connection.source.moduleId] dest = operation_pipeline.modules[connection.destination.moduleId] if src.module_descriptor is inputport_desc: param = get_function(src, 'name') ports = operation_params.setdefault(param, []) ports.append( (operation_modules_map[connection.destination.moduleId], connection.destination.name)) elif dest.module_descriptor is outputport_desc: output = (operation_modules_map[connection.source.moduleId], connection.source.name) else: generator.connect_modules( operation_modules_map[connection.source.moduleId], connection.source.name, operation_modules_map[connection.destination.moduleId], connection.destination.name) # Add the parameter subworkflows for i in xrange(len(args)): generator.append_operations(args[i]._generator.operations) o_mod = args[i]._output_module o_port = args[i]._outputport_name for i_mod, i_port in operation_params.get(op.parameters[i].name, []): generator.connect_modules(o_mod, o_port, i_mod, i_port) return Variable(type=op.return_type, controller=controller, generator=generator, output=output)
def apply_operation_subworkflow(controller, op, subworkflow, args): """Load an operation subworkflow from a file to build a new Variable. op is the requested operation. subworkflow is the filename of an XML file. args is a list of Variable that are the arguments of the operation; they will be connected in place of the operation subworkflow's InputPort modules. """ reg = get_module_registry() inputport_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'InputPort') outputport_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'OutputPort') generator = PipelineGenerator(controller) # Add the operation subworkflow locator = XMLFileLocator(subworkflow) vistrail = locator.load() operation_pipeline = get_upgraded_pipeline(vistrail) # Copy every module but the InputPorts and the OutputPort operation_modules_map = dict() # old module id -> new module for module in operation_pipeline.modules.itervalues(): if module.module_descriptor not in (inputport_desc, outputport_desc): operation_modules_map[module.id] = generator.copy_module(module) # Copy the connections and locate the input ports and the output port operation_params = dict() # param name -> [(module, input port name)] output = None # (module, port name) for connection in operation_pipeline.connection_list: src = operation_pipeline.modules[connection.source.moduleId] dest = operation_pipeline.modules[connection.destination.moduleId] if src.module_descriptor is inputport_desc: param = get_function(src, 'name') ports = operation_params.setdefault(param, []) ports.append(( operation_modules_map[connection.destination.moduleId], connection.destination.name)) elif dest.module_descriptor is outputport_desc: output = (operation_modules_map[connection.source.moduleId], connection.source.name) else: generator.connect_modules( operation_modules_map[connection.source.moduleId], connection.source.name, operation_modules_map[connection.destination.moduleId], connection.destination.name) # Add the parameter subworkflows for i in xrange(len(args)): generator.append_operations(args[i]._generator.operations) o_mod = args[i]._output_module o_port = args[i]._outputport_name for i_mod, i_port in operation_params.get(op.parameters[i].name, []): generator.connect_modules( o_mod, o_port, i_mod, i_port) return Variable( type=op.return_type, controller=controller, generator=generator, output=output)
def create_pipeline(controller, recipe, row, column, var_sheetname, typecast=None): """Create a pipeline from a recipe and return its information. """ # Build from the root version controller.change_selected_version(0) reg = get_module_registry() generator = PipelineGenerator(controller) inputport_desc = reg.get_descriptor_by_name( 'org.vistrails.vistrails.basic', 'InputPort') # Add the plot subworkflow if recipe.plot.subworkflow is not None: locator = XMLFileLocator(recipe.plot.subworkflow) vistrail = locator.load() plot_pipeline = get_upgraded_pipeline(vistrail) elif recipe.plot.callback is not None: callback_ret = recipe.plot.callback() if isinstance(callback_ret, Pipeline): plot_pipeline = callback_ret elif callback_ret[0] == 'pipeline': plot_pipeline, = callback_ret[1:] elif callback_ret[0] == 'python_lists': plot_pipeline = build_pipeline(*callback_ret[1:]) else: raise ValueError("Plot callback returned invalid value %r" % callback_ret[0]) else: assert False connected_to_inputport = set( c.source.moduleId for c in plot_pipeline.connection_list if (plot_pipeline.modules[c.destination.moduleId] .module_descriptor is inputport_desc)) # Copy every module but the InputPorts and up plot_modules_map = dict() # old module id -> new module for module in plot_pipeline.modules.itervalues(): if (module.module_descriptor is not inputport_desc and module.id not in connected_to_inputport): plot_modules_map[module.id] = generator.copy_module(module) del connected_to_inputport def _get_or_create_module(moduleType): """Returns or creates a new module of the given type. Warns if multiple modules of that type were found. """ modules = find_modules_by_type(plot_pipeline, [moduleType]) if not modules: desc = reg.get_descriptor_from_module(moduleType) module = controller.create_module_from_descriptor(desc) generator.add_module(module) return module, True else: # Currently we do not support multiple cell locations in one # pipeline but this may be a feature in the future, to have # linked visualizations in multiple cells if len(modules) > 1: warnings.warn("Found multiple %s modules in plot " "subworkflow, only using one." % moduleType) return plot_modules_map[modules[0].id], False # Connect the CellLocation to the SpreadsheetCell cell_modules = find_modules_by_type(plot_pipeline, [SpreadsheetCell]) if cell_modules: cell_module = plot_modules_map[cell_modules[0].id] # Add a CellLocation module if the plot subworkflow didn't contain one location_module, new_location = _get_or_create_module(CellLocation) if new_location: # Connect the CellLocation to the SpreadsheetCell generator.connect_modules( location_module, 'value', cell_module, 'Location') generator.update_function( location_module, 'Row', [str(row + 1)]) generator.update_function( location_module, 'Column', [str(column + 1)]) if len(cell_modules) > 1: warnings.warn("Plot subworkflow '%s' contains more than " "one spreadsheet cell module. Only one " "was connected to a location module." % recipe.plot.name) # Add a SheetReference module sheetref_module, new_sheetref = _get_or_create_module(SheetReference) if new_sheetref or new_location: # Connection the SheetReference to the CellLocation generator.connect_modules( sheetref_module, 'value', location_module, 'SheetReference') generator.connect_var( var_sheetname, sheetref_module, 'SheetName') else: warnings.warn("Plot subworkflow '%s' does not contain a " "spreadsheet cell module" % recipe.plot.name) # TODO : use walk_modules() to find all modules above an InputPort's # 'Default' port and ignore them in the following loop # Copy the connections and locate the input ports plot_params = dict() # param name -> [(module, input port name)] for connection in plot_pipeline.connection_list: src = plot_pipeline.modules[connection.source.moduleId] dest = plot_pipeline.modules[connection.destination.moduleId] if dest.module_descriptor is inputport_desc: continue elif src.module_descriptor is inputport_desc: param = get_function(src, 'name') ports = plot_params.setdefault(param, []) ports.append(( plot_modules_map[connection.destination.moduleId], connection.destination.name)) else: generator.connect_modules( plot_modules_map[connection.source.moduleId], connection.source.name, plot_modules_map[connection.destination.moduleId], connection.destination.name) # Find the constant ports declared with aliases aliases = {port.name: port for port in recipe.plot.ports if port.is_alias} for module in plot_pipeline.module_list: for function in module.functions: remove = False for param in function.parameters: if param.alias in aliases: plot_params[param.alias] = [( plot_modules_map[module.id], function.name)] remove = True if remove: # Remove the function from the generated pipeline generator.update_function( plot_modules_map[module.id], function.name, None) del aliases # Adds default values for unset constants parameters_incl_defaults = dict(recipe.parameters) for port in recipe.plot.ports: if (isinstance(port, ConstantPort) and port.default_value is not None and port.name not in recipe.parameters): parameters_incl_defaults[port.name] = [RecipeParameterValue( constant=port.default_value)] # Maps a port name to the list of parameters # for each parameter, we have a list of connections tying it to modules of # the plot conn_map = dict() # param: str -> [[conn_id: int]] name_to_port = {port.name: port for port in recipe.plot.ports} actual_parameters = {} for port_name, parameters in parameters_incl_defaults.iteritems(): plot_ports = plot_params.get(port_name, []) p_conns = conn_map[port_name] = [] actual_values = [] for parameter in parameters: if parameter.type == RecipeParameterValue.VARIABLE: conns, actual_param = add_variable_subworkflow_typecast( generator, parameter.variable, plot_ports, name_to_port[port_name].type, typecast=typecast) p_conns.append(conns) actual_values.append(actual_param) else: # parameter.type == RecipeParameterValue.CONSTANT desc = name_to_port[port_name].type p_conns.append(add_constant_module( generator, desc, parameter.constant, plot_ports)) actual_values.append(parameter) actual_parameters[port_name] = actual_values del name_to_port pipeline_version = generator.perform_action() controller.vistrail.change_description( "Created DAT plot %s" % recipe.plot.name, pipeline_version) # FIXME : from_root seems to be necessary here, I don't know why controller.change_selected_version(pipeline_version, from_root=True) # Convert the modules to module ids in the port_map port_map = dict() for param, portlist in plot_params.iteritems(): port_map[param] = [(module.id, port) for module, port in portlist] return PipelineInformation( pipeline_version, DATRecipe(recipe.plot, actual_parameters), conn_map, port_map)
def test15(self): import vistrails.core.vistrail from vistrails.core.db.locator import XMLFileLocator import vistrails.core.system v = XMLFileLocator(vistrails.core.system.vistrails_root_directory() + '/tests/resources/dummy.xml').load()
def _read_metadata(self, package_identifier): """Reads a plot's ports from the subworkflow file Finds each InputPort module and gets the parameter name, optional flag and type from its 'name', 'optional' and 'spec' input functions. If input ports were declared in this Plot, we check that they are indeed present and were all listed (either list all of them or none). If the module type is a subclass of Constant, we will assume the port is to be set via direct input (ConstantPort), else by dragging a variable (DataPort). We also automatically add aliased input ports of compatible constant types as optional ConstantPort's. """ if self.subworkflow is None: return locator = XMLFileLocator(self.subworkflow) vistrail = locator.load() pipeline = get_upgraded_pipeline(vistrail) inputports = find_modules_by_type(pipeline, [InputPort]) if not inputports: raise ValueError("No InputPort module") currentports = {port.name: port for port in self.ports} seenports = set() for port in inputports: name = get_function(port, 'name') if not name: raise ValueError( "Subworkflow of plot '%s' in package '%s' has an " "InputPort with no name" % ( self.name, package_identifier)) if name in seenports: raise ValueError( "Subworkflow of plot '%s' in package '%s' has several " "InputPort modules with name '%s'" % ( self.name, package_identifier, name)) spec = get_function(port, 'spec') optional = get_function(port, 'optional') if optional == 'True': optional = True elif optional == 'False': optional = False else: optional = None try: currentport = currentports[name] except KeyError: # If the package didn't provide any port, it's ok, we can # discover them. But if some were present and some were # forgotten, emit a warning if currentports: warnings.warn( "Declaration of plot '%s' in package '%s' omitted " "port '%s'" % ( self.name, package_identifier, name)) if not spec: warnings.warn( "Subworkflow of plot '%s' in package '%s' has an " "InputPort '%s' with no type; assuming Module" % ( self.name, package_identifier, name)) spec = 'org.vistrails.vistrails.basic:Module' if not optional: optional = False type = resolve_descriptor(spec, package_identifier) if issubclass(type.module, Constant): currentport = ConstantPort( name=name, type=type, optional=optional) else: currentport = DataPort( name=name, type=type, optional=optional) self.ports.append(currentport) else: currentspec = (currentport.type.identifier + ':' + currentport.type.name) if ((spec and spec != currentspec) or (optional is not None and optional != currentport.optional)): warnings.warn( "Declaration of port '%s' from plot '%s' in " "package '%s' differs from subworkflow " "contents" % ( name, self.name, package_identifier)) spec = currentspec type = resolve_descriptor(currentspec, package_identifier) # Get info from the PortSpec currentport.default_value = None currentport.enumeration = None try: (default_type, default_value, entry_type, enum_values) = read_port_specs( pipeline, port) if default_value is not None: if not issubclass(default_type, type.module): raise ValueError("incompatible type %r" % (( default_type, type.module),)) elif default_type is type.module: currentport.default_value = default_value currentport.entry_type = entry_type currentport.enum_values = enum_values except ValueError, e: raise ValueError( "Error reading specs for port '%s' from plot '%s' of " "package '%s': %s" % ( name, self.name, package_identifier, e.args[0])) seenports.add(name)