Ejemplo n.º 1
0
def run_integrator(method, show=False):
    f_integral = bm.jit(method(f_lorenz, dt=dt), auto_infer=False)
    x, y, z = bm.ones(1), bm.ones(1), bm.ones(1)

    def f(t):
        x.value, y.value, z.value = f_integral(x, y, z, t)

    f_scan = bm.make_loop(f, dyn_vars=[x, y, z], out_vars=[x, y, z])

    times = np.arange(0, duration, dt)
    mon_x, mon_y, mon_z = f_scan(times)
    mon_x = np.array(mon_x).flatten()
    mon_y = np.array(mon_y).flatten()
    mon_z = np.array(mon_z).flatten()

    if show:
        fig = plt.figure()
        ax = fig.gca(projection='3d')
        plt.plot(mon_x, mon_y, mon_z)
        ax.set_xlabel('x')
        ax.set_xlabel('y')
        ax.set_xlabel('z')
        plt.show()

    return mon_x, mon_y, mon_z
Ejemplo n.º 2
0
 def F_vmap_jacobian(self):
   if C.F_vmap_jacobian not in self.analyzed_results:
     f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args),
                                       self.F_fy(xy[0], xy[1], *args)])
     f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device)
     self.analyzed_results[C.F_vmap_jacobian] = f2
   return self.analyzed_results[C.F_vmap_jacobian]
Ejemplo n.º 3
0
    def find_fps_with_opt_solver(self, candidates, opt_method=None):
        """Optimize fixed points with nonlinear optimization solvers.

    Parameters
    ----------
    candidates
    opt_method: function, callable
    """

        assert bm.ndim(candidates) == 2 and isinstance(
            candidates, (bm.JaxArray, jax.numpy.ndarray))
        if opt_method is None:
            opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
        if self.verbose:
            print(f"Optimizing to find fixed points:")
        f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
        res = f_opt(bm.as_device_array(candidates))
        valid_ids = jax.numpy.where(res.success)[0]
        self._fixed_points = np.asarray(res.x[valid_ids])
        self._losses = np.asarray(res.fun[valid_ids])
        self._selected_ids = np.asarray(valid_ids)
        if self.verbose:
            print(
                f'    '
                f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.'
            )
Ejemplo n.º 4
0
def get_sign2(f, *xyz, args=()):
  in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
  f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
  xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz)
  XYZ = jnp.meshgrid(*xyz)
  XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)
  shape = (len(v) for v in xyz)
  return jnp.sign(f(*(XYZ + args))).reshape(shape)
Ejemplo n.º 5
0
def roots_of_1d_by_xy(f, starts, ends, args):
  f = f_without_jaxarray_return(f)
  f_opt = bm.jit(bm.vmap(jax_brentq(f)))
  res = f_opt(starts, ends, (args,))
  valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
  xs = res['root'][valid_idx]
  ys = args[valid_idx]
  return xs, ys
Ejemplo n.º 6
0
    def __init__(self,
                 target,
                 monitors=None,
                 inputs=(),
                 dyn_vars=None,
                 jit=False,
                 dt=None,
                 numpy_mon_after_run=True,
                 progress_bar=True):
        self._has_iter_array = False  # default do not have iterable input array

        super(StructRunner,
              self).__init__(target=target,
                             inputs=inputs,
                             monitors=monitors,
                             jit=jit,
                             dt=dt,
                             dyn_vars=dyn_vars,
                             numpy_mon_after_run=numpy_mon_after_run)

        # intrinsic parameters
        self._i = math.Variable(math.asarray([0]))
        self._pbar = None  # progress bar
        self.progress_bar = progress_bar

        # JAX does not support iterator in fori_loop, scan, etc.
        #   https://github.com/google/jax/issues/3567
        # We use Variable i to index the current input data.
        if self._has_iter_array:
            self.dyn_vars.update({'_i': self._i})
        else:
            self._i = None

        # setup step function
        if progress_bar:

            def _step(t_and_dt):
                _t, _dt = t_and_dt[0], t_and_dt[1]
                self._input_step(_t=_t, _dt=_dt)
                for step in self.target.steps.values():
                    step(_t=_t, _dt=_dt)
                # id_tap(lambda *args: self._pbar.update(round(self.dt, 4)), ())
                id_tap(lambda *args: self._pbar.update(), ())
                return self._monitor_step(_t=_t, _dt=_dt)
        else:

            def _step(t_and_dt):
                _t, _dt = t_and_dt[0], t_and_dt[1]
                self._input_step(_t=_t, _dt=_dt)
                for step in self.target.steps.values():
                    step(_t=_t, _dt=_dt)
                return self._monitor_step(_t=_t, _dt=_dt)

        # build the update step
        self._step = math.make_loop(_step,
                                    dyn_vars=self.dyn_vars,
                                    has_return=True)
        if jit: self._step = math.jit(self._step, dyn_vars=dyn_vars)
Ejemplo n.º 7
0
def brentq_roots(f, starts, ends, *vmap_args, args=()):
  in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args)))
  vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes))
  all_args = vmap_args + args
  if len(all_args):
    res = vmap_f_opt(starts, ends, all_args)
  else:
    res = vmap_f_opt(starts, ends, )
  valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
  roots = res['root'][valid_idx]
  vmap_args = tuple(a[valid_idx] for a in vmap_args)
  return roots, vmap_args
