def run_integrator(method, show=False): f_integral = bm.jit(method(f_lorenz, dt=dt), auto_infer=False) x, y, z = bm.ones(1), bm.ones(1), bm.ones(1) def f(t): x.value, y.value, z.value = f_integral(x, y, z, t) f_scan = bm.make_loop(f, dyn_vars=[x, y, z], out_vars=[x, y, z]) times = np.arange(0, duration, dt) mon_x, mon_y, mon_z = f_scan(times) mon_x = np.array(mon_x).flatten() mon_y = np.array(mon_y).flatten() mon_z = np.array(mon_z).flatten() if show: fig = plt.figure() ax = fig.gca(projection='3d') plt.plot(mon_x, mon_y, mon_z) ax.set_xlabel('x') ax.set_xlabel('y') ax.set_xlabel('z') plt.show() return mon_x, mon_y, mon_z
def F_vmap_jacobian(self): if C.F_vmap_jacobian not in self.analyzed_results: f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args), self.F_fy(xy[0], xy[1], *args)]) f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device) self.analyzed_results[C.F_vmap_jacobian] = f2 return self.analyzed_results[C.F_vmap_jacobian]
def find_fps_with_opt_solver(self, candidates, opt_method=None): """Optimize fixed points with nonlinear optimization solvers. Parameters ---------- candidates opt_method: function, callable """ assert bm.ndim(candidates) == 2 and isinstance( candidates, (bm.JaxArray, jax.numpy.ndarray)) if opt_method is None: opt_method = lambda f, x0: minimize(f, x0, method='BFGS') if self.verbose: print(f"Optimizing to find fixed points:") f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0))) res = f_opt(bm.as_device_array(candidates)) valid_ids = jax.numpy.where(res.success)[0] self._fixed_points = np.asarray(res.x[valid_ids]) self._losses = np.asarray(res.fun[valid_ids]) self._selected_ids = np.asarray(valid_ids) if self.verbose: print( f' ' f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.' )
def get_sign2(f, *xyz, args=()): in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz) XYZ = jnp.meshgrid(*xyz) XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) shape = (len(v) for v in xyz) return jnp.sign(f(*(XYZ + args))).reshape(shape)
def roots_of_1d_by_xy(f, starts, ends, args): f = f_without_jaxarray_return(f) f_opt = bm.jit(bm.vmap(jax_brentq(f))) res = f_opt(starts, ends, (args,)) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] xs = res['root'][valid_idx] ys = args[valid_idx] return xs, ys
def __init__(self, target, monitors=None, inputs=(), dyn_vars=None, jit=False, dt=None, numpy_mon_after_run=True, progress_bar=True): self._has_iter_array = False # default do not have iterable input array super(StructRunner, self).__init__(target=target, inputs=inputs, monitors=monitors, jit=jit, dt=dt, dyn_vars=dyn_vars, numpy_mon_after_run=numpy_mon_after_run) # intrinsic parameters self._i = math.Variable(math.asarray([0])) self._pbar = None # progress bar self.progress_bar = progress_bar # JAX does not support iterator in fori_loop, scan, etc. # https://github.com/google/jax/issues/3567 # We use Variable i to index the current input data. if self._has_iter_array: self.dyn_vars.update({'_i': self._i}) else: self._i = None # setup step function if progress_bar: def _step(t_and_dt): _t, _dt = t_and_dt[0], t_and_dt[1] self._input_step(_t=_t, _dt=_dt) for step in self.target.steps.values(): step(_t=_t, _dt=_dt) # id_tap(lambda *args: self._pbar.update(round(self.dt, 4)), ()) id_tap(lambda *args: self._pbar.update(), ()) return self._monitor_step(_t=_t, _dt=_dt) else: def _step(t_and_dt): _t, _dt = t_and_dt[0], t_and_dt[1] self._input_step(_t=_t, _dt=_dt) for step in self.target.steps.values(): step(_t=_t, _dt=_dt) return self._monitor_step(_t=_t, _dt=_dt) # build the update step self._step = math.make_loop(_step, dyn_vars=self.dyn_vars, has_return=True) if jit: self._step = math.jit(self._step, dyn_vars=dyn_vars)
def brentq_roots(f, starts, ends, *vmap_args, args=()): in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes)) all_args = vmap_args + args if len(all_args): res = vmap_f_opt(starts, ends, all_args) else: res = vmap_f_opt(starts, ends, ) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] roots = res['root'][valid_idx] vmap_args = tuple(a[valid_idx] for a in vmap_args) return roots, vmap_args
def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True): self.verbose = verbose if f_type not in ['discrete', 'continuous']: raise AnalyzerError( f'Only support "continuous" (continuous derivative function) or ' f'"discrete" (discrete update function), not {f_type}.') # functions self.f_cell = f_cell if f_loss_batch is None: if f_type == 'discrete': self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h))**2)) self.f_loss_batch = bm.jit(lambda h: bm.mean( (h - bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1)) if f_type == 'continuous': self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2)) self.f_loss_batch = bm.jit(lambda h: bm.mean( (bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1)) else: self.f_loss_batch = f_loss_batch self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2)) self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell))) # essential variables self._losses = None self._fixed_points = None self._selected_ids = None self.opt_losses = None
def roots_of_1d_by_x(f, candidates, args=()): """Find the roots of the given function by numerical methods. """ f = f_without_jaxarray_return(f) candidates = candidates.value if isinstance(candidates, bm.JaxArray) else candidates args = tuple(a.value if isinstance(candidates, bm.JaxArray) else a for a in args) vals = f(candidates, *args) signs = jnp.sign(vals) zero_sign_idx = jnp.where(signs == 0)[0] fps = candidates[zero_sign_idx] candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0] if len(candidate_ids) <= 0: return fps starts = candidates[candidate_ids] ends = candidates[candidate_ids + 1] f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None))) res = f_opt(starts, ends, args) valid_idx = jnp.where(res['status'] == ECONVERGED)[0] fps2 = res['root'][valid_idx] return jnp.concatenate([fps, fps2])
def __init__(self, target, inputs=(), monitors=None, dyn_vars=None, jit=False, dt=None, numpy_mon_after_run=True): super(ReportRunner, self).__init__(target=target, inputs=inputs, monitors=monitors, jit=jit, dt=dt, dyn_vars=dyn_vars, numpy_mon_after_run=numpy_mon_after_run) # Build the update function self._update_step = lambda _t, _dt: [ _step(_t=_t, _dt=_dt) for _step in self.target.steps.values() ] if jit: self._update_step = math.jit(self._update_step, dyn_vars=self.dyn_vars)
# %% [markdown] # ## Train the recurrent network on the decision-making task # %% # Instantiate the network and print information hidden_size = 64 net = RNN(num_input=input_size, num_hidden=hidden_size, num_output=output_size, num_batch=batch_size, dt=env.dt) # %% # prediction method predict = bm.jit(net.predict, dyn_vars=net.vars()) # Adam optimizer opt = bm.optimizers.Adam(lr=0.001, train_vars=net.train_vars().unique()) # gradient function grad_f = bm.grad(net.loss, dyn_vars=net.vars(), grad_vars=net.train_vars().unique(), return_value=True, has_aux=True) # training function @bm.jit @bm.function(nodes=(net, opt))
def F_vmap_dfxdx(self): if C.F_vmap_dfxdx not in self.analyzed_results: f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) self.analyzed_results[C.F_vmap_dfxdx] = f return self.analyzed_results[C.F_vmap_dfxdx]
def F_vmap_brentq_fy(self): if C.F_vmap_brentq_fy not in self.analyzed_results: f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy))) self.analyzed_results[C.F_vmap_brentq_fy] = f_opt return self.analyzed_results[C.F_vmap_brentq_fy]
def test_jacfwd_and_aux_nested(self): def f(x): jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x) return aux[0] f2 = lambda x: x ** 3 self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.)) self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) def f(x): jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x) return aux[0] * bm.sin(x) f2 = lambda x: x ** 3 * bm.sin(x) self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.)) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.)) self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.))) self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))