コード例 #1
0
    def __init__(self, *arg_ds, name=None, **kwarg_ds):
        super(Sequential, self).__init__(name=name)

        self.implicit_nodes = Collector()
        # check "args"
        for ds in arg_ds:
            if not isinstance(ds, DynamicalSystem):
                raise errors.BrainPyError(
                    f'Only support {DynamicalSystem.__name__}, '
                    f'but we got {type(ds)}: {str(ds)}.')
            self.implicit_nodes[ds.name] = ds

        # check "kwargs"
        for key, ds in kwarg_ds.items():
            if not isinstance(ds, DynamicalSystem):
                raise errors.BrainPyError(
                    f'Only support {DynamicalSystem.__name__}, '
                    f'but we got {type(ds)}: {str(ds)}.')
            self.implicit_nodes[key] = ds

        # all update functions
        self._return_kwargs = [
            'kwargs' in inspect.signature(ds.update).parameters.keys()
            for ds in self.implicit_nodes.values()
        ]
コード例 #2
0
    def nodes(self, method='absolute', _paths=None):
        """Collect all children nodes.

    Parameters
    ----------
    method : str
      The method to access the nodes.
    _paths : set, Optional
      The data structure to solve the circular reference.

    Returns
    -------
    gather : Collector
      The collection contained (the path, the node).
    """
        if _paths is None:
            _paths = set()
        gather = Collector()
        if method == 'absolute':
            nodes = []
            for k, v in self.__dict__.items():
                if isinstance(v, Base):
                    path = (id(self), id(v))
                    if path not in _paths:
                        _paths.add(path)
                        gather[v.name] = v
                        nodes.append(v)
            for node in self.implicit_nodes.values():
                path = (id(self), id(node))
                if path not in _paths:
                    _paths.add(path)
                    gather[node.name] = node
                    nodes.append(node)
            for v in nodes:
                gather.update(v.nodes(method=method, _paths=_paths))
            gather[self.name] = self

        elif method == 'relative':
            nodes = []
            gather[''] = self
            for k, v in self.__dict__.items():
                if isinstance(v, Base):
                    path = (id(self), id(v))
                    if path not in _paths:
                        _paths.add(path)
                        gather[k] = v
                        nodes.append((k, v))
            for key, node in self.implicit_nodes.items():
                path = (id(self), id(node))
                if path not in _paths:
                    _paths.add(path)
                    gather[key] = node
                    nodes.append((key, node))
            for k1, v1 in nodes:
                for k2, v2 in v1.nodes(method=method, _paths=_paths).items():
                    if k2: gather[f'{k1}.{k2}'] = v2

        else:
            raise ValueError(f'No support for the method of "{method}".')
        return gather
コード例 #3
0
    def __init__(self, name=None):
        # check whether the object has a unique name.
        self.name = self.unique_name(name=name)
        naming.check_name_uniqueness(name=self.name, obj=self)

        # Used to wrap the implicit variables
        # which cannot be accessed by self.xxx
        self.implicit_vars = TensorCollector()

        # Used to wrap the implicit children nodes
        # which cannot be accessed by self.xxx
        self.implicit_nodes = Collector()
コード例 #4
0
ファイル: base.py プロジェクト: PKU-NIP-Lab/BrainPy
    def __init__(self, steps=None, name=None):
        super(DynamicalSystem, self).__init__(name=name)

        # step functions
        if steps is None:
            steps = ('update', )
        self.steps = Collector()
        if isinstance(steps, tuple):
            for step in steps:
                if isinstance(step, str):
                    self.steps[step] = getattr(self, step)
                elif callable(step):
                    self.steps[step.__name__] = step
                else:
                    raise ModelBuildError(
                        _error_msg.format(steps[0].__class__, str(steps[0])))
        elif isinstance(steps, dict):
            for key, step in steps.items():
                if callable(step):
                    self.steps[key] = step
                else:
                    raise ModelBuildError(
                        _error_msg.format(steps.__class__, str(steps)))
        else:
            raise ModelBuildError(
                _error_msg.format(steps.__class__, str(steps)))
コード例 #5
0
ファイル: base.py プロジェクト: PKU-NIP-Lab/BrainPy
    def __init__(self, *ds_tuple, steps=None, name=None, **ds_dict):
        # integrative step function
        if steps is None:
            steps = ('update', )
        super(Container, self).__init__(steps=steps, name=name)

        # children dynamical systems
        self.implicit_nodes = Collector()
        for ds in ds_tuple:
            if not isinstance(ds, DynamicalSystem):
                raise ModelBuildError(
                    f'{self.__class__.__name__} receives instances of '
                    f'DynamicalSystem, however, we got {type(ds)}.')
            if ds.name in self.implicit_nodes:
                raise ValueError(
                    f'{ds.name} has been paired with {ds}. Please change a unique name.'
                )
        self.register_implicit_nodes({node.name: node for node in ds_tuple})
        for key, ds in ds_dict.items():
            if not isinstance(ds, DynamicalSystem):
                raise ModelBuildError(
                    f'{self.__class__.__name__} receives instances of '
                    f'DynamicalSystem, however, we got {type(ds)}.')
            if key in self.implicit_nodes:
                raise ValueError(
                    f'{key} has been paired with {ds}. Please change a unique name.'
                )
        self.register_implicit_nodes(ds_dict)
コード例 #6
0
ファイル: ast2numba.py プロジェクト: PKU-NIP-Lab/BrainPy
    def __init__(self, loop_values, iter_name, show_code=False, **jit_setting):
        self.success = False

        # targets
        self.loop_values = loop_values
        self.iter_name = iter_name

        # setting
        self.show_code = show_code
        self.jit_setting = jit_setting

        # results
        self.arguments = set()
        self.arg2call = dict()
        self.nodes = Collector()
        self.code_scope = dict()
