Ejemplo n.º 1
0
    def fulfil_bindings(binding_data, fulfilled_promises):
        """
        Substitutes promise values in binding_data with model Literal values built from python std values in
        fulfilled_promises

        :param _interface.BindingData binding_data:
        :param dict[Text,T] fulfilled_promises:
        :rtype:
        """
        if binding_data.scalar:
            return _literals.Literal(scalar=binding_data.scalar)
        elif binding_data.collection:
            return _literals.Literal(collection=_literals.LiteralCollection(
                [DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for sub_binding_data in
                 binding_data.collection.bindings]))
        elif binding_data.promise:
            if binding_data.promise.node_id not in fulfilled_promises:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id))
            node_output = fulfilled_promises[binding_data.promise.node_id]
            if binding_data.promise.var not in node_output:
                raise _system_exception.FlyteSystemAssertion(
                    "Expecting output [{}] of node [{}] but that hasn't been produced.".format(
                        binding_data.promise.var,
                        binding_data.promise.node_id))

            return binding_data.promise.sdk_type.from_python_std(node_output[binding_data.promise.var])
        elif binding_data.map:
            return _literals.Literal(map=_literals.LiteralMap(
                {
                    k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in
                    _six.iteritems(binding_data.map.bindings)
                }))
Ejemplo n.º 2
0
def get_promise(binding_data: _literal_models.BindingData, outputs_cache: Dict[Node, Dict[str, Promise]]) -> Promise:
    """
    This is a helper function that will turn a binding into a Promise object, using a lookup map. Please see
    get_promise_map for the rest of the details.
    """
    if binding_data.promise is not None:
        if not isinstance(binding_data.promise, NodeOutput):
            raise FlyteValidationException(
                f"Binding data Promises have to be of the NodeOutput type {type(binding_data.promise)} found"
            )
        # b.var is the name of the input to the task
        # binding_data.promise.var is the name of the upstream node's output we want
        return outputs_cache[binding_data.promise.node][binding_data.promise.var]
    elif binding_data.scalar is not None:
        return Promise(var="placeholder", val=_literal_models.Literal(scalar=binding_data.scalar))
    elif binding_data.collection is not None:
        literals = []
        for bd in binding_data.collection.bindings:
            p = get_promise(bd, outputs_cache)
            literals.append(p.val)
        return Promise(
            var="placeholder",
            val=_literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literals)),
        )
    elif binding_data.map is not None:
        literals = {}
        for k, bd in binding_data.map.bindings.items():
            p = get_promise(bd, outputs_cache)
            literals[k] = p.val
        return Promise(
            var="placeholder", val=_literal_models.Literal(map=_literal_models.LiteralMap(literals=literals))
        )

    raise FlyteValidationException("Binding type unrecognized.")
