Exemplo n.º 1
0
    def test_multiple_transform(self):
        x = pandas.DataFrame(data=[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
        x.columns = "X1 X2".split()
        name = self.get_name("mul_1.pb")
        with open(name, "rb") as f:
            content = f.read()

        res = list(OnnxTransformer.enumerate_create(content))
        assert len(res) > 0
        for k, tr in res:
            tr.fit()
            try:
                tr.transform(x)
            except RuntimeError:
                pass
Exemplo n.º 2
0
 def test_pipeline_iris(self):
     iris = load_iris()
     X, y = iris.data, iris.target
     pipe = make_pipeline(PCA(n_components=2), LogisticRegression())
     pipe.fit(X, y)
     onx = convert_sklearn(pipe,
                           initial_types=[('input',
                                           FloatTensorType(
                                               (1, X.shape[1])))])
     onx_bytes = onx.SerializeToString()
     res = list(OnnxTransformer.enumerate_create(onx_bytes))
     outputs = []
     shapes = []
     for k, tr in res:
         outputs.append(k)
         tr.fit()
         y = tr.transform(X)
         self.assertEqual(y.shape[0], X.shape[0])
         shapes.append(y.shape)
     self.assertEqual(len(set(outputs)), len(outputs))
     shapes = set(shapes)
     self.assertEqual(shapes, {(150, 3), (150, 4), (150, 2), (150, )})
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis('off')

###########################
# Visualize intermediate outputs
# ++++++++++++++++++++++++++++++

from skonnxrt.sklapi import OnnxTransformer  # noqa

with open("TfidfVectorizer.onnx", "rb") as f:
    content = f.read()

input = corpus[2]
print("with input:", [input])
for step in OnnxTransformer.enumerate_create(content):
    print("-> node '{}'".format(step[0]))
    step[1].fit()
    print(step[1].transform(input))

#################################
# **Versions used for this example**

import numpy, sklearn  # noqa
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
import onnx, onnxruntime, skl2onnx, skonnxrt  # noqa
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("scikit-onnxruntime: ", skonnxrt.__version__)
print("skl2onnx: ", skl2onnx.__version__)