Ejemplo n.º 1
0
    def test_grad_closure_with_vmap(self):
        # https://github.com/google/jax/issues/2718
        @jax.jit
        def experiment(x):
            def model(y, t):
                return -x * y

            history = odeint(model, 1., np.arange(0, 10, 0.1))
            return history[-1]

        gradfun = jax.value_and_grad(experiment)
        t = np.arange(0., 1., 0.01)
        h, g = jax.vmap(gradfun)(t)  # doesn't crash
        ans = h[11], g[11]

        expected_h = experiment(t[11])
        expected_g = (experiment(t[11] + 1e-5) - expected_h) / 1e-5
        expected = expected_h, expected_g

        self.assertAllClose(ans,
                            expected,
                            check_dtypes=False,
                            atol=1e-2,
                            rtol=1e-2)
Ejemplo n.º 2
0
def train(network_def, target_params, optimizer, states, actions, next_states,
          rewards, terminals, cumulative_gamma):
    """Run the training step."""
    online_params = optimizer.target

    def loss_fn(params, target):
        def q_online(state):
            return network_def.apply(params, state)

        q_values = jax.vmap(q_online)(states).q_values
        q_values = jnp.squeeze(q_values)
        replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions)
        loss = jnp.mean(jax.vmap(huber_loss)(target, replay_chosen_q))
        return loss

    def q_target(state):
        return network_def.apply(target_params, state)

    target = target_q(q_target, next_states, rewards, terminals,
                      cumulative_gamma)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(online_params, target)
    optimizer = optimizer.apply_gradient(grad)
    return optimizer, loss
Ejemplo n.º 3
0
def _wrap_f(vs, names, f, jit, _convert):
    # Differentiable assignments will overwrite the variables, so make a copy.
    vs_copy = vs.copy()

    # Keep track of function evaluations.
    f_evals = []

    def f_vectorised(x):
        vs_copy.set_latent_vector(x, *names, differentiable=True)
        return f(vs_copy)

    f_value_and_grad = value_and_grad(f_vectorised)

    if jit:
        f_value_and_grad = B.jit(f_value_and_grad)

    def f_wrapped(x):
        x = B.cast(vs.dtype, x)

        # Compute objective function value and gradient.
        try:
            obj_value, grad = f_value_and_grad(x)
        except Exception as e:
            return exception(x, e)

        # Perform requested conversion.
        obj_value, grad = _convert(obj_value, grad)

        # The gradient may not have the right memory layout, which sometimes cannot
        # be adjusted. We therefore make a copy, which can always be freely manipulated.
        grad = np.array(grad)

        f_evals.append(obj_value)
        return obj_value, grad

    return f_evals, f_wrapped
Ejemplo n.º 4
0
def train_step(model: models.NerfModel,
               rng_key: Callable[[int], jnp.ndarray],
               state,
               batch: Dict[str, Any],
               scalar_params: ScalarParams,
               use_elastic_loss: bool = False,
               elastic_reduce_method: str = 'median',
               use_background_loss: bool = False):
  """One optimization step.

  Args:
    model: the model module to evaluate.
    rng_key: The random number generator.
    state: model_utils.TrainState, state of model and optimizer.
    batch: dict. A mini-batch of data for training.
    scalar_params: scalar-valued parameters.
    use_elastic_loss: is True use the elastic regularization loss.
    elastic_reduce_method: which method to use to reduce the samples for the
      elastic loss. 'median' selects the median depth point sample while
      'weight' computes a weighted sum using the density weights.
    use_background_loss: if True use the background regularization loss.

  Returns:
    new_state: model_utils.TrainState, new training state.
    stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
  """
  rng_key, fine_key, coarse_key, reg_key = random.split(rng_key, 4)

  # pylint: disable=unused-argument
  def _compute_loss_and_stats(params, model_out, use_elastic_loss=False):
    rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean()
    stats = {
        'loss/rgb': rgb_loss,
    }
    loss = rgb_loss
    if use_elastic_loss:
      v_elastic_fn = jax.jit(vmap(vmap(compute_elastic_loss)))
      weights = lax.stop_gradient(model_out['weights'])
      jacobian = model_out['warp_jacobian']
      # Pick the median point Jacobian.
      if elastic_reduce_method == 'median':
        depth_indices = model_utils.compute_depth_index(weights)
        jacobian = jnp.take_along_axis(
            # Unsqueeze axes: sample axis, Jacobian row, Jacobian col.
            jacobian, depth_indices[..., None, None, None], axis=-3)
      # Compute loss using Jacobian.
      elastic_loss, elastic_residual = v_elastic_fn(jacobian)
      # Multiply weight if weighting by density.
      if elastic_reduce_method == 'weight':
        elastic_loss = weights * elastic_loss
      elastic_loss = elastic_loss.sum(axis=-1).mean()
      stats['loss/elastic'] = elastic_loss
      stats['residual/elastic'] = jnp.mean(elastic_residual)
      loss += scalar_params.elastic_loss_weight * elastic_loss

    if 'warp_jacobian' in model_out:
      jacobian = model_out['warp_jacobian']
      jacobian_det = jnp.linalg.det(jacobian)
      jacobian_div = utils.jacobian_to_div(jacobian)
      jacobian_curl = utils.jacobian_to_curl(jacobian)
      stats['metric/jacobian_det'] = jnp.mean(jacobian_det)
      stats['metric/jacobian_div'] = jnp.mean(jacobian_div)
      stats['metric/jacobian_curl'] = jnp.mean(
          jnp.linalg.norm(jacobian_curl, axis=-1))

    stats['loss/total'] = loss
    stats['metric/psnr'] = utils.compute_psnr(rgb_loss)
    return loss, stats

  def _loss_fn(params):
    ret = model.apply({'params': params['model']},
                      batch,
                      warp_alpha=state.warp_alpha,
                      rngs={
                          'fine': fine_key,
                          'coarse': coarse_key
                      })

    losses = {}
    stats = {}
    if 'fine' in ret:
      losses['fine'], stats['fine'] = _compute_loss_and_stats(
          params, ret['fine'])
    if 'coarse' in ret:
      losses['coarse'], stats['coarse'] = _compute_loss_and_stats(
          params, ret['coarse'], use_elastic_loss=use_elastic_loss)

    if use_background_loss:
      background_loss = compute_background_loss(
          model,
          state=state,
          params=params['model'],
          key=reg_key,
          points=batch['background_points'],
          noise_std=scalar_params.background_noise_std)
      background_loss = background_loss.mean()
      losses['background'] = (
          scalar_params.background_loss_weight * background_loss)
      stats['background_loss'] = background_loss

    return sum(losses.values()), stats

  optimizer = state.optimizer
  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
  (_, stats), grad = grad_fn(optimizer.target)
  grad = jax.lax.pmean(grad, axis_name='batch')
  stats = jax.lax.pmean(stats, axis_name='batch')
  new_optimizer = optimizer.apply_gradient(
      grad, learning_rate=scalar_params.learning_rate)
  new_state = state.replace(optimizer=new_optimizer)
  return new_state, stats, rng_key
