Пример #1
0
def test_workflow_node_lp():
    obj = _workflow.WorkflowNode(launchplan_ref=_generic_id)
    assert obj.launchplan_ref == _generic_id
    assert obj.reference == _generic_id

    obj2 = _workflow.WorkflowNode.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.reference == _generic_id
    assert obj2.launchplan_ref == _generic_id
Пример #2
0
def test_workflow_node_sw():
    obj = _workflow.WorkflowNode(sub_workflow_ref=_generic_id)
    assert obj.sub_workflow_ref == _generic_id
    assert obj.reference == _generic_id

    obj2 = _workflow.WorkflowNode.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.reference == _generic_id
    assert obj2.sub_workflow_ref == _generic_id
Пример #3
0
def get_serializable_node(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: Node,
) -> workflow_model.Node:
    if entity.flyte_entity is None:
        raise Exception(f"Node {entity.id} has no flyte entity")

    upstream_sdk_nodes = [
        get_serializable(entity_mapping, settings, n)
        for n in entity.upstream_nodes
        if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
    ]

    # Reference entities also inherit from the classes in the second if statement so address them first.
    if isinstance(entity.flyte_entity, ReferenceEntity):
        # This is a throw away call.
        # See the comment in compile_into_workflow in python_function_task. This is just used to place a None value
        # in the entity_mapping.
        get_serializable(entity_mapping, settings, entity.flyte_entity)
        ref = entity.flyte_entity
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
        )
        if ref.reference.resource_type == _identifier_model.ResourceType.TASK:
            node_model._task_node = workflow_model.TaskNode(reference_id=ref.id)
        elif ref.reference.resource_type == _identifier_model.ResourceType.WORKFLOW:
            node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref.id)
        elif ref.reference.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
            node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref.id)
        else:
            raise Exception(f"Unexpected reference type {ref}")
        return node_model

    if isinstance(entity.flyte_entity, PythonTask):
        task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            task_node=workflow_model.TaskNode(
                reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
            ),
        )
        if entity._aliases:
            node_model._output_aliases = entity._aliases

    elif isinstance(entity.flyte_entity, WorkflowBase):
        wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id),
        )

    elif isinstance(entity.flyte_entity, BranchNode):
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity),
        )

    elif isinstance(entity.flyte_entity, LaunchPlan):
        lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)

        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id),
        )
    else:
        raise Exception(f"Node contained non-serializable entity {entity._flyte_entity}")

    return node_model
Пример #4
0
def get_serializable_node(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: Node,
    options: Optional[Options] = None,
) -> workflow_model.Node:
    if entity.flyte_entity is None:
        raise Exception(f"Node {entity.id} has no flyte entity")

    # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214
    from flytekit.remote.launch_plan import FlyteLaunchPlan
    from flytekit.remote.task import FlyteTask
    from flytekit.remote.workflow import FlyteWorkflow

    upstream_sdk_nodes = [
        get_serializable(entity_mapping, settings, n, options=options)
        for n in entity.upstream_nodes
        if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
    ]

    # Reference entities also inherit from the classes in the second if statement so address them first.
    if isinstance(entity.flyte_entity, ReferenceEntity):
        ref_spec = get_serializable(entity_mapping,
                                    settings,
                                    entity.flyte_entity,
                                    options=options)
        ref_template = ref_spec.template
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
        )
        if ref_template.resource_type == _identifier_model.ResourceType.TASK:
            node_model._task_node = workflow_model.TaskNode(
                reference_id=ref_template.id)
        elif ref_template.resource_type == _identifier_model.ResourceType.WORKFLOW:
            node_model._workflow_node = workflow_model.WorkflowNode(
                sub_workflow_ref=ref_template.id)
        elif ref_template.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
            node_model._workflow_node = workflow_model.WorkflowNode(
                launchplan_ref=ref_template.id)
        else:
            raise Exception(
                f"Unexpected resource type for reference entity {entity.flyte_entity}: {ref_template.resource_type}"
            )
        return node_model

    if isinstance(entity.flyte_entity, PythonTask):
        task_spec = get_serializable(entity_mapping,
                                     settings,
                                     entity.flyte_entity,
                                     options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            task_node=workflow_model.TaskNode(
                reference_id=task_spec.template.id,
                overrides=TaskNodeOverrides(resources=entity._resources)),
        )
        if entity._aliases:
            node_model._output_aliases = entity._aliases

    elif isinstance(entity.flyte_entity, WorkflowBase):
        wf_spec = get_serializable(entity_mapping,
                                   settings,
                                   entity.flyte_entity,
                                   options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                sub_workflow_ref=wf_spec.template.id),
        )

    elif isinstance(entity.flyte_entity, BranchNode):
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            branch_node=get_serializable(entity_mapping,
                                         settings,
                                         entity.flyte_entity,
                                         options=options),
        )

    elif isinstance(entity.flyte_entity, LaunchPlan):
        lp_spec = get_serializable(entity_mapping,
                                   settings,
                                   entity.flyte_entity,
                                   options=options)

        # Node's inputs should not contain the data which is fixed input
        node_input = []
        for b in entity.bindings:
            if b.var not in entity.flyte_entity.fixed_inputs.literals:
                node_input.append(b)

        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=node_input,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                launchplan_ref=lp_spec.id),
        )

    elif isinstance(entity.flyte_entity, FlyteTask):
        # Recursive call doesn't do anything except put the entity on the map.
        get_serializable(entity_mapping,
                         settings,
                         entity.flyte_entity,
                         options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            task_node=workflow_model.TaskNode(
                reference_id=entity.flyte_entity.id,
                overrides=TaskNodeOverrides(resources=entity._resources)),
        )
    elif isinstance(entity.flyte_entity, FlyteWorkflow):
        wf_template = get_serializable(entity_mapping,
                                       settings,
                                       entity.flyte_entity,
                                       options=options)
        for _, sub_wf in entity.flyte_entity.sub_workflows.items():
            get_serializable(entity_mapping, settings, sub_wf, options=options)
        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=entity.bindings,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                sub_workflow_ref=wf_template.id),
        )
    elif isinstance(entity.flyte_entity, FlyteLaunchPlan):
        # Recursive call doesn't do anything except put the entity on the map.
        get_serializable(entity_mapping,
                         settings,
                         entity.flyte_entity,
                         options=options)
        # Node's inputs should not contain the data which is fixed input
        node_input = []
        for b in entity.bindings:
            if b.var not in entity.flyte_entity.fixed_inputs.literals:
                node_input.append(b)

        node_model = workflow_model.Node(
            id=_dnsify(entity.id),
            metadata=entity.metadata,
            inputs=node_input,
            upstream_node_ids=[n.id for n in upstream_sdk_nodes],
            output_aliases=[],
            workflow_node=workflow_model.WorkflowNode(
                launchplan_ref=entity.flyte_entity.id),
        )
    else:
        raise Exception(
            f"Node contained non-serializable entity {entity._flyte_entity}")

    return node_model