示例#1
0
    def to_literal(self, ctx: FlyteContext, python_val: FlyteSchema,
                   python_type: Type[FlyteSchema],
                   expected: LiteralType) -> Literal:
        if isinstance(python_val, FlyteSchema):
            remote_path = python_val.remote_path
            if remote_path is None or remote_path == "":
                remote_path = ctx.file_access.get_random_remote_path()
            ctx.file_access.put_data(python_val.local_path,
                                     remote_path,
                                     is_multipart=True)
            return Literal(scalar=Scalar(schema=Schema(
                remote_path, self._get_schema_type(python_type))))

        schema = python_type(
            local_path=ctx.file_access.get_random_local_directory(),
            remote_path=ctx.file_access.get_random_remote_directory(),
        )
        try:
            h = SchemaEngine.get_handler(type(python_val))
        except ValueError as e:
            raise TypeTransformerFailedError(
                f"DataFrames of type {type(python_val)} are not supported currently"
            ) from e
        writer = schema.open(type(python_val))
        writer.write(python_val)
        if not h.handles_remote_io:
            ctx.file_access.put_data(schema.local_path,
                                     schema.remote_path,
                                     is_multipart=True)
        return Literal(scalar=Scalar(schema=Schema(
            schema.remote_path, self._get_schema_type(python_type))))
示例#2
0
def test_enum_type():
    t = TypeEngine.to_literal_type(Color)
    assert t is not None
    assert t.enum_type is not None
    assert t.enum_type.values
    assert t.enum_type.values == [c.value for c in Color]

    ctx = FlyteContextManager.current_context()
    lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color))
    assert lv
    assert lv.scalar
    assert lv.scalar.primitive.string_value == "red"

    v = TypeEngine.to_python_value(ctx, lv, Color)
    assert v
    assert v == Color.RED

    v = TypeEngine.to_python_value(ctx, lv, str)
    assert v
    assert v == "red"

    with pytest.raises(ValueError):
        TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value=str(Color.RED)))), Color)

    with pytest.raises(ValueError):
        TypeEngine.to_python_value(ctx, Literal(scalar=Scalar(primitive=Primitive(string_value="bad"))), Color)

    with pytest.raises(AssertionError):
        TypeEngine.to_literal_type(UnsupportedEnumValues)
示例#3
0
    def to_literal(self, ctx: FlyteContext, python_val: FlyteSchema,
                   python_type: Type[FlyteSchema],
                   expected: LiteralType) -> Literal:
        if isinstance(python_val, FlyteSchema):
            remote_path = python_val.remote_path
            if remote_path is None or remote_path == "":
                remote_path = ctx.file_access.get_random_remote_path()
            ctx.file_access.put_data(python_val.local_path,
                                     remote_path,
                                     is_multipart=True)
            return Literal(scalar=Scalar(schema=Schema(
                remote_path, self._get_schema_type(python_type))))

        schema = python_type(
            local_path=ctx.file_access.get_random_local_directory(),
            remote_path=ctx.file_access.get_random_remote_directory(),
        )
        writer = schema.open(type(python_val))
        writer.write(python_val)
        h = SchemaEngine.get_handler(type(python_val))
        if not h.handles_remote_io:
            ctx.file_access.put_data(schema.local_path,
                                     schema.remote_path,
                                     is_multipart=True)
        return Literal(scalar=Scalar(schema=Schema(
            schema.remote_path, self._get_schema_type(python_type))))
