Ejemplo n.º 1
0
def test_workflow_template():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata()
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {"a": _interface.Variable(int_type, "description1")},
        {
            "b": _interface.Variable(int_type, "description2"),
            "c": _interface.Variable(int_type, "description3")
        },
    )
    wf_node = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    obj = _workflow.WorkflowTemplate(
        id=_generic_id,
        metadata=wf_metadata,
        metadata_defaults=wf_metadata_defaults,
        interface=typed_interface,
        nodes=[wf_node],
        outputs=[],
    )
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
Ejemplo n.º 2
0
def test_task_node():
    obj = _workflow.TaskNode(reference_id=_generic_id)
    assert obj.reference_id == _generic_id

    obj2 = _workflow.TaskNode.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.reference_id == _generic_id
Ejemplo n.º 3
0
def test_workflow_template_with_queuing_budget():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta(
        seconds=10))
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    wf_node = _workflow.Node(id='some:node:id',
                             metadata=nm,
                             inputs=[],
                             upstream_node_ids=[],
                             output_aliases=[],
                             task_node=task)
    obj = _workflow.WorkflowTemplate(id=_generic_id,
                                     metadata=wf_metadata,
                                     metadata_defaults=wf_metadata_defaults,
                                     interface=typed_interface,
                                     nodes=[wf_node],
                                     outputs=[])
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
Ejemplo n.º 4
0
def test_node_task_with_inputs():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)
    bd = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=5)))
    bd2 = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=99)))
    binding = _literals.Binding(var="myvar", binding=bd)
    binding2 = _literals.Binding(var="myothervar", binding=bd2)

    obj = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[binding, binding2],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    assert obj.target == task
    assert obj.id == "some:node:id"
    assert obj.metadata == nm
    assert len(obj.inputs) == 2
    assert obj.inputs[0] == binding

    obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.target == task
    assert obj2.id == "some:node:id"
    assert obj2.metadata == nm
    assert len(obj2.inputs) == 2
    assert obj2.inputs[1] == binding2
Ejemplo n.º 5
0
def test_branch_node():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)
    bd = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=5)))
    bd2 = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=99)))
    binding = _literals.Binding(var='myvar', binding=bd)
    binding2 = _literals.Binding(var='myothervar', binding=bd2)

    obj = _workflow.Node(id='some:node:id',
                         metadata=nm,
                         inputs=[binding, binding2],
                         upstream_node_ids=[],
                         output_aliases=[],
                         task_node=task)

    bn = _workflow.BranchNode(
        _workflow.IfElseBlock(
            case=_workflow.IfBlock(condition=_condition.BooleanExpression(
                comparison=_condition.ComparisonExpression(
                    _condition.ComparisonExpression.Operator.EQ,
                    _condition.Operand(primitive=_literals.Primitive(
                        integer=5)),
                    _condition.Operand(primitive=_literals.Primitive(
                        integer=2)))),
                                   then_node=obj),
            other=[
                _workflow.IfBlock(condition=_condition.BooleanExpression(
                    conjunction=_condition.ConjunctionExpression(
                        _condition.ConjunctionExpression.LogicalOperator.AND,
                        _condition.BooleanExpression(
                            comparison=_condition.ComparisonExpression(
                                _condition.ComparisonExpression.Operator.EQ,
                                _condition.Operand(
                                    primitive=_literals.Primitive(integer=5)),
                                _condition.Operand(
                                    primitive=_literals.Primitive(
                                        integer=2)))),
                        _condition.BooleanExpression(
                            comparison=_condition.ComparisonExpression(
                                _condition.ComparisonExpression.Operator.EQ,
                                _condition.Operand(
                                    primitive=_literals.Primitive(integer=5)),
                                _condition.Operand(
                                    primitive=_literals.Primitive(
                                        integer=2)))))),
                                  then_node=obj)
            ],
            else_node=obj))
Ejemplo n.º 6
0
def test_node_task_with_no_inputs():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)

    obj = _workflow.Node(id='some:node:id',
                         metadata=nm,
                         inputs=[],
                         upstream_node_ids=[],
                         output_aliases=[],
                         task_node=task)
    assert obj.target == task
    assert obj.id == 'some:node:id'
    assert obj.metadata == nm

    obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.target == task
    assert obj2.id == 'some:node:id'
    assert obj2.metadata == nm
Ejemplo n.º 7
0
def test_future_task_document(task):
    rs = _literals.RetryStrategy(0)
    nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs)
    n = _workflow.Node(id="id",
                       metadata=nm,
                       inputs=[],
                       upstream_node_ids=[],
                       output_aliases=[],
                       task_node=_workflow.TaskNode(task.id))
    n.to_flyte_idl()
    doc = _dynamic_job.DynamicJobSpec(
        tasks=[task],
        nodes=[n],
        min_successes=1,
        outputs=[_literals.Binding("var", _literals.BindingData())],
        subworkflows=[])
    assert text_format.MessageToString(
        doc.to_flyte_idl()) == text_format.MessageToString(
            _dynamic_job.DynamicJobSpec.from_flyte_idl(
                doc.to_flyte_idl()).to_flyte_idl())