Ejemplo n.º 8
0
    def __init__(self,
                 f_cell,
                 f_type='continuous',
                 f_loss_batch=None,
                 verbose=True):
        self.verbose = verbose
        if f_type not in ['discrete', 'continuous']:
            raise AnalyzerError(
                f'Only support "continuous" (continuous derivative function) or '
                f'"discrete" (discrete update function), not {f_type}.')

        # functions
        self.f_cell = f_cell
        if f_loss_batch is None:
            if f_type == 'discrete':
                self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h))**2))
                self.f_loss_batch = bm.jit(lambda h: bm.mean(
                    (h - bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1))
            if f_type == 'continuous':
                self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2))
                self.f_loss_batch = bm.jit(lambda h: bm.mean(
                    (bm.vmap(f_cell, auto_infer=False)(h))**2, axis=1))

        else:
            self.f_loss_batch = f_loss_batch
            self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h)**2))
        self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell)))

        # essential variables
        self._losses = None
        self._fixed_points = None
        self._selected_ids = None
        self.opt_losses = None
Ejemplo n.º 9
0
def roots_of_1d_by_x(f, candidates, args=()):
  """Find the roots of the given function by numerical methods.
  """
  f = f_without_jaxarray_return(f)
  candidates = candidates.value if isinstance(candidates, bm.JaxArray) else candidates
  args = tuple(a.value if isinstance(candidates, bm.JaxArray) else a for a in args)
  vals = f(candidates, *args)
  signs = jnp.sign(vals)
  zero_sign_idx = jnp.where(signs == 0)[0]
  fps = candidates[zero_sign_idx]
  candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0]
  if len(candidate_ids) <= 0:
    return fps
  starts = candidates[candidate_ids]
  ends = candidates[candidate_ids + 1]
  f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None)))
  res = f_opt(starts, ends, args)
  valid_idx = jnp.where(res['status'] == ECONVERGED)[0]
  fps2 = res['root'][valid_idx]
  return jnp.concatenate([fps, fps2])
Ejemplo n.º 10
0
    def __init__(self,
                 target,
                 inputs=(),
                 monitors=None,
                 dyn_vars=None,
                 jit=False,
                 dt=None,
                 numpy_mon_after_run=True):
        super(ReportRunner,
              self).__init__(target=target,
                             inputs=inputs,
                             monitors=monitors,
                             jit=jit,
                             dt=dt,
                             dyn_vars=dyn_vars,
                             numpy_mon_after_run=numpy_mon_after_run)

        # Build the update function
        self._update_step = lambda _t, _dt: [
            _step(_t=_t, _dt=_dt) for _step in self.target.steps.values()
        ]
        if jit:
            self._update_step = math.jit(self._update_step,
                                         dyn_vars=self.dyn_vars)
Ejemplo n.º 11
0
# %% [markdown]
# ## Train the recurrent network on the decision-making task

# %%
# Instantiate the network and print information
hidden_size = 64
net = RNN(num_input=input_size,
          num_hidden=hidden_size,
          num_output=output_size,
          num_batch=batch_size,
          dt=env.dt)

# %%
# prediction method
predict = bm.jit(net.predict, dyn_vars=net.vars())

# Adam optimizer
opt = bm.optimizers.Adam(lr=0.001, train_vars=net.train_vars().unique())

# gradient function
grad_f = bm.grad(net.loss,
                 dyn_vars=net.vars(),
                 grad_vars=net.train_vars().unique(),
                 return_value=True,
                 has_aux=True)


# training function
@bm.jit
@bm.function(nodes=(net, opt))
Ejemplo n.º 12
0
 def F_vmap_dfxdx(self):
   if C.F_vmap_dfxdx not in self.analyzed_results:
     f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
     self.analyzed_results[C.F_vmap_dfxdx] = f
   return self.analyzed_results[C.F_vmap_dfxdx]
Ejemplo n.º 13
0
 def F_vmap_brentq_fy(self):
     if C.F_vmap_brentq_fy not in self.analyzed_results:
         f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)))
         self.analyzed_results[C.F_vmap_brentq_fy] = f_opt
     return self.analyzed_results[C.F_vmap_brentq_fy]
Ejemplo n.º 14
0
  def test_jacfwd_and_aux_nested(self):
    def f(x):
      jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
      return aux[0]

    f2 = lambda x: x ** 3

    self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.))

    self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))

    def f(x):
      jac, aux = _jacfwd(lambda x: (x ** 3, [x ** 3]), has_aux=True)(x)
      return aux[0] * bm.sin(x)

    f2 = lambda x: x ** 3 * bm.sin(x)

    self.assertEqual(_jacfwd(f)(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(f))(4.), _jacfwd(f2)(4.))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(4.), _jacfwd(f2)(4.))

    self.assertEqual(_jacfwd(f)(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(f))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))
    self.assertEqual(bm.jit(_jacfwd(bm.jit(f)))(bm.asarray(4.)), _jacfwd(f2)(bm.asarray(4.)))