Exemplo n.º 1
0
def test_ref_sub_wf():
    ref_entity = get_reference_entity(
        _identifier_model.ResourceType.WORKFLOW,
        "proj",
        "dom",
        "app.other.sub_wf",
        "123",
        inputs=kwtypes(a=str, b=int),
        outputs={},
    )

    ctx = context_manager.FlyteContext.current_context()
    with pytest.raises(Exception) as e:
        ref_entity()
    assert "You must mock this out" in f"{e}"

    with context_manager.FlyteContextManager.with_context(
            ctx.with_new_compilation_state()) as ctx:
        with pytest.raises(Exception) as e:
            ref_entity()
        assert "Input was not specified" in f"{e}"

        output = ref_entity(a="hello", b=3)
        assert isinstance(output, VoidPromise)

    @workflow
    def wf1(a: str, b: int):
        ref_entity(a=a, b=b)

    serialization_settings = context_manager.SerializationSettings(
        project="test_proj",
        domain="test_domain",
        version="abc",
        image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
        env={},
    )
    with pytest.raises(Exception):
        # Subworkflow as references don't work (probably ever). The reason is because we'd need to make a network call
        # to admin to get the structure of the subworkflow and the whole point of reference entities is that there
        # is no network call.
        get_serializable(OrderedDict(), serialization_settings, wf1)
Exemplo n.º 2
0
def test_container():
    @task
    def t1(a: int) -> (int, str):
        return a + 2, str(a) + "-HELLO"

    t2 = ContainerTask(
        "raw",
        image="alpine",
        inputs=kwtypes(a=int, b=str),
        input_data_dir="/tmp",
        output_data_dir="/tmp",
        command=["cat"],
        arguments=["/tmp/a"],
        requests=Resources(mem="400Mi", cpu="1"),
    )

    task_spec = get_serializable(OrderedDict(),
                                 serialization_settings,
                                 t2,
                                 fast=True)
    assert "pyflyte" not in task_spec.template.container.args
Exemplo n.º 3
0
def test_references():
    rlp = ReferenceLaunchPlan("media",
                              "stg",
                              "some.name",
                              "cafe",
                              inputs=kwtypes(in1=str),
                              outputs=kwtypes())
    lp_model = get_serializable(OrderedDict(), serialization_settings, rlp)
    assert isinstance(lp_model, ReferenceSpec)
    assert isinstance(lp_model.template, ReferenceTemplate)
    assert lp_model.template.id == rlp.reference.id
    assert lp_model.template.resource_type == identifier_models.ResourceType.LAUNCH_PLAN

    rt = ReferenceTask("media",
                       "stg",
                       "some.name",
                       "cafe",
                       inputs=kwtypes(in1=str),
                       outputs=kwtypes())
    task_spec = get_serializable(OrderedDict(), serialization_settings, rt)
    assert isinstance(task_spec, ReferenceSpec)
    assert isinstance(task_spec.template, ReferenceTemplate)
    assert task_spec.template.id == rt.reference.id
    assert task_spec.template.resource_type == identifier_models.ResourceType.TASK

    rw = ReferenceWorkflow("media",
                           "stg",
                           "some.name",
                           "cafe",
                           inputs=kwtypes(in1=str),
                           outputs=kwtypes())
    wf_spec = get_serializable(OrderedDict(), serialization_settings, rw)
    assert isinstance(wf_spec, ReferenceSpec)
    assert isinstance(wf_spec.template, ReferenceTemplate)
    assert wf_spec.template.id == rw.reference.id
    assert wf_spec.template.resource_type == identifier_models.ResourceType.WORKFLOW
import pandas as pd
import pyarrow as pa
import pytest

from flytekit.core import context_manager
from flytekit.core.base_task import kwtypes
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured import basic_dfs
from flytekit.types.structured.structured_dataset import (
    StructuredDataset,
    StructuredDatasetDecoder,
    StructuredDatasetEncoder,
)

my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str)

fields = [("some_int", pa.int32()), ("some_string", pa.string())]
arrow_schema = pa.schema(fields)


def test_pandas():
    df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
    encoder = basic_dfs.PandasToParquetEncodingHandler("/")
    decoder = basic_dfs.ParquetToPandasDecodingHandler("/")

    ctx = context_manager.FlyteContextManager.current_context()
    sd = StructuredDataset(dataframe=df)
    sd_type = StructuredDatasetType(format="parquet")
    sd_lit = encoder.encode(ctx, sd, sd_type)