Exemplo n.º 1
0
    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: TensorFlow2ONNX,
        python_type: Type[TensorFlow2ONNX],
        expected: LiteralType,
    ) -> Literal:
        python_type, config = extract_config(python_type)

        if config:
            remote_path = ctx.file_access.get_random_remote_path()
            local_path = to_onnx(ctx, python_val.model, config.__dict__.copy())
            ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
        else:
            raise TypeTransformerFailedError(f"{python_type}'s config is None")

        return Literal(
            scalar=Scalar(
                blob=Blob(
                    uri=remote_path,
                    metadata=BlobMetadata(
                        type=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE)
                    ),
                )
            )
        )
Exemplo n.º 2
0
def test_dont_convert_remotes():
    @task
    def t1(in1: FlyteFile):
        print(in1)

    @dynamic
    def dyn(in1: FlyteFile):
        t1(in1=in1)

    fd = FlyteFile("s3://anything")

    with context_manager.FlyteContext.current_context(
    ).new_serialization_settings(
            serialization_settings=context_manager.SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(
                    Image(name="name", fqn="image", tag="name")),
                env={},
            )) as ctx:
        with ctx.new_execution_context(
                mode=ExecutionState.Mode.TASK_EXECUTION) as ctx:
            lit = TypeEngine.to_literal(
                ctx, fd, FlyteFile,
                BlobType("",
                         dimensionality=BlobType.BlobDimensionality.SINGLE))
            lm = LiteralMap(literals={"in1": lit})
            wf = dyn.dispatch_execute(ctx, lm)
            assert wf.nodes[0].inputs[
                0].binding.scalar.blob.uri == "s3://anything"
Exemplo n.º 3
0
def test_get_literal_type():
    tf = FlytePickleTransformer()
    lt = tf.get_literal_type(FlytePickle)
    assert lt == LiteralType(
        blob=BlobType(
            format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE
        )
    )
Exemplo n.º 4
0
def test_dont_convert_remotes():
    @task
    def t1(in1: FlyteFile):
        print(in1)

    @dynamic
    def dyn(in1: FlyteFile):
        t1(in1=in1)

    fd = FlyteFile("s3://anything")

    with context_manager.FlyteContextManager.with_context(
        context_manager.FlyteContextManager.current_context().with_serialization_settings(
            flytekit.configuration.SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ):
        ctx = context_manager.FlyteContextManager.current_context()
        with context_manager.FlyteContextManager.with_context(
            ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
        ) as ctx:
            lit = TypeEngine.to_literal(
                ctx, fd, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE)
            )
            lm = LiteralMap(literals={"in1": lit})
            wf = dyn.dispatch_execute(ctx, lm)
            assert wf.nodes[0].inputs[0].binding.scalar.blob.uri == "s3://anything"

            with pytest.raises(TypeError, match="No automatic conversion found from type <class 'int'>"):
                TypeEngine.to_literal(
                    ctx, 3, FlyteFile, BlobType("", dimensionality=BlobType.BlobDimensionality.SINGLE)
                )
Exemplo n.º 5
0
def test_file_format_getting_python_value():
    transformer = TypeEngine.get_transformer(FlyteFile)

    ctx = FlyteContext.current_context()

    # This file probably won't exist, but it's okay. It won't be downloaded unless we try to read the thing returned
    lv = Literal(scalar=Scalar(blob=Blob(metadata=BlobMetadata(
        type=BlobType(format="txt", dimensionality=0)),
                                         uri="file:///tmp/test")))

    pv = transformer.to_python_value(ctx,
                                     lv,
                                     expected_python_type=FlyteFile["txt"])
    assert isinstance(pv, FlyteFile)
    assert pv.extension() == "txt"
Exemplo n.º 6
0
def test_to_python_value_and_literal():
    ctx = context_manager.FlyteContext.current_context()
    tf = NumpyArrayTransformer()
    python_val = np.array([1, 2, 3])
    lt = tf.get_literal_type(np.ndarray)

    lv = tf.to_literal(ctx, python_val, type(python_val), lt)  # type: ignore
    assert lv.scalar.blob.metadata == BlobMetadata(type=BlobType(
        format=NumpyArrayTransformer.NUMPY_ARRAY_FORMAT,
        dimensionality=BlobType.BlobDimensionality.SINGLE,
    ))
    assert lv.scalar.blob.uri is not None

    output = tf.to_python_value(ctx, lv, np.ndarray)
    assert_array_equal(output, python_val)