示例#4
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: FlyteDirectory,
        python_type: typing.Type[FlyteDirectory],
        expected: LiteralType,
    ) -> Literal:

        remote_directory = None
        should_upload = True

        # There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory.
        # Handle the FlyteDirectory case
        if isinstance(python_val, FlyteDirectory):
            source_path = python_val.path
            if python_val.remote_directory is False:
                # If the user specified the remote_path to be False, that means no matter what, do not upload
                should_upload = False
            else:
                # Otherwise, if not an "" use the user-specified remote path instead of the random one
                remote_directory = python_val.remote_directory or None

        # Handle the string case
        else:
            if not (isinstance(python_val, os.PathLike)
                    or isinstance(python_val, str)):
                raise AssertionError(
                    f"Expected FlyteDirectory or os.PathLike object, received {type(python_val)}"
                )

            source_path = python_val
            # Only do this check if it's a local directory.
            if not ctx.file_access.is_remote(source_path):
                p = Path(source_path)
                if not p.is_dir():
                    raise AssertionError(
                        f"Expected a directory. {source_path} is not a directory"
                    )

        # For remote values, say s3://some/extant/dir/, we will not upload to Flyte's store (S3/GCS)
        # and just return a literal with a uri equal to the path given
        if ctx.file_access.is_remote(source_path) or not should_upload:
            meta = BlobMetadata(type=self._blob_type(
                format=self.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=source_path)))

        # For local paths, we will upload to the Flyte store (note that for local execution, the remote store is just
        # a subfolder), unless remote_path=False was given
        else:
            if remote_directory is None:
                remote_directory = ctx.file_access.get_random_remote_directory(
                )
            ctx.file_access.put_data(source_path,
                                     remote_directory,
                                     is_multipart=True)
            meta = BlobMetadata(type=self._blob_type(
                format=self.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=remote_directory)))
示例#5
0
def test_zero_floats():
    ctx = FlyteContext.current_context()

    l0 = Literal(scalar=Scalar(primitive=Primitive(integer=0)))
    l1 = Literal(scalar=Scalar(primitive=Primitive(float_value=0.0)))

    assert TypeEngine.to_python_value(ctx, l0, float) == 0
    assert TypeEngine.to_python_value(ctx, l1, float) == 0
示例#6
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: typing.Union[FlyteFile, os.PathLike, str],
        python_type: typing.Type[FlyteFile],
        expected: LiteralType,
    ) -> Literal:
        remote_path = None
        should_upload = True

        if python_val is None:
            raise AssertionError("None value cannot be converted to a file.")
        if isinstance(python_val, FlyteFile):
            # If the object has a remote source, then we just convert it back.
            if python_val._remote_source is not None:
                meta = BlobMetadata(type=self._blob_type(
                    format=self.get_format(python_type)))
                return Literal(scalar=Scalar(
                    blob=Blob(metadata=meta, uri=python_val._remote_source)))

            source_path = python_val.path
            if python_val.remote_path is False:
                # If the user specified the remote_path to be False, that means no matter what, do not upload
                should_upload = False
            else:
                # Otherwise, if not an "" use the user-specified remote path instead of the random one
                remote_path = python_val.remote_path or None
        else:
            if not (isinstance(python_val, os.PathLike)
                    or isinstance(python_val, str)):
                raise AssertionError(
                    f"Expected FlyteFile or os.PathLike object, received {type(python_val)}"
                )
            source_path = python_val

        # For remote values, say https://raw.github.com/demo_data.csv, we will not upload to Flyte's store (S3/GCS)
        # and just return a literal with a uri equal to the path given
        if ctx.file_access.is_remote(source_path) or not should_upload:
            # TODO: Add copying functionality so that FlyteFile(path="s3://a", remote_path="s3://b") will copy.
            meta = BlobMetadata(type=self._blob_type(
                format=FlyteFilePathTransformer.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=source_path)))

        # For local paths, we will upload to the Flyte store (note that for local execution, the remote store is just
        # a subfolder), unless remote_path=False was given
        else:
            if remote_path is None:
                remote_path = ctx.file_access.get_random_remote_path(
                    source_path)
            ctx.file_access.put_data(source_path,
                                     remote_path,
                                     is_multipart=False)
            meta = BlobMetadata(type=self._blob_type(
                format=FlyteFilePathTransformer.get_format(python_type)))
            return Literal(scalar=Scalar(
                blob=Blob(metadata=meta, uri=remote_path or source_path)))
示例#7
0
def test_list_transformer():
    l0 = Literal(scalar=Scalar(primitive=Primitive(integer=3)))
    l1 = Literal(scalar=Scalar(primitive=Primitive(integer=4)))
    lc = LiteralCollection(literals=[l0, l1])
    lit = Literal(collection=lc)

    ctx = FlyteContext.current_context()
    xx = TypeEngine.to_python_value(ctx, lit, typing.List[int])
    assert xx == [3, 4]
