Ejemplo n.º 1
0
    def to_literal(self, ctx: FlyteContext, python_val: FlyteSchema,
                   python_type: Type[FlyteSchema],
                   expected: LiteralType) -> Literal:
        if isinstance(python_val, FlyteSchema):
            remote_path = python_val.remote_path
            if remote_path is None or remote_path == "":
                remote_path = ctx.file_access.get_random_remote_path()
            ctx.file_access.put_data(python_val.local_path,
                                     remote_path,
                                     is_multipart=True)
            return Literal(scalar=Scalar(schema=Schema(
                remote_path, self._get_schema_type(python_type))))

        schema = python_type(
            local_path=ctx.file_access.get_random_local_directory(),
            remote_path=ctx.file_access.get_random_remote_directory(),
        )
        try:
            h = SchemaEngine.get_handler(type(python_val))
        except ValueError as e:
            raise TypeTransformerFailedError(
                f"DataFrames of type {type(python_val)} are not supported currently"
            ) from e
        writer = schema.open(type(python_val))
        writer.write(python_val)
        if not h.handles_remote_io:
            ctx.file_access.put_data(schema.local_path,
                                     schema.remote_path,
                                     is_multipart=True)
        return Literal(scalar=Scalar(schema=Schema(
            schema.remote_path, self._get_schema_type(python_type))))
Ejemplo n.º 2
0
def test_enum_type():
    t = TypeEngine.to_literal_type(Color)
    assert t is not None
    assert t.enum_type is not None
    assert t.enum_type.values
    assert t.enum_type.values == [c.value for c in Color]

    ctx = FlyteContextManager.current_context()
    lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color))
    assert lv
    assert lv.scalar
    assert lv.scalar.primitive.string_value == "red"

    v = TypeEngine.to_python_value(ctx, lv, Color)
    assert v
    assert v == Color.RED

    v = TypeEngine.to_python_value(ctx, lv, str)
    assert v
    assert v == "red"

    with pytest.raises(ValueError):
        TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), Color)

    with pytest.raises(ValueError):
        TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value="bad"))), Color)

    with pytest.raises(AssertionError):
        TypeEngine.to_literal_type(UnsupportedEnumValues)
Ejemplo n.º 3
0
    def to_literal(self, ctx: FlyteContext, python_val: FlyteSchema,
                   python_type: Type[FlyteSchema],
                   expected: LiteralType) -> Literal:
        if isinstance(python_val, FlyteSchema):
            remote_path = python_val.remote_path
            if remote_path is None or remote_path == "":
                remote_path = ctx.file_access.get_random_remote_path()
            ctx.file_access.put_data(python_val.local_path,
                                     remote_path,
                                     is_multipart=True)
            return Literal(scalar=Scalar(schema=Schema(
                remote_path, self._get_schema_type(python_type))))

        schema = python_type(
            local_path=ctx.file_access.get_random_local_directory(),
            remote_path=ctx.file_access.get_random_remote_directory(),
        )
        writer = schema.open(type(python_val))
        writer.write(python_val)
        h = SchemaEngine.get_handler(type(python_val))
        if not h.handles_remote_io:
            ctx.file_access.put_data(schema.local_path,
                                     schema.remote_path,
                                     is_multipart=True)
        return Literal(scalar=Scalar(schema=Schema(
            schema.remote_path, self._get_schema_type(python_type))))
