Ejemplo n.º 1
0
def test_node_task_with_inputs():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)
    bd = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=5)))
    bd2 = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=99)))
    binding = _literals.Binding(var="myvar", binding=bd)
    binding2 = _literals.Binding(var="myothervar", binding=bd2)

    obj = _workflow.Node(
        id="some:node:id",
        metadata=nm,
        inputs=[binding, binding2],
        upstream_node_ids=[],
        output_aliases=[],
        task_node=task,
    )
    assert obj.target == task
    assert obj.id == "some:node:id"
    assert obj.metadata == nm
    assert len(obj.inputs) == 2
    assert obj.inputs[0] == binding

    obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.target == task
    assert obj2.id == "some:node:id"
    assert obj2.metadata == nm
    assert len(obj2.inputs) == 2
    assert obj2.inputs[1] == binding2
Ejemplo n.º 2
0
def test_get_sdk_value_from_literal():
    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())))
    assert o.to_python_std() is None

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())),
        sdk_type=_sdk_types.Types.Integer,
    )
    assert o.to_python_std() is None

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=1))),
        sdk_type=_sdk_types.Types.Integer,
    )
    assert o.to_python_std() == 1

    o = _type_helpers.get_sdk_value_from_literal(
        _literals.Literal(collection=_literals.LiteralCollection([
            _literals.Literal(scalar=_literals.Scalar(
                primitive=_literals.Primitive(integer=1))),
            _literals.Literal(scalar=_literals.Scalar(
                none_type=_literals.Void())),
        ])))
    assert o.to_python_std() == [1, None]
Ejemplo n.º 3
0
def test_binding_data_map():
    b1 = literals.BindingData(scalar=literals.Scalar(
        primitive=literals.Primitive(integer=5)))
    b2 = literals.BindingData(scalar=literals.Scalar(
        primitive=literals.Primitive(integer=57)))
    b3 = literals.BindingData(scalar=literals.Scalar(
        primitive=literals.Primitive(integer=2)))
    binding_map_sub = literals.BindingDataMap(bindings={
        "first": b1,
        "second": b2
    })
    binding_map = literals.BindingDataMap(
        bindings={
            "three": b3,
            "sample_map": literals.BindingData(map=binding_map_sub)
        })
    obj = literals.BindingData(map=binding_map)
    assert obj.scalar is None
    assert obj.promise is None
    assert obj.collection is None
    assert obj.value.bindings["three"].value.value.value == 2
    assert obj.value.bindings["sample_map"].value.bindings[
        "second"].value.value.value == 57

    obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.scalar is None
    assert obj2.promise is None
    assert obj2.collection is None
    assert obj2.value.bindings["three"].value.value.value == 2
    assert obj2.value.bindings["sample_map"].value.bindings[
        "first"].value.value.value == 5
Ejemplo n.º 4
0
def test_lp_default_handling():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @workflow
    def my_wf(a: int, b: int) -> (str, str, int, int):
        x, y = t1(a=a)
        u, v = t1(a=b)
        return y, v, x, u

    lp = launch_plan.LaunchPlan.create("test1", my_wf)
    assert len(lp.parameters.parameters) == 2
    assert lp.parameters.parameters["a"].required
    assert lp.parameters.parameters["a"].default is None
    assert lp.parameters.parameters["b"].required
    assert lp.parameters.parameters["b"].default is None
    assert len(lp.fixed_inputs.literals) == 0

    lp_with_defaults = launch_plan.LaunchPlan.create("test2", my_wf, default_inputs={"a": 3})
    assert len(lp_with_defaults.parameters.parameters) == 2
    assert not lp_with_defaults.parameters.parameters["a"].required
    assert lp_with_defaults.parameters.parameters["a"].default == _literal_models.Literal(
        scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3))
    )
    assert len(lp_with_defaults.fixed_inputs.literals) == 0

    lp_with_fixed = launch_plan.LaunchPlan.create("test3", my_wf, fixed_inputs={"a": 3})
    assert len(lp_with_fixed.parameters.parameters) == 1
    assert len(lp_with_fixed.fixed_inputs.literals) == 1
    assert lp_with_fixed.fixed_inputs.literals["a"] == _literal_models.Literal(
        scalar=_literal_models.Scalar(primitive=_literal_models.Primitive(integer=3))
    )

    @workflow
    def my_wf2(a: int, b: int = 42) -> (str, str, int, int):
        x, y = t1(a=a)
        u, v = t1(a=b)
        return y, v, x, u

    lp = launch_plan.LaunchPlan.create("test4", my_wf2)
    assert len(lp.parameters.parameters) == 2
    assert len(lp.fixed_inputs.literals) == 0

    lp_with_defaults = launch_plan.LaunchPlan.create("test5", my_wf2, default_inputs={"a": 3})
    assert len(lp_with_defaults.parameters.parameters) == 2
    assert len(lp_with_defaults.fixed_inputs.literals) == 0
    # Launch plan defaults override wf defaults
    assert lp_with_defaults(b=3) == ("world-5", "world-5", 5, 5)

    lp_with_fixed = launch_plan.LaunchPlan.create("test6", my_wf2, fixed_inputs={"a": 3})
    assert len(lp_with_fixed.parameters.parameters) == 1
    assert len(lp_with_fixed.fixed_inputs.literals) == 1
    # Launch plan defaults override wf defaults
    assert lp_with_fixed(b=3) == ("world-5", "world-5", 5, 5)

    lp_with_fixed = launch_plan.LaunchPlan.create("test7", my_wf2, fixed_inputs={"b": 3})
    assert len(lp_with_fixed.parameters.parameters) == 1
    assert len(lp_with_fixed.fixed_inputs.literals) == 1
