Exemple #1
0
 def update_fun(_, param_and_state):
   params, opt_state = param_and_state
   updates, next_opt_state = self._opt.update(grad_fn(params), opt_state)
   next_params = optax.apply_updates(params, updates)
   next_params = jax.tree_multimap(
       lambda relax, param: relax.project_params(param),
       relaxations, next_params)
   return next_params, next_opt_state
Exemple #2
0
 def step(params, opt_state, batch, labels):
     (loss_val, accuracy), grads = jax.value_and_grad(loss,
                                                      has_aux=True)(params,
                                                                    batch,
                                                                    labels)
     updates, opt_state = opt.update(grads, opt_state, params)
     params = optax.apply_updates(params, updates)
     return params, opt_state, loss_val, accuracy
Exemple #3
0
def _compute_update(curr_params, optimizer_state, gradient,
                    optimizer_update_fn):
    model_updates, _ = optimizer_update_fn(gradient,
                                           optimizer_state,
                                           params=curr_params)
    new_params = optax.apply_updates(curr_params, model_updates)
    diff = _tree_sub(curr_params, new_params)
    return diff
Exemple #4
0
 def update(carry, rng):
     # TODO log the two losses separately?
     opt_state, params = carry
     xs_batch = vmap(sample_random_x)(random.split(rng, config.batch_size))
     batch_loss, g = value_and_grad(loss)(params, xs_batch)
     updates, opt_state = tx.update(g, opt_state)
     params = optax.apply_updates(params, updates)
     return (opt_state, params), batch_loss
Exemple #5
0
 def update(params_adv, opt_state, info_states, samp_regrets,
            iterations, masks, total_iterations):
     main_loss, grads = self._adv_grads(params_adv, info_states,
                                        samp_regrets, iterations, masks,
                                        total_iterations)
     updates, new_opt_state = self._opt_adv_update(grads, opt_state)
     new_params = optax.apply_updates(params_adv, updates)
     return new_params, new_opt_state, main_loss
 def train_step(carry, step):
     params, opt_state = carry
     loss, grads = value_and_grad(objective)(params)
     updates, opt_state = optimizer.update(grads, opt_state)
     params = optax.apply_updates(params, updates)
     callback_result = callback(params,
                                step) if callback is not None else None
     return (params, opt_state), (loss, callback_result)
Exemple #7
0
    def _computeUpdate(self, params: hk.Params, target: hk.Params, opt: Any,
                       batch: Batch):
        delta, grad = jax.value_and_grad(self._loss)(params, target, batch)

        updates, state = self.optimizer.update(grad, opt, params)
        params = optax.apply_updates(params, updates)

        return jnp.sqrt(delta), state, params
Exemple #8
0
 def sgd_step(state: AgentState,
              trajectory: buffer.Trajectory) -> AgentState:
     """Performs a step of SGD over a trajectory."""
     gradients = jax.grad(loss_fn)(state.params, trajectory)
     updates, new_opt_state = optimizer.update(gradients,
                                               state.opt_state)
     new_params = optax.apply_updates(state.params, updates)
     return AgentState(params=new_params, opt_state=new_opt_state)
Exemple #9
0
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState:
    """Does a step of SGD given inputs & targets."""
    _, optimizer = make_optimizer()
    _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
    gradients = jax.grad(loss_fn)(state.params, batch)
    updates, new_opt_state = optimizer(gradients, state.opt_state)
    new_params = optax.apply_updates(state.params, updates)
    return TrainingState(params=new_params, opt_state=new_opt_state)
 def train_step(params: hk.Params, rng_key: PRNGKey,
                opt_state: optax.OptState,
                batch: Batch) -> Tuple[hk.Params, optax.OptState]:
     """Single update step to maximize the ELBO."""
     grads = jax.grad(objective_fn)(params, rng_key, batch)
     updates, new_opt_state = optimizer.update(grads, opt_state)
     new_params = optax.apply_updates(params, updates)
     return new_params, new_opt_state
Exemple #11
0
 def update(self, minibatch):
     images, labels = minibatch['image'].astype(
         jnp.float32) / 255., minibatch['label']
     grads = jax.grad(self.loss_fn)(self.params, images, labels)
     updates, self.opt_state = self.optimizer.update(grads, self.opt_state)
     params = optax.apply_updates(self.params, updates)
     self.avg_params = ema_update(self.avg_params, params)
     self.log(net=self.network, net_params=params, datasets=self.datasets)
