Beispiel #1
0
def test_get_sdk_value_from_literal():
    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())))
    assert o.to_python_std() is None

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())),
        sdk_type=_sdk_types.Types.Integer,
    )
    assert o.to_python_std() is None

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=1))),
        sdk_type=_sdk_types.Types.Integer,
    )
    assert o.to_python_std() == 1

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(collection=_literals.LiteralCollection([
            _literals.Literal(scalar=_literals.Scalar(
                primitive=_literals.Primitive(integer=1))),
            _literals.Literal(scalar=_literals.Scalar(
                none_type=_literals.Void())),
        ])))
    assert o.to_python_std() == [1, None]
Beispiel #2
0
    def fulfil_bindings(binding_data, fulfilled_promises):
        """
        Substitutes promise values in binding_data with model Literal values built from python std values in
        fulfilled_promises

        :param _interface.BindingData binding_data:
        :param dict[Text,T] fulfilled_promises:
        :rtype:
        """
        if binding_data.scalar:
            return _literals.Literal(scalar=binding_data.scalar)
        elif binding_data.collection:
            return _literals.Literal(collection=_literals.LiteralCollection(
                [DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for sub_binding_data in
                 binding_data.collection.bindings]))
        elif binding_data.promise:
            if binding_data.promise.node_id not in fulfilled_promises:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id))
            node_output = fulfilled_promises[binding_data.promise.node_id]
            if binding_data.promise.var not in node_output:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output [{}] of node [{}] but that hasn't been produced.".format(
                        binding_data.promise.var,
                        binding_data.promise.node_id))

            return binding_data.promise.sdk_type.from_python_std(node_output[binding_data.promise.var])
        elif binding_data.map:
            return _literals.Literal(map=_literals.LiteralMap(
                {
                    k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
                    _six.iteritems(binding_data.map.bindings)
                }))
Beispiel #3
0
def get_promise(binding_data: _literal_models.BindingData, outputs_cache: Dict[Node, Dict[str, Promise]]) -> Promise:
    """
    This is a helper function that will turn a binding into a Promise object, using a lookup map. Please see
    get_promise_map for the rest of the details.
    """
    if binding_data.promise is not None:
        if not isinstance(binding_data.promise, NodeOutput):
            raise FlyteValidationException(
                f"Binding data Promises have to be of the NodeOutput type {type(binding_data.promise)} found"
            )
        # b.var is the name of the input to the task
        # binding_data.promise.var is the name of the upstream node's output we want
        return outputs_cache[binding_data.promise.node][binding_data.promise.var]
    elif binding_data.scalar is not None:
        return Promise(var="placeholder", val=_literal_models.Literal(scalar=binding_data.scalar))
    elif binding_data.collection is not None:
        literals = []
        for bd in binding_data.collection.bindings:
            p = get_promise(bd, outputs_cache)
            literals.append(p.val)
        return Promise(
            var="placeholder",
            val=_literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literals)),
        )
    elif binding_data.map is not None:
        literals = {}
        for k, bd in binding_data.map.bindings.items():
            p = get_promise(bd, outputs_cache)
            literals[k] = p.val
        return Promise(
            var="placeholder", val=_literal_models.Literal(map=_literal_models.LiteralMap(literals=literals))
        )

    raise FlyteValidationException("Binding type unrecognized.")
