Beispiel #1
0
    def testUtilityClipGrads(self):
        g = (np.ones(2), (np.ones(3), np.ones(4)))
        norm = optimizers.l2_norm(g)

        ans = optimizers.clip_grads(g, 1.1 * norm)
        expected = g
        self.assertAllClose(ans, expected, check_dtypes=False)

        ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
        expected = 0.9 * norm
        self.assertAllClose(ans, expected, check_dtypes=False)
 def update_w_gc(i, opt_state, hps, opt_hps, key, x_bxt, kl_warmup):
     """Update fun for gradients, includes gradient clipping."""
     params = get_params(opt_state)
     grads = grad(lfads.training_loss_jit)(params, hps, key, x_bxt,
                                           kl_warmup, opt_hps['keep_rate'])
     clipped_grads = optimizers.clip_grads(grads, opt_hps['max_grad_norm'])
     return opt_update(i, clipped_grads, opt_state)
Beispiel #3
0
def fit_mixture(data,
                num_components=3,
                verbose=False,
                num_samples=5000) -> LogisticMixtureParams:
    # the data might be something weird, like a pandas dataframe column;
    # turn it into a regular old numpy array
    data_as_np_array = np.array(data)
    step_size = 0.01
    components = initialize_components(num_components)
    (init_fun, update_fun, get_params) = sgd(step_size)
    opt_state = init_fun(components)
    for i in tqdm(range(num_samples)):
        components = get_params(opt_state)
        grads = -grad_mixture_logpdf(data_as_np_array, components)
        if np.any(np.isnan(grads)):
            print("Encoutered nan gradient, stopping early")
            print(grads)
            print(components)
            break
        grads = clip_grads(grads, 1.0)
        opt_state = update_fun(i, grads, opt_state)
        if i % 500 == 0 and verbose:
            pprint(components)
            score = mixture_logpdf(data_as_np_array, components)
            print(f"Log score: {score:.3f}")
    return structure_mixture_params(components)
Beispiel #4
0
def update_w_gc(i, opt_state, opt_update, get_params, x_bxt, f_bxt, f_mask_bxt,
                max_grad_norm, l2reg):
    """Update the parameters w/ gradient clipped, gradient descent updates.

  Arguments:
    i: batch number
    opt_state: parameters plus optimizer state
    opt_update: optimizer state update function
    get_params: function to extract parameters from optimizer state
    x_bxt: rnn inputs
    f_bxt: rnn targets
    f_mask_bxt: masks for when target is defined
    max_grad_norm: maximum norm value gradient is allowed to take
    l2reg: l2 regularization hyperparameter

  Returns:
    opt_state tuple (as above) that includes updated parameters and optimzier
      state.
  """
    params = get_params(opt_state)

    def training_loss(params, x_bxt, f_bxt, l2reg):
        return loss(params, x_bxt, f_bxt, f_mask_bxt, l2reg)['total']

    grads = grad(training_loss)(params, x_bxt, f_bxt, l2reg)
    clipped_grads = optimizers.clip_grads(grads, max_grad_norm)
    return opt_update(i, clipped_grads, opt_state)
    def compute_grads_and_update(self, batch, env_ids, max_grad_norm, new_rng,
                                 train_loss_fn, train_state):

        # Compute learning rate:
        lr = self.get_learning_rate(train_state.global_step)

        # Compute gradients:
        compute_gradient_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
        (_, (new_model_state, logits,
             logs)), grad = compute_gradient_fn(train_state.optimizer.target)

        # Update parameters:
        grad = jax.lax.pmean(grad, axis_name='batch')
        # Clip gradients:
        if max_grad_norm is not None:
            grad = clip_grads(grad, max_grad_norm)

        new_optimizer = train_state.optimizer.apply_gradient(grad,
                                                             learning_rate=lr)

        # Get the new (updated) train_state:
        new_train_state = pipeline_utils.TrainState(
            global_step=train_state.global_step + 1,
            optimizer=new_optimizer,
            model_state=new_model_state,
            rng=new_rng)

        metrics = self.collect_metrics(batch, env_ids, logits, logs, lr,
                                       train_state.optimizer.target)

        return new_train_state, metrics
