def train_from_paths(self, paths):
        # Concatenate from all the trajectories
        observations = np.concatenate([path["observations"] for path in paths])
        actions = np.concatenate([path["actions"] for path in paths])
        advantages = np.concatenate([path["advantages"] for path in paths])
        # Advantage whitening
        advantages = (advantages - np.mean(advantages)) / (np.std(advantages) +
                                                           1e-6)
        # NOTE : advantage should be zero mean in expectation
        # normalized step size invariant to advantage scaling,
        # but scaling can help with least squares

        self.n_steps += len(advantages)

        # cache return distributions for the paths
        path_returns = [sum(p["rewards"]) for p in paths]
        mean_return = np.mean(path_returns)
        std_return = np.std(path_returns)
        min_return = np.amin(path_returns)
        max_return = np.amax(path_returns)
        base_stats = [mean_return, std_return, min_return, max_return]
        self.running_score = mean_return if self.running_score is None else \
                             0.9*self.running_score + 0.1*mean_return  # approx avg of last 10 iters
        if self.save_logs: self.log_rollout_statistics(paths)

        # Keep track of times for various computations
        t_gLL = 0.0
        t_FIM = 0.0

        self.optim.zero_grad()

        # Optimization. Negate gradient since the optimizer is minimizing.
        vpg_grad = -self.flat_vpg(observations, actions, advantages)
        vector_to_gradients(Variable(torch.from_numpy(vpg_grad).float()),
                            self.policy.trainable_params)

        closure = self.kl_closure(self.policy, observations, actions,
                                  self.policy_kl_fn)
        info = self.optim.step(closure)
        self.policy.set_param_values(self.policy.get_param_values())

        # Log information
        if self.save_logs:
            self.logger.log_kv('alpha', info['alpha'])
            self.logger.log_kv('delta', info['delta'])
            # self.logger.log_kv('time_vpg', t_gLL)
            # self.logger.log_kv('time_npg', t_FIM)
            # self.logger.log_kv('kl_dist', kl_dist)
            # self.logger.log_kv('surr_improvement', surr_after - surr_before)
            self.logger.log_kv('running_score', self.running_score)
            self.logger.log_kv('steps', self.n_steps)

            try:
                success_rate = self.env.env.env.evaluate_success(paths)
                self.logger.log_kv('success_rate', success_rate)
            except:
                pass

        return base_stats