示例#8
0
 def encode(
     self,
     ctx: FlyteContext,
     sd: StructuredDataset,
     df_type: Type,
     protocol: str,
     format: str,
     structured_literal_type: StructuredDatasetType,
 ) -> Literal:
     handler: StructuredDatasetEncoder
     handler = self.get_encoder(df_type, protocol, format)
     sd_model = handler.encode(ctx, sd, structured_literal_type)
     # This block is here in case the encoder did not set the type information in the metadata. Since this literal
     # is special in that it carries around the type itself, we want to make sure the type info therein is at
     # least as good as the type of the interface.
     if sd_model.metadata is None:
         sd_model._metadata = StructuredDatasetMetadata(
             structured_literal_type)
     if sd_model.metadata.structured_dataset_type is None:
         sd_model.metadata._structured_dataset_type = structured_literal_type
     # Always set the format here to the format of the handler.
     # Note that this will always be the same as the incoming format except for when the fallback handler
     # with a format of "" is used.
     sd_model.metadata._structured_dataset_type.format = handler.supported_format
     return Literal(scalar=Scalar(structured_dataset=sd_model))
示例#9
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: TensorFlow2ONNX,
        python_type: Type[TensorFlow2ONNX],
        expected: LiteralType,
    ) -> Literal:
        python_type, config = extract_config(python_type)

        if config:
            remote_path = ctx.file_access.get_random_remote_path()
            local_path = to_onnx(ctx, python_val.model, config.__dict__.copy())
            ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        else:
            raise TypeTransformerFailedError(f"{python_type}'s config is None")

        return Literal(
            scalar=Scalar(
                blob=Blob(
                    uri=remote_path,
                    metadata=BlobMetadata(
                        type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)
                    ),
                )
            )
        )
示例#10
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: DoltTable,
        python_type: typing.Type[DoltTable],
        expected: LiteralType,
    ) -> Literal:

        if not isinstance(python_val, DoltTable):
            raise AssertionError(f"Value cannot be converted to a table: {python_val}")

        conf = python_val.config
        if python_val.data is not None and python_val.config.tablename is not None:
            db = dolt.Dolt(conf.db_path)
            with tempfile.NamedTemporaryFile() as f:
                python_val.data.to_csv(f.name, index=False)
                message = f"Generated by Flyte execution id: {ctx.user_space_params.execution_id}"
                dolt_int.save(
                    db=db,
                    tablename=conf.tablename,
                    filename=f.name,
                    branch_conf=conf.branch_conf,
                    meta_conf=conf.meta_conf,
                    remote_conf=conf.remote_conf,
                    save_args=conf.io_args,
                    commit_message=message,
                )

        s = Struct()
        s.update(python_val.to_dict())
        return Literal(Scalar(generic=s))
示例#11
0
    def to_literal(
        self, ctx: FlyteContext, python_val: os.PathLike, python_type: Type[os.PathLike], expected: LiteralType
    ) -> Literal:
        # TODO we could guess the mimetype and allow the format to be changed at runtime. thus a non existent format
        #      could be replaced with a guess format?

        rpath = ctx.file_access.get_random_remote_path()

        # For remote values, say https://raw.github.com/demo_data.csv, we will not upload to Flyte's store (S3/GCS)
        # and just return a literal with a uri equal to the path given
        if ctx.file_access.is_remote(python_val):
            return Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(expected.blob), uri=python_val)))

        # For local files, we'll upload for the user.
        ctx.file_access.put_data(python_val, rpath, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(expected.blob), uri=rpath)))
示例#12
0
 def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
     if not dataclasses.is_dataclass(python_val):
         raise AssertionError(
             f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
             f"user defined datatypes in Flytekit"
         )
     if not issubclass(type(python_val), DataClassJsonMixin):
         raise AssertionError(
             f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly"
         )
     return Literal(scalar=Scalar(generic=_json_format.Parse(python_val.to_json(), _struct.Struct())))
示例#13
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: DatasetProfileView,
     python_type: Type[DatasetProfileView],
     expected: LiteralType,
 ) -> Literal:
     remote_path = ctx.file_access.get_random_remote_directory()
     local_dir = ctx.file_access.get_random_local_path()
     python_val.write(local_dir)
     ctx.file_access.upload(local_dir, remote_path)
     return Literal(scalar=Scalar(blob=Blob(uri=remote_path, metadata=BlobMetadata(type=self._TYPE_INFO))))