Ejemplo n.º 3
0
def test_lp_default_handling():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @workflow
    def my_wf(a: int, b: int) -> (str, str, int, int):
        x, y = t1(a=a)
        u, v = t1(a=b)
        return y, v, x, u

    lp = launch_plan.LaunchPlan.create("test1", my_wf)
    assert len(lp.parameters.parameters) == 2
    assert lp.parameters.parameters["a"].required
    assert lp.parameters.parameters["a"].default is None
    assert lp.parameters.parameters["b"].required
    assert lp.parameters.parameters["b"].default is None
    assert len(lp.fixed_inputs.literals) == 0

    lp_with_defaults = launch_plan.LaunchPlan.create("test2", my_wf, default_inputs={"a": 3})
    assert len(lp_with_defaults.parameters.parameters) == 2
    assert not lp_with_defaults.parameters.parameters["a"].required
    assert lp_with_defaults.parameters.parameters["a"].default == _literal_models.Literal(
        scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3))
    )
    assert len(lp_with_defaults.fixed_inputs.literals) == 0

    lp_with_fixed = launch_plan.LaunchPlan.create("test3", my_wf, fixed_inputs={"a": 3})
    assert len(lp_with_fixed.parameters.parameters) == 1
    assert len(lp_with_fixed.fixed_inputs.literals) == 1
    assert lp_with_fixed.fixed_inputs.literals["a"] == _literal_models.Literal(
        scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3))
    )

    @workflow
    def my_wf2(a: int, b: int = 42) -> (str, str, int, int):
        x, y = t1(a=a)
        u, v = t1(a=b)
        return y, v, x, u

    lp = launch_plan.LaunchPlan.create("test4", my_wf2)
    assert len(lp.parameters.parameters) == 2
    assert len(lp.fixed_inputs.literals) == 0

    lp_with_defaults = launch_plan.LaunchPlan.create("test5", my_wf2, default_inputs={"a": 3})
    assert len(lp_with_defaults.parameters.parameters) == 2
    assert len(lp_with_defaults.fixed_inputs.literals) == 0
    # Launch plan defaults override wf defaults
    assert lp_with_defaults(b=3) == ("world-5", "world-5", 5, 5)

    lp_with_fixed = launch_plan.LaunchPlan.create("test6", my_wf2, fixed_inputs={"a": 3})
    assert len(lp_with_fixed.parameters.parameters) == 1
    assert len(lp_with_fixed.fixed_inputs.literals) == 1
    # Launch plan defaults override wf defaults
    assert lp_with_fixed(b=3) == ("world-5", "world-5", 5, 5)

    lp_with_fixed = launch_plan.LaunchPlan.create("test7", my_wf2, fixed_inputs={"b": 3})
    assert len(lp_with_fixed.parameters.parameters) == 1
    assert len(lp_with_fixed.fixed_inputs.literals) == 1
Ejemplo n.º 4
0
 def extract_value(
     ctx: FlyteContext, input_val: Any, val_type: type,
     flyte_literal_type: _type_models.LiteralType
 ) -> _literal_models.Literal:
     if isinstance(input_val, list):
         if flyte_literal_type.collection_type is None:
             raise Exception(
                 f"Not a collection type {flyte_literal_type} but got a list {input_val}"
             )
         try:
             sub_type = ListTransformer.get_sub_type(val_type)
         except ValueError:
             if len(input_val) == 0:
                 raise
             sub_type = type(input_val[0])
         literals = [
             extract_value(ctx, v, sub_type,
                           flyte_literal_type.collection_type)
             for v in input_val
         ]
         return _literal_models.Literal(
             collection=_literal_models.LiteralCollection(
                 literals=literals))
     elif isinstance(input_val, dict):
         if (flyte_literal_type.map_value_type is None
                 and flyte_literal_type.simple !=
                 _type_models.SimpleType.STRUCT):
             raise Exception(
                 f"Not a map type {flyte_literal_type} but got a map {input_val}"
             )
         k_type, sub_type = DictTransformer.get_dict_types(val_type)
         if flyte_literal_type.simple == _type_models.SimpleType.STRUCT:
             return TypeEngine.to_literal(ctx, input_val, type(input_val),
                                          flyte_literal_type)
         else:
             literals = {
                 k: extract_value(ctx, v, sub_type,
                                  flyte_literal_type.map_value_type)
                 for k, v in input_val.items()
             }
             return _literal_models.Literal(map=_literal_models.LiteralMap(
                 literals=literals))
     elif isinstance(input_val, Promise):
         # In the example above, this handles the "in2=a" type of argument
         return input_val.val
     elif isinstance(input_val, VoidPromise):
         raise AssertionError(
             f"Outputs of a non-output producing task {input_val.task_name} cannot be passed to another task."
         )
     elif isinstance(input_val, tuple):
         raise AssertionError(
             "Tuples are not a supported type for individual values in Flyte - got a tuple -"
             f" {input_val}. If using named tuple in an inner task, please, de-reference the"
             "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
             "return v.x, instead of v, even if this has a single element")
     else:
         # This handles native values, the 5 example
         return TypeEngine.to_literal(ctx, input_val, val_type,
                                      flyte_literal_type)
Ejemplo n.º 5
0
def test_infer_sdk_type_from_literal():
    o = _type_helpers.infer_sdk_type_from_literal(
        _literals.Literal(scalar=_literals.Scalar(
            primitive=_literals.Primitive(string_value="abc"))))
    assert o == _sdk_types.Types.String

    o = _type_helpers.infer_sdk_type_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())))
    assert o is _base_sdk_types.Void
