def conjugate_grad(self, deltaW, Hx, inputs, extra_inputs=()): # s = H^-1 g descent_direction = krylov.cg(Hx, deltaW, cg_iters=self._cg_iters) init_step = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))) # s' H s = g' s, as s = H^-1 g # init_step = np.sqrt(2.0 * self._max_constraint_val * #(1. / (descent_direction.dot(deltaW)) + 1e-8)) if np.isnan(init_step): init_step = 1. descent_step = init_step * descent_direction return self.line_search(descent_step, inputs, extra_inputs)
def optimize(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: n_samples = len(inputs[0]) inds = np.random.choice(n_samples, n_samples * self._subsample_factor, replace=False) subsample_inputs = tuple([x[inds] for x in inputs]) else: subsample_inputs = inputs logger.log("computing loss before") loss_before = self._opt_fun["f_loss"](*(inputs + extra_inputs)) logger.log("performing update") logger.log("computing descent direction") flat_g = self._opt_fun["f_grad"](*(inputs + extra_inputs)) def Hx(x): xs = tuple(self._target.flat_to_params(x, trainable=True)) # rop = f_Hx_rop(*(inputs + xs)) plain = self._opt_fun["f_Hx_plain"](*(subsample_inputs + extra_inputs + xs)) + self._reg_coeff * x # assert np.allclose(rop, plain) return plain # alternatively we can do finite difference on flat_grad descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1.0 / (descent_direction.dot(Hx(descent_direction)) + 1e-8)) ) flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") prev_param = self._target.get_param_values(trainable=True) for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = self._opt_fun["f_loss_constraint"](*(inputs + extra_inputs)) if self._debug_nan and np.isnan(constraint_val): import ipdb ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def optimize(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: n_samples = len(inputs[0]) inds = np.random.choice( n_samples, n_samples * self._subsample_factor, replace=False) subsample_inputs = tuple([x[inds] for x in inputs]) else: subsample_inputs = inputs logger.log("computing loss before") loss_before = self._opt_fun["f_loss"](*(inputs + extra_inputs)) logger.log("performing update") logger.log("computing descent direction") flat_g = self._opt_fun["f_grad"](*(inputs + extra_inputs)) def Hx(x): xs = tuple(self._target.flat_to_params(x, trainable=True)) # rop = f_Hx_rop(*(inputs + xs)) plain = self._opt_fun["f_Hx_plain"](*(subsample_inputs + extra_inputs + xs)) + self._reg_coeff * x # assert np.allclose(rop, plain) return plain # alternatively we can do finite difference on flat_grad descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)) ) flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") prev_param = self._target.get_param_values(trainable=True) for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = self._opt_fun["f_loss_constraint"](*(inputs + extra_inputs)) if self._debug_nan and np.isnan(constraint_val): import ipdb; ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def optimize( self, inputs, extra_inputs=None, subsample_grouped_inputs=None, precomputed_eval=None, precomputed_threshold=None, diff_threshold=False, inputs2=None, extra_inputs2=None, ): """ precomputed_eval : The value of the safety constraint at theta = theta_old. Provide this when the lin_constraint function is a surrogate, and evaluating it at theta_old will not give you the correct value. precomputed_threshold & diff_threshold : These relate to the linesearch that is used to ensure constraint satisfaction. If the lin_constraint function is indeed the safety constraint function, then it suffices to check that lin_constraint < max_lin_constraint_val to ensure satisfaction. But if the lin_constraint function is a surrogate - ie, it only has the same /gradient/ as the safety constraint - then the threshold we check it against has to be adjusted. You can provide a fixed adjusted threshold via "precomputed_threshold." When "diff_threshold" == True, instead of checking lin_constraint < threshold, it will check lin_constraint - old_lin_constraint < threshold. """ inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() # inputs2 and extra_inputs2 are for calculation of the linearized constraint. # This functionality - of having separate inputs for that constraint - is # intended to allow a "learning without forgetting" setup. if inputs2 is None: inputs2 = inputs if extra_inputs2 is None: extra_inputs2 = tuple() def subsampled_inputs(inputs, subsample_grouped_inputs): if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice(n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple( [x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs return subsample_inputs subsample_inputs = subsampled_inputs(inputs, subsample_grouped_inputs) if self._resample_inputs: subsample_inputs2 = subsampled_inputs(inputs, subsample_grouped_inputs) logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing descent direction") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) flat_b = sliced_fun(self._opt_fun["f_lin_constraint_grad"], self._num_slices)(inputs2, extra_inputs2) Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) v = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters, verbose=self._verbose_cg) approx_g = Hx(v) q = v.dot(approx_g) # approx = g^T H^{-1} g delta = 2 * self._max_quad_constraint_val eps = 1e-8 residual = np.sqrt((approx_g - flat_g).dot(approx_g - flat_g)) rescale = q / (v.dot(v)) logger.record_tabular("OptimDiagnostic_Residual", residual) logger.record_tabular("OptimDiagnostic_Rescale", rescale) if self.precompute: S = precomputed_eval assert (np.ndim(S) == 0) # please be a scalar else: S = sliced_fun(self._opt_fun["lin_constraint"], self._num_slices)(inputs, extra_inputs) c = S - self._max_lin_constraint_val if c > 0: logger.log("warning! safety constraint is already violated") else: # the current parameters constitute a feasible point: save it as "last good point" self.last_safe_point = np.copy( self._target.get_param_values(trainable=True)) # can't stop won't stop (unless something in the conditional checks / calculations that follow # require premature stopping of optimization process) stop_flag = False if flat_b.dot(flat_b) <= eps: # if safety gradient is zero, linear constraint is not present; # ignore its implementation. lam = np.sqrt(q / delta) nu = 0 w = 0 r, s, A, B = 0, 0, 0, 0 optim_case = 4 else: if self._resample_inputs: Hx = self._hvp_approach.build_eval(subsample_inputs2 + extra_inputs) norm_b = np.sqrt(flat_b.dot(flat_b)) unit_b = flat_b / norm_b w = norm_b * krylov.cg( Hx, unit_b, cg_iters=self._cg_iters, verbose=self._verbose_cg) r = w.dot(approx_g) # approx = b^T H^{-1} g s = w.dot(Hx(w)) # approx = b^T H^{-1} b # figure out lambda coeff (lagrange multiplier for trust region) # and nu coeff (lagrange multiplier for linear constraint) A = q - r**2 / s # this should always be positive by Cauchy-Schwarz B = delta - c**2 / s # this one says whether or not the closest point on the plane is feasible # if (B < 0), that means the trust region plane doesn't intersect the safety boundary if c < 0 and B < 0: # point in trust region is feasible and safety boundary doesn't intersect # ==> entire trust region is feasible optim_case = 3 elif c < 0 and B > 0: # x = 0 is feasible and safety boundary intersects # ==> most of trust region is feasible optim_case = 2 elif c > 0 and B > 0: # x = 0 is infeasible (bad! unsafe!) and safety boundary intersects # ==> part of trust region is feasible # ==> this is 'recovery mode' optim_case = 1 if self.attempt_feasible_recovery: logger.log( "alert! conjugate constraint optimizer is attempting feasible recovery" ) else: logger.log( "alert! problem is feasible but needs recovery, and we were instructed not to attempt recovery" ) stop_flag = True else: # x = 0 infeasible (bad! unsafe!) and safety boundary doesn't intersect # ==> whole trust region infeasible # ==> optimization problem infeasible!!! optim_case = 0 if self.attempt_infeasible_recovery: logger.log( "alert! conjugate constraint optimizer is attempting infeasible recovery" ) else: logger.log( "alert! problem is infeasible, and we were instructed not to attempt recovery" ) stop_flag = True # default dual vars, which assume safety constraint inactive # (this corresponds to either optim_case == 3, # or optim_case == 2 under certain conditions) lam = np.sqrt(q / delta) nu = 0 if optim_case == 2 or optim_case == 1: # dual function is piecewise continuous # on region (a): # # L(lam) = -1/2 (A / lam + B * lam) - r * c / s # # on region (b): # # L(lam) = -1/2 (q / lam + delta * lam) # lam_mid = r / c L_mid = -0.5 * (q / lam_mid + lam_mid * delta) lam_a = np.sqrt(A / (B + eps)) L_a = -np.sqrt(A * B) - r * c / (s + eps) # note that for optim_case == 1 or 2, B > 0, so this calculation should never be an issue lam_b = np.sqrt(q / delta) L_b = -np.sqrt(q * delta) #those lam's are solns to the pieces of piecewise continuous dual function. #the domains of the pieces depend on whether or not c < 0 (x=0 feasible), #and so projection back on to those domains is determined appropriately. if lam_mid > 0: if c < 0: # here, domain of (a) is [0, lam_mid) # and domain of (b) is (lam_mid, infty) if lam_a > lam_mid: lam_a = lam_mid L_a = L_mid if lam_b < lam_mid: lam_b = lam_mid L_b = L_mid else: # here, domain of (a) is (lam_mid, infty) # and domain of (b) is [0, lam_mid) if lam_a < lam_mid: lam_a = lam_mid L_a = L_mid if lam_b > lam_mid: lam_b = lam_mid L_b = L_mid if L_a >= L_b: lam = lam_a else: lam = lam_b else: if c < 0: lam = lam_b else: lam = lam_a nu = max(0, lam * c - r) / (s + eps) logger.record_tabular( "OptimCase", optim_case) # 4 / 3: trust region totally in safe region; # 2 : trust region partly intersects safe region, and current point is feasible # 1 : trust region partly intersects safe region, and current point is infeasible # 0 : trust region does not intersect safe region logger.record_tabular("LagrangeLamda", lam) # dual variable for trust region logger.record_tabular("LagrangeNu", nu) # dual variable for safety constraint logger.record_tabular("OptimDiagnostic_q", q) # approx = g^T H^{-1} g logger.record_tabular("OptimDiagnostic_r", r) # approx = b^T H^{-1} g logger.record_tabular("OptimDiagnostic_s", s) # approx = b^T H^{-1} b logger.record_tabular("OptimDiagnostic_c", c) # if > 0, constraint is violated logger.record_tabular("OptimDiagnostic_A", A) logger.record_tabular("OptimDiagnostic_B", B) logger.record_tabular("OptimDiagnostic_S", S) if nu == 0: logger.log("safety constraint is not active!") # Predict worst-case next S nextS = S + np.sqrt(delta * s) logger.record_tabular("OptimDiagnostic_WorstNextS", nextS) # for cases where we will not attempt recovery, we stop here. we didn't stop earlier # because first we wanted to record the various critical quantities for understanding the failure mode # (such as optim_case, B, c, S). Also, the logger gets angry if you are inconsistent about recording # a given quantity from iteration to iteration. That's why we have to record a BacktrackIters here. def record_zeros(): logger.record_tabular("BacktrackIters", 0) logger.record_tabular("LossRejects", 0) logger.record_tabular("QuadRejects", 0) logger.record_tabular("LinRejects", 0) if optim_case > 0: flat_descent_step = (1. / (lam + eps)) * (v + nu * w) else: # current default behavior for attempting infeasible recovery: # take a step on natural safety gradient flat_descent_step = np.sqrt(delta / (s + eps)) * w logger.log("descent direction computed") prev_param = np.copy(self._target.get_param_values(trainable=True)) prev_lin_constraint_val = sliced_fun(self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) logger.record_tabular("PrevLinConstVal", prev_lin_constraint_val) lin_reject_threshold = self._max_lin_constraint_val if precomputed_threshold is not None: lin_reject_threshold = precomputed_threshold if diff_threshold: lin_reject_threshold += prev_lin_constraint_val logger.record_tabular("LinRejectThreshold", lin_reject_threshold) def check_nan(): loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan( lin_constraint_val): logger.log("Something is NaN. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True) def line_search(check_loss=True, check_quad=True, check_lin=True): loss_rejects = 0 quad_rejects = 0 lin_rejects = 0 n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) loss_flag = loss < loss_before quad_flag = quad_constraint_val <= self._max_quad_constraint_val lin_flag = lin_constraint_val <= lin_reject_threshold if check_loss and not (loss_flag): logger.log("At backtrack itr %i, loss failed to improve." % n_iter) loss_rejects += 1 if check_quad and not (quad_flag): logger.log( "At backtrack itr %i, quad constraint violated." % n_iter) logger.log( "Quad constraint violation was %.3f %%." % (100 * (quad_constraint_val / self._max_quad_constraint_val) - 100)) quad_rejects += 1 if check_lin and not (lin_flag): logger.log( "At backtrack itr %i, expression for lin constraint failed to improve." % n_iter) logger.log( "Lin constraint violation was %.3f %%." % (100 * (lin_constraint_val / lin_reject_threshold) - 100)) lin_rejects += 1 if (loss_flag or not (check_loss)) and ( quad_flag or not (check_quad)) and (lin_flag or not (check_lin)): logger.log("Accepted step at backtrack itr %i." % n_iter) break logger.record_tabular("BacktrackIters", n_iter) logger.record_tabular("LossRejects", loss_rejects) logger.record_tabular("QuadRejects", quad_rejects) logger.record_tabular("LinRejects", lin_rejects) return loss, quad_constraint_val, lin_constraint_val, n_iter def wrap_up(): if optim_case < 4: lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val logger.record_tabular("LinConstraintDelta", lin_constraint_delta) cur_param = self._target.get_param_values() next_linear_S = S + flat_b.dot(cur_param - prev_param) next_surrogate_S = S + lin_constraint_delta lin_surrogate_acc = 100. * ( next_linear_S - next_surrogate_S) / next_surrogate_S logger.record_tabular("PredictedLinearS", next_linear_S) logger.record_tabular("PredictedSurrogateS", next_surrogate_S) logger.record_tabular("LinearSurrogateErr", lin_surrogate_acc) lin_pred_err = (self._last_lin_pred_S - S) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - S) #/ (S + eps) logger.record_tabular("PredictionErrorLinearS", lin_pred_err) logger.record_tabular("PredictionErrorSurrogateS", surr_pred_err) self._last_lin_pred_S = next_linear_S self._last_surr_pred_S = next_surrogate_S else: logger.record_tabular("LinConstraintDelta", 0) logger.record_tabular("PredictedLinearS", 0) logger.record_tabular("PredictedSurrogateS", 0) logger.record_tabular("LinearSurrogateErr", 0) lin_pred_err = (self._last_lin_pred_S - 0) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - 0) #/ (S + eps) logger.record_tabular("PredictionErrorLinearS", lin_pred_err) logger.record_tabular("PredictionErrorSurrogateS", surr_pred_err) self._last_lin_pred_S = 0 self._last_surr_pred_S = 0 if stop_flag == True: record_zeros() wrap_up() return if optim_case == 1 and not (self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log( "feasible recovery mode: constrained natural gradient step. performing linesearch on constraints." ) line_search(False, True, True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log( "feasible recovery mode: constrained natural gradient step. no linesearch performed." ) check_nan() record_zeros() wrap_up() return elif optim_case == 0 and not (self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log( "infeasible recovery mode: natural safety step. performing linesearch on constraints." ) line_search(False, True, True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log( "infeasible recovery mode: natural safety gradient step. no linesearch performed." ) check_nan() record_zeros() wrap_up() return elif (optim_case == 0 or optim_case == 1) and self.revert_to_last_safe_point: if self.last_safe_point: self._target.set_param_values(self.last_safe_point, trainable=True) logger.log( "infeasible recovery mode: reverted to last safe point!") else: logger.log( "alert! infeasible recovery mode failed: no last safe point to revert to." ) record_zeros() wrap_up() return loss, quad_constraint_val, lin_constraint_val, n_iter = line_search() if (np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val) or loss >= loss_before or quad_constraint_val >= self._max_quad_constraint_val or lin_constraint_val > lin_reject_threshold ) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) if loss >= loss_before: logger.log("Violated because loss not improving") if quad_constraint_val >= self._max_quad_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name_1) if lin_constraint_val > lin_reject_threshold: logger.log( "Violated because constraint %s exceeded threshold" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished") wrap_up()
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): prev_param = np.copy(self._target.get_param_values(trainable=True)) inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice(n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs logger.log( "Start CG optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(prev_param), len(inputs[0]), len(subsample_inputs[0]))) logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing gradient") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if self._debug_nan and np.isnan(constraint_val): import ipdb ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val ) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): prev_param = np.copy(self._target.get_param_values(trainable=True)) inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice( n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs logger.log("Start CG optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"%(len(prev_param),len(inputs[0]), len(subsample_inputs[0]))) logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing gradient") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)) ) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = sliced_fun(self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if self._debug_nan and np.isnan(constraint_val): import ipdb; ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None, precomputed_eval=None, precomputed_threshold=None, diff_threshold=False, inputs2=None, extra_inputs2=None, ): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if inputs2 is None: inputs2 = inputs if extra_inputs2 is None: extra_inputs2 = tuple() def subsampled_inputs(inputs,subsample_grouped_inputs): if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice( n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs return subsample_inputs subsample_inputs = subsampled_inputs(inputs,subsample_grouped_inputs) if self._resample_inputs: subsample_inputs2 = subsampled_inputs(inputs,subsample_grouped_inputs) loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)( inputs, extra_inputs) flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)( inputs, extra_inputs) flat_b = sliced_fun(self._opt_fun["f_lin_constraint_grad"], self._num_slices)( inputs2, extra_inputs2) Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) v = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters, verbose=self._verbose_cg) approx_g = Hx(v) q = v.dot(approx_g) delta = 2 * self._max_quad_constraint_val eps = 1e-8 residual = np.sqrt((approx_g - flat_g).dot(approx_g - flat_g)) rescale = q / (v.dot(v)) if self.precompute: S = precomputed_eval assert(np.ndim(S)==0) else: S = sliced_fun(self._opt_fun["lin_constraint"], self._num_slices)(inputs, extra_inputs) c = S - self._max_lin_constraint_val if c > 0: logger.log("warning! safety constraint is already violated") else: self.last_safe_point = np.copy(self._target.get_param_values(trainable=True)) stop_flag = False if flat_b.dot(flat_b) <= eps : lam = np.sqrt(q / delta) nu = 0 w = 0 r,s,A,B = 0,0,0,0 optim_case = 4 else: if self._resample_inputs: Hx = self._hvp_approach.build_eval(subsample_inputs2 + extra_inputs) norm_b = np.sqrt(flat_b.dot(flat_b)) unit_b = flat_b / norm_b w = norm_b * krylov.cg(Hx, unit_b, cg_iters=self._cg_iters, verbose=self._verbose_cg) r = w.dot(approx_g) # approx = b^T H^{-1} g s = w.dot(Hx(w)) # approx = b^T H^{-1} b A = q - r**2 / s # this should always be positive by Cauchy-Schwarz B = delta - c**2 / s # this one says whether or not the closest point on the plane is feasible if c <0 and B < 0: optim_case = 3 elif c < 0 and B > 0: optim_case = 2 elif c > 0 and B > 0: optim_case = 1 if self.attempt_feasible_recovery: logger.log("alert! conjugate constraint optimizer is attempting feasible recovery") else: logger.log("alert! problem is feasible but needs recovery, and we were instructed not to attempt recovery") stop_flag = True else: optim_case = 0 if self.attempt_infeasible_recovery: logger.log("alert! conjugate constraint optimizer is attempting infeasible recovery") else: logger.log("alert! problem is infeasible, and we were instructed not to attempt recovery") stop_flag = True lam = np.sqrt(q / delta) nu = 0 if optim_case == 2 or optim_case == 1: lam_mid = r / c L_mid = - 0.5 * (q / lam_mid + lam_mid * delta) lam_a = np.sqrt(A / (B + eps)) L_a = -np.sqrt(A*B) - r*c / (s + eps) lam_b = np.sqrt(q / delta) L_b = -np.sqrt(q * delta) if lam_mid > 0: if c < 0: if lam_a > lam_mid: lam_a = lam_mid L_a = L_mid if lam_b < lam_mid: lam_b = lam_mid L_b = L_mid else: if lam_a < lam_mid: lam_a = lam_mid L_a = L_mid if lam_b > lam_mid: lam_b = lam_mid L_b = L_mid if L_a >= L_b: lam = lam_a else: lam = lam_b else: if c < 0: lam = lam_b else: lam = lam_a nu = max(0, lam * c - r) / (s + eps) nextS = S + np.sqrt(delta * s) def record_zeros(): logger.record_tabular("BacktrackIters", 0) logger.record_tabular("LossRejects", 0) logger.record_tabular("QuadRejects", 0) logger.record_tabular("LinRejects", 0) if optim_case > 0: flat_descent_step = (1. / (lam + eps) ) * ( v + nu * w ) else: flat_descent_step = np.sqrt(delta / (s + eps)) * w prev_param = np.copy(self._target.get_param_values(trainable=True)) prev_lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_reject_threshold = self._max_lin_constraint_val if precomputed_threshold is not None: lin_reject_threshold = precomputed_threshold if diff_threshold: lin_reject_threshold += prev_lin_constraint_val def check_nan(): loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val): if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True) def line_search(check_loss=True, check_quad=True, check_lin=True): loss_rejects = 0 quad_rejects = 0 lin_rejects = 0 n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) loss_flag = loss < loss_before quad_flag = quad_constraint_val <= self._max_quad_constraint_val lin_flag = lin_constraint_val <= lin_reject_threshold if check_loss and not(loss_flag): loss_rejects += 1 if check_quad and not(quad_flag): quad_rejects += 1 if check_lin and not(lin_flag): lin_rejects += 1 if (loss_flag or not(check_loss)) and (quad_flag or not(check_quad)) and (lin_flag or not(check_lin)): break return loss, quad_constraint_val, lin_constraint_val, n_iter def wrap_up(): if optim_case < 4: lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val logger.record_tabular("LinConstraintDelta",lin_constraint_delta) cur_param = self._target.get_param_values() next_linear_S = S + flat_b.dot(cur_param - prev_param) next_surrogate_S = S + lin_constraint_delta lin_surrogate_acc = 100.*(next_linear_S - next_surrogate_S) / next_surrogate_S lin_pred_err = (self._last_lin_pred_S - S) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - S) #/ (S + eps) self._last_lin_pred_S = next_linear_S self._last_surr_pred_S = next_surrogate_S else: lin_pred_err = (self._last_lin_pred_S - 0) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - 0) #/ (S + eps) self._last_lin_pred_S = 0 self._last_surr_pred_S = 0 if stop_flag==True: record_zeros() wrap_up() return if optim_case == 1 and not(self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log("feasible recovery mode: constrained natural gradient step. performing linesearch on constraints.") line_search(False,True,True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log("feasible recovery mode: constrained natural gradient step. no linesearch performed.") check_nan() record_zeros() wrap_up() return elif optim_case == 0 and not(self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log("infeasible recovery mode: natural safety step. performing linesearch on constraints.") line_search(False,True,True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log("infeasible recovery mode: natural safety gradient step. no linesearch performed.") check_nan() record_zeros() wrap_up() return elif (optim_case == 0 or optim_case == 1) and self.revert_to_last_safe_point: if self.last_safe_point: self._target.set_param_values(self.last_safe_point, trainable=True) logger.log("infeasible recovery mode: reverted to last safe point!") else: logger.log("alert! infeasible recovery mode failed: no last safe point to revert to.") record_zeros() wrap_up() return loss, quad_constraint_val, lin_constraint_val, n_iter = line_search() if (np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val) or loss >= loss_before or quad_constraint_val >= self._max_quad_constraint_val or lin_constraint_val > lin_reject_threshold) and not self._accept_violation: self._target.set_param_values(prev_param, trainable=True) wrap_up()