コード例 #7
0
ファイル: ast2numba.py プロジェクト: PKU-NIP-Lab/BrainPy
def _jit_func(obj_or_fun, show_code=False, **jit_setting):
    if callable(obj_or_fun):
        # integrator
        if isinstance(obj_or_fun, Integrator):
            return _jit_intg(obj_or_fun, show_code=show_code, **jit_setting)

        # bounded method
        elif hasattr(obj_or_fun, '__self__') and isinstance(
                obj_or_fun.__self__, Base):
            return _jit_cls_func(obj_or_fun,
                                 host=obj_or_fun.__self__,
                                 show_code=show_code,
                                 **jit_setting)

        # wrapped function
        elif isinstance(obj_or_fun, Function):
            return _jit_Function(obj_or_fun,
                                 show_code=show_code,
                                 **jit_setting)

        # base class function
        elif isinstance(obj_or_fun, Base):
            return _jit_cls_func(obj_or_fun.__call__,
                                 host=obj_or_fun,
                                 show_code=show_code,
                                 **jit_setting)

        else:
            # native function
            if not isinstance(obj_or_fun, Dispatcher):
                if inspector.inspect_function(
                        obj_or_fun)['numba_type'] is None:
                    f = numba.jit(obj_or_fun, **jit_setting)
                    return dict(func=f,
                                arguments=set(),
                                arg2call=Collector(),
                                nodes=Collector())
            # numba function or innate supported function
            return dict(func=obj_or_fun,
                        arguments=set(),
                        arg2call=Collector(),
                        nodes=Collector())

    else:
        raise ValueError
コード例 #8
0
        def integral_func(*args, **kwargs):
            # format arguments
            params_in = Collector()
            for i, arg in enumerate(args):
                params_in[all_vps[i]] = arg
            params_in.update(kwargs)
            if 'dt' not in params_in:
                params_in['dt'] = math.get_dt()

            # call integrals
            results = []
            for i, int_fun in enumerate(integrals):
                _key = arg_names[i][0]
                r = int_fun(
                    params_in[_key], **{
                        arg: params_in[arg]
                        for arg in arg_names[i][1:] if arg in params_in
                    })
                results.append(r)
            return results if isinstance(self.f,
                                         joint_eq.JointEq) else results[0]
コード例 #9
0
    def ints(self, method='absolute'):
        """Collect all integrators in this node and the children nodes.

    Parameters
    ----------
    method : str
      The method to access the integrators.

    Returns
    -------
    collector : Collector
      The collection contained (the path, the integrator).
    """
        global Integrator
        if Integrator is None: from brainpy.integrators.base import Integrator

        nodes = self.nodes(method=method)
        gather = Collector()
        for node_path, node in nodes.items():
            for k in dir(node):
                v = getattr(node, k)
                if isinstance(v, Integrator):
                    gather[f'{node_path}.{k}' if node_path else k] = v
        return gather