Ejemplo n.º 4
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: FlyteDirectory,
        python_type: typing.Type[FlyteDirectory],
        expected: LiteralType,
    ) -> Literal:

        remote_directory = None
        should_upload = True

        # There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory.
        # Handle the FlyteDirectory case
        if isinstance(python_val, FlyteDirectory):
            source_path = python_val.path
            if python_val.remote_directory is False:
                # If the user specified the remote_path to be False, that means no matter what, do not upload
                should_upload = False
            else:
                # Otherwise, if not an "" use the user-specified remote path instead of the random one
                remote_directory = python_val.remote_directory or None

        # Handle the string case
        else:
            if not (isinstance(python_val, os.PathLike)
                    or isinstance(python_val, str)):
                raise AssertionError(
                    f"Expected FlyteDirectory or os.PathLike object, received {type(python_val)}"
                )

            source_path = python_val
            # Only do this check if it's a local directory.
            if not ctx.file_access.is_remote(source_path):
                p = Path(source_path)
                if not p.is_dir():
                    raise AssertionError(
                        f"Expected a directory. {source_path} is not a directory"
                    )

        # For remote values, say s3://some/extant/dir/, we will not upload to Flyte's store (S3/GCS)
        # and just return a literal with a uri equal to the path given
        if ctx.file_access.is_remote(source_path) or not should_upload:
            meta = BlobMetadata(type=self._blob_type(
                format=self.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=source_path)))

        # For local paths, we will upload to the Flyte store (note that for local execution, the remote store is just
        # a subfolder), unless remote_path=False was given
        else:
            if remote_directory is None:
                remote_directory = ctx.file_access.get_random_remote_directory(
                )
            ctx.file_access.put_data(source_path,
                                     remote_directory,
                                     is_multipart=True)
            meta = BlobMetadata(type=self._blob_type(
                format=self.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=remote_directory)))
Ejemplo n.º 5
0
def test_zero_floats():
    ctx = FlyteContext.current_context()

    l0 = Literal(scalar=Scalar(primitive=Primitive(integer=0)))
    l1 = Literal(scalar=Scalar(primitive=Primitive(float_value=0.0)))

    assert TypeEngine.to_python_value(ctx, l0, float) == 0
    assert TypeEngine.to_python_value(ctx, l1, float) == 0
Ejemplo n.º 6
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: typing.Union[FlyteFile, os.PathLike, str],
        python_type: typing.Type[FlyteFile],
        expected: LiteralType,
    ) -> Literal:
        remote_path = None
        should_upload = True

        if python_val is None:
            raise AssertionError("None value cannot be converted to a file.")
        if isinstance(python_val, FlyteFile):
            # If the object has a remote source, then we just convert it back.
            if python_val._remote_source is not None:
                meta = BlobMetadata(type=self._blob_type(
                    format=self.get_format(python_type)))
                return Literal(scalar=Scalar(
                    blob=Blob(metadata=meta, uri=python_val._remote_source)))

            source_path = python_val.path
            if python_val.remote_path is False:
                # If the user specified the remote_path to be False, that means no matter what, do not upload
                should_upload = False
            else:
                # Otherwise, if not an "" use the user-specified remote path instead of the random one
                remote_path = python_val.remote_path or None
        else:
            if not (isinstance(python_val, os.PathLike)
                    or isinstance(python_val, str)):
                raise AssertionError(
                    f"Expected FlyteFile or os.PathLike object, received {type(python_val)}"
                )
            source_path = python_val

        # For remote values, say https://raw.github.com/demo_data.csv, we will not upload to Flyte's store (S3/GCS)
        # and just return a literal with a uri equal to the path given
        if ctx.file_access.is_remote(source_path) or not should_upload:
            # TODO: Add copying functionality so that FlyteFile(path="s3://a", remote_path="s3://b") will copy.
            meta = BlobMetadata(type=self._blob_type(
                format=FlyteFilePathTransformer.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=source_path)))

        # For local paths, we will upload to the Flyte store (note that for local execution, the remote store is just
        # a subfolder), unless remote_path=False was given
        else:
            if remote_path is None:
                remote_path = ctx.file_access.get_random_remote_path(
                    source_path)
            ctx.file_access.put_data(source_path,
                                     remote_path,
                                     is_multipart=False)
            meta = BlobMetadata(type=self._blob_type(
                format=FlyteFilePathTransformer.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=remote_path or source_path)))
Ejemplo n.º 7
0
def test_list_transformer():
    l0 = Literal(scalar=Scalar(primitive=Primitive(integer=3)))
    l1 = Literal(scalar=Scalar(primitive=Primitive(integer=4)))
    lc = LiteralCollection(literals=[l0, l1])
    lit = Literal(collection=lc)

    ctx = FlyteContext.current_context()
    xx = TypeEngine.to_python_value(ctx, lit, typing.List[int])
    assert xx == [3, 4]