Ejemplo n.º 5
0
def test_launch_workflow_with_subworkflows(flyteclient,
                                           flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.subworkflows.parent_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=101)))
            }),
        )
    execution.wait_for_completion()
    # check node execution inputs and outputs
    assert execution.node_executions["n0"].inputs == {"a": 101}
    assert execution.node_executions["n0"].outputs == {
        "t1_int_output": 103,
        "c": "world"
    }
    assert execution.node_executions["n1"].inputs == {"a": 103}
    assert execution.node_executions["n1"].outputs == {
        "o0": "world",
        "o1": "world"
    }

    # check subworkflow task execution inputs and outputs
    subworkflow_node_executions = execution.node_executions[
        "n1"].subworkflow_node_executions
    subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
    subworkflow_node_executions["n1-0-n1"].outputs == {
        "t1_int_output": 107,
        "c": "world"
    }
Ejemplo n.º 6
0
def test_launch_plan_spec():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role")
    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")
    empty_raw_data_output_config = common.RawOutputDataConfig("")
    max_parallelism = 100

    lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_raw_output_prefixed.to_flyte_idl())
    assert obj2 == lp_spec_raw_output_prefixed

    lp_spec_no_prefix = launch_plan.LaunchPlanSpec(
        identifier_model,
        launch_plan_metadata_model,
        parameter_map,
        fixed_inputs,
        labels_model,
        annotations_model,
        auth_role_model,
        empty_raw_data_output_config,
        max_parallelism,
    )

    obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(
        lp_spec_no_prefix.to_flyte_idl())
    assert obj2 == lp_spec_no_prefix
Ejemplo n.º 7
0
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.basic_workflow.my_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=10))),
                "b":
                literals.Literal(
                    literals.Scalar(
                        literals.Primitive(string_value="foobar"))),
            }),
        )
    execution.wait_for_completion()
    assert execution.node_executions["n0"].inputs == {"a": 10}
    assert execution.node_executions["n0"].outputs == {
        "t1_int_output": 12,
        "c": "world"
    }
    assert execution.node_executions["n1"].inputs == {
        "a": "world",
        "b": "foobar"
    }
    assert execution.node_executions["n1"].outputs == {"o0": "foobarworld"}
    assert execution.node_executions["n0"].task_executions[0].inputs == {
        "a": 10
    }
    assert execution.node_executions["n0"].task_executions[0].outputs == {
        "t1_int_output": 12,
        "c": "world"
    }
    assert execution.node_executions["n1"].task_executions[0].inputs == {
        "a": "world",
        "b": "foobar"
    }
    assert execution.node_executions["n1"].task_executions[0].outputs == {
        "o0": "foobarworld"
    }
    assert execution.inputs["a"] == 10
    assert execution.inputs["b"] == "foobar"
    assert execution.outputs["o0"] == 12
    assert execution.outputs["o1"] == "foobarworld"