Ejemplo n.º 5
0
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
    r"""Creates a function which evaluates both ``fun`` and the grad of ``fun``.

  NOTE: You only need this in a very specific case that you want to take a
  gradient **inside** a :func:`transform`\ ed function and the function you are
  differentiating uses :func:`set_state`. For example:

  >>> class MyModule(hk.Module):
  ...   def __call__(self, x):
  ...     hk.set_state("last", jnp.sum(x))
  ...     return x ** 2

  >>> def f(x):
  ...   m = MyModule()
  ...   y, g = hk.value_and_grad(m)(x)
  ...   return y, g

  >>> f = hk.transform_with_state(f)
  >>> x = jnp.array(2.)
  >>> _ = jax.jit(f.init)(None, x)

  Args:
    fun: Function to be differentiated. Its arguments at positions specified by
      ``argnums`` should be arrays, scalars, or standard Python containers. It
      should return a scalar (which includes arrays with shape ``()`` but not
      arrays with shape ``(1,)`` etc.)
    argnums: Optional, integer or tuple of integers. Specifies which positional
      argument(s) to differentiate with respect to (default 0).
    has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
     first element is considered the output of the mathematical function to be
     differentiated and the second element is auxiliary data. Default False.
    holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
      holomorphic. Default False.

  Returns:
    A function with the same arguments as ``fun`` that evaluates both ``fun``
    and the gradient of ``fun`` and returns them as a pair (a two-element
    tuple). If ``argnums`` is an integer then the gradient has the same shape
    and type as the positional argument indicated by that integer. If argnums is
    a tuple of integers, the gradient is a tuple of values with the same shapes
    and types as the corresponding arguments.
  """
    if not base.inside_transform():
        raise ValueError(
            "hk.grad() should not be used outside of hk.transform(). "
            "Use jax.grad() instead.")

    @functools.wraps(fun)
    def stateful_fun(*args, **kwargs):
        state_in = kwargs.pop("hk_state")
        with temporary_internal_state(state_in):
            out = fun(*args, **kwargs)
            out, aux = (out if has_aux else (out, None))
            state_out = difference(state_in, internal_state())
            return out, (aux, state_out)

    grad_fun = jax.value_and_grad(stateful_fun,
                                  argnums=argnums,
                                  has_aux=True,
                                  holomorphic=holomorphic)

    @functools.wraps(grad_fun)
    def wrapper(*args, **kwargs):
        kwargs["hk_state"] = internal_state()
        (value, (aux, hk_state)), grads = grad_fun(*args, **kwargs)
        update_internal_state(hk_state)
        if has_aux:
            return (value, aux), grads
        else:
            return value, grads

    return wrapper
Ejemplo n.º 6
0
def fista_step(data, loss_and_prox_op, model_param, options):
    """Fista optimization step for solving regularized problem.

  Args:
    data: A tuple of inputs and labels passed to the loss function.
    loss_and_prox_op: Tuple of (loss_f, prox_g)
      loss_f is the loss function that takes in model_param, inputs, and labels.
      prox_g is the proximity operator for g.
    model_param: Current model parameters to be passed to loss_f.
    options: A dictionary of optimizer specific hyper-parameters.

  Returns:
    Updated model parameters and updated step size.
  """
    options = dict(options)
    step_size = options.get('step_size', 1.0)
    acceleration = options.get('acceleration', True)
    t = options.get('t', 1.0)
    verbose = options.get('verbose', False)
    reuse_last_step = options.get('reuse_last_step', False)

    loss_f, prox_g = loss_and_prox_op
    inputs, labels = data[0], data[1]
    fun_f = lambda param: loss_f(param, inputs, labels)
    value_and_grad_f = jax.value_and_grad(fun_f)
    x, unravel_fn = ravel_pytree(model_param)
    y = options.get('y', x)
    value_f, grad_f = value_and_grad_f(unravel_fn(y))
    grad_f, unravel_fn = ravel_pytree(grad_f)

    def next_candidate(step_size):
        return prox_g(y - grad_f * step_size, step_size)

    def stop_cond(step_size, next_iter):
        diff = next_iter - y
        sqdist = jnp.sum(diff**2)

        # We do not compute the non-smooth term (g in the paper)
        # as it cancels out from value_F and value_Q.
        value_bigf = fun_f(next_iter)
        value_bigq = value_f + jnp.sum(
            diff * grad_f) + 0.5 / step_size * sqdist
        return value_bigf <= value_bigq

    x_old = x

    step_size, x = backtracking(next_candidate, stop_cond, step_size, options)

    # Acceleration.
    if acceleration:
        t_next = (1 + jnp.sqrt(1 + 4 * t**2)) / 2.
        y = x + (t - 1) / t_next * (x - x_old)
        t = t_next
        options['y'] = y
        options['t'] = t
    else:
        y = x

    if reuse_last_step:
        options['step_size'] = step_size
    if verbose:
        logging.info('Step size: %f', step_size)

    return unravel_fn(x), options
Ejemplo n.º 7
0
 def __init__(self, f, argnums=0):
     self.f = f
     self.argnums = argnums
     self.vg = jax.value_and_grad(self.f, argnums)
     self.hess = hessian(self.f)
Ejemplo n.º 8
0
def find_reasonable_step_size(potential_fn, kinetic_fn, momentum_generator,
                              inverse_mass_matrix, position, rng_key,
                              init_step_size):
    """
    Finds a reasonable step size by tuning `init_step_size`. This function is used
    to avoid working with a too large or too small step size in HMC.

    **References:**

    1. *The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo*,
       Matthew D. Hoffman, Andrew Gelman

    :param potential_fn: A callable to compute potential energy.
    :param kinetic_fn: A callable to compute kinetic energy.
    :param momentum_generator: A generator to get a random momentum variable.
    :param inverse_mass_matrix: Inverse of mass matrix.
    :param position: Current position of the particle.
    :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
    :param float init_step_size: Initial step size to be tuned.
    :return: a reasonable value for step size.
    :rtype: float
    """
    # We are going to find a step_size which make accept_prob (Metropolis correction)
    # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
    # then we have to decrease step_size; otherwise, increase step_size.
    target_accept_prob = np.log(0.8)

    _, vv_update = velocity_verlet(potential_fn, kinetic_fn)
    z = position
    potential_energy, z_grad = value_and_grad(potential_fn)(z)
    finfo = np.finfo(get_dtype(init_step_size))

    def _body_fn(state):
        step_size, _, direction, rng_key = state
        rng_key, rng_key_momentum = random.split(rng_key)
        # scale step_size: increase 2x or decrease 2x depends on direction;
        # direction=1 means keep increasing step_size, otherwise decreasing step_size.
        # Note that the direction is -1 if delta_energy is `NaN`, which may be the
        # case for a diverging trajectory (e.g. in the case of evaluating log prob
        # of a value simulated using a large step size for a constrained sample site).
        step_size = (2.0**direction) * step_size
        r = momentum_generator(inverse_mass_matrix, rng_key_momentum)
        _, r_new, potential_energy_new, _ = vv_update(
            step_size, inverse_mass_matrix, (z, r, potential_energy, z_grad))
        energy_current = kinetic_fn(inverse_mass_matrix, r) + potential_energy
        energy_new = kinetic_fn(inverse_mass_matrix,
                                r_new) + potential_energy_new
        delta_energy = energy_new - energy_current
        direction_new = np.where(target_accept_prob < -delta_energy, 1, -1)
        return step_size, direction, direction_new, rng_key

    def _cond_fn(state):
        step_size, last_direction, direction, _ = state
        # condition to run only if step_size is not too small or we are not decreasing step_size
        not_small_step_size_cond = (step_size > finfo.tiny) | (direction >= 0)
        # condition to run only if step_size is not too large or we are not increasing step_size
        not_large_step_size_cond = (step_size < finfo.max) | (direction <= 0)
        not_extreme_cond = not_small_step_size_cond & not_large_step_size_cond
        return not_extreme_cond & ((last_direction == 0) |
                                   (direction == last_direction))

    step_size, _, _, _ = while_loop(_cond_fn, _body_fn,
                                    (init_step_size, 0, 0, rng_key))
    return step_size
