Esempio n. 1
0
def _get_args(f):
    # 1. get the function arguments
    original_args = []
    args = []
    kwargs = []

    for name, par in inspect.signature(f).parameters.items():
        if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD:
            args.append(par.name)
        elif par.kind is inspect.Parameter.VAR_POSITIONAL:
            args.append(par.name)
        elif par.kind is inspect.Parameter.KEYWORD_ONLY:
            args.append(par.name)
        elif par.kind is inspect.Parameter.POSITIONAL_ONLY:
            raise errors.BrainPyError(
                'Don not support positional only parameters, e.g., /')
        elif par.kind is inspect.Parameter.VAR_KEYWORD:
            kwargs.append(par.name)
        else:
            raise errors.BrainPyError(f'Unknown argument type: {par.kind}')

        original_args.append(str(par))

    # 2. analyze the function arguments
    #   2.1 class keywords
    class_kw = []
    if original_args[0] in CLASS_KEYWORDS:
        class_kw.append(original_args[0])
        original_args = original_args[1:]
        args = args[1:]
    for a in original_args:
        if a.split('=')[0].strip() in CLASS_KEYWORDS:
            raise errors.DiffEqError(f'Class keywords "{a}" must be defined '
                                     f'as the first argument.')
    return class_kw, args, kwargs, original_args
Esempio n. 2
0
    def __init__(self, *arg_ds, name=None, **kwarg_ds):
        super(Sequential, self).__init__(name=name)

        self.implicit_nodes = Collector()
        # check "args"
        for ds in arg_ds:
            if not isinstance(ds, DynamicalSystem):
                raise errors.BrainPyError(
                    f'Only support {DynamicalSystem.__name__}, '
                    f'but we got {type(ds)}: {str(ds)}.')
            self.implicit_nodes[ds.name] = ds

        # check "kwargs"
        for key, ds in kwarg_ds.items():
            if not isinstance(ds, DynamicalSystem):
                raise errors.BrainPyError(
                    f'Only support {DynamicalSystem.__name__}, '
                    f'but we got {type(ds)}: {str(ds)}.')
            self.implicit_nodes[key] = ds

        # all update functions
        self._return_kwargs = [
            'kwargs' in inspect.signature(ds.update).parameters.keys()
            for ds in self.implicit_nodes.values()
        ]
Esempio n. 3
0
def _device_reshape(x):
    """Reshape an input array in order to broadcast to multiple devices."""
    num_device = jax.local_device_count()

    if not hasattr(x, 'ndim'):
        raise errors.BrainPyError(
            f'Expected JaxArray, got {type(x)}. If you are trying to pass a scalar to '
            f'parallel, first convert it to a JaxArray, for example np.float(0.5)'
        )
    if x.ndim == 0:
        return np.broadcast_to(x, [num_device])
    if x.shape[0] % num_device != 0:
        raise errors.BrainPyError(
            f'Must be able to equally divide batch {x.shape} among '
            f'{num_device} devices, but does not go equally.')
    return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:])
Esempio n. 4
0
  def __init__(self, size, C=1., A=1e-3, V_th=0., method='exp_auto', name=None, **channels):
    super(CondNeuGroup, self).__init__(size, method=method, name=name)

    # parameters for neurons
    self.C = C
    self.A = A
    self.V_th = V_th

    # check 'channels'
    _channels = dict()
    for key in channels.items():
      assert isinstance(key, str), f'Key must be a str, but got {type(key)}: {key}'
      assert isinstance(channels[key], (tuple, list)) and len(channels[key]) == 2
      assert isinstance(channels[key][0], type)
      assert isinstance(channels[key][1], dict)
      cls = channels[key][0]
      params = channels[key][1].copy()
      params['host'] = self
      params['method'] = method
      _channels[key] = (cls, params)

    # initialize children channels
    self.channels = Collector()
    for key, (ch, params) in _channels.items():
      self.channels[key] = ch(**params)
      if not isinstance(self.channels[key], Channel):
        raise errors.BrainPyError(f'{self.__class__.__name__} only receives {Channel} instance, '
                                  f'while we got {type(self.channels[key])}: {self.channels[key]}.')

    self.ion_channels = self.channels.subset(IonChannel)
    self.mol_channels = self.channels.subset(MolChannel)
Esempio n. 5
0
def _check_var(var):
    global math
    if math is None: from brainpy import math
    if not isinstance(var, math.ndarray):
        raise errors.BrainPyError(
            f'Element in "dyn_vars" must be an instance of '
            f'{math.ndarray.__name__}, but we got {type(var)}.')
Esempio n. 6
0
def _jit_DS(obj_or_fun, show_code=False, **jit_setting):
    if not isinstance(obj_or_fun, DynamicalSystem):
        raise errors.UnsupportedError(f'JIT compilation in numpy backend only '
                                      f'supports {Base.__name__}, but we got '
                                      f'{type(obj_or_fun)}.')
    if not hasattr(obj_or_fun, 'steps'):
        raise errors.BrainPyError(
            f'Please init this DynamicalSystem {obj_or_fun} first, '
            f'then apply JIT.')

    # function analysis
    for key, step in list(obj_or_fun.steps.items()):
        key = key.replace(".", "_")
        r = _jit_func(obj_or_fun=step, show_code=show_code, **jit_setting)
        if r['func'] != step:
            func = _form_final_call(f_org=step,
                                    f_rep=r['func'],
                                    arg2call=r['arg2call'],
                                    arguments=r['arguments'],
                                    nodes=r['nodes'],
                                    show_code=show_code,
                                    name=step.__name__)
            obj_or_fun.steps.replace(key, func)

    # dynamic system
    return obj_or_fun
