예제 #1
0
class LBFGSNoisyOptimizer(BaseOptimizer):
    def __init__(self,
                 oracle: BaseConditionalGenerationOracle,
                 x: torch.Tensor,
                 lr: float = 1e-1,
                 memory_size: int = 5,
                 line_search='Wolfe',
                 lr_algo='None',
                 *args,
                 **kwargs):
        super().__init__(oracle, x, *args, **kwargs)
        self._line_search = line_search
        self._lr = lr
        self._alpha_k = None
        self._lr_algo = lr_algo  # None, grad, dim
        if not (lr_algo in ["None", "Grad", "Dim"]):
            ValueError("lr_algo is not right")
        if self._x_step:
            self._optimizer = LBFGS(params=[self._x],
                                    lr=self._x_step / 10.,
                                    line_search=line_search,
                                    history_size=memory_size)
        else:
            self._optimizer = LBFGS(params=[self._x],
                                    lr=self._lr,
                                    line_search=line_search,
                                    history_size=memory_size)

    def _step(self):
        x_k = self._x.detach().clone()
        x_k.requires_grad_(True)
        self._optimizer.param_groups[0]['params'][0] = x_k
        init_time = time.time()
        f_k = self._oracle.func(x_k, num_repetitions=self._num_repetitions)
        g_k = self._oracle.grad(x_k, num_repetitions=self._num_repetitions)
        grad_normed = g_k  # (g_k / g_k.norm())
        self._state_dict = copy.deepcopy(self._optimizer.state_dict())

        if self._lr_algo == "None":
            self._optimizer.param_groups[0]['lr'] = self._x_step
        elif self._lr_algo == "Grad":
            self._optimizer.param_groups[0]['lr'] = self._x_step / g_k.norm(
            ).item()
        elif self._lr_algo == "Dim":
            self._optimizer.param_groups[0]['lr'] = self._x_step / np.sqrt(
                chi2.ppf(0.95, df=len(g_k)))
        # define closure for line search
        def closure():
            self._optimizer.zero_grad()
            loss = self._oracle.func(x_k,
                                     num_repetitions=self._num_repetitions)
            return loss

        # two-loop recursion to compute search direction
        p = self._optimizer.two_loop_recursion(-grad_normed)
        options = {
            'closure': closure,
            'current_loss': f_k,
            'interpolate': False
        }
        if self._line_search == 'Wolfe':
            lbfg_opt = self._optimizer.step(p, grad_normed, options=options)
            f_k, d_k, lr = lbfg_opt[0], lbfg_opt[1], lbfg_opt[2]
        elif self._line_search == 'Armijo':
            lbfg_opt = self._optimizer.step(p, grad_normed, options=options)
            f_k, lr = lbfg_opt[0], lbfg_opt[1]
            d_k = -g_k
        elif self._line_search == 'None':
            # self._optimizer.param_groups[0]['lr'] = 1.
            d_k = -g_k
            lbfg_opt = self._optimizer.step(p, grad_normed, options=options)
            lr = lbfg_opt
        g_k = self._oracle.grad(x_k, num_repetitions=self._num_repetitions)
        grad_normed = g_k  # (g_k / g_k.norm())
        self._optimizer.curvature_update(grad_normed, eps=0.2, damping=False)
        self._lbfg_opt = lbfg_opt
        grad_norm = d_k.norm().item()
        self._x = x_k

        super()._post_step(init_time=init_time)

        if grad_norm < self._tolerance:
            return SUCCESS
        if not (torch.isfinite(x_k).all() and torch.isfinite(f_k).all()
                and torch.isfinite(d_k).all()):
            return COMP_ERROR

    def reverse_optimizer(self, **kwargs):
        self._optimizer.load_state_dict(self._state_dict)