def __init__(self, *arg_ds, name=None, **kwarg_ds): super(Sequential, self).__init__(name=name) self.implicit_nodes = Collector() # check "args" for ds in arg_ds: if not isinstance(ds, DynamicalSystem): raise errors.BrainPyError( f'Only support {DynamicalSystem.__name__}, ' f'but we got {type(ds)}: {str(ds)}.') self.implicit_nodes[ds.name] = ds # check "kwargs" for key, ds in kwarg_ds.items(): if not isinstance(ds, DynamicalSystem): raise errors.BrainPyError( f'Only support {DynamicalSystem.__name__}, ' f'but we got {type(ds)}: {str(ds)}.') self.implicit_nodes[key] = ds # all update functions self._return_kwargs = [ 'kwargs' in inspect.signature(ds.update).parameters.keys() for ds in self.implicit_nodes.values() ]
def nodes(self, method='absolute', _paths=None): """Collect all children nodes. Parameters ---------- method : str The method to access the nodes. _paths : set, Optional The data structure to solve the circular reference. Returns ------- gather : Collector The collection contained (the path, the node). """ if _paths is None: _paths = set() gather = Collector() if method == 'absolute': nodes = [] for k, v in self.__dict__.items(): if isinstance(v, Base): path = (id(self), id(v)) if path not in _paths: _paths.add(path) gather[v.name] = v nodes.append(v) for node in self.implicit_nodes.values(): path = (id(self), id(node)) if path not in _paths: _paths.add(path) gather[node.name] = node nodes.append(node) for v in nodes: gather.update(v.nodes(method=method, _paths=_paths)) gather[self.name] = self elif method == 'relative': nodes = [] gather[''] = self for k, v in self.__dict__.items(): if isinstance(v, Base): path = (id(self), id(v)) if path not in _paths: _paths.add(path) gather[k] = v nodes.append((k, v)) for key, node in self.implicit_nodes.items(): path = (id(self), id(node)) if path not in _paths: _paths.add(path) gather[key] = node nodes.append((key, node)) for k1, v1 in nodes: for k2, v2 in v1.nodes(method=method, _paths=_paths).items(): if k2: gather[f'{k1}.{k2}'] = v2 else: raise ValueError(f'No support for the method of "{method}".') return gather
def __init__(self, name=None): # check whether the object has a unique name. self.name = self.unique_name(name=name) naming.check_name_uniqueness(name=self.name, obj=self) # Used to wrap the implicit variables # which cannot be accessed by self.xxx self.implicit_vars = TensorCollector() # Used to wrap the implicit children nodes # which cannot be accessed by self.xxx self.implicit_nodes = Collector()
def __init__(self, steps=None, name=None): super(DynamicalSystem, self).__init__(name=name) # step functions if steps is None: steps = ('update', ) self.steps = Collector() if isinstance(steps, tuple): for step in steps: if isinstance(step, str): self.steps[step] = getattr(self, step) elif callable(step): self.steps[step.__name__] = step else: raise ModelBuildError( _error_msg.format(steps[0].__class__, str(steps[0]))) elif isinstance(steps, dict): for key, step in steps.items(): if callable(step): self.steps[key] = step else: raise ModelBuildError( _error_msg.format(steps.__class__, str(steps))) else: raise ModelBuildError( _error_msg.format(steps.__class__, str(steps)))
def __init__(self, *ds_tuple, steps=None, name=None, **ds_dict): # integrative step function if steps is None: steps = ('update', ) super(Container, self).__init__(steps=steps, name=name) # children dynamical systems self.implicit_nodes = Collector() for ds in ds_tuple: if not isinstance(ds, DynamicalSystem): raise ModelBuildError( f'{self.__class__.__name__} receives instances of ' f'DynamicalSystem, however, we got {type(ds)}.') if ds.name in self.implicit_nodes: raise ValueError( f'{ds.name} has been paired with {ds}. Please change a unique name.' ) self.register_implicit_nodes({node.name: node for node in ds_tuple}) for key, ds in ds_dict.items(): if not isinstance(ds, DynamicalSystem): raise ModelBuildError( f'{self.__class__.__name__} receives instances of ' f'DynamicalSystem, however, we got {type(ds)}.') if key in self.implicit_nodes: raise ValueError( f'{key} has been paired with {ds}. Please change a unique name.' ) self.register_implicit_nodes(ds_dict)
def __init__(self, loop_values, iter_name, show_code=False, **jit_setting): self.success = False # targets self.loop_values = loop_values self.iter_name = iter_name # setting self.show_code = show_code self.jit_setting = jit_setting # results self.arguments = set() self.arg2call = dict() self.nodes = Collector() self.code_scope = dict()
def _jit_func(obj_or_fun, show_code=False, **jit_setting): if callable(obj_or_fun): # integrator if isinstance(obj_or_fun, Integrator): return _jit_intg(obj_or_fun, show_code=show_code, **jit_setting) # bounded method elif hasattr(obj_or_fun, '__self__') and isinstance( obj_or_fun.__self__, Base): return _jit_cls_func(obj_or_fun, host=obj_or_fun.__self__, show_code=show_code, **jit_setting) # wrapped function elif isinstance(obj_or_fun, Function): return _jit_Function(obj_or_fun, show_code=show_code, **jit_setting) # base class function elif isinstance(obj_or_fun, Base): return _jit_cls_func(obj_or_fun.__call__, host=obj_or_fun, show_code=show_code, **jit_setting) else: # native function if not isinstance(obj_or_fun, Dispatcher): if inspector.inspect_function( obj_or_fun)['numba_type'] is None: f = numba.jit(obj_or_fun, **jit_setting) return dict(func=f, arguments=set(), arg2call=Collector(), nodes=Collector()) # numba function or innate supported function return dict(func=obj_or_fun, arguments=set(), arg2call=Collector(), nodes=Collector()) else: raise ValueError
def integral_func(*args, **kwargs): # format arguments params_in = Collector() for i, arg in enumerate(args): params_in[all_vps[i]] = arg params_in.update(kwargs) if 'dt' not in params_in: params_in['dt'] = math.get_dt() # call integrals results = [] for i, int_fun in enumerate(integrals): _key = arg_names[i][0] r = int_fun( params_in[_key], **{ arg: params_in[arg] for arg in arg_names[i][1:] if arg in params_in }) results.append(r) return results if isinstance(self.f, joint_eq.JointEq) else results[0]
def ints(self, method='absolute'): """Collect all integrators in this node and the children nodes. Parameters ---------- method : str The method to access the integrators. Returns ------- collector : Collector The collection contained (the path, the integrator). """ global Integrator if Integrator is None: from brainpy.integrators.base import Integrator nodes = self.nodes(method=method) gather = Collector() for node_path, node in nodes.items(): for k in dir(node): v = getattr(node, k) if isinstance(v, Integrator): gather[f'{node_path}.{k}' if node_path else k] = v return gather
def _jit_cls_func(f, code=None, host=None, show_code=False, **jit_setting): """JIT a class function. Examples -------- Example 1: the model has static parameters. >>> import brainpy as bp >>> >>> class HH(bp.NeuGroup): >>> def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0, >>> gNa=120., gK=36., gL=0.03, V_th=20., **kwargs): >>> super(HH, self).__init__(size=size, **kwargs) >>> # parameters >>> self.ENa = ENa >>> self.EK = EK >>> self.EL = EL >>> self.C = C >>> self.gNa = gNa >>> self.gK = gK >>> self.gL = gL >>> self.V_th = V_th >>> >>> def derivaitve(self, V, m, h, n, t, Iext): >>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) >>> beta = 4.0 * np.exp(-(V + 65) / 18) >>> dmdt = alpha * (1 - m) - beta * m >>> >>> alpha = 0.07 * np.exp(-(V + 65) / 20.) >>> beta = 1 / (1 + np.exp(-(V + 35) / 10)) >>> dhdt = alpha * (1 - h) - beta * h >>> >>> alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) >>> beta = 0.125 * np.exp(-(V + 65) / 80) >>> dndt = alpha * (1 - n) - beta * n >>> >>> I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) >>> I_K = (self.gK * n ** 4.0) * (V - self.EK) >>> I_leak = self.gL * (V - self.EL) >>> dVdt = (- I_Na - I_K - I_leak + Iext) / self.C >>> >>> return dVdt, dmdt, dhdt, dndt >>> >>> r = _jit_cls_func(HH(10).derivaitve, show_code=True) The recompiled function: ------------------------- def derivaitve(V, m, h, n, t, Iext): alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) beta = 4.0 * np.exp(-(V + 65) / 18) dmdt = alpha * (1 - m) - beta * m alpha = 0.07 * np.exp(-(V + 65) / 20.0) beta = 1 / (1 + np.exp(-(V + 35) / 10)) dhdt = alpha * (1 - h) - beta * h alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) beta = 0.125 * np.exp(-(V + 65) / 80) dndt = alpha * (1 - n) - beta * n I_Na = HH0_gNa * m ** 3.0 * h * (V - HH0_ENa) I_K = HH0_gK * n ** 4.0 * (V - HH0_EK) I_leak = HH0_gL * (V - HH0_EL) dVdt = (-I_Na - I_K - I_leak + Iext) / HH0_C return dVdt, dmdt, dhdt, dndt The namespace of the above function: {'HH0_C': 1.0, 'HH0_EK': -77.0, 'HH0_EL': -54.387, 'HH0_ENa': 50.0, 'HH0_gK': 36.0, 'HH0_gL': 0.03, 'HH0_gNa': 120.0, 'bp': <module 'brainpy' from 'D:\\codes\\Projects\\BrainPy\\brainpy\\__init__.py'>} >>> r['func'] CPUDispatcher(<function derivaitve at 0x0000020DF1647DC0>) >>> r['arguments'] set() >>> r['arg2call'] {} >>> r['nodes'] {'HH0': <__main__.<locals>.HH object at 0x0000020DF1623910>} Example 2: the model has dynamical variables. >>> import brainpy as bp >>> >>> class HH(bp.NeuGroup): >>> def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0, >>> gNa=120., gK=36., gL=0.03, V_th=20., **kwargs): >>> super(HH, self).__init__(size=size, **kwargs) >>> # parameters >>> self.ENa = ENa >>> self.EK = EK >>> self.EL = EL >>> self.C = C >>> self.gNa = gNa >>> self.gK = gK >>> self.gL = gL >>> self.V_th = V_th >>> self.input = bp.math.numpy.Variable(np.zeros(size)) >>> >>> def derivaitve(self, V, m, h, n, t): >>> alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) >>> beta = 4.0 * np.exp(-(V + 65) / 18) >>> dmdt = alpha * (1 - m) - beta * m >>> >>> alpha = 0.07 * np.exp(-(V + 65) / 20.) >>> beta = 1 / (1 + np.exp(-(V + 35) / 10)) >>> dhdt = alpha * (1 - h) - beta * h >>> >>> alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) >>> beta = 0.125 * np.exp(-(V + 65) / 80) >>> dndt = alpha * (1 - n) - beta * n >>> >>> I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) >>> I_K = (self.gK * n ** 4.0) * (V - self.EK) >>> I_leak = self.gL * (V - self.EL) >>> dVdt = (- I_Na - I_K - I_leak + self.input) / self.C >>> >>> return dVdt, dmdt, dhdt, dndt >>> >>> r = _jit_cls_func(HH(10).derivaitve, show_code=True) The recompiled function: ------------------------- def derivaitve(V, m, h, n, t, HH0_input=None): alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) beta = 4.0 * np.exp(-(V + 65) / 18) dmdt = alpha * (1 - m) - beta * m alpha = 0.07 * np.exp(-(V + 65) / 20.0) beta = 1 / (1 + np.exp(-(V + 35) / 10)) dhdt = alpha * (1 - h) - beta * h alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10)) beta = 0.125 * np.exp(-(V + 65) / 80) dndt = alpha * (1 - n) - beta * n I_Na = HH0_gNa * m ** 3.0 * h * (V - HH0_ENa) I_K = HH0_gK * n ** 4.0 * (V - HH0_EK) I_leak = HH0_gL * (V - HH0_EL) dVdt = (-I_Na - I_K - I_leak + HH0_input) / HH0_C return dVdt, dmdt, dhdt, dndt The namespace of the above function: {'HH0_C': 1.0, 'HH0_EK': -77.0, 'HH0_EL': -54.387, 'HH0_ENa': 50.0, 'HH0_gK': 36.0, 'HH0_gL': 0.03, 'HH0_gNa': 120.0, 'bp': <module 'brainpy' from 'D:\\codes\\Projects\\BrainPy\\brainpy\\__init__.py'>} >>> r['func'] CPUDispatcher(<function derivaitve at 0x0000020DF1647DC0>) >>> r['arguments'] {'HH0_input'} >>> r['arg2call'] {'HH0_input': 'HH0.input.value'} >>> r['nodes'] {'HH0': <__main__.<locals>.HH object at 0x00000219AE495E80>} Parameters ---------- f code host show_code jit_setting Returns ------- """ host = (host or f.__self__) # data to return arguments = set() arg2call = dict() nodes = Collector() nodes[host.name] = host # code code = (code or tools.deindent(inspect.getsource(f)).strip()) # function name func_name = f.__name__ # code scope closure_vars = inspect.getclosurevars(f) code_scope = dict(closure_vars.nonlocals) code_scope.update(closure_vars.globals) # analyze class function code, _arguments, _arg2call, _nodes, _code_scope = _analyze_cls_func( host=host, code=code, show_code=show_code, **jit_setting) arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) code_scope.update(_code_scope) # compile new function # code, _scope = _add_try_except(code) # code_scope.update(_scope) code_scope_to_compile = code_scope.copy() if show_code: _show_compiled_codes(code, code_scope) exec(compile(code, '', 'exec'), code_scope_to_compile) func = code_scope_to_compile[func_name] func = numba.jit(func, **jit_setting) # returns return dict(func=func, code=code, code_scope=code_scope, arguments=arguments, arg2call=arg2call, nodes=nodes)
def __init__(self, target, monitors=None, inits=None, args=None, dyn_args=None, dyn_vars=None, jit=True, dt=None, numpy_mon_after_run=True, progress_bar=True): super(IntegratorRunner, self).__init__() # parameters dt = math.get_dt() if dt is None else dt if not isinstance(dt, (int, float)): raise RunningError(f'"dt" must be scalar, but got {dt}') self.dt = dt self.jit = jit self.numpy_mon_after_run = numpy_mon_after_run self._pbar = None # progress bar self.progress_bar = progress_bar # target if not isinstance(target, Integrator): raise RunningError( f'"target" must be an instance of {Integrator.__name__}, ' f'but we got {type(target)}: {target}') self.target = target # arguments of the integral function self._static_args = Collector() if args is not None: assert isinstance( args, dict ), f'"args" must be a dict, but we get {type(args)}: {args}' self._static_args.update(args) self._dyn_args = Collector() if dyn_args is not None: assert isinstance( dyn_args, dict ), f'"dyn_args" must be a dict, but we get {type(dyn_args)}: {dyn_args}' sizes = np.unique([len(v) for v in dyn_args.values()]) num_size = len(sizes) if num_size != 1: raise RunningError( f'All values in "dyn_args" should have the same length. But we got ' f'{num_size}: {sizes}') self._dyn_args.update(dyn_args) # dynamical changed variables if dyn_vars is None: dyn_vars = self.target.vars().unique() if isinstance(dyn_vars, (list, tuple)): dyn_vars = {f'_v{i}': v for i, v in enumerate(dyn_vars)} if not isinstance(dyn_vars, dict): raise RunningError( f'"dyn_vars" must be a dict, but we got {type(dyn_vars)}') self.dyn_vars = TensorCollector(dyn_vars) # monitors if monitors is None: self.mon = Monitor(target=self, variables=[]) elif isinstance(monitors, (list, tuple, dict)): self.mon = Monitor(target=self, variables=monitors) elif isinstance(monitors, Monitor): self.mon = monitors self.mon.target = self else: raise MonitorError(f'"monitors" only supports list/tuple/dict/ ' f'instance of Monitor, not {type(monitors)}.') self.mon.build() # build the monitor for k in self.mon.item_names: if k not in self.target.variables: raise MonitorError( f'Variable "{k}" to monitor is not defined in the integrator {self.target}.' ) # start simulation time self._start_t = None # Variables if inits is not None: if isinstance(inits, (list, tuple)): assert len(self.target.variables) == len(inits) inits = { k: inits[i] for i, k in enumerate(self.target.variables) } assert isinstance(inits, dict) sizes = np.unique([np.size(v) for v in list(inits.values())]) max_size = np.max(sizes) else: max_size = 1 inits = dict() self.variables = TensorCollector({ v: math.Variable(math.zeros(max_size)) for v in self.target.variables }) for k in inits.keys(): self.variables[k][:] = inits[k] self.dyn_vars.update(self.variables) if len(self._dyn_args) > 0: self.idx = math.Variable(math.zeros(1, dtype=math.int_)) self.dyn_vars['_idx'] = self.idx # build the update step if jit: _loop_func = math.make_loop( self._step, dyn_vars=self.dyn_vars, out_vars={k: self.variables[k] for k in self.mon.item_names}) else: def _loop_func(t_and_dt): out_vars = {k: [] for k in self.mon.item_names} times, dts = t_and_dt for i in range(len(times)): _t = times[i] _dt = dts[i] self._step([_t, _dt]) for k in self.mon.item_names: out_vars[k].append( math.as_device_array(self.variables[k])) out_vars = { k: math.asarray(out_vars[k]) for k in self.mon.item_names } return out_vars self.step_func = _loop_func
class Base(object): """The Base class for whole BrainPy ecosystem. The subclass of Base includes: - ``DynamicalSystem`` in *brainpy.simulation.brainobjects.base.py* - ``Integrator`` in *brainpy.integrators.base.py* - ``Function`` in *brainpy.base.function.py* - ``AutoGrad`` in *brainpy.math.jax.autograd.py* - ``Optimizer`` in *brainpy.math.jax.optimizers.py* - ``Scheduler`` in *brainpy.math.jax.optimizers.py* """ def __init__(self, name=None): # check whether the object has a unique name. self.name = self.unique_name(name=name) naming.check_name_uniqueness(name=self.name, obj=self) # Used to wrap the implicit variables # which cannot be accessed by self.xxx self.implicit_vars = TensorCollector() # Used to wrap the implicit children nodes # which cannot be accessed by self.xxx self.implicit_nodes = Collector() def register_implicit_vars(self, variables): assert isinstance(variables, dict) self.implicit_vars.update(variables) def register_implicit_nodes(self, nodes): assert isinstance(nodes, dict) self.implicit_nodes.update(nodes) def vars(self, method='absolute'): """Collect all variables in this node and the children nodes. Parameters ---------- method : str The method to access the variables. Returns ------- gather : TensorCollector The collection contained (the path, the variable). """ global math if math is None: from brainpy import math nodes = self.nodes(method=method) gather = TensorCollector() for node_path, node in nodes.items(): for k in dir(node): v = getattr(node, k) if isinstance(v, math.Variable): gather[f'{node_path}.{k}' if node_path else k] = v gather.update( {f'{node_path}.{k}': v for k, v in node.implicit_vars.items()}) return gather def train_vars(self, method='absolute'): """The shortcut for retrieving all trainable variables. Parameters ---------- method : str The method to access the variables. Support 'absolute' and 'relative'. Returns ------- gather : TensorCollector The collection contained (the path, the trainable variable). """ global math if math is None: from brainpy import math return self.vars(method=method).subset(math.TrainVar) def nodes(self, method='absolute', _paths=None): """Collect all children nodes. Parameters ---------- method : str The method to access the nodes. _paths : set, Optional The data structure to solve the circular reference. Returns ------- gather : Collector The collection contained (the path, the node). """ if _paths is None: _paths = set() gather = Collector() if method == 'absolute': nodes = [] for k, v in self.__dict__.items(): if isinstance(v, Base): path = (id(self), id(v)) if path not in _paths: _paths.add(path) gather[v.name] = v nodes.append(v) for node in self.implicit_nodes.values(): path = (id(self), id(node)) if path not in _paths: _paths.add(path) gather[node.name] = node nodes.append(node) for v in nodes: gather.update(v.nodes(method=method, _paths=_paths)) gather[self.name] = self elif method == 'relative': nodes = [] gather[''] = self for k, v in self.__dict__.items(): if isinstance(v, Base): path = (id(self), id(v)) if path not in _paths: _paths.add(path) gather[k] = v nodes.append((k, v)) for key, node in self.implicit_nodes.items(): path = (id(self), id(node)) if path not in _paths: _paths.add(path) gather[key] = node nodes.append((key, node)) for k1, v1 in nodes: for k2, v2 in v1.nodes(method=method, _paths=_paths).items(): if k2: gather[f'{k1}.{k2}'] = v2 else: raise ValueError(f'No support for the method of "{method}".') return gather def ints(self, method='absolute'): """Collect all integrators in this node and the children nodes. Parameters ---------- method : str The method to access the integrators. Returns ------- collector : Collector The collection contained (the path, the integrator). """ global Integrator if Integrator is None: from brainpy.integrators.base import Integrator nodes = self.nodes(method=method) gather = Collector() for node_path, node in nodes.items(): for k in dir(node): v = getattr(node, k) if isinstance(v, Integrator): gather[f'{node_path}.{k}' if node_path else k] = v return gather def unique_name(self, name=None, type_=None): """Get the unique name for this object. Parameters ---------- name : str, optional The expected name. If None, the default unique name will be returned. Otherwise, the provided name will be checked to guarantee its uniqueness. type_ : str, optional The name of this class, used for object naming. Returns ------- name : str The unique name for this object. """ if name is None: if type_ is None: return naming.get_unique_name(type_=self.__class__.__name__) else: return naming.get_unique_name(type_=type_) else: naming.check_name_uniqueness(name=name, obj=self) return name def load_states(self, filename, verbose=False, check_missing=False): """Load the model states. Parameters ---------- filename : str The filename which stores the model states. verbose: bool check_missing: bool """ if not os.path.exists(filename): raise errors.BrainPyError(f'Cannot find the file path: {filename}') elif filename.endswith('.hdf5') or filename.endswith('.h5'): io.load_h5(filename, target=self, verbose=verbose, check=check_missing) elif filename.endswith('.pkl'): io.load_pkl(filename, target=self, verbose=verbose, check=check_missing) elif filename.endswith('.npz'): io.load_npz(filename, target=self, verbose=verbose, check=check_missing) elif filename.endswith('.mat'): io.load_mat(filename, target=self, verbose=verbose, check=check_missing) else: raise errors.BrainPyError( f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}' ) def save_states(self, filename, all_vars=None, **setting): """Save the model states. Parameters ---------- filename : str The file name which to store the model states. all_vars: optional, dict, TensorCollector """ if all_vars is None: all_vars = self.vars(method='relative').unique() if filename.endswith('.hdf5') or filename.endswith('.h5'): io.save_h5(filename, all_vars=all_vars) elif filename.endswith('.pkl'): io.save_pkl(filename, all_vars=all_vars) elif filename.endswith('.npz'): io.save_npz(filename, all_vars=all_vars, **setting) elif filename.endswith('.mat'): io.save_mat(filename, all_vars=all_vars) else: raise errors.BrainPyError( f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}' )
def visit_For(self, node): iter_ = tools.ast2code(ast.fix_missing_locations(node.iter)) if iter_.strip() == self.iter_name: data_to_replace = Collector() final_node = ast.Module(body=[]) self.success = True # target if not isinstance(node.target, ast.Name): raise errors.BrainPyError( f'Only support scalar iter, like "for x in xxxx:", not "for ' f'{tools.ast2code(ast.fix_missing_locations(node.target))} ' f'in {iter_}:') target = node.target.id # for loop values for i, value in enumerate(self.loop_values): # module and code module = ast.Module(body=deepcopy(node).body) code = tools.ast2code(module) if isinstance(value, Base): # transform Base objects r = _analyze_cls_func_body(host=value, self_name=target, code=code, tree=module, show_code=self.show_code, **self.jit_setting) new_code, arguments, arg2call, nodes, code_scope = r self.arguments.update(arguments) self.arg2call.update(arg2call) self.arg2call.update(arg2call) self.nodes.update(nodes) self.code_scope.update(code_scope) final_node.body.extend(ast.parse(new_code).body) elif callable(value): # transform functions r = _jit_func(obj_or_fun=value, show_code=self.show_code, **self.jit_setting) tree = _replace_func_call_by_tree( deepcopy(module), func_call=target, arg_to_append=r['arguments'], new_func_name=f'{target}_{i}') # update import parameters self.arguments.update(r['arguments']) self.arg2call.update(r['arg2call']) self.nodes.update(r['nodes']) # replace the data if isinstance(value, Base): host = value replace_name = f'{host.name}_{target}' elif hasattr(value, '__self__') and isinstance( value.__self__, Base): host = value.__self__ replace_name = f'{host.name}_{target}' else: replace_name = f'{target}_{i}' self.code_scope[replace_name] = r['func'] data_to_replace[f'{target}_{i}'] = replace_name final_node.body.extend(tree.body) else: raise errors.BrainPyError( f'Only support JIT an iterable objects of function ' f'or Base object, but we got:\n\n {value}') # replace words final_code = tools.ast2code(final_node) final_code = tools.word_replace(final_code, data_to_replace, exclude_dot=True) final_node = ast.parse(final_code) else: final_node = node self.generic_visit(final_node) return final_node
class Sequential(Module): """Basic sequential object to control data flow. Parameters ---------- arg_ds The modules without name specifications. name : str, optional The name of the sequential module. kwarg_ds The modules with name specifications. """ def __init__(self, *arg_ds, name=None, **kwarg_ds): super(Sequential, self).__init__(name=name) self.implicit_nodes = Collector() # check "args" for ds in arg_ds: if not isinstance(ds, DynamicalSystem): raise errors.BrainPyError( f'Only support {DynamicalSystem.__name__}, ' f'but we got {type(ds)}: {str(ds)}.') self.implicit_nodes[ds.name] = ds # check "kwargs" for key, ds in kwarg_ds.items(): if not isinstance(ds, DynamicalSystem): raise errors.BrainPyError( f'Only support {DynamicalSystem.__name__}, ' f'but we got {type(ds)}: {str(ds)}.') self.implicit_nodes[key] = ds # all update functions self._return_kwargs = [ 'kwargs' in inspect.signature(ds.update).parameters.keys() for ds in self.implicit_nodes.values() ] def _check_kwargs(self, i, kwargs): return kwargs if self._return_kwargs[i] else dict() def update(self, *args, **kwargs): """Functional call. Parameters ---------- args : list, tuple The *args arguments. kwargs : dict The config arguments. The configuration used across modules. If the "__call__" function in submodule receives "config" arguments, This "config" parameter will be passed into this function. """ ds = list(self.implicit_nodes.values()) # first layer args = ds[0].update(*args, **self._check_kwargs(0, kwargs)) # other layers for i in range(1, len(self.implicit_nodes)): args = ds[i].update(*_check_args(args=args), **self._check_kwargs(i, kwargs)) return args def __getitem__(self, key: int): return list(self.implicit_nodes.values())[key]
def _analyze_cls_func_body(host, self_name, code, tree, show_code=False, has_func_def=False, **jit_setting): arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict() # all self data self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code) self_data = list(set(self_data)) # analyze variables and functions accessed by the self.xx data_to_replace = {} for key in self_data: split_keys = key.split('.') if len(split_keys) < 2: raise errors.BrainPyError # get target and data target = host for i in range(1, len(split_keys)): next_target = getattr(target, split_keys[i]) if isinstance(next_target, Integrator): break if not isinstance(next_target, Base): break target = next_target else: raise errors.BrainPyError data = getattr(target, split_keys[i]) # analyze data if isinstance(data, math.numpy.Variable): # data is a variable arguments.add(f'{target.name}_{split_keys[i]}') arg2call[ f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value' nodes[target.name] = target # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif isinstance(data, np.random.RandomState): # data is a RandomState # replace RandomState code_scope[f'{target.name}_{split_keys[i]}'] = np.random # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif callable(data): # data is a function assert len(split_keys) == i + 1 r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting) # if len(r['arguments']): tree = _replace_func_call_by_tree(tree, func_call=key, arg_to_append=r['arguments']) arguments.update(r['arguments']) arg2call.update(r['arg2call']) nodes.update(r['nodes']) code_scope[f'{target.name}_{split_keys[i]}'] = r['func'] data_to_replace[ key] = f'{target.name}_{split_keys[i]}' # replace the data elif isinstance( data, (dict, list, tuple)): # data is a list/tuple/dict of function/object # get all values if isinstance(data, dict): # check dict if len(split_keys) != i + 2 and split_keys[-1] != 'values': raise errors.BrainPyError( f'Only support iter dict.values(). while we got ' f'dict.{split_keys[-1]} for data: \n\n{data}') values = list(data.values()) iter_name = key + '()' else: # check list / tuple assert len(split_keys) == i + 1 values = list(data) iter_name = key if len(values) > 0: if not (callable(values[0]) or isinstance(values[0], Base)): code_scope[f'{target.name}_{split_keys[i]}'] = data if len(split_keys) == i + 1: data_to_replace[ key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' continue # raise errors.BrainPyError(f'Only support JIT an iterable objects of function ' # f'or Base object, but we got:\n\n {values[0]}') # replace this for-loop r = _replace_this_forloop(tree=tree, iter_name=iter_name, loop_values=values, show_code=show_code, **jit_setting) tree, _arguments, _arg2call, _nodes, _code_scope = r arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) code_scope.update(_code_scope) else: # constants code_scope[f'{target.name}_{split_keys[i]}'] = data # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' if has_func_def: tree.body[0].decorator_list.clear() tree.body[0].args.args.extend( [ast.Name(id=a) for a in sorted(arguments)]) tree.body[0].args.defaults.extend( [ast.Constant(None) for _ in sorted(arguments)]) tree.body[0].args.kwarg = None # replace words code = tools.ast2code(tree) code = tools.word_replace(code, data_to_replace, exclude_dot=True) return code, arguments, arg2call, nodes, code_scope
def _jit_intg(f, show_code=False, **jit_setting): # TODO: integrator has "integral", "code_lines", "code_scope", "func_name", "derivative", assert isinstance(f, Integrator) # exponential euler methods if hasattr(f.integral, '__self__'): return _jit_cls_func(f=f.integral, code="\n".join(f.code_lines), show_code=show_code, **jit_setting) # information in the integrator func_name = f.func_name raw_func = f.derivative tree = ast.parse('\n'.join(f.code_lines)) code_scope = {key: val for key, val in f.code_scope.items()} # essential information arguments = set() arg2call = dict() nodes = Collector() # jit raw functions f_node = None remove_self = None if hasattr(f, '__self__') and isinstance(f.__self__, DynamicalSystem): f_node = f.__self__ _arg = tree.body[0].args.args.pop(0) # remove "self" arg # remove "self" in functional call remove_self = _arg.arg need_recompile = False for key, func in raw_func.items(): # get node of host func_node = None if f_node: func_node = f_node elif hasattr(func, '__self__') and isinstance(func.__self__, DynamicalSystem): func_node = func.__self__ # get new compiled function if isinstance(func, Dispatcher): continue elif func_node is not None: need_recompile = True r = _jit_cls_func(f=func, host=func_node, show_code=show_code, **jit_setting) if len(r['arguments']) or remove_self: tree = _replace_func_call_by_tree(tree, func_call=key, arg_to_append=r['arguments'], remove_self=remove_self) code_scope[key] = r['func'] arguments.update(r['arguments']) # update arguments arg2call.update(r['arg2call']) # update arg2call nodes.update(r['nodes']) # update nodes nodes[func_node.name] = func_node # update nodes else: need_recompile = True code_scope[key] = numba.jit(func, **jit_setting) if need_recompile: tree.body[0].decorator_list.clear() tree.body[0].args.args.extend( [ast.Name(id=a) for a in sorted(arguments)]) tree.body[0].args.defaults.extend( [ast.Constant(None) for _ in sorted(arguments)]) code = tools.ast2code(tree) # code, _scope = _add_try_except(code) # code_scope.update(_scope) # code_scope_backup = {k: v for k, v in code_scope.items()} # compile functions if show_code: _show_compiled_codes(code, code_scope) exec(compile(code, '', 'exec'), code_scope) new_f = code_scope[func_name] # new_f.brainpy_data = {key: val for key, val in f.brainpy_data.items()} # new_f.brainpy_data['code_lines'] = code.strip().split('\n') # new_f.brainpy_data['code_scope'] = code_scope_backup jit_f = numba.jit(new_f, **jit_setting) return dict(func=jit_f, arguments=arguments, arg2call=arg2call, nodes=nodes) else: return dict(func=f, arguments=arguments, arg2call=arg2call, nodes=nodes)
class IntegratorRunner(Runner): def __init__(self, target, monitors=None, inits=None, args=None, dyn_args=None, dyn_vars=None, jit=True, dt=None, numpy_mon_after_run=True, progress_bar=True): super(IntegratorRunner, self).__init__() # parameters dt = math.get_dt() if dt is None else dt if not isinstance(dt, (int, float)): raise RunningError(f'"dt" must be scalar, but got {dt}') self.dt = dt self.jit = jit self.numpy_mon_after_run = numpy_mon_after_run self._pbar = None # progress bar self.progress_bar = progress_bar # target if not isinstance(target, Integrator): raise RunningError( f'"target" must be an instance of {Integrator.__name__}, ' f'but we got {type(target)}: {target}') self.target = target # arguments of the integral function self._static_args = Collector() if args is not None: assert isinstance( args, dict ), f'"args" must be a dict, but we get {type(args)}: {args}' self._static_args.update(args) self._dyn_args = Collector() if dyn_args is not None: assert isinstance( dyn_args, dict ), f'"dyn_args" must be a dict, but we get {type(dyn_args)}: {dyn_args}' sizes = np.unique([len(v) for v in dyn_args.values()]) num_size = len(sizes) if num_size != 1: raise RunningError( f'All values in "dyn_args" should have the same length. But we got ' f'{num_size}: {sizes}') self._dyn_args.update(dyn_args) # dynamical changed variables if dyn_vars is None: dyn_vars = self.target.vars().unique() if isinstance(dyn_vars, (list, tuple)): dyn_vars = {f'_v{i}': v for i, v in enumerate(dyn_vars)} if not isinstance(dyn_vars, dict): raise RunningError( f'"dyn_vars" must be a dict, but we got {type(dyn_vars)}') self.dyn_vars = TensorCollector(dyn_vars) # monitors if monitors is None: self.mon = Monitor(target=self, variables=[]) elif isinstance(monitors, (list, tuple, dict)): self.mon = Monitor(target=self, variables=monitors) elif isinstance(monitors, Monitor): self.mon = monitors self.mon.target = self else: raise MonitorError(f'"monitors" only supports list/tuple/dict/ ' f'instance of Monitor, not {type(monitors)}.') self.mon.build() # build the monitor for k in self.mon.item_names: if k not in self.target.variables: raise MonitorError( f'Variable "{k}" to monitor is not defined in the integrator {self.target}.' ) # start simulation time self._start_t = None # Variables if inits is not None: if isinstance(inits, (list, tuple)): assert len(self.target.variables) == len(inits) inits = { k: inits[i] for i, k in enumerate(self.target.variables) } assert isinstance(inits, dict) sizes = np.unique([np.size(v) for v in list(inits.values())]) max_size = np.max(sizes) else: max_size = 1 inits = dict() self.variables = TensorCollector({ v: math.Variable(math.zeros(max_size)) for v in self.target.variables }) for k in inits.keys(): self.variables[k][:] = inits[k] self.dyn_vars.update(self.variables) if len(self._dyn_args) > 0: self.idx = math.Variable(math.zeros(1, dtype=math.int_)) self.dyn_vars['_idx'] = self.idx # build the update step if jit: _loop_func = math.make_loop( self._step, dyn_vars=self.dyn_vars, out_vars={k: self.variables[k] for k in self.mon.item_names}) else: def _loop_func(t_and_dt): out_vars = {k: [] for k in self.mon.item_names} times, dts = t_and_dt for i in range(len(times)): _t = times[i] _dt = dts[i] self._step([_t, _dt]) for k in self.mon.item_names: out_vars[k].append( math.as_device_array(self.variables[k])) out_vars = { k: math.asarray(out_vars[k]) for k in self.mon.item_names } return out_vars self.step_func = _loop_func def _post(self, times, returns): # monitor self.mon.ts = times for key in self.mon.item_names: self.mon.item_contents[key] = math.asarray(returns[key]) def _step(self, t_and_dt): # arguments kwargs = dict() kwargs.update(self.variables) kwargs.update({'t': t_and_dt[0], 'dt': t_and_dt[1]}) kwargs.update(self._static_args) if len(self._dyn_args) > 0: kwargs.update({k: v[self.idx] for k, v in self._dyn_args.items()}) self.idx += 1 # call integrator function update_values = self.target(**kwargs) for i, v in enumerate(self.target.variables): self.variables[v].update(update_values[i]) if self.progress_bar: id_tap(lambda *args: self._pbar.update(), ()) def run(self, duration, start_t=None): self.__call__(duration, start_t) def __call__(self, duration, start_t=None): """The running function. Parameters ---------- duration : float, int, tuple, list The running duration. start_t : float, optional Returns ------- running_time : float The total running time. """ if len(self._dyn_args) > 0: self.dyn_vars['_idx'][0] = 0 # time step if start_t is None: if self._start_t is None: start_t = 0. else: start_t = float(self._start_t) end_t = float(start_t + duration) # times times = math.arange(start_t, end_t, self.dt) time_steps = math.ones_like(times) * self.dt # running if self.progress_bar: self._pbar = tqdm.auto.tqdm(total=times.size) self._pbar.set_description( f"Running a duration of {round(float(duration), 3)} ({times.size} steps)", refresh=True) t0 = time.time() hists = self.step_func([times.value, time_steps.value]) running_time = time.time() - t0 if self.progress_bar: self._pbar.close() # post-running self._post(times, hists) self._start_t = end_t if self.numpy_mon_after_run: self.mon.numpy() return running_time