Esempio n. 7
0
def _replace_this_forloop(tree,
                          iter_name,
                          loop_values,
                          show_code=False,
                          **jit_setting):
    """Replace the given for-loop.

  This function aims to replace the specific for-loop structure, like:

  replace this for-loop

  >>> def update(_t, _dt):
  >>>    for step in self.child_steps.values():
  >>>        step(_t, _dt)

  to

  >>> def update(_t, _dt, AMPA_vec0_delay_g_data=None, AMPA_vec0_delay_g_in_idx=None,
  >>>            AMPA_vec0_delay_g_out_idx=None, AMPA_vec0_s=None, HH0_V=None, HH0_V_th=None,
  >>>            HH0_gNa=None, HH0_h=None, HH0_input=None, HH0_m=None, HH0_n=None, HH0_spike=None):
  >>>    HH0_step(_t, _dt, HH0_V=HH0_V, HH0_V_th=HH0_V_th, HH0_gNa=HH0_gNa,
  >>>             HH0_h=HH0_h, HH0_input=HH0_input, HH0_m=HH0_m, HH0_n=HH0_n,
  >>>             HH0_spike=HH0_spike)
  >>>    AMPA_vec0_step(_t, _dt, AMPA_vec0_delay_g_data=AMPA_vec0_delay_g_data,
  >>>                   AMPA_vec0_delay_g_in_idx=AMPA_vec0_delay_g_in_idx,
  >>>                   AMPA_vec0_delay_g_out_idx=AMPA_vec0_delay_g_out_idx,
  >>>                   AMPA_vec0_s=AMPA_vec0_s, HH0_V=HH0_V, HH0_input=HH0_input,
  >>>                   HH0_spike=HH0_spike)
  >>>    AMPA_vec0_delay_g_step(_t, _dt, AMPA_vec0_delay_g_in_idx=AMPA_vec0_delay_g_in_idx,
  >>>                           AMPA_vec0_delay_g_out_idx=AMPA_vec0_delay_g_out_idx)

  Parameters
  ----------
  tree : ast.Module
    The target code tree.
  iter_name : str
    The for-loop iter.
  loop_values : list/tuple
    The iter contents in the current loop.
  show_code : bool
    Whether show the formatted code.
  """
    assert isinstance(tree, ast.Module)

    replacer = ReplaceThisForLoop(loop_values=loop_values,
                                  iter_name=iter_name,
                                  show_code=show_code,
                                  **jit_setting)
    tree = replacer.visit(tree)
    if not replacer.success:
        raise errors.BrainPyError(
            f'Do not find the for-loop for "{iter_name}", '
            f'currently we only support for-loop like '
            f'"for xxx in {iter_name}:". Does your for-loop '
            f'structure is not like this. ')

    return tree, replacer.arguments, replacer.arg2call, replacer.nodes, replacer.code_scope
Esempio n. 8
0
def check_name_uniqueness(name, obj):
    """Check the uniqueness of the name for the object type."""
    if not name.isidentifier():
        raise errors.BrainPyError(f'"{name}" isn\'t a valid identifier '
                                  f'according to Python language definition. '
                                  f'Please choose another name.')
    if name in _name2id:
        if _name2id[name] != id(obj):
            raise errors.UniqueNameError(
                f'In BrainPy, each object should have a unique name. '
                f'However, we detect that {obj} has a used name "{name}".')
    else:
        _name2id[name] = id(obj)
Esempio n. 9
0
    def load_states(self, filename, verbose=False, check_missing=False):
        """Load the model states.

    Parameters
    ----------
    filename : str
      The filename which stores the model states.
    verbose: bool
    check_missing: bool
    """
        if not os.path.exists(filename):
            raise errors.BrainPyError(f'Cannot find the file path: {filename}')
        elif filename.endswith('.hdf5') or filename.endswith('.h5'):
            io.load_h5(filename,
                       target=self,
                       verbose=verbose,
                       check=check_missing)
        elif filename.endswith('.pkl'):
            io.load_pkl(filename,
                        target=self,
                        verbose=verbose,
                        check=check_missing)
        elif filename.endswith('.npz'):
            io.load_npz(filename,
                        target=self,
                        verbose=verbose,
                        check=check_missing)
        elif filename.endswith('.mat'):
            io.load_mat(filename,
                        target=self,
                        verbose=verbose,
                        check=check_missing)
        else:
            raise errors.BrainPyError(
                f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}'
            )
Esempio n. 10
0
 def call(*args):
     un_replicated = [
         k for k, v in dyn_vars.items()
         if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer,
                                     DynamicJaxprTracer))
     ]
     if len(un_replicated):
         raise errors.BrainPyError(
             f'Some variables were not replicated: {un_replicated}.'
             f'did you forget to call xx.replicate() on them?')
     _args = []
     for i, x in enumerate(args):
         if i + 2 in static_broadcasted_argnums:
             _args.append(x)
         else:
             _args.append(jax.tree_map(_device_reshape, [x])[0])
     dyn_data = dyn_vars.dict()
     rand_data = rand_vars.dict()
     output, dyn_changes, rand_changes = pmapped_func(
         dyn_data, rand_data, *_args)
     dyn_vars.assign(dyn_changes)
     rand_vars.assign(rand_changes)
     return jax.tree_map(reduce_func, output)