コード例 #10
0
ファイル: ast2numba.py プロジェクト: PKU-NIP-Lab/BrainPy
def _jit_cls_func(f, code=None, host=None, show_code=False, **jit_setting):
    """JIT a class function.

  Examples
  --------

  Example 1: the model has static parameters.

  >>> import brainpy as bp
  >>>
  >>> class HH(bp.NeuGroup):
  >>>     def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0,
  >>>                  gNa=120., gK=36., gL=0.03, V_th=20., **kwargs):
  >>>       super(HH, self).__init__(size=size, **kwargs)
  >>>       # parameters
  >>>       self.ENa = ENa
  >>>       self.EK = EK
  >>>       self.EL = EL
  >>>       self.C = C
  >>>       self.gNa = gNa
  >>>       self.gK = gK
  >>>       self.gL = gL
  >>>       self.V_th = V_th
  >>>
  >>>     def derivaitve(self, V, m, h, n, t, Iext):
  >>>       alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
  >>>       beta = 4.0 * np.exp(-(V + 65) / 18)
  >>>       dmdt = alpha * (1 - m) - beta * m
  >>>
  >>>       alpha = 0.07 * np.exp(-(V + 65) / 20.)
  >>>       beta = 1 / (1 + np.exp(-(V + 35) / 10))
  >>>       dhdt = alpha * (1 - h) - beta * h
  >>>
  >>>       alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))
  >>>       beta = 0.125 * np.exp(-(V + 65) / 80)
  >>>       dndt = alpha * (1 - n) - beta * n
  >>>
  >>>       I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
  >>>       I_K = (self.gK * n ** 4.0) * (V - self.EK)
  >>>       I_leak = self.gL * (V - self.EL)
  >>>       dVdt = (- I_Na - I_K - I_leak + Iext) / self.C
  >>>
  >>>       return dVdt, dmdt, dhdt, dndt
  >>>
  >>> r = _jit_cls_func(HH(10).derivaitve, show_code=True)

  The recompiled function:
  -------------------------

  def derivaitve(V, m, h, n, t, Iext):
      alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
      beta = 4.0 * np.exp(-(V + 65) / 18)
      dmdt = alpha * (1 - m) - beta * m
      alpha = 0.07 * np.exp(-(V + 65) / 20.0)
      beta = 1 / (1 + np.exp(-(V + 35) / 10))
      dhdt = alpha * (1 - h) - beta * h
      alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))
      beta = 0.125 * np.exp(-(V + 65) / 80)
      dndt = alpha * (1 - n) - beta * n
      I_Na = HH0_gNa * m ** 3.0 * h * (V - HH0_ENa)
      I_K = HH0_gK * n ** 4.0 * (V - HH0_EK)
      I_leak = HH0_gL * (V - HH0_EL)
      dVdt = (-I_Na - I_K - I_leak + Iext) / HH0_C
      return dVdt, dmdt, dhdt, dndt

  The namespace of the above function:
  {'HH0_C': 1.0,
   'HH0_EK': -77.0,
   'HH0_EL': -54.387,
   'HH0_ENa': 50.0,
   'HH0_gK': 36.0,
   'HH0_gL': 0.03,
   'HH0_gNa': 120.0,
   'bp': <module 'brainpy' from 'D:\\codes\\Projects\\BrainPy\\brainpy\\__init__.py'>}
  >>> r['func']
  CPUDispatcher(<function derivaitve at 0x0000020DF1647DC0>)
  >>> r['arguments']
  set()
  >>> r['arg2call']
  {}
  >>> r['nodes']
  {'HH0': <__main__.<locals>.HH object at 0x0000020DF1623910>}


  Example 2: the model has dynamical variables.

  >>> import brainpy as bp
  >>>
  >>> class HH(bp.NeuGroup):
  >>>     def __init__(self, size, ENa=50., EK=-77., EL=-54.387, C=1.0,
  >>>                  gNa=120., gK=36., gL=0.03, V_th=20., **kwargs):
  >>>       super(HH, self).__init__(size=size, **kwargs)
  >>>       # parameters
  >>>       self.ENa = ENa
  >>>       self.EK = EK
  >>>       self.EL = EL
  >>>       self.C = C
  >>>       self.gNa = gNa
  >>>       self.gK = gK
  >>>       self.gL = gL
  >>>       self.V_th = V_th
  >>>       self.input = bp.math.numpy.Variable(np.zeros(size))
  >>>
  >>>     def derivaitve(self, V, m, h, n, t):
  >>>       alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
  >>>       beta = 4.0 * np.exp(-(V + 65) / 18)
  >>>       dmdt = alpha * (1 - m) - beta * m
  >>>
  >>>       alpha = 0.07 * np.exp(-(V + 65) / 20.)
  >>>       beta = 1 / (1 + np.exp(-(V + 35) / 10))
  >>>       dhdt = alpha * (1 - h) - beta * h
  >>>
  >>>       alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))
  >>>       beta = 0.125 * np.exp(-(V + 65) / 80)
  >>>       dndt = alpha * (1 - n) - beta * n
  >>>
  >>>       I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
  >>>       I_K = (self.gK * n ** 4.0) * (V - self.EK)
  >>>       I_leak = self.gL * (V - self.EL)
  >>>       dVdt = (- I_Na - I_K - I_leak + self.input) / self.C
  >>>
  >>>       return dVdt, dmdt, dhdt, dndt
  >>>
  >>> r = _jit_cls_func(HH(10).derivaitve, show_code=True)

  The recompiled function:
  -------------------------

  def derivaitve(V, m, h, n, t, HH0_input=None):
      alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))
      beta = 4.0 * np.exp(-(V + 65) / 18)
      dmdt = alpha * (1 - m) - beta * m
      alpha = 0.07 * np.exp(-(V + 65) / 20.0)
      beta = 1 / (1 + np.exp(-(V + 35) / 10))
      dhdt = alpha * (1 - h) - beta * h
      alpha = 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))
      beta = 0.125 * np.exp(-(V + 65) / 80)
      dndt = alpha * (1 - n) - beta * n
      I_Na = HH0_gNa * m ** 3.0 * h * (V - HH0_ENa)
      I_K = HH0_gK * n ** 4.0 * (V - HH0_EK)
      I_leak = HH0_gL * (V - HH0_EL)
      dVdt = (-I_Na - I_K - I_leak + HH0_input) / HH0_C
      return dVdt, dmdt, dhdt, dndt

  The namespace of the above function:
  {'HH0_C': 1.0,
   'HH0_EK': -77.0,
   'HH0_EL': -54.387,
   'HH0_ENa': 50.0,
   'HH0_gK': 36.0,
   'HH0_gL': 0.03,
   'HH0_gNa': 120.0,
   'bp': <module 'brainpy' from 'D:\\codes\\Projects\\BrainPy\\brainpy\\__init__.py'>}
  >>> r['func']
  CPUDispatcher(<function derivaitve at 0x0000020DF1647DC0>)
  >>> r['arguments']
  {'HH0_input'}
  >>> r['arg2call']
  {'HH0_input': 'HH0.input.value'}
  >>> r['nodes']
  {'HH0': <__main__.<locals>.HH object at 0x00000219AE495E80>}

  Parameters
  ----------
  f
  code
  host
  show_code
  jit_setting

  Returns
  -------

  """
    host = (host or f.__self__)

    # data to return
    arguments = set()
    arg2call = dict()
    nodes = Collector()
    nodes[host.name] = host

    # code
    code = (code or tools.deindent(inspect.getsource(f)).strip())
    # function name
    func_name = f.__name__
    # code scope
    closure_vars = inspect.getclosurevars(f)
    code_scope = dict(closure_vars.nonlocals)
    code_scope.update(closure_vars.globals)
    # analyze class function
    code, _arguments, _arg2call, _nodes, _code_scope = _analyze_cls_func(
        host=host, code=code, show_code=show_code, **jit_setting)
    arguments.update(_arguments)
    arg2call.update(_arg2call)
    nodes.update(_nodes)
    code_scope.update(_code_scope)

    # compile new function
    # code, _scope = _add_try_except(code)
    # code_scope.update(_scope)
    code_scope_to_compile = code_scope.copy()
    if show_code:
        _show_compiled_codes(code, code_scope)
    exec(compile(code, '', 'exec'), code_scope_to_compile)
    func = code_scope_to_compile[func_name]
    func = numba.jit(func, **jit_setting)

    # returns
    return dict(func=func,
                code=code,
                code_scope=code_scope,
                arguments=arguments,
                arg2call=arg2call,
                nodes=nodes)