Beispiel #6
0
    def optimizer_step(current_step, state, batch):
        """Takes a single optimization step."""
        p = get_params(state)
        loss, gradients = jax.value_and_grad(loss_fun)(p, batch)

        gradients = optimizers.clip_grads(gradients, gradient_clip)

        new_state = update_opt(current_step, gradients, state)
        return current_step + 1, new_state, loss
Beispiel #7
0
def update_w_gc(i, opt_state, opt_update, get_params, x_bxt, f_bxt,
                max_grad_norm, l2reg):
  """Update the parameters w/ gradient clipped, gradient descent updates."""
  params = get_params(opt_state)

  def training_loss(params, x_bxt, f_bxt, l2reg):
    return loss(params, x_bxt, f_bxt, l2reg)['total']
  
  grads = grad(training_loss)(params, x_bxt, f_bxt, l2reg)
  clipped_grads = optimizers.clip_grads(grads, max_grad_norm)
  return opt_update(i, clipped_grads, opt_state)
Beispiel #8
0
    def optimizer_step_clip(current_step, state, batch):
        """Takes a single optimization step."""
        p = get_params(state)
        loss, gradients = jax.value_and_grad(loss_fun)(p, batch)

        gradients = optimizers.clip_grads(gradients, gradient_clip)
        # Sets readout gradients to zero
        # rnn_grads, ro_grads = gradients
        # ro_grads = optimizers.clip_grads(ro_grads, 0.0)
        # gradients = rnn_grads, ro_grads

        new_state = update_opt(current_step, gradients, state)
        return current_step + 1, new_state, loss
Beispiel #9
0
        def update(batch_idx, __opt_state):
            """Update func for gradients, includes gradient clipping."""
            kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx)

            batch_data = lax.dynamic_slice_in_dim(epoch_data,
                                                  batch_idx * BATCH_SIZE,
                                                  BATCH_SIZE,
                                                  axis=0)
            batch_data = batch_data.astype(np.float32)

            params = get_params(__opt_state)
            grads = grad(loss_fn)(params, batch_data, next(batch_keys),
                                  BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup,
                                  L2_REG)
            clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM)

            return opt_update(batch_idx, clipped_grads, __opt_state)
Beispiel #10
0
def critic_step(
    optimizer: optim.Optimizer,
    state: jnp.ndarray,
    action: jnp.ndarray,
    target_Q: jnp.ndarray,
) -> optim.Optimizer:
    """
    The critic is optimized the same way as typical actor critic methods,
    minimizing the TD error.
    """
    def loss_fn(critic_params):
        current_Q1, current_Q2 = apply_double_critic_model(
            critic_params, state, action, False)
        critic_loss = double_mse(current_Q1, current_Q2, target_Q)
        return critic_loss.mean()

    grad = jax.grad(loss_fn)(optimizer.target)
    grad = clip_grads(grad, 40.0)
    return optimizer.apply_gradient(grad)
    def compute_grads_and_update(self, batch, max_grad_norm, new_rng,
                                 train_loss_fn, train_state):
        """Compute grads and updates parameters.

    Args:
      batch: dict; Batch of examples.
      max_grad_norm: float; Max value for grad norm (used for grad clipping).
      new_rng: Jax RNG key.
      train_loss_fn: fn(params)--> loss; Loss function (for which grad is
        computed).
      train_state: TrainState, the state of training including the current
        global_step, model_state, rng, and optimizer.

    Returns:
      Updated state of training and calculated metrics.
    """

        # Compute learning rate:
        lr = self.get_learning_rate(train_state.global_step)

        compute_gradient_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
        (_, (new_model_state,
             logits)), grad = compute_gradient_fn(train_state.optimizer.target)
        # re-use same axis_name as in the call to `pmap(...train_step...)` below
        grad = jax.lax.pmean(grad, axis_name='batch')

        if max_grad_norm is not None:
            grad = clip_grads(grad, max_grad_norm)

        new_optimizer = train_state.optimizer.apply_gradient(grad,
                                                             learning_rate=lr)
        new_train_state = train_state.replace(
            global_step=train_state.global_step + 1,
            optimizer=new_optimizer,
            model_state=new_model_state,
            rng=new_rng)

        metric_dict = self.collect_metrics(batch, logits, lr)

        return new_train_state, metric_dict
