def eval(x): xs = tuple(self.target.flat_to_params(x, trainable=True)) ret = sliced_fun(self.opt_fun["f_hx_plain"], self._num_slices)( inputs, xs) + self.reg_coeff * x return ret
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice(n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing descent direction") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) descent_direction = krylov.cg(hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") prev_param = np.copy(self._target.get_param_values(trainable=True)) n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): # yapf: disable cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if loss < loss_before \ and constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val) \ and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def loss(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs)
def constraint_val(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs)
def optimize(self, inputs, gradient, extra_inputs=None, subsample_grouped_inputs=None, name=None): prev_param = np.copy(self._target.get_param_values(trainable=True)) inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() sample_inputs = inputs logger.log(("Start CG optimization: " "#parameters: %d, #inputs: %d, #subsample_inputs: %d") % (len(prev_param), len(inputs[0]), len(sample_inputs[0]))) logger.log("performing update") logger.log("computing descent direction") hx = self._hvp_approach.build_eval(sample_inputs) descent_direction = krylov.cg(hx, gradient, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") n_iter = 0 loss_before = self._opt_fun["f_loss"](*(inputs + extra_inputs)) for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): # yapf: disable cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs + extra_inputs) if self._debug_nan and np.isnan(constraint_val): break if loss < loss_before and \ constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss > loss_before or constraint_val >= self._max_constraint_val ) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss > loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")