コード例 #11
0
ファイル: runner.py プロジェクト: PKU-NIP-Lab/BrainPy
    def __init__(self,
                 target,
                 monitors=None,
                 inits=None,
                 args=None,
                 dyn_args=None,
                 dyn_vars=None,
                 jit=True,
                 dt=None,
                 numpy_mon_after_run=True,
                 progress_bar=True):
        super(IntegratorRunner, self).__init__()

        # parameters
        dt = math.get_dt() if dt is None else dt
        if not isinstance(dt, (int, float)):
            raise RunningError(f'"dt" must be scalar, but got {dt}')
        self.dt = dt
        self.jit = jit
        self.numpy_mon_after_run = numpy_mon_after_run
        self._pbar = None  # progress bar
        self.progress_bar = progress_bar

        # target
        if not isinstance(target, Integrator):
            raise RunningError(
                f'"target" must be an instance of {Integrator.__name__}, '
                f'but we got {type(target)}: {target}')
        self.target = target

        # arguments of the integral function
        self._static_args = Collector()
        if args is not None:
            assert isinstance(
                args, dict
            ), f'"args" must be a dict, but we get {type(args)}: {args}'
            self._static_args.update(args)
        self._dyn_args = Collector()
        if dyn_args is not None:
            assert isinstance(
                dyn_args, dict
            ), f'"dyn_args" must be a dict, but we get {type(dyn_args)}: {dyn_args}'
            sizes = np.unique([len(v) for v in dyn_args.values()])
            num_size = len(sizes)
            if num_size != 1:
                raise RunningError(
                    f'All values in "dyn_args" should have the same length. But we got '
                    f'{num_size}: {sizes}')
            self._dyn_args.update(dyn_args)

        # dynamical changed variables
        if dyn_vars is None:
            dyn_vars = self.target.vars().unique()
        if isinstance(dyn_vars, (list, tuple)):
            dyn_vars = {f'_v{i}': v for i, v in enumerate(dyn_vars)}
        if not isinstance(dyn_vars, dict):
            raise RunningError(
                f'"dyn_vars" must be a dict, but we got {type(dyn_vars)}')
        self.dyn_vars = TensorCollector(dyn_vars)

        # monitors
        if monitors is None:
            self.mon = Monitor(target=self, variables=[])
        elif isinstance(monitors, (list, tuple, dict)):
            self.mon = Monitor(target=self, variables=monitors)
        elif isinstance(monitors, Monitor):
            self.mon = monitors
            self.mon.target = self
        else:
            raise MonitorError(f'"monitors" only supports list/tuple/dict/ '
                               f'instance of Monitor, not {type(monitors)}.')
        self.mon.build()  # build the monitor
        for k in self.mon.item_names:
            if k not in self.target.variables:
                raise MonitorError(
                    f'Variable "{k}" to monitor is not defined in the integrator {self.target}.'
                )

        # start simulation time
        self._start_t = None

        # Variables
        if inits is not None:
            if isinstance(inits, (list, tuple)):
                assert len(self.target.variables) == len(inits)
                inits = {
                    k: inits[i]
                    for i, k in enumerate(self.target.variables)
                }
            assert isinstance(inits, dict)
            sizes = np.unique([np.size(v) for v in list(inits.values())])
            max_size = np.max(sizes)
        else:
            max_size = 1
            inits = dict()
        self.variables = TensorCollector({
            v: math.Variable(math.zeros(max_size))
            for v in self.target.variables
        })
        for k in inits.keys():
            self.variables[k][:] = inits[k]
        self.dyn_vars.update(self.variables)
        if len(self._dyn_args) > 0:
            self.idx = math.Variable(math.zeros(1, dtype=math.int_))
            self.dyn_vars['_idx'] = self.idx

        # build the update step
        if jit:
            _loop_func = math.make_loop(
                self._step,
                dyn_vars=self.dyn_vars,
                out_vars={k: self.variables[k]
                          for k in self.mon.item_names})
        else:

            def _loop_func(t_and_dt):
                out_vars = {k: [] for k in self.mon.item_names}
                times, dts = t_and_dt
                for i in range(len(times)):
                    _t = times[i]
                    _dt = dts[i]
                    self._step([_t, _dt])
                    for k in self.mon.item_names:
                        out_vars[k].append(
                            math.as_device_array(self.variables[k]))
                out_vars = {
                    k: math.asarray(out_vars[k])
                    for k in self.mon.item_names
                }
                return out_vars

        self.step_func = _loop_func
