コード例 #1
0
 def test_plot_isotonic_regression(self):
     data = os.path.join(os.path.dirname(__file__), "data")
     plot = os.path.join(data, "plot_isotonic_regression.py")
     try:
         verify_script(plot, try_onnx=False)
     except NameError as e:
         # Issues with local variable in comprehension list.
         self.assertIn("'y'", str(e))
コード例 #2
0
 def test_plot_kernel_ridge_regression(self):
     data = os.path.join(os.path.dirname(__file__), "data")
     plot = os.path.join(data, "plot_kernel_ridge_regression.py")
     res = verify_script(plot, try_onnx=False)
     self.assertIsInstance(res, dict)
     loc = res['locals']
     self.assertNotEmpty(filter(lambda n: n.endswith('_onnx'), loc))
コード例 #3
0
 def test_plot_anomaly_comparison(self):
     data = os.path.join(os.path.dirname(__file__), "data")
     plot = os.path.join(data, "plot_anomaly_comparison.py")
     res = verify_script(plot, try_onnx=False)
     self.assertIsInstance(res, dict)
     loc = res['locals']
     self.assertIn('algorithm', loc)
コード例 #4
0
 def test_plot_iris_logistic(self):
     data = os.path.join(os.path.dirname(__file__), "data")
     plot = os.path.join(data, "plot_iris_logistic.py")
     res = verify_script(plot)
     self.assertIsInstance(res, dict)
     loc = res['locals']
     self.assertIn('X', loc)
     self.assertIn('logreg', loc)
     self.assertIn('logreg_onnx', loc)
コード例 #5
0
    def test_plot_examples(self):
        fLOG(__file__, self._testMethodName, OutputPrint=__name__ == "__main__")
        datas = ["git"]

        no_onnx = {}
        noconv = {}
        skipped = {}
        issues = {}
        has_onnx = {}

        for data in datas:  # pylint: disable=R1702
            if data == "git":
                temp_git = get_temp_folder(
                    __file__, "temp_sklearn", clean=False)
                examples = os.path.join(temp_git, "examples")
                if not os.path.exists(examples):
                    fLOG('cloning scikit-learn...')
                    clone(temp_git, "github.com", "scikit-learn",
                          "scikit-learn", fLOG=fLOG)
                    fLOG('done.')
                data = examples
            fold = os.path.join(os.path.dirname(__file__), data)
            if not os.path.exists(fold):
                continue
            for ind_root, (root, _, files) in enumerate(os.walk(fold)):
                last = os.path.split(root)[-1]
                if last in TestLONGSklearnExample.skipped_folder:
                    continue
                for ind, nfile in enumerate(files):
                    full_ind = ind_root * 1000 + ind
                    if full_ind < TestLONGSklearnExample.begin:
                        skipped[nfile] = None
                        continue
                    if (full_ind >= TestLONGSklearnExample.end and
                            TestLONGSklearnExample.end >= 0):
                        skipped[nfile] = None
                        continue
                    nfile = nfile.replace("\\", "/")
                    if nfile in TestLONGSklearnExample.skipped_examples:
                        skipped[nfile] = None
                        continue
                    ext = os.path.splitext(nfile)[-1]
                    if ext != '.py':
                        continue
                    name = os.path.split(nfile)[-1]
                    if not name.startswith('plot_'):
                        continue
                    fLOG("verify {}/{}:{} - '{}'".format(
                        ind + 1, len(files), full_ind, nfile))
                    plot = os.path.join(root, nfile)

                    try:
                        res = verify_script(
                            plot, existing_loc=TestLONGSklearnExample.existing_loc)
                    except NotFittedError as e:
                        fLOG('    model was not trained',
                             str(e).split('\n')[0])
                        noconv[nfile] = e
                        continue
                    except MissingShapeCalculator as e:
                        fLOG('    missing converter', str(e).split('\n')[0])
                        noconv[nfile] = e
                        continue
                    except MissingVariableError as e:
                        issues[nfile] = e
                        fLOG('    missing variable', str(e).split('\n')[0])
                        continue
                    except ValueError as e:
                        if skl_version == "0.22.2.post1":
                            fLOG('    value error', str(e).split('\n')[0])
                            continue
                        raise e
                    except (KeyError, NameError, RuntimeError, TypeError,
                            ImportError, AttributeError) as e:
                        issues[nfile] = e
                        fLOG('    local function', str(e).split('\n')[0])
                        continue

                    if res is not None:
                        if any(filter(lambda n: n.endswith('_onnx'), res['locals'])):
                            fLOG('   ONNX ok')
                            has_onnx[nfile] = res['onx_info']
                        else:
                            no_onnx[nfile] = res['locals']
                            fLOG('   no onnx')
                    else:
                        fLOG('   issue')

        if len(has_onnx) == 0:
            raise RuntimeError("Unable to find any example in\n{}".format(
                "\n".join(datas)))

        rows = []
        for n, d in [('no_onnx', no_onnx),
                     ('noconv', noconv),
                     ('issues', issues),
                     ]:
            for k, v in d.items():
                sv = str(v).replace("\r", "").replace("\n", " ")
                rows.append(dict(name=k, result=sv, kind=n))
        for k, v in has_onnx.items():
            rows.append(dict(name=k, result='OK', kind='ONNX'))

        temp = get_temp_folder(__file__, 'temp_plot_examples')
        stats = os.path.join(temp, "stats.csv")
        pandas.DataFrame(rows).to_csv(stats, index=False)
        for k, v in has_onnx.items():
            pkl = os.path.join(temp, k + '.pkl')
            with open(pkl, 'wb') as f:
                pickle.dump(v, f)