def test_arrayjob_entrypoint_in_proc():
    with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__),
                                              'fake.config'),
                                 internal_overrides={
                                     'project': 'test',
                                     'domain': 'development'
                                 }):
        with _utils.AutoDeletingTempDir("dir") as dir:
            literal_map = _type_helpers.pack_python_std_map_to_literal_map(
                {'a': 9},
                _type_map_from_variable_map(
                    _task_defs.add_one.interface.inputs))

            input_dir = os.path.join(dir.name, "1")
            os.mkdir(
                input_dir)  # auto cleanup will take this subdir into account

            input_file = os.path.join(input_dir, "inputs.pb")
            _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)

            # construct indexlookup.pb which has array: [1]
            mapped_index = _literals.Literal(
                _literals.Scalar(primitive=_literals.Primitive(integer=1)))
            index_lookup_collection = _literals.LiteralCollection(
                [mapped_index])
            index_lookup_file = os.path.join(dir.name, "indexlookup.pb")
            _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(),
                                       index_lookup_file)

            # fake arrayjob task by setting environment variables
            orig_env_index_var_name = os.environ.get(
                'BATCH_JOB_ARRAY_INDEX_VAR_NAME')
            orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX')
            os.environ[
                'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX'
            os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0'

            execute_task(_task_defs.add_one.task_module,
                         _task_defs.add_one.task_function_name, dir.name,
                         dir.name, False)

            raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std(
                _literal_models.LiteralMap.from_flyte_idl(
                    _utils.load_proto_from_file(
                        _literals_pb2.LiteralMap,
                        os.path.join(input_dir, _constants.OUTPUT_FILE_NAME))),
                _type_map_from_variable_map(
                    _task_defs.add_one.interface.outputs))
            assert raw_map['b'] == 10
            assert len(raw_map) == 1

            # reset the env vars
            if orig_env_index_var_name:
                os.environ[
                    'BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = orig_env_index_var_name
            if orig_env_array_index:
                os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = orig_env_array_index
Ejemplo n.º 9
0
def test_infer_sdk_type_from_literal():
    o = _type_helpers.infer_sdk_type_from_literal(
        _literals.Literal(scalar=_literals.Scalar(
            primitive=_literals.Primitive(string_value="abc"))))
    assert o == _sdk_types.Types.String

    o = _type_helpers.infer_sdk_type_from_literal(
        _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())))
    assert o is _base_sdk_types.Void
Ejemplo n.º 10
0
def test_model_promotion():
    list_type = containers.List(primitives.Integer)
    list_model = literals.Literal(collection=literals.LiteralCollection(
        literals=[
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=0))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=1))),
            literals.Literal(scalar=literals.Scalar(
                primitive=literals.Primitive(integer=2))),
        ]))
    list_obj = list_type.promote_from_model(list_model)
    assert len(list_obj.collection.literals) == 3
    assert isinstance(list_obj.collection.literals[0], primitives.Integer)
    assert list_obj == list_type.from_python_std([0, 1, 2])
    assert list_obj == list_type(
        [primitives.Integer(0),
         primitives.Integer(1),
         primitives.Integer(2)])
Ejemplo n.º 11
0
def test_launch_workflow_with_args(flyteclient, flyte_workflows_register):
    execution = launch_plan.FlyteLaunchPlan.fetch(
        PROJECT, "development", "workflows.basic.basic_workflow.my_wf",
        f"v{VERSION}").launch_with_literals(
            PROJECT,
            "development",
            literals.LiteralMap({
                "a":
                literals.Literal(
                    literals.Scalar(literals.Primitive(integer=10))),
                "b":
                literals.Literal(
                    literals.Scalar(
                        literals.Primitive(string_value="foobar"))),
            }),
        )
    execution.wait_for_completion()
    assert execution.outputs.literals["o0"].scalar.primitive.integer == 12
    assert execution.outputs.literals[
        "o1"].scalar.primitive.string_value == "foobarworld"
Ejemplo n.º 12
0
def test_binding_data_collection():
    b1 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5)))
    b2 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=57)))

    coll = literals.BindingDataCollection(bindings=[b1, b2])
    obj = literals.BindingData(collection=coll)
    assert obj.scalar is None
    assert obj.promise is None
    assert obj.collection is not None
    assert obj.map is None
    assert obj.value.bindings[0].value.value.value == 5
    assert obj.value.bindings[1].value.value.value == 57

    obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.scalar is None
    assert obj2.promise is None
    assert obj2.collection is not None
    assert obj2.map is None
    assert obj2.value.bindings[0].value.value.value == 5
    assert obj2.value.bindings[1].value.value.value == 57
