def update_fun(_, param_and_state): params, opt_state = param_and_state updates, next_opt_state = self._opt.update(grad_fn(params), opt_state) next_params = optax.apply_updates(params, updates) next_params = jax.tree_multimap( lambda relax, param: relax.project_params(param), relaxations, next_params) return next_params, next_opt_state
def step(params, opt_state, batch, labels): (loss_val, accuracy), grads = jax.value_and_grad(loss, has_aux=True)(params, batch, labels) updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss_val, accuracy
def _compute_update(curr_params, optimizer_state, gradient, optimizer_update_fn): model_updates, _ = optimizer_update_fn(gradient, optimizer_state, params=curr_params) new_params = optax.apply_updates(curr_params, model_updates) diff = _tree_sub(curr_params, new_params) return diff
def update(carry, rng): # TODO log the two losses separately? opt_state, params = carry xs_batch = vmap(sample_random_x)(random.split(rng, config.batch_size)) batch_loss, g = value_and_grad(loss)(params, xs_batch) updates, opt_state = tx.update(g, opt_state) params = optax.apply_updates(params, updates) return (opt_state, params), batch_loss
def update(params_adv, opt_state, info_states, samp_regrets, iterations, masks, total_iterations): main_loss, grads = self._adv_grads(params_adv, info_states, samp_regrets, iterations, masks, total_iterations) updates, new_opt_state = self._opt_adv_update(grads, opt_state) new_params = optax.apply_updates(params_adv, updates) return new_params, new_opt_state, main_loss
def train_step(carry, step): params, opt_state = carry loss, grads = value_and_grad(objective)(params) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) callback_result = callback(params, step) if callback is not None else None return (params, opt_state), (loss, callback_result)
def _computeUpdate(self, params: hk.Params, target: hk.Params, opt: Any, batch: Batch): delta, grad = jax.value_and_grad(self._loss)(params, target, batch) updates, state = self.optimizer.update(grad, opt, params) params = optax.apply_updates(params, updates) return jnp.sqrt(delta), state, params
def sgd_step(state: AgentState, trajectory: buffer.Trajectory) -> AgentState: """Performs a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return AgentState(params=new_params, opt_state=new_opt_state)
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState: """Does a step of SGD given inputs & targets.""" _, optimizer = make_optimizer() _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss)) gradients = jax.grad(loss_fn)(state.params, batch) updates, new_opt_state = optimizer(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state)
def train_step(params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch) -> Tuple[hk.Params, optax.OptState]: """Single update step to maximize the ELBO.""" grads = jax.grad(objective_fn)(params, rng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state
def update(self, minibatch): images, labels = minibatch['image'].astype( jnp.float32) / 255., minibatch['label'] grads = jax.grad(self.loss_fn)(self.params, images, labels) updates, self.opt_state = self.optimizer.update(grads, self.opt_state) params = optax.apply_updates(self.params, updates) self.avg_params = ema_update(self.avg_params, params) self.log(net=self.network, net_params=params, datasets=self.datasets)
def update(rng_key, opt_state, online_params, target_params, transitions): """Computes learning update from batch of replay transitions.""" rng_key, update_key = jax.random.split(rng_key) d_loss_d_params = jax.grad(loss_fn)(online_params, target_params, transitions, update_key) updates, new_opt_state = optimizer.update(d_loss_d_params, opt_state) new_online_params = optax.apply_updates(online_params, updates) return rng_key, new_opt_state, new_online_params
def update(params, opt_state, batch, target, weights, rng) -> Tuple[hk.Params, optax.OptState, jnp.ndarray]: batch_loss, grads = jax.value_and_grad(model_loss)(params, batch, target, weights, rng) updates, opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state, batch_loss
def body(carry, i): key, state, params = carry key, subkey = random.split(key) (lp_val, param_grad), _ = estimate_gradient(i, subkey, params) neg_param_grad = tree_map(lambda x: -x, param_grad) updates, state = optimizer.update(neg_param_grad, state) params = optax.apply_updates(params, updates) return (key, state, params), (lp_val, ravel_pytree(params)[0])
def body_fn(inputs): it, x, _, opt_state = inputs grad_x = grad_fn(x) updates, opt_state = opt.update(grad_x, opt_state, x) x = optax.apply_updates(x, updates) x = jnp.clip(x, l, u) it = it + 1 return it, x, grad_x, opt_state
def sgd_step(state: TrainingState, trajectory: sequence.Trajectory) -> TrainingState: """Does a step of SGD over a trajectory.""" gradients = jax.grad(loss_fn)(state.params, trajectory) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) return TrainingState(params=new_params, opt_state=new_opt_state)
def update(params_policy, opt_state, info_states, action_probs, iterations, masks, total_iterations): main_loss, grads = self._policy_grads(params_policy, info_states, action_probs, iterations, masks, total_iterations) updates, new_opt_state = self._opt_policy_update(grads, opt_state) new_params = optax.apply_updates(params_policy, updates) return new_params, new_opt_state, main_loss
def train(network_def, online_params, target_params, optimizer, optimizer_state, states, actions, next_states, rewards, terminals, kappa, num_atoms, cumulative_gamma, mico_weight, distance_fn): """Run a training step.""" def loss_fn(params, bellman_target, target_r, target_next_r): def q_online(state): return network_def.apply(params, state) model_output = jax.vmap(q_online)(states) logits = model_output.logits logits = jnp.squeeze(logits) representations = model_output.representation representations = jnp.squeeze(representations) # Fetch the logits for its selected action. We use vmap to perform this # indexing across the batch. chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions) bellman_errors = (bellman_target[:, None, :] - chosen_action_logits[:, :, None] ) # Input `u' of Eq. 9. # Eq. 9 of paper. huber_loss = ( (jnp.abs(bellman_errors) <= kappa).astype(jnp.float32) * 0.5 * bellman_errors**2 + (jnp.abs(bellman_errors) > kappa).astype(jnp.float32) * kappa * (jnp.abs(bellman_errors) - 0.5 * kappa)) tau_hat = ((jnp.arange(num_atoms, dtype=jnp.float32) + 0.5) / num_atoms ) # Quantile midpoints. See Lemma 2 of paper. # Eq. 10 of paper. tau_bellman_diff = jnp.abs(tau_hat[None, :, None] - (bellman_errors < 0).astype(jnp.float32)) quantile_huber_loss = tau_bellman_diff * huber_loss # Sum over tau dimension, average over target value dimension. quantile_loss = jnp.sum(jnp.mean(quantile_huber_loss, 2), 1) online_dist = metric_utils.representation_distances( representations, target_r, distance_fn) target_dist = metric_utils.target_distances(target_next_r, rewards, distance_fn, cumulative_gamma) metric_loss = jnp.mean( jax.vmap(losses.huber_loss)(online_dist, target_dist)) loss = ((1. - mico_weight) * quantile_loss + mico_weight * metric_loss) return jnp.mean(loss), (loss, jnp.mean(quantile_loss), metric_loss) def q_target(state): return network_def.apply(target_params, state) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) bellman_target, target_r, target_next_r = target_distribution( q_target, states, next_states, rewards, terminals, cumulative_gamma) all_losses, grad = grad_fn(online_params, bellman_target, target_r, target_next_r) mean_loss, component_losses = all_losses loss, quantile_loss, metric_loss = component_losses updates, optimizer_state = optimizer.update(grad, optimizer_state) online_params = optax.apply_updates(online_params, updates) return (optimizer_state, online_params, loss, mean_loss, quantile_loss, metric_loss)
def sgd_step( state: TrainingState, data: Tuple[types.Transition, types.Transition] ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: replay_transitions, demo_transitions = data key, key_loss = jax.random.split(state.key) compute_losses_with_input = functools.partial( compute_losses, replay_o_tm1=replay_transitions.observation, replay_a_tm1=replay_transitions.action, replay_o_t=replay_transitions.next_observation, demo_o_tm1=demo_transitions.observation, demo_a_tm1=demo_transitions.action, demo_o_t=demo_transitions.next_observation, key=key_loss) (policy_loss_value, nu_loss_value), vjpfun = jax.vjp( compute_losses_with_input, state.policy_params, state.nu_params) policy_gradients, _ = vjpfun((1.0, 0.0)) _, nu_gradients = vjpfun((0.0, 1.0)) # Update optimizers. policy_update, policy_optimizer_state = policy_optimizer.update( policy_gradients, state.policy_optimizer_state) policy_params = optax.apply_updates(state.policy_params, policy_update) nu_update, nu_optimizer_state = nu_optimizer.update( nu_gradients, state.nu_optimizer_state) nu_params = optax.apply_updates(state.nu_params, nu_update) new_state = TrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, nu_optimizer_state=nu_optimizer_state, nu_params=nu_params, key=key, steps=state.steps + 1, ) metrics = { 'policy_loss': policy_loss_value, 'nu_loss': nu_loss_value, } return new_state, metrics
def update(params, opt_state, x, y_true): # calc grads; summed across devices loss, grads = value_and_grad(mean_cross_entropy)(params, x, y_true) grads = tree_map(lambda v: psum(v, 'device'), grads) # apply update updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) # return new states & mean loss return params, opt_state, loss.mean()
def discriminator_step(trainer, state, rng, images): loss_fn = ft.partial(discriminator_loss, trainer.model, rng) val, grads = jax.value_and_grad(loss_fn)(state.model, images) update_d, opt_state_d = trainer.optim.g.update(grads.d, state.optim.d) state_d = optax.apply_updates(state.model.d, update_d) model = GAN(state.model.g, state_d, state.model.s) optim = GAN(state.optim.g, opt_state_d, state.optim.s) return val, Trainer(model, optim)
def update( params: hk.Params, opt_state, x,l ): grads = jax.grad(loss_fn)(params, x,l) updates, opt_state = opt.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, opt_state
def update( self, params: hk.Params, opt_state: optax.OptState, trajs: Transition, ) -> Tuple[hk.Params, optax.OptState]: g = jax.grad(self._agent.loss)(params, trajs) updates, new_opt_state = self._opt_update(g, opt_state) return optax.apply_updates(params, updates), new_opt_state
def update( param_dict, opt_state, signal ): grads = jax.grad(loss_fn)(param_dict, signal) updates, opt_state = opt.update(grads, opt_state) new_params = optax.apply_updates(param_dict, updates) return new_params, opt_state
def train(*, data_folder, batch_size, epochs, learning_rate, weight_decay, seed, max_norm, text_vocab, text_dim, text_depth, text_heads, audio_dim, audio_depth, audio_heads): # rng rng_key = random.PRNGKey(seed) # data dataset = PairTextSpectrogramDataset(data_folder) dl = DataLoader(dataset, batch_size=batch_size, collate_fn=pair_text_spectrogram_dataset_collate_fn, drop_last=True, shuffle=True) # model model = CLAP(text_vocab=text_vocab, text_dim=text_dim, text_depth=text_depth, text_heads=text_heads, audio_dim=audio_dim, audio_depth=audio_depth, audio_heads=audio_heads) # optimizer exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1, params) optim = chain(clip_by_global_norm(max_norm), scale_by_adam(eps=1e-4), add_decayed_weights(weight_decay, exclude_bias), scale(-learning_rate)) # init audio, audio_mask, text, text_mask = next(iter(dl)) params = model.init(rng_key, text, audio, text_mask, audio_mask) optim_state = optim.init(params) # loss function, for use with value_and_grad @jit @value_and_grad def loss_fn(params, text, audio, text_mask, audio_mask): return model.apply(params, text, audio, text_mask, audio_mask) # train loop for _ in range(epochs): for audio, audio_mask, text, text_mask in dl: loss, grads = loss_fn(params, text, audio, text_mask, audio_mask) updates, optim_state = optim.update(grads, optim_state, params) params = apply_updates(params, updates) print(f'loss: {loss}')
def learner_step(self, params, data, learner_state, unused_key): opt_state, pop_art_state = learner_state dloss_dtheta, pop_art_state = jax.grad(self._loss, has_aux=True)(params, pop_art_state, *data) updates, opt_state = self._optimizer.update(dloss_dtheta, opt_state) params = optax.apply_updates(params, updates) return params, (opt_state, pop_art_state)
def minimize( self, x: chex.Array, state: optax.OptState ) -> Tuple[chex.Array, chex.Array, optax.OptState]: """Performs a single minimization step.""" g, loss = gradients_fn(self._loss_fn, x) if g is None: raise ValueError('loss_fn does not depend on input.') updates, state = self._gradient_transformation.update(g, state, x) return optax.apply_updates(x, updates), loss, state
def body_fn(it, inputs): del it # unused x, prng_in, opt_state = inputs prng_out, prng_used = jax.random.split(prng_in) grad_x = grad_fn(x, prng_used) updates, opt_state = opt.update(grad_x, opt_state, x) x = optax.apply_updates(x, updates) x = projection_fn(x) return x, prng_out, opt_state
def update(params, state, sn_state, rng_key, opt_state, batch): (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, state, rng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) if FLAGS.spectral_norm > 0: new_params, new_sn_state = sn_fn.apply(None, sn_state, None, new_params) else: new_sn_state = sn_state return loss, new_params, state, new_sn_state, new_opt_state
def test_add_decayed_weights(self): """Test no mask gets added for add_decayed_weights.""" tx_no_mask = from_hparams( ml_collections.ConfigDict({ '0': { 'element': 'nesterov', 'hps': { 'one_minus_decay': 0.1, } }, '1': { 'element': 'add_decayed_weights', 'hps': { 'weight_decay': 1e-4 } } })) tx_none_mask = from_hparams( ml_collections.ConfigDict({ '0': { 'element': 'nesterov', 'hps': { 'one_minus_decay': 0.1, } }, '1': { 'element': 'add_decayed_weights', 'hps': { 'weight_decay': 1e-4, 'mask': None } } })) params = {'a': 1.} state = tx_no_mask.init(params) updates, state = tx_no_mask.update(params, state, params) result_no_mask = optax.apply_updates(params, updates) state = tx_none_mask.init(params) updates, state = tx_none_mask.update(params, state, params) result_none_mask = optax.apply_updates(params, updates) chex.assert_trees_all_equal(result_no_mask, result_none_mask)