Beispiel #2
0
    def step(
            self,
            closure,
            execute_update=True):  #Fvp_fn, execute_update=True, closure=None):
        """Performs a single optimization step.

        Arguments:
            Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length
                equal to the number of model paramsters and returns the Fisher-vector product.
        """
        state = self.state
        # State initialization
        if len(state) == 0:
            state['step'] = 0
            # Set shrinkage to defaults, i.e. no shrinkage
            state['rho'] = 0.0
            state['diag_shrunk'] = 1.0

        state['step'] += 1

        # Get flat grad
        g = gradients_to_vector(self._params)

        if 'ng_prior' not in state:
            state['ng_prior'] = torch.zeros_like(g)

        curv_type = self._param_group['curv_type']
        if curv_type not in self.valid_curv_types:
            raise ValueError("Invalid curv_type.")

        # Create closure to pass to Lanczos and CG
        if curv_type == 'fisher':
            Fvp_theta_fn = make_fvp_fun(closure, self._params)
        elif curv_type == 'gauss_newton':
            Fvp_theta_fn = make_gnvp_fun(closure, self._params)

        shrinkage_method = self._param_group['shrinkage_method']
        lanczos_amortization = self._param_group['lanczos_amortization']
        if shrinkage_method == 'lanczos' and (state['step'] -
                                              1) % lanczos_amortization == 0:
            # print ("Computing Lanczos shrinkage at step ", state['step'])
            w = lanczos_iteration(Fvp_theta_fn,
                                  self._numel(),
                                  k=self._param_group['lanczos_iters'])
            rho, diag_shrunk = estimate_shrinkage(
                w, self._numel(), self._param_group['batch_size'])
            state['rho'] = rho
            state['diag_shrunk'] = diag_shrunk

        M = None
        if self._param_group['cg_precondition_empirical']:
            # Empirical Fisher is g * g
            M = (g * g + self._param_group['cg_precondition_regu_coef'] *
                 torch.ones_like(g))**self._param_group['cg_precondition_exp']

        # Do CG solve with hvp fn closure
        extract_tridiag = self._param_group['shrinkage_method'] == 'cg'
        cg_result = cg_solve(
            Fvp_theta_fn,
            g.data.clone(),
            x_0=self._param_group['cg_prev_init_coef'] * state['ng_prior'],
            M=M,
            cg_iters=self._param_group['cg_iters'],
            cg_residual_tol=self._param_group['cg_residual_tol'],
            shrunk=self._param_group['shrinkage_method'] is not None,
            rho=state['rho'],
            Dshrunk=state['diag_shrunk'],
            extract_tridiag=extract_tridiag)

        if extract_tridiag:
            # print ("Computing CG shrinkage at step ", state['step'])
            ng, (diag_elems, off_diag_elems) = cg_result
            w = eigvalsh_tridiagonal(diag_elems, off_diag_elems)
            rho, diag_shrunk = estimate_shrinkage(
                w, self._numel(), self._param_group['batch_size'])
            state['rho'] = rho
            state['diag_shrunk'] = diag_shrunk
        else:
            ng = cg_result

        state['ng_prior'] = ng.data.clone()

        # Normalize NG
        lr = self._param_group['lr']
        alpha = torch.sqrt(torch.abs(lr / (torch.dot(g, ng) + 1e-20)))

        # Unflatten grad
        vector_to_gradients(ng, self._params)

        if execute_update:
            # Apply step
            for p in self._params:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                p.data.add_(-alpha, d_p)

        return dict(alpha=alpha, delta=lr, natural_grad=ng)
    def step(self, closure, execute_update=True):
        """Performs a single optimization step.

        Arguments:
            Fvp_fn (callable): A closure that accepts a vector of length equal to the number of
                model paramsters and returns the Fisher-vector product.
        """
        state = self.state
        param_vec = parameters_to_vector(self._params)
        # State initialization
        if len(state) == 0:
            state['step'] = 0
            # Exponential moving average of gradient values
            state['m'] = torch.zeros_like(param_vec.data)
            # Maintain adaptive preconditioner if needed
            if self._param_group['cg_precondition_empirical']:
                state['M'] = torch.zeros_like(param_vec.data)
            # Set shrinkage to defaults, i.e. no shrinkage
            state['rho'] = 0.0
            state['diag_shrunk'] = 1.0

        m = state['m']
        beta1, beta2 = self._param_group['betas']
        state['step'] += 1

        bias_correction1 = 1 - beta1**state['step']
        bias_correction2 = 1 - beta2**state['step']

        # Get flat grad
        g = gradients_to_vector(self._params)

        # Update moving average mean
        m.mul_(beta1).add_(1 - beta1, g)
        g_hat = m / bias_correction1

        theta = parameters_to_vector(self._params)
        theta_old = parameters_to_vector(self._params_old)

        if 'ng_prior' not in state:
            state['ng_prior'] = torch.zeros_like(g_hat)  #g_hat.data.clone()
        if 'max_fisher_spectral_norm' not in state:
            state['max_fisher_spectral_norm'] = 0.0

        curv_type = self._param_group['curv_type']
        if curv_type not in self.valid_curv_types:
            raise ValueError("Invalid curv_type.")

        if curv_type == 'fisher':
            weighted_fvp_fn_div_beta2 = self._make_combined_fvp_fun(
                closure,
                self._params,
                self._params_old,
                bias_correction2=bias_correction2)
        elif curv_type == 'gauss_newton':
            weighted_fvp_fn_div_beta2 = self._make_combined_gnvp_fun(
                closure,
                self._params,
                self._params_old,
                bias_correction2=bias_correction2)

        fisher_norm = lanczos_iteration(weighted_fvp_fn_div_beta2,
                                        self._numel(),
                                        k=1)[0]
        is_max_norm = fisher_norm > state['max_fisher_spectral_norm'] or state[
            'step'] == 1
        if is_max_norm:
            state['max_fisher_spectral_norm'] = fisher_norm

        if is_max_norm:
            if self._param_group['assume_locally_linear']:
                # Update theta_old beta2 portion towards theta
                theta_old = beta2 * theta_old + (1 - beta2) * theta
            else:
                # Do linesearch first to update theta_old. Then can do CG with only one HVP at each itr.
                ng = self.state['ng_prior'].clone(
                ) if state['step'] > 1 else g_hat.data.clone()
                if curv_type == 'fisher':
                    weighted_fvp_fn = self._make_combined_fvp_fun(
                        closure, self._params, self._params_old)
                    f = make_fvp_obj_fun(closure, weighted_fvp_fn, ng)
                elif curv_type == 'gauss_newton':
                    weighted_fvp_fn = self._make_combined_gnvp_fun(
                        closure, self._params, self._params_old)
                    f = make_gnvp_obj_fun(closure, weighted_fvp_fn, ng)
                xmin, fmin, alpha = randomized_linesearch(
                    f, theta_old.data, theta.data)
                theta_old = Variable(xmin.float())
            vector_to_parameters(theta_old, self._params_old)

        # Now that theta_old has been updated, do CG with only theta old
        # If not max norm, then this will remain the old params.
        if curv_type == 'fisher':
            fvp_fn_div_beta2 = make_fvp_fun(closure,
                                            self._params_old,
                                            bias_correction2=bias_correction2)
        elif curv_type == 'gauss_newton':
            fvp_fn_div_beta2 = make_gnvp_fun(closure,
                                             self._params_old,
                                             bias_correction2=bias_correction2)

        shrinkage_method = self._param_group['shrinkage_method']
        lanczos_amortization = self._param_group['lanczos_amortization']
        if shrinkage_method == 'lanczos' and (state['step'] -
                                              1) % lanczos_amortization == 0:
            # print ("Computing Lanczos shrinkage at step ", state['step'])
            w = lanczos_iteration(fvp_fn_div_beta2,
                                  self._numel(),
                                  k=self._param_group['lanczos_iters'])
            rho, diag_shrunk = estimate_shrinkage(
                w, self._numel(), self._param_group['batch_size'])
            state['rho'] = rho
            state['diag_shrunk'] = diag_shrunk

        M = None
        if self._param_group['cg_precondition_empirical']:
            # Empirical Fisher is g * g
            V = state['M']
            Mt = (g * g + self._param_group['cg_precondition_regu_coef'] *
                  torch.ones_like(g))**self._param_group['cg_precondition_exp']
            Vhat = V.mul(beta2).add(1 - beta2, Mt) / bias_correction2
            V = torch.max(V, Vhat)
            M = V

        extract_tridiag = self._param_group['shrinkage_method'] == 'cg'
        cg_result = cg_solve(
            fvp_fn_div_beta2,
            g_hat.data.clone(),
            x_0=self._param_group['cg_prev_init_coef'] * state['ng_prior'],
            M=M,
            cg_iters=self._param_group['cg_iters'],
            cg_residual_tol=self._param_group['cg_residual_tol'],
            shrunk=self._param_group['shrinkage_method'] is not None,
            rho=state['rho'],
            Dshrunk=state['diag_shrunk'],
            extract_tridiag=extract_tridiag)

        if extract_tridiag:
            # print ("Computing CG shrinkage at step ", state['step'])
            ng, (diag_elems, off_diag_elems) = cg_result
            w = eigvalsh_tridiagonal(diag_elems, off_diag_elems)
            rho, diag_shrunk = estimate_shrinkage(
                w, self._numel(), self._param_group['batch_size'])
            state['rho'] = rho
            state['diag_shrunk'] = diag_shrunk
        else:
            ng = cg_result

        self.state['ng_prior'] = ng.data.clone()

        # Normalize NG
        lr = self._param_group['lr']
        alpha = torch.sqrt(torch.abs(lr / (torch.dot(g_hat, ng) + 1e-20)))

        # Unflatten grad
        vector_to_gradients(ng, self._params)

        if execute_update:
            # Apply step
            for p in self._params:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                p.data.add_(-alpha, d_p)

        return dict(alpha=alpha, delta=lr, natural_grad=ng)
