def _eval(vector): """The evaluation function. Args: vector (torch.Tensor): The vector to be multiplied with Hessian. Returns: torch.Tensor: The product of Hessian of function f and v. """ unflatten_vector = unflatten_tensors(vector, param_shapes) assert len(f_grads) == len(unflatten_vector) grad_vector_product = torch.sum( torch.stack( [torch.sum(g * x) for g, x in zip(f_grads, unflatten_vector)])) hvp = list( torch.autograd.grad(grad_vector_product, params, retain_graph=True)) for i, (hx, p) in enumerate(zip(hvp, params)): if hx is None: hvp[i] = torch.zeros_like(p) flat_output = torch.cat([h.reshape(-1) for h in hvp]) return flat_output + reg_coeff * vector
def set_param_values(self, param_values): """Set param values. Args: param_values (np.ndarray): A numpy array of parameter values. """ param_values = unflatten_tensors(param_values, self.get_param_shapes()) for param, value in zip(self.get_params(), param_values): param.load(value)
def flat_to_params(self, flattened_params): """Unflatten tensors according to their respective shapes. Args: flattened_params (np.ndarray): A numpy array of flattened params. Returns: List[np.ndarray]: A list of parameters reshaped to the shapes specified. """ return unflatten_tensors(flattened_params, self.get_param_shapes())
def _backtracking_line_search(self, params, descent_step, f_loss, f_constraint): prev_params = [p.clone() for p in params] ratio_list = self._backtrack_ratio**np.arange(self._max_backtracks) loss_before = f_loss() param_shapes = [p.shape or torch.Size([1]) for p in params] descent_step = unflatten_tensors(descent_step, param_shapes) assert len(descent_step) == len(params) for ratio in ratio_list: for step, prev_param, param in zip(descent_step, prev_params, params): step = ratio * step new_param = prev_param.data - step param.data = new_param.data loss = f_loss() constraint_val = f_constraint() if (loss < loss_before and constraint_val <= self._max_constraint_value): break if ((torch.isnan(loss) or torch.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_value) and not self._accept_violation): logger.log('Line search condition violated. Rejecting the step!') if torch.isnan(loss): logger.log('Violated because loss is NaN') if torch.isnan(constraint_val): logger.log('Violated because constraint is NaN') if loss >= loss_before: logger.log('Violated because loss not improving') if constraint_val >= self._max_constraint_value: logger.log('Violated because constraint is violated') for prev, cur in zip(prev_params, params): cur.data = prev.data