Beispiel #1
0
def test_binding_data_list_static():
    upstream_nodes = set()
    bd = interface.BindingData.from_python_std(containers.List(
        primitives.String).to_flyte_literal_type(), ['abc', 'cde'],
                                               upstream_nodes=upstream_nodes)

    assert len(upstream_nodes) == 0
    assert bd.promise is None
    assert bd.collection.bindings[0].scalar.primitive.string_value == 'abc'
    assert bd.collection.bindings[1].scalar.primitive.string_value == 'cde'
    assert bd.map is None
    assert bd.scalar is None

    assert interface.BindingData.from_flyte_idl(bd.to_flyte_idl()) == bd

    with pytest.raises(_user_exceptions.FlyteTypeException):
        interface.BindingData.from_python_std(
            containers.List(primitives.String).to_flyte_literal_type(),
            'abc',
        )

    with pytest.raises(_user_exceptions.FlyteTypeException):
        interface.BindingData.from_python_std(
            containers.List(primitives.String).to_flyte_literal_type(),
            [1.0, 2.0, 3.0])
Beispiel #2
0
 def infer_sdk_type_from_literal(self, literal):  # noqa
     """
     :param flytekit.models.literals.Literal literal:
     :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType
     """
     if literal.collection is not None:
         if len(literal.collection.literals) > 0:
             sdk_type = _container_types.List(
                 _helpers.infer_sdk_type_from_literal(
                     literal.collection.literals[0]))
         else:
             sdk_type = _container_types.List(_base_sdk_types.Void)
     elif literal.map is not None:
         raise NotImplementedError("TODO: Implement map")
     elif literal.scalar.blob is not None:
         sdk_type = self._get_blob_impl_from_type(
             literal.scalar.blob.metadata.type)
     elif literal.scalar.none_type is not None:
         sdk_type = _base_sdk_types.Void
     elif literal.scalar.schema is not None:
         sdk_type = _schema.schema_instantiator_from_proto(
             literal.scalar.schema.type)
     elif literal.scalar.error is not None:
         raise NotImplementedError("TODO: Implement error from literal map")
     elif literal.scalar.generic is not None:
         sdk_type = _primitive_types.Generic
     elif literal.scalar.binary is not None:
         if literal.scalar.binary.tag.startswith(
                 _proto.Protobuf.TAG_PREFIX):
             sdk_type = _proto_sdk_type_from_tag(
                 literal.scalar.binary.tag[len(_proto.Protobuf.TAG_PREFIX
                                               ):])
         else:
             raise NotImplementedError(
                 "TODO: Binary is only supported for protobuf types currently"
             )
     elif literal.scalar.primitive.boolean is not None:
         sdk_type = _primitive_types.Boolean
     elif literal.scalar.primitive.datetime is not None:
         sdk_type = _primitive_types.Datetime
     elif literal.scalar.primitive.duration is not None:
         sdk_type = _primitive_types.Timedelta
     elif literal.scalar.primitive.float_value is not None:
         sdk_type = _primitive_types.Float
     elif literal.scalar.primitive.integer is not None:
         sdk_type = _primitive_types.Integer
     elif literal.scalar.primitive.string_value is not None:
         sdk_type = _primitive_types.String
     else:
         raise _system_exceptions.FlyteSystemAssertion(
             "Received unknown literal: {}".format(literal))
     return sdk_type
Beispiel #3
0
def test_string_list():
    list_type = containers.List(primitives.String)
    obj = list_type.from_string(
        '["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""]')
    assert len(obj.collection.literals) == 5
    assert obj.to_python_std() == [
        "fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""
    ]

    # Test that two classes of the same type are comparable
    list_type_two = containers.List(primitives.String)
    obj2 = list_type_two.from_string(
        '["fdsa", "fff3", "fdsfhuie", "frfJliEILles", ""]')
    assert obj == obj2
