def _replace_expressions(self, expressions, name, y_sub, t_sub=None): """Replace expressions of df part. Parameters ---------- expressions : list, tuple The list/tuple of expressions. name : str The name of the new expression. y_sub : str The new name of the variable "y". t_sub : str, optional The new name of the variable "t". Returns ------- list_of_expr : list A list of expressions. """ return_expressions = [] # replacements replacement = {self.var_name: y_sub} if t_sub is not None: replacement[self.t_name] = t_sub # replace variables in expressions for expr in expressions: replace = False identifiers = expr.identifiers for repl_var in replacement.keys(): if repl_var in identifiers: replace = True break if replace: code = tools.word_replace(expr.code, replacement) new_expr = Expression(f"{expr.var_name}_{name}", code) return_expressions.append(new_expr) replacement[expr.var_name] = new_expr.var_name return return_expressions
def visit_For(self, node): iter_ = tools.ast2code(ast.fix_missing_locations(node.iter)) if iter_.strip() == self.iter_name: data_to_replace = Collector() final_node = ast.Module(body=[]) self.success = True # target if not isinstance(node.target, ast.Name): raise errors.BrainPyError( f'Only support scalar iter, like "for x in xxxx:", not "for ' f'{tools.ast2code(ast.fix_missing_locations(node.target))} ' f'in {iter_}:') target = node.target.id # for loop values for i, value in enumerate(self.loop_values): # module and code module = ast.Module(body=deepcopy(node).body) code = tools.ast2code(module) if isinstance(value, Base): # transform Base objects r = _analyze_cls_func_body(host=value, self_name=target, code=code, tree=module, show_code=self.show_code, **self.jit_setting) new_code, arguments, arg2call, nodes, code_scope = r self.arguments.update(arguments) self.arg2call.update(arg2call) self.arg2call.update(arg2call) self.nodes.update(nodes) self.code_scope.update(code_scope) final_node.body.extend(ast.parse(new_code).body) elif callable(value): # transform functions r = _jit_func(obj_or_fun=value, show_code=self.show_code, **self.jit_setting) tree = _replace_func_call_by_tree( deepcopy(module), func_call=target, arg_to_append=r['arguments'], new_func_name=f'{target}_{i}') # update import parameters self.arguments.update(r['arguments']) self.arg2call.update(r['arg2call']) self.nodes.update(r['nodes']) # replace the data if isinstance(value, Base): host = value replace_name = f'{host.name}_{target}' elif hasattr(value, '__self__') and isinstance( value.__self__, Base): host = value.__self__ replace_name = f'{host.name}_{target}' else: replace_name = f'{target}_{i}' self.code_scope[replace_name] = r['func'] data_to_replace[f'{target}_{i}'] = replace_name final_node.body.extend(tree.body) else: raise errors.BrainPyError( f'Only support JIT an iterable objects of function ' f'or Base object, but we got:\n\n {value}') # replace words final_code = tools.ast2code(final_node) final_code = tools.word_replace(final_code, data_to_replace, exclude_dot=True) final_node = ast.parse(final_code) else: final_node = node self.generic_visit(final_node) return final_node
def _analyze_cls_func_body(host, self_name, code, tree, show_code=False, has_func_def=False, **jit_setting): arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict() # all self data self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code) self_data = list(set(self_data)) # analyze variables and functions accessed by the self.xx data_to_replace = {} for key in self_data: split_keys = key.split('.') if len(split_keys) < 2: raise errors.BrainPyError # get target and data target = host for i in range(1, len(split_keys)): next_target = getattr(target, split_keys[i]) if isinstance(next_target, Integrator): break if not isinstance(next_target, Base): break target = next_target else: raise errors.BrainPyError data = getattr(target, split_keys[i]) # analyze data if isinstance(data, math.numpy.Variable): # data is a variable arguments.add(f'{target.name}_{split_keys[i]}') arg2call[ f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value' nodes[target.name] = target # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif isinstance(data, np.random.RandomState): # data is a RandomState # replace RandomState code_scope[f'{target.name}_{split_keys[i]}'] = np.random # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif callable(data): # data is a function assert len(split_keys) == i + 1 r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting) # if len(r['arguments']): tree = _replace_func_call_by_tree(tree, func_call=key, arg_to_append=r['arguments']) arguments.update(r['arguments']) arg2call.update(r['arg2call']) nodes.update(r['nodes']) code_scope[f'{target.name}_{split_keys[i]}'] = r['func'] data_to_replace[ key] = f'{target.name}_{split_keys[i]}' # replace the data elif isinstance( data, (dict, list, tuple)): # data is a list/tuple/dict of function/object # get all values if isinstance(data, dict): # check dict if len(split_keys) != i + 2 and split_keys[-1] != 'values': raise errors.BrainPyError( f'Only support iter dict.values(). while we got ' f'dict.{split_keys[-1]} for data: \n\n{data}') values = list(data.values()) iter_name = key + '()' else: # check list / tuple assert len(split_keys) == i + 1 values = list(data) iter_name = key if len(values) > 0: if not (callable(values[0]) or isinstance(values[0], Base)): code_scope[f'{target.name}_{split_keys[i]}'] = data if len(split_keys) == i + 1: data_to_replace[ key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' continue # raise errors.BrainPyError(f'Only support JIT an iterable objects of function ' # f'or Base object, but we got:\n\n {values[0]}') # replace this for-loop r = _replace_this_forloop(tree=tree, iter_name=iter_name, loop_values=values, show_code=show_code, **jit_setting) tree, _arguments, _arg2call, _nodes, _code_scope = r arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) code_scope.update(_code_scope) else: # constants code_scope[f'{target.name}_{split_keys[i]}'] = data # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' if has_func_def: tree.body[0].decorator_list.clear() tree.body[0].args.args.extend( [ast.Name(id=a) for a in sorted(arguments)]) tree.body[0].args.defaults.extend( [ast.Constant(None) for _ in sorted(arguments)]) tree.body[0].args.kwarg = None # replace words code = tools.ast2code(tree) code = tools.word_replace(code, data_to_replace, exclude_dot=True) return code, arguments, arg2call, nodes, code_scope
def get_integral_step(diff_eq, *args): dt = backend.get_dt() f_expressions = diff_eq.get_f_expressions( substitute_vars=diff_eq.var_name) # code lines code_lines = [str(expr) for expr in f_expressions[:-1]] # get the linear system using sympy f_res = f_expressions[-1] df_expr = ast_analysis.str2sympy(f_res.code).expr.expand() s_df = sympy.Symbol(f"{f_res.var_name}") code_lines.append(f'{s_df.name} = {ast_analysis.sympy2str(df_expr)}') var = sympy.Symbol(diff_eq.var_name, real=True) # get df part s_linear = sympy.Symbol(f'_{diff_eq.var_name}_linear') s_linear_exp = sympy.Symbol(f'_{diff_eq.var_name}_linear_exp') s_df_part = sympy.Symbol(f'_{diff_eq.var_name}_df_part') if df_expr.has(var): # linear linear = sympy.collect(df_expr, var, evaluate=False)[var] code_lines.append( f'{s_linear.name} = {ast_analysis.sympy2str(linear)}') # linear exponential linear_exp = sympy.exp(linear * dt) code_lines.append( f'{s_linear_exp.name} = {ast_analysis.sympy2str(linear_exp)}') # df part df_part = (s_linear_exp - 1) / s_linear * s_df code_lines.append( f'{s_df_part.name} = {ast_analysis.sympy2str(df_part)}') else: # linear exponential code_lines.append(f'{s_linear_exp.name} = sqrt({dt})') # df part code_lines.append( f'{s_df_part.name} = {ast_analysis.sympy2str(dt * s_df)}') # get dg part if diff_eq.is_stochastic: # dW noise = f'_normal_like_({diff_eq.var_name})' code_lines.append(f'_{diff_eq.var_name}_dW = {noise}') # expressions of the stochastic part g_expressions = diff_eq.get_g_expressions() code_lines.extend([str(expr) for expr in g_expressions[:-1]]) g_expr = g_expressions[-1].code # get the dg_part s_dg_part = sympy.Symbol(f'_{diff_eq.var_name}_dg_part') code_lines.append( f'_{diff_eq.var_name}_dg_part = {g_expr} * _{diff_eq.var_name}_dW' ) else: s_dg_part = 0 # update expression update = var + s_df_part + s_dg_part * s_linear_exp # The actual update step code_lines.append( f'{diff_eq.var_name} = {ast_analysis.sympy2str(update)}') return_expr = ', '.join([diff_eq.var_name] + diff_eq.return_intermediates) code_lines.append(f'_res = {return_expr}') # final code = '\n'.join(code_lines) subs_dict = { arg: f'_{arg}' for arg in diff_eq.func_args + diff_eq.expr_names } code = tools.word_replace(code, subs_dict) return code
def class2func(cls_func, host, func_name=None, show_code=False): """Transform the function in a class into the ordinary function which is compatible with the Numba JIT compilation. Parameters ---------- cls_func : function The function of the instantiated class. func_name : str The function name. If not given, it will get the function by `cls_func.__name__`. show_code : bool Whether show the code. Returns ------- new_func : function The transformed function. """ class_arg, arguments = utils.get_args(cls_func) func_name = cls_func.__name__ if func_name is None else func_name host_name = host.name # arguments 1 calls = [] for arg in arguments: if hasattr(host, arg): calls.append(f'{host_name}.{arg}') elif arg in backend.SYSTEM_KEYWORDS: calls.append(arg) else: raise errors.ModelDefError( f'Step function "{func_name}" of {host} ' f'define an unknown argument "{arg}" which is not ' f'an attribute of {host} nor the system keywords ' f'{backend.SYSTEM_KEYWORDS}.') # analysis analyzed_results = analyze_step_func(host=host, f=cls_func) delay_call = analyzed_results['delay_call'] # code_string = analyzed_results['code_string'] main_code = analyzed_results['code_string'] code_scope = analyzed_results['code_scope'] self_data_in_right = analyzed_results['self_data_in_right'] self_data_without_index_in_left = analyzed_results[ 'self_data_without_index_in_left'] self_data_with_index_in_left = analyzed_results[ 'self_data_with_index_in_left'] # main_code = get_func_body_code(code_string) num_indent = get_num_indent(main_code) data_need_pass = sorted( list(set(self_data_in_right + self_data_with_index_in_left))) data_need_return = self_data_without_index_in_left # check delay replaces_early = {} replaces_later = {} if len(delay_call) > 0: for delay_ in delay_call.values(): # delay_ = dict(type=calls[-1], # args=args, # keywords=keywords, # kws_append=kws_append, # func=func, # org_call=org_call, # rep_call=rep_call, # data_need_pass=data_need_pass) if delay_['type'] == 'push': if len(delay_['args'] + delay_['keywords']) == 2: func = numba.njit(delay.push_type2) elif len(delay_['args'] + delay_['keywords']) == 1: func = numba.njit(delay.push_type1) else: raise ValueError(f'Unknown delay push. {delay_}') else: if len(delay_['args'] + delay_['keywords']) == 1: func = numba.njit(delay.pull_type1) elif len(delay_['args'] + delay_['keywords']) == 0: func = numba.njit(delay.pull_type0) else: raise ValueError(f'Unknown delay pull. {delay_}') delay_call_name = delay_['func'] data_need_pass.remove(delay_call_name) data_need_pass.extend(delay_['data_need_pass']) replaces_early[delay_['org_call']] = delay_['rep_call'] replaces_later[delay_call_name] = delay_call_name.replace('.', '_') code_scope[delay_call_name.replace('.', '_')] = func for target, dest in replaces_early.items(): main_code = main_code.replace(target, dest) # main_code = tools.word_replace(main_code, replaces_early) # arguments 2: data need pass new_args = arguments + [] for data in sorted(set(data_need_pass)): splits = data.split('.') replaces_later[data] = data.replace('.', '_') obj = host for attr in splits[1:]: obj = getattr(obj, attr) if callable(obj): code_scope[data.replace('.', '_')] = obj continue new_args.append(data.replace('.', '_')) calls.append('.'.join([host_name] + splits[1:])) # data need return assigns = [] returns = [] for data in data_need_return: splits = data.split('.') assigns.append('.'.join([host_name] + splits[1:])) returns.append(data.replace('.', '_')) replaces_later[data] = data.replace('.', '_') # code scope code_scope[host_name] = host # codes header = f'def new_{func_name}({", ".join(new_args)}):\n' main_code = header + tools.indent(main_code, spaces_per_tab=2) if len(returns): main_code += f'\n{" " * num_indent + " "}return {", ".join(returns)}' main_code = tools.word_replace(main_code, replaces_later) if show_code: print(main_code) print(code_scope) print() # recompile exec(compile(main_code, '', 'exec'), code_scope) func = code_scope[f'new_{func_name}'] func = numba.jit(**NUMBA_PROFILE)(func) return func, calls, assigns