def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True): self.verbose = verbose if f_type not in ['discrete', 'continuous']: raise AnalyzerError( f'Only support "continuous" (continuous derivative function) or ' f'"discrete" (discrete update function), not {f_type}.') # functions self.f_cell = f_cell if f_loss_batch is None: if f_type == 'discrete': self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h))**2)) self.f_loss_batch = bm.jit(lambda h: bm.mean( (h - bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1)) if f_type == 'continuous': self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2)) self.f_loss_batch = bm.jit(lambda h: bm.mean( (bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1)) else: self.f_loss_batch = f_loss_batch self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2)) self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell))) # essential variables self._losses = None self._fixed_points = None self._selected_ids = None self.opt_losses = None
def F_vmap_jacobian(self): if C.F_vmap_jacobian not in self.analyzed_results: f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args), self.F_fy(xy[0], xy[1], *args)]) f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device) self.analyzed_results[C.F_vmap_jacobian] = f2 return self.analyzed_results[C.F_vmap_jacobian]
def find_fps_with_opt_solver(self, candidates, opt_method=None): """Optimize fixed points with nonlinear optimization solvers. Parameters ---------- candidates opt_method: function, callable """ assert bm.ndim(candidates) == 2 and isinstance( candidates, (bm.JaxArray, jax.numpy.ndarray)) if opt_method is None: opt_method = lambda f, x0: minimize(f, x0, method='BFGS') if self.verbose: print(f"Optimizing to find fixed points:") f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0))) res = f_opt(bm.as_device_array(candidates)) valid_ids = jax.numpy.where(res.success)[0] self._fixed_points = np.asarray(res.x[valid_ids]) self._losses = np.asarray(res.fun[valid_ids]) self._selected_ids = np.asarray(valid_ids) if self.verbose: print( f' ' f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.' )
def get_sign2(f, *xyz, args=()): in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz) XYZ = jnp.meshgrid(*xyz) XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) shape = (len(v) for v in xyz) return jnp.sign(f(*(XYZ + args))).reshape(shape)
def roots_of_1d_by_xy(f, starts, ends, args): f = f_without_jaxarray_return(f) f_opt = bm.jit(bm.vmap(jax_brentq(f))) res = f_opt(starts, ends, (args,)) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] xs = res['root'][valid_idx] ys = args[valid_idx] return xs, ys
def brentq_roots(f, starts, ends, *vmap_args, args=()): in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes)) all_args = vmap_args + args if len(all_args): res = vmap_f_opt(starts, ends, all_args) else: res = vmap_f_opt(starts, ends, ) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] roots = res['root'][valid_idx] vmap_args = tuple(a[valid_idx] for a in vmap_args) return roots, vmap_args
def roots_of_1d_by_x(f, candidates, args=()): """Find the roots of the given function by numerical methods. """ f = f_without_jaxarray_return(f) candidates = candidates.value if isinstance(candidates, bm.JaxArray) else candidates args = tuple(a.value if isinstance(candidates, bm.JaxArray) else a for a in args) vals = f(candidates, *args) signs = jnp.sign(vals) zero_sign_idx = jnp.where(signs == 0)[0] fps = candidates[zero_sign_idx] candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0] if len(candidate_ids) <= 0: return fps starts = candidates[candidate_ids] ends = candidates[candidate_ids + 1] f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None))) res = f_opt(starts, ends, args) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] fps2 = res['root'][valid_idx] return jnp.concatenate([fps, fps2])
def F_vmap_dfxdx(self): if C.F_vmap_dfxdx not in self.analyzed_results: f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) self.analyzed_results[C.F_vmap_dfxdx] = f return self.analyzed_results[C.F_vmap_dfxdx]
def F_vmap_brentq_fy(self): if C.F_vmap_brentq_fy not in self.analyzed_results: f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy))) self.analyzed_results[C.F_vmap_brentq_fy] = f_opt return self.analyzed_results[C.F_vmap_brentq_fy]
cann = CANN2D(length=512, k=0.1) cann.show_conn() # encoding Iext, length = bp.inputs.section_input( values=[cann.get_stimulus_by_pos([0., 0.]), 0.], durations=[10., 20.], return_length=True) runner = bp.StructRunner(cann, inputs=['input', Iext, 'iter'], monitors=['r'], dyn_vars=cann.vars()) runner.run(length) bp.visualize.animate_2D(values=runner.mon.r, net_size=(cann.length, cann.length)) # tracking length = 20 positions = bp.inputs.ramp_input(-bm.pi, bm.pi, duration=length, t_start=0) positions = bm.stack([positions, positions]).T Iext = bm.vmap(cann.get_stimulus_by_pos)(positions) runner = bp.StructRunner(cann, inputs=['input', Iext, 'iter'], monitors=['r'], dyn_vars=cann.vars()) runner.run(length) bp.visualize.animate_2D(values=runner.mon.r, net_size=(cann.length, cann.length))