Ejemplo n.º 8
0
 def encode(
     self,
     ctx: FlyteContext,
     sd: StructuredDataset,
     df_type: Type,
     protocol: str,
     format: str,
     structured_literal_type: StructuredDatasetType,
 ) -> Literal:
     handler: StructuredDatasetEncoder
     handler = self.get_encoder(df_type, protocol, format)
     sd_model = handler.encode(ctx, sd, structured_literal_type)
     # This block is here in case the encoder did not set the type information in the metadata. Since this literal
     # is special in that it carries around the type itself, we want to make sure the type info therein is at
     # least as good as the type of the interface.
     if sd_model.metadata is None:
         sd_model._metadata = StructuredDatasetMetadata(
             structured_literal_type)
     if sd_model.metadata.structured_dataset_type is None:
         sd_model.metadata._structured_dataset_type = structured_literal_type
     # Always set the format here to the format of the handler.
     # Note that this will always be the same as the incoming format except for when the fallback handler
     # with a format of "" is used.
     sd_model.metadata._structured_dataset_type.format = handler.supported_format
     return Literal(scalar=Scalar(structured_dataset=sd_model))
Ejemplo n.º 9
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: TensorFlow2ONNX,
        python_type: Type[TensorFlow2ONNX],
        expected: LiteralType,
    ) -> Literal:
        python_type, config = extract_config(python_type)

        if config:
            remote_path = ctx.file_access.get_random_remote_path()
            local_path = to_onnx(ctx, python_val.model, config.__dict__.copy())
            ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        else:
            raise TypeTransformerFailedError(f"{python_type}'s config is None")

        return Literal(
            scalar=Scalar(
                blob=Blob(
                    uri=remote_path,
                    metadata=BlobMetadata(
                        type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)
                    ),
                )
            )
        )
Ejemplo n.º 10
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: DoltTable,
        python_type: typing.Type[DoltTable],
        expected: LiteralType,
    ) -> Literal:

        if not isinstance(python_val, DoltTable):
            raise AssertionError(f"Value cannot be converted to a table: {python_val}")

        conf = python_val.config
        if python_val.data is not None and python_val.config.tablename is not None:
            db = dolt.Dolt(conf.db_path)
            with tempfile.NamedTemporaryFile() as f:
                python_val.data.to_csv(f.name, index=False)
                message = f"Generated by Flyte execution id: {ctx.user_space_params.execution_id}"
                dolt_int.save(
                    db=db,
                    tablename=conf.tablename,
                    filename=f.name,
                    branch_conf=conf.branch_conf,
                    meta_conf=conf.meta_conf,
                    remote_conf=conf.remote_conf,
                    save_args=conf.io_args,
                    commit_message=message,
                )

        s = Struct()
        s.update(python_val.to_dict())
        return Literal(Scalar(generic=s))
Ejemplo n.º 11
0
    def to_literal(
        self, ctx: FlyteContext, python_val: os.PathLike, python_type: Type[os.PathLike], expected: LiteralType
    ) -> Literal:
        # TODO we could guess the mimetype and allow the format to be changed at runtime. thus a non existent format
        #      could be replaced with a guess format?

        rpath = ctx.file_access.get_random_remote_path()

        # For remote values, say https://raw.github.com/demo_data.csv, we will not upload to Flyte's store (S3/GCS)
        # and just return a literal with a uri equal to the path given
        if ctx.file_access.is_remote(python_val):
            return Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(expected.blob), uri=python_val)))

        # For local files, we'll upload for the user.
        ctx.file_access.put_data(python_val, rpath, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(expected.blob), uri=rpath)))
Ejemplo n.º 12
0
 def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
     if not dataclasses.is_dataclass(python_val):
         raise AssertionError(
             f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
             f"user defined datatypes in Flytekit"
         )
     if not issubclass(type(python_val), DataClassJsonMixin):
         raise AssertionError(
             f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly"
         )
     return Literal(scalar=Scalar(generic=_json_format.Parse(python_val.to_json(), _struct.Struct())))