Beispiel #4
0
    def step(self, closure, execute_update=True):
        """Performs a single optimization step.

        Arguments:
            Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length
                equal to the number of model paramsters and returns the Fisher-vector product.
        """

        # Update theta old for all blocks first, only approx update is supported
        params_i = 0
        params_j = 0

        for gi, group in enumerate(self.param_groups):
            params = group['params']
            params_j += len(params)

            num_params = self._numel(gi, params)
            # print ("num_params: ", num_params, params_i, params_j)

            state = self.state[gi]
            if len(state) == 0:
                state['step'] = 0
                # Exponential moving average of gradient values
                state['m'] = torch.zeros(num_params)
                # Maintain adaptive preconditioner if needed
                if group['cg_precondition_empirical']:
                    state['M'] = torch.zeros(num_params)
                # Set shrinkage to defaults, i.e. no shrinkage
                state['rho'] = 0.0
                state['diag_shrunk'] = 1.0
                state['lagged'] = []
                for i in range(len(params)):
                    state['lagged'].append(params[i] +
                                           torch.randn(params[i].shape) *
                                           0.0001)

            beta1, beta2 = group['betas']

            theta = parameters_to_vector(params)
            theta_old = parameters_to_vector(state['lagged'])

            # Update theta_old beta2 portion towards theta
            theta_old = beta2 * theta_old + (1 - beta2) * theta
            vector_to_parameters(theta_old, state['lagged'])
            # print (theta_old)
            # input("")

        info = {}

        # If doing block diag, perform the update for each param group
        params_i = 0
        params_j = 0

        for gi, group in enumerate(self.param_groups):
            params = group['params']
            params_j += len(params)

            num_params = self._numel(gi, params)

            # NOTE: state is initialized above
            state = self.state[gi]

            m = state['m']
            beta1, beta2 = group['betas']
            state['step'] += 1
            params_old = state['lagged']  #

            bias_correction1 = 1 - beta1**state['step']
            bias_correction2 = 1 - beta2**state['step']

            # Get flat grad
            g = gradients_to_vector(params)

            # Update moving average mean
            m.mul_(beta1).add_(1 - beta1, g)
            g_hat = m / bias_correction1

            if 'ng_prior' not in state:
                state['ng_prior'] = torch.zeros_like(
                    g)  #g_hat) #g_hat.data.clone()

            curv_type = group['curv_type']
            if curv_type not in self.valid_curv_types:
                raise ValueError("Invalid curv_type.")

            # Now that theta_old has been updated, do CG with only theta old
            if curv_type == 'fisher':
                fvp_fn_div_beta2 = make_fvp_fun_idx(
                    closure,
                    params_old,
                    params_i,
                    params_j,
                    bias_correction2=bias_correction2)
            elif curv_type == 'gauss_newton':
                fvp_fn_div_beta2 = make_gnvp_fun(
                    closure, params_old, bias_correction2=bias_correction2)

            shrinkage_method = group['shrinkage_method']
            lanczos_amortization = group['lanczos_amortization']
            if shrinkage_method == 'lanczos' and (
                    state['step'] - 1) % lanczos_amortization == 0:
                # print ("Computing Lanczos shrinkage at step ", state['step'])
                w = lanczos_iteration(fvp_fn_div_beta2,
                                      num_params,
                                      k=group['lanczos_iters'])
                rho, diag_shrunk = estimate_shrinkage(w, num_params,
                                                      group['batch_size'])
                state['rho'] = rho
                state['diag_shrunk'] = diag_shrunk

            M = None
            if group['cg_precondition_empirical']:
                # Empirical Fisher is g * g
                V = state['M']
                Mt = (g * g + group['cg_precondition_regu_coef'] *
                      torch.ones_like(g))**group['cg_precondition_exp']
                Vhat = V.mul(beta2).add(1 - beta2, Mt) / bias_correction2
                V = torch.max(V, Vhat)
                M = V

            extract_tridiag = group['shrinkage_method'] == 'cg'
            cg_result = cg_solve(fvp_fn_div_beta2,
                                 g_hat.data.clone(),
                                 x_0=group['cg_prev_init_coef'] *
                                 state['ng_prior'],
                                 M=M,
                                 cg_iters=group['cg_iters'],
                                 cg_residual_tol=group['cg_residual_tol'],
                                 shrunk=group['shrinkage_method'] is not None,
                                 rho=state['rho'],
                                 Dshrunk=state['diag_shrunk'],
                                 extract_tridiag=extract_tridiag)

            if extract_tridiag:
                # print ("Computing CG shrinkage at step ", state['step'])
                ng, (diag_elems, off_diag_elems) = cg_result
                w = eigvalsh_tridiagonal(diag_elems, off_diag_elems)
                rho, diag_shrunk = estimate_shrinkage(w, num_params,
                                                      group['batch_size'])
                state['rho'] = rho
                state['diag_shrunk'] = diag_shrunk
            else:
                ng = cg_result
            # print ("NG: ", ng)

            state['ng_prior'] = ng.data.clone()

            # Normalize NG
            lr = group['lr']
            alpha = torch.sqrt(torch.abs(lr / (torch.dot(g_hat, ng) + 1e-20)))

            # Unflatten grad
            vector_to_gradients(ng, params)

            if execute_update:
                # Apply step
                for p in params:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    p.data.add_(-alpha, d_p)

            params_i = params_j
            info[gi] = dict(alpha=alpha, delta=lr, natural_grad=ng)

        return info
