Esempio n. 1
0
    def find_fps_with_opt_solver(self, candidates, opt_method=None):
        """Optimize fixed points with nonlinear optimization solvers.

    Parameters
    ----------
    candidates
    opt_method: function, callable
    """

        assert bm.ndim(candidates) == 2 and isinstance(
            candidates, (bm.JaxArray, jax.numpy.ndarray))
        if opt_method is None:
            opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
        if self.verbose:
            print(f"Optimizing to find fixed points:")
        f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
        res = f_opt(bm.as_device_array(candidates))
        valid_ids = jax.numpy.where(res.success)[0]
        self._fixed_points = np.asarray(res.x[valid_ids])
        self._losses = np.asarray(res.fun[valid_ids])
        self._selected_ids = np.asarray(valid_ids)
        if self.verbose:
            print(
                f'    '
                f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.'
            )
Esempio n. 2
0
    def build_monitors(self, show_code=False):
        monitors = utils.check_and_format_monitors(host=self.target,
                                                   mon=self.mon)

        returns = []
        code_lines = []
        host = self.target
        code_scope = dict(sys=sys)
        for key, target, variable, idx, interval in monitors:
            code_scope[host.name] = host
            code_scope[target.name] = target

            # get data
            data = target
            for k in variable.split('.'):
                data = getattr(data, k)

            # get the data key in the host
            if not isinstance(data, math.Variable):
                raise RunningError(
                    f'"{key}" in {target} is not a dynamically changed Variable, '
                    f'its value will not change, we think there is no need to '
                    f'monitor its trajectory.')
            if math.ndim(data) == 1:
                key_in_host = f'{target.name}.{variable}.value'
            else:
                key_in_host = f'{target.name}.{variable}.value.flatten()'

            # format the monitor index
            if idx is None:
                right = key_in_host
            else:
                idx = math.asarray(idx)
                right = f'{key_in_host}[_{key.replace(".", "_")}_idx]'
                code_scope[f'_{key.replace(".", "_")}_idx'] = idx.value

            # format the monitor lines according to the time interval
            returns.append(right)
            if interval is not None:
                raise ValueError(
                    f'Running with "{self.__class__.__name__}" does '
                    f'not support "interval" in the monitor.')

        if len(code_lines) or len(returns):
            code_lines.append(f'return {", ".join(returns) + ", "}')
            # function
            code_scope_old = {k: v for k, v in code_scope.items()}
            code, func = tools.code_lines_to_func(lines=code_lines,
                                                  func_name=_mon_func_name,
                                                  func_args=['_t', '_dt'],
                                                  scope=code_scope)
            if show_code:
                print(code)
                print()
                pprint(code_scope_old)
                print()
        else:
            func = lambda _t, _dt: None
        return func