Beispiel #4
0
 def get_sdk_type_from_literal_type(self, literal_type):
     """
     :param flytekit.models.types.LiteralType literal_type:
     :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType
     """
     if literal_type.collection_type is not None:
         return _container_types.List(
             _helpers.get_sdk_type_from_literal_type(
                 literal_type.collection_type))
     elif literal_type.map_value_type is not None:
         raise NotImplementedError("TODO: Implement map")
     elif literal_type.schema is not None:
         return _schema.schema_instantiator_from_proto(literal_type.schema)
     elif literal_type.blob is not None:
         return self._get_blob_impl_from_type(literal_type.blob)
     elif literal_type.simple is not None:
         if (literal_type.simple == _literal_type_models.SimpleType.BINARY
                 and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata):
             return _proto_sdk_type_from_tag(
                 literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY])
         if (literal_type.simple == _literal_type_models.SimpleType.STRUCT
                 and literal_type.metadata
                 and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata):
             return _generic_proto_sdk_type_from_tag(
                 literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY])
         sdk_type = self._SIMPLE_TYPE_LOOKUP_TABLE.get(literal_type.simple)
         if sdk_type is None:
             raise NotImplementedError(
                 "We haven't implemented this type yet:  Simple type={}".
                 format(literal_type.simple))
         return sdk_type
     else:
         raise _system_exceptions.FlyteSystemAssertion(
             "An unrecognized literal type was received: {}".format(
                 literal_type))
Beispiel #5
0
 def python_std_to_sdk_type(self, t):
     """
     :param T t: User input.  Should be of the form: Types.Integer, [Types.Integer], {Types.String:
     Types.Integer}, etc.
     :rtype: flytekit.common.types.base_sdk_types.FlyteSdkType
     """
     if isinstance(t, list):
         if len(t) != 1:
             raise _user_exceptions.FlyteAssertion(
                 "When specifying a list type, there must be exactly one element in "
                 "the list describing the contained type.")
         return _container_types.List(_helpers.python_std_to_sdk_type(t[0]))
     elif isinstance(t, dict):
         raise _user_exceptions.FlyteAssertion(
             "Map types are not yet implemented.")
     elif isinstance(t, _base_sdk_types.FlyteSdkType):
         return t
     else:
         raise _user_exceptions.FlyteTypeException(
             type(t),
             _base_sdk_types.FlyteSdkType,
             additional_msg=
             "Should be of form similar to: Types.Integer, [Types.Integer], {Types.String: "
             "Types.Integer}",
             received_value=t,
         )
Beispiel #6
0
def test_reprs():
    list_type = containers.List(primitives.Integer)
    obj = list_type.from_python_std(list(_range(3)))
    assert obj.short_string(
    ) == "List<Integer>(len=3, [Integer(0), Integer(1), Integer(2)])"
    assert obj.verbose_string() == \
        "List<Integer>(\n" \
        "\tlen=3,\n" \
        "\t[\n" \
        "\t\tInteger(0),\n" \
        "\t\tInteger(1),\n" \
        "\t\tInteger(2)\n" \
        "\t]\n" \
        ")"

    nested_list_type = containers.List(containers.List(primitives.Integer))
    nested_obj = nested_list_type.from_python_std(
        [list(_range(3)), list(_range(3))])

    assert nested_obj.short_string() == \
        "List<List<Integer>>(len=2, [List<Integer>(len=3, [Integer(0), Integer(1), Integer(2)]), " \
        "List<Integer>(len=3, [Integer(0), Integer(1), Integer(2)])])"
    assert nested_obj.verbose_string() == \
        "List<List<Integer>>(\n" \
        "\tlen=2,\n" \
        "\t[\n" \
        "\t\tList<Integer>(\n" \
        "\t\t\tlen=3,\n" \
        "\t\t\t[\n" \
        "\t\t\t\tInteger(0),\n" \
        "\t\t\t\tInteger(1),\n" \
        "\t\t\t\tInteger(2)\n" \
        "\t\t\t]\n" \
        "\t\t),\n" \
        "\t\tList<Integer>(\n" \
        "\t\t\tlen=3,\n" \
        "\t\t\t[\n" \
        "\t\t\t\tInteger(0),\n" \
        "\t\t\t\tInteger(1),\n" \
        "\t\t\t\tInteger(2)\n" \
        "\t\t\t]\n" \
        "\t\t)\n" \
        "\t]\n" \
        ")"