Beispiel #5
0
    def step(self, loss): #, closure=None):
        """Performs a single optimization step.

        Arguments:
            Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length
                equal to the number of model paramsters and returns the Fisher-vector product.
        """
        state = self.state
        # State initialization
        if len(state) == 0:
            state['step'] = 0
            state['m'] = torch.zeros((self._numel(),))
            state['Ft'] = torch.zeros((self._numel(), self._numel()))

        state['step'] += 1

        # Get flat grad
        g = gradients_to_vector(self._params)

        # shrunk = self._param_group['shrunk']

        # Compute Fisher
        Gt = self.H(self._params, loss)

        if self.adaptive:
            m = state['m']
            Ft = state['Ft']

            beta1, beta2 = self._param_group['betas']
            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']

            m.mul_(beta1).add_(1 - beta1, g)
            g_hat = m / bias_correction1

            Ft.mul_(beta2).add_(1 - beta2, Gt)
            Ft_hat = Ft / bias_correction2

            ng = torch.pinverse(Ft_hat) @ g_hat
            H = Ft_hat
            # alpha = float(torch.sqrt(ng.dot(ng)) / (ng.view(-1, 1).t() @ Ft_hat @ ng.view(-1, 1)))
        else:
            ng = torch.pinverse(Gt) @ g
            H = Gt
            # alpha = float(torch.sqrt(ng.dot(ng)) / (ng.view(-1, 1).t() @ Gt @ ng.view(-1, 1)))

        lr = self._param_group['lr']
        alpha = torch.sqrt(torch.abs(lr / (torch.dot(g, ng) + 1e-20)))

        # alpha *= 0.1
        # Unflatten grad
        vector_to_gradients(ng, self._params)

        # Apply step
        for p in self._params:
            if p.grad is None:
                continue
            d_p = p.grad.data
            p.data.add_(-alpha, d_p)

        return dict(alpha=alpha, H=H.clone()) #, delta=lr)
