def testUtilityClipGrads(self): g = (np.ones(2), (np.ones(3), np.ones(4))) norm = optimizers.l2_norm(g) ans = optimizers.clip_grads(g, 1.1 * norm) expected = g self.assertAllClose(ans, expected, check_dtypes=False) ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm)) expected = 0.9 * norm self.assertAllClose(ans, expected, check_dtypes=False)
def update_w_gc(i, opt_state, hps, opt_hps, key, x_bxt, kl_warmup): """Update fun for gradients, includes gradient clipping.""" params = get_params(opt_state) grads = grad(lfads.training_loss_jit)(params, hps, key, x_bxt, kl_warmup, opt_hps['keep_rate']) clipped_grads = optimizers.clip_grads(grads, opt_hps['max_grad_norm']) return opt_update(i, clipped_grads, opt_state)
def fit_mixture(data, num_components=3, verbose=False, num_samples=5000) -> LogisticMixtureParams: # the data might be something weird, like a pandas dataframe column; # turn it into a regular old numpy array data_as_np_array = np.array(data) step_size = 0.01 components = initialize_components(num_components) (init_fun, update_fun, get_params) = sgd(step_size) opt_state = init_fun(components) for i in tqdm(range(num_samples)): components = get_params(opt_state) grads = -grad_mixture_logpdf(data_as_np_array, components) if np.any(np.isnan(grads)): print("Encoutered nan gradient, stopping early") print(grads) print(components) break grads = clip_grads(grads, 1.0) opt_state = update_fun(i, grads, opt_state) if i % 500 == 0 and verbose: pprint(components) score = mixture_logpdf(data_as_np_array, components) print(f"Log score: {score:.3f}") return structure_mixture_params(components)
def update_w_gc(i, opt_state, opt_update, get_params, x_bxt, f_bxt, f_mask_bxt, max_grad_norm, l2reg): """Update the parameters w/ gradient clipped, gradient descent updates. Arguments: i: batch number opt_state: parameters plus optimizer state opt_update: optimizer state update function get_params: function to extract parameters from optimizer state x_bxt: rnn inputs f_bxt: rnn targets f_mask_bxt: masks for when target is defined max_grad_norm: maximum norm value gradient is allowed to take l2reg: l2 regularization hyperparameter Returns: opt_state tuple (as above) that includes updated parameters and optimzier state. """ params = get_params(opt_state) def training_loss(params, x_bxt, f_bxt, l2reg): return loss(params, x_bxt, f_bxt, f_mask_bxt, l2reg)['total'] grads = grad(training_loss)(params, x_bxt, f_bxt, l2reg) clipped_grads = optimizers.clip_grads(grads, max_grad_norm) return opt_update(i, clipped_grads, opt_state)
def compute_grads_and_update(self, batch, env_ids, max_grad_norm, new_rng, train_loss_fn, train_state): # Compute learning rate: lr = self.get_learning_rate(train_state.global_step) # Compute gradients: compute_gradient_fn = jax.value_and_grad(train_loss_fn, has_aux=True) (_, (new_model_state, logits, logs)), grad = compute_gradient_fn(train_state.optimizer.target) # Update parameters: grad = jax.lax.pmean(grad, axis_name='batch') # Clip gradients: if max_grad_norm is not None: grad = clip_grads(grad, max_grad_norm) new_optimizer = train_state.optimizer.apply_gradient(grad, learning_rate=lr) # Get the new (updated) train_state: new_train_state = pipeline_utils.TrainState( global_step=train_state.global_step + 1, optimizer=new_optimizer, model_state=new_model_state, rng=new_rng) metrics = self.collect_metrics(batch, env_ids, logits, logs, lr, train_state.optimizer.target) return new_train_state, metrics
def optimizer_step(current_step, state, batch): """Takes a single optimization step.""" p = get_params(state) loss, gradients = jax.value_and_grad(loss_fun)(p, batch) gradients = optimizers.clip_grads(gradients, gradient_clip) new_state = update_opt(current_step, gradients, state) return current_step + 1, new_state, loss
def update_w_gc(i, opt_state, opt_update, get_params, x_bxt, f_bxt, max_grad_norm, l2reg): """Update the parameters w/ gradient clipped, gradient descent updates.""" params = get_params(opt_state) def training_loss(params, x_bxt, f_bxt, l2reg): return loss(params, x_bxt, f_bxt, l2reg)['total'] grads = grad(training_loss)(params, x_bxt, f_bxt, l2reg) clipped_grads = optimizers.clip_grads(grads, max_grad_norm) return opt_update(i, clipped_grads, opt_state)
def optimizer_step_clip(current_step, state, batch): """Takes a single optimization step.""" p = get_params(state) loss, gradients = jax.value_and_grad(loss_fun)(p, batch) gradients = optimizers.clip_grads(gradients, gradient_clip) # Sets readout gradients to zero # rnn_grads, ro_grads = gradients # ro_grads = optimizers.clip_grads(ro_grads, 0.0) # gradients = rnn_grads, ro_grads new_state = update_opt(current_step, gradients, state) return current_step + 1, new_state, loss
def update(batch_idx, __opt_state): """Update func for gradients, includes gradient clipping.""" kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx) batch_data = lax.dynamic_slice_in_dim(epoch_data, batch_idx * BATCH_SIZE, BATCH_SIZE, axis=0) batch_data = batch_data.astype(np.float32) params = get_params(__opt_state) grads = grad(loss_fn)(params, batch_data, next(batch_keys), BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup, L2_REG) clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM) return opt_update(batch_idx, clipped_grads, __opt_state)
def critic_step( optimizer: optim.Optimizer, state: jnp.ndarray, action: jnp.ndarray, target_Q: jnp.ndarray, ) -> optim.Optimizer: """ The critic is optimized the same way as typical actor critic methods, minimizing the TD error. """ def loss_fn(critic_params): current_Q1, current_Q2 = apply_double_critic_model( critic_params, state, action, False) critic_loss = double_mse(current_Q1, current_Q2, target_Q) return critic_loss.mean() grad = jax.grad(loss_fn)(optimizer.target) grad = clip_grads(grad, 40.0) return optimizer.apply_gradient(grad)
def compute_grads_and_update(self, batch, max_grad_norm, new_rng, train_loss_fn, train_state): """Compute grads and updates parameters. Args: batch: dict; Batch of examples. max_grad_norm: float; Max value for grad norm (used for grad clipping). new_rng: Jax RNG key. train_loss_fn: fn(params)--> loss; Loss function (for which grad is computed). train_state: TrainState, the state of training including the current global_step, model_state, rng, and optimizer. Returns: Updated state of training and calculated metrics. """ # Compute learning rate: lr = self.get_learning_rate(train_state.global_step) compute_gradient_fn = jax.value_and_grad(train_loss_fn, has_aux=True) (_, (new_model_state, logits)), grad = compute_gradient_fn(train_state.optimizer.target) # re-use same axis_name as in the call to `pmap(...train_step...)` below grad = jax.lax.pmean(grad, axis_name='batch') if max_grad_norm is not None: grad = clip_grads(grad, max_grad_norm) new_optimizer = train_state.optimizer.apply_gradient(grad, learning_rate=lr) new_train_state = train_state.replace( global_step=train_state.global_step + 1, optimizer=new_optimizer, model_state=new_model_state, rng=new_rng) metric_dict = self.collect_metrics(batch, logits, lr) return new_train_state, metric_dict
split='train' if args.data_augment else 'test', seed=seed * args.seed_separator + i, num_classes=args.num_classes) X, Y = batch['image'], batch['label'] params_curr = get_params(state) loss_curr, grad_curr = value_and_grad_loss(params_curr, X, Y) # monitor gradient norm grad_norm = optimizers.l2_norm(grad_curr) writer.add_scalar(f'grad_norm/{tb_flag}', grad_norm.item(), global_step) if np.isnan(loss_curr): sys.exit() running_loss += loss_curr running_count += 1 if args.grad_norm_thresh > 0: grad_curr = optimizers.clip_grads(grad_curr, args.grad_norm_thresh) state = opt_apply(epoch, grad_curr, state) global_step += 1 print( f"Step {global_step}, training loss={loss_curr:.4f}, grad norm={grad_norm:.4f}" ) # Evaluate on the test set if global_step % args.save_steps == 0 \ or global_step % args.early_save_steps == 0 and global_step <= args.early_save_till_step: test_loader = tfds.as_numpy(test_data.batch(args.batch_size_test)) acc_f, loss_test = 0., 0. acc_g, loss_test_g = 0., 0. params_curr = get_params(state) start_ind = 0 for j, test_batch in enumerate(test_loader):
def correction_step(self) -> Tuple: """Given the current state optimize to the correct state. Returns: (state: problem parameters, bparam: continuation parameter) Tuple """ quality = 1.0 if self.hparams["meta"]["dataset"] == "mnist": # TODO: make it generic batch_data = next(self.data_loader) else: batch_data = None ants_norm_grads = [5.0 for _ in range(self.hparams["n_wall_ants"])] ants_loss_values = [5.0 for _ in range(self.hparams["n_wall_ants"])] ants_state = [self._state for _ in range(self.hparams["n_wall_ants"])] ants_bparam = [ self._bparam for _ in range(self.hparams["n_wall_ants"]) ] for i_n in range(self.hparams["n_wall_ants"]): corrector_omega = 1.0 stop = False _, key = random.split( random.PRNGKey(self.key_state + i_n + npr.randint(1, (i_n + 1) * 10))) del _ self._parc_vec, self.state_stack = self._perform_perturb_by_projection( self._state_secant_vector, self._state_secant_c2, key, self.pred_prev_state, self._state, self._bparam, i_n, self.sphere_radius, batch_data, ) if self.hparams["_evaluate_perturb"]: self._evaluate_perturb() # does every time ants_state[i_n] = self.state_stack["state"] ants_bparam[i_n] = self.state_stack["bparam"] D_values = [] print(f"num_batches", self.num_batches) for j_epoch in range(self.descent_period): for b_j in range(self.num_batches): #alternate # grads = self.compute_grad_fn(self._state, self._bparam, batch_data) # self._state = self.opt.update_params(self._state, grads[0]) state_grads, bparam_grads = self.compute_min_grad_fn( ants_state[i_n], ants_bparam[i_n], self._lagrange_multiplier, self._state_secant_c2, self._state_secant_vector, batch_data, self.delta_s, ) if self.hparams["adaptive"]: self.opt.lr = self.exp_decay( j_epoch, self.hparams["natural_lr"]) quality = l2_norm(state_grads) #l2_norm(bparam_grads) if self.hparams[ "local_test_measure"] == "norm_gradients": if quality > self.hparams["quality_thresh"]: pass print( f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j_epoch}" ) else: stop = True print( f"quality {quality} stopping at , {j_epoch}th step" ) else: print( f"quality {quality}, {bparam_grads} ,{j_epoch}" ) if len(D_values) >= 20: tmp_means = running_mean(D_values, 10) if (math.isclose( tmp_means[-1], tmp_means[-2], abs_tol=self.hparams["loss_tol"])): print( f"stopping at , {j_epoch}th step, {ants_bparam[i_n]} bparam" ) stop = True state_grads = clip_grads(state_grads, self.hparams["max_clip_grad"]) bparam_grads = clip_grads( bparam_grads, self.hparams["max_clip_grad"]) if self.hparams["guess_ant_steps"] >= ( j_epoch + 1): # To get around folds slowly corrector_omega = min( self.hparams["guess_ant_steps"] / (j_epoch + 1), 1.5) else: corrector_omega = max( self.hparams["guess_ant_steps"] / (j_epoch + 1), 0.05) ants_state[i_n] = self.opt.update_params( ants_state[i_n], state_grads, j_epoch) ants_bparam[i_n] = self.opt.update_params( ants_bparam[i_n], bparam_grads, j_epoch) ants_loss_values[i_n] = self.value_fn( ants_state[i_n], ants_bparam[i_n], batch_data) D_values.append(ants_loss_values[i_n]) ants_norm_grads[i_n] = quality # if stop: # break if (self.hparams["meta"]["dataset"] == "mnist" ): # TODO: make it generic batch_data = next(self.data_loader) if stop: break # ants_group = dict(enumerate(grouper(ants_state, tolerence), 1)) # print(f"Number of groups: {len(ants_group)}") cheapest_index = get_cheapest_ant( ants_norm_grads, ants_loss_values, local_test=self.hparams["local_test_measure"]) self._state = ants_state[cheapest_index] self._bparam = ants_bparam[cheapest_index] value = self.value_fn(self._state, self._bparam, batch_data) # Todo: why only final batch data _, _, test_images, test_labels = mnist(permute_train=False, resize=True, filter=self.hparams["filter"]) del _ val_loss = self.value_fn(self._state, self._bparam, (test_images, test_labels)) print(f"val loss: {val_loss}") return self._state, self._bparam, quality, value, val_loss, corrector_omega
def m_step( rngs: PRNGSequence, actor_optimizer: optim.Optimizer, actor_target_params: FrozenDict, eps_mu: float, eps_sig: float, mu_lagrange_optimizer: optim.Optimizer, sig_lagrange_optimizer: optim.Optimizer, max_action: float, action_dim: int, state: jnp.ndarray, weights: jnp.ndarray, sampled_actions: jnp.ndarray, ) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]: """ The 'M-step' from the MPO paper. We optimize our policy network to maximize the lower bound on the probablility of obtaining the maximum reward given that we act according to our policy (i.e. weighted according to our sampled actions). """ def loss_fn(mlo, slo, actor_params): # get the distribution of the actor network (current policy) mu, log_sig = apply_gaussian_policy_model( actor_params, action_dim, max_action, state, None, False, True ) sig = jnp.exp(log_sig) # get the distribution of the target network (old policy) target_mu, target_log_sig = apply_gaussian_policy_model( actor_target_params, action_dim, max_action, state, None, False, True ) target_mu = jax.lax.stop_gradient(target_mu) target_log_sig = jax.lax.stop_gradient(target_log_sig) target_sig = jnp.exp(target_log_sig) # get the log likelihooods of the sampled actions according to the # decoupled distributions. described in section 4.2.1 of # Relative Entropy Regularized Policy Iteration # this ensures that the nonparametric policy won't collapse to give # a probability of 1 to the best action, which is a risk when we use # the on-policy distribution to calculate the likelihood. actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, log_sig) actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_log_sig) actor_log_prob = actor_log_prob.transpose((0, 1)) mu_kl = kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean() sig_kl = kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean() mlo = mu_lagrange_step(mlo, eps_mu - jax.lax.stop_gradient(mu_kl)) slo = sig_lagrange_step(slo, eps_sig - jax.lax.stop_gradient(sig_kl)) # maximize the log likelihood, regularized by the divergence between # the target policy and the current policy. the goal here is to fit # the parametric policy to have the minimum divergence with the nonparametric # distribution based on the sampled actions. actor_loss = -(actor_log_prob * weights).sum(axis=1).mean() actor_loss -= jax.lax.stop_gradient( apply_constant_model(mlo.target, 1.0, True) ) * (eps_mu - mu_kl) actor_loss -= jax.lax.stop_gradient( apply_constant_model(slo.target, 100.0, True) ) * (eps_sig - sig_kl) return actor_loss.mean(), (mlo, slo) grad, (mu_lagrange_optimizer, sig_lagrange_optimizer) = jax.grad( partial(loss_fn, mu_lagrange_optimizer, sig_lagrange_optimizer), has_aux=True )(actor_optimizer.target) grad = clip_grads(grad, 40.0) actor_optimizer = actor_optimizer.apply_gradient(grad) return mu_lagrange_optimizer, sig_lagrange_optimizer, actor_optimizer
def postprocess_gradients(gradients): return optimizers.clip_grads(gradients, 1.0)
def correction_step(self) -> Tuple: """Given the current state optimize to the correct state. Returns: (state: problem parameters, bparam: continuation parameter) Tuple """ _, key = random.split(random.PRNGKey(self.key_state + npr.randint(1, 100))) del _ quality = 1.0 N_opt = 10 stop = False corrector_omega = 1.0 # bparam_grads = pytree_zeros_like(self._bparam) print("the radius", self.sphere_radius) self._parc_vec, self.state_stack = self._perform_perturb_by_projection( self._state_secant_vector, self._state_secant_c2, key, self.pred_prev_state, self._state, self._bparam, self.sphere_radius, ) if self.hparams["_evaluate_perturb"]: self._evaluate_perturb() # does every time for j in range(self.descent_period): for b_j in range(self.num_batches): if self.hparams["meta"]["dataset"] == "mnist": # TODO: make it generic batch_data = next(self.data_loader) else: batch_data = None # grads = self.compute_grad_fn(self._state, self._bparam, batch_data) # self._state = self.opt.update_params(self._state, grads[0]) state_grads, bparam_grads = self.compute_min_grad_fn( self._state, self._bparam, self._lagrange_multiplier, self._state_secant_c2, self._state_secant_vector, batch_data, self.delta_s, ) if self.hparams["adaptive"]: self.opt.lr = self.exp_decay(j, self.hparams["natural_lr"]) quality = l2_norm(state_grads) # +l2_norm(bparam_grads) if quality > self.hparams["quality_thresh"]: pass # print(f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j}") else: if N_opt > (j + 1): # To get around folds slowly corrector_omega = min(N_opt / (j + 1), 2.0) else: corrector_omega = max(N_opt / (j + 1), 0.5) stop = True print(f"quality {quality} stopping at , {j}th step") state_grads = clip_grads(state_grads, self.hparams["max_clip_grad"]) bparam_grads = clip_grads( bparam_grads, self.hparams["max_clip_grad"] ) self._bparam = self.opt.update_params(self._bparam, bparam_grads, j) self._state = self.opt.update_params(self._state, state_grads, j) if stop: break if stop: break value = self.value_fn( self._state, self._bparam, batch_data ) # Todo: why only final batch data return self._state, self._bparam, quality, value, corrector_omega