Beispiel #12
0
                             split='train' if args.data_augment else 'test',
                             seed=seed * args.seed_separator + i,
                             num_classes=args.num_classes)
        X, Y = batch['image'], batch['label']
        params_curr = get_params(state)
        loss_curr, grad_curr = value_and_grad_loss(params_curr, X, Y)
        # monitor gradient norm
        grad_norm = optimizers.l2_norm(grad_curr)
        writer.add_scalar(f'grad_norm/{tb_flag}', grad_norm.item(),
                          global_step)
        if np.isnan(loss_curr):
            sys.exit()
        running_loss += loss_curr
        running_count += 1
        if args.grad_norm_thresh > 0:
            grad_curr = optimizers.clip_grads(grad_curr, args.grad_norm_thresh)
        state = opt_apply(epoch, grad_curr, state)
        global_step += 1
        print(
            f"Step {global_step}, training loss={loss_curr:.4f}, grad norm={grad_norm:.4f}"
        )

        # Evaluate on the test set
        if global_step % args.save_steps == 0 \
                or global_step % args.early_save_steps == 0 and global_step <= args.early_save_till_step:
            test_loader = tfds.as_numpy(test_data.batch(args.batch_size_test))
            acc_f, loss_test = 0., 0.
            acc_g, loss_test_g = 0., 0.
            params_curr = get_params(state)
            start_ind = 0
            for j, test_batch in enumerate(test_loader):
    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """

        quality = 1.0
        if self.hparams["meta"]["dataset"] == "mnist":  # TODO: make it generic
            batch_data = next(self.data_loader)
        else:
            batch_data = None

        ants_norm_grads = [5.0 for _ in range(self.hparams["n_wall_ants"])]
        ants_loss_values = [5.0 for _ in range(self.hparams["n_wall_ants"])]
        ants_state = [self._state for _ in range(self.hparams["n_wall_ants"])]
        ants_bparam = [
            self._bparam for _ in range(self.hparams["n_wall_ants"])
        ]
        for i_n in range(self.hparams["n_wall_ants"]):
            corrector_omega = 1.0
            stop = False
            _, key = random.split(
                random.PRNGKey(self.key_state + i_n +
                               npr.randint(1, (i_n + 1) * 10)))
            del _
            self._parc_vec, self.state_stack = self._perform_perturb_by_projection(
                self._state_secant_vector,
                self._state_secant_c2,
                key,
                self.pred_prev_state,
                self._state,
                self._bparam,
                i_n,
                self.sphere_radius,
                batch_data,
            )
            if self.hparams["_evaluate_perturb"]:
                self._evaluate_perturb()  # does every time

            ants_state[i_n] = self.state_stack["state"]
            ants_bparam[i_n] = self.state_stack["bparam"]
            D_values = []
            print(f"num_batches", self.num_batches)
            for j_epoch in range(self.descent_period):
                for b_j in range(self.num_batches):

                    #alternate
                    # grads = self.compute_grad_fn(self._state, self._bparam, batch_data)
                    # self._state = self.opt.update_params(self._state, grads[0])
                    state_grads, bparam_grads = self.compute_min_grad_fn(
                        ants_state[i_n],
                        ants_bparam[i_n],
                        self._lagrange_multiplier,
                        self._state_secant_c2,
                        self._state_secant_vector,
                        batch_data,
                        self.delta_s,
                    )

                    if self.hparams["adaptive"]:
                        self.opt.lr = self.exp_decay(
                            j_epoch, self.hparams["natural_lr"])
                        quality = l2_norm(state_grads)  #l2_norm(bparam_grads)
                        if self.hparams[
                                "local_test_measure"] == "norm_gradients":
                            if quality > self.hparams["quality_thresh"]:
                                pass
                                print(
                                    f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j_epoch}"
                                )
                            else:
                                stop = True
                                print(
                                    f"quality {quality} stopping at , {j_epoch}th step"
                                )
                        else:
                            print(
                                f"quality {quality}, {bparam_grads} ,{j_epoch}"
                            )
                            if len(D_values) >= 20:
                                tmp_means = running_mean(D_values, 10)
                                if (math.isclose(
                                        tmp_means[-1],
                                        tmp_means[-2],
                                        abs_tol=self.hparams["loss_tol"])):
                                    print(
                                        f"stopping at , {j_epoch}th step, {ants_bparam[i_n]} bparam"
                                    )
                                    stop = True

                        state_grads = clip_grads(state_grads,
                                                 self.hparams["max_clip_grad"])
                        bparam_grads = clip_grads(
                            bparam_grads, self.hparams["max_clip_grad"])

                    if self.hparams["guess_ant_steps"] >= (
                            j_epoch + 1):  # To get around folds slowly
                        corrector_omega = min(
                            self.hparams["guess_ant_steps"] / (j_epoch + 1),
                            1.5)
                    else:
                        corrector_omega = max(
                            self.hparams["guess_ant_steps"] / (j_epoch + 1),
                            0.05)

                    ants_state[i_n] = self.opt.update_params(
                        ants_state[i_n], state_grads, j_epoch)
                    ants_bparam[i_n] = self.opt.update_params(
                        ants_bparam[i_n], bparam_grads, j_epoch)
                    ants_loss_values[i_n] = self.value_fn(
                        ants_state[i_n], ants_bparam[i_n], batch_data)
                    D_values.append(ants_loss_values[i_n])
                    ants_norm_grads[i_n] = quality
                    # if stop:
                    #     break
                    if (self.hparams["meta"]["dataset"] == "mnist"
                        ):  # TODO: make it generic
                        batch_data = next(self.data_loader)
                if stop:
                    break

        # ants_group = dict(enumerate(grouper(ants_state, tolerence), 1))
        # print(f"Number of groups: {len(ants_group)}")
        cheapest_index = get_cheapest_ant(
            ants_norm_grads,
            ants_loss_values,
            local_test=self.hparams["local_test_measure"])
        self._state = ants_state[cheapest_index]
        self._bparam = ants_bparam[cheapest_index]
        value = self.value_fn(self._state, self._bparam,
                              batch_data)  # Todo: why only final batch data

        _, _, test_images, test_labels = mnist(permute_train=False,
                                               resize=True,
                                               filter=self.hparams["filter"])
        del _
        val_loss = self.value_fn(self._state, self._bparam,
                                 (test_images, test_labels))
        print(f"val loss: {val_loss}")

        return self._state, self._bparam, quality, value, val_loss, corrector_omega
Beispiel #14
0
def m_step(
    rngs: PRNGSequence,
    actor_optimizer: optim.Optimizer,
    actor_target_params: FrozenDict,
    eps_mu: float,
    eps_sig: float,
    mu_lagrange_optimizer: optim.Optimizer,
    sig_lagrange_optimizer: optim.Optimizer,
    max_action: float,
    action_dim: int,
    state: jnp.ndarray,
    weights: jnp.ndarray,
    sampled_actions: jnp.ndarray,
) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]:
    """
    The 'M-step' from the MPO paper. We optimize our policy network to maximize
    the lower bound on the probablility of obtaining the maximum reward given
    that we act according to our policy (i.e. weighted according to our sampled actions).
    """

    def loss_fn(mlo, slo, actor_params):
        # get the distribution of the actor network (current policy)
        mu, log_sig = apply_gaussian_policy_model(
            actor_params, action_dim, max_action, state, None, False, True
        )
        sig = jnp.exp(log_sig)
        # get the distribution of the target network (old policy)
        target_mu, target_log_sig = apply_gaussian_policy_model(
            actor_target_params, action_dim, max_action, state, None, False, True
        )
        target_mu = jax.lax.stop_gradient(target_mu)
        target_log_sig = jax.lax.stop_gradient(target_log_sig)
        target_sig = jnp.exp(target_log_sig)

        # get the log likelihooods of the sampled actions according to the
        # decoupled distributions. described in section 4.2.1 of
        # Relative Entropy Regularized Policy Iteration
        # this ensures that the nonparametric policy won't collapse to give
        # a probability of 1 to the best action, which is a risk when we use
        # the on-policy distribution to calculate the likelihood.
        actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, log_sig)
        actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_log_sig)
        actor_log_prob = actor_log_prob.transpose((0, 1))

        mu_kl = kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean()
        sig_kl = kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean()

        mlo = mu_lagrange_step(mlo, eps_mu - jax.lax.stop_gradient(mu_kl))
        slo = sig_lagrange_step(slo, eps_sig - jax.lax.stop_gradient(sig_kl))

        # maximize the log likelihood, regularized by the divergence between
        # the target policy and the current policy. the goal here is to fit
        # the parametric policy to have the minimum divergence with the nonparametric
        # distribution based on the sampled actions.
        actor_loss = -(actor_log_prob * weights).sum(axis=1).mean()
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(mlo.target, 1.0, True)
        ) * (eps_mu - mu_kl)
        actor_loss -= jax.lax.stop_gradient(
            apply_constant_model(slo.target, 100.0, True)
        ) * (eps_sig - sig_kl)
        return actor_loss.mean(), (mlo, slo)

    grad, (mu_lagrange_optimizer, sig_lagrange_optimizer) = jax.grad(
        partial(loss_fn, mu_lagrange_optimizer, sig_lagrange_optimizer), has_aux=True
    )(actor_optimizer.target)
    grad = clip_grads(grad, 40.0)

    actor_optimizer = actor_optimizer.apply_gradient(grad)

    return mu_lagrange_optimizer, sig_lagrange_optimizer, actor_optimizer
Beispiel #15
0
def postprocess_gradients(gradients):
    return optimizers.clip_grads(gradients, 1.0)
Beispiel #16
0
    def correction_step(self) -> Tuple:
        """Given the current state optimize to the correct state.

        Returns:
          (state: problem parameters, bparam: continuation parameter) Tuple
        """
        _, key = random.split(random.PRNGKey(self.key_state + npr.randint(1, 100)))
        del _
        quality = 1.0
        N_opt = 10
        stop = False
        corrector_omega = 1.0
        # bparam_grads = pytree_zeros_like(self._bparam)
        print("the radius", self.sphere_radius)
        self._parc_vec, self.state_stack = self._perform_perturb_by_projection(
            self._state_secant_vector,
            self._state_secant_c2,
            key,
            self.pred_prev_state,
            self._state,
            self._bparam,
            self.sphere_radius,
        )
        if self.hparams["_evaluate_perturb"]:
            self._evaluate_perturb()  # does every time

        for j in range(self.descent_period):
            for b_j in range(self.num_batches):
                if self.hparams["meta"]["dataset"] == "mnist":  # TODO: make it generic
                    batch_data = next(self.data_loader)
                else:
                    batch_data = None
                # grads = self.compute_grad_fn(self._state, self._bparam, batch_data)
                # self._state = self.opt.update_params(self._state, grads[0])
                state_grads, bparam_grads = self.compute_min_grad_fn(
                    self._state,
                    self._bparam,
                    self._lagrange_multiplier,
                    self._state_secant_c2,
                    self._state_secant_vector,
                    batch_data,
                    self.delta_s,
                )

                if self.hparams["adaptive"]:
                    self.opt.lr = self.exp_decay(j, self.hparams["natural_lr"])
                    quality = l2_norm(state_grads)  # +l2_norm(bparam_grads)
                    if quality > self.hparams["quality_thresh"]:
                        pass
                        # print(f"quality {quality}, {self.opt.lr}, {bparam_grads} ,{j}")
                    else:
                        if N_opt > (j + 1):  # To get around folds slowly
                            corrector_omega = min(N_opt / (j + 1), 2.0)
                        else:
                            corrector_omega = max(N_opt / (j + 1), 0.5)
                        stop = True
                        print(f"quality {quality} stopping at , {j}th step")
                    state_grads = clip_grads(state_grads, self.hparams["max_clip_grad"])
                    bparam_grads = clip_grads(
                        bparam_grads, self.hparams["max_clip_grad"]
                    )

                self._bparam = self.opt.update_params(self._bparam, bparam_grads, j)
                self._state = self.opt.update_params(self._state, state_grads, j)
                if stop:
                    break
            if stop:
                break

        value = self.value_fn(
            self._state, self._bparam, batch_data
        )  # Todo: why only final batch data
        return self._state, self._bparam, quality, value, corrector_omega