def test_exit_codes_filter(self): """Test that the `exit_codes` argument properly filters, returning `None` if the `node` has different status.""" exit_code_filter = ExitCode(400) # This process node should match the exit code filter of the error handler node_match = ProcessNode() node_match.set_exit_status(exit_code_filter.status) # This one should not because it has a different exit status node_skip = ProcessNode() node_skip.set_exit_status(200) # Some other exit status class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): """Minimal base restart workchain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation @process_handler(exit_codes=exit_code_filter) def _(self, node): return ProcessHandlerReport() # Create dummy process instance process = ArithmeticAddBaseWorkChain() # Loop over all handlers, which should be just the one, and call it with the two different nodes for handler in process.get_process_handlers(): # The `node_match` should match the `exit_codes` filter and so return a report instance assert isinstance(handler(process, node_match), ProcessHandlerReport) # The `node_skip` has a wrong exit status and so should get skipped, returning `None` assert handler(process, node_skip) is None
def test_empty_exit_codes_list(self): """A `process_handler` with an empty `exit_codes` list should not run.""" class SomeWorkChain(BaseRestartWorkChain): _process_class = ArithmeticAddCalculation @process_handler(exit_codes=[]) def should_not_run(self, node): raise ValueError('This should not run.') child = ProcessNode() child.set_process_state(ProcessState.FINISHED) process = SomeWorkChain() process.setup() process.ctx.iteration = 1 process.ctx.children = [child] process.inspect_process()
def test_priority(self): """Test that the handlers are called in order of their `priority`.""" attribute_key = 'handlers_called' class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): """Implementation of a possible BaseRestartWorkChain for the ``ArithmeticAddCalculation``.""" _process_class = ArithmeticAddCalculation # Register some handlers that should be called in order of 4 -> 3 -> 2 -> 1 but are on purpose registered in # a different order. When called, they should add their name to `handlers_called` attribute of the node. # This can then be checked after invoking `inspect_process` to ensure they were called in the right order @process_handler(priority=100) def handler_01(self, node): """Example handler returing ExitCode 100.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_01') node.set_attribute(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(100)) @process_handler(priority=300) def handler_03(self, node): """Example handler returing ExitCode 300.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_03') node.set_attribute(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(300)) @process_handler(priority=200) def handler_02(self, node): """Example handler returing ExitCode 200.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_02') node.set_attribute(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(200)) @process_handler(priority=400) def handler_04(self, node): """Example handler returing ExitCode 400.""" handlers_called = node.get_attribute(attribute_key, default=[]) handlers_called.append('handler_04') node.set_attribute(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(400)) child = ProcessNode() child.set_process_state(ProcessState.FINISHED) child.set_exit_status(400) process = ArithmeticAddBaseWorkChain() process.setup() process.ctx.iteration = 1 process.ctx.children = [child] # Last called handler should be `handler_01` which returned `ExitCode(100)` assert process.inspect_process() == ExitCode(100) assert child.get_attribute(attribute_key, []) == [ 'handler_04', 'handler_03', 'handler_02', 'handler_01' ]
def exposed_outputs(self, node: orm.ProcessNode, process_class: Type['Process'], namespace: Optional[str] = None, agglomerate: bool = True) -> AttributeDict: """Return the outputs which were exposed from the ``process_class`` and emitted by the specific ``node`` :param node: process node whose outputs to try and retrieve :param namespace: Namespace in which to search for exposed outputs. :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched for outputs. Outputs in lower-lying namespaces take precedence. :returns: exposed outputs """ namespace_separator = self.spec().namespace_separator output_key_map = {} # maps the exposed name to all outputs that belong to it top_namespace_map = collections.defaultdict(list) link_types = (LinkType.CREATE, LinkType.RETURN) process_outputs_dict = node.get_outgoing(link_type=link_types).nested() for port_name in process_outputs_dict: top_namespace = port_name.split(namespace_separator)[0] top_namespace_map[top_namespace].append(port_name) for port_namespace in self._get_namespace_list( namespace=namespace, agglomerate=agglomerate): # only the top-level key is stored in _exposed_outputs for top_name in top_namespace_map: if top_name in self.spec( )._exposed_outputs[port_namespace][process_class]: # pylint: disable=protected-access output_key_map[top_name] = port_namespace result = {} for top_name, port_namespace in output_key_map.items(): # collect all outputs belonging to the given top_name for port_name in top_namespace_map[top_name]: if port_namespace is None: result[port_name] = process_outputs_dict[port_name] else: result[port_namespace + namespace_separator + port_name] = process_outputs_dict[port_name] return AttributeDict(result)