Exemple #12
0
 def update(rng_key, opt_state, online_params, target_params, transitions):
   """Computes learning update from batch of replay transitions."""
   rng_key, update_key = jax.random.split(rng_key)
   d_loss_d_params = jax.grad(loss_fn)(online_params, target_params,
                                       transitions, update_key)
   updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state)
   new_online_params = optax.apply_updates(online_params, updates)
   return rng_key, new_opt_state, new_online_params
Exemple #13
0
 def update(params, opt_state, batch, target, weights,
            rng) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]:
     batch_loss, grads = jax.value_and_grad(model_loss)(params, batch,
                                                        target, weights,
                                                        rng)
     updates, opt_state = optimizer.update(grads, opt_state)
     new_params = optax.apply_updates(params, updates)
     return new_params, opt_state, batch_loss
 def body(carry, i):
     key, state, params = carry
     key, subkey = random.split(key)
     (lp_val, param_grad), _ = estimate_gradient(i, subkey, params)
     neg_param_grad = tree_map(lambda x: -x, param_grad)
     updates, state = optimizer.update(neg_param_grad, state)
     params = optax.apply_updates(params, updates)
     return (key, state, params), (lp_val, ravel_pytree(params)[0])
Exemple #15
0
 def body_fn(inputs):
     it, x, _, opt_state = inputs
     grad_x = grad_fn(x)
     updates, opt_state = opt.update(grad_x, opt_state, x)
     x = optax.apply_updates(x, updates)
     x = jnp.clip(x, l, u)
     it = it + 1
     return it, x, grad_x, opt_state
Exemple #16
0
 def sgd_step(state: TrainingState,
              trajectory: sequence.Trajectory) -> TrainingState:
     """Does a step of SGD over a trajectory."""
     gradients = jax.grad(loss_fn)(state.params, trajectory)
     updates, new_opt_state = optimizer.update(gradients,
                                               state.opt_state)
     new_params = optax.apply_updates(state.params, updates)
     return TrainingState(params=new_params, opt_state=new_opt_state)
Exemple #17
0
 def update(params_policy, opt_state, info_states, action_probs, iterations,
            masks, total_iterations):
   main_loss, grads = self._policy_grads(params_policy, info_states,
                                         action_probs, iterations, masks,
                                         total_iterations)
   updates, new_opt_state = self._opt_policy_update(grads, opt_state)
   new_params = optax.apply_updates(params_policy, updates)
   return new_params, new_opt_state, main_loss
def train(network_def, online_params, target_params, optimizer,
          optimizer_state, states, actions, next_states, rewards, terminals,
          kappa, num_atoms, cumulative_gamma, mico_weight, distance_fn):
    """Run a training step."""
    def loss_fn(params, bellman_target, target_r, target_next_r):
        def q_online(state):
            return network_def.apply(params, state)

        model_output = jax.vmap(q_online)(states)
        logits = model_output.logits
        logits = jnp.squeeze(logits)
        representations = model_output.representation
        representations = jnp.squeeze(representations)
        # Fetch the logits for its selected action. We use vmap to perform this
        # indexing across the batch.
        chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions)
        bellman_errors = (bellman_target[:, None, :] -
                          chosen_action_logits[:, :, None]
                          )  # Input `u' of Eq. 9.
        # Eq. 9 of paper.
        huber_loss = (
            (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 0.5 *
            bellman_errors**2 +
            (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * kappa *
            (jnp.abs(bellman_errors) - 0.5 * kappa))

        tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) / num_atoms
                   )  # Quantile midpoints.  See Lemma 2 of paper.
        # Eq. 10 of paper.
        tau_bellman_diff = jnp.abs(tau_hat[None, :, None] -
                                   (bellman_errors < 0).astype(jnp.float32))
        quantile_huber_loss = tau_bellman_diff * huber_loss
        # Sum over tau dimension, average over target value dimension.
        quantile_loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1)
        online_dist = metric_utils.representation_distances(
            representations, target_r, distance_fn)
        target_dist = metric_utils.target_distances(target_next_r, rewards,
                                                    distance_fn,
                                                    cumulative_gamma)
        metric_loss = jnp.mean(
            jax.vmap(losses.huber_loss)(online_dist, target_dist))
        loss = ((1. - mico_weight) * quantile_loss + mico_weight * metric_loss)
        return jnp.mean(loss), (loss, jnp.mean(quantile_loss), metric_loss)

    def q_target(state):
        return network_def.apply(target_params, state)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    bellman_target, target_r, target_next_r = target_distribution(
        q_target, states, next_states, rewards, terminals, cumulative_gamma)
    all_losses, grad = grad_fn(online_params, bellman_target, target_r,
                               target_next_r)
    mean_loss, component_losses = all_losses
    loss, quantile_loss, metric_loss = component_losses
    updates, optimizer_state = optimizer.update(grad, optimizer_state)
    online_params = optax.apply_updates(online_params, updates)
    return (optimizer_state, online_params, loss, mean_loss, quantile_loss,
            metric_loss)