示例#14
0
def test_dolt_table_to_literal_error():
    s = Struct()
    s.update({"dummy": "data"})
    lv = Literal(Scalar(generic=s))

    with pytest.raises(ValueError):
        DoltTableNameTransformer.to_python_value(
            self=None,
            ctx=None,
            lv=lv,
            expected_python_type=DoltTable,
        )
示例#15
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: pyspark.sql.DataFrame,
     python_type: Type[pyspark.sql.DataFrame],
     expected: LiteralType,
 ) -> Literal:
     remote_path = ctx.file_access.get_random_remote_directory()
     w = SparkDataFrameSchemaWriter(to_path=remote_path,
                                    cols=None,
                                    fmt=SchemaFormat.PARQUET)
     w.write(python_val)
     return Literal(scalar=Scalar(
         schema=Schema(remote_path, self._get_schema_type())))
示例#16
0
def test_file_format_getting_python_value():
    transformer = TypeEngine.get_transformer(FlyteFile)

    ctx = FlyteContext.current_context()

    # This file probably won't exist, but it's okay. It won't be downloaded unless we try to read the thing returned
    lv = Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(
        type=BlobType(format="txt", dimensionality=0)),
                                         uri="file:///tmp/test")))

    pv = transformer.to_python_value(ctx,
                                     lv,
                                     expected_python_type=FlyteFile["txt"])
    assert isinstance(pv, FlyteFile)
    assert pv.extension() == "txt"
示例#17
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: pandas.DataFrame,
     python_type: Type[pandas.DataFrame],
     expected: LiteralType,
 ) -> Literal:
     local_dir = ctx.file_access.get_random_local_directory()
     w = PandasSchemaWriter(local_dir=local_dir,
                            cols=None,
                            fmt=SchemaFormat.PARQUET)
     w.write(python_val)
     remote_path = ctx.file_access.get_random_remote_directory()
     ctx.file_access.put_data(local_dir, remote_path, is_multipart=True)
     return Literal(scalar=Scalar(
         schema=Schema(remote_path, self._get_schema_type())))
示例#18
0
def test_protos():
    ctx = FlyteContext.current_context()

    pb = errors_pb2.ContainerError(code="code", message="message")
    lt = TypeEngine.to_literal_type(errors_pb2.ContainerError)
    assert lt.simple == SimpleType.STRUCT
    assert lt.metadata["pb_type"] == "flyteidl.core.errors_pb2.ContainerError"

    lit = TypeEngine.to_literal(ctx, pb, errors_pb2.ContainerError, lt)
    new_python_val = TypeEngine.to_python_value(ctx, lit, errors_pb2.ContainerError)
    assert new_python_val == pb

    # Test error
    l0 = Literal(scalar=Scalar(primitive=Primitive(integer=4)))
    with pytest.raises(AssertionError):
        TypeEngine.to_python_value(ctx, l0, errors_pb2.ContainerError)
示例#19
0
    def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
        meta = BlobMetadata(
            type=_core_types.BlobType(
                format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
            )
        )
        # Dump the task output into pickle
        local_dir = ctx.file_access.get_random_local_directory()
        os.makedirs(local_dir, exist_ok=True)
        local_path = ctx.file_access.get_random_local_path()
        uri = os.path.join(local_dir, local_path)
        with open(uri, "w+b") as outfile:
            cloudpickle.dump(python_val, outfile)

        remote_path = ctx.file_access.get_random_remote_path(uri)
        ctx.file_access.put_data(uri, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
示例#20
0
def test_dolt_table_to_literal(mocker):
    df = pandas.DataFrame()
    mocker.patch("dolt_integrations.core.load", return_value=None)
    mocker.patch("doltcli.Dolt", return_value=None)
    mocker.patch("pandas.read_csv", return_value=df)

    s = Struct()
    s.update({"config": {"db_path": "", "tablename": "t"}})
    lv = Literal(Scalar(generic=s))

    res = DoltTableNameTransformer.to_python_value(
        self=None,
        ctx=None,
        lv=lv,
        expected_python_type=DoltTable,
    )

    assert res.data.equals(df)
示例#21
0
    def to_literal(self, ctx: FlyteContext, python_val: np.ndarray,
                   python_type: Type[np.ndarray],
                   expected: LiteralType) -> Literal:
        meta = BlobMetadata(type=_core_types.BlobType(
            format=self.NUMPY_ARRAY_FORMAT,
            dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE))

        local_path = ctx.file_access.get_random_local_path() + ".npy"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        # save numpy array to a file
        # allow_pickle=False prevents numpy from trying to save object arrays (dtype=object) using pickle
        np.save(file=local_path, arr=python_val, allow_pickle=False)

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(
            blob=Blob(metadata=meta, uri=remote_path)))
