def test_binding_data_list_static(): upstream_nodes = set() bd = interface.BindingData.from_python_std(containers.List( primitives.String).to_flyte_literal_type(), ['abc', 'cde'], upstream_nodes=upstream_nodes) assert len(upstream_nodes) == 0 assert bd.promise is None assert bd.collection.bindings[0].scalar.primitive.string_value == 'abc' assert bd.collection.bindings[1].scalar.primitive.string_value == 'cde' assert bd.map is None assert bd.scalar is None assert interface.BindingData.from_flyte_idl(bd.to_flyte_idl()) == bd with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( containers.List(primitives.String).to_flyte_literal_type(), 'abc', ) with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( containers.List(primitives.String).to_flyte_literal_type(), [1.0, 2.0, 3.0])
def infer_sdk_type_from_literal(self, literal): # noqa """ :param flytekit.models.literals.Literal literal: :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType """ if literal.collection is not None: if len(literal.collection.literals) > 0: sdk_type = _container_types.List( _helpers.infer_sdk_type_from_literal( literal.collection.literals[0])) else: sdk_type = _container_types.List(_base_sdk_types.Void) elif literal.map is not None: raise NotImplementedError("TODO: Implement map") elif literal.scalar.blob is not None: sdk_type = self._get_blob_impl_from_type( literal.scalar.blob.metadata.type) elif literal.scalar.none_type is not None: sdk_type = _base_sdk_types.Void elif literal.scalar.schema is not None: sdk_type = _schema.schema_instantiator_from_proto( literal.scalar.schema.type) elif literal.scalar.error is not None: raise NotImplementedError("TODO: Implement error from literal map") elif literal.scalar.generic is not None: sdk_type = _primitive_types.Generic elif literal.scalar.binary is not None: if literal.scalar.binary.tag.startswith( _proto.Protobuf.TAG_PREFIX): sdk_type = _proto_sdk_type_from_tag( literal.scalar.binary.tag[len(_proto.Protobuf.TAG_PREFIX ):]) else: raise NotImplementedError( "TODO: Binary is only supported for protobuf types currently" ) elif literal.scalar.primitive.boolean is not None: sdk_type = _primitive_types.Boolean elif literal.scalar.primitive.datetime is not None: sdk_type = _primitive_types.Datetime elif literal.scalar.primitive.duration is not None: sdk_type = _primitive_types.Timedelta elif literal.scalar.primitive.float_value is not None: sdk_type = _primitive_types.Float elif literal.scalar.primitive.integer is not None: sdk_type = _primitive_types.Integer elif literal.scalar.primitive.string_value is not None: sdk_type = _primitive_types.String else: raise _system_exceptions.FlyteSystemAssertion( "Received unknown literal: {}".format(literal)) return sdk_type
def test_string_list(): list_type = containers.List(primitives.String) obj = list_type.from_string( '["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""]') assert len(obj.collection.literals) == 5 assert obj.to_python_std() == [ "fdsa", "fff3", "fdsfhuie", "frfJliEILles", "" ] # Test that two classes of the same type are comparable list_type_two = containers.List(primitives.String) obj2 = list_type_two.from_string( '["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""]') assert obj == obj2
def get_sdk_type_from_literal_type(self, literal_type): """ :param flytekit.models.types.LiteralType literal_type: :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType """ if literal_type.collection_type is not None: return _container_types.List( _helpers.get_sdk_type_from_literal_type( literal_type.collection_type)) elif literal_type.map_value_type is not None: raise NotImplementedError("TODO: Implement map") elif literal_type.schema is not None: return _schema.schema_instantiator_from_proto(literal_type.schema) elif literal_type.blob is not None: return self._get_blob_impl_from_type(literal_type.blob) elif literal_type.simple is not None: if (literal_type.simple == _literal_type_models.SimpleType.BINARY and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata): return _proto_sdk_type_from_tag( literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) if (literal_type.simple == _literal_type_models.SimpleType.STRUCT and literal_type.metadata and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata): return _generic_proto_sdk_type_from_tag( literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) sdk_type = self._SIMPLE_TYPE_LOOKUP_TABLE.get(literal_type.simple) if sdk_type is None: raise NotImplementedError( "We haven't implemented this type yet: Simple type={}". format(literal_type.simple)) return sdk_type else: raise _system_exceptions.FlyteSystemAssertion( "An unrecognized literal type was received: {}".format( literal_type))
def python_std_to_sdk_type(self, t): """ :param T t: User input. Should be of the form: Types.Integer, [Types.Integer], {Types.String: Types.Integer}, etc. :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType """ if isinstance(t, list): if len(t) != 1: raise _user_exceptions.FlyteAssertion( "When specifying a list type, there must be exactly one element in " "the list describing the contained type.") return _container_types.List(_helpers.python_std_to_sdk_type(t[0])) elif isinstance(t, dict): raise _user_exceptions.FlyteAssertion( "Map types are not yet implemented.") elif isinstance(t, _base_sdk_types.FlyteSdkType): return t else: raise _user_exceptions.FlyteTypeException( type(t), _base_sdk_types.FlyteSdkType, additional_msg= "Should be of form similar to: Types.Integer, [Types.Integer], {Types.String: " "Types.Integer}", received_value=t, )
def test_reprs(): list_type = containers.List(primitives.Integer) obj = list_type.from_python_std(list(_range(3))) assert obj.short_string( ) == "List<Integer>(len=3, [Integer(0), Integer(1), Integer(2)])" assert obj.verbose_string() == \ "List<Integer>(\n" \ "\tlen=3,\n" \ "\t[\n" \ "\t\tInteger(0),\n" \ "\t\tInteger(1),\n" \ "\t\tInteger(2)\n" \ "\t]\n" \ ")" nested_list_type = containers.List(containers.List(primitives.Integer)) nested_obj = nested_list_type.from_python_std( [list(_range(3)), list(_range(3))]) assert nested_obj.short_string() == \ "List<List<Integer>>(len=2, [List<Integer>(len=3, [Integer(0), Integer(1), Integer(2)]), " \ "List<Integer>(len=3, [Integer(0), Integer(1), Integer(2)])])" assert nested_obj.verbose_string() == \ "List<List<Integer>>(\n" \ "\tlen=2,\n" \ "\t[\n" \ "\t\tList<Integer>(\n" \ "\t\t\tlen=3,\n" \ "\t\t\t[\n" \ "\t\t\t\tInteger(0),\n" \ "\t\t\t\tInteger(1),\n" \ "\t\t\t\tInteger(2)\n" \ "\t\t\t]\n" \ "\t\t),\n" \ "\t\tList<Integer>(\n" \ "\t\t\tlen=3,\n" \ "\t\t\t[\n" \ "\t\t\t\tInteger(0),\n" \ "\t\t\t\tInteger(1),\n" \ "\t\t\t\tInteger(2)\n" \ "\t\t\t]\n" \ "\t\t)\n" \ "\t]\n" \ ")"
def test_empty_parsing(): list_type = containers.List(primitives.String) obj = list_type.from_string("[]") assert len(obj) == 0 # The String primitive type does not allow lists or maps to be converted with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string('["fdjs", []]') with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string('["fdjs", {}]')
def test_nested_list(): list_type = containers.List(containers.List(primitives.Integer)) assert list_type.to_flyte_literal_type().simple is None assert list_type.to_flyte_literal_type().map_value_type is None assert list_type.to_flyte_literal_type().schema is None assert list_type.to_flyte_literal_type().collection_type.simple is None assert list_type.to_flyte_literal_type( ).collection_type.map_value_type is None assert list_type.to_flyte_literal_type().collection_type.schema is None assert list_type.to_flyte_literal_type( ).collection_type.collection_type.simple == literal_types.SimpleType.INTEGER gt = [[1, 2, 3], [4, 5, 6], []] list_value = list_type.from_python_std(gt) assert list_value.to_python_std() == gt assert list_type.from_flyte_idl(list_value.to_flyte_idl()) == list_value assert list_value.collection.literals[0].collection.literals[ 0].scalar.primitive.integer == 1 assert list_value.collection.literals[0].collection.literals[ 1].scalar.primitive.integer == 2 assert list_value.collection.literals[0].collection.literals[ 2].scalar.primitive.integer == 3 assert list_value.collection.literals[1].collection.literals[ 0].scalar.primitive.integer == 4 assert list_value.collection.literals[1].collection.literals[ 1].scalar.primitive.integer == 5 assert list_value.collection.literals[1].collection.literals[ 2].scalar.primitive.integer == 6 assert len(list_value.collection.literals[2].collection.literals) == 0 obj = list_type.from_string("[[1, 2, 3], [4, 5, 6]]") assert len(obj) == 2 assert len(obj.collection.literals[0]) == 3
def test_list(): list_type = containers.List(primitives.Integer) assert list_type.to_flyte_literal_type().simple is None assert list_type.to_flyte_literal_type().map_value_type is None assert list_type.to_flyte_literal_type().schema is None assert list_type.to_flyte_literal_type( ).collection_type.simple == literal_types.SimpleType.INTEGER list_value = list_type.from_python_std([1, 2, 3, 4]) assert list_value.to_python_std() == [1, 2, 3, 4] assert list_type.from_flyte_idl(list_value.to_flyte_idl()) == list_value assert list_value.collection.literals[0].scalar.primitive.integer == 1 assert list_value.collection.literals[1].scalar.primitive.integer == 2 assert list_value.collection.literals[2].scalar.primitive.integer == 3 assert list_value.collection.literals[3].scalar.primitive.integer == 4 obj2 = list_type.from_string("[1, 2, 3,4]") assert obj2 == list_value with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_python_std(["a", "b", "c", "d"]) with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_python_std([1, 2, 3, "abc"]) with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_python_std(1) with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_python_std([[1]]) with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string('["fdsa"]') with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string("[1, 2, 3, []]") with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string("'[\"not list json\"]'") with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string('["unclosed","list"')
def test_model_promotion(): list_type = containers.List(primitives.Integer) list_model = literals.Literal(collection=literals.LiteralCollection( literals=[ literals.Literal(scalar=literals.Scalar( primitive=literals.Primitive(integer=0))), literals.Literal(scalar=literals.Scalar( primitive=literals.Primitive(integer=1))), literals.Literal(scalar=literals.Scalar( primitive=literals.Primitive(integer=2))), ])) list_obj = list_type.promote_from_model(list_model) assert len(list_obj.collection.literals) == 3 assert isinstance(list_obj.collection.literals[0], primitives.Integer) assert list_obj == list_type.from_python_std([0, 1, 2]) assert list_obj == list_type( [primitives.Integer(0), primitives.Integer(1), primitives.Integer(2)])
def test_input(): i = Input(primitives.Integer, help="blah", default=None) assert i.name == '' assert i.sdk_default is None assert i.default == base_sdk_types.Void() assert i.sdk_required is False assert i.required is None assert i.help == "blah" assert i.var.description == "blah" assert i.sdk_type == primitives.Integer i = i.rename_and_return_reference('new_name') assert i.name == 'new_name' assert i.sdk_default is None assert i.default == base_sdk_types.Void() assert i.sdk_required is False assert i.required is None assert i.help == "blah" assert i.var.description == "blah" assert i.sdk_type == primitives.Integer i = Input(primitives.Integer, default=1) assert i.name == '' assert i.sdk_default is 1 assert i.default == primitives.Integer(1) assert i.sdk_required is False assert i.required is None assert i.help is None assert i.var.description == "" assert i.sdk_type == primitives.Integer i = i.rename_and_return_reference('new_name') assert i.name == 'new_name' assert i.sdk_default is 1 assert i.default == primitives.Integer(1) assert i.sdk_required is False assert i.required is None assert i.help is None assert i.var.description == "" assert i.sdk_type == primitives.Integer with pytest.raises(_user_exceptions.FlyteAssertion): Input(primitives.Integer, required=True, default=1) i = Input([primitives.Integer], default=[1, 2]) assert i.name == '' assert i.sdk_default == [1, 2] assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)]) assert i.sdk_required is False assert i.required is None assert i.help is None assert i.var.description == "" assert i.sdk_type == containers.List(primitives.Integer) i = i.rename_and_return_reference('new_name') assert i.name == 'new_name' assert i.sdk_default == [1, 2] assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)]) assert i.sdk_required is False assert i.required is None assert i.help is None assert i.var.description == "" assert i.sdk_type == containers.List(primitives.Integer)
def test_workflow_node(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version") input_list = [ promise.Input("required", primitives.Integer), promise.Input("not_required", primitives.Integer, default=5, help="Not required."), ] n1 = my_task(a=input_list[0]).assign_id_and_return("n1") n2 = my_task(a=input_list[1]).assign_id_and_return("n2") n3 = my_task(a=100).assign_id_and_return("n3") n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return("n5") n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ _local_workflow.Output( "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], ), _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition( inputs=input_list, outputs=wf_out, nodes=nodes) # Test that required input isn't set with _pytest.raises(_user_exceptions.FlyteAssertion): w() # Test that positional args are rejected with _pytest.raises(_user_exceptions.FlyteAssertion): w(1, 2) # Test that type checking works with _pytest.raises(_user_exceptions.FlyteTypeException): w(required="abc", not_required=1) # Test that bad arg name is detected with _pytest.raises(_user_exceptions.FlyteAssertion): w(required=1, bad_arg=1) # Test default input is accounted for n = w(required=10) assert n.inputs[0].var == "not_required" assert n.inputs[0].binding.scalar.primitive.integer == 5 assert n.inputs[1].var == "required" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test default input is overridden n = w(required=10, not_required=50) assert n.inputs[0].var == "not_required" assert n.inputs[0].binding.scalar.primitive.integer == 50 assert n.inputs[1].var == "required" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test that workflow is saved in the node w.id = "fake" assert n.workflow_node.sub_workflow_ref == "fake" w.id = None # Test that outputs are promised n.assign_id_and_return("node-id*") # dns'ified assert n.outputs["scalar_out"].sdk_type.to_flyte_literal_type( ) == primitives.Integer.to_flyte_literal_type() assert n.outputs["scalar_out"].var == "scalar_out" assert n.outputs["scalar_out"].node_id == "node-id" assert (n.outputs["nested_out"].sdk_type.to_flyte_literal_type() == containers.List(containers.List( primitives.Integer)).to_flyte_literal_type()) assert n.outputs["nested_out"].var == "nested_out" assert n.outputs["nested_out"].node_id == "node-id"
def test_workflow_node(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', 'version') input_list = [ promise.Input('required', primitives.Integer), promise.Input('not_required', primitives.Integer, default=5, help='Not required.') ] n1 = my_task(a=input_list[0]).assign_id_and_return('n1') n2 = my_task(a=input_list[1]).assign_id_and_return('n2') n3 = my_task(a=100).assign_id_and_return('n3') n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return('n5') n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ workflow.Output( 'nested_out', [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]]), workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer) ] w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes) # Test that required input isn't set with _pytest.raises(_user_exceptions.FlyteAssertion): w() # Test that positional args are rejected with _pytest.raises(_user_exceptions.FlyteAssertion): w(1, 2) # Test that type checking works with _pytest.raises(_user_exceptions.FlyteTypeException): w(required='abc', not_required=1) # Test that bad arg name is detected with _pytest.raises(_user_exceptions.FlyteAssertion): w(required=1, bad_arg=1) # Test default input is accounted for n = w(required=10) assert n.inputs[0].var == 'not_required' assert n.inputs[0].binding.scalar.primitive.integer == 5 assert n.inputs[1].var == 'required' assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test default input is overridden n = w(required=10, not_required=50) assert n.inputs[0].var == 'not_required' assert n.inputs[0].binding.scalar.primitive.integer == 50 assert n.inputs[1].var == 'required' assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test that workflow is saved in the node w._id = 'fake' assert n.workflow_node.sub_workflow_ref == 'fake' w._id = None # Test that outputs are promised n.assign_id_and_return('node-id*') # dns'ified assert n.outputs['scalar_out'].sdk_type.to_flyte_literal_type( ) == primitives.Integer.to_flyte_literal_type() assert n.outputs['scalar_out'].var == 'scalar_out' assert n.outputs['scalar_out'].node_id == 'node-id' assert n.outputs['nested_out'].sdk_type.to_flyte_literal_type() == \ containers.List(containers.List(primitives.Integer)).to_flyte_literal_type() assert n.outputs['nested_out'].var == 'nested_out' assert n.outputs['nested_out'].node_id == 'node-id'