Esempio n. 11
0
    def save_states(self, filename, all_vars=None, **setting):
        """Save the model states.

    Parameters
    ----------
    filename : str
      The file name which to store the model states.
    all_vars: optional, dict, TensorCollector
    """
        if all_vars is None:
            all_vars = self.vars(method='relative').unique()

        if filename.endswith('.hdf5') or filename.endswith('.h5'):
            io.save_h5(filename, all_vars=all_vars)
        elif filename.endswith('.pkl'):
            io.save_pkl(filename, all_vars=all_vars)
        elif filename.endswith('.npz'):
            io.save_npz(filename, all_vars=all_vars, **setting)
        elif filename.endswith('.mat'):
            io.save_mat(filename, all_vars=all_vars)
        else:
            raise errors.BrainPyError(
                f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}'
            )
Esempio n. 12
0
def raster_plot(ts,
                sp_matrix,
                ax=None,
                marker='.',
                markersize=2,
                color='k',
                xlabel='Time (ms)',
                ylabel='Neuron index',
                xlim=None,
                ylim=None,
                title=None,
                show=False,
                **kwargs):
    """Show the rater plot of the spikes.

  Parameters
  ----------
  ts : np.ndarray
      The run times.
  sp_matrix : np.ndarray
      The spike matrix which records the spike information.
      It can be easily accessed by specifying the ``monitors``
      of NeuGroup by: ``neu = NeuGroup(..., monitors=['spike'])``
  ax : Axes
      The figure.
  markersize : int
      The size of the marker.
  color : str
      The color of the marker.
  xlim : list, tuple
      The xlim.
  ylim : list, tuple
      The ylim.
  xlabel : str
      The xlabel.
  ylabel : str
      The ylabel.
  show : bool
      Show the figure.
  """

    sp_matrix = np.asarray(sp_matrix)
    if ts is None:
        raise errors.BrainPyError('Must provide "ts".')
    ts = np.asarray(ts)

    # get index and time
    elements = np.where(sp_matrix > 0.)
    index = elements[1]
    time = ts[elements[0]]

    # plot rater
    if ax is None:
        ax = plt
    ax.plot(time, index, marker + color, markersize=markersize, **kwargs)

    # xlable
    if xlabel:
        plt.xlabel(xlabel)

    # ylabel
    if ylabel:
        plt.ylabel(ylabel)

    if xlim:
        plt.xlim(xlim[0], xlim[1])

    if ylim:
        plt.ylim(ylim[0], ylim[1])

    if title:
        plt.title(title)

    if show:
        plt.show()
Esempio n. 13
0
    def visit_For(self, node):
        iter_ = tools.ast2code(ast.fix_missing_locations(node.iter))

        if iter_.strip() == self.iter_name:
            data_to_replace = Collector()
            final_node = ast.Module(body=[])
            self.success = True

            # target
            if not isinstance(node.target, ast.Name):
                raise errors.BrainPyError(
                    f'Only support scalar iter, like "for x in xxxx:", not "for '
                    f'{tools.ast2code(ast.fix_missing_locations(node.target))} '
                    f'in {iter_}:')
            target = node.target.id

            # for loop values
            for i, value in enumerate(self.loop_values):
                # module and code
                module = ast.Module(body=deepcopy(node).body)
                code = tools.ast2code(module)

                if isinstance(value, Base):  # transform Base objects
                    r = _analyze_cls_func_body(host=value,
                                               self_name=target,
                                               code=code,
                                               tree=module,
                                               show_code=self.show_code,
                                               **self.jit_setting)

                    new_code, arguments, arg2call, nodes, code_scope = r
                    self.arguments.update(arguments)
                    self.arg2call.update(arg2call)
                    self.arg2call.update(arg2call)
                    self.nodes.update(nodes)
                    self.code_scope.update(code_scope)

                    final_node.body.extend(ast.parse(new_code).body)

                elif callable(value):  # transform functions
                    r = _jit_func(obj_or_fun=value,
                                  show_code=self.show_code,
                                  **self.jit_setting)
                    tree = _replace_func_call_by_tree(
                        deepcopy(module),
                        func_call=target,
                        arg_to_append=r['arguments'],
                        new_func_name=f'{target}_{i}')

                    # update import parameters
                    self.arguments.update(r['arguments'])
                    self.arg2call.update(r['arg2call'])
                    self.nodes.update(r['nodes'])

                    # replace the data
                    if isinstance(value, Base):
                        host = value
                        replace_name = f'{host.name}_{target}'
                    elif hasattr(value, '__self__') and isinstance(
                            value.__self__, Base):
                        host = value.__self__
                        replace_name = f'{host.name}_{target}'
                    else:
                        replace_name = f'{target}_{i}'
                    self.code_scope[replace_name] = r['func']
                    data_to_replace[f'{target}_{i}'] = replace_name

                    final_node.body.extend(tree.body)

                else:
                    raise errors.BrainPyError(
                        f'Only support JIT an iterable objects of function '
                        f'or Base object, but we got:\n\n {value}')

            # replace words
            final_code = tools.ast2code(final_node)
            final_code = tools.word_replace(final_code,
                                            data_to_replace,
                                            exclude_dot=True)
            final_node = ast.parse(final_code)

        else:
            final_node = node

        self.generic_visit(final_node)
        return final_node