Ejemplo n.º 8
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Ejemplo n.º 9
0
def get_serializable_node(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: Node,
) -> workflow_model.Node:
    if entity.flyte_entity is None:
        raise Exception(f"Node {entity.id} has no flyte entity")

    upstream_sdk_nodes = [
        get_serializable(entity_mapping, settings, n)
        for n in entity.upstream_nodes
        if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
    ]

    # Reference entities also inherit from the classes in the second if statement so address them first.
    if isinstance(entity.flyte_entity, ReferenceEntity):
        # This is a throw away call.
        # See the comment in compile_into_workflow in python_function_task. This is just used to place a None value
        # in the entity_mapping.
        get_serializable(entity_mapping, settings, entity.flyte_entity)
        ref = entity.flyte_entity
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
        )
        if ref.reference.resource_type == _identifier_model.ResourceType.TASK:
            node_model._task_node = workflow_model.TaskNode(reference_id=ref.id)
        elif ref.reference.resource_type == _identifier_model.ResourceType.WORKFLOW:
            node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref.id)
        elif ref.reference.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
            node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref.id)
        else:
            raise Exception(f"Unexpected reference type {ref}")
        return node_model

    if isinstance(entity.flyte_entity, PythonTask):
        task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            task_node=workflow_model.TaskNode(
                reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
            ),
        )
        if entity._aliases:
            node_model._output_aliases = entity._aliases

    elif isinstance(entity.flyte_entity, WorkflowBase):
        wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id),
        )

    elif isinstance(entity.flyte_entity, BranchNode):
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity),
        )

    elif isinstance(entity.flyte_entity, LaunchPlan):
        lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)

        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id),
        )
    else:
        raise Exception(f"Node contained non-serializable entity {entity._flyte_entity}")

    return node_model
Ejemplo n.º 10
0
def get_serializable_node(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: Node,
    options: Optional[Options] = None,
) -> workflow_model.Node:
    if entity.flyte_entity is None:
        raise Exception(f"Node {entity.id} has no flyte entity")

    # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214
    from flytekit.remote.launch_plan import FlyteLaunchPlan
    from flytekit.remote.task import FlyteTask
    from flytekit.remote.workflow import FlyteWorkflow

    upstream_sdk_nodes = [
        get_serializable(entity_mapping, settings, n, options=options)
        for n in entity.upstream_nodes
        if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
    ]

    # Reference entities also inherit from the classes in the second if statement so address them first.
    if isinstance(entity.flyte_entity, ReferenceEntity):
        ref_spec = get_serializable(entity_mapping,
                                    settings,
                                    entity.flyte_entity,
                                    options=options)
        ref_template = ref_spec.template
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
        )
        if ref_template.resource_type == _identifier_model.ResourceType.TASK:
            node_model._task_node = workflow_model.TaskNode(
                reference_id=ref_template.id)
        elif ref_template.resource_type == _identifier_model.ResourceType.WORKFLOW:
            node_model._workflow_node = workflow_model.WorkflowNode(
                sub_workflow_ref=ref_template.id)
        elif ref_template.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
            node_model._workflow_node = workflow_model.WorkflowNode(
                launchplan_ref=ref_template.id)
        else:
            raise Exception(
                f"Unexpected resource type for reference entity {entity.flyte_entity}: {ref_template.resource_type}"
            )
        return node_model

    if isinstance(entity.flyte_entity, PythonTask):
        task_spec = get_serializable(entity_mapping,
                                     settings,
                                     entity.flyte_entity,
                                     options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            task_node=workflow_model.TaskNode(
                reference_id=task_spec.template.id,
                overrides=TaskNodeOverrides(resources=entity._resources)),
        )
        if entity._aliases:
            node_model._output_aliases = entity._aliases

    elif isinstance(entity.flyte_entity, WorkflowBase):
        wf_spec = get_serializable(entity_mapping,
                                   settings,
                                   entity.flyte_entity,
                                   options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                sub_workflow_ref=wf_spec.template.id),
        )

    elif isinstance(entity.flyte_entity, BranchNode):
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            branch_node=get_serializable(entity_mapping,
                                         settings,
                                         entity.flyte_entity,
                                         options=options),
        )

    elif isinstance(entity.flyte_entity, LaunchPlan):
        lp_spec = get_serializable(entity_mapping,
                                   settings,
                                   entity.flyte_entity,
                                   options=options)

        # Node's inputs should not contain the data which is fixed input
        node_input = []
        for b in entity.bindings:
            if b.var not in entity.flyte_entity.fixed_inputs.literals:
                node_input.append(b)

        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=node_input,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                launchplan_ref=lp_spec.id),
        )

    elif isinstance(entity.flyte_entity, FlyteTask):
        # Recursive call doesn't do anything except put the entity on the map.
        get_serializable(entity_mapping,
                         settings,
                         entity.flyte_entity,
                         options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            task_node=workflow_model.TaskNode(
                reference_id=entity.flyte_entity.id,
                overrides=TaskNodeOverrides(resources=entity._resources)),
        )
    elif isinstance(entity.flyte_entity, FlyteWorkflow):
        wf_template = get_serializable(entity_mapping,
                                       settings,
                                       entity.flyte_entity,
                                       options=options)
        for _, sub_wf in entity.flyte_entity.sub_workflows.items():
            get_serializable(entity_mapping, settings, sub_wf, options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                sub_workflow_ref=wf_template.id),
        )
    elif isinstance(entity.flyte_entity, FlyteLaunchPlan):
        # Recursive call doesn't do anything except put the entity on the map.
        get_serializable(entity_mapping,
                         settings,
                         entity.flyte_entity,
                         options=options)
        # Node's inputs should not contain the data which is fixed input
        node_input = []
        for b in entity.bindings:
            if b.var not in entity.flyte_entity.fixed_inputs.literals:
                node_input.append(b)

        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=node_input,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                launchplan_ref=entity.flyte_entity.id),
        )
    else:
        raise Exception(
            f"Node contained non-serializable entity {entity._flyte_entity}")

    return node_model