Пример #1
0
def test_workflow_decorator():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         'propject', 'domain', 'my_task',
                                         'version')

    @inputs(a=[primitives.Integer])
    @outputs(b=[primitives.Integer])
    @python_task
    def my_list_task(wf_params, a, b):
        b.set([v + 1 for v in a])

    my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                              'propject', 'domain',
                                              'my_list_task', 'version')

    class my_workflow(object):
        input_1 = promise.Input('input_1', primitives.Integer)
        input_2 = promise.Input('input_2',
                                primitives.Integer,
                                default=5,
                                help='Not required.')
        n1 = my_task(a=input_1)
        n2 = my_task(a=input_2)
        n3 = my_task(a=100)
        n4 = my_task(a=n1.outputs.b)
        n5 = my_list_task(a=[input_1, input_2, n3.outputs.b, 100])
        n6 = my_list_task(a=n5.outputs.b)
        n1 >> n6
        a = workflow.Output('a', n1.outputs.b, sdk_type=primitives.Integer)

    w = workflow.build_sdk_workflow_from_metaclass(
        my_workflow,
        on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.
        FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)

    assert w.interface.inputs[
        'input_1'].type == primitives.Integer.to_flyte_literal_type()
    assert w.interface.inputs[
        'input_2'].type == primitives.Integer.to_flyte_literal_type()
    assert w.nodes[0].inputs[0].var == 'a'
    assert w.nodes[0].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[0].inputs[0].binding.promise.var == 'input_1'
    assert w.nodes[1].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[1].inputs[0].binding.promise.var == 'input_2'
    assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100
    assert w.nodes[3].inputs[0].var == 'a'
    assert w.nodes[3].inputs[0].binding.promise.node_id == 'n1'

    # Test conversion to flyte_idl and back
    w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, 'fake',
                                   'faker', 'fakest', 'fakerest')
    w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl())
    assert w.interface.inputs[
        'input_1'].type == primitives.Integer.to_flyte_literal_type()
    assert w.interface.inputs[
        'input_2'].type == primitives.Integer.to_flyte_literal_type()
    assert w.nodes[0].inputs[0].var == 'a'
    assert w.nodes[0].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[0].inputs[0].binding.promise.var == 'input_1'
    assert w.nodes[1].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[1].inputs[0].binding.promise.var == 'input_2'
    assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100
    assert w.nodes[3].inputs[0].var == 'a'
    assert w.nodes[3].inputs[0].binding.promise.node_id == 'n1'
    assert w.nodes[4].inputs[0].var == 'a'
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        0].promise.var == 'input_1'
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        1].promise.var == 'input_2'
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        2].promise.node_id == 'n3'
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        2].promise.var == 'b'
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        3].scalar.primitive.integer == 100
    assert w.nodes[5].inputs[0].var == 'a'
    assert w.nodes[5].inputs[0].binding.promise.node_id == 'n5'
    assert w.nodes[5].inputs[0].binding.promise.var == 'b'

    assert len(w.outputs) == 1
    assert w.outputs[0].var == 'a'
    assert w.outputs[0].binding.promise.var == 'b'
    assert w.outputs[0].binding.promise.node_id == 'n1'
    assert w.metadata.on_failure == _workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE
Пример #2
0
 def wrapper(metaclass):
     wf = _common_workflow.build_sdk_workflow_from_metaclass(metaclass,
                                                             cls=cls)
     return wf
Пример #3
0
 def wrapper(metaclass):
     wf = _common_workflow.build_sdk_workflow_from_metaclass(
         metaclass, cls=cls, queuing_budget=queuing_budget)
     return wf