コード例 #12
0
class Base(object):
    """The Base class for whole BrainPy ecosystem.

  The subclass of Base includes:

  - ``DynamicalSystem`` in *brainpy.simulation.brainobjects.base.py*
  - ``Integrator`` in *brainpy.integrators.base.py*
  - ``Function`` in *brainpy.base.function.py*
  - ``AutoGrad`` in *brainpy.math.jax.autograd.py*
  - ``Optimizer`` in *brainpy.math.jax.optimizers.py*
  - ``Scheduler`` in *brainpy.math.jax.optimizers.py*

  """
    def __init__(self, name=None):
        # check whether the object has a unique name.
        self.name = self.unique_name(name=name)
        naming.check_name_uniqueness(name=self.name, obj=self)

        # Used to wrap the implicit variables
        # which cannot be accessed by self.xxx
        self.implicit_vars = TensorCollector()

        # Used to wrap the implicit children nodes
        # which cannot be accessed by self.xxx
        self.implicit_nodes = Collector()

    def register_implicit_vars(self, variables):
        assert isinstance(variables, dict)
        self.implicit_vars.update(variables)

    def register_implicit_nodes(self, nodes):
        assert isinstance(nodes, dict)
        self.implicit_nodes.update(nodes)

    def vars(self, method='absolute'):
        """Collect all variables in this node and the children nodes.

    Parameters
    ----------
    method : str
      The method to access the variables.

    Returns
    -------
    gather : TensorCollector
      The collection contained (the path, the variable).
    """
        global math
        if math is None: from brainpy import math

        nodes = self.nodes(method=method)
        gather = TensorCollector()
        for node_path, node in nodes.items():
            for k in dir(node):
                v = getattr(node, k)
                if isinstance(v, math.Variable):
                    gather[f'{node_path}.{k}' if node_path else k] = v
            gather.update(
                {f'{node_path}.{k}': v
                 for k, v in node.implicit_vars.items()})
        return gather

    def train_vars(self, method='absolute'):
        """The shortcut for retrieving all trainable variables.

    Parameters
    ----------
    method : str
      The method to access the variables. Support 'absolute' and 'relative'.

    Returns
    -------
    gather : TensorCollector
      The collection contained (the path, the trainable variable).
    """
        global math
        if math is None:
            from brainpy import math
        return self.vars(method=method).subset(math.TrainVar)

    def nodes(self, method='absolute', _paths=None):
        """Collect all children nodes.

    Parameters
    ----------
    method : str
      The method to access the nodes.
    _paths : set, Optional
      The data structure to solve the circular reference.

    Returns
    -------
    gather : Collector
      The collection contained (the path, the node).
    """
        if _paths is None:
            _paths = set()
        gather = Collector()
        if method == 'absolute':
            nodes = []
            for k, v in self.__dict__.items():
                if isinstance(v, Base):
                    path = (id(self), id(v))
                    if path not in _paths:
                        _paths.add(path)
                        gather[v.name] = v
                        nodes.append(v)
            for node in self.implicit_nodes.values():
                path = (id(self), id(node))
                if path not in _paths:
                    _paths.add(path)
                    gather[node.name] = node
                    nodes.append(node)
            for v in nodes:
                gather.update(v.nodes(method=method, _paths=_paths))
            gather[self.name] = self

        elif method == 'relative':
            nodes = []
            gather[''] = self
            for k, v in self.__dict__.items():
                if isinstance(v, Base):
                    path = (id(self), id(v))
                    if path not in _paths:
                        _paths.add(path)
                        gather[k] = v
                        nodes.append((k, v))
            for key, node in self.implicit_nodes.items():
                path = (id(self), id(node))
                if path not in _paths:
                    _paths.add(path)
                    gather[key] = node
                    nodes.append((key, node))
            for k1, v1 in nodes:
                for k2, v2 in v1.nodes(method=method, _paths=_paths).items():
                    if k2: gather[f'{k1}.{k2}'] = v2

        else:
            raise ValueError(f'No support for the method of "{method}".')
        return gather

    def ints(self, method='absolute'):
        """Collect all integrators in this node and the children nodes.

    Parameters
    ----------
    method : str
      The method to access the integrators.

    Returns
    -------
    collector : Collector
      The collection contained (the path, the integrator).
    """
        global Integrator
        if Integrator is None: from brainpy.integrators.base import Integrator

        nodes = self.nodes(method=method)
        gather = Collector()
        for node_path, node in nodes.items():
            for k in dir(node):
                v = getattr(node, k)
                if isinstance(v, Integrator):
                    gather[f'{node_path}.{k}' if node_path else k] = v
        return gather

    def unique_name(self, name=None, type_=None):
        """Get the unique name for this object.

    Parameters
    ----------
    name : str, optional
      The expected name. If None, the default unique name will be returned.
      Otherwise, the provided name will be checked to guarantee its uniqueness.
    type_ : str, optional
      The name of this class, used for object naming.

    Returns
    -------
    name : str
      The unique name for this object.
    """
        if name is None:
            if type_ is None:
                return naming.get_unique_name(type_=self.__class__.__name__)
            else:
                return naming.get_unique_name(type_=type_)
        else:
            naming.check_name_uniqueness(name=name, obj=self)
            return name

    def load_states(self, filename, verbose=False, check_missing=False):
        """Load the model states.

    Parameters
    ----------
    filename : str
      The filename which stores the model states.
    verbose: bool
    check_missing: bool
    """
        if not os.path.exists(filename):
            raise errors.BrainPyError(f'Cannot find the file path: {filename}')
        elif filename.endswith('.hdf5') or filename.endswith('.h5'):
            io.load_h5(filename,
                       target=self,
                       verbose=verbose,
                       check=check_missing)
        elif filename.endswith('.pkl'):
            io.load_pkl(filename,
                        target=self,
                        verbose=verbose,
                        check=check_missing)
        elif filename.endswith('.npz'):
            io.load_npz(filename,
                        target=self,
                        verbose=verbose,
                        check=check_missing)
        elif filename.endswith('.mat'):
            io.load_mat(filename,
                        target=self,
                        verbose=verbose,
                        check=check_missing)
        else:
            raise errors.BrainPyError(
                f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}'
            )

    def save_states(self, filename, all_vars=None, **setting):
        """Save the model states.

    Parameters
    ----------
    filename : str
      The file name which to store the model states.
    all_vars: optional, dict, TensorCollector
    """
        if all_vars is None:
            all_vars = self.vars(method='relative').unique()

        if filename.endswith('.hdf5') or filename.endswith('.h5'):
            io.save_h5(filename, all_vars=all_vars)
        elif filename.endswith('.pkl'):
            io.save_pkl(filename, all_vars=all_vars)
        elif filename.endswith('.npz'):
            io.save_npz(filename, all_vars=all_vars, **setting)
        elif filename.endswith('.mat'):
            io.save_mat(filename, all_vars=all_vars)
        else:
            raise errors.BrainPyError(
                f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}'
            )