Beispiel #7
0
def test_empty_parsing():
    list_type = containers.List(primitives.String)
    obj = list_type.from_string("[]")
    assert len(obj) == 0

    # The String primitive type does not allow lists or maps to be converted
    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_string('["fdjs", []]')

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_string('["fdjs", {}]')
Beispiel #8
0
def test_nested_list():
    list_type = containers.List(containers.List(primitives.Integer))

    assert list_type.to_flyte_literal_type().simple is None
    assert list_type.to_flyte_literal_type().map_value_type is None
    assert list_type.to_flyte_literal_type().schema is None
    assert list_type.to_flyte_literal_type().collection_type.simple is None
    assert list_type.to_flyte_literal_type(
    ).collection_type.map_value_type is None
    assert list_type.to_flyte_literal_type().collection_type.schema is None
    assert list_type.to_flyte_literal_type(
    ).collection_type.collection_type.simple == literal_types.SimpleType.INTEGER

    gt = [[1, 2, 3], [4, 5, 6], []]
    list_value = list_type.from_python_std(gt)
    assert list_value.to_python_std() == gt
    assert list_type.from_flyte_idl(list_value.to_flyte_idl()) == list_value

    assert list_value.collection.literals[0].collection.literals[
        0].scalar.primitive.integer == 1
    assert list_value.collection.literals[0].collection.literals[
        1].scalar.primitive.integer == 2
    assert list_value.collection.literals[0].collection.literals[
        2].scalar.primitive.integer == 3

    assert list_value.collection.literals[1].collection.literals[
        0].scalar.primitive.integer == 4
    assert list_value.collection.literals[1].collection.literals[
        1].scalar.primitive.integer == 5
    assert list_value.collection.literals[1].collection.literals[
        2].scalar.primitive.integer == 6

    assert len(list_value.collection.literals[2].collection.literals) == 0

    obj = list_type.from_string("[[1, 2, 3], [4, 5, 6]]")
    assert len(obj) == 2
    assert len(obj.collection.literals[0]) == 3
Beispiel #9
0
def test_list():
    list_type = containers.List(primitives.Integer)
    assert list_type.to_flyte_literal_type().simple is None
    assert list_type.to_flyte_literal_type().map_value_type is None
    assert list_type.to_flyte_literal_type().schema is None
    assert list_type.to_flyte_literal_type(
    ).collection_type.simple == literal_types.SimpleType.INTEGER

    list_value = list_type.from_python_std([1, 2, 3, 4])
    assert list_value.to_python_std() == [1, 2, 3, 4]
    assert list_type.from_flyte_idl(list_value.to_flyte_idl()) == list_value

    assert list_value.collection.literals[0].scalar.primitive.integer == 1
    assert list_value.collection.literals[1].scalar.primitive.integer == 2
    assert list_value.collection.literals[2].scalar.primitive.integer == 3
    assert list_value.collection.literals[3].scalar.primitive.integer == 4

    obj2 = list_type.from_string("[1, 2, 3,4]")
    assert obj2 == list_value

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_python_std(["a", "b", "c", "d"])

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_python_std([1, 2, 3, "abc"])

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_python_std(1)

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_python_std([[1]])

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_string('["fdsa"]')

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_string("[1, 2, 3, []]")

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_string("'[\"not list json\"]'")

    with pytest.raises(_user_exceptions.FlyteTypeException):
        list_type.from_string('["unclosed","list"')
Beispiel #10
0
def test_model_promotion():
    list_type = containers.List(primitives.Integer)
    list_model = literals.Literal(collection=literals.LiteralCollection(
        literals=[
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=0))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=1))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=2))),
        ]))
    list_obj = list_type.promote_from_model(list_model)
    assert len(list_obj.collection.literals) == 3
    assert isinstance(list_obj.collection.literals[0], primitives.Integer)
    assert list_obj == list_type.from_python_std([0, 1, 2])
    assert list_obj == list_type(
        [primitives.Integer(0),
         primitives.Integer(1),
         primitives.Integer(2)])
