def test_workflow_template(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() int_type = _types.LiteralType(_types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata() wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( {"a": _interface.Variable(int_type, "description1")}, { "b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3") }, ) wf_node = _workflow.Node( id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, ) obj = _workflow.WorkflowTemplate( id=_generic_id, metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], outputs=[], ) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj
def test_workflow_template_with_queuing_budget(): task = _workflow.TaskNode(reference_id=_generic_id) nm = _get_sample_node_metadata() int_type = _types.LiteralType(_types.SimpleType.INTEGER) wf_metadata = _workflow.WorkflowMetadata(queuing_budget=timedelta( seconds=10)) wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) wf_node = _workflow.Node(id='some:node:id', metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task) obj = _workflow.WorkflowTemplate(id=_generic_id, metadata=wf_metadata, metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], outputs=[]) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj
def test_node_task_with_inputs(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=99))) binding = _literals.Binding(var="myvar", binding=bd) binding2 = _literals.Binding(var="myothervar", binding=bd2) obj = _workflow.Node( id="some:node:id", metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], task_node=task, ) assert obj.target == task assert obj.id == "some:node:id" assert obj.metadata == nm assert len(obj.inputs) == 2 assert obj.inputs[0] == binding obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.target == task assert obj2.id == "some:node:id" assert obj2.metadata == nm assert len(obj2.inputs) == 2 assert obj2.inputs[1] == binding2
def test_branch_node(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=99))) binding = _literals.Binding(var='myvar', binding=bd) binding2 = _literals.Binding(var='myothervar', binding=bd2) obj = _workflow.Node(id='some:node:id', metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], task_node=task) bn = _workflow.BranchNode( _workflow.IfElseBlock( case=_workflow.IfBlock(condition=_condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand(primitive=_literals.Primitive( integer=5)), _condition.Operand(primitive=_literals.Primitive( integer=2)))), then_node=obj), other=[ _workflow.IfBlock(condition=_condition.BooleanExpression( conjunction=_condition.ConjunctionExpression( _condition.ConjunctionExpression.LogicalOperator.AND, _condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand( primitive=_literals.Primitive(integer=5)), _condition.Operand( primitive=_literals.Primitive( integer=2)))), _condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand( primitive=_literals.Primitive(integer=5)), _condition.Operand( primitive=_literals.Primitive( integer=2)))))), then_node=obj) ], else_node=obj))
def test_node_task_with_no_inputs(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) obj = _workflow.Node(id='some:node:id', metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task) assert obj.target == task assert obj.id == 'some:node:id' assert obj.metadata == nm obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.target == task assert obj2.id == 'some:node:id' assert obj2.metadata == nm
def test_future_task_document(task): rs = _literals.RetryStrategy(0) nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs) n = _workflow.Node(id="id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=_workflow.TaskNode(task.id)) n.to_flyte_idl() doc = _dynamic_job.DynamicJobSpec( tasks=[task], nodes=[n], min_successes=1, outputs=[_literals.Binding("var", _literals.BindingData())], subworkflows=[]) assert text_format.MessageToString( doc.to_flyte_idl()) == text_format.MessageToString( _dynamic_job.DynamicJobSpec.from_flyte_idl( doc.to_flyte_idl()).to_flyte_idl())
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def get_serializable_node( entity_mapping: OrderedDict, settings: SerializationSettings, entity: Node, ) -> workflow_model.Node: if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") upstream_sdk_nodes = [ get_serializable(entity_mapping, settings, n) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID ] # Reference entities also inherit from the classes in the second if statement so address them first. if isinstance(entity.flyte_entity, ReferenceEntity): # This is a throw away call. # See the comment in compile_into_workflow in python_function_task. This is just used to place a None value # in the entity_mapping. get_serializable(entity_mapping, settings, entity.flyte_entity) ref = entity.flyte_entity node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], ) if ref.reference.resource_type == _identifier_model.ResourceType.TASK: node_model._task_node = workflow_model.TaskNode(reference_id=ref.id) elif ref.reference.resource_type == _identifier_model.ResourceType.WORKFLOW: node_model._workflow_node = workflow_model.WorkflowNode(sub_workflow_ref=ref.id) elif ref.reference.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: node_model._workflow_node = workflow_model.WorkflowNode(launchplan_ref=ref.id) else: raise Exception(f"Unexpected reference type {ref}") return node_model if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) ), ) if entity._aliases: node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, WorkflowBase): wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id), ) elif isinstance(entity.flyte_entity, BranchNode): node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity), ) elif isinstance(entity.flyte_entity, LaunchPlan): lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) else: raise Exception(f"Node contained non-serializable entity {entity._flyte_entity}") return node_model
def get_serializable_node( entity_mapping: OrderedDict, settings: SerializationSettings, entity: Node, options: Optional[Options] = None, ) -> workflow_model.Node: if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214 from flytekit.remote.launch_plan import FlyteLaunchPlan from flytekit.remote.task import FlyteTask from flytekit.remote.workflow import FlyteWorkflow upstream_sdk_nodes = [ get_serializable(entity_mapping, settings, n, options=options) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID ] # Reference entities also inherit from the classes in the second if statement so address them first. if isinstance(entity.flyte_entity, ReferenceEntity): ref_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) ref_template = ref_spec.template node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], ) if ref_template.resource_type == _identifier_model.ResourceType.TASK: node_model._task_node = workflow_model.TaskNode( reference_id=ref_template.id) elif ref_template.resource_type == _identifier_model.ResourceType.WORKFLOW: node_model._workflow_node = workflow_model.WorkflowNode( sub_workflow_ref=ref_template.id) elif ref_template.resource_type == _identifier_model.ResourceType.LAUNCH_PLAN: node_model._workflow_node = workflow_model.WorkflowNode( launchplan_ref=ref_template.id) else: raise Exception( f"Unexpected resource type for reference entity {entity.flyte_entity}: {ref_template.resource_type}" ) return node_model if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)), ) if entity._aliases: node_model._output_aliases = entity._aliases elif isinstance(entity.flyte_entity, WorkflowBase): wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( sub_workflow_ref=wf_spec.template.id), ) elif isinstance(entity.flyte_entity, BranchNode): node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options), ) elif isinstance(entity.flyte_entity, LaunchPlan): lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) # Node's inputs should not contain the data which is fixed input node_input = [] for b in entity.bindings: if b.var not in entity.flyte_entity.fixed_inputs.literals: node_input.append(b) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( launchplan_ref=lp_spec.id), ) elif isinstance(entity.flyte_entity, FlyteTask): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources)), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): wf_template = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) for _, sub_wf in entity.flyte_entity.sub_workflows.items(): get_serializable(entity_mapping, settings, sub_wf, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( sub_workflow_ref=wf_template.id), ) elif isinstance(entity.flyte_entity, FlyteLaunchPlan): # Recursive call doesn't do anything except put the entity on the map. get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) # Node's inputs should not contain the data which is fixed input node_input = [] for b in entity.bindings: if b.var not in entity.flyte_entity.fixed_inputs.literals: node_input.append(b) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, upstream_node_ids=[n.id for n in upstream_sdk_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode( launchplan_ref=entity.flyte_entity.id), ) else: raise Exception( f"Node contained non-serializable entity {entity._flyte_entity}") return node_model