def test_workflow_node(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', 'version') input_list = [ promise.Input('required', primitives.Integer), promise.Input('not_required', primitives.Integer, default=5, help='Not required.') ] n1 = my_task(a=input_list[0]).assign_id_and_return('n1') n2 = my_task(a=input_list[1]).assign_id_and_return('n2') n3 = my_task(a=100).assign_id_and_return('n3') n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return('n5') n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ workflow.Output( 'nested_out', [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]]), workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer) ] w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes) with _pytest.raises(NotImplementedError): w() # TODO: Uncomment when sub-workflows are supported. """
def test_workflow_serialization(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', 'version') input_list = [ promise.Input('required', primitives.Integer), promise.Input('not_required', primitives.Integer, default=5, help='Not required.') ] n1 = my_task(a=input_list[0]).assign_id_and_return('n1') n2 = my_task(a=input_list[1]).assign_id_and_return('n2') n3 = my_task(a=100).assign_id_and_return('n3') n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return('n5') n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ workflow.Output( 'nested_out', [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]]), workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer) ] w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes) serialized = w.serialize() assert isinstance(serialized, _workflow_pb2.WorkflowSpec) assert len(serialized.template.nodes) == 6 assert len(serialized.template.interface.inputs.variables.keys()) == 2 assert len(serialized.template.interface.outputs.variables.keys()) == 2
def _create_workflow(self, name, tasks): """ Create workflow for the pipeline. :param str name: :param list[airflow.models.BaseOperator] tasks: """ deps = {} for t in tasks: deps[t] = t.upstream_task_ids tasks, nodes = self._create_tasks(tasks) # Create map to look up tasks by their fully-qualified name. This map goes from something like # app.workflows.MyWorkflow.task_one to the task_one SdkRunnable task object tmap = {} for t in tasks: # This mocks an Admin registration, setting the reference id to the name of the task itself t.target._reference_id = t.id tmap[t.id] = t w = workflow_common.SdkWorkflow(inputs=[], outputs=[], nodes=nodes) task_templates = [] for n in w.nodes: if n.task_node is not None: task_templates.append(tmap[n.task_node.reference_id]) # TODO: sub_dags should be converted to subwokflows # elif n.workflow_node is not None: # n.workflow_node._launchplan_ref = n.workflow_node.id # n.workflow_node._sub_workflow_ref = n.workflow_node.id if n.id in deps: n._upstream_node_ids = deps[n.id] # Create the WorkflowClosure object that wraps both the workflow and its tasks return WorkflowClosure(workflow=w, tasks=task_templates)
def test_workflow_node(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', 'version') input_list = [ promise.Input('required', primitives.Integer), promise.Input('not_required', primitives.Integer, default=5, help='Not required.') ] n1 = my_task(a=input_list[0]).assign_id_and_return('n1') n2 = my_task(a=input_list[1]).assign_id_and_return('n2') n3 = my_task(a=100).assign_id_and_return('n3') n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return('n5') n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ workflow.Output( 'nested_out', [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]]), workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer) ] w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes) # Test that required input isn't set with _pytest.raises(_user_exceptions.FlyteAssertion): w() # Test that positional args are rejected with _pytest.raises(_user_exceptions.FlyteAssertion): w(1, 2) # Test that type checking works with _pytest.raises(_user_exceptions.FlyteTypeException): w(required='abc', not_required=1) # Test that bad arg name is detected with _pytest.raises(_user_exceptions.FlyteAssertion): w(required=1, bad_arg=1) # Test default input is accounted for n = w(required=10) assert n.inputs[0].var == 'not_required' assert n.inputs[0].binding.scalar.primitive.integer == 5 assert n.inputs[1].var == 'required' assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test default input is overridden n = w(required=10, not_required=50) assert n.inputs[0].var == 'not_required' assert n.inputs[0].binding.scalar.primitive.integer == 50 assert n.inputs[1].var == 'required' assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test that workflow is saved in the node w._id = 'fake' assert n.workflow_node.sub_workflow_ref == 'fake' w._id = None # Test that outputs are promised n.assign_id_and_return('node-id*') # dns'ified assert n.outputs['scalar_out'].sdk_type.to_flyte_literal_type( ) == primitives.Integer.to_flyte_literal_type() assert n.outputs['scalar_out'].var == 'scalar_out' assert n.outputs['scalar_out'].node_id == 'node-id' assert n.outputs['nested_out'].sdk_type.to_flyte_literal_type() == \ containers.List(containers.List(primitives.Integer)).to_flyte_literal_type() assert n.outputs['nested_out'].var == 'nested_out' assert n.outputs['nested_out'].node_id == 'node-id'
def test_workflow(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', 'version') input_list = [ promise.Input('input_1', primitives.Integer), promise.Input('input_2', primitives.Integer, default=5, help='Not required.') ] n1 = my_task(a=input_list[0]).assign_id_and_return('n1') n2 = my_task(a=input_list[1]).assign_id_and_return('n2') n3 = my_task(a=100).assign_id_and_return('n3') n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return('n5') n6 = my_list_task(a=n5.outputs.b) n1 >> n6 nodes = [n1, n2, n3, n4, n5, n6] w = workflow.SdkWorkflow(inputs=input_list, outputs=[ workflow.Output('a', n1.outputs.b, sdk_type=primitives.Integer) ], nodes=nodes) assert w.interface.inputs[ 'input_1'].type == primitives.Integer.to_flyte_literal_type() assert w.interface.inputs[ 'input_2'].type == primitives.Integer.to_flyte_literal_type() assert w.nodes[0].inputs[0].var == 'a' assert w.nodes[0].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[0].inputs[0].binding.promise.var == 'input_1' assert w.nodes[1].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[1].inputs[0].binding.promise.var == 'input_2' assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 assert w.nodes[3].inputs[0].var == 'a' assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id # Test conversion to flyte_idl and back w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, 'fake', 'faker', 'fakest', 'fakerest') w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) assert w.interface.inputs[ 'input_1'].type == primitives.Integer.to_flyte_literal_type() assert w.interface.inputs[ 'input_2'].type == primitives.Integer.to_flyte_literal_type() assert w.nodes[0].inputs[0].var == 'a' assert w.nodes[0].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[0].inputs[0].binding.promise.var == 'input_1' assert w.nodes[1].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[1].inputs[0].binding.promise.var == 'input_2' assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 assert w.nodes[3].inputs[0].var == 'a' assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id assert w.nodes[4].inputs[0].var == 'a' assert w.nodes[4].inputs[0].binding.collection.bindings[ 0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[4].inputs[0].binding.collection.bindings[ 0].promise.var == 'input_1' assert w.nodes[4].inputs[0].binding.collection.bindings[ 1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[4].inputs[0].binding.collection.bindings[ 1].promise.var == 'input_2' assert w.nodes[4].inputs[0].binding.collection.bindings[ 2].promise.node_id == n3.id assert w.nodes[4].inputs[0].binding.collection.bindings[ 2].promise.var == 'b' assert w.nodes[4].inputs[0].binding.collection.bindings[ 3].scalar.primitive.integer == 100 assert w.nodes[5].inputs[0].var == 'a' assert w.nodes[5].inputs[0].binding.promise.node_id == n5.id assert w.nodes[5].inputs[0].binding.promise.var == 'b' assert len(w.outputs) == 1 assert w.outputs[0].var == 'a' assert w.outputs[0].binding.promise.var == 'b' assert w.outputs[0].binding.promise.node_id == 'n1'