Beispiel #11
0
def test_input():
    i = Input(primitives.Integer, help="blah", default=None)
    assert i.name == ''
    assert i.sdk_default is None
    assert i.default == base_sdk_types.Void()
    assert i.sdk_required is False
    assert i.required is None
    assert i.help == "blah"
    assert i.var.description == "blah"
    assert i.sdk_type == primitives.Integer

    i = i.rename_and_return_reference('new_name')
    assert i.name == 'new_name'
    assert i.sdk_default is None
    assert i.default == base_sdk_types.Void()
    assert i.sdk_required is False
    assert i.required is None
    assert i.help == "blah"
    assert i.var.description == "blah"
    assert i.sdk_type == primitives.Integer

    i = Input(primitives.Integer, default=1)
    assert i.name == ''
    assert i.sdk_default is 1
    assert i.default == primitives.Integer(1)
    assert i.sdk_required is False
    assert i.required is None
    assert i.help is None
    assert i.var.description == ""
    assert i.sdk_type == primitives.Integer

    i = i.rename_and_return_reference('new_name')
    assert i.name == 'new_name'
    assert i.sdk_default is 1
    assert i.default == primitives.Integer(1)
    assert i.sdk_required is False
    assert i.required is None
    assert i.help is None
    assert i.var.description == ""
    assert i.sdk_type == primitives.Integer

    with pytest.raises(_user_exceptions.FlyteAssertion):
        Input(primitives.Integer, required=True, default=1)

    i = Input([primitives.Integer], default=[1, 2])
    assert i.name == ''
    assert i.sdk_default == [1, 2]
    assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)])
    assert i.sdk_required is False
    assert i.required is None
    assert i.help is None
    assert i.var.description == ""
    assert i.sdk_type == containers.List(primitives.Integer)

    i = i.rename_and_return_reference('new_name')
    assert i.name == 'new_name'
    assert i.sdk_default == [1, 2]
    assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)])
    assert i.sdk_required is False
    assert i.required is None
    assert i.help is None
    assert i.var.description == ""
    assert i.sdk_type == containers.List(primitives.Integer)
Beispiel #12
0
def test_workflow_node():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         "project", "domain", "my_task",
                                         "version")

    @inputs(a=[primitives.Integer])
    @outputs(b=[primitives.Integer])
    @python_task
    def my_list_task(wf_params, a, b):
        b.set([v + 1 for v in a])

    my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                              "project", "domain",
                                              "my_list_task", "version")

    input_list = [
        promise.Input("required", primitives.Integer),
        promise.Input("not_required",
                      primitives.Integer,
                      default=5,
                      help="Not required."),
    ]

    n1 = my_task(a=input_list[0]).assign_id_and_return("n1")
    n2 = my_task(a=input_list[1]).assign_id_and_return("n2")
    n3 = my_task(a=100).assign_id_and_return("n3")
    n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4")
    n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100
                         ]).assign_id_and_return("n5")
    n6 = my_list_task(a=n5.outputs.b)

    nodes = [n1, n2, n3, n4, n5, n6]

    wf_out = [
        _local_workflow.Output(
            "nested_out",
            [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]],
            sdk_type=[[primitives.Integer]],
        ),
        _local_workflow.Output("scalar_out",
                               n1.outputs.b,
                               sdk_type=primitives.Integer),
    ]

    w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition(
        inputs=input_list, outputs=wf_out, nodes=nodes)

    # Test that required input isn't set
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w()

    # Test that positional args are rejected
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w(1, 2)

    # Test that type checking works
    with _pytest.raises(_user_exceptions.FlyteTypeException):
        w(required="abc", not_required=1)

    # Test that bad arg name is detected
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w(required=1, bad_arg=1)

    # Test default input is accounted for
    n = w(required=10)
    assert n.inputs[0].var == "not_required"
    assert n.inputs[0].binding.scalar.primitive.integer == 5
    assert n.inputs[1].var == "required"
    assert n.inputs[1].binding.scalar.primitive.integer == 10

    # Test default input is overridden
    n = w(required=10, not_required=50)
    assert n.inputs[0].var == "not_required"
    assert n.inputs[0].binding.scalar.primitive.integer == 50
    assert n.inputs[1].var == "required"
    assert n.inputs[1].binding.scalar.primitive.integer == 10

    # Test that workflow is saved in the node
    w.id = "fake"
    assert n.workflow_node.sub_workflow_ref == "fake"
    w.id = None

    # Test that outputs are promised
    n.assign_id_and_return("node-id*")  # dns'ified
    assert n.outputs["scalar_out"].sdk_type.to_flyte_literal_type(
    ) == primitives.Integer.to_flyte_literal_type()
    assert n.outputs["scalar_out"].var == "scalar_out"
    assert n.outputs["scalar_out"].node_id == "node-id"

    assert (n.outputs["nested_out"].sdk_type.to_flyte_literal_type() ==
            containers.List(containers.List(
                primitives.Integer)).to_flyte_literal_type())
    assert n.outputs["nested_out"].var == "nested_out"
    assert n.outputs["nested_out"].node_id == "node-id"