Ejemplo n.º 13
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: DatasetProfileView,
     python_type: Type[DatasetProfileView],
     expected: LiteralType,
 ) -> Literal:
     remote_path = ctx.file_access.get_random_remote_directory()
     local_dir = ctx.file_access.get_random_local_path()
     python_val.write(local_dir)
     ctx.file_access.upload(local_dir, remote_path)
     return Literal(scalar=Scalar(blob=Blob(uri=remote_path, metadata=BlobMetadata(type=self._TYPE_INFO))))
Ejemplo n.º 14
0
def test_dolt_table_to_literal_error():
    s = Struct()
    s.update({"dummy": "data"})
    lv = Literal(Scalar(generic=s))

    with pytest.raises(ValueError):
        DoltTableNameTransformer.to_python_value(
            self=None,
            ctx=None,
            lv=lv,
            expected_python_type=DoltTable,
        )
Ejemplo n.º 15
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: pyspark.sql.DataFrame,
     python_type: Type[pyspark.sql.DataFrame],
     expected: LiteralType,
 ) -> Literal:
     remote_path = ctx.file_access.get_random_remote_directory()
     w = SparkDataFrameSchemaWriter(to_path=remote_path,
                                    cols=None,
                                    fmt=SchemaFormat.PARQUET)
     w.write(python_val)
     return Literal(scalar=Scalar(
         schema=Schema(remote_path, self._get_schema_type())))
Ejemplo n.º 16
0
def test_file_format_getting_python_value():
    transformer = TypeEngine.get_transformer(FlyteFile)

    ctx = FlyteContext.current_context()

    # This file probably won't exist, but it's okay. It won't be downloaded unless we try to read the thing returned
    lv = Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(
        type=BlobType(format="txt", dimensionality=0)),
                                         uri="file:///tmp/test")))

    pv = transformer.to_python_value(ctx,
                                     lv,
                                     expected_python_type=FlyteFile["txt"])
    assert isinstance(pv, FlyteFile)
    assert pv.extension() == "txt"
Ejemplo n.º 17
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: pandas.DataFrame,
     python_type: Type[pandas.DataFrame],
     expected: LiteralType,
 ) -> Literal:
     local_dir = ctx.file_access.get_random_local_directory()
     w = PandasSchemaWriter(local_dir=local_dir,
                            cols=None,
                            fmt=SchemaFormat.PARQUET)
     w.write(python_val)
     remote_path = ctx.file_access.get_random_remote_directory()
     ctx.file_access.put_data(local_dir, remote_path, is_multipart=True)
     return Literal(scalar=Scalar(
         schema=Schema(remote_path, self._get_schema_type())))
Ejemplo n.º 18
0
def test_protos():
    ctx = FlyteContext.current_context()

    pb = errors_pb2.ContainerError(code="code", message="message")
    lt = TypeEngine.to_literal_type(errors_pb2.ContainerError)
    assert lt.simple == SimpleType.STRUCT
    assert lt.metadata["pb_type"] == "flyteidl.core.errors_pb2.ContainerError"

    lit = TypeEngine.to_literal(ctx, pb, errors_pb2.ContainerError, lt)
    new_python_val = TypeEngine.to_python_value(ctx, lit, errors_pb2.ContainerError)
    assert new_python_val == pb

    # Test error
    l0 = Literal(scalar=Scalar(primitive=Primitive(integer=4)))
    with pytest.raises(AssertionError):
        TypeEngine.to_python_value(ctx, l0, errors_pb2.ContainerError)