Beispiel #6
0
    def train_from_paths(self, paths):

        # Concatenate from all the trajectories
        observations = np.concatenate([path["observations"] for path in paths])
        actions = np.concatenate([path["actions"] for path in paths])
        advantages = np.concatenate([path["advantages"] for path in paths])
        # Advantage whitening
        advantages = (advantages - np.mean(advantages)) / (np.std(advantages) +
                                                           1e-6)
        # NOTE : advantage should be zero mean in expectation
        # normalized step size invariant to advantage scaling,
        # but scaling can help with least squares

        self.n_steps += len(advantages)

        # cache return distributions for the paths
        path_returns = [sum(p["rewards"]) for p in paths]
        mean_return = np.mean(path_returns)
        std_return = np.std(path_returns)
        min_return = np.amin(path_returns)
        max_return = np.amax(path_returns)
        base_stats = [mean_return, std_return, min_return, max_return]
        self.running_score = mean_return if self.running_score is None else \
                             0.9*self.running_score + 0.1*mean_return  # approx avg of last 10 iters
        if self.save_logs: self.log_rollout_statistics(paths)

        # Keep track of times for various computations
        t_gLL = 0.0
        t_FIM = 0.0

        # Optimization algorithm
        # --------------------------
        self.optim.zero_grad()

        surr_before = self.CPI_surrogate(observations, actions,
                                         advantages).data.numpy().ravel()[0]

        # VPG
        ts = timer.time()
        vpg_grad = self.flat_vpg(observations, actions, advantages)
        vector_to_gradients(Variable(torch.from_numpy(vpg_grad).float()),
                            self.policy.trainable_params)
        t_gLL += timer.time() - ts

        # NPG
        # Note: unlike the standard NPG, negation is not needed here since the optimizer does not
        # apply the update step.
        ts = timer.time()
        closure = self.kl_closure(self.policy, observations, actions,
                                  self.policy_kl_fn)
        info = self.optim.step(closure, execute_update=False)
        npg_grad = info['natural_grad'].data.numpy()
        t_FIM += timer.time() - ts

        # Step size computation
        # --------------------------
        n_step_size = 2.0 * self.kl_dist
        alpha = np.sqrt(
            np.abs(n_step_size / (np.dot(vpg_grad.T, npg_grad) + 1e-20)))

        # Policy update
        # --------------------------
        curr_params = self.policy.get_param_values()
        for k in range(100):
            new_params = curr_params + alpha * npg_grad
            self.policy.set_param_values(new_params,
                                         set_new=True,
                                         set_old=False)
            kl_dist = self.kl_old_new(observations,
                                      actions).data.numpy().ravel()[0]
            surr_after = self.CPI_surrogate(observations, actions,
                                            advantages).data.numpy().ravel()[0]
            if kl_dist < self.kl_dist:
                break
            else:
                alpha = 0.9 * alpha  # backtrack
                # print("Step size too high. Backtracking. | kl = %f | surr diff = %f" % \
                # (kl_dist, surr_after-surr_before) )
            if k == 99:
                alpha = 0.0

        new_params = curr_params + alpha * npg_grad
        self.policy.set_param_values(new_params, set_new=True, set_old=False)
        kl_dist = self.kl_old_new(observations,
                                  actions).data.numpy().ravel()[0]
        surr_after = self.CPI_surrogate(observations, actions,
                                        advantages).data.numpy().ravel()[0]
        self.policy.set_param_values(new_params, set_new=True, set_old=True)

        # Log information
        if self.save_logs:
            self.logger.log_kv('alpha', alpha)
            self.logger.log_kv('delta', n_step_size)
            self.logger.log_kv('time_vpg', t_gLL)
            self.logger.log_kv('time_npg', t_FIM)
            self.logger.log_kv('kl_dist', kl_dist)
            self.logger.log_kv('surr_improvement', surr_after - surr_before)
            self.logger.log_kv('running_score', self.running_score)
            self.logger.log_kv('steps', self.n_steps)

        return base_stats
