예제 #1
0
 def setUp(self):
     logger = getLogger('skl2onnx')
     logger.disabled = True
     with warnings.catch_warnings():
         warnings.simplefilter("ignore", ResourceWarning)
         register_rewritten_operators()
     return self
예제 #2
0
 def setup(self, runtime, N, nf, opset, dtype, optim):
     "asv API"
     logger = getLogger('skl2onnx')
     logger.disabled = True
     register_converters()
     register_rewritten_operators()
     with open(self._name(nf, opset, dtype), "rb") as f:
         stored = pickle.load(f)
     self.stored = stored
     self.model = stored['model']
     self.X, self.y = make_n_rows(stored['X'], N, stored['y'])
     onx, rt_, rt_fct_, rt_fct_track_ = self._create_onnx_and_runtime(
         runtime, self.model, self.X, opset, dtype, optim)
     self.onx = onx
     setattr(self, "rt_" + runtime, rt_)
     setattr(self, "rt_fct_" + runtime, rt_fct_)
     setattr(self, "rt_fct_track_" + runtime, rt_fct_track_)
     set_config(assume_finite=True)
 def setUp(self):
     logger = getLogger('skl2onnx')
     logger.disabled = True
     with warnings.catch_warnings():
         warnings.simplefilter("ignore", ResourceWarning)
         res = register_rewritten_operators()
     self.assertGreater(len(res), 2)
     self.assertIn('SklearnFunctionTransformer', res[0])
     self.assertIn('SklearnFunctionTransformer', res[1])
예제 #4
0
 def setUp(self):
     logger = getLogger('skl2onnx')
     logger.disabled = True
     register_rewritten_operators()
예제 #5
0
# -*- coding: utf-8 -*-
import sys
import os
import alabaster
from pyquickhelper.helpgen.default_conf import set_sphinx_variables, get_default_stylesheet
from sklearn.experimental import enable_hist_gradient_boosting
from mlprodict.onnx_conv import register_converters, register_rewritten_operators
register_converters()
try:
    register_rewritten_operators()
except KeyError:
    import warnings
    warnings.warn("converter for HistGradientBoosting* not not exist. "
                  "Upgrade sklearn-onnx")

try:
    import generate_visual_graphs
    import generate_automated_pages
except ImportError:  # pragma: no cover
    this = os.path.dirname(__file__)
    sys.path.append(os.path.join(this, '_exts'))
    import generate_visual_graphs
    import generate_automated_pages

sys.path.insert(0, os.path.abspath(os.path.join(os.path.split(__file__)[0])))

local_template = os.path.join(os.path.abspath(os.path.dirname(__file__)),
                              "phdoc_templates")

set_sphinx_variables(
    __file__,