Ejemplo n.º 6
0
 def extract_value(
     ctx: FlyteContext, input_val: Any, val_type: type,
     flyte_literal_type: _type_models.LiteralType
 ) -> _literal_models.Literal:
     if isinstance(input_val, list):
         if flyte_literal_type.collection_type is None:
             raise Exception(
                 f"Not a collection type {flyte_literal_type} but got a list {input_val}"
             )
         try:
             sub_type = ListTransformer.get_sub_type(val_type)
         except ValueError:
             if len(input_val) == 0:
                 raise
             sub_type = type(input_val[0])
         literals = [
             extract_value(ctx, v, sub_type,
                           flyte_literal_type.collection_type)
             for v in input_val
         ]
         return _literal_models.Literal(
             collection=_literal_models.LiteralCollection(
                 literals=literals))
     elif isinstance(input_val, dict):
         if (flyte_literal_type.map_value_type is None
                 and flyte_literal_type.simple !=
                 _type_models.SimpleType.STRUCT):
             raise Exception(
                 f"Not a map type {flyte_literal_type} but got a map {input_val}"
             )
         k_type, sub_type = DictTransformer.get_dict_types(val_type)
         if flyte_literal_type.simple == _type_models.SimpleType.STRUCT:
             return TypeEngine.to_literal(ctx, input_val, type(input_val),
                                          flyte_literal_type)
         else:
             literals = {
                 k: extract_value(ctx, v, sub_type,
                                  flyte_literal_type.map_value_type)
                 for k, v in input_val.items()
             }
             return _literal_models.Literal(map=_literal_models.LiteralMap(
                 literals=literals))
     elif isinstance(input_val, Promise):
         # In the example above, this handles the "in2=a" type of argument
         return input_val.val
     elif isinstance(input_val, VoidPromise):
         raise AssertionError(
             f"Outputs of a non-output producing task {input_val.task_name} cannot be passed to another task."
         )
     else:
         # This handles native values, the 5 example
         return TypeEngine.to_literal(ctx, input_val, val_type,
                                      flyte_literal_type)
Ejemplo n.º 7
0
def test_launch_workflow_with_subworkflows(flyteclient,
                                           flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.subworkflows.parent_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=101)))
            }),
        )
    execution.wait_for_completion()
    # check node execution inputs and outputs
    assert execution.node_executions["n0"].inputs == {"a": 101}
    assert execution.node_executions["n0"].outputs == {
        "t1_int_output": 103,
        "c": "world"
    }
    assert execution.node_executions["n1"].inputs == {"a": 103}
    assert execution.node_executions["n1"].outputs == {
        "o0": "world",
        "o1": "world"
    }

    # check subworkflow task execution inputs and outputs
    subworkflow_node_executions = execution.node_executions[
        "n1"].subworkflow_node_executions
    subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
    subworkflow_node_executions["n1-0-n1"].outputs == {
        "t1_int_output": 107,
        "c": "world"
    }
Ejemplo n.º 8
0
def test_launch_plan_spec():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role")
    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")
    empty_raw_data_output_config = common.RawOutputDataConfig("")
    max_parallelism = 100

    lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_raw_output_prefixed.to_flyte_idl())
    assert obj2 == lp_spec_raw_output_prefixed

    lp_spec_no_prefix = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        empty_raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_no_prefix.to_flyte_idl())
    assert obj2 == lp_spec_no_prefix
