Esempio n. 1
0
    def _build_onnx_runtime_numpy_compile(self, opsets):
        """
        Second part of @see me _build_onnx_runtime_numpy.
        """
        try:
            compiled_code = compile(self.numpy_code_, '<string>', 'exec')
        except SyntaxError as e:  # pragma: no cover
            raise AssertionError("Unable to compile a script due to %r. "
                                 "\n--CODE--\n%s"
                                 "" % (e, print_code(self.numpy_code_))) from e

        glo = globals().copy()
        loc = {
            'numpy': numpy,
            'dict': dict,
            'list': list,
            'print': print,
            'sorted': sorted,
            'collections': collections,
            'inspect': inspect,
            'helper': helper,
            'scipy_special': scipy_special,
            'scipy_distance': scipy_distance,
            'array_feature_extrator': array_feature_extrator,
            'argmin_use_numpy_select_last_index':
            argmin_use_numpy_select_last_index,
            'argmax_use_numpy_select_last_index':
            argmax_use_numpy_select_last_index,
            'make_slice': make_slice
        }
        out = io.StringIO()
        err = io.StringIO()
        with redirect_stdout(out):
            with redirect_stderr(err):
                try:
                    exec(compiled_code, glo, loc)  # pylint: disable=W0122
                except Exception as e:  # pragma: no cover
                    raise AssertionError(
                        "Unable to execute a script due to %r. "
                        "\n--OUT--\n%s\n--ERR--\n%s\n--CODE--\n%s"
                        "" % (e, out.getvalue(), err.getvalue(),
                              print_code(self.numpy_code_))) from e
        names = [k for k in loc if k.startswith('numpy_')]
        if len(names) != 1:
            raise RuntimeError(  # pragma: no cover
                "Unable to guess which function is the one, names=%r."
                "" % list(sorted(names)))
        fct = loc[names[0]]
        if self.runtime == 'numba':
            from numba import jit
            jitter = jit(nopython=self.nopython)
            fct = jitter(fct)
        cl = FunctionTransformer(fct, accept_sparse=True)
        cl.op_version = opsets.get('', __max_supported_opset__)
        return cl