Esempio n. 14
0
def _analyze_cls_func_body(host,
                           self_name,
                           code,
                           tree,
                           show_code=False,
                           has_func_def=False,
                           **jit_setting):
    arguments, arg2call, nodes, code_scope = set(), dict(), Collector(), dict()

    # all self data
    self_data = re.findall('\\b' + self_name + '\\.[A-Za-z_][A-Za-z0-9_.]*\\b',
                           code)
    self_data = list(set(self_data))

    # analyze variables and functions accessed by the self.xx
    data_to_replace = {}
    for key in self_data:
        split_keys = key.split('.')
        if len(split_keys) < 2:
            raise errors.BrainPyError

        # get target and data
        target = host
        for i in range(1, len(split_keys)):
            next_target = getattr(target, split_keys[i])
            if isinstance(next_target, Integrator):
                break
            if not isinstance(next_target, Base):
                break
            target = next_target
        else:
            raise errors.BrainPyError
        data = getattr(target, split_keys[i])

        # analyze data
        if isinstance(data, math.numpy.Variable):  # data is a variable
            arguments.add(f'{target.name}_{split_keys[i]}')
            arg2call[
                f'{target.name}_{split_keys[i]}'] = f'{target.name}.{split_keys[-1]}.value'
            nodes[target.name] = target
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif isinstance(data, np.random.RandomState):  # data is a RandomState
            # replace RandomState
            code_scope[f'{target.name}_{split_keys[i]}'] = np.random
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

        elif callable(data):  # data is a function
            assert len(split_keys) == i + 1
            r = _jit_func(obj_or_fun=data, show_code=show_code, **jit_setting)
            # if len(r['arguments']):
            tree = _replace_func_call_by_tree(tree,
                                              func_call=key,
                                              arg_to_append=r['arguments'])
            arguments.update(r['arguments'])
            arg2call.update(r['arg2call'])
            nodes.update(r['nodes'])
            code_scope[f'{target.name}_{split_keys[i]}'] = r['func']
            data_to_replace[
                key] = f'{target.name}_{split_keys[i]}'  # replace the data

        elif isinstance(
                data, (dict, list,
                       tuple)):  # data is a list/tuple/dict of function/object
            # get all values
            if isinstance(data, dict):  # check dict
                if len(split_keys) != i + 2 and split_keys[-1] != 'values':
                    raise errors.BrainPyError(
                        f'Only support iter dict.values(). while we got '
                        f'dict.{split_keys[-1]}  for data: \n\n{data}')
                values = list(data.values())
                iter_name = key + '()'
            else:  # check list / tuple
                assert len(split_keys) == i + 1
                values = list(data)
                iter_name = key
                if len(values) > 0:
                    if not (callable(values[0])
                            or isinstance(values[0], Base)):
                        code_scope[f'{target.name}_{split_keys[i]}'] = data
                        if len(split_keys) == i + 1:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}'
                        else:
                            data_to_replace[
                                key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'
                        continue
                        # raise errors.BrainPyError(f'Only support JIT an iterable objects of function '
                        #                           f'or Base object, but we got:\n\n {values[0]}')
            # replace this for-loop
            r = _replace_this_forloop(tree=tree,
                                      iter_name=iter_name,
                                      loop_values=values,
                                      show_code=show_code,
                                      **jit_setting)
            tree, _arguments, _arg2call, _nodes, _code_scope = r
            arguments.update(_arguments)
            arg2call.update(_arg2call)
            nodes.update(_nodes)
            code_scope.update(_code_scope)

        else:  # constants
            code_scope[f'{target.name}_{split_keys[i]}'] = data
            # replace the data
            if len(split_keys) == i + 1:
                data_to_replace[key] = f'{target.name}_{split_keys[i]}'
            else:
                data_to_replace[
                    key] = f'{target.name}_{split_keys[i]}.{".".join(split_keys[i + 1:])}'

    if has_func_def:
        tree.body[0].decorator_list.clear()
        tree.body[0].args.args.extend(
            [ast.Name(id=a) for a in sorted(arguments)])
        tree.body[0].args.defaults.extend(
            [ast.Constant(None) for _ in sorted(arguments)])
        tree.body[0].args.kwarg = None

    # replace words
    code = tools.ast2code(tree)
    code = tools.word_replace(code, data_to_replace, exclude_dot=True)

    return code, arguments, arg2call, nodes, code_scope
Esempio n. 15
0
def _check_node(node):
    if not isinstance(node, Base):
        raise errors.BrainPyError(f'Element in "nodes" must be an instance of '
                                  f'{Base.__name__}, but we got {type(node)}.')
Esempio n. 16
0
def vmap(func,
         dyn_vars=None,
         batched_vars=None,
         in_axes=0,
         out_axes=0,
         axis_name=None,
         reduce_func=None,
         auto_infer=False):
    """Vectorization compilation for class objects.

  Vectorized compile a function or a module to run in parallel on a single device.

  Examples
  --------

  Parameters
  ----------
  func : Base, function, callable
    The function or the module to compile.
  dyn_vars : dict
  batched_vars : dict
  in_axes : optional, int, sequence of int
    Specify which input array axes to map over. If each positional argument to
    ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None,
    or a tuple of integers and Nones with length equal to the number of
    positional arguments to ``obj_or_func``. An integer or ``None``
    indicates which array axis to map over for all arguments (with ``None``
    indicating not to map any axis), and a tuple indicates which axis to map
    for each corresponding positional argument. Axis integers must be in the
    range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
    dimensions (axes) of the corresponding input array.

    If the positional arguments to ``obj_or_func`` are container types, the
    corresponding element of ``in_axes`` can itself be a matching container,
    so that distinct array axes can be mapped for different container
    elements. ``in_axes`` must be a container tree prefix of the positional
    argument tuple passed to ``obj_or_func``.

    At least one positional argument must have ``in_axes`` not None. The sizes
    of the mapped input axes for all mapped positional arguments must all be
    equal.

    Arguments passed as keywords are always mapped over their leading axis
    (i.e. axis index 0).
  out_axes : optional, int, tuple/list/dict
    Indicate where the mapped axis should appear in the output. All outputs
    with a mapped axis must have a non-None ``out_axes`` specification. Axis
    integers must be in the range ``[-ndim, ndim)`` for each output array,
    where ``ndim`` is the number of dimensions (axes) of the array returned
    by the :func:`vmap`-ed function, which is one more than the number of
    dimensions (axes) of the corresponding array returned by ``obj_or_func``.
  axis_name : optional

  Returns
  -------
  obj_or_func : Any
    Batched/vectorized version of ``obj_or_func`` with arguments that correspond to
    those of ``obj_or_func``, but with extra array axes at positions indicated by
    ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but
    with extra array axes at positions indicated by ``out_axes``.

  """
    from brainpy.building.brainobjects import DynamicalSystem

    if isinstance(func, DynamicalSystem):
        if len(func.steps):  # DynamicalSystem has step functions

            # dynamical variables
            dyn_vars = (dyn_vars or func.vars().unique())
            dyn_vars, rand_vars = TensorCollector(), TensorCollector()
            for key, val in dyn_vars.items():
                if isinstance(val, RandomState):
                    rand_vars[key] = val
                else:
                    dyn_vars[key] = val

            # in axes
            if in_axes is None:
                in_axes = {key: (None, 0) for key in func.steps.keys()}
            elif isinstance(in_axes, int):
                in_axes = {
                    key: (None, 0, in_axes)
                    for key in func.steps.keys()
                }
            elif isinstance(in_axes, (tuple, list)):
                in_axes = {
                    key: (None, 0) + tuple(in_axes)
                    for key in func.steps.keys()
                }
            elif isinstance(in_axes, dict):
                keys = list(func.steps.keys())
                if keys[0] not in in_axes:
                    in_axes = {key: (None, 0, in_axes) for key in keys}
                else:
                    in_axes = {
                        key: (None, 0) + tuple(in_axes[key])
                        for key in keys
                    }
            assert isinstance(in_axes, dict)

            # batch size index
            batch_idx = {}
            for key, axes in in_axes.items():
                for i, axis in enumerate(axes[2:]):
                    if axis is not None:
                        batch_idx[key] = (i, axis)
                        break
                else:
                    raise ValueError(f'Found no batch axis: {axes}.')

            # out axes
            if out_axes is None:
                out_axes = {key: 0 for key in func.steps.keys()}
            elif isinstance(out_axes, int):
                out_axes = {key: out_axes for key in func.steps.keys()}
            elif isinstance(out_axes, (tuple, list)):
                out_axes = {
                    key: tuple(out_axes) + (0, 0)
                    for key in func.steps.keys()
                }
            elif isinstance(out_axes, dict):
                keys = list(func.steps.keys())
                if keys[0] not in out_axes:
                    out_axes = {key: (out_axes, 0, 0) for key in keys}
                else:
                    out_axes = {
                        key: tuple(out_axes[key]) + (0, 0)
                        for key in keys
                    }
            assert isinstance(out_axes, dict)

            # reduce_func
            if reduce_func is None:
                reduce_func = lambda x: x.mean(axis=0)

            # vectorized map functions
            for key in func.steps.keys():
                func.steps[key] = _make_vmap(func=func.steps[key],
                                             dyn_vars=dyn_vars,
                                             rand_vars=rand_vars,
                                             in_axes=in_axes[key],
                                             out_axes=out_axes[key],
                                             axis_name=axis_name,
                                             batch_idx=batch_idx[key],
                                             reduce_func=reduce_func,
                                             f_name=key)

            return func

    if callable(func):
        if auto_infer:
            if dyn_vars is not None:
                dyn_vars = dyn_vars
            elif isinstance(func,
                            Base):  # Base has '__call__()' implementation
                dyn_vars = func.vars().unique()
            elif hasattr(func, '__self__'):
                if isinstance(func.__self__, Base):
                    dyn_vars = func.__self__.vars().unique()

        if dyn_vars is None:
            return jax.vmap(func,
                            in_axes=in_axes,
                            out_axes=out_axes,
                            axis_name=axis_name)

        else:
            # dynamical variables
            dyn_vars, rand_vars = TensorCollector(), TensorCollector()
            for key, val in dyn_vars.items():
                if isinstance(val, RandomState):
                    rand_vars[key] = val
                else:
                    dyn_vars[key] = val

            # in axes
            if in_axes is None:
                in_axes = (None, 0)
            elif isinstance(in_axes, (int, dict)):
                in_axes = (None, 0, in_axes)
            elif isinstance(in_axes, (tuple, list)):
                in_axes = (None, 0) + tuple(in_axes)
            assert isinstance(in_axes, (tuple, list))

            # batch size index
            batch_idx = {}
            for key, axes in batch_idx.items():
                for i, axis in enumerate(axes[2:]):
                    if axis is not None:
                        batch_idx[key] = (i, axis)
                        break
                else:
                    raise ValueError(f'Found no batch axis: {axes}.')

            # out axes
            if out_axes is None:
                out_axes = 0
            elif isinstance(out_axes, (int, dict)):
                out_axes = (out_axes, 0, 0)
            elif isinstance(out_axes, (tuple, list)):
                out_axes = tuple(out_axes) + (0, 0)
            assert isinstance(out_axes, (list, tuple))

            # reduce_func
            if reduce_func is None:
                reduce_func = lambda x: x.mean(axis=0)

            # jit function
            return _make_vmap(func=func,
                              dyn_vars=dyn_vars,
                              rand_vars=rand_vars,
                              in_axes=in_axes,
                              out_axes=out_axes,
                              axis_name=axis_name,
                              batch_idx=batch_idx,
                              reduce_func=reduce_func)

    else:
        raise errors.BrainPyError(
            f'Only support instance of {Base.__name__}, or a callable '
            f'function, but we got {type(func)}.')
Esempio n. 17
0
def pmap(func,
         dyn_vars=None,
         axis_name=None,
         in_axes=0,
         out_axes=0,
         static_broadcasted_argnums=(),
         devices=None,
         backend=None,
         axis_size=None,
         donate_argnums=(),
         global_arg_shapes=None,
         reduce_func=None):
    """Parallel compilation for class objects.

  Parallel compile a function or a module to run on multiple devices in parallel.

  Parameters
  ----------
  func
  axis_name
  in_axes
  out_axes
  static_broadcasted_argnums
  devices
  backend
  axis_size
  donate_argnums
  global_arg_shapes

  Returns
  -------


  Examples
  --------


  """
    from brainpy.building.brainobjects import DynamicalSystem

    if isinstance(func, DynamicalSystem):
        if len(func.steps):  # DynamicalSystem has step functions

            # dynamical variables
            all_vars = (dyn_vars or func.vars().unique())
            dyn_vars = TensorCollector()
            rand_vars = TensorCollector()
            for key, val in all_vars.items():
                if isinstance(val, RandomState):
                    rand_vars[key] = val
                else:
                    dyn_vars[key] = val

            # reduce function
            if reduce_func is None:
                reduce_func = jnp.concatenate

            # static broadcast-ed arguments
            if static_broadcasted_argnums is None:
                static_broadcasted_argnums = ()
            elif isinstance(static_broadcasted_argnums, int):
                static_broadcasted_argnums = (static_broadcasted_argnums + 2, )
            elif isinstance(static_broadcasted_argnums, (tuple, list)):
                static_broadcasted_argnums = tuple(
                    argnum + 2 for argnum in static_broadcasted_argnums)
            assert isinstance(static_broadcasted_argnums, (tuple, list))

            # jit functions
            for key in func.steps.keys():
                step = func.steps[key]
                func.steps[key] = _make_pmap(
                    dyn_vars=dyn_vars,
                    rand_vars=rand_vars,
                    func=step,
                    axis_name=axis_name,
                    in_axes=in_axes,
                    out_axes=out_axes,
                    static_broadcasted_argnums=static_broadcasted_argnums,
                    devices=devices,
                    backend=backend,
                    axis_size=axis_size,
                    donate_argnums=donate_argnums,
                    global_arg_shapes=global_arg_shapes,
                    reduce_func=reduce_func,
                    f_name=key)
            return func

    if callable(func):
        if dyn_vars is not None:
            dyn_vars = dyn_vars
        elif isinstance(func, Base):  # Base has '__call__()' implementation
            dyn_vars = func.vars().unique()
        elif hasattr(func, '__self__'):
            if isinstance(func.__self__, Base):
                dyn_vars = func.__self__.vars().unique()

        if dyn_vars is None:
            return jax.pmap(
                func,
                axis_name=axis_name,
                in_axes=in_axes,
                out_axes=out_axes,
                static_broadcasted_argnums=static_broadcasted_argnums,
                devices=devices,
                backend=backend,
                axis_size=axis_size,
                donate_argnums=donate_argnums,
                global_arg_shapes=global_arg_shapes)
        else:
            # dynamical variables
            dyn_vars = TensorCollector()
            rand_vars = TensorCollector()
            for key, val in dyn_vars.items():
                if isinstance(val, RandomState):
                    rand_vars[key] = val
                else:
                    dyn_vars[key] = val

            # static broadcast-ed arguments
            if static_broadcasted_argnums is None:
                static_broadcasted_argnums = ()
            elif isinstance(static_broadcasted_argnums, int):
                static_broadcasted_argnums = (static_broadcasted_argnums + 2, )
            elif isinstance(static_broadcasted_argnums, (tuple, list)):
                static_broadcasted_argnums = tuple(
                    argnum + 2 for argnum in static_broadcasted_argnums)
            assert isinstance(static_broadcasted_argnums, (tuple, list))

            # reduce function
            if reduce_func is None:
                reduce_func = jnp.concatenate

            # jit function
            func.__call__ = _make_pmap(
                dyn_vars=dyn_vars,
                rand_vars=rand_vars,
                func=func,
                axis_name=axis_name,
                in_axes=in_axes,
                out_axes=out_axes,
                static_broadcasted_argnums=static_broadcasted_argnums,
                devices=devices,
                backend=backend,
                axis_size=axis_size,
                donate_argnums=donate_argnums,
                global_arg_shapes=global_arg_shapes,
                reduce_func=reduce_func)
            return func

    else:
        raise errors.BrainPyError(
            f'Only support instance of {Base.__name__}, or a callable function, '
            f'but we got {type(func)}.')
Esempio n. 18
0
def line_plot(ts,
              val_matrix,
              plot_ids=None,
              ax=None,
              xlim=None,
              ylim=None,
              xlabel='Time (ms)',
              ylabel=None,
              legend=None,
              title=None,
              show=False,
              **kwargs):
    """Show the specified value in the given object (Neurons or Synapses.)

  Parameters
  ----------
  ts : np.ndarray
      The time steps.
  val_matrix : np.ndarray
      The value matrix which record the history trajectory.
      It can be easily accessed by specifying the ``monitors``
      of NeuGroup/SynConn by:
      ``neu/syn = NeuGroup/SynConn(..., monitors=[k1, k2])``
  plot_ids : None, int, tuple, a_list
      The index of the value to plot.
  ax : None, Axes
      The figure to plot.
  xlim : list, tuple
      The xlim.
  ylim : list, tuple
      The ylim.
  xlabel : str
      The xlabel.
  ylabel : str
      The ylabel.
  legend : str
      The prefix of legend for plot.
  show : bool
      Whether show the figure.
  """
    # get plot_ids
    if plot_ids is None:
        plot_ids = [0]
    elif isinstance(plot_ids, int):
        plot_ids = [plot_ids]
    if not isinstance(plot_ids, (list, tuple)) and \
        not (isinstance(plot_ids, np.ndarray) and np.ndim(plot_ids) == 1):
        raise errors.BrainPyError(
            f'"plot_ids" specifies the value index to plot, it must '
            f'be a list/tuple/1D numpy.ndarray, not {type(plot_ids)}.')

    # get ax
    if ax is None:
        ax = plt

    val_matrix = val_matrix.reshape((val_matrix.shape[0], -1))
    # change data
    val_matrix = np.asarray(val_matrix)
    ts = np.asarray(ts)

    # plot
    if legend:
        for idx in plot_ids:
            ax.plot(ts, val_matrix[:, idx], label=f'{legend}-{idx}', **kwargs)
    else:
        for idx in plot_ids:
            ax.plot(ts, val_matrix[:, idx], **kwargs)

    # legend
    if legend:
        ax.legend()

    # xlim
    if xlim is not None:
        plt.xlim(xlim[0], xlim[1])

    # ylim
    if ylim is not None:
        plt.ylim(ylim[0], ylim[1])

    # xlable
    if xlabel:
        plt.xlabel(xlabel)

    # ylabel
    if ylabel:
        plt.ylabel(ylabel)

    # title
    if title:
        plt.title(title)

    # show
    if show:
        plt.show()
Esempio n. 19
0
def jit(func,
        dyn_vars=None,
        static_argnames=None,
        device=None,
        auto_infer=True):
    """JIT (Just-In-Time) compilation for class objects.

  This function has the same ability to Just-In-Time compile a pure function,
  but it can also JIT compile a :py:class:`brainpy.DynamicalSystem`, or a
  :py:class:`brainpy.Base` object, or a bounded method for a
  :py:class:`brainpy.Base` object.

  .. note::
    There are several notes when using JIT compilation.

    1. Avoid using scalar in a Variable, TrainVar, etc.

    For example,

    >>> import brainpy as bp
    >>> import brainpy.math as bm
    >>>
    >>> class Test(bp.Base):
    >>>   def __init__(self):
    >>>     super(Test, self).__init__()
    >>>     self.a = bm.Variable(1.)  # Avoid! DO NOT USE!
    >>>   def __call__(self, *args, **kwargs):
    >>>     self.a += 1.

    The above usage is deprecated, because it may cause several errors.
    Instead, we recommend you define the scalar value variable as:

    >>> class Test(bp.Base):
    >>>   def __init__(self):
    >>>     super(Test, self).__init__()
    >>>     self.a = bm.Variable(bm.array([1.]))  # use array to wrap a scalar is recommended
    >>>   def __call__(self, *args, **kwargs):
    >>>     self.a += 1.

    Here, a ndarray is recommended to used to update the variable ``a``.

    2. ``jit`` compilation in ``brainpy.math`` does not support `static_argnums`.
       Instead, users should use `static_argnames`, and call the jitted function with
       keywords like ``jitted_func(arg1=var1, arg2=var2)``. For example,

    >>> def f(a, b, c=1.):
    >>>   if c > 0.: return a + b
    >>>   else: return a * b
    >>>
    >>> # ERROR! https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
    >>> bm.jit(f)(1, 2, 0)
    jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where
    concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)
    >>> # this is right
    >>> bm.jit(f, static_argnames='c')(1, 2, 0)
    DeviceArray(2, dtype=int32, weak_type=True)

  Examples
  --------

  You can JIT a :py:class:`brainpy.DynamicalSystem`

  >>> import brainpy as bp
  >>>
  >>> class LIF(bp.NeuGroup):
  >>>   pass
  >>> lif = bp.math.jit(LIF(10))

  You can JIT a :py:class:`brainpy.Base` object with ``__call__()`` implementation.

  >>> mlp = bp.layers.GRU(100, 200)
  >>> jit_mlp = bp.math.jit(mlp)

  You can also JIT a bounded method of a :py:class:`brainpy.Base` object.

  >>> class Hello(bp.Base):
  >>>   def __init__(self):
  >>>     super(Hello, self).__init__()
  >>>     self.a = bp.math.Variable(bp.math.array(10.))
  >>>     self.b = bp.math.Variable(bp.math.array(2.))
  >>>   def transform(self):
  >>>     return self.a ** self.b
  >>>
  >>> test = Hello()
  >>> bp.math.jit(test.transform)

  Further, you can JIT a normal function, just used like in JAX.

  >>> @bp.math.jit
  >>> def selu(x, alpha=1.67, lmbda=1.05):
  >>>   return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha)


  Parameters
  ----------
  func : Base, function, callable
    The instance of Base or a function.
  dyn_vars : optional, dict, tuple, list, JaxArray
    These variables will be changed in the function, or needed in the computation.
  static_argnames : optional, str, list, tuple, dict
    An optional string or collection of strings specifying which named arguments to treat
    as static (compile-time constant). See the comment on ``static_argnums`` for details.
    If not provided but ``static_argnums`` is set, the default is based on calling
    ``inspect.signature(fun)`` to find corresponding named arguments.
  device: optional, Any
    This is an experimental feature and the API is likely to change.
    Optional, the Device the jitted function will run on. (Available devices
    can be retrieved via :py:func:`jax.devices`.) The default is inherited
    from XLA's DeviceAssignment logic and is usually to use
    ``jax.devices()[0]``.
  auto_infer : bool
    Automatical infer the dynamical variables.

  Returns
  -------
  func : Any
    A wrapped version of Base object or function, set up for just-in-time compilation.
  """
    from brainpy.building.brainobjects import DynamicalSystem

    if isinstance(func, DynamicalSystem):
        if len(func.steps):  # DynamicalSystem has step functions

            # dynamical variables
            if dyn_vars is None:
                if auto_infer:
                    dyn_vars = func.vars().unique()
                else:
                    dyn_vars = TensorCollector()
            if isinstance(dyn_vars, JaxArray):
                dyn_vars = TensorCollector({'_': dyn_vars})
            elif isinstance(dyn_vars, dict):
                dyn_vars = TensorCollector(dyn_vars)
            elif isinstance(dyn_vars, (tuple, list)):
                dyn_vars = TensorCollector(
                    {f'_v{i}': v
                     for i, v in enumerate(dyn_vars)})
            else:
                raise ValueError

            # static arguments by name
            if static_argnames is None:
                static_argnames = {key: None for key in func.steps.keys()}
            elif isinstance(static_argnames, str):
                static_argnames = {
                    key: (static_argnames, )
                    for key in func.steps.keys()
                }
            elif isinstance(static_argnames, (tuple, list)) and isinstance(
                    static_argnames[0], str):
                static_argnames = {
                    key: static_argnames
                    for key in func.steps.keys()
                }
            assert isinstance(static_argnames, dict)

            # jit functions
            for key in list(func.steps.keys()):
                jitted_func = _make_jit(vars=dyn_vars,
                                        func=func.steps[key],
                                        static_argnames=static_argnames[key],
                                        device=device,
                                        f_name=key)
                func.steps.replace(key, jitted_func)
            return func

    if callable(func):
        if dyn_vars is not None:
            if isinstance(dyn_vars, JaxArray):
                dyn_vars = TensorCollector({'_': dyn_vars})
            elif isinstance(dyn_vars, dict):
                dyn_vars = TensorCollector(dyn_vars)
            elif isinstance(dyn_vars, (tuple, list)):
                dyn_vars = TensorCollector(
                    {f'_v{i}': v
                     for i, v in enumerate(dyn_vars)})
            else:
                raise ValueError
        else:
            if auto_infer:
                if isinstance(func, Base):
                    dyn_vars = func.vars().unique()
                elif hasattr(func, '__self__') and isinstance(
                        func.__self__, Base):
                    dyn_vars = func.__self__.vars().unique()
                else:
                    dyn_vars = TensorCollector()
            else:
                dyn_vars = TensorCollector()

        if len(dyn_vars) == 0:  # pure function
            return jax.jit(func,
                           static_argnames=static_argnames,
                           device=device)

        else:  # Base object which implements __call__, or bounded method of Base object
            return _make_jit(vars=dyn_vars,
                             func=func,
                             static_argnames=static_argnames,
                             device=device)

    else:
        raise errors.BrainPyError(
            f'Only support instance of {Base.__name__}, or a callable '
            f'function, but we got {type(func)}.')
Esempio n. 20
0
  def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
                       tol_aux=1e-8, loss_screen=None):
    utils.output('I am making bifurcation analysis ...')

    xs = self.resolutions[self.x_var]
    vps = bm.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names))
    vps = tuple(jnp.moveaxis(vp.value, 0, 1).flatten() for vp in vps)
    candidates = vps[0]
    pars = vps[1:]
    fixed_points, _, pars = self._get_fixed_points(candidates, *pars,
                                                   tol_aux=tol_aux,
                                                   loss_screen=loss_screen,
                                                   num_seg=len(xs))
    dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars))
    pars = tuple(np.asarray(p) for p in pars)

    if with_plot:
      if len(self.target_pars) == 1:
        container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p, x, dx in zip(pars[0], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p'].append(p)
          container[fp_type]['x'].append(x)

        # visualization
        plt.figure(self.x_var)
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
        plt.xlabel(self.target_par_names[0])
        plt.ylabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        plt.legend()
        if show:
          plt.show()

      elif len(self.target_pars) == 2:
        container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()}

        # fixed point
        for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx):
          fp_type = stability.stability_analysis(dx)
          container[fp_type]['p0'].append(p0)
          container[fp_type]['p1'].append(p1)
          container[fp_type]['x'].append(x)

        # visualization
        fig = plt.figure(self.x_var)
        ax = fig.add_subplot(projection='3d')
        for fp_type, points in container.items():
          if len(points['x']):
            plot_style = stability.plot_scheme[fp_type]
            xs = points['p0']
            ys = points['p1']
            zs = points['x']
            ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

        ax.set_xlabel(self.target_par_names[0])
        ax.set_ylabel(self.target_par_names[1])
        ax.set_zlabel(self.x_var)

        scale = (self.lim_scale - 1) / 2
        ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale))
        ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale))
        ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale))

        ax.grid(True)
        ax.legend()
        if show:
          plt.show()

      else:
        raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} '
                                  f'bifurcation.')
    if with_return:
      return fixed_points, pars, dfxdx