示例#22
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: MyDataset,
     python_type: Type[MyDataset],
     expected: LiteralType,
 ) -> Literal:
     """
     This method is used to convert from given python type object ``MyDataset`` to the Literal representation
     """
     # Step 1: lets upload all the data into a remote place recommended by Flyte
     remote_dir = ctx.file_access.get_random_remote_directory()
     ctx.file_access.upload_directory(python_val.base_dir, remote_dir)
     # Step 2: lets return a pointer to this remote_dir in the form of a literal
     return Literal(
         scalar=Scalar(
             blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO))
         )
     )
示例#23
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: Union[FlyteFile, FlyteSchema, str],
        python_type: Type[GreatExpectationsType],
        expected: LiteralType,
    ) -> Literal:
        datatype = GreatExpectationsTypeTransformer.get_config(python_type)[0]

        if issubclass(datatype, FlyteSchema):
            return FlyteSchemaTransformer().to_literal(ctx, python_val,
                                                       datatype, expected)
        elif issubclass(datatype, FlyteFile):
            return FlyteFilePathTransformer().to_literal(
                ctx, python_val, datatype, expected)
        elif issubclass(datatype, str):
            return Literal(scalar=Scalar(primitive=Primitive(
                string_value=python_val)))
        else:
            raise TypeError(f"{datatype} is not a supported type")
示例#24
0
 def to_literal(
     self,
     ctx: FlyteContext,
     python_val: pandas.DataFrame,
     python_type: Type[pandera.typing.DataFrame],
     expected: LiteralType,
 ) -> Literal:
     if isinstance(python_val, pandas.DataFrame):
         local_dir = ctx.file_access.get_random_local_directory()
         w = PandasSchemaWriter(local_dir=local_dir,
                                cols=self._get_col_dtypes(python_type),
                                fmt=SchemaFormat.PARQUET)
         w.write(python_val)
         remote_path = ctx.file_access.get_random_remote_directory()
         ctx.file_access.put_data(local_dir, remote_path, is_multipart=True)
         return Literal(scalar=Scalar(schema=Schema(
             remote_path, self._get_schema_type(python_type))))
     else:
         raise AssertionError(
             f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"
         )
示例#25
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: PyTorchCheckpoint,
        python_type: Type[PyTorchCheckpoint],
        expected: LiteralType,
    ) -> Literal:
        meta = BlobMetadata(
            type=_core_types.BlobType(
                format=self.PYTORCH_CHECKPOINT_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
            )
        )

        local_path = ctx.file_access.get_random_local_path() + ".pt"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        to_save = {}
        for field in fields(python_val):
            value = getattr(python_val, field.name)

            if value and field.name in ["module", "optimizer"]:
                to_save[field.name + "_state_dict"] = getattr(value, "state_dict")()
            elif value and field.name == "hyperparameters":
                if isinstance(value, dict):
                    to_save.update(value)
                elif isinstance(value, tuple):
                    to_save.update(value._asdict())
                elif is_dataclass(value):
                    to_save.update(asdict(value))

        if not to_save:
            raise TypeTransformerFailedError(f"Cannot save empty {python_val}")

        # save checkpoint to a file
        torch.save(to_save, local_path)

        remote_path = ctx.file_access.get_random_remote_path(local_path)
        ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))
