def visit_Call(self, node, level=0): if getattr(node, 'starargs', None) is not None: raise ValueError( "Variable number of arguments (*args) are not supported") if getattr(node, 'kwargs', None) is not None: raise ValueError("Keyword arguments (**kwargs) are not supported") # get function name call = tools.ast2code(node.func) if call == self.func_name: # args args = [self.generic_visit(arg) for arg in node.args] # remove self arg if self.remove_self: if args[0].id == self.remove_self: args.pop(0) # kwargs kwargs = [self.generic_visit(keyword) for keyword in node.keywords] # new kwargs arg_to_append = deepcopy(self.arg_to_append) for arg in kwargs: if arg.arg in arg_to_append: arg_to_append.remove(arg.arg) if len(arg_to_append): code = f'f({", ".join([f"{k}={k}" for k in arg_to_append])})' tree = ast.parse(code) new_keywords = tree.body[0].value.keywords kwargs.extend(new_keywords) # final function if self.new_func_name: func_call = ast.parse( f'{self.new_func_name}()').body[0].value.func else: func_call = node.func return ast.Call(func=func_call, args=args, keywords=kwargs) return node
def visit_For(self, node): iter_ = tools.ast2code(ast.fix_missing_locations(node.iter)) if iter_.strip() == self.iter_name: data_to_replace = Collector() final_node = ast.Module(body=[]) self.success = True # target if not isinstance(node.target, ast.Name): raise errors.BrainPyError( f'Only support scalar iter, like "for x in xxxx:", not "for ' f'{tools.ast2code(ast.fix_missing_locations(node.target))} ' f'in {iter_}:') target = node.target.id # for loop values for i, value in enumerate(self.loop_values): # module and code module = ast.Module(body=deepcopy(node).body) code = tools.ast2code(module) if isinstance(value, Base): # transform Base objects r = _analyze_cls_func_body(host=value, self_name=target, code=code, tree=module, show_code=self.show_code, **self.jit_setting) new_code, arguments, arg2call, nodes, code_scope = r self.arguments.update(arguments) self.arg2call.update(arg2call) self.arg2call.update(arg2call) self.nodes.update(nodes) self.code_scope.update(code_scope) final_node.body.extend(ast.parse(new_code).body) elif callable(value): # transform functions r = _jit_func(obj_or_fun=value, show_code=self.show_code, **self.jit_setting) tree = _replace_func_call_by_tree( deepcopy(module), func_call=target, arg_to_append=r['arguments'], new_func_name=f'{target}_{i}') # update import parameters self.arguments.update(r['arguments']) self.arg2call.update(r['arg2call']) self.nodes.update(r['nodes']) # replace the data if isinstance(value, Base): host = value replace_name = f'{host.name}_{target}' elif hasattr(value, '__self__') and isinstance( value.__self__, Base): host = value.__self__ replace_name = f'{host.name}_{target}' else: replace_name = f'{target}_{i}' self.code_scope[replace_name] = r['func'] data_to_replace[f'{target}_{i}'] = replace_name final_node.body.extend(tree.body) else: raise errors.BrainPyError( f'Only support JIT an iterable objects of function ' f'or Base object, but we got:\n\n {value}') # replace words final_code = tools.ast2code(final_node) final_code = tools.word_replace(final_code, data_to_replace, exclude_dot=True) final_node = ast.parse(final_code) else: final_node = node self.generic_visit(final_node) return final_node
def _analyze_cls_func_body(host, self_name, code, tree, show_code=False, has_func_def=False, **jit_setting): arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict() # all self data self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code) self_data = list(set(self_data)) # analyze variables and functions accessed by the self.xx data_to_replace = {} for key in self_data: split_keys = key.split('.') if len(split_keys) < 2: raise errors.BrainPyError # get target and data target = host for i in range(1, len(split_keys)): next_target = getattr(target, split_keys[i]) if isinstance(next_target, Integrator): break if not isinstance(next_target, Base): break target = next_target else: raise errors.BrainPyError data = getattr(target, split_keys[i]) # analyze data if isinstance(data, math.numpy.Variable): # data is a variable arguments.add(f'{target.name}_{split_keys[i]}') arg2call[ f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value' nodes[target.name] = target # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif isinstance(data, np.random.RandomState): # data is a RandomState # replace RandomState code_scope[f'{target.name}_{split_keys[i]}'] = np.random # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif callable(data): # data is a function assert len(split_keys) == i + 1 r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting) # if len(r['arguments']): tree = _replace_func_call_by_tree(tree, func_call=key, arg_to_append=r['arguments']) arguments.update(r['arguments']) arg2call.update(r['arg2call']) nodes.update(r['nodes']) code_scope[f'{target.name}_{split_keys[i]}'] = r['func'] data_to_replace[ key] = f'{target.name}_{split_keys[i]}' # replace the data elif isinstance( data, (dict, list, tuple)): # data is a list/tuple/dict of function/object # get all values if isinstance(data, dict): # check dict if len(split_keys) != i + 2 and split_keys[-1] != 'values': raise errors.BrainPyError( f'Only support iter dict.values(). while we got ' f'dict.{split_keys[-1]} for data: \n\n{data}') values = list(data.values()) iter_name = key + '()' else: # check list / tuple assert len(split_keys) == i + 1 values = list(data) iter_name = key if len(values) > 0: if not (callable(values[0]) or isinstance(values[0], Base)): code_scope[f'{target.name}_{split_keys[i]}'] = data if len(split_keys) == i + 1: data_to_replace[ key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' continue # raise errors.BrainPyError(f'Only support JIT an iterable objects of function ' # f'or Base object, but we got:\n\n {values[0]}') # replace this for-loop r = _replace_this_forloop(tree=tree, iter_name=iter_name, loop_values=values, show_code=show_code, **jit_setting) tree, _arguments, _arg2call, _nodes, _code_scope = r arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) code_scope.update(_code_scope) else: # constants code_scope[f'{target.name}_{split_keys[i]}'] = data # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' if has_func_def: tree.body[0].decorator_list.clear() tree.body[0].args.args.extend( [ast.Name(id=a) for a in sorted(arguments)]) tree.body[0].args.defaults.extend( [ast.Constant(None) for _ in sorted(arguments)]) tree.body[0].args.kwarg = None # replace words code = tools.ast2code(tree) code = tools.word_replace(code, data_to_replace, exclude_dot=True) return code, arguments, arg2call, nodes, code_scope
def _jit_intg(f, show_code=False, **jit_setting): # TODO: integrator has "integral", "code_lines", "code_scope", "func_name", "derivative", assert isinstance(f, Integrator) # exponential euler methods if hasattr(f.integral, '__self__'): return _jit_cls_func(f=f.integral, code="\n".join(f.code_lines), show_code=show_code, **jit_setting) # information in the integrator func_name = f.func_name raw_func = f.derivative tree = ast.parse('\n'.join(f.code_lines)) code_scope = {key: val for key, val in f.code_scope.items()} # essential information arguments = set() arg2call = dict() nodes = Collector() # jit raw functions f_node = None remove_self = None if hasattr(f, '__self__') and isinstance(f.__self__, DynamicalSystem): f_node = f.__self__ _arg = tree.body[0].args.args.pop(0) # remove "self" arg # remove "self" in functional call remove_self = _arg.arg need_recompile = False for key, func in raw_func.items(): # get node of host func_node = None if f_node: func_node = f_node elif hasattr(func, '__self__') and isinstance(func.__self__, DynamicalSystem): func_node = func.__self__ # get new compiled function if isinstance(func, Dispatcher): continue elif func_node is not None: need_recompile = True r = _jit_cls_func(f=func, host=func_node, show_code=show_code, **jit_setting) if len(r['arguments']) or remove_self: tree = _replace_func_call_by_tree(tree, func_call=key, arg_to_append=r['arguments'], remove_self=remove_self) code_scope[key] = r['func'] arguments.update(r['arguments']) # update arguments arg2call.update(r['arg2call']) # update arg2call nodes.update(r['nodes']) # update nodes nodes[func_node.name] = func_node # update nodes else: need_recompile = True code_scope[key] = numba.jit(func, **jit_setting) if need_recompile: tree.body[0].decorator_list.clear() tree.body[0].args.args.extend( [ast.Name(id=a) for a in sorted(arguments)]) tree.body[0].args.defaults.extend( [ast.Constant(None) for _ in sorted(arguments)]) code = tools.ast2code(tree) # code, _scope = _add_try_except(code) # code_scope.update(_scope) # code_scope_backup = {k: v for k, v in code_scope.items()} # compile functions if show_code: _show_compiled_codes(code, code_scope) exec(compile(code, '', 'exec'), code_scope) new_f = code_scope[func_name] # new_f.brainpy_data = {key: val for key, val in f.brainpy_data.items()} # new_f.brainpy_data['code_lines'] = code.strip().split('\n') # new_f.brainpy_data['code_scope'] = code_scope_backup jit_f = numba.jit(new_f, **jit_setting) return dict(func=jit_f, arguments=arguments, arg2call=arg2call, nodes=nodes) else: return dict(func=f, arguments=arguments, arg2call=arg2call, nodes=nodes)
def analyze_step_func(host, f): """Analyze the step functions in a population. Parameters ---------- f : callable The step function. host : Population The data and the function host. Returns ------- results : dict The code string of the function, the code scope, the data need pass into the arguments, the data need return. """ code_string = tools.deindent(inspect.getsource(f)).strip() tree = ast.parse(code_string) # arguments # --- args = tools.ast2code(ast.fix_missing_locations( tree.body[0].args)).split(',') # code AST analysis # --- formatter = StepFuncReader(host=host) formatter.visit(tree) # data assigned by self.xx in line right # --- self_data_in_right = [] if args[0] in backend.CLASS_KEYWORDS: code = ', \n'.join(formatter.rights) self_data_in_right = re.findall( '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code) self_data_in_right = list(set(self_data_in_right)) # data assigned by self.xxx in line left # --- code = ', \n'.join(formatter.lefts) self_data_without_index_in_left = [] self_data_with_index_in_left = [] if args[0] in backend.CLASS_KEYWORDS: class_p1 = '\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b' self_data_without_index_in_left = set(re.findall(class_p1, code)) class_p2 = '(\\b' + args[0] + '\\.[A-Za-z_][A-Za-z0-9_.]*)\\[.*\\]' self_data_with_index_in_left = set(re.findall( class_p2, code)) #- self_data_without_index_in_left # self_data_with_index_in_left = set(re.findall(class_p2, code)) - self_data_without_index_in_left self_data_with_index_in_left = list(self_data_with_index_in_left) self_data_without_index_in_left = list(self_data_without_index_in_left) # code scope # --- closure_vars = inspect.getclosurevars(f) code_scope = dict(closure_vars.nonlocals) code_scope.update(closure_vars.globals) # final # --- self_data_in_right = sorted(self_data_in_right) self_data_without_index_in_left = sorted(self_data_without_index_in_left) self_data_with_index_in_left = sorted(self_data_with_index_in_left) analyzed_results = { 'delay_call': formatter.delay_call, 'code_string': '\n'.join(formatter.lines), 'code_scope': code_scope, 'self_data_in_right': self_data_in_right, 'self_data_without_index_in_left': self_data_without_index_in_left, 'self_data_with_index_in_left': self_data_with_index_in_left, } return analyzed_results
def visit_Call(self, node, level=0): if getattr(node, 'starargs', None) is not None: raise ValueError("Variable number of arguments not supported") if getattr(node, 'kwargs', None) is not None: raise ValueError("Keyword arguments not supported") if node in self.visited_calls: return node calls = self.visit_attr(node.func) calls = calls[::-1] # get the object and the function if calls[0] not in backend.CLASS_KEYWORDS: return node obj = self.host for data in calls[1:-1]: obj = getattr(obj, data) obj_func = getattr(obj, calls[-1]) # get function arguments args = [] for arg in node.args: args.append(tools.ast2code(ast.fix_missing_locations(arg))) kw_args = OrderedDict() for keyword in node.keywords: kw_args[keyword.arg] = tools.ast2code( ast.fix_missing_locations(keyword.value)) # TASK 1 : extract delay push and delay pull # ------ # Replace the delay function call to the delay_data # index. In such a way, delay function will be removed. # ------ if calls[-1] in ['push', 'pull'] and isinstance( obj, delays.ConstantDelay) and callable(obj_func): dvar4call = '.'.join(calls[0:-1]) uniform_delay = getattr(obj, 'uniform_delay') if calls[-1] == 'push': data_need_pass = [ f'{dvar4call}.delay_data', f'{dvar4call}.delay_in_idx' ] idx_or_val = kw_args['idx_or_val'] if len( args) == 0 else args[0] if len(args) + len(kw_args) == 1: rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx] = {idx_or_val}' elif len(args) + len(kw_args) == 2: value = kw_args['value'] if len(args) <= 1 else args[1] if uniform_delay: rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx, {idx_or_val}] = {value}' else: rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx[{idx_or_val}], {idx_or_val}] = {value}' else: raise errors.CodeError( f'Cannot analyze the code: \n\n' f'{tools.ast2code(ast.fix_missing_locations(node))}') else: data_need_pass = [ f'{dvar4call}.delay_data', f'{dvar4call}.delay_out_idx' ] if len(args) + len(kw_args) == 0: rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx]' elif len(args) + len(kw_args) == 1: idx = kw_args['idx'] if len(args) == 0 else args[0] if uniform_delay: rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx, {idx}]' else: rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx[{idx}], {idx}]' else: raise errors.CodeError( f'Cannot analyze the code: \n\n' f'{tools.ast2code(ast.fix_missing_locations(node))}') org_call = tools.ast2code(ast.fix_missing_locations(node)) self.visited_calls[node] = dict(type=calls[-1], org_call=org_call, rep_call=rep_expression, data_need_pass=data_need_pass) self.generic_visit(node)