Beispiel #4
0
 def extract_value(
     ctx: FlyteContext, input_val: Any, val_type: type,
     flyte_literal_type: _type_models.LiteralType
 ) -> _literal_models.Literal:
     if isinstance(input_val, list):
         if flyte_literal_type.collection_type is None:
             raise Exception(
                 f"Not a collection type {flyte_literal_type} but got a list {input_val}"
             )
         try:
             sub_type = ListTransformer.get_sub_type(val_type)
         except ValueError:
             if len(input_val) == 0:
                 raise
             sub_type = type(input_val[0])
         literals = [
             extract_value(ctx, v, sub_type,
                           flyte_literal_type.collection_type)
             for v in input_val
         ]
         return _literal_models.Literal(
             collection=_literal_models.LiteralCollection(
                 literals=literals))
     elif isinstance(input_val, dict):
         if (flyte_literal_type.map_value_type is None
                 and flyte_literal_type.simple !=
                 _type_models.SimpleType.STRUCT):
             raise Exception(
                 f"Not a map type {flyte_literal_type} but got a map {input_val}"
             )
         k_type, sub_type = DictTransformer.get_dict_types(val_type)
         if flyte_literal_type.simple == _type_models.SimpleType.STRUCT:
             return TypeEngine.to_literal(ctx, input_val, type(input_val),
                                          flyte_literal_type)
         else:
             literals = {
                 k: extract_value(ctx, v, sub_type,
                                  flyte_literal_type.map_value_type)
                 for k, v in input_val.items()
             }
             return _literal_models.Literal(map=_literal_models.LiteralMap(
                 literals=literals))
     elif isinstance(input_val, Promise):
         # In the example above, this handles the "in2=a" type of argument
         return input_val.val
     elif isinstance(input_val, VoidPromise):
         raise AssertionError(
             f"Outputs of a non-output producing task {input_val.task_name} cannot be passed to another task."
         )
     elif isinstance(input_val, tuple):
         raise AssertionError(
             "Tuples are not a supported type for individual values in Flyte - got a tuple -"
             f" {input_val}. If using named tuple in an inner task, please, de-reference the"
             "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
             "return v.x, instead of v, even if this has a single element")
     else:
         # This handles native values, the 5 example
         return TypeEngine.to_literal(ctx, input_val, val_type,
                                      flyte_literal_type)