Ejemplo n.º 13
0
def test_construct_literal_map_from_variable_map():
    v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description")
    variable_map = {
        'inputa': v,
    }

    input_txt_dictionary = {'inputa': '15'}

    literal_map = helpers.construct_literal_map_from_variable_map(variable_map, input_txt_dictionary)
    parsed_literal = literal_map.literals['inputa'].value
    ll = literals.Scalar(primitive=literals.Primitive(integer=15))
    assert parsed_literal == ll
Ejemplo n.º 14
0
def test_branch_node():
    nm = _get_sample_node_metadata()
    task = _workflow.TaskNode(reference_id=_generic_id)
    bd = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=5)))
    bd2 = _literals.BindingData(scalar=_literals.Scalar(
        primitive=_literals.Primitive(integer=99)))
    binding = _literals.Binding(var='myvar', binding=bd)
    binding2 = _literals.Binding(var='myothervar', binding=bd2)

    obj = _workflow.Node(id='some:node:id',
                         metadata=nm,
                         inputs=[binding, binding2],
                         upstream_node_ids=[],
                         output_aliases=[],
                         task_node=task)

    bn = _workflow.BranchNode(
        _workflow.IfElseBlock(
            case=_workflow.IfBlock(condition=_condition.BooleanExpression(
                comparison=_condition.ComparisonExpression(
                    _condition.ComparisonExpression.Operator.EQ,
                    _condition.Operand(primitive=_literals.Primitive(
                        integer=5)),
                    _condition.Operand(primitive=_literals.Primitive(
                        integer=2)))),
                                   then_node=obj),
            other=[
                _workflow.IfBlock(condition=_condition.BooleanExpression(
                    conjunction=_condition.ConjunctionExpression(
                        _condition.ConjunctionExpression.LogicalOperator.AND,
                        _condition.BooleanExpression(
                            comparison=_condition.ComparisonExpression(
                                _condition.ComparisonExpression.Operator.EQ,
                                _condition.Operand(
                                    primitive=_literals.Primitive(integer=5)),
                                _condition.Operand(
                                    primitive=_literals.Primitive(
                                        integer=2)))),
                        _condition.BooleanExpression(
                            comparison=_condition.ComparisonExpression(
                                _condition.ComparisonExpression.Operator.EQ,
                                _condition.Operand(
                                    primitive=_literals.Primitive(integer=5)),
                                _condition.Operand(
                                    primitive=_literals.Primitive(
                                        integer=2)))))),
                                  then_node=obj)
            ],
            else_node=obj))
Ejemplo n.º 15
0
def test_binding_data_scalar():
    obj = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5)))
    assert obj.value.value.value == 5
    assert obj.promise is None
    assert obj.collection is None
    assert obj.map is None

    obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.value.value.value == 5
    assert obj2.promise is None
    assert obj2.collection is None
    assert obj2.map is None
Ejemplo n.º 16
0
def test_construct_literal_map_from_parameter_map():
    v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description")
    p = Parameter(var=v, required=True)
    pm = ParameterMap(parameters={"inputa": p})

    input_txt_dictionary = {"inputa": "15"}

    literal_map = helpers.construct_literal_map_from_parameter_map(pm, input_txt_dictionary)
    parsed_literal = literal_map.literals["inputa"].value
    ll = literals.Scalar(primitive=literals.Primitive(integer=15))
    assert parsed_literal == ll

    with pytest.raises(Exception):
        helpers.construct_literal_map_from_parameter_map(pm, {})
Ejemplo n.º 17
0
def test_lp_serialize():
    @task
    def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
        a = a + 2
        return a, "world-" + str(a)

    @task
    def t2(a: str, b: str) -> str:
        return b + a

    @workflow
    def my_subwf(a: int) -> (str, str):
        x, y = t1(a=a)
        u, v = t1(a=x)
        return y, v

    lp = launch_plan.LaunchPlan.create("serialize_test1", my_subwf)
    lp_with_defaults = launch_plan.LaunchPlan.create("serialize_test2",
                                                     my_subwf,
                                                     default_inputs={"a": 3})

    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa",
                                       tag="123")),
        env={},
    )
    sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert sdk_lp.default_inputs.parameters["a"].required
    assert len(sdk_lp.fixed_inputs.literals) == 0

    sdk_lp = get_serializable(OrderedDict(), serialization_settings,
                              lp_with_defaults)
    assert len(sdk_lp.default_inputs.parameters) == 1
    assert not sdk_lp.default_inputs.parameters["a"].required
    assert sdk_lp.default_inputs.parameters[
        "a"].default == _literal_models.Literal(scalar=_literal_models.Scalar(
            primitive=_literal_models.Primitive(integer=3)))
    assert len(sdk_lp.fixed_inputs.literals) == 0

    # Adding a check to make sure oneof is respected. Tricky with booleans... if a default is specified, the
    # required field needs to be None, not False.
    parameter_a = sdk_lp.default_inputs.parameters["a"]
    parameter_a = Parameter.from_flyte_idl(parameter_a.to_flyte_idl())
    assert parameter_a.default is not None