Ejemplo n.º 19
0
    def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
        meta = BlobMetadata(
            type=_core_types.BlobType(
                format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
            )
        )
        # Dump the task output into pickle
        local_dir = ctx.file_access.get_random_local_directory()
        os.makedirs(local_dir, exist_ok=True)
        local_path = ctx.file_access.get_random_local_path()
        uri = os.path.join(local_dir, local_path)
        with open(uri, "w+b") as outfile:
            cloudpickle.dump(python_val, outfile)

        remote_path = ctx.file_access.get_random_remote_path(uri)
        ctx.file_access.put_data(uri, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
Ejemplo n.º 20
0
def test_dolt_table_to_literal(mocker):
    df = pandas.DataFrame()
    mocker.patch("dolt_integrations.core.load", return_value=None)
    mocker.patch("doltcli.Dolt", return_value=None)
    mocker.patch("pandas.read_csv", return_value=df)

    s = Struct()
    s.update({"config": {"db_path": "", "tablename": "t"}})
    lv = Literal(Scalar(generic=s))

    res = DoltTableNameTransformer.to_python_value(
        self=None,
        ctx=None,
        lv=lv,
        expected_python_type=DoltTable,
    )

    assert res.data.equals(df)
Ejemplo n.º 21
0
    def to_literal(self, ctx: FlyteContext, python_val: np.ndarray,
                   python_type: Type[np.ndarray],
                   expected: LiteralType) -> Literal:
        meta = BlobMetadata(type=_core_types.BlobType(
            format=self.NUMPY_ARRAY_FORMAT,
            dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))

        local_path = ctx.file_access.get_random_local_path() + ".npy"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        # save numpy array to a file
        # allow_pickle=False prevents numpy from trying to save object arrays (dtype=object) using pickle
        np.save(file=local_path, arr=python_val, allow_pickle=False)

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(
            blob=Blob(metadata=meta, uri=remote_path)))
Ejemplo n.º 22
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: MyDataset,
     python_type: Type[MyDataset],
     expected: LiteralType,
 ) -> Literal:
     """
     This method is used to convert from given python type object ``MyDataset`` to the Literal representation
     """
     # Step 1: lets upload all the data into a remote place recommended by Flyte
     remote_dir = ctx.file_access.get_random_remote_directory()
     ctx.file_access.upload_directory(python_val.base_dir, remote_dir)
     # Step 2: lets return a pointer to this remote_dir in the form of a literal
     return Literal(
         scalar=Scalar(
             blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO))
         )
     )
Ejemplo n.º 23
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: Union[FlyteFile, FlyteSchema, str],
        python_type: Type[GreatExpectationsType],
        expected: LiteralType,
    ) -> Literal:
        datatype = GreatExpectationsTypeTransformer.get_config(python_type)[0]

        if issubclass(datatype, FlyteSchema):
            return FlyteSchemaTransformer().to_literal(ctx, python_val,
                                                       datatype, expected)
        elif issubclass(datatype, FlyteFile):
            return FlyteFilePathTransformer().to_literal(
                ctx, python_val, datatype, expected)
        elif issubclass(datatype, str):
            return Literal(scalar=Scalar(primitive=Primitive(
                string_value=python_val)))
        else:
            raise TypeError(f"{datatype} is not a supported type")
Ejemplo n.º 24
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: pandas.DataFrame,
     python_type: Type[pandera.typing.DataFrame],
     expected: LiteralType,
 ) -> Literal:
     if isinstance(python_val, pandas.DataFrame):
         local_dir = ctx.file_access.get_random_local_directory()
         w = PandasSchemaWriter(local_dir=local_dir,
                                cols=self._get_col_dtypes(python_type),
                                fmt=SchemaFormat.PARQUET)
         w.write(python_val)
         remote_path = ctx.file_access.get_random_remote_directory()
         ctx.file_access.put_data(local_dir, remote_path, is_multipart=True)
         return Literal(scalar=Scalar(schema=Schema(
             remote_path, self._get_schema_type(python_type))))
     else:
         raise AssertionError(
             f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"
         )