Beispiel #7
0
    def step(self, closure, execute_update=True):
        """Performs a single optimization step.

        Arguments:
            Fvp_fn (callable): A closure that accepts a vector of parameters and a vector of length
                equal to the number of model paramsters and returns the Fisher-vector product.
        """
        info = {}

        # If doing block diag, perform the update for each param group
        params_i = 0
        params_j = 0

        for gi, group in enumerate(self.param_groups):
            params = group['params']
            params_j += len(params)

            state = self.state[gi]
            if len(state) == 0:
                state['step'] = 0
                # Set shrinkage to defaults, i.e. no shrinkage
                state['rho'] = 0.0
                state['diag_shrunk'] = 1.0

            state['step'] += 1

            g = gradients_to_vector(params)

            if 'ng_prior' not in state:
                state['ng_prior'] = torch.zeros_like(g)

            curv_type = group['curv_type']
            if curv_type not in self.valid_curv_types:
                raise ValueError("Invalid curv_type.")

            # Create closure to pass to Lanczos and CG
            if curv_type == 'fisher':
                Fvp_theta_fn = make_fvp_fun_idx(closure, params, params_i,
                                                params_j)
            elif curv_type == 'gauss_newton':
                # Pass indices instead of actual params, since these params should be the same at
                # the model params anyway. Then the closure should set only the subset of params
                # and only return the tmp_params from that subset.
                # This would require that the param groups are order in a specific manner?
                Fvp_theta_fn = make_gnvp_fun_idx(closure, params, params_i,
                                                 params_j)

            num_params = self._numel(gi, params)

            shrinkage_method = group['shrinkage_method']
            lanczos_amortization = group['lanczos_amortization']
            if shrinkage_method == 'lanczos' and (
                    state['step'] - 1) % lanczos_amortization == 0:
                # print ("Computing Lanczos shrinkage at step ", state['step'])
                w = lanczos_iteration(Fvp_theta_fn,
                                      num_params,
                                      k=group['lanczos_iters'])
                rho, diag_shrunk = estimate_shrinkage(w, num_params,
                                                      group['batch_size'])
                state['rho'] = rho
                state['diag_shrunk'] = diag_shrunk

            M = None
            if group['cg_precondition_empirical']:
                # Empirical Fisher is g * g
                M = (g * g + group['cg_precondition_regu_coef'] *
                     torch.ones_like(g))**group['cg_precondition_exp']

            # Do CG solve with hvp fn closure
            extract_tridiag = group['shrinkage_method'] == 'cg'
            cg_result = cg_solve(Fvp_theta_fn,
                                 g.data.clone(),
                                 x_0=group['cg_prev_init_coef'] *
                                 state['ng_prior'],
                                 M=M,
                                 cg_iters=group['cg_iters'],
                                 cg_residual_tol=group['cg_residual_tol'],
                                 shrunk=group['shrinkage_method'] is not None,
                                 rho=state['rho'],
                                 Dshrunk=state['diag_shrunk'],
                                 extract_tridiag=extract_tridiag)

            if extract_tridiag:
                # print ("Computing CG shrinkage at step ", state['step'])
                ng, (diag_elems, off_diag_elems) = cg_result
                w = eigvalsh_tridiagonal(diag_elems, off_diag_elems)
                rho, diag_shrunk = estimate_shrinkage(w, num_params,
                                                      group['batch_size'])
                state['rho'] = rho
                state['diag_shrunk'] = diag_shrunk
            else:
                ng = cg_result

            state['ng_prior'] = ng.data.clone()

            # Normalize NG
            lr = group['lr']
            alpha = torch.sqrt(torch.abs(lr / (torch.dot(g, ng) + 1e-20)))

            # Unflatten grad
            vector_to_gradients(ng, params)

            if execute_update:
                # Apply step
                for p in params:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    p.data.add_(-alpha, d_p)

            params_i = params_j
            info[gi] = dict(alpha=alpha, delta=lr, natural_grad=ng)

        return info