示例#26
0
def test_dict_transformer():
    d = DictTransformer()

    def assert_struct(lit: LiteralType):
        assert lit is not None
        assert lit.simple == SimpleType.STRUCT

    def recursive_assert(lit: LiteralType,
                         expected: LiteralType,
                         expected_depth: int = 1,
                         curr_depth: int = 0):
        assert curr_depth <= expected_depth
        assert lit is not None
        if lit.map_value_type is None:
            assert lit == expected
            return
        recursive_assert(lit.map_value_type, expected, expected_depth,
                         curr_depth + 1)

    # Type inference
    assert_struct(d.get_literal_type(dict))
    assert_struct(d.get_literal_type(typing.Dict[int, int]))
    recursive_assert(d.get_literal_type(typing.Dict[str, str]),
                     LiteralType(simple=SimpleType.STRING))
    recursive_assert(d.get_literal_type(typing.Dict[str, int]),
                     LiteralType(simple=SimpleType.INTEGER))
    recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]),
                     LiteralType(simple=SimpleType.DATETIME))
    recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]),
                     LiteralType(simple=SimpleType.DURATION))
    recursive_assert(d.get_literal_type(typing.Dict[str, dict]),
                     LiteralType(simple=SimpleType.STRUCT))
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]),
        LiteralType(simple=SimpleType.STRING),
        expected_depth=2,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=2,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[str,
                                                                    str]]]),
        LiteralType(simple=SimpleType.STRING),
        expected_depth=3,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[str,
                                                                    dict]]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=3,
    )
    recursive_assert(
        d.get_literal_type(typing.Dict[str, typing.Dict[str,
                                                        typing.Dict[int,
                                                                    dict]]]),
        LiteralType(simple=SimpleType.STRUCT),
        expected_depth=2,
    )

    ctx = FlyteContext.current_context()

    lit = d.to_literal(ctx, {}, typing.Dict, LiteralType(SimpleType.STRUCT))
    pv = d.to_python_value(ctx, lit, typing.Dict)
    assert pv == {}

    # Literal to python
    with pytest.raises(TypeError):
        d.to_python_value(
            ctx, Literal(scalar=Scalar(primitive=Primitive(integer=10))), dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(), dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})),
                          dict)
    with pytest.raises(TypeError):
        d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})),
                          typing.Dict[int, str])

    d.to_python_value(
        ctx,
        Literal(map=LiteralMap(
            literals={
                "x": Literal(scalar=Scalar(primitive=Primitive(integer=1)))
            })),
        typing.Dict[str, int],
    )
示例#27
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: Union[StructuredDataset, typing.Any],
        python_type: Union[Type[StructuredDataset], Type],
        expected: LiteralType,
    ) -> Literal:
        # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
        # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema
        python_type, *attrs = extract_cols_and_format(python_type)
        # In case it's a FlyteSchema
        sdt = StructuredDatasetType(
            format=self.DEFAULT_FORMATS.get(python_type, None))

        if expected and expected.structured_dataset_type:
            sdt = StructuredDatasetType(
                columns=expected.structured_dataset_type.columns,
                format=expected.structured_dataset_type.format,
                external_schema_type=expected.structured_dataset_type.
                external_schema_type,
                external_schema_bytes=expected.structured_dataset_type.
                external_schema_bytes,
            )

        # If the type signature has the StructuredDataset class, it will, or at least should, also be a
        # StructuredDataset instance.
        if issubclass(python_type, StructuredDataset) and isinstance(
                python_val, StructuredDataset):
            # There are three cases that we need to take care of here.

            # 1. A task returns a StructuredDataset that was just a passthrough input. If this happens
            # then return the original literals.StructuredDataset without invoking any encoder
            #
            # Ex.
            #   def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]:
            #       return dataset
            if python_val._literal_sd is not None:
                if python_val.dataframe is not None:
                    raise ValueError(
                        f"Shouldn't have specified both literal {python_val._literal_sd} and dataframe {python_val.dataframe}"
                    )
                return Literal(scalar=Scalar(
                    structured_dataset=python_val._literal_sd))

            # 2. A task returns a python StructuredDataset with a uri.
            # Note: this case is also what happens we start a local execution of a task with a python StructuredDataset.
            #  It gets converted into a literal first, then back into a python StructuredDataset.
            #
            # Ex.
            #   def t2(uri: str) -> Annotated[StructuredDataset, my_cols]
            #       return StructuredDataset(uri=uri)
            if python_val.dataframe is None:
                if not python_val.uri:
                    raise ValueError(
                        f"If dataframe is not specified, then the uri should be specified. {python_val}"
                    )
                sd_model = literals.StructuredDataset(
                    uri=python_val.uri,
                    metadata=StructuredDatasetMetadata(
                        structured_dataset_type=sdt),
                )
                return Literal(scalar=Scalar(structured_dataset=sd_model))

            # 3. This is the third and probably most common case. The python StructuredDataset object wraps a dataframe
            # that we will need to invoke an encoder for. Figure out which encoder to call and invoke it.
            df_type = type(python_val.dataframe)
            if python_val.uri is None:
                protocol = self.DEFAULT_PROTOCOLS[df_type]
            else:
                protocol = protocol_prefix(python_val.uri)
            return self.encode(
                ctx,
                python_val,
                df_type,
                protocol,
                sdt.format or typing.cast(StructuredDataset,
                                          python_val).DEFAULT_FILE_FORMAT,
                sdt,
            )

        # Otherwise assume it's a dataframe instance. Wrap it with some defaults
        fmt = self.DEFAULT_FORMATS[python_type]
        protocol = self.DEFAULT_PROTOCOLS[python_type]
        meta = StructuredDatasetMetadata(
            structured_dataset_type=expected.
            structured_dataset_type if expected else None)

        sd = StructuredDataset(dataframe=python_val, metadata=meta)
        return self.encode(ctx, sd, python_type, protocol, fmt, sdt)