Ejemplo n.º 9
0
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.basic_workflow.my_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=10))),
                "b":
                literals.Literal(
                    literals.Scalar(
                        literals.Primitive(string_value="foobar"))),
            }),
        )
    execution.wait_for_completion()
    assert execution.node_executions["n0"].inputs == {"a": 10}
    assert execution.node_executions["n0"].outputs == {
        "t1_int_output": 12,
        "c": "world"
    }
    assert execution.node_executions["n1"].inputs == {
        "a": "world",
        "b": "foobar"
    }
    assert execution.node_executions["n1"].outputs == {"o0": "foobarworld"}
    assert execution.node_executions["n0"].task_executions[0].inputs == {
        "a": 10
    }
    assert execution.node_executions["n0"].task_executions[0].outputs == {
        "t1_int_output": 12,
        "c": "world"
    }
    assert execution.node_executions["n1"].task_executions[0].inputs == {
        "a": "world",
        "b": "foobar"
    }
    assert execution.node_executions["n1"].task_executions[0].outputs == {
        "o0": "foobarworld"
    }
    assert execution.inputs["a"] == 10
    assert execution.inputs["b"] == "foobar"
    assert execution.outputs["o0"] == 12
    assert execution.outputs["o1"] == "foobarworld"
Ejemplo n.º 10
0
def test_arrayjob_entrypoint_in_proc():
    with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__),
                                              'fake.config'),
                                 internal_overrides={
                                     'project': 'test',
                                     'domain': 'development'
                                 }):
        with _utils.AutoDeletingTempDir("dir") as dir:
            literal_map = _type_helpers.pack_python_std_map_to_literal_map(
                {'a': 9},
                _type_map_from_variable_map(
                    _task_defs.add_one.interface.inputs))

            input_dir = os.path.join(dir.name, "1")
            os.mkdir(
                input_dir)  # auto cleanup will take this subdir into account

            input_file = os.path.join(input_dir, "inputs.pb")
            _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)

            # construct indexlookup.pb which has array: [1]
            mapped_index = _literals.Literal(
                _literals.Scalar(primitive=_literals.Primitive(integer=1)))
            index_lookup_collection = _literals.LiteralCollection(
                [mapped_index])
            index_lookup_file = os.path.join(dir.name, "indexlookup.pb")
            _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(),
                                       index_lookup_file)

            # fake arrayjob task by setting environment variables
            orig_env_index_var_name = os.environ.get(
                'BATCH_JOB_ARRAY_INDEX_VAR_NAME')
            orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX')
            os.environ[
                'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX'
            os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0'

            execute_task(_task_defs.add_one.task_module,
                         _task_defs.add_one.task_function_name, dir.name,
                         dir.name, False)

            raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
                _literal_models.LiteralMap.from_flyte_idl(
                    _utils.load_proto_from_file(
                        _literals_pb2.LiteralMap,
                        os.path.join(input_dir, _constants.OUTPUT_FILE_NAME))),
                _type_map_from_variable_map(
                    _task_defs.add_one.interface.outputs))
            assert raw_map['b'] == 10
            assert len(raw_map) == 1

            # reset the env vars
            if orig_env_index_var_name:
                os.environ[
                    'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = orig_env_index_var_name
            if orig_env_array_index:
                os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = orig_env_array_index
Ejemplo n.º 11
0
def test_model_promotion():
    list_type = containers.List(primitives.Integer)
    list_model = literals.Literal(collection=literals.LiteralCollection(
        literals=[
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=0))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=1))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=2))),
        ]))
    list_obj = list_type.promote_from_model(list_model)
    assert len(list_obj.collection.literals) == 3
    assert isinstance(list_obj.collection.literals[0], primitives.Integer)
    assert list_obj == list_type.from_python_std([0, 1, 2])
    assert list_obj == list_type(
        [primitives.Integer(0),
         primitives.Integer(1),
         primitives.Integer(2)])
Ejemplo n.º 12
0
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.basic_workflow.my_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=10))),
                "b":
                literals.Literal(
                    literals.Scalar(
                        literals.Primitive(string_value="foobar"))),
            }),
        )
    execution.wait_for_completion()
    assert execution.outputs.literals["o0"].scalar.primitive.integer == 12
    assert execution.outputs.literals[
        "o1"].scalar.primitive.string_value == "foobarworld"
Ejemplo n.º 13
0
def test_infer_proto_from_literal():
    sdk_type = _flyte_engine.FlyteDefaultTypeEngine().infer_sdk_type_from_literal(
        _literal_models.Literal(
            scalar=_literal_models.Scalar(
                binary=_literal_models.Binary(
                    value="", tag="{}{}".format(_proto.Protobuf.TAG_PREFIX, "flyteidl.core.errors_pb2.ContainerError",),
                )
            )
        )
    )
    assert sdk_type.pb_type == _errors_pb2.ContainerError