Exemplo n.º 7
0
class MyDatasetTransformer(TypeTransformer[MyDataset]):
    _TYPE_INFO = BlobType(
        format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART
    )

    def __init__(self):
        super(MyDatasetTransformer, self).__init__(
            name="mydataset-transform", t=MyDataset
        )

    def get_literal_type(self, t: Type[MyDataset]) -> LiteralType:
        """
        This is useful to tell the Flytekit type system that ``MyDataset`` actually refers to what corresponding type
        In this example, we say its of format binary (do not try to introspect) and there are more than one files in it
        """
        return LiteralType(blob=self._TYPE_INFO)

    def to_literal(
        self,
        ctx: FlyteContext,
        python_val: MyDataset,
        python_type: Type[MyDataset],
        expected: LiteralType,
    ) -> Literal:
        """
        This method is used to convert from given python type object ``MyDataset`` to the Literal representation
        """
        # Step 1: lets upload all the data into a remote place recommended by Flyte
        remote_dir = ctx.file_access.get_random_remote_directory()
        ctx.file_access.upload_directory(python_val.base_dir, remote_dir)
        # Step 2: lets return a pointer to this remote_dir in the form of a literal
        return Literal(
            scalar=Scalar(
                blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO))
            )
        )

    def to_python_value(
        self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[MyDataset]
    ) -> MyDataset:
        """
        In this function we want to be able to re-hydrate the custom object from Flyte Literal value
        """
        # Step 1: lets download remote data locally
        local_dir = ctx.file_access.get_random_local_directory()
        ctx.file_access.download_directory(lv.scalar.blob.uri, local_dir)
        # Step 2: create the MyDataset object
        return MyDataset(base_dir=local_dir)
Exemplo n.º 8
0
def test_to_python_value_and_literal():
    ctx = context_manager.FlyteContext.current_context()
    tf = FlytePickleTransformer()
    python_val = "fake_output"
    lt = tf.get_literal_type(FlytePickle)

    lv = tf.to_literal(ctx, python_val, type(python_val), lt)  # type: ignore
    assert lv.scalar.blob.metadata == BlobMetadata(
        type=BlobType(
            format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT,
            dimensionality=BlobType.BlobDimensionality.SINGLE,
        )
    )
    assert lv.scalar.blob.uri is not None

    output = tf.to_python_value(ctx, lv, str)
    assert output == python_val
Exemplo n.º 9
0
def test_to_python_value_and_literal(transformer, python_type, format,
                                     python_val):
    ctx = context_manager.FlyteContext.current_context()
    tf = transformer
    python_val = python_val
    lt = tf.get_literal_type(python_type)

    lv = tf.to_literal(ctx, python_val, type(python_val), lt)  # type: ignore
    assert lv.scalar.blob.metadata == BlobMetadata(type=BlobType(
        format=format,
        dimensionality=BlobType.BlobDimensionality.SINGLE,
    ))
    assert lv.scalar.blob.uri is not None

    output = tf.to_python_value(ctx, lv, python_type)
    if isinstance(python_val, torch.Tensor):
        assert torch.equal(output, python_val)
    elif isinstance(python_val, torch.nn.Module):
        for p1, p2 in zip(output.parameters(), python_val.parameters()):
            if p1.data.ne(p2.data).sum() > 0:
                assert False
        assert True
    else:
        assert isinstance(output, dict)
Exemplo n.º 10
0
 def get_literal_type(self, t: Type[ScikitLearn2ONNX]) -> LiteralType:
     return LiteralType(
         blob=BlobType(format=self.ONNX_FORMAT,
                       dimensionality=BlobType.BlobDimensionality.SINGLE))
Exemplo n.º 11
0
 def _blob_type(self, format: str) -> BlobType:
     return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)
Exemplo n.º 12
0
def test_get_literal_type(transformer, python_type, format):
    tf = transformer
    lt = tf.get_literal_type(python_type)
    assert lt == LiteralType(blob=BlobType(
        format=format, dimensionality=BlobType.BlobDimensionality.SINGLE))
Exemplo n.º 13
0
def test_get_literal_type():
    tf = NumpyArrayTransformer()
    lt = tf.get_literal_type(np.ndarray)
    assert lt == LiteralType(
        blob=BlobType(format=NumpyArrayTransformer.NUMPY_ARRAY_FORMAT,
                      dimensionality=BlobType.BlobDimensionality.SINGLE))