def test_workflow_node_lp(): obj = _workflow.WorkflowNode(launchplan_ref=_generic_id) assert obj.launchplan_ref == _generic_id assert obj.reference == _generic_id obj2 = _workflow.WorkflowNode.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.reference == _generic_id assert obj2.launchplan_ref == _generic_id
def test_workflow_node_sw(): obj = _workflow.WorkflowNode(sub_workflow_ref=_generic_id) assert obj.sub_workflow_ref == _generic_id assert obj.reference == _generic_id obj2 = _workflow.WorkflowNode.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.reference == _generic_id assert obj2.sub_workflow_ref == _generic_id
def get_serializable_node( entity_mapping: OrderedDict, settings: SerializationSettings, entity: Node, ) -> workflow_model.Node: if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") upstream_sdk_nodes = [ get_serializable(entity_mapping, settings, n) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID ] # Reference entities also inherit from the classes in the second if statement so address them first. if isinstance(entity.flyte_entity, ReferenceEntity): # This is a throw away call. # See the comment in compile_into_workflow in python_function_task. This is just used to place a None value # in the entity_mapping. get_serializable(entity_mapping, settings, entity.flyte_entity) ref = entity.flyte_entity node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], ) if ref.reference.resource_type == _identifier_model.ResourceType.TASK: node_model._task_node = workflow_model.TaskNode(reference_id=ref.id) elif ref.reference.resource_type == _identifier_model.ResourceType.WORKFLOW: node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref.id) elif ref.reference.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref.id) else: raise Exception(f"Unexpected reference type {ref}") return node_model if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) ), ) if entity._aliases: node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, WorkflowBase): wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id), ) elif isinstance(entity.flyte_entity, BranchNode): node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity), ) elif isinstance(entity.flyte_entity, LaunchPlan): lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) else: raise Exception(f"Node contained non-serializable entity {entity._flyte_entity}") return node_model
def get_serializable_node( entity_mapping: OrderedDict, settings: SerializationSettings, entity: Node, options: Optional[Options] = None, ) -> workflow_model.Node: if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214 from flytekit.remote.launch_plan import FlyteLaunchPlan from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow upstream_sdk_nodes = [ get_serializable(entity_mapping, settings, n, options=options) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID ] # Reference entities also inherit from the classes in the second if statement so address them first. if isinstance(entity.flyte_entity, ReferenceEntity): ref_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) ref_template = ref_spec.template node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], ) if ref_template.resource_type == _identifier_model.ResourceType.TASK: node_model._task_node = workflow_model.TaskNode( reference_id=ref_template.id) elif ref_template.resource_type == _identifier_model.ResourceType.WORKFLOW: node_model._workflow_node = workflow_model.WorkflowNode( sub_workflow_ref=ref_template.id) elif ref_template.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: node_model._workflow_node = workflow_model.WorkflowNode( launchplan_ref=ref_template.id) else: raise Exception( f"Unexpected resource type for reference entity {entity.flyte_entity}: {ref_template.resource_type}" ) return node_model if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)), ) if entity._aliases: node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, WorkflowBase): wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( sub_workflow_ref=wf_spec.template.id), ) elif isinstance(entity.flyte_entity, BranchNode): node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options), ) elif isinstance(entity.flyte_entity, LaunchPlan): lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) # Node's inputs should not contain the data which is fixed input node_input = [] for b in entity.bindings: if b.var not in entity.flyte_entity.fixed_inputs.literals: node_input.append(b) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( launchplan_ref=lp_spec.id), ) elif isinstance(entity.flyte_entity, FlyteTask): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources)), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): wf_template = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) for _, sub_wf in entity.flyte_entity.sub_workflows.items(): get_serializable(entity_mapping, settings, sub_wf, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( sub_workflow_ref=wf_template.id), ) elif isinstance(entity.flyte_entity, FlyteLaunchPlan): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) # Node's inputs should not contain the data which is fixed input node_input = [] for b in entity.bindings: if b.var not in entity.flyte_entity.fixed_inputs.literals: node_input.append(b) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( launchplan_ref=entity.flyte_entity.id), ) else: raise Exception( f"Node contained non-serializable entity {entity._flyte_entity}") return node_model