Ejemplo n.º 25
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: PyTorchCheckpoint,
        python_type: Type[PyTorchCheckpoint],
        expected: LiteralType,
    ) -> Literal:
        meta = BlobMetadata(
            type=_core_types.BlobType(
                format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
            )
        )

        local_path = ctx.file_access.get_random_local_path() + ".pt"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        to_save = {}
        for field in fields(python_val):
            value = getattr(python_val, field.name)

            if value and field.name in ["module", "optimizer"]:
                to_save[field.name + "_state_dict"] = getattr(value, "state_dict")()
            elif value and field.name == "hyperparameters":
                if isinstance(value, dict):
                    to_save.update(value)
                elif isinstance(value, tuple):
                    to_save.update(value._asdict())
                elif is_dataclass(value):
                    to_save.update(asdict(value))

        if not to_save:
            raise TypeTransformerFailedError(f"Cannot save empty {python_val}")

        # save checkpoint to a file
        torch.save(to_save, local_path)

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
Ejemplo n.º 26
0
def test_dict_transformer():
    d = DictTransformer()

    def assert_struct(lit: LiteralType):
        assert lit is not None
        assert lit.simple == SimpleType.STRUCT

    def recursive_assert(lit: LiteralType,
                         expected: LiteralType,
                         expected_depth: int = 1,
                         curr_depth: int = 0):
        assert curr_depth <= expected_depth
        assert lit is not None
        if lit.map_value_type is None:
            assert lit == expected
            return
        recursive_assert(lit.map_value_type, expected, expected_depth,
                         curr_depth + 1)

    # Type inference
    assert_struct(d.get_literal_type(dict))
    assert_struct(d.get_literal_type(typing.Dict[int, int]))
    recursive_assert(d.get_literal_type(typing.Dict[str, str]),
                     LiteralType(simple=SimpleType.STRING))
    recursive_assert(d.get_literal_type(typing.Dict[str, int]),
                     LiteralType(simple=SimpleType.INTEGER))
    recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]),
                     LiteralType(simple=SimpleType.DATETIME))
    recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]),
                     LiteralType(simple=SimpleType.DURATION))
    recursive_assert(d.get_literal_type(typing.Dict[str, dict]),
                     LiteralType(simple=SimpleType.STRUCT))
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]),
        LiteralType(simple=SimpleType.STRING),
        expected_depth=2,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=2,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[str,
                                                                    str]]]),
        LiteralType(simple=SimpleType.STRING),
        expected_depth=3,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[str,
                                                                    dict]]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=3,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[int,
                                                                    dict]]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=2,
    )

    ctx = FlyteContext.current_context()

    lit = d.to_literal(ctx, {}, typing.Dict, LiteralType(SimpleType.STRUCT))
    pv = d.to_python_value(ctx, lit, typing.Dict)
    assert pv == {}

    # Literal to python
    with pytest.raises(TypeError):
        d.to_python_value(
            ctx, Literal(scalar=Scalar(primitive=Primitive(integer=10))), dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(), dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})),
                          dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})),
                          typing.Dict[int, str])

    d.to_python_value(
        ctx,
        Literal(map=LiteralMap(
            literals={
                "x": Literal(scalar=Scalar(primitive=Primitive(integer=1)))
            })),
        typing.Dict[str, int],
    )
