def _jit(func): if func_in_numpy_or_math(func): return func if isinstance(func, Dispatcher): return func vars = inspect.getclosurevars(func) code_scope = dict(vars.nonlocals) code_scope.update(vars.globals) modified = False # check scope variables for k, v in code_scope.items(): # function if callable(v): if (not func_in_numpy_or_math(v)) and (not isinstance( v, Dispatcher)): code_scope[k] = _jit(v) modified = True if modified: func_code = tools.deindent(tools.get_func_source(func)) exec(compile(func_code, '', "exec"), code_scope) func = code_scope[func.__name__] return numba.njit(func) else: return numba.njit(func)
def _add_try_except(code): splits = re.compile(r'\)\s*?:').split(code) if len(splits) == 1: raise ValueError(f"Cannot analyze code:\n{code}") def_line = splits[0] + '):' code_lines = '):'.join(splits[1:]) code_lines = [line for line in code_lines.split('\n') if line.strip()] main_code = tools.deindent("\n".join(code_lines)) code = def_line + '\n' code += ' try:\n' code += tools.indent(main_code, num_tabs=2, spaces_per_tab=2) code += '\n' code += ' except NumbaError:\n' code += ' print(_code_)' return code, {'NumbaError': numba.errors.NumbaError, '_code_': code}
def _jit_Function(func, show_code=False, **jit_setting): assert isinstance(func, Function) # code_scope closure_vars = inspect.getclosurevars(func._f) code_scope = dict(closure_vars.nonlocals) code_scope.update(closure_vars.globals) # code code = tools.deindent(inspect.getsource(func._f)).strip() # arguments arguments = set() # nodes nodes = {v.name: v for v in func._nodes.values()} # arg2call arg2call = dict() for key, node in func._nodes.items(): code, _arguments, _arg2call, _nodes, code_scope = _analyze_cls_func( host=node, code=code, show_code=show_code, code_scope=code_scope, self_name=key, pop_self=True, **jit_setting) arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) # compile new function # code, _scope = _add_try_except(code) # code_scope.update(_scope) if show_code: _show_compiled_codes(code, code_scope) exec(compile(code, '', 'exec'), code_scope) func = code_scope[func._f.__name__] func = numba.jit(func, **jit_setting) # returns return dict(func=func, arguments=arguments, arg2call=arg2call, nodes=nodes)
def separate_variables(func_or_code): """Separate the expressions in a differential equation for each variable. For example, take the HH neuron model as an example: >>> eq_code = ''' >>> def integral(m, h, t, Iext, V): >>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) >>> beta = 4.0 * np.exp(-(V + 65) / 18) >>> dmdt = alpha * (1 - m) - beta * m >>> >>> alpha = 0.07 * np.exp(-(V + 65) / 20.) >>> beta = 1 / (1 + np.exp(-(V + 35) / 10)) >>> dhdt = alpha * (1 - h) - beta * h >>> return dmdt, dhdt >>> ''' >>> analyser = DiffEqReader() >>> analyser.visit(ast.parse(eq_code)) >>> separate_variables(returns=analyser.returns, >>> variables=analyser.variables, >>> right_exprs=analyser.rights, >>> code_lines=analyser.code_lines) {'dhdt': ['alpha = 0.07 * np.exp(-(V + 65) / 20.0)\n', 'beta = 1 / (1 + np.exp(-(V + 35) / 10))\n', 'dhdt = alpha * (1 - h) - beta * h\n'], 'dmdt': ['alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))\n', 'beta = 4.0 * np.exp(-(V + 65) / 18)\n', 'dmdt = alpha * (1 - m) - beta * m\n']} Parameters ---------- func_or_code : callable, str The callable function or the function code. Returns ------- anlysis : dict The expressions for each return variable. """ if callable(func_or_code): func_or_code = tools.deindent(inspect.getsource(func_or_code)) assert isinstance(func_or_code, str) analyser = DiffEqReader() analyser.visit(ast.parse(func_or_code)) returns = analyser.returns variables = analyser.variables right_exprs = analyser.rights code_lines = analyser.code_lines return_requires = OrderedDict([(r, set(tools.get_identifiers(r))) for r in returns]) code_lines_for_returns = OrderedDict([(r, []) for r in returns]) variables_for_returns = OrderedDict([(r, []) for r in returns]) expressions_for_returns = OrderedDict([(r, []) for r in returns]) length = len(variables) reverse_ids = list(reversed([i - length for i in range(length)])) for r in code_lines_for_returns.keys(): for rid in reverse_ids: dep = [] for v in variables[rid]: if v in return_requires[r]: dep.append(v) if len(dep): code_lines_for_returns[r].append(code_lines[rid]) variables_for_returns[r].append(variables[rid]) expr = right_exprs[rid] expressions_for_returns[r].append(expr) for d in dep: return_requires[r].remove(d) return_requires[r].update(tools.get_identifiers(expr)) for r in list(code_lines_for_returns.keys()): code_lines_for_returns[r] = code_lines_for_returns[r][::-1] variables_for_returns[r] = variables_for_returns[r][::-1] expressions_for_returns[r] = expressions_for_returns[r][::-1] analysis = tools.DictPlus( code_lines_for_returns=code_lines_for_returns, variables_for_returns=variables_for_returns, expressions_for_returns=expressions_for_returns, ) return analysis
def _jit_cls_func(f, code=None, host=None, show_code=False, **jit_setting): """JIT a class function. Examples -------- Example 1: the model has static parameters. >>> import brainpy as bp >>> >>> class HH(bp.NeuGroup): >>> def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0, >>> gNa=120., gK=36., gL=0.03, V_th=20., **kwargs): >>> super(HH, self).__init__(size=size, **kwargs) >>> # parameters >>> self.ENa = ENa >>> self.EK = EK >>> self.EL = EL >>> self.C = C >>> self.gNa = gNa >>> self.gK = gK >>> self.gL = gL >>> self.V_th = V_th >>> >>> def derivaitve(self, V, m, h, n, t, Iext): >>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) >>> beta = 4.0 * np.exp(-(V + 65) / 18) >>> dmdt = alpha * (1 - m) - beta * m >>> >>> alpha = 0.07 * np.exp(-(V + 65) / 20.) >>> beta = 1 / (1 + np.exp(-(V + 35) / 10)) >>> dhdt = alpha * (1 - h) - beta * h >>> >>> alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) >>> beta = 0.125 * np.exp(-(V + 65) / 80) >>> dndt = alpha * (1 - n) - beta * n >>> >>> I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) >>> I_K = (self.gK * n ** 4.0) * (V - self.EK) >>> I_leak = self.gL * (V - self.EL) >>> dVdt = (- I_Na - I_K - I_leak + Iext) / self.C >>> >>> return dVdt, dmdt, dhdt, dndt >>> >>> r = _jit_cls_func(HH(10).derivaitve, show_code=True) The recompiled function: ------------------------- def derivaitve(V, m, h, n, t, Iext): alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) beta = 4.0 * np.exp(-(V + 65) / 18) dmdt = alpha * (1 - m) - beta * m alpha = 0.07 * np.exp(-(V + 65) / 20.0) beta = 1 / (1 + np.exp(-(V + 35) / 10)) dhdt = alpha * (1 - h) - beta * h alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) beta = 0.125 * np.exp(-(V + 65) / 80) dndt = alpha * (1 - n) - beta * n I_Na = HH0_gNa * m ** 3.0 * h * (V - HH0_ENa) I_K = HH0_gK * n ** 4.0 * (V - HH0_EK) I_leak = HH0_gL * (V - HH0_EL) dVdt = (-I_Na - I_K - I_leak + Iext) / HH0_C return dVdt, dmdt, dhdt, dndt The namespace of the above function: {'HH0_C': 1.0, 'HH0_EK': -77.0, 'HH0_EL': -54.387, 'HH0_ENa': 50.0, 'HH0_gK': 36.0, 'HH0_gL': 0.03, 'HH0_gNa': 120.0, 'bp': <module 'brainpy' from 'D:\\codes\\Projects\\BrainPy\\brainpy\\__init__.py'>} >>> r['func'] CPUDispatcher(<function derivaitve at 0x0000020DF1647DC0>) >>> r['arguments'] set() >>> r['arg2call'] {} >>> r['nodes'] {'HH0': <__main__.<locals>.HH object at 0x0000020DF1623910>} Example 2: the model has dynamical variables. >>> import brainpy as bp >>> >>> class HH(bp.NeuGroup): >>> def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0, >>> gNa=120., gK=36., gL=0.03, V_th=20., **kwargs): >>> super(HH, self).__init__(size=size, **kwargs) >>> # parameters >>> self.ENa = ENa >>> self.EK = EK >>> self.EL = EL >>> self.C = C >>> self.gNa = gNa >>> self.gK = gK >>> self.gL = gL >>> self.V_th = V_th >>> self.input = bp.math.numpy.Variable(np.zeros(size)) >>> >>> def derivaitve(self, V, m, h, n, t): >>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) >>> beta = 4.0 * np.exp(-(V + 65) / 18) >>> dmdt = alpha * (1 - m) - beta * m >>> >>> alpha = 0.07 * np.exp(-(V + 65) / 20.) >>> beta = 1 / (1 + np.exp(-(V + 35) / 10)) >>> dhdt = alpha * (1 - h) - beta * h >>> >>> alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) >>> beta = 0.125 * np.exp(-(V + 65) / 80) >>> dndt = alpha * (1 - n) - beta * n >>> >>> I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) >>> I_K = (self.gK * n ** 4.0) * (V - self.EK) >>> I_leak = self.gL * (V - self.EL) >>> dVdt = (- I_Na - I_K - I_leak + self.input) / self.C >>> >>> return dVdt, dmdt, dhdt, dndt >>> >>> r = _jit_cls_func(HH(10).derivaitve, show_code=True) The recompiled function: ------------------------- def derivaitve(V, m, h, n, t, HH0_input=None): alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) beta = 4.0 * np.exp(-(V + 65) / 18) dmdt = alpha * (1 - m) - beta * m alpha = 0.07 * np.exp(-(V + 65) / 20.0) beta = 1 / (1 + np.exp(-(V + 35) / 10)) dhdt = alpha * (1 - h) - beta * h alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) beta = 0.125 * np.exp(-(V + 65) / 80) dndt = alpha * (1 - n) - beta * n I_Na = HH0_gNa * m ** 3.0 * h * (V - HH0_ENa) I_K = HH0_gK * n ** 4.0 * (V - HH0_EK) I_leak = HH0_gL * (V - HH0_EL) dVdt = (-I_Na - I_K - I_leak + HH0_input) / HH0_C return dVdt, dmdt, dhdt, dndt The namespace of the above function: {'HH0_C': 1.0, 'HH0_EK': -77.0, 'HH0_EL': -54.387, 'HH0_ENa': 50.0, 'HH0_gK': 36.0, 'HH0_gL': 0.03, 'HH0_gNa': 120.0, 'bp': <module 'brainpy' from 'D:\\codes\\Projects\\BrainPy\\brainpy\\__init__.py'>} >>> r['func'] CPUDispatcher(<function derivaitve at 0x0000020DF1647DC0>) >>> r['arguments'] {'HH0_input'} >>> r['arg2call'] {'HH0_input': 'HH0.input.value'} >>> r['nodes'] {'HH0': <__main__.<locals>.HH object at 0x00000219AE495E80>} Parameters ---------- f code host show_code jit_setting Returns ------- """ host = (host or f.__self__) # data to return arguments = set() arg2call = dict() nodes = Collector() nodes[host.name] = host # code code = (code or tools.deindent(inspect.getsource(f)).strip()) # function name func_name = f.__name__ # code scope closure_vars = inspect.getclosurevars(f) code_scope = dict(closure_vars.nonlocals) code_scope.update(closure_vars.globals) # analyze class function code, _arguments, _arg2call, _nodes, _code_scope = _analyze_cls_func( host=host, code=code, show_code=show_code, **jit_setting) arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) code_scope.update(_code_scope) # compile new function # code, _scope = _add_try_except(code) # code_scope.update(_scope) code_scope_to_compile = code_scope.copy() if show_code: _show_compiled_codes(code, code_scope) exec(compile(code, '', 'exec'), code_scope_to_compile) func = code_scope_to_compile[func_name] func = numba.jit(func, **jit_setting) # returns return dict(func=func, code=code, code_scope=code_scope, arguments=arguments, arg2call=arg2call, nodes=nodes)
def separate_variables(func_or_code): """Separate the expressions in a differential equation for each variable. For example, take the HH neuron model as an example: >>> eq_code = ''' >>> def derivative(V, m, h, n, t, C, gNa, ENa, gK, EK, gL, EL, Iext): >>> alpha = 0.1 * (V + 40) / (1 - bp.math.exp(-(V + 40) / 10)) >>> beta = 4.0 * bp.math.exp(-(V + 65) / 18) >>> dmdt = alpha * (1 - m) - beta * m >>> >>> alpha = 0.07 * bp.math.exp(-(V + 65) / 20.) >>> beta = 1 / (1 + bp.math.exp(-(V + 35) / 10)) >>> dhdt = alpha * (1 - h) - beta * h >>> >>> alpha = 0.01 * (V + 55) / (1 - bp.math.exp(-(V + 55) / 10)) >>> beta = 0.125 * bp.math.exp(-(V + 65) / 80) >>> dndt = alpha * (1 - n) - beta * n >>> >>> I_Na = (gNa * m ** 3.0 * h) * (V - ENa) >>> I_K = (gK * n ** 4.0) * (V - EK) >>> I_leak = gL * (V - EL) >>> dVdt = (- I_Na - I_K - I_leak + Iext) / C >>> >>> return dVdt, dmdt, dhdt, dndt >>> ''' >>> separate_variables(eq_code) {'code_lines_for_returns': {'dVdt': ['I_Na = gNa * m ** 3.0 * h * (V - ENa)\n', 'I_K = gK * n ** 4.0 * (V - EK)\n', 'I_leak = gL * (V - EL)\n', 'dVdt = (-I_Na - I_K - I_leak + Iext) / C\n'], 'dhdt': ['alpha = 0.07 * bp.math.exp(-(V + 65) / 20.0)\n', 'beta = 1 / (1 + bp.math.exp(-(V + 35) / 10))\n', 'dhdt = alpha * (1 - h) - beta * h\n'], 'dmdt': ['alpha = 0.1 * (V + 40) / (1 - ' 'bp.math.exp(-(V + 40) / 10))\n', 'beta = 4.0 * bp.math.exp(-(V + 65) / 18)\n', 'dmdt = alpha * (1 - m) - beta * m\n'], 'dndt': ['alpha = 0.01 * (V + 55) / (1 - ' 'bp.math.exp(-(V + 55) / 10))\n', 'beta = 0.125 * bp.math.exp(-(V + 65) / 80)\n', 'dndt = alpha * (1 - n) - beta * n\n']}, 'expressions_for_returns': {'dVdt': ['gNa * m ** 3.0 * h * (V - ENa)', 'gK * n ** 4.0 * (V - EK)', 'gL * (V - EL)', '(-I_Na - I_K - I_leak + Iext) / C'], 'dhdt': ['0.07 * bp.math.exp(-(V + 65) / 20.0)', '1 / (1 + bp.math.exp(-(V + 35) / 10))', 'alpha * (1 - h) - beta * h'], 'dmdt': ['0.1 * (V + 40) / (1 - ' 'bp.math.exp(-(V + 40) / 10))', '4.0 * bp.math.exp(-(V + 65) / 18)', 'alpha * (1 - m) - beta * m'], 'dndt': ['0.01 * (V + 55) / (1 - ' 'bp.math.exp(-(V + 55) / 10))', '0.125 * bp.math.exp(-(V + 65) / 80)', 'alpha * (1 - n) - beta * n']}, 'variables_for_returns': {'dVdt': [['I_Na'], ['I_K'], ['I_leak'], ['dVdt']], 'dhdt': [['alpha'], ['beta'], ['dhdt']], 'dmdt': [['alpha'], ['beta'], ['dmdt']], 'dndt': [['alpha'], ['beta'], ['dndt']]}} Parameters ---------- func_or_code : callable, str The callable function or the function code. Returns ------- anlysis : dict The expressions for each return variable. """ if callable(func_or_code): if tools.is_lambda_function(func_or_code): raise errors.AnalyzerError( f'Cannot analyze lambda function: {func_or_code}.') func_or_code = tools.deindent(inspect.getsource(func_or_code)) assert isinstance(func_or_code, str) analyser = DiffEqReader() analyser.visit(ast.parse(func_or_code)) returns = analyser.returns variables = analyser.variables right_exprs = analyser.rights code_lines = analyser.code_lines return_requires = OrderedDict([(r, set(tools.get_identifiers(r))) for r in returns]) code_lines_for_returns = OrderedDict([(r, []) for r in returns]) variables_for_returns = OrderedDict([(r, []) for r in returns]) expressions_for_returns = OrderedDict([(r, []) for r in returns]) length = len(variables) reverse_ids = list(reversed([i - length for i in range(length)])) for r in code_lines_for_returns.keys(): for rid in reverse_ids: dep = [] for v in variables[rid]: if v in return_requires[r]: dep.append(v) if len(dep): code_lines_for_returns[r].append(code_lines[rid]) variables_for_returns[r].append(variables[rid]) expr = right_exprs[rid] expressions_for_returns[r].append(expr) for d in dep: return_requires[r].remove(d) return_requires[r].update(tools.get_identifiers(expr)) for r in list(code_lines_for_returns.keys()): code_lines_for_returns[r] = code_lines_for_returns[r][::-1] variables_for_returns[r] = variables_for_returns[r][::-1] expressions_for_returns[r] = expressions_for_returns[r][::-1] analysis = tools.DictPlus( code_lines_for_returns=code_lines_for_returns, variables_for_returns=variables_for_returns, expressions_for_returns=expressions_for_returns, ) return analysis
def analyze_step_func(host, f): """Analyze the step functions in a population. Parameters ---------- f : callable The step function. host : Population The data and the function host. Returns ------- results : dict The code string of the function, the code scope, the data need pass into the arguments, the data need return. """ code_string = tools.deindent(inspect.getsource(f)).strip() tree = ast.parse(code_string) # arguments # --- args = tools.ast2code(ast.fix_missing_locations( tree.body[0].args)).split(',') # code AST analysis # --- formatter = StepFuncReader(host=host) formatter.visit(tree) # data assigned by self.xx in line right # --- self_data_in_right = [] if args[0] in backend.CLASS_KEYWORDS: code = ', \n'.join(formatter.rights) self_data_in_right = re.findall( '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code) self_data_in_right = list(set(self_data_in_right)) # data assigned by self.xxx in line left # --- code = ', \n'.join(formatter.lefts) self_data_without_index_in_left = [] self_data_with_index_in_left = [] if args[0] in backend.CLASS_KEYWORDS: class_p1 = '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b' self_data_without_index_in_left = set(re.findall(class_p1, code)) class_p2 = '(\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*)\\[.*\\]' self_data_with_index_in_left = set(re.findall( class_p2, code)) #- self_data_without_index_in_left # self_data_with_index_in_left = set(re.findall(class_p2, code)) - self_data_without_index_in_left self_data_with_index_in_left = list(self_data_with_index_in_left) self_data_without_index_in_left = list(self_data_without_index_in_left) # code scope # --- closure_vars = inspect.getclosurevars(f) code_scope = dict(closure_vars.nonlocals) code_scope.update(closure_vars.globals) # final # --- self_data_in_right = sorted(self_data_in_right) self_data_without_index_in_left = sorted(self_data_without_index_in_left) self_data_with_index_in_left = sorted(self_data_with_index_in_left) analyzed_results = { 'delay_call': formatter.delay_call, 'code_string': '\n'.join(formatter.lines), 'code_scope': code_scope, 'self_data_in_right': self_data_in_right, 'self_data_without_index_in_left': self_data_without_index_in_left, 'self_data_with_index_in_left': self_data_with_index_in_left, } return analyzed_results