Ejemplo n.º 9
0
def train_step(model, rng, state, batch, lr):
    """One optimization step.

  Args:
    model: The linen model.
    rng: jnp.ndarray, random number generator.
    state: utils.TrainState, state of the model/optimizer.
    batch: dict, a mini-batch of data for training.
    lr: float, real-time learning rate.

  Returns:
    new_state: utils.TrainState, new training state.
    stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
    rng: jnp.ndarray, updated random number generator.
  """
    rng, key_0, key_1 = random.split(rng, 3)

    def loss_fn(variables):
        rays = batch["rays"]
        ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
        psnr = utils.compute_psnr(loss)
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
            psnr_c = utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        def tree_sum_fn(fn):
            return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
                                             variables,
                                             initializer=0)

        weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) /
                     tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

        stats = utils.Stats(loss=loss,
                            psnr=psnr,
                            loss_c=loss_c,
                            psnr_c=psnr_c,
                            weight_l2=weight_l2)
        return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats

    (_,
     stats), grad = (jax.value_and_grad(loss_fn,
                                        has_aux=True)(state.optimizer.target))
    grad = jax.lax.pmean(grad, axis_name="batch")
    stats = jax.lax.pmean(stats, axis_name="batch")

    # Clip the gradient by value.
    if FLAGS.grad_max_val > 0:
        clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val
                                     )
        grad = jax.tree_util.tree_map(clip_fn, grad)

    # Clip the (possibly value-clipped) gradient by norm.
    if FLAGS.grad_max_norm > 0:
        grad_norm = jnp.sqrt(
            jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2),
                                      grad,
                                      initializer=0))
        mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
        grad = jax.tree_util.tree_map(lambda z: mult * z, grad)

    new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
    new_state = state.replace(optimizer=new_optimizer)
    return new_state, stats, rng
Ejemplo n.º 10
0
 def train_step(i, opt_state, batch):
     """Train for a single step."""
     params = get_params(opt_state)
     value_and_grad_fn = jax.value_and_grad(loss_fn)
     loss, grad = value_and_grad_fn(params, batch)
     return opt_update(i, grad, opt_state), loss
Ejemplo n.º 11
0
def fmin_bfgs(func, x0, args=(), options=None):
    """
    The BFGS algorithm from
        Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 136-143

        Notes:
            We utilise boolean arithmetic to avoid jax.cond calls which don't work on accelerators.
            A side effect is that we perform more gradient evaluations than scipy's BFGS
        func: callable
            Function of the form f(x) where x is a flat ndarray and returns a real scalar. The function should be
            composed of operations with vjp defined. If func is jittable then fmin_bfgs is jittable. If func is
            not jittable, then _nojit should be set to True.

        x0: ndarray
            initial variable
        args: tuple, optional
            Extra arguments to pass to func as func(x,*args)
        options: Optional dict of parameters
            maxiter: int
                Maximum number of evaluations
            norm: float
                Order of norm for convergence check. Default inf.
            gtol: flat
                Terminates minimization when |grad|_norm < g_tol
            ls_maxiter: int
                Maximum number of linesearch iterations

    Returns: BFGSResults

    """

    if options is None:
        options = dict()
    maxiter: Optional[int] = options.get('maxiter', None)
    norm: float = options.get('norm', jnp.inf)
    gtol: float = options.get('gtol', 1e-5)
    ls_maxiter: int = options.get('ls_maxiter', 10)

    state = BFGSResults(converged=False,
                        failed=False,
                        k=0,
                        nfev=0,
                        ngev=0,
                        nhev=0,
                        x_k=x0,
                        f_k=None,
                        g_k=None,
                        H_k=None,
                        status=None,
                        ls_status=jnp.array(0))

    if maxiter is None:
        maxiter = jnp.size(x0) * 200

    d = x0.shape[0]

    initial_H = jnp.eye(d)
    initial_H = options.get('hess_inv', initial_H)

    def func_with_args(x):
        return func(x, *args)

    value_and_grad = jax.value_and_grad(func_with_args)

    f_0, g_0 = value_and_grad(x0)
    state = state._replace(f_k=f_0, g_k=g_0, H_k=initial_H, nfev=state.nfev + 1, ngev=state.ngev + 1,
                           converged=jnp.linalg.norm(g_0, ord=norm) < gtol)

    def body(state):
        p_k = -(state.H_k @ state.g_k)
        line_search_results = line_search(value_and_grad, state.x_k, p_k, old_fval=state.f_k, gfk=state.g_k,
                                          maxiter=ls_maxiter)
        state = state._replace(nfev=state.nfev + line_search_results.nfev,
                               ngev=state.ngev + line_search_results.ngev,
                               failed=line_search_results.failed,
                               ls_status=line_search_results.status)
        s_k = line_search_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = line_search_results.f_k
        g_kp1 = line_search_results.g_k
        # print(g_kp1)
        y_k = g_kp1 - state.g_k
        rho_k = jnp.reciprocal(y_k @ s_k)

        sy_k = s_k[:, None] * y_k[None, :]
        w = jnp.eye(d) - rho_k * sy_k
        H_kp1 = jnp.where(jnp.isfinite(rho_k),
                          jnp.linalg.multi_dot([w, state.H_k, w.T]) + rho_k * s_k[:, None] * s_k[None, :], state.H_k)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        state = state._replace(converged=converged,
                               k=state.k + 1,
                               x_k=x_kp1,
                               f_k=f_kp1,
                               g_k=g_kp1,
                               H_k=H_kp1
                               )

        return state

    state = while_loop(
        lambda state: (~ state.converged) & (~state.failed) & (state.k < maxiter),
        body,
        state)

    state = state._replace(status=jnp.where(state.converged, jnp.array(0),#converged
                                             jnp.where(state.k == maxiter, jnp.array(1),#max iters reached
                                                       jnp.where(state.failed, jnp.array(2)+state.ls_status,#ls failed (+ reason)
                                                                 jnp.array(-1)))))#undefined

    return state
Ejemplo n.º 12
0
from .. import get_backend, default_backend
from ..tensor.common import _TensorViewer
from .autodiff import AutoDiffOptimizerMixin
import jax


def _final_objective(pars, data, fixed_vals, model, objective, fixed_idx,
                     variable_idx):
    tensorlib, _ = get_backend()
    tv = _TensorViewer([fixed_idx, variable_idx])
    pars = tensorlib.astensor(pars)
    constrained_pars = tv.stitch([fixed_vals, pars])
    return objective(constrained_pars, data, model)[0]


_jitted_objective_and_grad = jax.jit(jax.value_and_grad(_final_objective),
                                     static_argnums=(3, 4, 5, 6))


