def test_does_not_convert_converted_functions(): def jit_func(): # pragma: no cover return 5 converted = convert_to_jit(jit_func) assert convert_to_jit(converted) is converted assert converted() == 5
def visit_Call(self, node): node = self.generic_visit(node) if hasattr(node.func, 'attr'): # This will be hit where you have module # functions such as np.zeros_like. # We don't want fastats to modify these - # we see the module prefix as an indicator # that the author specifically wants to use # that function, so we just return early here. # To bypass this early return, import the # function and use without the module prefix, ie # from numpy import zeros_like # ... # a = zeros_like(x) return node name = node.func.id if name not in self._globals: # This will be hit for items not in the # function globals, such as `range` return node elif name in self._params: new_name = self.new_name_from_call_name(name) new_func = self._params[name] self._replaced[name] = self._globals[name] self._globals[name] = convert_to_jit(self._globals[name]) self._globals[new_name] = convert_to_jit(new_func) new_node = ast.Call( func=ast.Name(id=new_name, ctx=ast.Load()), args=node.args, keywords=[] ) ast.copy_location(new_node, node) ast.fix_missing_locations(new_node) return new_node else: # Lazy import because it's circular. from fastats.core.ast_transforms.processor import AstProcessor orig_inner_func = self._globals[node.func.id] not_ufunc = not isinstance(orig_inner_func, np.ufunc) not_builtin = not isbuiltin(orig_inner_func) if not_ufunc and not_builtin: self._replaced[node.func.id] = orig_inner_func proc = AstProcessor( orig_inner_func, self._params, self._replaced, self._new_funcs ) new_inner_func = proc.process() self._globals[node.func.id] = convert_to_jit(new_inner_func) ast.fix_missing_locations(node) return node
def process(self): source = inspect.getsource(self.top_level_func) # `ast.parse` can throw an IndentationError if passed # standalone nested function. In this case we take the # more expensive code path through `uncompile`. try: tree = ast.parse(source) except IndentationError: data = uncompile(self.top_level_func.__code__) tree = parse_snippet(*data) # We have to dynamically add the jit to nested functions # in order to get `nopython` mode working correctly. As # a result we always need `jit` in globals. # This can be removed if/when numba supports nested functions # in nopython mode by default. globs = self.top_level_func.__globals__ globs['jit'] = jit t = CallTransform(self._overrides, globs, self._replaced, self._new_funcs) new_tree = t.visit(tree) # TODO remove the fs decorator from within the ast code new_tree.body[0].decorator_list = [ast.Name(id='jit', ctx=ast.Load())] ast.fix_missing_locations(new_tree) if self._debug: pprint(ast.dump(new_tree)) code_obj = recompile(new_tree, '<fastats>', 'exec') self.top_level_func.__code__ = code_obj return convert_to_jit(self.top_level_func)
def test_converts_simple_function(): def add(a, b): # pragma: no cover return a + b jitted = convert_to_jit(add) assert jitted(1, 2) == 3
def test_does_not_convert_jitted_functions(): @jit def jit_func(): # pragma: no cover return 5 assert convert_to_jit(jit_func) is jit_func assert jit_func() == 5
def fs_wrapper(*args, **kwargs): return_callable = kwargs.pop('return_callable', None) # This deliberately mutates the kwargs. # We don't want to have a fs-decorated function # as a kwarg to another, so we undecorate it first. for k, v in kwargs.items(): if hasattr(v, 'undecorated'): kwargs[k] = v.undecorated # TODO : ensure jit function returned if not kwargs: return _func(*args) # TODO : remove fastats keywords such as 'debug' # before passing into AstProcessor new_funcs = {} for v in kwargs.values(): if isfunction(v) and v.__name__ not in kwargs: inner_replaced = {} processor = AstProcessor(v, kwargs, inner_replaced, new_funcs) proc = processor.process() new_funcs[v.__name__] = convert_to_jit(proc) new_kwargs = {} for k, v in kwargs.items(): if new_funcs.get(v.__name__): new_kwargs[k] = new_funcs[v.__name__] kwargs.update(new_kwargs) processor = AstProcessor(_func, kwargs, replaced, new_funcs) proc = processor.process() if return_callable: return convert_to_jit(proc) return convert_to_jit(proc)(*args)
def test_raises_for_non_func(): with raises(TypeError): convert_to_jit('And Now for Something Completely Different') with raises(TypeError): convert_to_jit({'answer': 42}) callable_but_not_function = partial(sum) with raises(TypeError): convert_to_jit(callable_but_not_function)
from unittest import TestCase import numpy as np from numpy.testing import assert_allclose from pytest import mark from fastats.core.ast_transforms.convert_to_jit import convert_to_jit from fastats.linear_algebra import qr, qr_classical_gram_schmidt from fastats.scaling.scaling import standard from tests.data.datasets import SKLearnDataSets qr_jit = convert_to_jit(qr) qr_classical_gram_schmidt_jit = convert_to_jit(qr_classical_gram_schmidt) class QRTestMixin: @staticmethod def assert_orthonormal(Q): n = Q.shape[1] assert_allclose(Q.T @ Q, np.eye(n), atol=1e-10) @staticmethod def check_versus_expectations(Q, Q_expected, R, R_expected, A): assert_allclose(Q, Q_expected) assert_allclose(R, R_expected) assert_allclose(Q @ R, A) def test_ucla_4x3(self): """
import numpy as np import statsmodels.api as sm from pytest import approx, raises from sklearn import datasets from fastats.linear_algebra import ( ols, ols_cholesky, ols_qr, ols_svd, add_intercept, adjusted_r_squared, adjusted_r_squared_no_intercept, fitted_values, mean_standard_error_residuals, r_squared, r_squared_no_intercept, residuals, standard_error, sum_of_squared_residuals, t_statistic, f_statistic, f_statistic_no_intercept, drop_missing) from fastats.core.ast_transforms.convert_to_jit import convert_to_jit drop_missing_jit = convert_to_jit(drop_missing) class BaseOLS(TestCase): def setUp(self): self._data = datasets.load_diabetes() self._labels = [ 'age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6' ] class SklearnDiabetesOLS: """ Linear Regression example taken from the fast.ai course 'Numerical Linear Algebra' """
size = len(A) for i in range(size): for k in range(size): total = nsum(L[i, 0:i] * U[0:i, k]) U[i, k] = A[i, k] - total for k in range(size): if i == k: L[i, i] = 1.0 else: total = nsum(L[k, 0:i] * U[0:i, i]) L[k, i] = (A[k, i] - total) / U[i, i] lu_inplace_jit = convert_to_jit(lu_inplace) def lu(A): """ This performs LU Decomposition on `A`. This takes a square matrix `A`. This scales as O(n^3). This allocates `L` and `U` on each call. Example ------- This is the example from wikipedia:
from unittest import TestCase import numpy as np from numpy.testing import assert_allclose from pytest import approx from fastats.linear_algebra import lu, lu_inplace, lu_compact from fastats.core.ast_transforms.convert_to_jit import convert_to_jit lu_jit = convert_to_jit(lu) lu_compact_jit = convert_to_jit(lu_compact) class LUDecompValidator: """ This is a mixin class which tests both the raw Python and the jit-compiled version of the `lu()` function. """ A, L, U = None, None, None def setUp(self): self._A = np.array(self.A) def test_lu_outputs_numpy(self): L = np.zeros_like(self._A) U = np.zeros_like(self._A) lu_inplace(self._A, L, U) assert L.tolist() == self.L
def test_does_not_convert_math_builtins(): for func in (math.atan2, math.atanh, math.degrees, math.exp, math.floor, math.log, math.sin, math.sinh, math.tan, math.tanh): assert convert_to_jit(func) is func