示例#1
0
def test_workflow_template():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata()
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {"a": _interface.Variable(int_type, "description1")},
        {
            "b": _interface.Variable(int_type, "description2"),
            "c": _interface.Variable(int_type, "description3")
        },
    )
    wf_node = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    obj = _workflow.WorkflowTemplate(
        id=_generic_id,
        metadata=wf_metadata,
        metadata_defaults=wf_metadata_defaults,
        interface=typed_interface,
        nodes=[wf_node],
        outputs=[],
    )
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
示例#2
0
def test_workflow_template_with_queuing_budget():
    task = _workflow.TaskNode(reference_id=_generic_id)
    nm = _get_sample_node_metadata()
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta(
        seconds=10))
    wf_metadata_defaults = _workflow.WorkflowMetadataDefaults()
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    wf_node = _workflow.Node(id='some:node:id',
                             metadata=nm,
                             inputs=[],
                             upstream_node_ids=[],
                             output_aliases=[],
                             task_node=task)
    obj = _workflow.WorkflowTemplate(id=_generic_id,
                                     metadata=wf_metadata,
                                     metadata_defaults=wf_metadata_defaults,
                                     interface=typed_interface,
                                     nodes=[wf_node],
                                     outputs=[])
    obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
示例#3
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
示例#4
0
def get_serializable_workflow(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: WorkflowBase,
) -> admin_workflow_models.WorkflowSpec:
    # Get node models
    upstream_node_models = [
        get_serializable(entity_mapping, settings, n)
        for n in entity.nodes
        if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
    ]

    sub_wfs = []
    for n in entity.nodes:
        if isinstance(n.flyte_entity, WorkflowBase):
            if isinstance(n.flyte_entity, ReferenceEntity):
                raise Exception(
                    f"Sorry, reference subworkflows do not work right now, please use the launch plan instead for the "
                    f"subworkflow you're trying to invoke. Node: {n}"
                )
            sub_wf_spec = get_serializable(entity_mapping, settings, n.flyte_entity)
            if not isinstance(sub_wf_spec, admin_workflow_models.WorkflowSpec):
                raise Exception(
                    f"Serialized form of a workflow should be an admin.WorkflowSpec but {type(sub_wf_spec)} found instead"
                )
            sub_wfs.append(sub_wf_spec.template)
            sub_wfs.extend(sub_wf_spec.sub_workflows)

        if isinstance(n.flyte_entity, BranchNode):
            if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block
            # See comment in get_serializable_branch_node also. Again this is a List[Node] even though it's supposed
            # to be a List[workflow_model.Node]
            leaf_nodes: List[Node] = filter(  # noqa
                None,
                [
                    if_else.case.then_node,
                    *([] if if_else.other is None else [x.then_node for x in if_else.other]),
                    if_else.else_node,
                ],
            )
            for leaf_node in leaf_nodes:
                if isinstance(leaf_node.flyte_entity, WorkflowBase):
                    sub_wf_spec = get_serializable(entity_mapping, settings, leaf_node.flyte_entity)
                    sub_wfs.append(sub_wf_spec.template)
                    sub_wfs.extend(sub_wf_spec.sub_workflows)

    wf_id = _identifier_model.Identifier(
        resource_type=_identifier_model.ResourceType.WORKFLOW,
        project=settings.project,
        domain=settings.domain,
        name=entity.name,
        version=settings.version,
    )
    wf_t = workflow_model.WorkflowTemplate(
        id=wf_id,
        metadata=entity.workflow_metadata.to_flyte_model(),
        metadata_defaults=entity.workflow_metadata_defaults.to_flyte_model(),
        interface=entity.interface,
        nodes=upstream_node_models,
        outputs=entity.output_bindings,
    )

    return admin_workflow_models.WorkflowSpec(template=wf_t, sub_workflows=list(set(sub_wfs)))
示例#5
0
def get_serializable_workflow(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: WorkflowBase,
    options: Optional[Options] = None,
) -> admin_workflow_models.WorkflowSpec:
    # TODO: Try to move up following config refactor - https://github.com/flyteorg/flyte/issues/2214
    from flytekit.remote.workflow import FlyteWorkflow

    # Get node models
    upstream_node_models = [
        get_serializable(entity_mapping, settings, n, options)
        for n in entity.nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
    ]

    sub_wfs = []
    for n in entity.nodes:
        if isinstance(n.flyte_entity, WorkflowBase):
            # We are currently not supporting reference workflows since these will
            # require a network call to flyteadmin to populate the WorkflowTemplate
            # object
            if isinstance(n.flyte_entity, ReferenceWorkflow):
                raise Exception(
                    "Reference sub-workflows are currently unsupported. Use reference launch plans instead."
                )
            sub_wf_spec = get_serializable(entity_mapping, settings,
                                           n.flyte_entity, options)
            if not isinstance(sub_wf_spec, admin_workflow_models.WorkflowSpec):
                raise TypeError(
                    f"Unexpected type for serialized form of workflow. Expected {admin_workflow_models.WorkflowSpec}, but got {type(sub_wf_spec)}"
                )
            sub_wfs.append(sub_wf_spec.template)
            sub_wfs.extend(sub_wf_spec.sub_workflows)

        if isinstance(n.flyte_entity, FlyteWorkflow):
            get_serializable(entity_mapping, settings, n.flyte_entity, options)
            sub_wfs.append(n.flyte_entity)
            sub_wfs.extend([s for s in n.flyte_entity.sub_workflows.values()])

        if isinstance(n.flyte_entity, BranchNode):
            if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block
            # See comment in get_serializable_branch_node also. Again this is a List[Node] even though it's supposed
            # to be a List[workflow_model.Node]
            leaf_nodes: List[Node] = filter(  # noqa
                None,
                [
                    if_else.case.then_node,
                    *([] if if_else.other is None else
                      [x.then_node for x in if_else.other]),
                    if_else.else_node,
                ],
            )
            for leaf_node in leaf_nodes:
                if isinstance(leaf_node.flyte_entity, WorkflowBase):
                    sub_wf_spec = get_serializable(entity_mapping, settings,
                                                   leaf_node.flyte_entity,
                                                   options)
                    sub_wfs.append(sub_wf_spec.template)
                    sub_wfs.extend(sub_wf_spec.sub_workflows)
                elif isinstance(leaf_node.flyte_entity, FlyteWorkflow):
                    get_serializable(entity_mapping, settings,
                                     leaf_node.flyte_entity, options)
                    sub_wfs.append(leaf_node.flyte_entity)
                    sub_wfs.extend([
                        s
                        for s in leaf_node.flyte_entity.sub_workflows.values()
                    ])

    wf_id = _identifier_model.Identifier(
        resource_type=_identifier_model.ResourceType.WORKFLOW,
        project=settings.project,
        domain=settings.domain,
        name=entity.name,
        version=settings.version,
    )
    wf_t = workflow_model.WorkflowTemplate(
        id=wf_id,
        metadata=entity.workflow_metadata.to_flyte_model(),
        metadata_defaults=entity.workflow_metadata_defaults.to_flyte_model(),
        interface=entity.interface,
        nodes=upstream_node_models,
        outputs=entity.output_bindings,
    )
    return admin_workflow_models.WorkflowSpec(
        template=wf_t,
        sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()))