Ejemplo n.º 14
0
def test_blob_promote_from_model():
    m = _literal_models.Literal(scalar=_literal_models.Scalar(
        blob=_literal_models.Blob(
            _literal_models.BlobMetadata(
                _core_types.BlobType(format="f",
                                     dimensionality=_core_types.BlobType.
                                     BlobDimensionality.SINGLE)),
            "some/path")))
    b = blobs.Blob.promote_from_model(m)
    assert b.value.blob.uri == "some/path"
    assert b.value.blob.metadata.type.format == "f"
    assert b.value.blob.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
Ejemplo n.º 15
0
def test_get_sdk_value_from_literal():
    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())))
    assert o.to_python_std() is None

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())),
        sdk_type=_sdk_types.Types.Integer,
    )
    assert o.to_python_std() is None

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=1))),
        sdk_type=_sdk_types.Types.Integer,
    )
    assert o.to_python_std() == 1

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(collection=_literals.LiteralCollection([
            _literals.Literal(scalar=_literals.Scalar(
                primitive=_literals.Primitive(integer=1))),
            _literals.Literal(scalar=_literals.Scalar(
                none_type=_literals.Void())),
        ])))
    assert o.to_python_std() == [1, None]
Ejemplo n.º 16
0
def test_scalar_literals(scalar_value_pair):
    scalar, _ = scalar_value_pair
    obj = literals.Literal(scalar=scalar)
    assert obj.value == scalar
    assert obj.scalar == scalar
    assert obj.collection is None
    assert obj.map is None

    obj2 = literals.Literal.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.value == scalar
    assert obj2.scalar == scalar
    assert obj2.collection is None
    assert obj2.map is None
Ejemplo n.º 17
0
def test_lp_serialize():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_subwf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf)
    lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2",
                                                     my_subwf,
                                                     default_inputs={"a": 3})

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert sdk_lp.default_inputs.parameters["a"].required
    assert len(sdk_lp.fixed_inputs.literals) == 0

    sdk_lp = get_serializable(OrderedDict(), serialization_settings,
                              lp_with_defaults)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert not sdk_lp.default_inputs.parameters["a"].required
    assert sdk_lp.default_inputs.parameters[
        "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar(
            primitive=_literal_models.Primitive(integer=3)))
    assert len(sdk_lp.fixed_inputs.literals) == 0

    # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the
    # required field needs to be None, not False.
    parameter_a = sdk_lp.default_inputs.parameters["a"]
    parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl())
    assert parameter_a.default is not None
Ejemplo n.º 18
0
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:

            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask
Ejemplo n.º 19
0
def test_old_style_role():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")

    old_role = _launch_plan_idl.Auth(
        kubernetes_service_account="my:service:account")

    old_style_spec = _launch_plan_idl.LaunchPlanSpec(
        workflow_id=identifier_model.to_flyte_idl(),
        entity_metadata=launch_plan_metadata_model.to_flyte_idl(),
        default_inputs=parameter_map.to_flyte_idl(),
        fixed_inputs=fixed_inputs.to_flyte_idl(),
        labels=labels_model.to_flyte_idl(),
        annotations=annotations_model.to_flyte_idl(),
        raw_output_data_config=raw_data_output_config.to_flyte_idl(),
        auth=old_role,
    )

    lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec)

    assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
Ejemplo n.º 20
0
from flytekit.common import constants, utils
from flytekit.common.exceptions import scopes
from flytekit.configuration import TemporaryConfiguration
from flytekit.engines.flyte import engine
from flytekit.models import common as _common_models
from flytekit.models import execution as _execution_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import literals
from flytekit.models import task as _task_models
from flytekit.models.admin import common as _common
from flytekit.models.core import errors, identifier
from flytekit.sdk import test_utils