Esempio n. 3
0
 def make_conn(self, x):
     assert bm.ndim(x) == 1
     x_left = bm.reshape(x, (-1, 1))
     x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
     d = self.dist(x_left - x_right)
     Jxx = self.J0 * bm.exp(
         -0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
     return Jxx
Esempio n. 4
0
    def compute_jacobians(self, points):
        """Compute the jacobian matrices at the points.

    Parameters
    ----------
    points: np.ndarray, bm.JaxArray, jax.ndarray
      The fixed points with the shape of (num_point, num_dim).

    Returns
    -------
    jacobians : bm.JaxArray
      npoints number of jacobians, np array with shape npoints x dim x dim
    """
        # if len(self.fixed_points) == 0: return
        if bm.ndim(points) == 1:
            points = bm.asarray([
                points,
            ])
        assert bm.ndim(points) == 2
        return self.f_jacob_batch(bm.asarray(points))
Esempio n. 5
0
    def __init__(self, size, delay, dtype=None, dt=None, **kwargs):
        # dt
        self.dt = bm.get_dt() if dt is None else dt

        # data size
        if isinstance(size, int): size = (size, )
        if not isinstance(size, (tuple, list)):
            raise ModelBuildError(
                f'"size" must a tuple/list of int, but we got {type(size)}: {size}'
            )
        self.size = tuple(size)

        # delay time length
        self.delay = delay

        # data and operations
        if isinstance(delay, (int, float)):  # uniform delay
            self.uniform_delay = True
            self.num_step = int(pm.ceil(delay / self.dt)) + 1
            self.out_idx = bm.Variable(bm.array([0], dtype=bm.uint32))
            self.in_idx = bm.Variable(
                bm.array([self.num_step - 1], dtype=bm.uint32))
            self.data = bm.Variable(
                bm.zeros((self.num_step, ) + self.size, dtype=dtype))

        else:  # non-uniform delay
            self.uniform_delay = False
            if not len(self.size) == 1:
                raise NotImplementedError(
                    f'Currently, BrainPy only supports 1D heterogeneous '
                    f'delays, while we got the heterogeneous delay with '
                    f'{len(self.size)}-dimensions.')
            self.num = size2len(size)
            if bm.ndim(delay) != 1:
                raise ModelBuildError(f'Only support a 1D non-uniform delay. '
                                      f'But we got {delay.ndim}D: {delay}')
            if delay.shape[0] != self.size[0]:
                raise ModelBuildError(
                    f"The first shape of the delay time size must "
                    f"be the same with the delay data size. But "
                    f"we got {delay.shape[0]} != {self.size[0]}")
            delay = bm.around(delay / self.dt)
            self.diag = bm.array(bm.arange(self.num), dtype=bm.int_)
            self.num_step = bm.array(delay, dtype=bm.uint32) + 1
            self.in_idx = bm.Variable(self.num_step - 1)
            self.out_idx = bm.Variable(bm.zeros(self.num, dtype=bm.uint32))
            self.data = bm.Variable(
                bm.zeros((self.num_step.max(), ) + size, dtype=dtype))

        super(ConstantDelay, self).__init__(**kwargs)
Esempio n. 6
0
def brentq_candidates(vmap_f, *values, args=()):
  # change the position of meshgrid values
  values = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in values)
  xs = values[0]
  mesh_values = jnp.meshgrid(*values)
  if bm.ndim(mesh_values[0]) > 1:
    mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values)
  mesh_values = tuple(m.flatten() for m in mesh_values)
  # function outputs
  signs = jnp.sign(vmap_f(*(mesh_values + args)))
  # compute the selected values
  signs = signs.reshape((xs.shape[0], -1))
  par_len = signs.shape[1]
  signs1 = signs.at[-1].set(1)  # discard the final row
  signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1)  # discard the first row
  ids = jnp.where((signs1 * signs2).flatten() <= 0)[0]
  x_starts = mesh_values[0][ids]
  x_ends = mesh_values[0][ids + par_len]
  other_vals = tuple(v[ids] for v in mesh_values[1:])
  return x_starts, x_ends, other_vals
Esempio n. 7
0
    def build_monitors(self, show_code=False):
        """Get the monitor function according to the user's setting.

    This method will consider the following things:

    1. the monitor variable
    2. the monitor index
    3. the monitor interval

    """
        monitors = utils.check_and_format_monitors(host=self.target,
                                                   mon=self.mon)

        host = self.target
        code_lines = []
        code_scope = dict(sys=sys, self_mon=self.mon)
        for key, target, variable, idx, interval in monitors:
            code_scope[host.name] = host
            code_scope[target.name] = target

            # get data
            data = target
            for k in variable.split('.'):
                data = getattr(data, k)

            # get the data key in the host
            if not isinstance(data, math.Variable):
                raise RunningError(
                    f'"{key}" in {target} is not a dynamically changed Variable, '
                    f'its value will not change, we think there is no need to '
                    f'monitor its trajectory.')
            if math.ndim(data) == 1:
                key_in_host = f'{target.name}.{variable}.value'
            else:
                key_in_host = f'{target.name}.{variable}.value.flatten()'

            # format the monitor index
            if idx is None:
                right = key_in_host
            else:
                idx = math.asarray(idx)
                right = f'{key_in_host}[_{key.replace(".", "_")}_idx]'
                code_scope[f'_{key.replace(".", "_")}_idx'] = idx.value

            # format the monitor lines according to the time interval
            if interval is None:
                code_lines.append(
                    f'self_mon.item_contents["{key}"].append({right})')
            else:
                code_scope[f'_{key.replace(".", "_")}_next_time'] = interval
                code_lines.extend([
                    f'global _{key.replace(".", "_")}_next_time',
                    f'if _t >= _{key.replace(".", "_")}_next_time:',
                    f'  self_mon.item_contents["{key}"].append({right})',
                    f'  self_mon.item_contents["{key}.t"].append(_t)',
                    f'  _{key.replace(".", "_")}_next_time += {interval}'
                ])

        if len(code_lines):
            # function
            code_scope_old = {k: v for k, v in code_scope.items()}
            code, func = tools.code_lines_to_func(lines=code_lines,
                                                  func_name=_mon_func_name,
                                                  func_args=['_t', '_dt'],
                                                  scope=code_scope)
            if show_code:
                print(code)
                print()
                pprint(code_scope_old)
                print()
        else:
            func = lambda _t, _dt: None
        return func