def test_binding_data_map(): b1 = literals.BindingData(scalar=literals.Scalar( primitive=literals.Primitive(integer=5))) b2 = literals.BindingData(scalar=literals.Scalar( primitive=literals.Primitive(integer=57))) b3 = literals.BindingData(scalar=literals.Scalar( primitive=literals.Primitive(integer=2))) binding_map_sub = literals.BindingDataMap(bindings={ "first": b1, "second": b2 }) binding_map = literals.BindingDataMap( bindings={ "three": b3, "sample_map": literals.BindingData(map=binding_map_sub) }) obj = literals.BindingData(map=binding_map) assert obj.scalar is None assert obj.promise is None assert obj.collection is None assert obj.value.bindings["three"].value.value.value == 2 assert obj.value.bindings["sample_map"].value.bindings[ "second"].value.value.value == 57 obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.scalar is None assert obj2.promise is None assert obj2.collection is None assert obj2.value.bindings["three"].value.value.value == 2 assert obj2.value.bindings["sample_map"].value.bindings[ "first"].value.value.value == 5
def test_node_task_with_inputs(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=99))) binding = _literals.Binding(var="myvar", binding=bd) binding2 = _literals.Binding(var="myothervar", binding=bd2) obj = _workflow.Node( id="some:node:id", metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], task_node=task, ) assert obj.target == task assert obj.id == "some:node:id" assert obj.metadata == nm assert len(obj.inputs) == 2 assert obj.inputs[0] == binding obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.target == task assert obj2.id == "some:node:id" assert obj2.metadata == nm assert len(obj2.inputs) == 2 assert obj2.inputs[1] == binding2
def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, t_value: typing.Any, t_value_type: type, ) -> _literals_models.BindingData: # This handles the case where the given value is the output of another task if isinstance(t_value, Promise): if not t_value.is_ready: return _literals_models.BindingData(promise=t_value.ref) elif isinstance(t_value, VoidPromise): raise AssertionError( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) elif isinstance(t_value, list): if expected_literal_type.collection_type is None: raise AssertionError( f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}" ) sub_type = ListTransformer.get_sub_type(t_value_type) collection = _literals_models.BindingDataCollection(bindings=[ binding_data_from_python_std( ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value ]) return _literals_models.BindingData(collection=collection) elif isinstance(t_value, dict): if (expected_literal_type.map_value_type is None and expected_literal_type.simple != _type_models.SimpleType.STRUCT): raise AssertionError( f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}" ) k_type, v_type = DictTransformer.get_dict_types(t_value_type) if expected_literal_type.simple == _type_models.SimpleType.STRUCT: lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: m = _literals_models.BindingDataMap( bindings={ k: binding_data_from_python_std( ctx, expected_literal_type.map_value_type, v, v_type) for k, v in t_value.items() }) return _literals_models.BindingData(map=m) # This is the scalar case - e.g. my_task(in1=5) scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar return _literals_models.BindingData(scalar=scalar)
def test_branch_node(): nm = _get_sample_node_metadata() task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=99))) binding = _literals.Binding(var='myvar', binding=bd) binding2 = _literals.Binding(var='myothervar', binding=bd2) obj = _workflow.Node(id='some:node:id', metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], task_node=task) bn = _workflow.BranchNode( _workflow.IfElseBlock( case=_workflow.IfBlock(condition=_condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand(primitive=_literals.Primitive( integer=5)), _condition.Operand(primitive=_literals.Primitive( integer=2)))), then_node=obj), other=[ _workflow.IfBlock(condition=_condition.BooleanExpression( conjunction=_condition.ConjunctionExpression( _condition.ConjunctionExpression.LogicalOperator.AND, _condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand( primitive=_literals.Primitive(integer=5)), _condition.Operand( primitive=_literals.Primitive( integer=2)))), _condition.BooleanExpression( comparison=_condition.ComparisonExpression( _condition.ComparisonExpression.Operator.EQ, _condition.Operand( primitive=_literals.Primitive(integer=5)), _condition.Operand( primitive=_literals.Primitive( integer=2)))))), then_node=obj) ], else_node=obj))
def test_binding_data_collection(): b1 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5))) b2 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=57))) coll = literals.BindingDataCollection(bindings=[b1, b2]) obj = literals.BindingData(collection=coll) assert obj.scalar is None assert obj.promise is None assert obj.collection is not None assert obj.map is None assert obj.value.bindings[0].value.value.value == 5 assert obj.value.bindings[1].value.value.value == 57 obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.scalar is None assert obj2.promise is None assert obj2.collection is not None assert obj2.map is None assert obj2.value.bindings[0].value.value.value == 5 assert obj2.value.bindings[1].value.value.value == 57
def test_binding_data_scalar(): obj = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5))) assert obj.value.value.value == 5 assert obj.promise is None assert obj.collection is None assert obj.map is None obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.value.value.value == 5 assert obj2.promise is None assert obj2.collection is None assert obj2.map is None
def test_codecov(): with pytest.raises(FlyteValidationException): get_promise(literal_models.BindingData(), {}) with pytest.raises(FlyteValidationException): get_promise(literal_models.BindingData(promise=3), {}) @task def t1(a: str) -> str: return a + " world" wb = ImperativeWorkflow(name="my.workflow") wb.add_workflow_input("in1", str) node = wb.add_entity(t1, a=wb.inputs["in1"]) wb.add_workflow_output("from_n0t1", node.outputs["o0"]) assert wb(in1="hello") == "hello world" with pytest.raises(AssertionError): wb(3) with pytest.raises(ValueError): wb(in2="hello")
def test_binding_data_promise(): obj = literals.BindingData(promise=_types.OutputReference('some_node', 'myvar')) assert obj.scalar is None assert obj.promise is not None assert obj.collection is None assert obj.map is None assert obj.value.node_id == 'some_node' assert obj.value.var == 'myvar' obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.scalar is None assert obj2.promise is not None assert obj2.collection is None assert obj2.map is None
def test_future_task_document(task): rs = _literals.RetryStrategy(0) nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs) n = _workflow.Node(id="id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=_workflow.TaskNode(task.id)) n.to_flyte_idl() doc = _dynamic_job.DynamicJobSpec( tasks=[task], nodes=[n], min_successes=1, outputs=[_literals.Binding("var", _literals.BindingData())], subworkflows=[]) assert text_format.MessageToString( doc.to_flyte_idl()) == text_format.MessageToString( _dynamic_job.DynamicJobSpec.from_flyte_idl( doc.to_flyte_idl()).to_flyte_idl())
def to_binding(p: Promise) -> _literals_models.Binding: return _literals_models.Binding( var=p.var, binding=_literals_models.BindingData(promise=p.ref))
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, t_value: typing.Any, t_value_type: type, ) -> _literals_models.BindingData: # This handles the case where the given value is the output of another task if isinstance(t_value, Promise): if not t_value.is_ready: return _literals_models.BindingData(promise=t_value.ref) elif isinstance(t_value, VoidPromise): raise AssertionError( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) elif isinstance(t_value, list): if expected_literal_type.collection_type is None: raise AssertionError( f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}" ) sub_type = ListTransformer.get_sub_type(t_value_type) collection = _literals_models.BindingDataCollection(bindings=[ binding_data_from_python_std( ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value ]) return _literals_models.BindingData(collection=collection) elif isinstance(t_value, dict): if (expected_literal_type.map_value_type is None and expected_literal_type.simple != _type_models.SimpleType.STRUCT): raise AssertionError( f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}" ) k_type, v_type = DictTransformer.get_dict_types(t_value_type) if expected_literal_type.simple == _type_models.SimpleType.STRUCT: lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: m = _literals_models.BindingDataMap( bindings={ k: binding_data_from_python_std( ctx, expected_literal_type.map_value_type, v, v_type) for k, v in t_value.items() }) return _literals_models.BindingData(map=m) elif isinstance(t_value, tuple): raise AssertionError( "Tuples are not a supported type for individual values in Flyte - got a tuple -" f" {t_value}. If using named tuple in an inner task, please, de-reference the" "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element") # This is the scalar case - e.g. my_task(in1=5) scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar return _literals_models.BindingData(scalar=scalar)