Ejemplo n.º 27
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: Union[StructuredDataset, typing.Any],
        python_type: Union[Type[StructuredDataset], Type],
        expected: LiteralType,
    ) -> Literal:
        # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
        # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema
        python_type, *attrs = extract_cols_and_format(python_type)
        # In case it's a FlyteSchema
        sdt = StructuredDatasetType(
            format=self.DEFAULT_FORMATS.get(python_type, None))

        if expected and expected.structured_dataset_type:
            sdt = StructuredDatasetType(
                columns=expected.structured_dataset_type.columns,
                format=expected.structured_dataset_type.format,
                external_schema_type=expected.structured_dataset_type.
                external_schema_type,
                external_schema_bytes=expected.structured_dataset_type.
                external_schema_bytes,
            )

        # If the type signature has the StructuredDataset class, it will, or at least should, also be a
        # StructuredDataset instance.
        if issubclass(python_type, StructuredDataset) and isinstance(
                python_val, StructuredDataset):
            # There are three cases that we need to take care of here.

            # 1. A task returns a StructuredDataset that was just a passthrough input. If this happens
            # then return the original literals.StructuredDataset without invoking any encoder
            #
            # Ex.
            #   def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]:
            #       return dataset
            if python_val._literal_sd is not None:
                if python_val.dataframe is not None:
                    raise ValueError(
                        f"Shouldn't have specified both literal {python_val._literal_sd} and dataframe {python_val.dataframe}"
                    )
                return Literal(scalar=Scalar(
                    structured_dataset=python_val._literal_sd))

            # 2. A task returns a python StructuredDataset with a uri.
            # Note: this case is also what happens we start a local execution of a task with a python StructuredDataset.
            #  It gets converted into a literal first, then back into a python StructuredDataset.
            #
            # Ex.
            #   def t2(uri: str) -> Annotated[StructuredDataset, my_cols]
            #       return StructuredDataset(uri=uri)
            if python_val.dataframe is None:
                if not python_val.uri:
                    raise ValueError(
                        f"If dataframe is not specified, then the uri should be specified. {python_val}"
                    )
                sd_model = literals.StructuredDataset(
                    uri=python_val.uri,
                    metadata=StructuredDatasetMetadata(
                        structured_dataset_type=sdt),
                )
                return Literal(scalar=Scalar(structured_dataset=sd_model))

            # 3. This is the third and probably most common case. The python StructuredDataset object wraps a dataframe
            # that we will need to invoke an encoder for. Figure out which encoder to call and invoke it.
            df_type = type(python_val.dataframe)
            if python_val.uri is None:
                protocol = self.DEFAULT_PROTOCOLS[df_type]
            else:
                protocol = protocol_prefix(python_val.uri)
            return self.encode(
                ctx,
                python_val,
                df_type,
                protocol,
                sdt.format or typing.cast(StructuredDataset,
                                          python_val).DEFAULT_FILE_FORMAT,
                sdt,
            )

        # Otherwise assume it's a dataframe instance. Wrap it with some defaults
        fmt = self.DEFAULT_FORMATS[python_type]
        protocol = self.DEFAULT_PROTOCOLS[python_type]
        meta = StructuredDatasetMetadata(
            structured_dataset_type=expected.
            structured_dataset_type if expected else None)

        sd = StructuredDataset(dataframe=python_val, metadata=meta)
        return self.encode(ctx, sd, python_type, protocol, fmt, sdt)
Ejemplo n.º 28
0
def _register_default_type_transformers():
    TypeEngine.register(
        SimpleTransformer(
            "int",
            int,
            _primitives.Integer.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
            lambda x: x.scalar.primitive.integer,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "float",
            float,
            _primitives.Float.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))
                              ),
            _check_and_covert_float,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "bool",
            bool,
            _primitives.Boolean.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
            lambda x: x.scalar.primitive.boolean,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "str",
            str,
            _primitives.String.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x)
                                            )),
            lambda x: x.scalar.primitive.string_value,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "datetime",
            _datetime.datetime,
            _primitives.Datetime.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
            lambda x: x.scalar.primitive.datetime,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "timedelta",
            _datetime.timedelta,
            _primitives.Timedelta.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
            lambda x: x.scalar.primitive.duration,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "none",
            None,
            _type_models.LiteralType(simple=_type_models.SimpleType.NONE),
            lambda x: None,
            lambda x: None,
        ))
    TypeEngine.register(ListTransformer())
    TypeEngine.register(DictTransformer())
    TypeEngine.register(TextIOTransformer())
    TypeEngine.register(PathLikeTransformer())
    TypeEngine.register(BinaryIOTransformer())

    # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
    # doesn't support these currently.
    # Confusing note: typing.NamedTuple is in here even though task functions themselves can return them. We just mean
    # that the return signature of a task can be a NamedTuple that contains another NamedTuple inside it.
    # Also, it's not entirely true that Flyte IDL doesn't support tuples. We can always fake them as structs, but we'll
    # hold off on doing that for now, as we may amend the IDL formally to support tuples.
    TypeEngine.register(RestrictedType("non typed tuple", tuple))
    TypeEngine.register(RestrictedType("non typed tuple", typing.Tuple))
    TypeEngine.register(RestrictedType("named tuple", typing.NamedTuple))
Ejemplo n.º 29
0
 def dict_to_generic_literal(v: dict) -> Literal:
     return Literal(scalar=Scalar(
         generic=_json_format.Parse(_json.dumps(v), _struct.Struct())))
Ejemplo n.º 30
0
 def to_literal(self, ctx: FlyteContext, python_val: T,
                python_type: Type[T], expected: LiteralType) -> Literal:
     struct = Struct()
     struct.update(_MessageToDict(python_val))
     return Literal(scalar=Scalar(generic=struct))