Ejemplo n.º 1
0
def test_protocol_graph_simple(protocols_a, protocols_b):

    # Make sure that the graph can merge simple protocols
    # when they are added one after the other.
    protocol_graph = ProtocolGraph()
    protocol_graph.add_protocols(*protocols_a)

    dependants_graph = protocol_graph._build_dependants_graph(
        protocol_graph.protocols, False, apply_reduction=True)

    assert len(protocol_graph.protocols) == len(protocols_a)
    assert len(dependants_graph) == len(protocols_a)
    n_root_protocols = len(protocol_graph.root_protocols)

    protocol_graph.add_protocols(*protocols_b)

    dependants_graph = protocol_graph._build_dependants_graph(
        protocol_graph.protocols, False, apply_reduction=False)

    assert len(protocol_graph.protocols) == len(protocols_a)
    assert len(dependants_graph) == len(protocols_a)
    assert len(protocol_graph.root_protocols) == n_root_protocols

    # Currently the graph shouldn't merge with an
    # addition
    protocol_graph = ProtocolGraph()
    protocol_graph.add_protocols(*protocols_a, *protocols_b)

    dependants_graph = protocol_graph._build_dependants_graph(
        protocol_graph.protocols, False, apply_reduction=False)

    assert len(protocol_graph.protocols) == len(protocols_a) + len(protocols_b)
    assert len(dependants_graph) == len(protocols_a) + len(protocols_b)
    assert len(protocol_graph.root_protocols) == 2 * n_root_protocols
Ejemplo n.º 2
0
def test_protocol_graph_execution(calculation_backend, compute_resources):

    if calculation_backend is not None:
        calculation_backend.start()

    protocol_a = DummyInputOutputProtocol("protocol_a")
    protocol_a.input_value = 1
    protocol_b = DummyInputOutputProtocol("protocol_b")
    protocol_b.input_value = ProtocolPath("output_value", protocol_a.id)

    protocol_graph = ProtocolGraph()
    protocol_graph.add_protocols(protocol_a, protocol_b)

    with tempfile.TemporaryDirectory() as directory:

        results = protocol_graph.execute(directory, calculation_backend,
                                         compute_resources)

        final_result = results[protocol_b.id]

        if calculation_backend is not None:
            final_result = final_result.result()

        with open(final_result[1]) as file:
            results_b = json.load(file, cls=TypedJSONDecoder)

    assert results_b[".output_value"] == protocol_a.input_value

    if compute_resources is not None:
        assert protocol_b.output_value == protocol_a.input_value

    if calculation_backend is not None:
        calculation_backend.stop()
Ejemplo n.º 3
0
def test_protocol_group_resume():
    """A test that protocol groups can recover after being killed
    (e.g. by a worker being killed due to hitting a wallclock limit)
    """

    compute_resources = ComputeResources()

    # Fake a protocol group which executes the first
    # two protocols and then 'gets killed'.
    protocol_a = DummyInputOutputProtocol("protocol_a")
    protocol_a.input_value = 1
    protocol_b = DummyInputOutputProtocol("protocol_b")
    protocol_b.input_value = ProtocolPath("output_value", protocol_a.id)

    protocol_group_a = ProtocolGroup("group_a")
    protocol_group_a.add_protocols(protocol_a, protocol_b)

    protocol_graph = ProtocolGraph()
    protocol_graph.add_protocols(protocol_group_a)
    protocol_graph.execute("graph_a", compute_resources=compute_resources)

    # Remove the output file so it appears the the protocol group had not
    # completed.
    os.unlink(
        os.path.join("graph_a", protocol_group_a.id,
                     f"{protocol_group_a.id}_output.json"))

    # Build the 'full' group with the last two protocols which
    # 'had not been exited' after the group was 'killed'
    protocol_a = DummyInputOutputProtocol("protocol_a")
    protocol_a.input_value = 1
    protocol_b = DummyInputOutputProtocol("protocol_b")
    protocol_b.input_value = ProtocolPath("output_value", protocol_a.id)
    protocol_c = DummyInputOutputProtocol("protocol_c")
    protocol_c.input_value = ProtocolPath("output_value", protocol_b.id)
    protocol_d = DummyInputOutputProtocol("protocol_d")
    protocol_d.input_value = ProtocolPath("output_value", protocol_c.id)

    protocol_group_a = ProtocolGroup("group_a")
    protocol_group_a.add_protocols(protocol_a, protocol_b, protocol_c,
                                   protocol_d)

    protocol_graph = ProtocolGraph()
    protocol_graph.add_protocols(protocol_group_a)
    protocol_graph.execute("graph_a", compute_resources=compute_resources)

    assert all(x != UNDEFINED for x in protocol_group_a.outputs.values())
Ejemplo n.º 4
0
def test_protocol_group_merging():
    def build_protocols(prefix):

        #     .-------------------.
        #     |          / i - j -|- b
        # a - | g - h - |         |
        #     |          \ k - l -|- c
        #     .-------------------.
        protocol_a = DummyInputOutputProtocol(prefix + "protocol_a")
        protocol_a.input_value = 1
        fork_protocols = build_fork(prefix)
        fork_protocols[0].input_value = ProtocolPath("output_value",
                                                     protocol_a.id)
        protocol_group = ProtocolGroup(prefix + "protocol_group")
        protocol_group.add_protocols(*fork_protocols)
        protocol_b = DummyInputOutputProtocol(prefix + "protocol_b")
        protocol_b.input_value = ProtocolPath("output_value",
                                              protocol_group.id, "protocol_j")
        protocol_c = DummyInputOutputProtocol(prefix + "protocol_c")
        protocol_c.input_value = ProtocolPath("output_value",
                                              protocol_group.id, "protocol_l")

        return [protocol_a, protocol_group, protocol_b, protocol_c]

    protocols_a = build_protocols("a_")
    protocols_b = build_protocols("b_")

    protocol_graph = ProtocolGraph()
    protocol_graph.add_protocols(*protocols_a)
    protocol_graph.add_protocols(*protocols_b)

    assert len(protocol_graph.protocols) == len(protocols_a)
    assert "a_protocol_group" in protocol_graph.protocols

    original_protocol_group = protocols_a[1]
    merged_protocol_group = protocol_graph.protocols["a_protocol_group"]

    assert original_protocol_group.schema.json(
    ) == merged_protocol_group.schema.json()
Ejemplo n.º 5
0
    def __init__(self):

        super(WorkflowGraph, self).__init__()

        self._workflows_to_execute = {}
        self._protocol_graph = ProtocolGraph()