示例#28
0
def _register_default_type_transformers():
    TypeEngine.register(
        SimpleTransformer(
            "int",
            int,
            _primitives.Integer.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x))),
            lambda x: x.scalar.primitive.integer,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "float",
            float,
            _primitives.Float.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(float_value=x))
                              ),
            _check_and_covert_float,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "bool",
            bool,
            _primitives.Boolean.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(boolean=x))),
            lambda x: x.scalar.primitive.boolean,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "str",
            str,
            _primitives.String.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(string_value=x)
                                            )),
            lambda x: x.scalar.primitive.string_value,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "datetime",
            _datetime.datetime,
            _primitives.Datetime.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(datetime=x))),
            lambda x: x.scalar.primitive.datetime,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "timedelta",
            _datetime.timedelta,
            _primitives.Timedelta.to_flyte_literal_type(),
            lambda x: Literal(scalar=Scalar(primitive=Primitive(duration=x))),
            lambda x: x.scalar.primitive.duration,
        ))

    TypeEngine.register(
        SimpleTransformer(
            "none",
            None,
            _type_models.LiteralType(simple=_type_models.SimpleType.NONE),
            lambda x: None,
            lambda x: None,
        ))
    TypeEngine.register(ListTransformer())
    TypeEngine.register(DictTransformer())
    TypeEngine.register(TextIOTransformer())
    TypeEngine.register(PathLikeTransformer())
    TypeEngine.register(BinaryIOTransformer())

    # inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
    # doesn't support these currently.
    # Confusing note: typing.NamedTuple is in here even though task functions themselves can return them. We just mean
    # that the return signature of a task can be a NamedTuple that contains another NamedTuple inside it.
    # Also, it's not entirely true that Flyte IDL doesn't support tuples. We can always fake them as structs, but we'll
    # hold off on doing that for now, as we may amend the IDL formally to support tuples.
    TypeEngine.register(RestrictedType("non typed tuple", tuple))
    TypeEngine.register(RestrictedType("non typed tuple", typing.Tuple))
    TypeEngine.register(RestrictedType("named tuple", typing.NamedTuple))
示例#29
0
 def dict_to_generic_literal(v: dict) -> Literal:
     return Literal(scalar=Scalar(
         generic=_json_format.Parse(_json.dumps(v), _struct.Struct())))
示例#30
0
 def to_literal(self, ctx: FlyteContext, python_val: T,
                python_type: Type[T], expected: LiteralType) -> Literal:
     struct = Struct()
     struct.update(_MessageToDict(python_val))
     return Literal(scalar=Scalar(generic=struct))