_INPUT_MAP = literals.LiteralMap(
    {"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))}
)
_OUTPUT_MAP = literals.LiteralMap(
    {"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))}
)
_EMPTY_LITERAL_MAP = literals.LiteralMap(literals={})


@pytest.fixture(scope="function", autouse=True)
def temp_config():
    with TemporaryConfiguration(
        os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            "../../../common/configs/local.config",
        ),
        internal_overrides={
Ejemplo n.º 21
0
from flyteidl.core import errors_pb2
from mock import MagicMock, patch, PropertyMock

from flytekit.common import constants, utils
from flytekit.common.exceptions import scopes
from flytekit.configuration import TemporaryConfiguration
from flytekit.engines.flyte import engine
from flytekit.models import literals, execution as _execution_models, common as _common_models, launch_plan as \
    _launch_plan_models
from flytekit.models.core import errors, identifier
from flytekit.sdk import test_utils


_INPUT_MAP = literals.LiteralMap(
    {
        'a': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))
    }
)
_OUTPUT_MAP = literals.LiteralMap(
    {
        'b': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))
    }
)


@pytest.fixture(scope="function", autouse=True)
def temp_config():
    with TemporaryConfiguration(
            os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../common/configs/local.config'),
            internal_overrides={
                'image': 'myflyteimage:{}'.format(
Ejemplo n.º 22
0
                 types.SchemaType.SchemaColumn(
                     "b",
                     types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN),
                 types.SchemaType.SchemaColumn(
                     "c",
                     types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME),
                 types.SchemaType.SchemaColumn(
                     "d",
                     types.SchemaType.SchemaColumn.SchemaColumnType.DURATION),
                 types.SchemaType.SchemaColumn(
                     "e",
                     types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT),
                 types.SchemaType.SchemaColumn(
                     "f",
                     types.SchemaType.SchemaColumn.SchemaColumnType.STRING),
             ]))))
]

LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE = [
    (literals.Literal(scalar=s), v)
    for s, v in LIST_OF_SCALARS_AND_PYTHON_VALUES
]

LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE = [
    (literals.LiteralCollection(literals=[l, l, l]), [v, v, v])
    for l, v in LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE
]

LIST_OF_ALL_LITERALS_AND_VALUES = LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE + \
                                  LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE
Ejemplo n.º 23
0
import pytest

from flytekit.models import common as _common_models
from flytekit.models import execution as _execution
from flytekit.models import literals as _literals
from flytekit.models.core import execution as _core_exec
from flytekit.models.core import identifier as _identifier
from tests.flytekit.common import parameterizers as _parameterizers

_INPUT_MAP = _literals.LiteralMap({
    "a":
    _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(
        integer=1)))
})
_OUTPUT_MAP = _literals.LiteralMap({
    "b":
    _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(
        integer=2)))
})


def test_execution_metadata():
    obj = _execution.ExecutionMetadata(
        _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1)
    assert obj.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj.principal == "tester"
    assert obj.nesting == 1
    obj2 = _execution.ExecutionMetadata.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj2.principal == "tester"
Ejemplo n.º 24
0
LIST_OF_SCALARS_AND_PYTHON_VALUES = [
    (literals.Scalar(primitive=literals.Primitive(integer=100)), 100),
    (literals.Scalar(primitive=literals.Primitive(float_value=500.0)), 500.0),
    (literals.Scalar(primitive=literals.Primitive(boolean=True)), True),
    (literals.Scalar(primitive=literals.Primitive(string_value="hello")),
     "hello"),
    (
        literals.Scalar(primitive=literals.Primitive(duration=timedelta(
            seconds=5))),
        timedelta(seconds=5),
    ),
    (literals.Scalar(none_type=literals.Void()), None),
    (
        literals.Scalar(union=literals.Union(
            value=literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=10))),
            stored_type=types.LiteralType(simple=types.SimpleType.INTEGER,
                                          structure=types.TypeStructure(
                                              tag="int")),
        )),
        10,
    ),
    (
        literals.Scalar(union=literals.Union(
            value=literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=10))),
            stored_type=types.LiteralType(simple=types.SimpleType.INTEGER,
                                          structure=types.TypeStructure(
                                              tag="int")),
        )),
        10,