class jax_optimizer(AutoDiffOptimizerMixin):
    """JAX Optimizer Backend."""
    def setup_minimize(self,
                       objective,
                       data,
                       pdf,
                       init_pars,
                       par_bounds,
                       fixed_vals=None):
        """
        Prepare Minimization for AutoDiff-Optimizer.
Ejemplo n.º 13
0
    def value_and_grad_f(*args, **kwargs):
        out_shape = eval_shape(fun, *args, has_aux=has_aux, **kwargs)

        _args_iterable = (argnums, ) if isinstance(argnums, int) else argnums

        # only check if derivable arguments are complex
        if tree_leaf_iscomplex([args[i] for i in _args_iterable]):
            if is_complex(out_shape):  # C -> C
                return jax.value_and_grad(
                    fun,
                    argnums=argnums,
                    has_aux=has_aux,
                    allow_int=allow_int,
                    holomorphic=True,
                )(*args, **kwargs)
            else:  # C -> R
                raise RuntimeError(
                    "C->R function detected, but not supported.")
        else:
            if is_complex(out_shape):  # R -> C

                def grad_rc(*args, **kwargs):
                    if has_aux:

                        def real_fun(*args, **kwargs):
                            val, aux = fun(*args, **kwargs)
                            return val.real, aux

                        def imag_fun(*args, **kwargs):
                            val, aux = fun(*args, **kwargs)
                            return val.imag, aux

                        out_r, grad_r, aux = jax.value_and_grad(
                            real_fun,
                            argnums=argnums,
                            has_aux=True,
                            allow_int=allow_int)(*args, **kwargs)
                        out_j, grad_j, _ = jax.value_and_grad(
                            imag_fun,
                            argnums=argnums,
                            has_aux=True,
                            allow_int=allow_int)(*args, **kwargs)

                    else:
                        real_fun = lambda *args, **kwargs: fun(
                            *args, **kwargs).real
                        imag_fun = lambda *args, **kwargs: fun(
                            *args, **kwargs).imag

                        out_r, grad_r = jax.value_and_grad(
                            real_fun,
                            argnums=argnums,
                            has_aux=False,
                            allow_int=allow_int,
                        )(*args, **kwargs)
                        out_j, grad_j = jax.value_and_grad(
                            imag_fun,
                            argnums=argnums,
                            has_aux=False,
                            allow_int=allow_int,
                        )(*args, **kwargs)

                    out = out_r + 1j * out_j
                    grad = jax.tree_map(lambda re, im: re + 1j * im, grad_r,
                                        grad_j)

                    if has_aux:
                        return out, grad, aux
                    else:
                        return out, grad

                return grad_rc(*args, **kwargs)
            else:  # R -> R
                return jax.value_and_grad(fun,
                                          argnums=argnums,
                                          has_aux=has_aux,
                                          allow_int=allow_int)(*args, **kwargs)
    def test_freeze_feature_encoder(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        input_values = inputs_dict["input_values"]
        attention_mask = inputs_dict["attention_mask"]

        model = FlaxWav2Vec2ForPreTraining(config)
        params = model.params

        # dummy loss function
        def compute_loss(params,
                         input_values,
                         attention_mask,
                         freeze_feature_encoder: bool = False,
                         epsilon: float = 1e-8):
            outputs = model(
                input_values,
                attention_mask=attention_mask,
                freeze_feature_encoder=freeze_feature_encoder,
                params=params,
            )
            # compute cosine similarity of projected and projected_quantized states
            cosine_sim = optax.cosine_similarity(
                outputs.projected_states,
                outputs.projected_quantized_states,
                epsilon=epsilon)
            loss = cosine_sim.sum()
            return loss, outputs.to_tuple()

        # transform the loss function to get the gradients
        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)

        # compute loss, outputs and gradients for unfrozen model
        (loss, outputs), grads = grad_fn(params,
                                         input_values,
                                         attention_mask,
                                         freeze_feature_encoder=False)

        # compare to loss, outputs and gradients for frozen model
        (loss_frozen,
         outputs_frozen), grads_frozen = grad_fn(params,
                                                 input_values,
                                                 attention_mask,
                                                 freeze_feature_encoder=True)

        # ensure that the outputs and losses remain precisely equal
        for output, output_frozen in zip(outputs, outputs_frozen):
            self.assertTrue((output == output_frozen).all())
        self.assertEqual(loss, loss_frozen)

        grads = flatten_dict(grads)
        grads_frozen = flatten_dict(grads_frozen)

        # ensure that the dicts of gradients contain the same keys
        self.assertEqual(grads.keys(), grads_frozen.keys())

        # ensure that the gradients of the feature extractor layers are precisely zero when frozen and contain non-zero entries when unfrozen
        feature_extractor_grads = tuple(grads[k] for k in grads
                                        if "feature_extractor" in k)
        feature_extractor_grads_frozen = tuple(grads_frozen[k]
                                               for k in grads_frozen
                                               if "feature_extractor" in k)

        for feature_extractor_grad, feature_extractor_grad_frozen in zip(
                feature_extractor_grads, feature_extractor_grads_frozen):
            self.assertTrue((feature_extractor_grad_frozen == 0.0).all())
            self.assertTrue((feature_extractor_grad > 0.0).any())

        # ensure that the gradients of all unfrozen layers remain equal, i.e. all layers excluding the frozen 'feature_extractor'
        grads = tuple(grads[k] for k in grads if "feature_extractor" not in k)
        grads_frozen = tuple(grads_frozen[k] for k in grads_frozen
                             if "feature_extractor" not in k)

        for grad, grad_frozen in zip(grads, grads_frozen):
            self.assertTrue((grad == grad_frozen).all())
Ejemplo n.º 15
0
def find_connected_curve(w1,
                         w2,
                         loss_fn,
                         n_bends=3,
                         train_steps=100,
                         lr=0.05,
                         scheduler_name='constant',
                         batch_size=4,
                         log_every=1,
                         seed=43,
                         use_jit=False):
    """ Find a curve connecting two local minimas

    Args:
        w1: jnp.array, the first local minima
        w2: jnp.array, the second local minima
        loss_fn: callable,
        n_bends: int, number of bends between endpoints
        train_steps: int, number of training steps
        lr: float, initial learning rate
        scheduler_name: str, scheduler name
        batch_size: int, number time step samples
        log_every: int,
        seed: int, random seed
        use_jit: bool, whether to use jit-compilation of gradient function.

    Returns:
        curve function
    """

    print('Find a curve connecting two local minimas')
    curve = get_beizer_curve(n_bends + 2)  # includes the endpoints

    def curve_loss(t_, params_):  # t_ should be (m,) vector.
        params_ = jnp.vstack([w1, params_, w2])
        c = curve(t_, params_)
        loss_sum = 0
        for c_ in c:
            loss_sum += loss_fn(c_)
        return loss_sum / float(len(t_))

    # differentiate the 2nd argument
    print('Get the gradient of the loss function')
    grad_loss_fn = jax.value_and_grad(curve_loss, argnums=1)
    if use_jit:
        print('Use jit compilation')
        grad_loss_fn = jax.jit(grad_loss_fn)

    start_step = 0
    history = {'loss': [], 'grad': [], 'params': []}
    min_loss = float('inf')

    scheduler = qnnops.get_scheduler(lr, train_steps, name=scheduler_name)
    init_fun, update_fun, get_params = qnnops.get_optimizer(
        'adam', None, scheduler)
    # Pick evenly divided points from the line segment between w1 and w2.
    alpha = jnp.linspace(0, 1, n_bends + 2)

    init_bends = jnp.vstack(
        [alpha[i] * w1 + (1 - alpha[i]) * w2 for i in range(1, n_bends + 1)])

    print('State initialization')
    optimizer_state = init_fun(init_bends)
    params = get_params(optimizer_state)
    rng = jax.random.PRNGKey(seed)

    for step in range(start_step, train_steps):
        rng, key = jax.random.split(rng)

        # Loss = E[ L(\phi_\theta(t))]  where t ~ Unif(0, 1)
        t = jax.random.uniform(key, (batch_size, ))
        loss, grad = grad_loss_fn(t, params)
        optimizer_state = update_fun(step, grad, optimizer_state)
        params = get_params(optimizer_state)

        history['loss'].append(loss.item())  # scalar?
        history['grad'].append(onp.array(grad))
        history['params'].append(onp.array(params))

        if loss < min_loss:
            min_loss = loss
            save_checkpoints('best', step, w1, w2, params, history,
                             optimizer_state)

        if step % log_every == 0:
            grad_norm = jnp.linalg.norm(grad).item()
            logging_output = OrderedDict(loss=loss.item(),
                                         lr=scheduler(step),
                                         grad_norm=grad_norm)
            logging_output['min_loss'] = min_loss.item()
            expmgr.log(step, logging_output)
            save_checkpoints('last', step, w1, w2, params, history,
                             optimizer_state)

        del loss, grad
        gc.collect()
Ejemplo n.º 16
0
 def optimise(epoch, m, s, params, xs, ys):
     lossʹ, grads = jax.value_and_grad(loss)(params, xs, ys)
     m, s, params = jb.adabeliefʹ(epoch, grads, m, s, params)
     return m, s, params, lossʹ
Ejemplo n.º 17
0
def linear_model_loss_regularized_dw_lp(w, x, y, reg_coeff, norm_type):
    """Penalize by ||dL/dx||."""
    loss_and_dloss_dw_f = jax.value_and_grad(linear_model_loss, argnums=0)
    loss, dloss_dw = loss_and_dloss_dw_f(w, x, y)
    reg = norm_f(dloss_dw, norm_type)
    return loss + reg_coeff * reg
Ejemplo n.º 18
0
 def fun(flat_nb_params):
     v, g = value_and_grad(flat_loss)(flat_nb_params)
     return float(v), np.array(g)
Ejemplo n.º 19
0
            return cql_term

        def total_critic_loss(q_params, policy_params, target_q_params, alpha,
                              cql_alpha, transitions, key):
            critic_loss_term = critic_loss(q_params, policy_params,
                                           target_q_params, alpha, transitions,
                                           key)
            cql_term = cql_loss(q_params, policy_params, transitions, key)
            total = critic_loss_term + cql_alpha * cql_term
            return total, {
                'critic_loss': critic_loss_term,
                'cql_loss': cql_term
            }

        critic_grad = jax.value_and_grad(total_critic_loss, has_aux=True)

        def critic_update_step(q_params, target_q_params, optim_state,
                               policy_params, alpha, cql_alpha, transitions,
                               key):
            total_critic_loss_and_aux, critic_grads = critic_grad(
                q_params, policy_params, target_q_params, alpha, cql_alpha,
                transitions, key)
            critic_grads = jax.lax.pmean(critic_grads, 'devices')
            # Apply critic gradients
            critic_update, optim_state = q_optimizer.update(
                critic_grads, optim_state)
            q_params = optax.apply_updates(q_params, critic_update)

            target_q_params = jax.tree_multimap(
                lambda x, y: x * (1 - tau) + y * tau, target_q_params,
Ejemplo n.º 20
0
def unroll(x_init, theta, t_current, T, K):
    L_sum = 0.0
    initial_state = (L_sum, x_init, theta, t_current, T, K)
    iterations = jax.lax.iota(
        jnp.int32, K
    )  # This is like jnp.arange(0, K), I don't think iota works for different starting points than 0...
    state, outputs = jax.lax.scan(update, initial_state, iterations)
    (L, x_current, theta, t_current, T, K) = state
    x_trajectory = jnp.stack(outputs).T
    return L, (x_current, x_trajectory, t_current)


grad_unroll = jax.jit(jax.grad(unroll, argnums=1, has_aux=True),
                      static_argnums=(3, 4))
loss_and_grad_unroll = jax.jit(jax.value_and_grad(unroll,
                                                  argnums=1,
                                                  has_aux=True),
                               static_argnums=(3, 4))


# Functions for RTRL/UORO
# ----------------------------------------------------------------------------------
@jax.jit
def single_step(theta, x, t_current, T):
    g = loss_grad(x)
    lr = jnp.exp(theta[0]) * (T - t_current) / T + jnp.exp(
        theta[1]) * t_current / T
    x_new = x - lr * g
    return x_new

Ejemplo n.º 21
0
def optimize_params(c,
                    incidence_scenarios,
                    parameterizer,
                    loss_fn=None,
                    optimization_params=None,
                    verbose=True):
    """Modifies c in place subject to parameterizer to minimize a loss function.

  Picks params and runs parameterizer.xr_apply_params(c, params) to minimize
  loss_fn on the control arm event curves given incidence_scenarios. The final
  optimized params are stored in c['final_params'].

  Args:
    c: xr.Dataset specifying the trial with all data_vars required to call
      sim.recruitment and sim.control_arm_events.
    incidence_scenarios: xr.DataArray of shape [scenario, location, time], the
      forecast incidence.
    parameterizer: a Parameterizer specifying what the trial planner can
      control.
    loss_fn: optional function which takes trajectories of control arm events
      (an jnp.array of shape [scenario, time]) and returns a jnp.array of losses
      of shape [scenario]. Defaults to negative_mean_successiness.
    optimization_params: optional dict of stuff related to how to do the
      optimization.
    verbose: if True, print some stuff.
  """
    # Run non-differentiable simulation once, because they do more error checking.
    participants = sim.recruitment(c)
    sim.control_arm_events(c, participants, incidence_scenarios)

    c_ = sim.JaxDataset(c)
    incidence_scenarios_ = jnp.asarray(
        incidence_scenarios.transpose('scenario', 'location', 'time'))
    _, recruitment_fn, _ = sim.RECRUITMENT_REGISTRY[sim.get_recruitment_type(
        c)]
    historical_events = c.historical_control_arm_events
    if 'location' in historical_events.dims:
        historical_events = c.historical_control_arm_events.sum('location')
    historical_events_ = jnp.array(historical_events.values)

    if loss_fn is None:
        loss_fn = negative_mean_successiness(c)

    def loss(params):
        parameterizer.apply_params(c_, params)
        participants_ = recruitment_fn(c_)
        control_arm_events_ = sim.differentiable_control_arm_events(
            c_, participants_, incidence_scenarios_)
        historical_ = jnp.broadcast_to(
            historical_events_,
            control_arm_events_.shape[:1] + historical_events_.shape)
        control_arm_events_ = jnp.concatenate(
            [historical_, control_arm_events_], axis=1)
        return loss_fn(control_arm_events_).mean()

    if optimization_params is None:
        optimization_params = dict()
    min_steps = optimization_params.get('min_steps', 40)
    max_steps = optimization_params.get('max_steps', 200)
    epsilon = optimization_params.get('epsilon', 1e-3)
    smoothing_window = optimization_params.get('smoothing_window', 20)
    learning_rate = optimization_params.get('learning_rate', 0.1)
    optimizer = optimization_params.get('optimizer', optim.Adam(learning_rate))
    initial_params = parameterizer.init_params()
    optimizer = optimizer.create(jnp.array(initial_params.values))

    loss_curve = []
    while True:
        loss_value, grad = jax.value_and_grad(loss)(optimizer.target)
        loss_curve.append(float(loss_value))
        optimizer = optimizer.apply_gradient(grad)
        step = len(loss_curve)
        if (step > min_steps and
            (loss_curve[-smoothing_window] - loss_curve[-1]) < epsilon):
            # Not much progress recently. Call it a day.
            break
        if step >= max_steps:
            print('Hit max step limit. You can control this exit condition by '
                  'setting max_steps in optimized_params.')
            break
        if verbose and (step % 10 == 0):
            print(f'step {step}, loss value {loss_curve[-1]}')

    final_params = np.array(optimizer.target)
    c['final_params'] = xr.DataArray(final_params,
                                     coords=initial_params.coords)
    parameterizer.xr_apply_params(c, c.final_params)
Ejemplo n.º 22
0
    def test_implicit_differentiation_versus_autodiff(self, lse_mode):
        epsilon = 0.05

        def loss_g(a, x, implicit=True):
            out = sinkhorn.sinkhorn(geometry.Geometry(
                cost_matrix=jnp.sum(x**2, axis=1)[:, jnp.newaxis] +
                jnp.sum(self.y**2, axis=1)[jnp.newaxis, :] -
                2 * jnp.dot(x, self.y.T),
                epsilon=epsilon),
                                    a=a,
                                    b=self.b,
                                    tau_a=0.8,
                                    tau_b=0.87,
                                    threshold=1e-4,
                                    lse_mode=lse_mode,
                                    implicit_differentiation=implicit)
            return out.reg_ot_cost

        def loss_pcg(a, x, implicit=True):
            out = sinkhorn.sinkhorn(pointcloud.PointCloud(x,
                                                          self.y,
                                                          epsilon=epsilon),
                                    a=a,
                                    b=self.b,
                                    tau_a=1.0,
                                    tau_b=0.95,
                                    threshold=1e-4,
                                    lse_mode=lse_mode,
                                    implicit_differentiation=implicit)
            return out.reg_ot_cost

        for loss in [loss_g, loss_pcg]:
            loss_and_grad_imp = jax.jit(
                jax.value_and_grad(lambda a, x: loss(a, x, True),
                                   argnums=(0, 1)))
            loss_and_grad_auto = jax.jit(
                jax.value_and_grad(lambda a, x: loss(a, x, False),
                                   argnums=(0, 1)))

            loss_value_imp, grad_loss_imp = loss_and_grad_imp(self.a, self.x)
            loss_value_auto, grad_loss_auto = loss_and_grad_auto(
                self.a, self.x)

            self.assertAllClose(loss_value_imp, loss_value_auto)
            eps = 1e-3

            # test gradient w.r.t. a works and gradient implicit ~= gradient autodiff
            delta = jax.random.uniform(self.rngs[4], (self.n, )) / 10
            delta = delta - jnp.mean(delta)  # center perturbation
            reg_ot_delta_plus = loss(self.a + eps * delta, self.x)
            reg_ot_delta_minus = loss(self.a - eps * delta, self.x)
            delta_dot_grad = jnp.sum(delta * grad_loss_imp[0])
            self.assertAllClose(delta_dot_grad,
                                (reg_ot_delta_plus - reg_ot_delta_minus) /
                                (2 * eps),
                                rtol=1e-02,
                                atol=1e-02)
            # note how we removed gradients below. This is because gradients are only
            # determined up to additive constant here (the primal variable is in the
            # simplex).
            self.assertAllClose(grad_loss_imp[0] - jnp.mean(grad_loss_imp[0]),
                                grad_loss_auto[0] -
                                jnp.mean(grad_loss_auto[0]),
                                rtol=1e-02,
                                atol=1e-02)

            # test gradient w.r.t. x works and gradient implicit ~= gradient autodiff
            delta = jax.random.uniform(self.rngs[4], (self.n, self.dim))
            reg_ot_delta_plus = loss(self.a, self.x + eps * delta)
            reg_ot_delta_minus = loss(self.a, self.x - eps * delta)
            delta_dot_grad = jnp.sum(delta * grad_loss_imp[1])
            self.assertAllClose(delta_dot_grad,
                                (reg_ot_delta_plus - reg_ot_delta_minus) /
                                (2 * eps),
                                rtol=1e-02,
                                atol=1e-02)
            self.assertAllClose(grad_loss_imp[1],
                                grad_loss_auto[1],
                                rtol=1e-02,
                                atol=1e-02)
Ejemplo n.º 23
0
  masked_loss = jnp.abs(durations - x.durations) * mask
  loss = jnp.sum(masked_loss) / jnp.sum(mask)
  return loss, aux


forward_fn = jax.jit(hk.transform_with_state(lambda x: DurationModel(is_training=False)(x)).apply)


def predict_duration(params, aux, rng, x: DurationInput):
  d, _ = forward_fn(params, aux, rng, x)
  return d, x.durations


val_loss_fn = jax.jit(partial(loss_fn, is_training=False))

loss_vag = jax.value_and_grad(loss_fn, has_aux=True)

optimizer = optax.chain(
    optax.clip_by_global_norm(FLAGS.max_grad_norm),
    optax.adamw(FLAGS.duration_learning_rate, weight_decay=FLAGS.weight_decay)
)


@jax.jit
def update(params, aux, rng, optim_state, inputs: DurationInput):
  rng, new_rng = jax.random.split(rng)
  (loss, new_aux), grads = loss_vag(params, aux, rng, inputs)
  updates, new_optim_state = optimizer.update(grads, optim_state, params)
  new_params = optax.apply_updates(params, updates)
  return loss, (new_params, new_aux, new_rng, new_optim_state)
Ejemplo n.º 24
0
def _find_sigmoid_upper_bound_tangent(range_lb: Tensor, range_ub: Tensor,
                                      tol: float) -> Tensor:
    """Search the point where the concave hull of the sigmoid stops being linear.

  The concave upper bound of the sigmoid can be several things:
    - It can be the sigmoid itself (if the interval considered is in R+)
    - It can be linear (If the upper bound is small enough. This is a bit more
      general that just if the interval is in R-)
    - It can start linear and at some tangent point become sigmoid.

  This functions searches for the tangent point.
  For the other cases, another function would have narrowed the search range
  such that range_lb = range_ub and we early exit from the loop.

  This is a combination of a binary search and of the Newton method.

  Args:
    range_lb: Lower bound of the domain on which to define the convex hull.
    range_ub: Upper bound of the domain on which to define the convex hull.
    tol: Tolerance criterion for convergence
  Returns:
    final_t: Tangent point at which the concave upper bound of the sigmoid
      should go from linear to sigmoid. If range_lb == range_ub, that number
      should be returned.
  """
    flat_range_lb = jnp.reshape(range_lb, (-1, ))
    flat_range_ub = jnp.reshape(range_ub, (-1, ))

    fun = jax.nn.sigmoid
    dfun = lambda x: fun(x) * (1 - fun(x))

    # The point that we are looking for is the point where:
    #  dfun(x) = (fun(x) - fun(lb)) / (x - lb)
    to_root_fun = lambda x, lb: dfun(x) - (fun(x) - fun(lb)) / jnp.maximum(
        x - lb, tol)
    to_root_val_and_grad = jax.vmap(jax.value_and_grad(to_root_fun))

    search_lb = jnp.maximum(flat_range_lb, 0.)
    # In the case where l is very large (in the negative), we can have an
    # approximate solution. We can use this to shrink the search space, making the
    # binary search converge significantly faster.
    # If lb<-1e3, fun(lb)=0, so we can just solve:
    # dfun(x) = fun(x) / (x - lb).
    # <=> fun(X) (1 - fun(x)) = fun(x) / (x - lb)
    # <=> 1 - fun(x) = 1 / (x - lb)
    # <=> exp(-x) / (1 + exp(-x)) = 1 / (x - lb)
    # <=> exp(-x) * (x - lb - 1) = 1
    # <=> exp(x) - x = -lb -1
    # And we can assume that for large value, exp(x)-x ~ exp(x)
    # So we know that the optimal t is going to be close to log(-lb-1)
    # We add some padding (+1) around that value to make sure we are not excluding
    # a valid solution from the search space.
    upper_bound_for_large_l = jnp.where(
        flat_range_lb < -1e3,
        jnp.log(jnp.maximum(-flat_range_lb - 1, 1.)) + 1, float('inf'))
    search_ub = jnp.minimum(flat_range_ub, upper_bound_for_large_l)
    t_k = 0.5 * (search_lb + search_ub)
    it = jnp.array(0)

    def body_fun(loop_args):
        it, t, lb, ub = loop_args
        new_it = it + 1
        f, df = to_root_val_and_grad(t, flat_range_lb)
        new_lb = jnp.where(f >= 0., jnp.maximum(lb, t), lb)
        new_ub = jnp.where(f <= 0., jnp.minimum(ub, t), ub)
        newton_t = t - f / df
        out_of_bounds_t = (newton_t <= new_lb) | (newton_t >= new_ub)
        new_t = jnp.where((jnp.abs(df) <= tol) | out_of_bounds_t,
                          0.5 * (new_lb + new_ub), newton_t)
        return new_it, new_t, new_lb, new_ub

    def continue_search(loop_args):
        it, t, lb, ub = loop_args

        # Points that have not converged have both
        #   - high value on the difference between average slope and sig derivative
        #   - high value on the gap between upper bound and lower bound
        # If any one of this criterion is not satisfied, the point has converged.
        not_converged = ((jnp.abs(to_root_fun(t, flat_range_lb)) >= tol) &
                         (jnp.abs(ub - lb) >= tol))
        # We keep searching as long as:
        #   - we don't exceed 100 iterations
        #   - There is at least 1 point that has not converged.
        return jnp.logical_and(it <= 100, jnp.any(not_converged))

    _, final_t, _, _ = jax.lax.while_loop(continue_search, body_fun,
                                          (it, t_k, search_lb, search_ub))
    final_t = jnp.reshape(final_t, range_lb.shape)
    return final_t
Ejemplo n.º 25
0
def gradient_descent_line_search_step(data, loss_f, model_param, options):
    """Gradient Descent optimization with line search step.

  Args:
    data: A tuple of inputs and labels passed to the loss function.
    loss_f: The loss function that takes in model_param, inputs, and labels.
    model_param: Current model parameters to be passed to loss_f.
    options: A dictionary of optimizer specific hyper-parameters.

  Returns:
    Updated model parameters and updated step size.
  """
    options = dict(options)
    beta = options.get('beta', 0.9)
    beta_prime = options.get('beta_prime', 1e-4)
    step_size = options.get('step_size', 10000.0)
    verbose = options.get('verbose', False)
    reuse_last_step = options.get('reuse_last_step', False)

    inputs, labels = data[0], data[1]
    loss_with_data_f = lambda param: loss_f(param, inputs, labels)
    value_and_grad_f = jax.value_and_grad(loss_with_data_f)
    value, grad = value_and_grad_f(model_param)

    # Maximum learning rate allowed from Theorem 5 in Gunasekar et al. 2017
    if options['bound_step']:
        # Bound by dual of L2
        b_const = jnp.max(jnp.linalg.norm(inputs, ord=2, axis=0))
        step_size = min(step_size, 1 / (b_const * b_const * value))

    grad, unravel_fn = ravel_pytree(grad)
    x, unravel_fn = ravel_pytree(model_param)

    # If we normalize step_size will be harder to tune.
    direction = -grad

    # TODO(fartash): consider using the condition in FISTA
    def next_candidate(step_size):
        next_iter = x + step_size * direction
        next_value, next_grad = value_and_grad_f(unravel_fn(next_iter))
        next_grad, _ = ravel_pytree(next_grad)
        return next_iter, next_value, next_grad

    def stop_cond(step_size, res):
        _, next_value, next_grad = res
        gd = jnp.sum(grad * direction)

        # Strong Wolfe condition.
        cond1 = next_value <= value + beta_prime * step_size * gd
        cond2 = jnp.sum(jnp.abs(next_grad * direction)) >= beta * gd
        return cond1 and cond2

    step_size, res = backtracking(next_candidate,
                                  stop_cond,
                                  step_size,
                                  options=options)
    next_param = res[0]

    if reuse_last_step:
        options['step_size'] = step_size
    if verbose:
        logging.info('Step size: %f', step_size)

    return unravel_fn(next_param), options
Ejemplo n.º 26
0
def update(params, batch, opt_state):
    value, grads = value_and_grad(loss)(params, batch)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, value
Ejemplo n.º 27
0
import copt.penalty

X, y = datasets.make_regression()
n_samples, n_features = X.shape


def loss(w):
    """Squared error loss."""
    z = np.dot(X, w) - y
    return np.sum(z * z) / n_samples


# .. use JAX to compute the gradient of loss value_and_grad ..
# .. returns both the gradient and the objective, which is ..
# .. the format that COPT accepts ..
f_grad = jax.value_and_grad(loss)

w0 = onp.zeros(n_features)

l1_ball = copt.penalty.L1Norm(0.1)
cb = cp.utils.Trace(lambda x: loss(x) + l1_ball(x))
sol = cp.minimize_proximal_gradient(f_grad,
                                    w0,
                                    prox=l1_ball.prox,
                                    callback=cb,
                                    jac=True)
plt.plot(cb.trace_fx, lw=3)
plt.yscale("log")
plt.xlabel("# Iterations")
plt.ylabel("Objective value")
plt.grid()
Ejemplo n.º 28
0
 def __post_init__(self):
     # Pre-compile useful functions.
     self._value_and_grad_fun = jax.value_and_grad(self.fun,
                                                   has_aux=self.has_aux)
def train(target_network, optimizer, states, actions, next_states, rewards,
          terminals, num_tau_samples, num_tau_prime_samples,
          num_quantile_samples, cumulative_gamma, double_dqn, kappa, rng,
          mico_weight, distance_fn, tau, alpha, clip_value_min):
  """Run a training step."""
  def loss_fn(model, rng_input, target_quantile_vals, target_r, target_next_r):
    model_output = jax.vmap(
        lambda m, x, y, z: m(x=x, num_quantiles=y, rng=z),
        in_axes=(None, 0, None, None))(
            model, states, num_tau_samples, rng_input)
    quantile_values = model_output.quantile_values
    quantiles = model_output.quantiles
    representations = model_output.representation
    representations = jnp.squeeze(representations)
    chosen_action_quantile_values = jax.vmap(lambda x, y: x[:, y][:, None])(
        quantile_values, actions)
    # Shape of bellman_erors and huber_loss:
    # batch_size x num_tau_prime_samples x num_tau_samples x 1.
    bellman_errors = (target_quantile_vals[:, :, None, :] -
                      chosen_action_quantile_values[:, None, :, :])
    # The huber loss (see Section 2.3 of the paper) is defined via two cases:
    # case_one: |bellman_errors| <= kappa
    # case_two: |bellman_errors| > kappa
    huber_loss_case_one = (
        (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) *
        0.5 * bellman_errors ** 2)
    huber_loss_case_two = (
        (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) *
        kappa * (jnp.abs(bellman_errors) - 0.5 * kappa))
    huber_loss = huber_loss_case_one + huber_loss_case_two
    # Tile by num_tau_prime_samples along a new dimension. Shape is now
    # batch_size x num_tau_prime_samples x num_tau_samples x 1.
    # These quantiles will be used for computation of the quantile huber loss
    # below (see section 2.3 of the paper).
    quantiles = jnp.tile(quantiles[:, None, :, :],
                         [1, num_tau_prime_samples, 1, 1]).astype(jnp.float32)
    # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.
    quantile_huber_loss = (jnp.abs(quantiles - jax.lax.stop_gradient(
        (bellman_errors < 0).astype(jnp.float32))) * huber_loss) / kappa
    # Sum over current quantile value (num_tau_samples) dimension,
    # average over target quantile value (num_tau_prime_samples) dimension.
    # Shape: batch_size x num_tau_prime_samples x 1.
    quantile_huber_loss = jnp.sum(quantile_huber_loss, axis=2)
    quantile_huber_loss = jnp.mean(quantile_huber_loss, axis=1)
    online_dist = metric_utils.representation_distances(
        representations, target_r, distance_fn)
    target_dist = metric_utils.target_distances(
        target_next_r, rewards, distance_fn, cumulative_gamma)
    metric_loss = jnp.mean(jax.vmap(losses.huber_loss)(online_dist,
                                                       target_dist))
    loss = ((1. - mico_weight) * quantile_huber_loss +
            mico_weight * metric_loss)
    return jnp.mean(loss), (jnp.mean(quantile_huber_loss), metric_loss)

  if tau is None:
    rng, target_quantile_vals, target_r, target_next_r = target_quantile_values(
        optimizer.target,
        target_network,
        states,
        next_states,
        rewards,
        terminals,
        num_tau_prime_samples,
        num_quantile_samples,
        cumulative_gamma,
        double_dqn,
        rng)
  else:
    rng, target_quantile_vals, target_r, target_next_r = (
        munchausen_target_quantile_values(
            target_network,
            states,
            actions,
            next_states,
            rewards,
            terminals,
            num_tau_prime_samples,
            num_quantile_samples,
            cumulative_gamma,
            rng,
            tau,
            alpha,
            clip_value_min))
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  rng, rng_input = jax.random.split(rng)
  all_losses, grad = grad_fn(optimizer.target, rng_input, target_quantile_vals,
                             target_r, target_next_r)
  loss, component_losses = all_losses
  quantile_loss, metric_loss = component_losses
  optimizer = optimizer.apply_gradient(grad)
  return rng, optimizer, loss, quantile_loss, metric_loss
Ejemplo n.º 30
0
def main():
    with open('reuters_vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)

    v_dim = len(vocab['num_to_word'])
    e_dim = 1024

    prng_key = random.PRNGKey(0xdeadbeef)
    words = jnp.zeros(v_dim, dtype=jnp.float32)
    word_ix = random.uniform(prng_key, (1, )) * v_dim
    word_ix = int(jnp.floor(word_ix)[0])
    words = jops.index_update(words, word_ix, 1.0)

    # first create the architecture of the model
    mdl = Embedding(v_dim, e_dim)
    # then complete the model spec by giving an example input tensor
    params = mdl.init(prng_key, words)
    # now apply the params to the model with the input
    out = mdl.apply(params, words)
    print(f'out: {out}')
    print(f'shape: {out.shape}')

    # let's train the model on nltk's reuters dataset
    from nltk.corpus import reuters
    train_texts = []
    for fname in reuters.fileids():
        text = reuters.words(fname)
        train_texts.append(text)

    # now generate word-context elements
    window_size = 2
    word_pairs = []
    for words in tqdm(train_texts, desc='make train set'):
        for word_ix, word in enumerate(words):
            for offset in range(1, window_size + 1):
                back_context = word_ix - offset
                if back_context >= 0:
                    word_pairs.append((word, words[back_context]))
                fwd_context = word_ix + offset
                if fwd_context < len(words):
                    word_pairs.append((word, words[fwd_context]))

    # convert words to vocab IDs
    w2n = vocab['word_to_num']

    id_pairs = []
    for word_pair in tqdm(word_pairs, desc='gen word pairs'):
        word = word_pair[0]
        context = word_pair[1]
        if word in w2n and context in w2n:
            w_id, c_id = w2n[word], w2n[context]
            id_pairs.append((w_id, c_id))
    id_pairs = jnp.array(id_pairs)
    print(f'train pairs: {len(id_pairs)}')

    # run grad desc
    id_pairs = id_pairs[0:len(id_pairs) // 100]
    lr = 0.3
    batch_size = 2500

    # TEST: what if I run one at a time?
    '''
    loss_fn = lambda x, y : nll_loss_fn(mdl, params, x, y)
    grad_fn = jax.value_and_grad(loss_fn)
    grad_calc_fn = lambda params, x, y : grad_fn(params, x, y)
    param_update_fn = lambda old, grad: old -  lr * grad

    template_vec = jnp.zeros(v_dim, dtype=jnp.float32)
    for epoch in trange(5):
        for pair in tqdm(id_pairs):
            x = jops.index_update(template_vec, pair[0], 1.)
            y = jops.index_update(template_vec, pair[1], 1.)
            loss_val, grad = grad_calc_fn(params, x, y)
            params = jax.tree_multimap(param_update_fn, paramd, grad)
    import pdb; pdb.set_trace()
    pass
    '''
    # TEST END

    batches = jnp.split(id_pairs,
                        jnp.arange(batch_size, len(id_pairs), batch_size))

    for epoch in trange(1):
        # TODO: shuffle & batch id_pairs
        pbar = trange(len(batches), desc=f'epoch:--- - loss:------')
        for batch in batches:
            x_vals, y_vals = __id_to_one_hot(batch, v_dim)
            loss_fn = nll_loss_fn(mdl, params, x_vals, y_vals)
            grad_fn = jax.value_and_grad(loss_fn)
            loss_val, grad = grad_fn(params)
            params = jax.tree_multimap(lambda old, grad: old - lr * grad,
                                       params, grad)
            pbar.set_description(f'epoch:{epoch:03d} - loss:{loss_val:0.4f}')
            pbar.update()

    import pdb
    pdb.set_trace()
    print('done!')