コード例 #13
0
ファイル: ast2numba.py プロジェクト: PKU-NIP-Lab/BrainPy
    def visit_For(self, node):
        iter_ = tools.ast2code(ast.fix_missing_locations(node.iter))

        if iter_.strip() == self.iter_name:
            data_to_replace = Collector()
            final_node = ast.Module(body=[])
            self.success = True

            # target
            if not isinstance(node.target, ast.Name):
                raise errors.BrainPyError(
                    f'Only support scalar iter, like "for x in xxxx:", not "for '
                    f'{tools.ast2code(ast.fix_missing_locations(node.target))} '
                    f'in {iter_}:')
            target = node.target.id

            # for loop values
            for i, value in enumerate(self.loop_values):
                # module and code
                module = ast.Module(body=deepcopy(node).body)
                code = tools.ast2code(module)

                if isinstance(value, Base):  # transform Base objects
                    r = _analyze_cls_func_body(host=value,
                                               self_name=target,
                                               code=code,
                                               tree=module,
                                               show_code=self.show_code,
                                               **self.jit_setting)

                    new_code, arguments, arg2call, nodes, code_scope = r
                    self.arguments.update(arguments)
                    self.arg2call.update(arg2call)
                    self.arg2call.update(arg2call)
                    self.nodes.update(nodes)
                    self.code_scope.update(code_scope)

                    final_node.body.extend(ast.parse(new_code).body)

                elif callable(value):  # transform functions
                    r = _jit_func(obj_or_fun=value,
                                  show_code=self.show_code,
                                  **self.jit_setting)
                    tree = _replace_func_call_by_tree(
                        deepcopy(module),
                        func_call=target,
                        arg_to_append=r['arguments'],
                        new_func_name=f'{target}_{i}')

                    # update import parameters
                    self.arguments.update(r['arguments'])
                    self.arg2call.update(r['arg2call'])
                    self.nodes.update(r['nodes'])

                    # replace the data
                    if isinstance(value, Base):
                        host = value
                        replace_name = f'{host.name}_{target}'
                    elif hasattr(value, '__self__') and isinstance(
                            value.__self__, Base):
                        host = value.__self__
                        replace_name = f'{host.name}_{target}'
                    else:
                        replace_name = f'{target}_{i}'
                    self.code_scope[replace_name] = r['func']
                    data_to_replace[f'{target}_{i}'] = replace_name

                    final_node.body.extend(tree.body)

                else:
                    raise errors.BrainPyError(
                        f'Only support JIT an iterable objects of function '
                        f'or Base object, but we got:\n\n {value}')

            # replace words
            final_code = tools.ast2code(final_node)
            final_code = tools.word_replace(final_code,
                                            data_to_replace,
                                            exclude_dot=True)
            final_node = ast.parse(final_code)

        else:
            final_node = node

        self.generic_visit(final_node)
        return final_node
コード例 #14
0
class Sequential(Module):
    """Basic sequential object to control data flow.

  Parameters
  ----------
  arg_ds
    The modules without name specifications.
  name : str, optional
    The name of the sequential module.
  kwarg_ds
    The modules with name specifications.
  """
    def __init__(self, *arg_ds, name=None, **kwarg_ds):
        super(Sequential, self).__init__(name=name)

        self.implicit_nodes = Collector()
        # check "args"
        for ds in arg_ds:
            if not isinstance(ds, DynamicalSystem):
                raise errors.BrainPyError(
                    f'Only support {DynamicalSystem.__name__}, '
                    f'but we got {type(ds)}: {str(ds)}.')
            self.implicit_nodes[ds.name] = ds

        # check "kwargs"
        for key, ds in kwarg_ds.items():
            if not isinstance(ds, DynamicalSystem):
                raise errors.BrainPyError(
                    f'Only support {DynamicalSystem.__name__}, '
                    f'but we got {type(ds)}: {str(ds)}.')
            self.implicit_nodes[key] = ds

        # all update functions
        self._return_kwargs = [
            'kwargs' in inspect.signature(ds.update).parameters.keys()
            for ds in self.implicit_nodes.values()
        ]

    def _check_kwargs(self, i, kwargs):
        return kwargs if self._return_kwargs[i] else dict()

    def update(self, *args, **kwargs):
        """Functional call.

    Parameters
    ----------
    args : list, tuple
      The *args arguments.
    kwargs : dict
      The config arguments. The configuration used across modules.
      If the "__call__" function in submodule receives "config" arguments,
      This "config" parameter will be passed into this function.
    """
        ds = list(self.implicit_nodes.values())
        # first layer
        args = ds[0].update(*args, **self._check_kwargs(0, kwargs))
        # other layers
        for i in range(1, len(self.implicit_nodes)):
            args = ds[i].update(*_check_args(args=args),
                                **self._check_kwargs(i, kwargs))
        return args

    def __getitem__(self, key: int):
        return list(self.implicit_nodes.values())[key]
