コード例 #1
0
    def test_exit_codes_filter(self):
        """Test that the `exit_codes` argument properly filters, returning `None` if the `node` has different status."""

        exit_code_filter = ExitCode(400)

        # This process node should match the exit code filter of the error handler
        node_match = ProcessNode()
        node_match.set_exit_status(exit_code_filter.status)

        # This one should not because it has a different exit status
        node_skip = ProcessNode()
        node_skip.set_exit_status(200)  # Some other exit status

        class ArithmeticAddBaseWorkChain(BaseRestartWorkChain):
            """Minimal base restart workchain for the ``ArithmeticAddCalculation``."""

            _process_class = ArithmeticAddCalculation

            @process_handler(exit_codes=exit_code_filter)
            def _(self, node):
                return ProcessHandlerReport()

        # Create dummy process instance
        process = ArithmeticAddBaseWorkChain()

        # Loop over all handlers, which should be just the one, and call it with the two different nodes
        for handler in process.get_process_handlers():
            # The `node_match` should match the `exit_codes` filter and so return a report instance
            assert isinstance(handler(process, node_match),
                              ProcessHandlerReport)

            # The `node_skip` has a wrong exit status and so should get skipped, returning `None`
            assert handler(process, node_skip) is None
コード例 #2
0
    def test_empty_exit_codes_list(self):
        """A `process_handler` with an empty `exit_codes` list should not run."""
        class SomeWorkChain(BaseRestartWorkChain):
            _process_class = ArithmeticAddCalculation

            @process_handler(exit_codes=[])
            def should_not_run(self, node):
                raise ValueError('This should not run.')

        child = ProcessNode()
        child.set_process_state(ProcessState.FINISHED)

        process = SomeWorkChain()
        process.setup()
        process.ctx.iteration = 1
        process.ctx.children = [child]
        process.inspect_process()
コード例 #3
0
    def test_priority(self):
        """Test that the handlers are called in order of their `priority`."""
        attribute_key = 'handlers_called'

        class ArithmeticAddBaseWorkChain(BaseRestartWorkChain):
            """Implementation of a possible BaseRestartWorkChain for the ``ArithmeticAddCalculation``."""

            _process_class = ArithmeticAddCalculation

            # Register some handlers that should be called in order of 4 -> 3 -> 2 -> 1 but are on purpose registered in
            # a different order. When called, they should add their name to `handlers_called` attribute of the node.
            # This can then be checked after invoking `inspect_process` to ensure they were called in the right order
            @process_handler(priority=100)
            def handler_01(self, node):
                """Example handler returing ExitCode 100."""
                handlers_called = node.get_attribute(attribute_key, default=[])
                handlers_called.append('handler_01')
                node.set_attribute(attribute_key, handlers_called)
                return ProcessHandlerReport(False, ExitCode(100))

            @process_handler(priority=300)
            def handler_03(self, node):
                """Example handler returing ExitCode 300."""
                handlers_called = node.get_attribute(attribute_key, default=[])
                handlers_called.append('handler_03')
                node.set_attribute(attribute_key, handlers_called)
                return ProcessHandlerReport(False, ExitCode(300))

            @process_handler(priority=200)
            def handler_02(self, node):
                """Example handler returing ExitCode 200."""
                handlers_called = node.get_attribute(attribute_key, default=[])
                handlers_called.append('handler_02')
                node.set_attribute(attribute_key, handlers_called)
                return ProcessHandlerReport(False, ExitCode(200))

            @process_handler(priority=400)
            def handler_04(self, node):
                """Example handler returing ExitCode 400."""
                handlers_called = node.get_attribute(attribute_key, default=[])
                handlers_called.append('handler_04')
                node.set_attribute(attribute_key, handlers_called)
                return ProcessHandlerReport(False, ExitCode(400))

        child = ProcessNode()
        child.set_process_state(ProcessState.FINISHED)
        child.set_exit_status(400)
        process = ArithmeticAddBaseWorkChain()
        process.setup()
        process.ctx.iteration = 1
        process.ctx.children = [child]

        # Last called handler should be `handler_01` which returned `ExitCode(100)`
        assert process.inspect_process() == ExitCode(100)
        assert child.get_attribute(attribute_key, []) == [
            'handler_04', 'handler_03', 'handler_02', 'handler_01'
        ]
コード例 #4
0
ファイル: process.py プロジェクト: zhonger/aiida-core
    def exposed_outputs(self,
                        node: orm.ProcessNode,
                        process_class: Type['Process'],
                        namespace: Optional[str] = None,
                        agglomerate: bool = True) -> AttributeDict:
        """Return the outputs which were exposed from the ``process_class`` and emitted by the specific ``node``

        :param node: process node whose outputs to try and retrieve
        :param namespace: Namespace in which to search for exposed outputs.
        :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also
            be searched for outputs. Outputs in lower-lying namespaces take precedence.

        :returns: exposed outputs

        """
        namespace_separator = self.spec().namespace_separator

        output_key_map = {}
        # maps the exposed name to all outputs that belong to it
        top_namespace_map = collections.defaultdict(list)
        link_types = (LinkType.CREATE, LinkType.RETURN)
        process_outputs_dict = node.get_outgoing(link_type=link_types).nested()

        for port_name in process_outputs_dict:
            top_namespace = port_name.split(namespace_separator)[0]
            top_namespace_map[top_namespace].append(port_name)

        for port_namespace in self._get_namespace_list(
                namespace=namespace, agglomerate=agglomerate):
            # only the top-level key is stored in _exposed_outputs
            for top_name in top_namespace_map:
                if top_name in self.spec(
                )._exposed_outputs[port_namespace][process_class]:  # pylint: disable=protected-access
                    output_key_map[top_name] = port_namespace

        result = {}

        for top_name, port_namespace in output_key_map.items():
            # collect all outputs belonging to the given top_name
            for port_name in top_namespace_map[top_name]:
                if port_namespace is None:
                    result[port_name] = process_outputs_dict[port_name]
                else:
                    result[port_namespace + namespace_separator +
                           port_name] = process_outputs_dict[port_name]

        return AttributeDict(result)