Ejemplo n.º 18
0
def test_scalar_primitive():
    obj = literals.Scalar(primitive=literals.Primitive(float_value=5.6))
    assert obj.value.value == 5.6
    assert obj.error is None
    assert obj.blob is None
    assert obj.binary is None
    assert obj.schema is None
    assert obj.none_type is None

    x = obj.to_flyte_idl()
    assert x.primitive.float_value == 5.6

    obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.error is None
    assert obj2.blob is None
    assert obj2.binary is None
    assert obj2.schema is None
    assert obj2.none_type is None
Ejemplo n.º 19
0
    def setUp(self):
        with _utils.AutoDeletingTempDir("input_dir") as input_dir:

            self._task_input = _literals.LiteralMap({
                "input_1":
                _literals.Literal(scalar=_literals.Scalar(
                    primitive=_literals.Primitive(integer=1)))
            })

            self._context = _common_engine.EngineContext(
                execution_id=WorkflowExecutionIdentifier(project="unit_test",
                                                         domain="unit_test",
                                                         name="unit_test"),
                execution_date=_datetime.datetime.utcnow(),
                stats=MockStats(),
                logging=None,
                tmp_dir=input_dir.name,
            )

            # Defining the distributed training task without specifying an output-persist
            # predicate (so it will use the default)
            @inputs(input_1=Types.Integer)
            @outputs(model=Types.Blob)
            @custom_training_job_task(
                training_job_resource_config=TrainingJobResourceConfig(
                    instance_type="ml.m4.xlarge",
                    instance_count=2,
                    volume_size_in_gb=25,
                ),
                algorithm_specification=AlgorithmSpecification(
                    input_mode=InputMode.FILE,
                    input_content_type=InputContentType.TEXT_CSV,
                    metric_definitions=[
                        MetricDefinition(name="Validation error",
                                         regex="validation:error")
                    ],
                ),
            )
            def my_distributed_task(wf_params, input_1, model):
                pass

            self._my_distributed_task = my_distributed_task
            assert type(self._my_distributed_task) == CustomTrainingJobTask
Ejemplo n.º 20
0
def test_old_style_role():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")

    old_role = _launch_plan_idl.Auth(
        kubernetes_service_account="my:service:account")

    old_style_spec = _launch_plan_idl.LaunchPlanSpec(
        workflow_id=identifier_model.to_flyte_idl(),
        entity_metadata=launch_plan_metadata_model.to_flyte_idl(),
        default_inputs=parameter_map.to_flyte_idl(),
        fixed_inputs=fixed_inputs.to_flyte_idl(),
        labels=labels_model.to_flyte_idl(),
        annotations=annotations_model.to_flyte_idl(),
        raw_output_data_config=raw_data_output_config.to_flyte_idl(),
        auth=old_role,
    )

    lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec)

    assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
Ejemplo n.º 21
0
def test_boolean_primitive():
    obj = literals.Primitive(boolean=True)
    assert obj.integer is None
    assert obj.boolean is True
    assert obj.datetime is None
    assert obj.duration is None
    assert obj.float_value is None
    assert obj.string_value is None
    assert obj.value is True
    assert obj != literals.Primitive(integer=0)
    assert obj != literals.Primitive(boolean=False)
    assert obj != literals.Primitive(datetime=datetime.now())
    assert obj != literals.Primitive(duration=timedelta(minutes=1))
    assert obj != literals.Primitive(float_value=1.0)
    assert obj != literals.Primitive(string_value="abc")

    obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.integer is None
    assert obj2.boolean is True
    assert obj2.datetime is None
    assert obj2.duration is None
    assert obj2.float_value is None
    assert obj2.string_value is None
    assert obj2.value is True
    assert obj2 != literals.Primitive(integer=0)
    assert obj2 != literals.Primitive(boolean=False)
    assert obj2 != literals.Primitive(datetime=datetime.now())
    assert obj2 != literals.Primitive(duration=timedelta(minutes=1))
    assert obj2 != literals.Primitive(float_value=1.0)
    assert obj2 != literals.Primitive(string_value="abc")

    obj3 = literals.Primitive(boolean=False)
    assert obj3.value is False

    with pytest.raises(Exception):
        literals.Primitive(boolean=datetime.now()).to_flyte_idl()
