def test_ref_plain_two_outputs():
    r1 = ReferenceEntity(
        TaskReference("proj", "domain", "some.name", "abc"),
        inputs=kwtypes(a=str, b=int),
        outputs=kwtypes(x=bool, y=int),
    )

    ctx = context_manager.FlyteContext.current_context()
    with ctx.new_compilation_context():
        xx, yy = r1(a="five", b=6)
        # Note - misnomer, these are not SdkNodes, they are annotated.Nodes
        assert xx.ref.node is yy.ref.node
        assert xx.var == "x"
        assert yy.var == "y"
        assert xx.ref.node_id == "n0"
        assert len(xx.ref.node.bindings) == 2

    @task
    def t2(q: bool, r: int) -> str:
        return f"q: {q} r: {r}"

    @workflow
    def wf1(a: str, b: int) -> str:
        x_out, y_out = r1(a=a, b=b)
        return t2(q=x_out, r=y_out)

    @patch(r1)
    def inner_test(ref_mock):
        ref_mock.return_value = (False, 30)
        x = wf1(a="hello", b=10)
        assert x == "q: False r: 30"

    inner_test()
Example #2
0
    def __init__(
        self, name: str, task_config: SagemakerTrainingJobConfig, **kwargs,
    ):
        """
        Args:
            name: name of this specific task. This should be unique within the project. A good strategy is to prefix
                  with the module name
            metadata: Metadata for the task
            task_config: Config to use for the SagemakerBuiltinAlgorithms
        """
        if (
            task_config is None
            or task_config.algorithm_specification is None
            or task_config.training_job_resource_config is None
        ):
            raise ValueError("TaskConfig, algorithm_specification, training_job_resource_config are required")

        input_type = TypeVar(self._content_type_to_blob_format(task_config.algorithm_specification.input_content_type))

        interface = Interface(
            # TODO change train and validation to be FlyteDirectory when available
            inputs=kwtypes(
                static_hyperparameters=dict, train=FlyteDirectory[input_type], validation=FlyteDirectory[input_type]
            ),
            outputs=kwtypes(model=FlyteFile[self.OUTPUT_TYPE]),
        )
        super().__init__(
            self._SAGEMAKER_TRAINING_JOB_TASK, name, interface=interface, task_config=task_config, **kwargs,
        )
def test_lp_with_output():
    ref_lp = get_reference_entity(
        _identifier_model.ResourceType.LAUNCH_PLAN,
        "proj",
        "dom",
        "app.other.flyte_entity",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs=kwtypes(x=bool, y=int),
    )

    @task
    def t1() -> (str, int):
        return "hello", 88

    @task
    def t2(q: bool, r: int) -> str:
        return f"q: {q} r: {r}"

    @workflow
    def wf1() -> str:
        t1_str, t1_int = t1()
        x_out, y_out = ref_lp(a=t1_str, b=t1_int)
        return t2(q=x_out, r=y_out)

    @patch(ref_lp)
    def inner_test(ref_mock):
        ref_mock.return_value = (False, 30)
        x = wf1()
        assert x == "q: False r: 30"
        ref_mock.assert_called_with(a="hello", b=88)

    inner_test()

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, wf1)
    assert sdk_wf.nodes[1].workflow_node.launchplan_ref.project == "proj"
    assert sdk_wf.nodes[1].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
Example #4
0
def test_references():
    rlp = ReferenceLaunchPlan("media",
                              "stg",
                              "some.name",
                              "cafe",
                              inputs=kwtypes(in1=str),
                              outputs=kwtypes())
    sdk_lp = get_serializable(serialization_settings, rlp)
    assert sdk_lp.has_registered

    rt = ReferenceTask("media",
                       "stg",
                       "some.name",
                       "cafe",
                       inputs=kwtypes(in1=str),
                       outputs=kwtypes())
    sdk_task = get_serializable(serialization_settings, rt)
    assert sdk_task.has_registered

    rw = ReferenceWorkflow("media",
                           "stg",
                           "some.name",
                           "cafe",
                           inputs=kwtypes(in1=str),
                           outputs=kwtypes())
    sdk_wf = get_serializable(serialization_settings, rw)
    assert sdk_wf.has_registered
