def test_get_nodes(): in_node = EONode(InputTask()) inc_node0 = EONode(IncTask(), inputs=[in_node]) inc_node1 = EONode(IncTask(), inputs=[inc_node0]) inc_node2 = EONode(IncTask(), inputs=[inc_node1]) output_node = EONode(OutputTask(name="out"), inputs=[inc_node2]) eow = EOWorkflow([in_node, inc_node0, inc_node1, inc_node2, output_node]) returned_nodes = eow.get_nodes() assert [ in_node, inc_node0, inc_node1, inc_node2, output_node, ] == returned_nodes, "Returned nodes differ from original nodes" arguments_dict = {in_node: {"val": 2}, inc_node0: {"d": 2}} workflow_res = eow.execute(arguments_dict) manual_res = [] for _, node in enumerate(returned_nodes): manual_res = [ node.task.execute(*manual_res, **arguments_dict.get(node, {})) ] assert workflow_res.outputs["out"] == manual_res[ 0], "Manually running returned nodes produces different results."
def test_workflow_from_endnodes(): input_node1 = EONode(InputTask()) input_node2 = EONode(InputTask(), name="some name") divide_node = EONode(DivideTask(), inputs=(input_node1, input_node2), name="some name") output_node = EONode(OutputTask(name="out"), inputs=[divide_node]) regular_workflow = EOWorkflow( [input_node1, input_node2, divide_node, output_node]) endnode_workflow = EOWorkflow.from_endnodes(output_node) assert isinstance(endnode_workflow, EOWorkflow) assert set(endnode_workflow.get_nodes()) == set( regular_workflow.get_nodes()), "Nodes are different" with concurrent.futures.ProcessPoolExecutor(max_workers=5) as executor: regular_results = [ executor.submit(regular_workflow.execute, { input_node1: { "val": k**3 }, input_node2: { "val": k**2 } }) for k in range(2, 100) ] endnode_results = [ executor.submit(endnode_workflow.execute, { input_node1: { "val": k**3 }, input_node2: { "val": k**2 } }) for k in range(2, 100) ] executor.shutdown() assert all(x.result().outputs["out"] == y.result().outputs["out"] for x, y in zip(regular_results, endnode_results)) endnode_duplicates = EOWorkflow.from_endnodes(output_node, output_node, divide_node) assert set(endnode_duplicates.get_nodes()) == set( regular_workflow.get_nodes()), "Fails if endnodes are repeated"