예제 #1
0
def test_type_resolution():
    assert type(TypeEngine.get_transformer(typing.List[int])) == ListTransformer
    assert type(TypeEngine.get_transformer(typing.List)) == ListTransformer
    assert type(TypeEngine.get_transformer(list)) == ListTransformer

    assert type(TypeEngine.get_transformer(typing.Dict[str, int])) == DictTransformer
    assert type(TypeEngine.get_transformer(typing.Dict)) == DictTransformer
    assert type(TypeEngine.get_transformer(dict)) == DictTransformer

    assert type(TypeEngine.get_transformer(int)) == SimpleTransformer

    assert type(TypeEngine.get_transformer(os.PathLike)) == PathLikeTransformer
예제 #2
0
def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]:
    try:
        if hasattr(t, "__origin__") and hasattr(t, "__args__"):
            if t.__origin__ == list:
                return typing.List[_change_unrecognized_type_to_pickle(
                    t.__args__[0])]
            elif t.__origin__ == dict and t.__args__[0] == str:
                return typing.Dict[
                    str,
                    _change_unrecognized_type_to_pickle(t.__args__[1])]
        else:
            TypeEngine.get_transformer(t)
    except ValueError:
        logger.warning(
            f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. "
            f"Pickle can only be used to send objects between the exact same version of Python, "
            f"and we strongly recommend to use python type that flyte support."
        )
        return FlytePickle[t]
    return t
예제 #3
0
def test_file_format_getting_python_value():
    transformer = TypeEngine.get_transformer(FlyteFile)

    ctx = FlyteContext.current_context()

    # This file probably won't exist, but it's okay. It won't be downloaded unless we try to read the thing returned
    lv = Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(
        type=BlobType(format="txt", dimensionality=0)),
                                         uri="file:///tmp/test")))

    pv = transformer.to_python_value(ctx,
                                     lv,
                                     expected_python_type=FlyteFile["txt"])
    assert isinstance(pv, FlyteFile)
    assert pv.extension() == "txt"
예제 #4
0
def test_file_guess():
    transformer = TypeEngine.get_transformer(FlyteFile)
    lt = transformer.get_literal_type(FlyteFile["txt"])
    assert lt.blob.format == "txt"
    assert lt.blob.dimensionality == 0

    fft = transformer.guess_python_type(lt)
    assert issubclass(fft, FlyteFile)
    assert fft.extension() == "txt"

    lt = transformer.get_literal_type(FlyteFile)
    assert lt.blob.format == ""
    assert lt.blob.dimensionality == 0

    fft = transformer.guess_python_type(lt)
    assert issubclass(fft, FlyteFile)
    assert fft.extension() == ""
예제 #5
0
def test_file_formats_getting_literal_type():
    transformer = TypeEngine.get_transformer(FlyteFile)

    lt = transformer.get_literal_type(FlyteFile)
    assert lt.blob.format == ""

    # Works with formats that we define
    lt = transformer.get_literal_type(FlyteFile["txt"])
    assert lt.blob.format == "txt"

    lt = transformer.get_literal_type(FlyteFile[typing.TypeVar("jpg")])
    assert lt.blob.format == "jpg"

    # Empty default to the default
    lt = transformer.get_literal_type(FlyteFile)
    assert lt.blob.format == ""

    lt = transformer.get_literal_type(FlyteFile[typing.TypeVar(".png")])
    assert lt.blob.format == "png"
예제 #6
0
def test_type_resolution():
    assert type(TypeEngine.get_transformer(
        PipelineModel)) == PySparkPipelineModelTransformer