예제 #1
0
    def __init__(self,
                 f_cell,
                 f_type='continuous',
                 f_loss_batch=None,
                 verbose=True):
        self.verbose = verbose
        if f_type not in ['discrete', 'continuous']:
            raise AnalyzerError(
                f'Only support "continuous" (continuous derivative function) or '
                f'"discrete" (discrete update function), not {f_type}.')

        # functions
        self.f_cell = f_cell
        if f_loss_batch is None:
            if f_type == 'discrete':
                self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h))**2))
                self.f_loss_batch = bm.jit(lambda h: bm.mean(
                    (h - bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1))
            if f_type == 'continuous':
                self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2))
                self.f_loss_batch = bm.jit(lambda h: bm.mean(
                    (bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1))

        else:
            self.f_loss_batch = f_loss_batch
            self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2))
        self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell)))

        # essential variables
        self._losses = None
        self._fixed_points = None
        self._selected_ids = None
        self.opt_losses = None
예제 #2
0
 def F_vmap_jacobian(self):
   if C.F_vmap_jacobian not in self.analyzed_results:
     f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args),
                                       self.F_fy(xy[0], xy[1], *args)])
     f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device)
     self.analyzed_results[C.F_vmap_jacobian] = f2
   return self.analyzed_results[C.F_vmap_jacobian]
예제 #3
0
    def find_fps_with_opt_solver(self, candidates, opt_method=None):
        """Optimize fixed points with nonlinear optimization solvers.

    Parameters
    ----------
    candidates
    opt_method: function, callable
    """

        assert bm.ndim(candidates) == 2 and isinstance(
            candidates, (bm.JaxArray, jax.numpy.ndarray))
        if opt_method is None:
            opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
        if self.verbose:
            print(f"Optimizing to find fixed points:")
        f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
        res = f_opt(bm.as_device_array(candidates))
        valid_ids = jax.numpy.where(res.success)[0]
        self._fixed_points = np.asarray(res.x[valid_ids])
        self._losses = np.asarray(res.fun[valid_ids])
        self._selected_ids = np.asarray(valid_ids)
        if self.verbose:
            print(
                f'    '
                f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.'
            )
예제 #4
0
def get_sign2(f, *xyz, args=()):
  in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
  f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
  xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz)
  XYZ = jnp.meshgrid(*xyz)
  XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)
  shape = (len(v) for v in xyz)
  return jnp.sign(f(*(XYZ + args))).reshape(shape)
예제 #5
0
def roots_of_1d_by_xy(f, starts, ends, args):
  f = f_without_jaxarray_return(f)
  f_opt = bm.jit(bm.vmap(jax_brentq(f)))
  res = f_opt(starts, ends, (args,))
  valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
  xs = res['root'][valid_idx]
  ys = args[valid_idx]
  return xs, ys
예제 #6
0
def brentq_roots(f, starts, ends, *vmap_args, args=()):
  in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args)))
  vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes))
  all_args = vmap_args + args
  if len(all_args):
    res = vmap_f_opt(starts, ends, all_args)
  else:
    res = vmap_f_opt(starts, ends, )
  valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
  roots = res['root'][valid_idx]
  vmap_args = tuple(a[valid_idx] for a in vmap_args)
  return roots, vmap_args
예제 #7
0
def roots_of_1d_by_x(f, candidates, args=()):
  """Find the roots of the given function by numerical methods.
  """
  f = f_without_jaxarray_return(f)
  candidates = candidates.value if isinstance(candidates, bm.JaxArray) else candidates
  args = tuple(a.value if isinstance(candidates, bm.JaxArray) else a for a in args)
  vals = f(candidates, *args)
  signs = jnp.sign(vals)
  zero_sign_idx = jnp.where(signs == 0)[0]
  fps = candidates[zero_sign_idx]
  candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0]
  if len(candidate_ids) <= 0:
    return fps
  starts = candidates[candidate_ids]
  ends = candidates[candidate_ids + 1]
  f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None)))
  res = f_opt(starts, ends, args)
  valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
  fps2 = res['root'][valid_idx]
  return jnp.concatenate([fps, fps2])
예제 #8
0
 def F_vmap_dfxdx(self):
   if C.F_vmap_dfxdx not in self.analyzed_results:
     f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
     self.analyzed_results[C.F_vmap_dfxdx] = f
   return self.analyzed_results[C.F_vmap_dfxdx]
예제 #9
0
 def F_vmap_brentq_fy(self):
     if C.F_vmap_brentq_fy not in self.analyzed_results:
         f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)))
         self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
     return self.analyzed_results[C.F_vmap_brentq_fy]
예제 #10
0
cann = CANN2D(length=512, k=0.1)
cann.show_conn()

# encoding
Iext, length = bp.inputs.section_input(
    values=[cann.get_stimulus_by_pos([0., 0.]), 0.],
    durations=[10., 20.],
    return_length=True)
runner = bp.StructRunner(cann,
                         inputs=['input', Iext, 'iter'],
                         monitors=['r'],
                         dyn_vars=cann.vars())
runner.run(length)

bp.visualize.animate_2D(values=runner.mon.r,
                        net_size=(cann.length, cann.length))

# tracking
length = 20
positions = bp.inputs.ramp_input(-bm.pi, bm.pi, duration=length, t_start=0)
positions = bm.stack([positions, positions]).T
Iext = bm.vmap(cann.get_stimulus_by_pos)(positions)
runner = bp.StructRunner(cann,
                         inputs=['input', Iext, 'iter'],
                         monitors=['r'],
                         dyn_vars=cann.vars())
runner.run(length)

bp.visualize.animate_2D(values=runner.mon.r,
                        net_size=(cann.length, cann.length))