Ejemplo n.º 22
0
def test_datetime_primitive():
    dt = datetime.utcnow().replace(tzinfo=pytz.UTC)
    obj = literals.Primitive(datetime=dt)
    assert obj.integer is None
    assert obj.boolean is None
    assert obj.datetime == dt
    assert obj.duration is None
    assert obj.float_value is None
    assert obj.string_value is None
    assert obj.value == dt
    assert obj != literals.Primitive(integer=0)
    assert obj != literals.Primitive(boolean=False)
    assert obj != literals.Primitive(datetime=dt + timedelta(seconds=1))
    assert obj != literals.Primitive(duration=timedelta(minutes=1))
    assert obj != literals.Primitive(float_value=1.0)
    assert obj != literals.Primitive(string_value="abc")

    obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.integer is None
    assert obj2.boolean is None
    assert obj2.datetime == dt
    assert obj2.duration is None
    assert obj2.float_value is None
    assert obj2.string_value is None
    assert obj2.value == dt
    assert obj2 != literals.Primitive(integer=0)
    assert obj2 != literals.Primitive(boolean=False)
    assert obj2 != literals.Primitive(datetime=dt + timedelta(seconds=1))
    assert obj2 != literals.Primitive(duration=timedelta(minutes=1))
    assert obj2 != literals.Primitive(float_value=1.0)
    assert obj2 != literals.Primitive(string_value="abc")

    with pytest.raises(Exception):
        literals.Primitive(datetime=1.0).to_flyte_idl()
Ejemplo n.º 23
0
 def __init__(self, value):
     """
     :param int value: Int value to wrap
     """
     super(Integer, self).__init__(scalar=_literals.Scalar(
         primitive=_literals.Primitive(integer=value)))
Ejemplo n.º 24
0
 def __init__(self, value):
     """
     :param datetime.timedelta value: value to wrap
     """
     super(Timedelta, self).__init__(scalar=_literals.Scalar(
         primitive=_literals.Primitive(duration=value)))
Ejemplo n.º 25
0
 def __init__(self, value):
     """
     :param datetime.datetime value: value to wrap
     """
     super(Datetime, self).__init__(scalar=_literals.Scalar(
         primitive=_literals.Primitive(datetime=value)))
Ejemplo n.º 26
0
 def __init__(self, value):
     """
     :param Text value: value to wrap
     """
     super(String, self).__init__(scalar=_literals.Scalar(
         primitive=_literals.Primitive(string_value=value)))
Ejemplo n.º 27
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Ejemplo n.º 28
0
 def __init__(self, value):
     """
     :param float value: value to wrap
     """
     super(Float, self).__init__(scalar=_literals.Scalar(
         primitive=_literals.Primitive(float_value=value)))
Ejemplo n.º 29
0
import pytest

from flytekit.models import common as _common_models
from flytekit.models import execution as _execution
from flytekit.models import literals as _literals
from flytekit.models.core import execution as _core_exec
from flytekit.models.core import identifier as _identifier
from tests.flytekit.common import parameterizers as _parameterizers

_INPUT_MAP = _literals.LiteralMap({
    "a":
    _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(
        integer=1)))
})
_OUTPUT_MAP = _literals.LiteralMap({
    "b":
    _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(
        integer=2)))
})


def test_execution_metadata():
    obj = _execution.ExecutionMetadata(
        _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1)
    assert obj.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj.principal == "tester"
    assert obj.nesting == 1
    obj2 = _execution.ExecutionMetadata.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj2.principal == "tester"
Ejemplo n.º 30
0
 def __init__(self, value):
     """
     :param bool value: value to wrap
     """
     super(Boolean, self).__init__(scalar=_literals.Scalar(
         primitive=_literals.Primitive(boolean=value)))