def _get_args(f): # 1. get the function arguments original_args = [] args = [] kwargs = [] for name, par in inspect.signature(f).parameters.items(): if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: args.append(par.name) elif par.kind is inspect.Parameter.VAR_POSITIONAL: args.append(par.name) elif par.kind is inspect.Parameter.KEYWORD_ONLY: args.append(par.name) elif par.kind is inspect.Parameter.POSITIONAL_ONLY: raise errors.BrainPyError( 'Don not support positional only parameters, e.g., /') elif par.kind is inspect.Parameter.VAR_KEYWORD: kwargs.append(par.name) else: raise errors.BrainPyError(f'Unknown argument type: {par.kind}') original_args.append(str(par)) # 2. analyze the function arguments # 2.1 class keywords class_kw = [] if original_args[0] in CLASS_KEYWORDS: class_kw.append(original_args[0]) original_args = original_args[1:] args = args[1:] for a in original_args: if a.split('=')[0].strip() in CLASS_KEYWORDS: raise errors.DiffEqError(f'Class keywords "{a}" must be defined ' f'as the first argument.') return class_kw, args, kwargs, original_args
def __init__(self, *arg_ds, name=None, **kwarg_ds): super(Sequential, self).__init__(name=name) self.implicit_nodes = Collector() # check "args" for ds in arg_ds: if not isinstance(ds, DynamicalSystem): raise errors.BrainPyError( f'Only support {DynamicalSystem.__name__}, ' f'but we got {type(ds)}: {str(ds)}.') self.implicit_nodes[ds.name] = ds # check "kwargs" for key, ds in kwarg_ds.items(): if not isinstance(ds, DynamicalSystem): raise errors.BrainPyError( f'Only support {DynamicalSystem.__name__}, ' f'but we got {type(ds)}: {str(ds)}.') self.implicit_nodes[key] = ds # all update functions self._return_kwargs = [ 'kwargs' in inspect.signature(ds.update).parameters.keys() for ds in self.implicit_nodes.values() ]
def _device_reshape(x): """Reshape an input array in order to broadcast to multiple devices.""" num_device = jax.local_device_count() if not hasattr(x, 'ndim'): raise errors.BrainPyError( f'Expected JaxArray, got {type(x)}. If you are trying to pass a scalar to ' f'parallel, first convert it to a JaxArray, for example np.float(0.5)' ) if x.ndim == 0: return np.broadcast_to(x, [num_device]) if x.shape[0] % num_device != 0: raise errors.BrainPyError( f'Must be able to equally divide batch {x.shape} among ' f'{num_device} devices, but does not go equally.') return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:])
def __init__(self, size, C=1., A=1e-3, V_th=0., method='exp_auto', name=None, **channels): super(CondNeuGroup, self).__init__(size, method=method, name=name) # parameters for neurons self.C = C self.A = A self.V_th = V_th # check 'channels' _channels = dict() for key in channels.items(): assert isinstance(key, str), f'Key must be a str, but got {type(key)}: {key}' assert isinstance(channels[key], (tuple, list)) and len(channels[key]) == 2 assert isinstance(channels[key][0], type) assert isinstance(channels[key][1], dict) cls = channels[key][0] params = channels[key][1].copy() params['host'] = self params['method'] = method _channels[key] = (cls, params) # initialize children channels self.channels = Collector() for key, (ch, params) in _channels.items(): self.channels[key] = ch(**params) if not isinstance(self.channels[key], Channel): raise errors.BrainPyError(f'{self.__class__.__name__} only receives {Channel} instance, ' f'while we got {type(self.channels[key])}: {self.channels[key]}.') self.ion_channels = self.channels.subset(IonChannel) self.mol_channels = self.channels.subset(MolChannel)
def _check_var(var): global math if math is None: from brainpy import math if not isinstance(var, math.ndarray): raise errors.BrainPyError( f'Element in "dyn_vars" must be an instance of ' f'{math.ndarray.__name__}, but we got {type(var)}.')
def _jit_DS(obj_or_fun, show_code=False, **jit_setting): if not isinstance(obj_or_fun, DynamicalSystem): raise errors.UnsupportedError(f'JIT compilation in numpy backend only ' f'supports {Base.__name__}, but we got ' f'{type(obj_or_fun)}.') if not hasattr(obj_or_fun, 'steps'): raise errors.BrainPyError( f'Please init this DynamicalSystem {obj_or_fun} first, ' f'then apply JIT.') # function analysis for key, step in list(obj_or_fun.steps.items()): key = key.replace(".", "_") r = _jit_func(obj_or_fun=step, show_code=show_code, **jit_setting) if r['func'] != step: func = _form_final_call(f_org=step, f_rep=r['func'], arg2call=r['arg2call'], arguments=r['arguments'], nodes=r['nodes'], show_code=show_code, name=step.__name__) obj_or_fun.steps.replace(key, func) # dynamic system return obj_or_fun
def _replace_this_forloop(tree, iter_name, loop_values, show_code=False, **jit_setting): """Replace the given for-loop. This function aims to replace the specific for-loop structure, like: replace this for-loop >>> def update(_t, _dt): >>> for step in self.child_steps.values(): >>> step(_t, _dt) to >>> def update(_t, _dt, AMPA_vec0_delay_g_data=None, AMPA_vec0_delay_g_in_idx=None, >>> AMPA_vec0_delay_g_out_idx=None, AMPA_vec0_s=None, HH0_V=None, HH0_V_th=None, >>> HH0_gNa=None, HH0_h=None, HH0_input=None, HH0_m=None, HH0_n=None, HH0_spike=None): >>> HH0_step(_t, _dt, HH0_V=HH0_V, HH0_V_th=HH0_V_th, HH0_gNa=HH0_gNa, >>> HH0_h=HH0_h, HH0_input=HH0_input, HH0_m=HH0_m, HH0_n=HH0_n, >>> HH0_spike=HH0_spike) >>> AMPA_vec0_step(_t, _dt, AMPA_vec0_delay_g_data=AMPA_vec0_delay_g_data, >>> AMPA_vec0_delay_g_in_idx=AMPA_vec0_delay_g_in_idx, >>> AMPA_vec0_delay_g_out_idx=AMPA_vec0_delay_g_out_idx, >>> AMPA_vec0_s=AMPA_vec0_s, HH0_V=HH0_V, HH0_input=HH0_input, >>> HH0_spike=HH0_spike) >>> AMPA_vec0_delay_g_step(_t, _dt, AMPA_vec0_delay_g_in_idx=AMPA_vec0_delay_g_in_idx, >>> AMPA_vec0_delay_g_out_idx=AMPA_vec0_delay_g_out_idx) Parameters ---------- tree : ast.Module The target code tree. iter_name : str The for-loop iter. loop_values : list/tuple The iter contents in the current loop. show_code : bool Whether show the formatted code. """ assert isinstance(tree, ast.Module) replacer = ReplaceThisForLoop(loop_values=loop_values, iter_name=iter_name, show_code=show_code, **jit_setting) tree = replacer.visit(tree) if not replacer.success: raise errors.BrainPyError( f'Do not find the for-loop for "{iter_name}", ' f'currently we only support for-loop like ' f'"for xxx in {iter_name}:". Does your for-loop ' f'structure is not like this. ') return tree, replacer.arguments, replacer.arg2call, replacer.nodes, replacer.code_scope
def check_name_uniqueness(name, obj): """Check the uniqueness of the name for the object type.""" if not name.isidentifier(): raise errors.BrainPyError(f'"{name}" isn\'t a valid identifier ' f'according to Python language definition. ' f'Please choose another name.') if name in _name2id: if _name2id[name] != id(obj): raise errors.UniqueNameError( f'In BrainPy, each object should have a unique name. ' f'However, we detect that {obj} has a used name "{name}".') else: _name2id[name] = id(obj)
def load_states(self, filename, verbose=False, check_missing=False): """Load the model states. Parameters ---------- filename : str The filename which stores the model states. verbose: bool check_missing: bool """ if not os.path.exists(filename): raise errors.BrainPyError(f'Cannot find the file path: {filename}') elif filename.endswith('.hdf5') or filename.endswith('.h5'): io.load_h5(filename, target=self, verbose=verbose, check=check_missing) elif filename.endswith('.pkl'): io.load_pkl(filename, target=self, verbose=verbose, check=check_missing) elif filename.endswith('.npz'): io.load_npz(filename, target=self, verbose=verbose, check=check_missing) elif filename.endswith('.mat'): io.load_mat(filename, target=self, verbose=verbose, check=check_missing) else: raise errors.BrainPyError( f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}' )
def call(*args): un_replicated = [ k for k, v in dyn_vars.items() if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer)) ] if len(un_replicated): raise errors.BrainPyError( f'Some variables were not replicated: {un_replicated}.' f'did you forget to call xx.replicate() on them?') _args = [] for i, x in enumerate(args): if i + 2 in static_broadcasted_argnums: _args.append(x) else: _args.append(jax.tree_map(_device_reshape, [x])[0]) dyn_data = dyn_vars.dict() rand_data = rand_vars.dict() output, dyn_changes, rand_changes = pmapped_func( dyn_data, rand_data, *_args) dyn_vars.assign(dyn_changes) rand_vars.assign(rand_changes) return jax.tree_map(reduce_func, output)
def save_states(self, filename, all_vars=None, **setting): """Save the model states. Parameters ---------- filename : str The file name which to store the model states. all_vars: optional, dict, TensorCollector """ if all_vars is None: all_vars = self.vars(method='relative').unique() if filename.endswith('.hdf5') or filename.endswith('.h5'): io.save_h5(filename, all_vars=all_vars) elif filename.endswith('.pkl'): io.save_pkl(filename, all_vars=all_vars) elif filename.endswith('.npz'): io.save_npz(filename, all_vars=all_vars, **setting) elif filename.endswith('.mat'): io.save_mat(filename, all_vars=all_vars) else: raise errors.BrainPyError( f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}' )
def raster_plot(ts, sp_matrix, ax=None, marker='.', markersize=2, color='k', xlabel='Time (ms)', ylabel='Neuron index', xlim=None, ylim=None, title=None, show=False, **kwargs): """Show the rater plot of the spikes. Parameters ---------- ts : np.ndarray The run times. sp_matrix : np.ndarray The spike matrix which records the spike information. It can be easily accessed by specifying the ``monitors`` of NeuGroup by: ``neu = NeuGroup(..., monitors=['spike'])`` ax : Axes The figure. markersize : int The size of the marker. color : str The color of the marker. xlim : list, tuple The xlim. ylim : list, tuple The ylim. xlabel : str The xlabel. ylabel : str The ylabel. show : bool Show the figure. """ sp_matrix = np.asarray(sp_matrix) if ts is None: raise errors.BrainPyError('Must provide "ts".') ts = np.asarray(ts) # get index and time elements = np.where(sp_matrix > 0.) index = elements[1] time = ts[elements[0]] # plot rater if ax is None: ax = plt ax.plot(time, index, marker + color, markersize=markersize, **kwargs) # xlable if xlabel: plt.xlabel(xlabel) # ylabel if ylabel: plt.ylabel(ylabel) if xlim: plt.xlim(xlim[0], xlim[1]) if ylim: plt.ylim(ylim[0], ylim[1]) if title: plt.title(title) if show: plt.show()
def visit_For(self, node): iter_ = tools.ast2code(ast.fix_missing_locations(node.iter)) if iter_.strip() == self.iter_name: data_to_replace = Collector() final_node = ast.Module(body=[]) self.success = True # target if not isinstance(node.target, ast.Name): raise errors.BrainPyError( f'Only support scalar iter, like "for x in xxxx:", not "for ' f'{tools.ast2code(ast.fix_missing_locations(node.target))} ' f'in {iter_}:') target = node.target.id # for loop values for i, value in enumerate(self.loop_values): # module and code module = ast.Module(body=deepcopy(node).body) code = tools.ast2code(module) if isinstance(value, Base): # transform Base objects r = _analyze_cls_func_body(host=value, self_name=target, code=code, tree=module, show_code=self.show_code, **self.jit_setting) new_code, arguments, arg2call, nodes, code_scope = r self.arguments.update(arguments) self.arg2call.update(arg2call) self.arg2call.update(arg2call) self.nodes.update(nodes) self.code_scope.update(code_scope) final_node.body.extend(ast.parse(new_code).body) elif callable(value): # transform functions r = _jit_func(obj_or_fun=value, show_code=self.show_code, **self.jit_setting) tree = _replace_func_call_by_tree( deepcopy(module), func_call=target, arg_to_append=r['arguments'], new_func_name=f'{target}_{i}') # update import parameters self.arguments.update(r['arguments']) self.arg2call.update(r['arg2call']) self.nodes.update(r['nodes']) # replace the data if isinstance(value, Base): host = value replace_name = f'{host.name}_{target}' elif hasattr(value, '__self__') and isinstance( value.__self__, Base): host = value.__self__ replace_name = f'{host.name}_{target}' else: replace_name = f'{target}_{i}' self.code_scope[replace_name] = r['func'] data_to_replace[f'{target}_{i}'] = replace_name final_node.body.extend(tree.body) else: raise errors.BrainPyError( f'Only support JIT an iterable objects of function ' f'or Base object, but we got:\n\n {value}') # replace words final_code = tools.ast2code(final_node) final_code = tools.word_replace(final_code, data_to_replace, exclude_dot=True) final_node = ast.parse(final_code) else: final_node = node self.generic_visit(final_node) return final_node
def _analyze_cls_func_body(host, self_name, code, tree, show_code=False, has_func_def=False, **jit_setting): arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict() # all self data self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b', code) self_data = list(set(self_data)) # analyze variables and functions accessed by the self.xx data_to_replace = {} for key in self_data: split_keys = key.split('.') if len(split_keys) < 2: raise errors.BrainPyError # get target and data target = host for i in range(1, len(split_keys)): next_target = getattr(target, split_keys[i]) if isinstance(next_target, Integrator): break if not isinstance(next_target, Base): break target = next_target else: raise errors.BrainPyError data = getattr(target, split_keys[i]) # analyze data if isinstance(data, math.numpy.Variable): # data is a variable arguments.add(f'{target.name}_{split_keys[i]}') arg2call[ f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value' nodes[target.name] = target # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif isinstance(data, np.random.RandomState): # data is a RandomState # replace RandomState code_scope[f'{target.name}_{split_keys[i]}'] = np.random # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' elif callable(data): # data is a function assert len(split_keys) == i + 1 r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting) # if len(r['arguments']): tree = _replace_func_call_by_tree(tree, func_call=key, arg_to_append=r['arguments']) arguments.update(r['arguments']) arg2call.update(r['arg2call']) nodes.update(r['nodes']) code_scope[f'{target.name}_{split_keys[i]}'] = r['func'] data_to_replace[ key] = f'{target.name}_{split_keys[i]}' # replace the data elif isinstance( data, (dict, list, tuple)): # data is a list/tuple/dict of function/object # get all values if isinstance(data, dict): # check dict if len(split_keys) != i + 2 and split_keys[-1] != 'values': raise errors.BrainPyError( f'Only support iter dict.values(). while we got ' f'dict.{split_keys[-1]} for data: \n\n{data}') values = list(data.values()) iter_name = key + '()' else: # check list / tuple assert len(split_keys) == i + 1 values = list(data) iter_name = key if len(values) > 0: if not (callable(values[0]) or isinstance(values[0], Base)): code_scope[f'{target.name}_{split_keys[i]}'] = data if len(split_keys) == i + 1: data_to_replace[ key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' continue # raise errors.BrainPyError(f'Only support JIT an iterable objects of function ' # f'or Base object, but we got:\n\n {values[0]}') # replace this for-loop r = _replace_this_forloop(tree=tree, iter_name=iter_name, loop_values=values, show_code=show_code, **jit_setting) tree, _arguments, _arg2call, _nodes, _code_scope = r arguments.update(_arguments) arg2call.update(_arg2call) nodes.update(_nodes) code_scope.update(_code_scope) else: # constants code_scope[f'{target.name}_{split_keys[i]}'] = data # replace the data if len(split_keys) == i + 1: data_to_replace[key] = f'{target.name}_{split_keys[i]}' else: data_to_replace[ key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}' if has_func_def: tree.body[0].decorator_list.clear() tree.body[0].args.args.extend( [ast.Name(id=a) for a in sorted(arguments)]) tree.body[0].args.defaults.extend( [ast.Constant(None) for _ in sorted(arguments)]) tree.body[0].args.kwarg = None # replace words code = tools.ast2code(tree) code = tools.word_replace(code, data_to_replace, exclude_dot=True) return code, arguments, arg2call, nodes, code_scope
def _check_node(node): if not isinstance(node, Base): raise errors.BrainPyError(f'Element in "nodes" must be an instance of ' f'{Base.__name__}, but we got {type(node)}.')
def vmap(func, dyn_vars=None, batched_vars=None, in_axes=0, out_axes=0, axis_name=None, reduce_func=None, auto_infer=False): """Vectorization compilation for class objects. Vectorized compile a function or a module to run in parallel on a single device. Examples -------- Parameters ---------- func : Base, function, callable The function or the module to compile. dyn_vars : dict batched_vars : dict in_axes : optional, int, sequence of int Specify which input array axes to map over. If each positional argument to ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments to ``obj_or_func``. An integer or ``None`` indicates which array axis to map over for all arguments (with ``None`` indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of dimensions (axes) of the corresponding input array. If the positional arguments to ``obj_or_func`` are container types, the corresponding element of ``in_axes`` can itself be a matching container, so that distinct array axes can be mapped for different container elements. ``in_axes`` must be a container tree prefix of the positional argument tuple passed to ``obj_or_func``. At least one positional argument must have ``in_axes`` not None. The sizes of the mapped input axes for all mapped positional arguments must all be equal. Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0). out_axes : optional, int, tuple/list/dict Indicate where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None ``out_axes`` specification. Axis integers must be in the range ``[-ndim, ndim)`` for each output array, where ``ndim`` is the number of dimensions (axes) of the array returned by the :func:`vmap`-ed function, which is one more than the number of dimensions (axes) of the corresponding array returned by ``obj_or_func``. axis_name : optional Returns ------- obj_or_func : Any Batched/vectorized version of ``obj_or_func`` with arguments that correspond to those of ``obj_or_func``, but with extra array axes at positions indicated by ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but with extra array axes at positions indicated by ``out_axes``. """ from brainpy.building.brainobjects import DynamicalSystem if isinstance(func, DynamicalSystem): if len(func.steps): # DynamicalSystem has step functions # dynamical variables dyn_vars = (dyn_vars or func.vars().unique()) dyn_vars, rand_vars = TensorCollector(), TensorCollector() for key, val in dyn_vars.items(): if isinstance(val, RandomState): rand_vars[key] = val else: dyn_vars[key] = val # in axes if in_axes is None: in_axes = {key: (None, 0) for key in func.steps.keys()} elif isinstance(in_axes, int): in_axes = { key: (None, 0, in_axes) for key in func.steps.keys() } elif isinstance(in_axes, (tuple, list)): in_axes = { key: (None, 0) + tuple(in_axes) for key in func.steps.keys() } elif isinstance(in_axes, dict): keys = list(func.steps.keys()) if keys[0] not in in_axes: in_axes = {key: (None, 0, in_axes) for key in keys} else: in_axes = { key: (None, 0) + tuple(in_axes[key]) for key in keys } assert isinstance(in_axes, dict) # batch size index batch_idx = {} for key, axes in in_axes.items(): for i, axis in enumerate(axes[2:]): if axis is not None: batch_idx[key] = (i, axis) break else: raise ValueError(f'Found no batch axis: {axes}.') # out axes if out_axes is None: out_axes = {key: 0 for key in func.steps.keys()} elif isinstance(out_axes, int): out_axes = {key: out_axes for key in func.steps.keys()} elif isinstance(out_axes, (tuple, list)): out_axes = { key: tuple(out_axes) + (0, 0) for key in func.steps.keys() } elif isinstance(out_axes, dict): keys = list(func.steps.keys()) if keys[0] not in out_axes: out_axes = {key: (out_axes, 0, 0) for key in keys} else: out_axes = { key: tuple(out_axes[key]) + (0, 0) for key in keys } assert isinstance(out_axes, dict) # reduce_func if reduce_func is None: reduce_func = lambda x: x.mean(axis=0) # vectorized map functions for key in func.steps.keys(): func.steps[key] = _make_vmap(func=func.steps[key], dyn_vars=dyn_vars, rand_vars=rand_vars, in_axes=in_axes[key], out_axes=out_axes[key], axis_name=axis_name, batch_idx=batch_idx[key], reduce_func=reduce_func, f_name=key) return func if callable(func): if auto_infer: if dyn_vars is not None: dyn_vars = dyn_vars elif isinstance(func, Base): # Base has '__call__()' implementation dyn_vars = func.vars().unique() elif hasattr(func, '__self__'): if isinstance(func.__self__, Base): dyn_vars = func.__self__.vars().unique() if dyn_vars is None: return jax.vmap(func, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) else: # dynamical variables dyn_vars, rand_vars = TensorCollector(), TensorCollector() for key, val in dyn_vars.items(): if isinstance(val, RandomState): rand_vars[key] = val else: dyn_vars[key] = val # in axes if in_axes is None: in_axes = (None, 0) elif isinstance(in_axes, (int, dict)): in_axes = (None, 0, in_axes) elif isinstance(in_axes, (tuple, list)): in_axes = (None, 0) + tuple(in_axes) assert isinstance(in_axes, (tuple, list)) # batch size index batch_idx = {} for key, axes in batch_idx.items(): for i, axis in enumerate(axes[2:]): if axis is not None: batch_idx[key] = (i, axis) break else: raise ValueError(f'Found no batch axis: {axes}.') # out axes if out_axes is None: out_axes = 0 elif isinstance(out_axes, (int, dict)): out_axes = (out_axes, 0, 0) elif isinstance(out_axes, (tuple, list)): out_axes = tuple(out_axes) + (0, 0) assert isinstance(out_axes, (list, tuple)) # reduce_func if reduce_func is None: reduce_func = lambda x: x.mean(axis=0) # jit function return _make_vmap(func=func, dyn_vars=dyn_vars, rand_vars=rand_vars, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, batch_idx=batch_idx, reduce_func=reduce_func) else: raise errors.BrainPyError( f'Only support instance of {Base.__name__}, or a callable ' f'function, but we got {type(func)}.')
def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, reduce_func=None): """Parallel compilation for class objects. Parallel compile a function or a module to run on multiple devices in parallel. Parameters ---------- func axis_name in_axes out_axes static_broadcasted_argnums devices backend axis_size donate_argnums global_arg_shapes Returns ------- Examples -------- """ from brainpy.building.brainobjects import DynamicalSystem if isinstance(func, DynamicalSystem): if len(func.steps): # DynamicalSystem has step functions # dynamical variables all_vars = (dyn_vars or func.vars().unique()) dyn_vars = TensorCollector() rand_vars = TensorCollector() for key, val in all_vars.items(): if isinstance(val, RandomState): rand_vars[key] = val else: dyn_vars[key] = val # reduce function if reduce_func is None: reduce_func = jnp.concatenate # static broadcast-ed arguments if static_broadcasted_argnums is None: static_broadcasted_argnums = () elif isinstance(static_broadcasted_argnums, int): static_broadcasted_argnums = (static_broadcasted_argnums + 2, ) elif isinstance(static_broadcasted_argnums, (tuple, list)): static_broadcasted_argnums = tuple( argnum + 2 for argnum in static_broadcasted_argnums) assert isinstance(static_broadcasted_argnums, (tuple, list)) # jit functions for key in func.steps.keys(): step = func.steps[key] func.steps[key] = _make_pmap( dyn_vars=dyn_vars, rand_vars=rand_vars, func=step, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, reduce_func=reduce_func, f_name=key) return func if callable(func): if dyn_vars is not None: dyn_vars = dyn_vars elif isinstance(func, Base): # Base has '__call__()' implementation dyn_vars = func.vars().unique() elif hasattr(func, '__self__'): if isinstance(func.__self__, Base): dyn_vars = func.__self__.vars().unique() if dyn_vars is None: return jax.pmap( func, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes) else: # dynamical variables dyn_vars = TensorCollector() rand_vars = TensorCollector() for key, val in dyn_vars.items(): if isinstance(val, RandomState): rand_vars[key] = val else: dyn_vars[key] = val # static broadcast-ed arguments if static_broadcasted_argnums is None: static_broadcasted_argnums = () elif isinstance(static_broadcasted_argnums, int): static_broadcasted_argnums = (static_broadcasted_argnums + 2, ) elif isinstance(static_broadcasted_argnums, (tuple, list)): static_broadcasted_argnums = tuple( argnum + 2 for argnum in static_broadcasted_argnums) assert isinstance(static_broadcasted_argnums, (tuple, list)) # reduce function if reduce_func is None: reduce_func = jnp.concatenate # jit function func.__call__ = _make_pmap( dyn_vars=dyn_vars, rand_vars=rand_vars, func=func, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, reduce_func=reduce_func) return func else: raise errors.BrainPyError( f'Only support instance of {Base.__name__}, or a callable function, ' f'but we got {type(func)}.')
def line_plot(ts, val_matrix, plot_ids=None, ax=None, xlim=None, ylim=None, xlabel='Time (ms)', ylabel=None, legend=None, title=None, show=False, **kwargs): """Show the specified value in the given object (Neurons or Synapses.) Parameters ---------- ts : np.ndarray The time steps. val_matrix : np.ndarray The value matrix which record the history trajectory. It can be easily accessed by specifying the ``monitors`` of NeuGroup/SynConn by: ``neu/syn = NeuGroup/SynConn(..., monitors=[k1, k2])`` plot_ids : None, int, tuple, a_list The index of the value to plot. ax : None, Axes The figure to plot. xlim : list, tuple The xlim. ylim : list, tuple The ylim. xlabel : str The xlabel. ylabel : str The ylabel. legend : str The prefix of legend for plot. show : bool Whether show the figure. """ # get plot_ids if plot_ids is None: plot_ids = [0] elif isinstance(plot_ids, int): plot_ids = [plot_ids] if not isinstance(plot_ids, (list, tuple)) and \ not (isinstance(plot_ids, np.ndarray) and np.ndim(plot_ids) == 1): raise errors.BrainPyError( f'"plot_ids" specifies the value index to plot, it must ' f'be a list/tuple/1D numpy.ndarray, not {type(plot_ids)}.') # get ax if ax is None: ax = plt val_matrix = val_matrix.reshape((val_matrix.shape[0], -1)) # change data val_matrix = np.asarray(val_matrix) ts = np.asarray(ts) # plot if legend: for idx in plot_ids: ax.plot(ts, val_matrix[:, idx], label=f'{legend}-{idx}', **kwargs) else: for idx in plot_ids: ax.plot(ts, val_matrix[:, idx], **kwargs) # legend if legend: ax.legend() # xlim if xlim is not None: plt.xlim(xlim[0], xlim[1]) # ylim if ylim is not None: plt.ylim(ylim[0], ylim[1]) # xlable if xlabel: plt.xlabel(xlabel) # ylabel if ylabel: plt.ylabel(ylabel) # title if title: plt.title(title) # show if show: plt.show()
def jit(func, dyn_vars=None, static_argnames=None, device=None, auto_infer=True): """JIT (Just-In-Time) compilation for class objects. This function has the same ability to Just-In-Time compile a pure function, but it can also JIT compile a :py:class:`brainpy.DynamicalSystem`, or a :py:class:`brainpy.Base` object, or a bounded method for a :py:class:`brainpy.Base` object. .. note:: There are several notes when using JIT compilation. 1. Avoid using scalar in a Variable, TrainVar, etc. For example, >>> import brainpy as bp >>> import brainpy.math as bm >>> >>> class Test(bp.Base): >>> def __init__(self): >>> super(Test, self).__init__() >>> self.a = bm.Variable(1.) # Avoid! DO NOT USE! >>> def __call__(self, *args, **kwargs): >>> self.a += 1. The above usage is deprecated, because it may cause several errors. Instead, we recommend you define the scalar value variable as: >>> class Test(bp.Base): >>> def __init__(self): >>> super(Test, self).__init__() >>> self.a = bm.Variable(bm.array([1.])) # use array to wrap a scalar is recommended >>> def __call__(self, *args, **kwargs): >>> self.a += 1. Here, a ndarray is recommended to used to update the variable ``a``. 2. ``jit`` compilation in ``brainpy.math`` does not support `static_argnums`. Instead, users should use `static_argnames`, and call the jitted function with keywords like ``jitted_func(arg1=var1, arg2=var2)``. For example, >>> def f(a, b, c=1.): >>> if c > 0.: return a + b >>> else: return a * b >>> >>> # ERROR! https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit >>> bm.jit(f)(1, 2, 0) jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True) >>> # this is right >>> bm.jit(f, static_argnames='c')(1, 2, 0) DeviceArray(2, dtype=int32, weak_type=True) Examples -------- You can JIT a :py:class:`brainpy.DynamicalSystem` >>> import brainpy as bp >>> >>> class LIF(bp.NeuGroup): >>> pass >>> lif = bp.math.jit(LIF(10)) You can JIT a :py:class:`brainpy.Base` object with ``__call__()`` implementation. >>> mlp = bp.layers.GRU(100, 200) >>> jit_mlp = bp.math.jit(mlp) You can also JIT a bounded method of a :py:class:`brainpy.Base` object. >>> class Hello(bp.Base): >>> def __init__(self): >>> super(Hello, self).__init__() >>> self.a = bp.math.Variable(bp.math.array(10.)) >>> self.b = bp.math.Variable(bp.math.array(2.)) >>> def transform(self): >>> return self.a ** self.b >>> >>> test = Hello() >>> bp.math.jit(test.transform) Further, you can JIT a normal function, just used like in JAX. >>> @bp.math.jit >>> def selu(x, alpha=1.67, lmbda=1.05): >>> return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha) Parameters ---------- func : Base, function, callable The instance of Base or a function. dyn_vars : optional, dict, tuple, list, JaxArray These variables will be changed in the function, or needed in the computation. static_argnames : optional, str, list, tuple, dict An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on ``static_argnums`` for details. If not provided but ``static_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. device: optional, Any This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. auto_infer : bool Automatical infer the dynamical variables. Returns ------- func : Any A wrapped version of Base object or function, set up for just-in-time compilation. """ from brainpy.building.brainobjects import DynamicalSystem if isinstance(func, DynamicalSystem): if len(func.steps): # DynamicalSystem has step functions # dynamical variables if dyn_vars is None: if auto_infer: dyn_vars = func.vars().unique() else: dyn_vars = TensorCollector() if isinstance(dyn_vars, JaxArray): dyn_vars = TensorCollector({'_': dyn_vars}) elif isinstance(dyn_vars, dict): dyn_vars = TensorCollector(dyn_vars) elif isinstance(dyn_vars, (tuple, list)): dyn_vars = TensorCollector( {f'_v{i}': v for i, v in enumerate(dyn_vars)}) else: raise ValueError # static arguments by name if static_argnames is None: static_argnames = {key: None for key in func.steps.keys()} elif isinstance(static_argnames, str): static_argnames = { key: (static_argnames, ) for key in func.steps.keys() } elif isinstance(static_argnames, (tuple, list)) and isinstance( static_argnames[0], str): static_argnames = { key: static_argnames for key in func.steps.keys() } assert isinstance(static_argnames, dict) # jit functions for key in list(func.steps.keys()): jitted_func = _make_jit(vars=dyn_vars, func=func.steps[key], static_argnames=static_argnames[key], device=device, f_name=key) func.steps.replace(key, jitted_func) return func if callable(func): if dyn_vars is not None: if isinstance(dyn_vars, JaxArray): dyn_vars = TensorCollector({'_': dyn_vars}) elif isinstance(dyn_vars, dict): dyn_vars = TensorCollector(dyn_vars) elif isinstance(dyn_vars, (tuple, list)): dyn_vars = TensorCollector( {f'_v{i}': v for i, v in enumerate(dyn_vars)}) else: raise ValueError else: if auto_infer: if isinstance(func, Base): dyn_vars = func.vars().unique() elif hasattr(func, '__self__') and isinstance( func.__self__, Base): dyn_vars = func.__self__.vars().unique() else: dyn_vars = TensorCollector() else: dyn_vars = TensorCollector() if len(dyn_vars) == 0: # pure function return jax.jit(func, static_argnames=static_argnames, device=device) else: # Base object which implements __call__, or bounded method of Base object return _make_jit(vars=dyn_vars, func=func, static_argnames=static_argnames, device=device) else: raise errors.BrainPyError( f'Only support instance of {Base.__name__}, or a callable ' f'function, but we got {type(func)}.')
def plot_bifurcation(self, with_plot=True, show=False, with_return=False, tol_aux=1e-8, loss_screen=None): utils.output('I am making bifurcation analysis ...') xs = self.resolutions[self.x_var] vps = bm.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names)) vps = tuple(jnp.moveaxis(vp.value, 0, 1).flatten() for vp in vps) candidates = vps[0] pars = vps[1:] fixed_points, _, pars = self._get_fixed_points(candidates, *pars, tol_aux=tol_aux, loss_screen=loss_screen, num_seg=len(xs)) dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars)) pars = tuple(np.asarray(p) for p in pars) if with_plot: if len(self.target_pars) == 1: container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()} # fixed point for p, x, dx in zip(pars[0], fixed_points, dfxdx): fp_type = stability.stability_analysis(dx) container[fp_type]['p'].append(p) container[fp_type]['x'].append(x) # visualization plt.figure(self.x_var) for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type) plt.xlabel(self.target_par_names[0]) plt.ylabel(self.x_var) scale = (self.lim_scale - 1) / 2 plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) plt.legend() if show: plt.show() elif len(self.target_pars) == 2: container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()} # fixed point for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx): fp_type = stability.stability_analysis(dx) container[fp_type]['p0'].append(p0) container[fp_type]['p1'].append(p1) container[fp_type]['x'].append(x) # visualization fig = plt.figure(self.x_var) ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['x']): plot_style = stability.plot_scheme[fp_type] xs = points['p0'] ys = points['p1'] zs = points['x'] ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) ax.set_ylabel(self.target_par_names[1]) ax.set_zlabel(self.x_var) scale = (self.lim_scale - 1) / 2 ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) ax.grid(True) ax.legend() if show: plt.show() else: raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} ' f'bifurcation.') if with_return: return fixed_points, pars, dfxdx