Example #5
0
def test_container():
    @task
    def t1(a: int) -> (int, str):
        return a + 2, str(a) + "-HELLO"

    t2 = ContainerTask(
        "raw",
        image="alpine",
        inputs=kwtypes(a=int, b=str),
        input_data_dir="/tmp",
        output_data_dir="/tmp",
        command=["cat"],
        arguments=["/tmp/a"],
    )

    sdk_task = get_serializable(serialization_settings, t2, fast=True)
    assert "pyflyte" not in sdk_task.container.args
def test_ref_plain_no_outputs():
    r1 = ReferenceEntity(TaskReference("proj", "domain", "some.name", "abc"), inputs=kwtypes(a=str, b=int), outputs={},)

    # Reference entities should always raise an exception when not mocked out.
    with pytest.raises(Exception) as e:
        r1(a="fdsa", b=3)
    assert "You must mock this out" in f"{e}"

    @workflow
    def wf1(a: str, b: int):
        r1(a=a, b=b)

    @patch(r1)
    def inner_test(ref_mock):
        ref_mock.return_value = None
        x = wf1(a="fdsa", b=3)
        assert x is None

    inner_test()

    nt1 = typing.NamedTuple("DummyNamedTuple", t1_int_output=int, c=str)

    @task
    def t1(a: int) -> nt1:
        a = a + 2
        return a, "world-" + str(a)

    @workflow
    def wf2(a: int):
        t1_int, c = t1(a=a)
        r1(a=c, b=t1_int)

    @patch(r1)
    def inner_test2(ref_mock):
        ref_mock.return_value = None
        x = wf2(a=3)
        assert x is None
        ref_mock.assert_called_with(a="world-5", b=5)

    inner_test2()

    # Test nodes
    node_r1 = wf2._nodes[1]
    assert node_r1._upstream_nodes[0] is wf2._nodes[0]
def test_lps(resource_type):
    ref_entity = get_reference_entity(
        resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with ctx.new_compilation_context() as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    sdk_wf = get_serializable(serialization_settings, wf1)
    assert len(sdk_wf.interface.inputs) == 2
    assert len(sdk_wf.interface.outputs) == 0
    assert len(sdk_wf.nodes) == 1
    if resource_type == _identifier_model.ResourceType.LAUNCH_PLAN:
        assert sdk_wf.nodes[0].workflow_node.launchplan_ref.project == "proj"
        assert sdk_wf.nodes[0].workflow_node.launchplan_ref.name == "app.other.flyte_entity"
    elif resource_type == _identifier_model.ResourceType.WORKFLOW:
        assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.project == "proj"
        assert sdk_wf.nodes[0].workflow_node.sub_workflow_ref.name == "app.other.flyte_entity"
    else:
        assert sdk_wf.nodes[0].task_node.reference_id.project == "proj"
        assert sdk_wf.nodes[0].task_node.reference_id.name == "app.other.flyte_entity"
def test_serialization():
    hive_task = HiveTask(
        name="flytekit.demo.hive_task.hivequery1",
        inputs=kwtypes(my_schema=FlyteSchema, ds=str),
        config=HiveConfig(cluster_label="flyte"),
        query_template="""
            set engine=tez;
            insert overwrite directory '{{ .rawOutputDataPrefix }}' stored as parquet  -- will be unique per retry
            select *
            from blah
            where ds = '{{ .Inputs.ds }}' and uri = '{{ .inputs.my_schema }}'
        """,
        # the schema literal's backend uri will be equal to the value of .raw_output_data
        output_schema_type=FlyteSchema,
    )

    @workflow
    def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
        return hive_task(my_schema=in_schema, ds=ds)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="proj",
        domain="dom",
        version="123",
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        env={},
    )
    sdk_task = get_serializable(serialization_settings, hive_task)
    assert "{{ .rawOutputDataPrefix" in sdk_task.custom["query"]["query"]
    assert "insert overwrite directory" in sdk_task.custom["query"]["query"]
    assert len(sdk_task.interface.inputs) == 2
    assert len(sdk_task.interface.outputs) == 1

    sdk_wf = get_serializable(serialization_settings, my_wf)
    assert sdk_wf.interface.outputs["o0"].type.schema is not None
    assert sdk_wf.outputs[0].var == "o0"
    assert sdk_wf.outputs[0].binding.promise.node_id == "n0"
    assert sdk_wf.outputs[0].binding.promise.var == "results"