class LBFGSNoisyOptimizer(BaseOptimizer): def __init__(self, oracle: BaseConditionalGenerationOracle, x: torch.Tensor, lr: float = 1e-1, memory_size: int = 5, line_search='Wolfe', lr_algo='None', *args, **kwargs): super().__init__(oracle, x, *args, **kwargs) self._line_search = line_search self._lr = lr self._alpha_k = None self._lr_algo = lr_algo # None, grad, dim if not (lr_algo in ["None", "Grad", "Dim"]): ValueError("lr_algo is not right") if self._x_step: self._optimizer = LBFGS(params=[self._x], lr=self._x_step / 10., line_search=line_search, history_size=memory_size) else: self._optimizer = LBFGS(params=[self._x], lr=self._lr, line_search=line_search, history_size=memory_size) def _step(self): x_k = self._x.detach().clone() x_k.requires_grad_(True) self._optimizer.param_groups[0]['params'][0] = x_k init_time = time.time() f_k = self._oracle.func(x_k, num_repetitions=self._num_repetitions) g_k = self._oracle.grad(x_k, num_repetitions=self._num_repetitions) grad_normed = g_k # (g_k / g_k.norm()) self._state_dict = copy.deepcopy(self._optimizer.state_dict()) if self._lr_algo == "None": self._optimizer.param_groups[0]['lr'] = self._x_step elif self._lr_algo == "Grad": self._optimizer.param_groups[0]['lr'] = self._x_step / g_k.norm( ).item() elif self._lr_algo == "Dim": self._optimizer.param_groups[0]['lr'] = self._x_step / np.sqrt( chi2.ppf(0.95, df=len(g_k))) # define closure for line search def closure(): self._optimizer.zero_grad() loss = self._oracle.func(x_k, num_repetitions=self._num_repetitions) return loss # two-loop recursion to compute search direction p = self._optimizer.two_loop_recursion(-grad_normed) options = { 'closure': closure, 'current_loss': f_k, 'interpolate': False } if self._line_search == 'Wolfe': lbfg_opt = self._optimizer.step(p, grad_normed, options=options) f_k, d_k, lr = lbfg_opt[0], lbfg_opt[1], lbfg_opt[2] elif self._line_search == 'Armijo': lbfg_opt = self._optimizer.step(p, grad_normed, options=options) f_k, lr = lbfg_opt[0], lbfg_opt[1] d_k = -g_k elif self._line_search == 'None': # self._optimizer.param_groups[0]['lr'] = 1. d_k = -g_k lbfg_opt = self._optimizer.step(p, grad_normed, options=options) lr = lbfg_opt g_k = self._oracle.grad(x_k, num_repetitions=self._num_repetitions) grad_normed = g_k # (g_k / g_k.norm()) self._optimizer.curvature_update(grad_normed, eps=0.2, damping=False) self._lbfg_opt = lbfg_opt grad_norm = d_k.norm().item() self._x = x_k super()._post_step(init_time=init_time) if grad_norm < self._tolerance: return SUCCESS if not (torch.isfinite(x_k).all() and torch.isfinite(f_k).all() and torch.isfinite(d_k).all()): return COMP_ERROR def reverse_optimizer(self, **kwargs): self._optimizer.load_state_dict(self._state_dict)