Exemple #19
0
    def sgd_step(
        state: TrainingState,
        data: Tuple[types.Transition, types.Transition]
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      replay_transitions, demo_transitions = data
      key, key_loss = jax.random.split(state.key)
      compute_losses_with_input = functools.partial(
          compute_losses,
          replay_o_tm1=replay_transitions.observation,
          replay_a_tm1=replay_transitions.action,
          replay_o_t=replay_transitions.next_observation,
          demo_o_tm1=demo_transitions.observation,
          demo_a_tm1=demo_transitions.action,
          demo_o_t=demo_transitions.next_observation,
          key=key_loss)
      (policy_loss_value, nu_loss_value), vjpfun = jax.vjp(
          compute_losses_with_input,
          state.policy_params, state.nu_params)
      policy_gradients, _ = vjpfun((1.0, 0.0))
      _, nu_gradients = vjpfun((0.0, 1.0))

      # Update optimizers.
      policy_update, policy_optimizer_state = policy_optimizer.update(
          policy_gradients, state.policy_optimizer_state)
      policy_params = optax.apply_updates(state.policy_params, policy_update)

      nu_update, nu_optimizer_state = nu_optimizer.update(
          nu_gradients, state.nu_optimizer_state)
      nu_params = optax.apply_updates(state.nu_params, nu_update)

      new_state = TrainingState(
          policy_optimizer_state=policy_optimizer_state,
          policy_params=policy_params,
          nu_optimizer_state=nu_optimizer_state,
          nu_params=nu_params,
          key=key,
          steps=state.steps + 1,
      )

      metrics = {
          'policy_loss': policy_loss_value,
          'nu_loss': nu_loss_value,
      }

      return new_state, metrics
Exemple #20
0
def update(params, opt_state, x, y_true):
    # calc grads; summed across devices
    loss, grads = value_and_grad(mean_cross_entropy)(params, x, y_true)
    grads = tree_map(lambda v: psum(v, 'device'), grads)
    # apply update
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    # return new states & mean loss
    return params, opt_state, loss.mean()
Exemple #21
0
def discriminator_step(trainer, state, rng, images):
    loss_fn = ft.partial(discriminator_loss, trainer.model, rng)
    val, grads = jax.value_and_grad(loss_fn)(state.model, images)
    update_d, opt_state_d = trainer.optim.g.update(grads.d, state.optim.d)
    state_d = optax.apply_updates(state.model.d, update_d)

    model = GAN(state.model.g, state_d, state.model.s)
    optim = GAN(state.optim.g, opt_state_d, state.optim.s)
    return val, Trainer(model, optim)
def update(
    params: hk.Params,
    opt_state,
    x,l
    ):
    grads = jax.grad(loss_fn)(params, x,l)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state
Exemple #23
0
 def update(
     self,
     params: hk.Params,
     opt_state: optax.OptState,
     trajs: Transition,
 ) -> Tuple[hk.Params, optax.OptState]:
     g = jax.grad(self._agent.loss)(params, trajs)
     updates, new_opt_state = self._opt_update(g, opt_state)
     return optax.apply_updates(params, updates), new_opt_state
def update(
    param_dict,
    opt_state,
    signal
    ):
    grads = jax.grad(loss_fn)(param_dict, signal)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(param_dict, updates)
    return new_params, opt_state
Exemple #25
0
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay,
          seed, max_norm, text_vocab, text_dim, text_depth, text_heads,
          audio_dim, audio_depth, audio_heads):
    # rng

    rng_key = random.PRNGKey(seed)

    # data

    dataset = PairTextSpectrogramDataset(data_folder)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    collate_fn=pair_text_spectrogram_dataset_collate_fn,
                    drop_last=True,
                    shuffle=True)

    # model

    model = CLAP(text_vocab=text_vocab,
                 text_dim=text_dim,
                 text_depth=text_depth,
                 text_heads=text_heads,
                 audio_dim=audio_dim,
                 audio_depth=audio_depth,
                 audio_heads=audio_heads)

    # optimizer

    exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1,
                                                     params)

    optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4),
                  add_decayed_weights(weight_decay, exclude_bias),
                  scale(-learning_rate))

    # init

    audio, audio_mask, text, text_mask = next(iter(dl))

    params = model.init(rng_key, text, audio, text_mask, audio_mask)
    optim_state = optim.init(params)

    # loss function, for use with value_and_grad

    @jit
    @value_and_grad
    def loss_fn(params, text, audio, text_mask, audio_mask):
        return model.apply(params, text, audio, text_mask, audio_mask)

    # train loop

    for _ in range(epochs):
        for audio, audio_mask, text, text_mask in dl:
            loss, grads = loss_fn(params, text, audio, text_mask, audio_mask)
            updates, optim_state = optim.update(grads, optim_state, params)
            params = apply_updates(params, updates)
            print(f'loss: {loss}')