コード例 #15
0
ファイル: ast2numba.py プロジェクト: PKU-NIP-Lab/BrainPy
def _analyze_cls_func_body(host,
                           self_name,
                           code,
                           tree,
                           show_code=False,
                           has_func_def=False,
                           **jit_setting):
    arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict()

    # all self data
    self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b',
                           code)
    self_data = list(set(self_data))

    # analyze variables and functions accessed by the self.xx
    data_to_replace = {}
    for key in self_data:
        split_keys = key.split('.')
        if len(split_keys) < 2:
            raise errors.BrainPyError

        # get target and data
        target = host
        for i in range(1, len(split_keys)):
            next_target = getattr(target, split_keys[i])
            if isinstance(next_target, Integrator):
                break
            if not isinstance(next_target, Base):
                break
            target = next_target
        else:
            raise errors.BrainPyError
        data = getattr(target, split_keys[i])

        # analyze data
        if isinstance(data, math.numpy.Variable):  # data is a variable
            arguments.add(f'{target.name}_{split_keys[i]}')
            arg2call[
                f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value'
            nodes[target.name] = target
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif isinstance(data, np.random.RandomState):  # data is a RandomState
            # replace RandomState
            code_scope[f'{target.name}_{split_keys[i]}'] = np.random
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif callable(data):  # data is a function
            assert len(split_keys) == i + 1
            r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting)
            # if len(r['arguments']):
            tree = _replace_func_call_by_tree(tree,
                                              func_call=key,
                                              arg_to_append=r['arguments'])
            arguments.update(r['arguments'])
            arg2call.update(r['arg2call'])
            nodes.update(r['nodes'])
            code_scope[f'{target.name}_{split_keys[i]}'] = r['func']
            data_to_replace[
                key] = f'{target.name}_{split_keys[i]}'  # replace the data

        elif isinstance(
                data, (dict, list,
                       tuple)):  # data is a list/tuple/dict of function/object
            # get all values
            if isinstance(data, dict):  # check dict
                if len(split_keys) != i + 2 and split_keys[-1] != 'values':
                    raise errors.BrainPyError(
                        f'Only support iter dict.values(). while we got '
                        f'dict.{split_keys[-1]}  for data: \n\n{data}')
                values = list(data.values())
                iter_name = key + '()'
            else:  # check list / tuple
                assert len(split_keys) == i + 1
                values = list(data)
                iter_name = key
                if len(values) > 0:
                    if not (callable(values[0])
                            or isinstance(values[0], Base)):
                        code_scope[f'{target.name}_{split_keys[i]}'] = data
                        if len(split_keys) == i + 1:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}'
                        else:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'
                        continue
                        # raise errors.BrainPyError(f'Only support JIT an iterable objects of function '
                        #                           f'or Base object, but we got:\n\n {values[0]}')
            # replace this for-loop
            r = _replace_this_forloop(tree=tree,
                                      iter_name=iter_name,
                                      loop_values=values,
                                      show_code=show_code,
                                      **jit_setting)
            tree, _arguments, _arg2call, _nodes, _code_scope = r
            arguments.update(_arguments)
            arg2call.update(_arg2call)
            nodes.update(_nodes)
            code_scope.update(_code_scope)

        else:  # constants
            code_scope[f'{target.name}_{split_keys[i]}'] = data
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

    if has_func_def:
        tree.body[0].decorator_list.clear()
        tree.body[0].args.args.extend(
            [ast.Name(id=a) for a in sorted(arguments)])
        tree.body[0].args.defaults.extend(
            [ast.Constant(None) for _ in sorted(arguments)])
        tree.body[0].args.kwarg = None

    # replace words
    code = tools.ast2code(tree)
    code = tools.word_replace(code, data_to_replace, exclude_dot=True)

    return code, arguments, arg2call, nodes, code_scope
コード例 #16
0
ファイル: ast2numba.py プロジェクト: PKU-NIP-Lab/BrainPy
def _jit_intg(f, show_code=False, **jit_setting):
    # TODO: integrator has "integral", "code_lines", "code_scope", "func_name", "derivative",
    assert isinstance(f, Integrator)

    # exponential euler methods
    if hasattr(f.integral, '__self__'):
        return _jit_cls_func(f=f.integral,
                             code="\n".join(f.code_lines),
                             show_code=show_code,
                             **jit_setting)

    # information in the integrator
    func_name = f.func_name
    raw_func = f.derivative
    tree = ast.parse('\n'.join(f.code_lines))
    code_scope = {key: val for key, val in f.code_scope.items()}

    # essential information
    arguments = set()
    arg2call = dict()
    nodes = Collector()

    # jit raw functions
    f_node = None
    remove_self = None
    if hasattr(f, '__self__') and isinstance(f.__self__, DynamicalSystem):
        f_node = f.__self__
        _arg = tree.body[0].args.args.pop(0)  # remove "self" arg
        # remove "self" in functional call
        remove_self = _arg.arg

    need_recompile = False
    for key, func in raw_func.items():
        # get node of host
        func_node = None
        if f_node:
            func_node = f_node
        elif hasattr(func, '__self__') and isinstance(func.__self__,
                                                      DynamicalSystem):
            func_node = func.__self__

        # get new compiled function
        if isinstance(func, Dispatcher):
            continue
        elif func_node is not None:
            need_recompile = True
            r = _jit_cls_func(f=func,
                              host=func_node,
                              show_code=show_code,
                              **jit_setting)
            if len(r['arguments']) or remove_self:
                tree = _replace_func_call_by_tree(tree,
                                                  func_call=key,
                                                  arg_to_append=r['arguments'],
                                                  remove_self=remove_self)
            code_scope[key] = r['func']
            arguments.update(r['arguments'])  # update arguments
            arg2call.update(r['arg2call'])  # update arg2call
            nodes.update(r['nodes'])  # update nodes
            nodes[func_node.name] = func_node  # update nodes
        else:
            need_recompile = True
            code_scope[key] = numba.jit(func, **jit_setting)

    if need_recompile:
        tree.body[0].decorator_list.clear()
        tree.body[0].args.args.extend(
            [ast.Name(id=a) for a in sorted(arguments)])
        tree.body[0].args.defaults.extend(
            [ast.Constant(None) for _ in sorted(arguments)])
        code = tools.ast2code(tree)
        # code, _scope = _add_try_except(code)
        # code_scope.update(_scope)
        # code_scope_backup = {k: v for k, v in code_scope.items()}
        # compile functions
        if show_code:
            _show_compiled_codes(code, code_scope)
        exec(compile(code, '', 'exec'), code_scope)
        new_f = code_scope[func_name]
        # new_f.brainpy_data = {key: val for key, val in f.brainpy_data.items()}
        # new_f.brainpy_data['code_lines'] = code.strip().split('\n')
        # new_f.brainpy_data['code_scope'] = code_scope_backup
        jit_f = numba.jit(new_f, **jit_setting)
        return dict(func=jit_f,
                    arguments=arguments,
                    arg2call=arg2call,
                    nodes=nodes)
    else:
        return dict(func=f,
                    arguments=arguments,
                    arg2call=arg2call,
                    nodes=nodes)