Beispiel #13
0
def test_workflow_node():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         'project', 'domain', 'my_task',
                                         'version')

    @inputs(a=[primitives.Integer])
    @outputs(b=[primitives.Integer])
    @python_task
    def my_list_task(wf_params, a, b):
        b.set([v + 1 for v in a])

    my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                              'project', 'domain',
                                              'my_list_task', 'version')

    input_list = [
        promise.Input('required', primitives.Integer),
        promise.Input('not_required',
                      primitives.Integer,
                      default=5,
                      help='Not required.')
    ]

    n1 = my_task(a=input_list[0]).assign_id_and_return('n1')
    n2 = my_task(a=input_list[1]).assign_id_and_return('n2')
    n3 = my_task(a=100).assign_id_and_return('n3')
    n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4')
    n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100
                         ]).assign_id_and_return('n5')
    n6 = my_list_task(a=n5.outputs.b)

    nodes = [n1, n2, n3, n4, n5, n6]

    wf_out = [
        workflow.Output(
            'nested_out',
            [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]],
            sdk_type=[[primitives.Integer]]),
        workflow.Output('scalar_out',
                        n1.outputs.b,
                        sdk_type=primitives.Integer)
    ]

    w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes)

    # Test that required input isn't set
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w()

    # Test that positional args are rejected
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w(1, 2)

    # Test that type checking works
    with _pytest.raises(_user_exceptions.FlyteTypeException):
        w(required='abc', not_required=1)

    # Test that bad arg name is detected
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w(required=1, bad_arg=1)

    # Test default input is accounted for
    n = w(required=10)
    assert n.inputs[0].var == 'not_required'
    assert n.inputs[0].binding.scalar.primitive.integer == 5
    assert n.inputs[1].var == 'required'
    assert n.inputs[1].binding.scalar.primitive.integer == 10

    # Test default input is overridden
    n = w(required=10, not_required=50)
    assert n.inputs[0].var == 'not_required'
    assert n.inputs[0].binding.scalar.primitive.integer == 50
    assert n.inputs[1].var == 'required'
    assert n.inputs[1].binding.scalar.primitive.integer == 10

    # Test that workflow is saved in the node
    w._id = 'fake'
    assert n.workflow_node.sub_workflow_ref == 'fake'
    w._id = None

    # Test that outputs are promised
    n.assign_id_and_return('node-id*')  # dns'ified
    assert n.outputs['scalar_out'].sdk_type.to_flyte_literal_type(
    ) == primitives.Integer.to_flyte_literal_type()
    assert n.outputs['scalar_out'].var == 'scalar_out'
    assert n.outputs['scalar_out'].node_id == 'node-id'

    assert n.outputs['nested_out'].sdk_type.to_flyte_literal_type() == \
           containers.List(containers.List(primitives.Integer)).to_flyte_literal_type()
    assert n.outputs['nested_out'].var == 'nested_out'
    assert n.outputs['nested_out'].node_id == 'node-id'