Exemple #26
0
 def learner_step(self, params, data, learner_state, unused_key):
     opt_state, pop_art_state = learner_state
     dloss_dtheta, pop_art_state = jax.grad(self._loss,
                                            has_aux=True)(params,
                                                          pop_art_state,
                                                          *data)
     updates, opt_state = self._optimizer.update(dloss_dtheta, opt_state)
     params = optax.apply_updates(params, updates)
     return params, (opt_state, pop_art_state)
Exemple #27
0
 def minimize(
     self, x: chex.Array, state: optax.OptState
 ) -> Tuple[chex.Array, chex.Array, optax.OptState]:
     """Performs a single minimization step."""
     g, loss = gradients_fn(self._loss_fn, x)
     if g is None:
         raise ValueError('loss_fn does not depend on input.')
     updates, state = self._gradient_transformation.update(g, state, x)
     return optax.apply_updates(x, updates), loss, state
Exemple #28
0
 def body_fn(it, inputs):
     del it  # unused
     x, prng_in, opt_state = inputs
     prng_out, prng_used = jax.random.split(prng_in)
     grad_x = grad_fn(x, prng_used)
     updates, opt_state = opt.update(grad_x, opt_state, x)
     x = optax.apply_updates(x, updates)
     x = projection_fn(x)
     return x, prng_out, opt_state
 def update(params, state, sn_state, rng_key, opt_state, batch):
   (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, state, rng_key, batch)
   updates, new_opt_state = optimizer.update(grads, opt_state)
   new_params = optax.apply_updates(params, updates)
   if FLAGS.spectral_norm > 0:
     new_params, new_sn_state = sn_fn.apply(None, sn_state, None, new_params)
   else:
     new_sn_state = sn_state
   return loss, new_params, state, new_sn_state, new_opt_state
Exemple #30
0
    def test_add_decayed_weights(self):
        """Test no mask gets added for add_decayed_weights."""
        tx_no_mask = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'nesterov',
                    'hps': {
                        'one_minus_decay': 0.1,
                    }
                },
                '1': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4
                    }
                }
            }))
        tx_none_mask = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'nesterov',
                    'hps': {
                        'one_minus_decay': 0.1,
                    }
                },
                '1': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4,
                        'mask': None
                    }
                }
            }))

        params = {'a': 1.}
        state = tx_no_mask.init(params)
        updates, state = tx_no_mask.update(params, state, params)
        result_no_mask = optax.apply_updates(params, updates)

        state = tx_none_mask.init(params)
        updates, state = tx_none_mask.update(params, state, params)
        result_none_mask = optax.apply_updates(params, updates)

        chex.assert_trees_all_equal(result_no_mask, result_none_mask)