コード例 #17
0
ファイル: runner.py プロジェクト: PKU-NIP-Lab/BrainPy
class IntegratorRunner(Runner):
    def __init__(self,
                 target,
                 monitors=None,
                 inits=None,
                 args=None,
                 dyn_args=None,
                 dyn_vars=None,
                 jit=True,
                 dt=None,
                 numpy_mon_after_run=True,
                 progress_bar=True):
        super(IntegratorRunner, self).__init__()

        # parameters
        dt = math.get_dt() if dt is None else dt
        if not isinstance(dt, (int, float)):
            raise RunningError(f'"dt" must be scalar, but got {dt}')
        self.dt = dt
        self.jit = jit
        self.numpy_mon_after_run = numpy_mon_after_run
        self._pbar = None  # progress bar
        self.progress_bar = progress_bar

        # target
        if not isinstance(target, Integrator):
            raise RunningError(
                f'"target" must be an instance of {Integrator.__name__}, '
                f'but we got {type(target)}: {target}')
        self.target = target

        # arguments of the integral function
        self._static_args = Collector()
        if args is not None:
            assert isinstance(
                args, dict
            ), f'"args" must be a dict, but we get {type(args)}: {args}'
            self._static_args.update(args)
        self._dyn_args = Collector()
        if dyn_args is not None:
            assert isinstance(
                dyn_args, dict
            ), f'"dyn_args" must be a dict, but we get {type(dyn_args)}: {dyn_args}'
            sizes = np.unique([len(v) for v in dyn_args.values()])
            num_size = len(sizes)
            if num_size != 1:
                raise RunningError(
                    f'All values in "dyn_args" should have the same length. But we got '
                    f'{num_size}: {sizes}')
            self._dyn_args.update(dyn_args)

        # dynamical changed variables
        if dyn_vars is None:
            dyn_vars = self.target.vars().unique()
        if isinstance(dyn_vars, (list, tuple)):
            dyn_vars = {f'_v{i}': v for i, v in enumerate(dyn_vars)}
        if not isinstance(dyn_vars, dict):
            raise RunningError(
                f'"dyn_vars" must be a dict, but we got {type(dyn_vars)}')
        self.dyn_vars = TensorCollector(dyn_vars)

        # monitors
        if monitors is None:
            self.mon = Monitor(target=self, variables=[])
        elif isinstance(monitors, (list, tuple, dict)):
            self.mon = Monitor(target=self, variables=monitors)
        elif isinstance(monitors, Monitor):
            self.mon = monitors
            self.mon.target = self
        else:
            raise MonitorError(f'"monitors" only supports list/tuple/dict/ '
                               f'instance of Monitor, not {type(monitors)}.')
        self.mon.build()  # build the monitor
        for k in self.mon.item_names:
            if k not in self.target.variables:
                raise MonitorError(
                    f'Variable "{k}" to monitor is not defined in the integrator {self.target}.'
                )

        # start simulation time
        self._start_t = None

        # Variables
        if inits is not None:
            if isinstance(inits, (list, tuple)):
                assert len(self.target.variables) == len(inits)
                inits = {
                    k: inits[i]
                    for i, k in enumerate(self.target.variables)
                }
            assert isinstance(inits, dict)
            sizes = np.unique([np.size(v) for v in list(inits.values())])
            max_size = np.max(sizes)
        else:
            max_size = 1
            inits = dict()
        self.variables = TensorCollector({
            v: math.Variable(math.zeros(max_size))
            for v in self.target.variables
        })
        for k in inits.keys():
            self.variables[k][:] = inits[k]
        self.dyn_vars.update(self.variables)
        if len(self._dyn_args) > 0:
            self.idx = math.Variable(math.zeros(1, dtype=math.int_))
            self.dyn_vars['_idx'] = self.idx

        # build the update step
        if jit:
            _loop_func = math.make_loop(
                self._step,
                dyn_vars=self.dyn_vars,
                out_vars={k: self.variables[k]
                          for k in self.mon.item_names})
        else:

            def _loop_func(t_and_dt):
                out_vars = {k: [] for k in self.mon.item_names}
                times, dts = t_and_dt
                for i in range(len(times)):
                    _t = times[i]
                    _dt = dts[i]
                    self._step([_t, _dt])
                    for k in self.mon.item_names:
                        out_vars[k].append(
                            math.as_device_array(self.variables[k]))
                out_vars = {
                    k: math.asarray(out_vars[k])
                    for k in self.mon.item_names
                }
                return out_vars

        self.step_func = _loop_func

    def _post(self, times, returns):  # monitor
        self.mon.ts = times
        for key in self.mon.item_names:
            self.mon.item_contents[key] = math.asarray(returns[key])

    def _step(self, t_and_dt):
        # arguments
        kwargs = dict()
        kwargs.update(self.variables)
        kwargs.update({'t': t_and_dt[0], 'dt': t_and_dt[1]})
        kwargs.update(self._static_args)
        if len(self._dyn_args) > 0:
            kwargs.update({k: v[self.idx] for k, v in self._dyn_args.items()})
            self.idx += 1
        # call integrator function
        update_values = self.target(**kwargs)
        for i, v in enumerate(self.target.variables):
            self.variables[v].update(update_values[i])
        if self.progress_bar:
            id_tap(lambda *args: self._pbar.update(), ())

    def run(self, duration, start_t=None):
        self.__call__(duration, start_t)

    def __call__(self, duration, start_t=None):
        """The running function.

    Parameters
    ----------
    duration : float, int, tuple, list
      The running duration.
    start_t : float, optional

    Returns
    -------
    running_time : float
      The total running time.
    """
        if len(self._dyn_args) > 0:
            self.dyn_vars['_idx'][0] = 0

        # time step
        if start_t is None:
            if self._start_t is None:
                start_t = 0.
            else:
                start_t = float(self._start_t)
        end_t = float(start_t + duration)
        # times
        times = math.arange(start_t, end_t, self.dt)
        time_steps = math.ones_like(times) * self.dt
        # running
        if self.progress_bar:
            self._pbar = tqdm.auto.tqdm(total=times.size)
            self._pbar.set_description(
                f"Running a duration of {round(float(duration), 3)} ({times.size} steps)",
                refresh=True)
        t0 = time.time()
        hists = self.step_func([times.value, time_steps.value])
        running_time = time.time() - t0
        if self.progress_bar:
            self._pbar.close()
        # post-running
        self._post(times, hists)
        self._start_t = end_t
        if self.numpy_mon_after_run:
            self.mon.numpy()
        return running_time