def test_arrayjob_entrypoint_in_proc():
    with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__),
                                              'fake.config'),
                                 internal_overrides={
                                     'project': 'test',
                                     'domain': 'development'
                                 }):
        with _utils.AutoDeletingTempDir("dir") as dir:
            literal_map = _type_helpers.pack_python_std_map_to_literal_map(
                {'a': 9},
                _type_map_from_variable_map(
                    _task_defs.add_one.interface.inputs))

            input_dir = os.path.join(dir.name, "1")
            os.mkdir(
                input_dir)  # auto cleanup will take this subdir into account

            input_file = os.path.join(input_dir, "inputs.pb")
            _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)

            # construct indexlookup.pb which has array: [1]
            mapped_index = _literals.Literal(
                _literals.Scalar(primitive=_literals.Primitive(integer=1)))
            index_lookup_collection = _literals.LiteralCollection(
                [mapped_index])
            index_lookup_file = os.path.join(dir.name, "indexlookup.pb")
            _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(),
                                       index_lookup_file)

            # fake arrayjob task by setting environment variables
            orig_env_index_var_name = os.environ.get(
                'BATCH_JOB_ARRAY_INDEX_VAR_NAME')
            orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX')
            os.environ[
                'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX'
            os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0'

            execute_task(_task_defs.add_one.task_module,
                         _task_defs.add_one.task_function_name, dir.name,
                         dir.name, False)

            raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
                _literal_models.LiteralMap.from_flyte_idl(
                    _utils.load_proto_from_file(
                        _literals_pb2.LiteralMap,
                        os.path.join(input_dir, _constants.OUTPUT_FILE_NAME))),
                _type_map_from_variable_map(
                    _task_defs.add_one.interface.outputs))
            assert raw_map['b'] == 10
            assert len(raw_map) == 1

            # reset the env vars
            if orig_env_index_var_name:
                os.environ[
                    'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = orig_env_index_var_name
            if orig_env_array_index:
                os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = orig_env_array_index
Beispiel #6
0
def test_literal_collection(literal_value_pair):
    lit, _ = literal_value_pair
    obj = literals.LiteralCollection([lit, lit, lit])
    assert all(ll == lit for ll in obj.literals)
    assert len(obj.literals) == 3

    obj2 = literals.LiteralCollection.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert all(ll == lit for ll in obj.literals)
    assert len(obj.literals) == 3
Beispiel #7
0
 def extract_value(
     ctx: FlyteContext, input_val: Any, val_type: type,
     flyte_literal_type: _type_models.LiteralType
 ) -> _literal_models.Literal:
     if isinstance(input_val, list):
         if flyte_literal_type.collection_type is None:
             raise Exception(
                 f"Not a collection type {flyte_literal_type} but got a list {input_val}"
             )
         try:
             sub_type = ListTransformer.get_sub_type(val_type)
         except ValueError:
             if len(input_val) == 0:
                 raise
             sub_type = type(input_val[0])
         literals = [
             extract_value(ctx, v, sub_type,
                           flyte_literal_type.collection_type)
             for v in input_val
         ]
         return _literal_models.Literal(
             collection=_literal_models.LiteralCollection(
                 literals=literals))
     elif isinstance(input_val, dict):
         if (flyte_literal_type.map_value_type is None
                 and flyte_literal_type.simple !=
                 _type_models.SimpleType.STRUCT):
             raise Exception(
                 f"Not a map type {flyte_literal_type} but got a map {input_val}"
             )
         k_type, sub_type = DictTransformer.get_dict_types(val_type)
         if flyte_literal_type.simple == _type_models.SimpleType.STRUCT:
             return TypeEngine.to_literal(ctx, input_val, type(input_val),
                                          flyte_literal_type)
         else:
             literals = {
                 k: extract_value(ctx, v, sub_type,
                                  flyte_literal_type.map_value_type)
                 for k, v in input_val.items()
             }
             return _literal_models.Literal(map=_literal_models.LiteralMap(
                 literals=literals))
     elif isinstance(input_val, Promise):
         # In the example above, this handles the "in2=a" type of argument
         return input_val.val
     elif isinstance(input_val, VoidPromise):
         raise AssertionError(
             f"Outputs of a non-output producing task {input_val.task_name} cannot be passed to another task."
         )
     else:
         # This handles native values, the 5 example
         return TypeEngine.to_literal(ctx, input_val, val_type,
                                      flyte_literal_type)
Beispiel #8
0
def test_model_promotion():
    list_type = containers.List(primitives.Integer)
    list_model = literals.Literal(collection=literals.LiteralCollection(
        literals=[
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=0))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=1))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=2))),
        ]))
    list_obj = list_type.promote_from_model(list_model)
    assert len(list_obj.collection.literals) == 3
    assert isinstance(list_obj.collection.literals[0], primitives.Integer)
    assert list_obj == list_type.from_python_std([0, 1, 2])
    assert list_obj == list_type(
        [primitives.Integer(0),
         primitives.Integer(1),
         primitives.Integer(2)])
Beispiel #9
0
                 types.SchemaType.SchemaColumn(
                     "b",
                     types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN),
                 types.SchemaType.SchemaColumn(
                     "c",
                     types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME),
                 types.SchemaType.SchemaColumn(
                     "d",
                     types.SchemaType.SchemaColumn.SchemaColumnType.DURATION),
                 types.SchemaType.SchemaColumn(
                     "e",
                     types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT),
                 types.SchemaType.SchemaColumn(
                     "f",
                     types.SchemaType.SchemaColumn.SchemaColumnType.STRING),
             ]))))
]

LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE = [
    (literals.Literal(scalar=s), v)
    for s, v in LIST_OF_SCALARS_AND_PYTHON_VALUES
]

LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE = [
    (literals.LiteralCollection(literals=[l, l, l]), [v, v, v])
    for l, v in LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE
]

LIST_OF_ALL_LITERALS_AND_VALUES = LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE + \
                                  LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE
Beispiel #10
0
 def __init__(self, value):
     """
     :param list[flytekit.common.types.base_sdk_types.FlyteSdkValue] value: List value to wrap
     """
     super(TypedListImpl, self).__init__(
         collection=_literals.LiteralCollection(literals=value))