def optimize(self, inputs, extra_inputs=None, callback=None): if len(inputs) == 0: # Assumes that we should always sample mini-batches raise NotImplementedError f_loss = self._opt_fun["f_loss"] if extra_inputs is None: extra_inputs = tuple() last_loss = sliced_fun(f_loss, self._num_slices)(inputs, extra_inputs) #last_loss = f_loss(*(tuple(inputs) + extra_inputs)) start_time = time.time() dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) sess = tf.get_default_session() for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) for batch in dataset.iterate(update=True): if (self._ignore_last and len(batch[0]) != self._batch_size): continue sess.run(self._train_op, dict(list(zip(self._input_vars, batch)))) if self._verbose: progbar.update(len(batch[0])) if self._verbose: if progbar.active: progbar.stop() new_loss = sliced_fun(f_loss, self._num_slices)(inputs, extra_inputs) #new_loss = f_loss(*(tuple(inputs) + extra_inputs)) if self._verbose: logger.log("Epoch: %d | Loss: %f" % (epoch, new_loss)) if self._callback or callback: elapsed = time.time() - start_time callback_args = dict( loss=new_loss, params=self._target.get_param_values(trainable=True) if self._target else None, itr=epoch, elapsed=elapsed, ) if self._callback: self._callback(callback_args) if callback: callback(**callback_args) if abs(last_loss - new_loss) < self._tolerance: break last_loss = new_loss
def eval(x): if config.TF_NN_SETTRACE: ipdb.set_trace() xs = tuple(self.target.flat_to_params(x, trainable=True)) ret = sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)( inputs, xs) + self.reg_coeff * x return ret
def line_search(self, descent_step, inputs, extra_inputs=()): f_loss = self._opt_fun["f_loss"] f_loss_constraint = self._opt_fun["f_loss_constraint"] prev_w = np.copy(self._target.get_param_values(trainable=True)) loss_before = f_loss(*(inputs + extra_inputs)) n_iter = 0 succ_line_search = False for n_iter, ratio in enumerate( self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * descent_step cur_w = prev_w - cur_step self._target.set_param_values(cur_w, trainable=True) loss, constraint_val = sliced_fun(f_loss_constraint, self._num_slices)(inputs, extra_inputs) if loss < loss_before and constraint_val <= self._max_constraint_val: succ_line_search = True break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val): logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint is NaN") if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log( "Violated because constraint {:} is violated".format(constraint_val)) self._target.set_param_values(prev_w, trainable=True) logger.log("backtrack iters: %d" % n_iter) return succ_line_search
def line_search(check_loss=True, check_quad=True, check_lin=True): loss_rejects = 0 quad_rejects = 0 lin_rejects = 0 n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) loss_flag = loss < loss_before quad_flag = quad_constraint_val <= self._max_quad_constraint_val lin_flag = lin_constraint_val <= lin_reject_threshold if check_loss and not(loss_flag): loss_rejects += 1 if check_quad and not(quad_flag): quad_rejects += 1 if check_lin and not(lin_flag): lin_rejects += 1 if (loss_flag or not(check_loss)) and (quad_flag or not(check_quad)) and (lin_flag or not(check_lin)): break return loss, quad_constraint_val, lin_constraint_val, n_iter
def mean_kl(self, samples_data): all_input_values = self.construct_inputs(samples_data) kl_divs = [] for constraint in self.f_constraints: kl_divs.append(sliced_fun(constraint, 1)(all_input_values)) return kl_divs
def constraint_val(self, inputs, extra_inputs=None): if config.TF_NN_SETTRACE: ipdb.set_trace() inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs)
def _constraint_val(self, inputs, extra_inputs): """ Parallelized: returns the same value in all workers. """ shareds, barriers = self._par_objs shareds.constraint_val[self.rank] = self.avg_fac * sliced_fun( self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs) barriers.cnstr.wait() return sum(shareds.constraint_val)
def get_gradient(algo, samples_data, flat=False): all_input_values = tuple( ext.extract(samples_data, "observations", "actions", "advantages")) agent_infos = samples_data["agent_infos"] state_info_list = [agent_infos[k] for k in algo.policy.state_info_keys] dist_info_list = [ agent_infos[k] for k in algo.policy.distribution.dist_info_keys ] all_input_values += tuple(state_info_list) + tuple(dist_info_list) if flat: grad = sliced_fun(algo.optimizer._opt_fun["f_grad"], 1)(tuple(all_input_values), tuple()) else: grad = sliced_fun(algo.optimizer._opt_fun["f_grads"], 1)(tuple(all_input_values), tuple()) return grad
def flat_g(self, inputs, extra_inputs=None): shareds, barriers = self._par_objs # Each worker records result available to all. shareds.grads_2d[:, self.rank] = self.avg_fac * ext.sliced_fun( self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) barriers.flat_g[0].wait() if self.rank == 0: shareds.flat_g = np.sum(shareds.grads_2d, axis=1) barriers.flat_g[1].wait() return shareds.flat_g
def _loss_constraint(self, inputs, extra_inputs): """ Parallelized: returns the same values in all workers. """ shareds, barriers = self._par_objs loss, constraint_val = sliced_fun(self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) shareds.loss[self.rank] = self.avg_fac * loss shareds.constraint_val[self.rank] = self.avg_fac * constraint_val barriers.loss_cnstr.wait() return sum(shareds.loss), sum(shareds.constraint_val)
def check_nan(): loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val): if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True)
def get_grad(self, samples_data): all_input_values = tuple( ext.extract(samples_data, "observations", "actions", "advantages")) agent_infos = samples_data["agent_infos"] state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] dist_info_list = [ agent_infos[k] for k in self.policy.distribution.dist_info_keys ] all_input_values += tuple(state_info_list) + tuple(dist_info_list) if self.policy.recurrent: all_input_values += (samples_data["valids"], ) return sliced_fun(self.optimizer._opt_fun["f_grads"], 1)((all_input_values))
def line_search(check_loss=True, check_quad=True, check_lin=True): loss_rejects = 0 quad_rejects = 0 lin_rejects = 0 n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) loss_flag = loss < loss_before quad_flag = quad_constraint_val <= self._max_quad_constraint_val lin_flag = lin_constraint_val <= lin_reject_threshold if check_loss and not (loss_flag): logger.log("At backtrack itr %i, loss failed to improve." % n_iter) loss_rejects += 1 if check_quad and not (quad_flag): logger.log( "At backtrack itr %i, quad constraint violated." % n_iter) logger.log( "Quad constraint violation was %.3f %%." % (100 * (quad_constraint_val / self._max_quad_constraint_val) - 100)) quad_rejects += 1 if check_lin and not (lin_flag): logger.log( "At backtrack itr %i, expression for lin constraint failed to improve." % n_iter) logger.log( "Lin constraint violation was %.3f %%." % (100 * (lin_constraint_val / lin_reject_threshold) - 100)) lin_rejects += 1 if (loss_flag or not (check_loss)) and ( quad_flag or not (check_quad)) and (lin_flag or not (check_lin)): logger.log("Accepted step at backtrack itr %i." % n_iter) break logger.record_tabular("BacktrackIters", n_iter) logger.record_tabular("LossRejects", loss_rejects) logger.record_tabular("QuadRejects", quad_rejects) logger.record_tabular("LinRejects", lin_rejects) return loss, quad_constraint_val, lin_constraint_val, n_iter
def _flat_g(self, inputs, extra_inputs): """ Parallelized: returns the same values in all workers. """ shareds, barriers = self._par_objs # Each worker records result available to all. shareds.grads_2d[:, self.rank] = self.avg_fac * \ sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) barriers.flat_g[0].wait() # Each worker sums over an equal share of the grad elements across # workers (row major storage--sum along rows). shareds.flat_g[self.vb[0]:self.vb[1]] = \ np.sum(shareds.grads_2d[self.vb[0]:self.vb[1], :], axis=1) barriers.flat_g[1].wait()
def get_gradient(self, samples_data): all_input_values = tuple( ext.extract(samples_data, "observations", "actions", "advantages")) agent_infos = samples_data["agent_infos"] state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] dist_info_list = [ agent_infos[k] for k in self.policy.distribution.dist_info_keys ] all_input_values += tuple(state_info_list) + tuple(dist_info_list) if self.policy.recurrent: all_input_values += (samples_data["valids"], ) # multitask related task_obs = [] task_old_dist_info_list = [] task_old_dist_info = [] for i in range(self.task_num): task_obs.append([]) task_old_dist_info_list.append([]) task_old_dist_info.append([]) for k in self.policy.distribution.dist_info_keys: task_old_dist_info_list[i].append([]) for i in range(len(samples_data["observations"])): taskid = np.random.randint( self.task_num ) # fake the taskid to satisfy the calculation requirement, very ugly task_obs[taskid].append(samples_data["observations"][i]) for j, k in enumerate(self.policy.distribution.dist_info_keys): task_old_dist_info_list[taskid][j].append( samples_data["agent_infos"][k][i]) for i in range(self.task_num): for j, k in enumerate(self.policy.distribution.dist_info_keys): task_old_dist_info[i].append( np.array(task_old_dist_info_list[i][j])) task_obs[i] = np.array(task_obs[i]) for i in range(self.task_num): all_input_values += tuple([task_obs[i]]) for i in range(self.task_num): all_input_values += tuple(task_old_dist_info[i]) all_input_values += tuple([self.kl_weights]) grad = sliced_fun(self.optimizer._opt_fun["f_grads"], 1)((all_input_values)) return grad
def mean_kl(self, samples_data): all_input_values = tuple( ext.extract(samples_data, "observations", "actions", "advantages")) agent_infos = samples_data["agent_infos"] state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] dist_info_list = [ agent_infos[k] for k in self.policy.distribution.dist_info_keys ] all_input_values += tuple(state_info_list) + tuple(dist_info_list) if self.policy.recurrent: all_input_values += (samples_data["valids"], ) # multitask related task_obs = [] task_old_dist_info_list = [] task_old_dist_info = [] for i in range(self.task_num): task_obs.append([]) task_old_dist_info_list.append([]) task_old_dist_info.append([]) for k in self.policy.distribution.dist_info_keys: task_old_dist_info_list[i].append([]) for i in range(len(samples_data["observations"])): taskid = samples_data["env_infos"]["state_index"][i] task_obs[taskid].append(samples_data["observations"][i]) for j, k in enumerate(self.policy.distribution.dist_info_keys): task_old_dist_info_list[taskid][j].append( samples_data["agent_infos"][k][i]) for i in range(self.task_num): for j, k in enumerate(self.policy.distribution.dist_info_keys): task_old_dist_info[i].append( np.array(task_old_dist_info_list[i][j])) task_obs[i] = np.array(task_obs[i]) for i in range(self.task_num): all_input_values += tuple([task_obs[i]]) for i in range(self.task_num): all_input_values += tuple(task_old_dist_info[i]) all_input_values += tuple([self.kl_weights]) kl_divs = [] for constraint in self.f_constraints: kl_divs.append(sliced_fun(constraint, 1)(all_input_values)) return kl_divs
def parallel_eval(x): """ Parallelized. """ shareds, barriers = self._par_objs xs = tuple(self.target.flat_to_params(x, trainable=True)) shareds.grads_2d[:, self.pd.rank] = self.pd.avg_fac * \ sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)(inputs, xs) barriers.Hx[0].wait() shareds.Hx[self.pd.vb[0]:self.pd.vb[1]] = \ self.reg_coeff * x[self.pd.vb[0]:self.pd.vb[1]] + \ np.sum(shareds.grads_2d[self.pd.vb[0]:self.pd.vb[1], :], axis=1) barriers.Hx[1].wait() return shareds.Hx # (or can just access this persistent var elsewhere)
def get_gradient(self, samples_data): all_input_values = tuple(ext.extract( samples_data, "observations", "actions", "advantages" )) agent_infos = samples_data["agent_infos"] task_id = samples_data["env_infos"]["state_index"][-1] state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys] task_obs = [] task_actions = [] task_advantages = [] task_old_dist_info_list = [] task_old_dist_info = [] for i in range(self.task_num): task_obs.append([]) task_actions.append([]) task_advantages.append([]) task_old_dist_info_list.append([]) task_old_dist_info.append([]) for k in self.policy.distribution.dist_info_keys: task_old_dist_info_list[i].append([]) for i in range(len(samples_data["observations"])): taskid = samples_data["env_infos"]["state_index"][i] task_obs[taskid].append(samples_data["observations"][i]) task_actions[taskid].append(samples_data["actions"][i]) task_advantages[taskid].append(samples_data["advantages"][i]) for j, k in enumerate(self.policy.distribution.dist_info_keys): task_old_dist_info_list[taskid][j].append(samples_data["agent_infos"][k][i]) for i in range(self.task_num): for j, k in enumerate(self.policy.distribution.dist_info_keys): task_old_dist_info[i].append(np.array(task_old_dist_info_list[i][j])) task_obs[i] = np.array(task_obs[i]) input_values = tuple([task_obs[task_id]]) + tuple([task_actions[task_id]]) + tuple([task_advantages[task_id]]) + tuple(task_old_dist_info[task_id]) + tuple(state_info_list) grad = sliced_fun(self.f_task_grads[task_id], 1)( (input_values)) return grad
def wrap_up(): if optim_case < 4: lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val logger.record_tabular("LinConstraintDelta", lin_constraint_delta) cur_param = self._target.get_param_values() next_linear_S = S + flat_b.dot(cur_param - prev_param) next_surrogate_S = S + lin_constraint_delta lin_surrogate_acc = 100. * ( next_linear_S - next_surrogate_S) / next_surrogate_S logger.record_tabular("PredictedLinearS", next_linear_S) logger.record_tabular("PredictedSurrogateS", next_surrogate_S) logger.record_tabular("LinearSurrogateErr", lin_surrogate_acc) lin_pred_err = (self._last_lin_pred_S - S) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - S) #/ (S + eps) logger.record_tabular("PredictionErrorLinearS", lin_pred_err) logger.record_tabular("PredictionErrorSurrogateS", surr_pred_err) self._last_lin_pred_S = next_linear_S self._last_surr_pred_S = next_surrogate_S else: logger.record_tabular("LinConstraintDelta", 0) logger.record_tabular("PredictedLinearS", 0) logger.record_tabular("PredictedSurrogateS", 0) logger.record_tabular("LinearSurrogateErr", 0) lin_pred_err = (self._last_lin_pred_S - 0) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - 0) #/ (S + eps) logger.record_tabular("PredictionErrorLinearS", lin_pred_err) logger.record_tabular("PredictionErrorSurrogateS", surr_pred_err) self._last_lin_pred_S = 0 self._last_surr_pred_S = 0
def optimize_expert_policies(self, itr, all_samples_data): dist_info_keys = self.policy.distribution.dist_info_keys for n, optimizer in enumerate(self.optimizers): obs_act_adv_values = tuple( ext.extract(all_samples_data[n], "observations", "actions", "advantages")) dist_info_list = tuple([ all_samples_data[n]["agent_infos"][k] for k in dist_info_keys ]) all_task_obs_values = tuple([ samples_data["observations"] for samples_data in all_samples_data ]) all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values + all_task_obs_values optimizer.optimize(all_input_values) kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values)
def eval(x): xs = tuple(self.target.flat_to_params(x, trainable=True)) ret = sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)( inputs, xs) + self.reg_coeff * x return ret
def optimize( self, inputs, extra_inputs=None, subsample_grouped_inputs=None, precomputed_eval=None, precomputed_threshold=None, diff_threshold=False, inputs2=None, extra_inputs2=None, ): """ precomputed_eval : The value of the safety constraint at theta = theta_old. Provide this when the lin_constraint function is a surrogate, and evaluating it at theta_old will not give you the correct value. precomputed_threshold & diff_threshold : These relate to the linesearch that is used to ensure constraint satisfaction. If the lin_constraint function is indeed the safety constraint function, then it suffices to check that lin_constraint < max_lin_constraint_val to ensure satisfaction. But if the lin_constraint function is a surrogate - ie, it only has the same /gradient/ as the safety constraint - then the threshold we check it against has to be adjusted. You can provide a fixed adjusted threshold via "precomputed_threshold." When "diff_threshold" == True, instead of checking lin_constraint < threshold, it will check lin_constraint - old_lin_constraint < threshold. """ inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() # inputs2 and extra_inputs2 are for calculation of the linearized constraint. # This functionality - of having separate inputs for that constraint - is # intended to allow a "learning without forgetting" setup. if inputs2 is None: inputs2 = inputs if extra_inputs2 is None: extra_inputs2 = tuple() def subsampled_inputs(inputs, subsample_grouped_inputs): if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice(n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple( [x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs return subsample_inputs subsample_inputs = subsampled_inputs(inputs, subsample_grouped_inputs) if self._resample_inputs: subsample_inputs2 = subsampled_inputs(inputs, subsample_grouped_inputs) logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing descent direction") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) flat_b = sliced_fun(self._opt_fun["f_lin_constraint_grad"], self._num_slices)(inputs2, extra_inputs2) Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) v = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters, verbose=self._verbose_cg) approx_g = Hx(v) q = v.dot(approx_g) # approx = g^T H^{-1} g delta = 2 * self._max_quad_constraint_val eps = 1e-8 residual = np.sqrt((approx_g - flat_g).dot(approx_g - flat_g)) rescale = q / (v.dot(v)) logger.record_tabular("OptimDiagnostic_Residual", residual) logger.record_tabular("OptimDiagnostic_Rescale", rescale) if self.precompute: S = precomputed_eval assert (np.ndim(S) == 0) # please be a scalar else: S = sliced_fun(self._opt_fun["lin_constraint"], self._num_slices)(inputs, extra_inputs) c = S - self._max_lin_constraint_val if c > 0: logger.log("warning! safety constraint is already violated") else: # the current parameters constitute a feasible point: save it as "last good point" self.last_safe_point = np.copy( self._target.get_param_values(trainable=True)) # can't stop won't stop (unless something in the conditional checks / calculations that follow # require premature stopping of optimization process) stop_flag = False if flat_b.dot(flat_b) <= eps: # if safety gradient is zero, linear constraint is not present; # ignore its implementation. lam = np.sqrt(q / delta) nu = 0 w = 0 r, s, A, B = 0, 0, 0, 0 optim_case = 4 else: if self._resample_inputs: Hx = self._hvp_approach.build_eval(subsample_inputs2 + extra_inputs) norm_b = np.sqrt(flat_b.dot(flat_b)) unit_b = flat_b / norm_b w = norm_b * krylov.cg( Hx, unit_b, cg_iters=self._cg_iters, verbose=self._verbose_cg) r = w.dot(approx_g) # approx = b^T H^{-1} g s = w.dot(Hx(w)) # approx = b^T H^{-1} b # figure out lambda coeff (lagrange multiplier for trust region) # and nu coeff (lagrange multiplier for linear constraint) A = q - r**2 / s # this should always be positive by Cauchy-Schwarz B = delta - c**2 / s # this one says whether or not the closest point on the plane is feasible # if (B < 0), that means the trust region plane doesn't intersect the safety boundary if c < 0 and B < 0: # point in trust region is feasible and safety boundary doesn't intersect # ==> entire trust region is feasible optim_case = 3 elif c < 0 and B > 0: # x = 0 is feasible and safety boundary intersects # ==> most of trust region is feasible optim_case = 2 elif c > 0 and B > 0: # x = 0 is infeasible (bad! unsafe!) and safety boundary intersects # ==> part of trust region is feasible # ==> this is 'recovery mode' optim_case = 1 if self.attempt_feasible_recovery: logger.log( "alert! conjugate constraint optimizer is attempting feasible recovery" ) else: logger.log( "alert! problem is feasible but needs recovery, and we were instructed not to attempt recovery" ) stop_flag = True else: # x = 0 infeasible (bad! unsafe!) and safety boundary doesn't intersect # ==> whole trust region infeasible # ==> optimization problem infeasible!!! optim_case = 0 if self.attempt_infeasible_recovery: logger.log( "alert! conjugate constraint optimizer is attempting infeasible recovery" ) else: logger.log( "alert! problem is infeasible, and we were instructed not to attempt recovery" ) stop_flag = True # default dual vars, which assume safety constraint inactive # (this corresponds to either optim_case == 3, # or optim_case == 2 under certain conditions) lam = np.sqrt(q / delta) nu = 0 if optim_case == 2 or optim_case == 1: # dual function is piecewise continuous # on region (a): # # L(lam) = -1/2 (A / lam + B * lam) - r * c / s # # on region (b): # # L(lam) = -1/2 (q / lam + delta * lam) # lam_mid = r / c L_mid = -0.5 * (q / lam_mid + lam_mid * delta) lam_a = np.sqrt(A / (B + eps)) L_a = -np.sqrt(A * B) - r * c / (s + eps) # note that for optim_case == 1 or 2, B > 0, so this calculation should never be an issue lam_b = np.sqrt(q / delta) L_b = -np.sqrt(q * delta) #those lam's are solns to the pieces of piecewise continuous dual function. #the domains of the pieces depend on whether or not c < 0 (x=0 feasible), #and so projection back on to those domains is determined appropriately. if lam_mid > 0: if c < 0: # here, domain of (a) is [0, lam_mid) # and domain of (b) is (lam_mid, infty) if lam_a > lam_mid: lam_a = lam_mid L_a = L_mid if lam_b < lam_mid: lam_b = lam_mid L_b = L_mid else: # here, domain of (a) is (lam_mid, infty) # and domain of (b) is [0, lam_mid) if lam_a < lam_mid: lam_a = lam_mid L_a = L_mid if lam_b > lam_mid: lam_b = lam_mid L_b = L_mid if L_a >= L_b: lam = lam_a else: lam = lam_b else: if c < 0: lam = lam_b else: lam = lam_a nu = max(0, lam * c - r) / (s + eps) logger.record_tabular( "OptimCase", optim_case) # 4 / 3: trust region totally in safe region; # 2 : trust region partly intersects safe region, and current point is feasible # 1 : trust region partly intersects safe region, and current point is infeasible # 0 : trust region does not intersect safe region logger.record_tabular("LagrangeLamda", lam) # dual variable for trust region logger.record_tabular("LagrangeNu", nu) # dual variable for safety constraint logger.record_tabular("OptimDiagnostic_q", q) # approx = g^T H^{-1} g logger.record_tabular("OptimDiagnostic_r", r) # approx = b^T H^{-1} g logger.record_tabular("OptimDiagnostic_s", s) # approx = b^T H^{-1} b logger.record_tabular("OptimDiagnostic_c", c) # if > 0, constraint is violated logger.record_tabular("OptimDiagnostic_A", A) logger.record_tabular("OptimDiagnostic_B", B) logger.record_tabular("OptimDiagnostic_S", S) if nu == 0: logger.log("safety constraint is not active!") # Predict worst-case next S nextS = S + np.sqrt(delta * s) logger.record_tabular("OptimDiagnostic_WorstNextS", nextS) # for cases where we will not attempt recovery, we stop here. we didn't stop earlier # because first we wanted to record the various critical quantities for understanding the failure mode # (such as optim_case, B, c, S). Also, the logger gets angry if you are inconsistent about recording # a given quantity from iteration to iteration. That's why we have to record a BacktrackIters here. def record_zeros(): logger.record_tabular("BacktrackIters", 0) logger.record_tabular("LossRejects", 0) logger.record_tabular("QuadRejects", 0) logger.record_tabular("LinRejects", 0) if optim_case > 0: flat_descent_step = (1. / (lam + eps)) * (v + nu * w) else: # current default behavior for attempting infeasible recovery: # take a step on natural safety gradient flat_descent_step = np.sqrt(delta / (s + eps)) * w logger.log("descent direction computed") prev_param = np.copy(self._target.get_param_values(trainable=True)) prev_lin_constraint_val = sliced_fun(self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) logger.record_tabular("PrevLinConstVal", prev_lin_constraint_val) lin_reject_threshold = self._max_lin_constraint_val if precomputed_threshold is not None: lin_reject_threshold = precomputed_threshold if diff_threshold: lin_reject_threshold += prev_lin_constraint_val logger.record_tabular("LinRejectThreshold", lin_reject_threshold) def check_nan(): loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan( lin_constraint_val): logger.log("Something is NaN. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True) def line_search(check_loss=True, check_quad=True, check_lin=True): loss_rejects = 0 quad_rejects = 0 lin_rejects = 0 n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) loss_flag = loss < loss_before quad_flag = quad_constraint_val <= self._max_quad_constraint_val lin_flag = lin_constraint_val <= lin_reject_threshold if check_loss and not (loss_flag): logger.log("At backtrack itr %i, loss failed to improve." % n_iter) loss_rejects += 1 if check_quad and not (quad_flag): logger.log( "At backtrack itr %i, quad constraint violated." % n_iter) logger.log( "Quad constraint violation was %.3f %%." % (100 * (quad_constraint_val / self._max_quad_constraint_val) - 100)) quad_rejects += 1 if check_lin and not (lin_flag): logger.log( "At backtrack itr %i, expression for lin constraint failed to improve." % n_iter) logger.log( "Lin constraint violation was %.3f %%." % (100 * (lin_constraint_val / lin_reject_threshold) - 100)) lin_rejects += 1 if (loss_flag or not (check_loss)) and ( quad_flag or not (check_quad)) and (lin_flag or not (check_lin)): logger.log("Accepted step at backtrack itr %i." % n_iter) break logger.record_tabular("BacktrackIters", n_iter) logger.record_tabular("LossRejects", loss_rejects) logger.record_tabular("QuadRejects", quad_rejects) logger.record_tabular("LinRejects", lin_rejects) return loss, quad_constraint_val, lin_constraint_val, n_iter def wrap_up(): if optim_case < 4: lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val logger.record_tabular("LinConstraintDelta", lin_constraint_delta) cur_param = self._target.get_param_values() next_linear_S = S + flat_b.dot(cur_param - prev_param) next_surrogate_S = S + lin_constraint_delta lin_surrogate_acc = 100. * ( next_linear_S - next_surrogate_S) / next_surrogate_S logger.record_tabular("PredictedLinearS", next_linear_S) logger.record_tabular("PredictedSurrogateS", next_surrogate_S) logger.record_tabular("LinearSurrogateErr", lin_surrogate_acc) lin_pred_err = (self._last_lin_pred_S - S) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - S) #/ (S + eps) logger.record_tabular("PredictionErrorLinearS", lin_pred_err) logger.record_tabular("PredictionErrorSurrogateS", surr_pred_err) self._last_lin_pred_S = next_linear_S self._last_surr_pred_S = next_surrogate_S else: logger.record_tabular("LinConstraintDelta", 0) logger.record_tabular("PredictedLinearS", 0) logger.record_tabular("PredictedSurrogateS", 0) logger.record_tabular("LinearSurrogateErr", 0) lin_pred_err = (self._last_lin_pred_S - 0) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - 0) #/ (S + eps) logger.record_tabular("PredictionErrorLinearS", lin_pred_err) logger.record_tabular("PredictionErrorSurrogateS", surr_pred_err) self._last_lin_pred_S = 0 self._last_surr_pred_S = 0 if stop_flag == True: record_zeros() wrap_up() return if optim_case == 1 and not (self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log( "feasible recovery mode: constrained natural gradient step. performing linesearch on constraints." ) line_search(False, True, True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log( "feasible recovery mode: constrained natural gradient step. no linesearch performed." ) check_nan() record_zeros() wrap_up() return elif optim_case == 0 and not (self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log( "infeasible recovery mode: natural safety step. performing linesearch on constraints." ) line_search(False, True, True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log( "infeasible recovery mode: natural safety gradient step. no linesearch performed." ) check_nan() record_zeros() wrap_up() return elif (optim_case == 0 or optim_case == 1) and self.revert_to_last_safe_point: if self.last_safe_point: self._target.set_param_values(self.last_safe_point, trainable=True) logger.log( "infeasible recovery mode: reverted to last safe point!") else: logger.log( "alert! infeasible recovery mode failed: no last safe point to revert to." ) record_zeros() wrap_up() return loss, quad_constraint_val, lin_constraint_val, n_iter = line_search() if (np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val) or loss >= loss_before or quad_constraint_val >= self._max_quad_constraint_val or lin_constraint_val > lin_reject_threshold ) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) if loss >= loss_before: logger.log("Violated because loss not improving") if quad_constraint_val >= self._max_quad_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name_1) if lin_constraint_val > lin_reject_threshold: logger.log( "Violated because constraint %s exceeded threshold" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished") wrap_up()
def constraint_val(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs)
def loss(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs)
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): prev_param = np.copy(self._target.get_param_values(trainable=True)) inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice(n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs logger.log( "Start CG optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(prev_param), len(inputs[0]), len(subsample_inputs[0]))) logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing gradient") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8))) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio**np.arange( self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if self._debug_nan and np.isnan(constraint_val): import ipdb ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val ) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def eval(x): xs = tuple(self.target.flat_to_params(x, trainable=True)) ret = sliced_fun(self.opt_fun["f_Hx_plain"], self._num_slices)(inputs, xs) + self.reg_coeff * x return ret
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): prev_param = np.copy(self._target.get_param_values(trainable=True)) inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice( n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs logger.log("Start CG optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d"%(len(prev_param),len(inputs[0]), len(subsample_inputs[0]))) logger.log("computing loss before") loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) logger.log("performing update") logger.log("computing gradient") flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)(inputs, extra_inputs) logger.log("gradient computed") logger.log("computing descent direction") Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) descent_direction = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters) initial_step_size = np.sqrt( 2.0 * self._max_constraint_val * (1. / (descent_direction.dot(Hx(descent_direction)) + 1e-8)) ) if np.isnan(initial_step_size): initial_step_size = 1. flat_descent_step = initial_step_size * descent_direction logger.log("descent direction computed") n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, constraint_val = sliced_fun(self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if self._debug_nan and np.isnan(constraint_val): import ipdb; ipdb.set_trace() if loss < loss_before and constraint_val <= self._max_constraint_val: break if (np.isnan(loss) or np.isnan(constraint_val) or loss >= loss_before or constraint_val >= self._max_constraint_val) and not self._accept_violation: logger.log("Line search condition violated. Rejecting the step!") if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(constraint_val): logger.log("Violated because constraint %s is NaN" % self._constraint_name) if loss >= loss_before: logger.log("Violated because loss not improving") if constraint_val >= self._max_constraint_val: logger.log("Violated because constraint %s is violated" % self._constraint_name) self._target.set_param_values(prev_param, trainable=True) logger.log("backtrack iters: %d" % n_iter) logger.log("computing loss after") logger.log("optimization finished")
def constraint_val(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_constraint"], self._num_slices)(inputs, extra_inputs)
def loss(self, inputs, extra_inputs=None): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() return sliced_fun(self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs)
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None, precomputed_eval=None, precomputed_threshold=None, diff_threshold=False, inputs2=None, extra_inputs2=None, ): inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() if inputs2 is None: inputs2 = inputs if extra_inputs2 is None: extra_inputs2 = tuple() def subsampled_inputs(inputs,subsample_grouped_inputs): if self._subsample_factor < 1: if subsample_grouped_inputs is None: subsample_grouped_inputs = [inputs] subsample_inputs = tuple() for inputs_grouped in subsample_grouped_inputs: n_samples = len(inputs_grouped[0]) inds = np.random.choice( n_samples, int(n_samples * self._subsample_factor), replace=False) subsample_inputs += tuple([x[inds] for x in inputs_grouped]) else: subsample_inputs = inputs return subsample_inputs subsample_inputs = subsampled_inputs(inputs,subsample_grouped_inputs) if self._resample_inputs: subsample_inputs2 = subsampled_inputs(inputs,subsample_grouped_inputs) loss_before = sliced_fun(self._opt_fun["f_loss"], self._num_slices)( inputs, extra_inputs) flat_g = sliced_fun(self._opt_fun["f_grad"], self._num_slices)( inputs, extra_inputs) flat_b = sliced_fun(self._opt_fun["f_lin_constraint_grad"], self._num_slices)( inputs2, extra_inputs2) Hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) v = krylov.cg(Hx, flat_g, cg_iters=self._cg_iters, verbose=self._verbose_cg) approx_g = Hx(v) q = v.dot(approx_g) delta = 2 * self._max_quad_constraint_val eps = 1e-8 residual = np.sqrt((approx_g - flat_g).dot(approx_g - flat_g)) rescale = q / (v.dot(v)) if self.precompute: S = precomputed_eval assert(np.ndim(S)==0) else: S = sliced_fun(self._opt_fun["lin_constraint"], self._num_slices)(inputs, extra_inputs) c = S - self._max_lin_constraint_val if c > 0: logger.log("warning! safety constraint is already violated") else: self.last_safe_point = np.copy(self._target.get_param_values(trainable=True)) stop_flag = False if flat_b.dot(flat_b) <= eps : lam = np.sqrt(q / delta) nu = 0 w = 0 r,s,A,B = 0,0,0,0 optim_case = 4 else: if self._resample_inputs: Hx = self._hvp_approach.build_eval(subsample_inputs2 + extra_inputs) norm_b = np.sqrt(flat_b.dot(flat_b)) unit_b = flat_b / norm_b w = norm_b * krylov.cg(Hx, unit_b, cg_iters=self._cg_iters, verbose=self._verbose_cg) r = w.dot(approx_g) # approx = b^T H^{-1} g s = w.dot(Hx(w)) # approx = b^T H^{-1} b A = q - r**2 / s # this should always be positive by Cauchy-Schwarz B = delta - c**2 / s # this one says whether or not the closest point on the plane is feasible if c <0 and B < 0: optim_case = 3 elif c < 0 and B > 0: optim_case = 2 elif c > 0 and B > 0: optim_case = 1 if self.attempt_feasible_recovery: logger.log("alert! conjugate constraint optimizer is attempting feasible recovery") else: logger.log("alert! problem is feasible but needs recovery, and we were instructed not to attempt recovery") stop_flag = True else: optim_case = 0 if self.attempt_infeasible_recovery: logger.log("alert! conjugate constraint optimizer is attempting infeasible recovery") else: logger.log("alert! problem is infeasible, and we were instructed not to attempt recovery") stop_flag = True lam = np.sqrt(q / delta) nu = 0 if optim_case == 2 or optim_case == 1: lam_mid = r / c L_mid = - 0.5 * (q / lam_mid + lam_mid * delta) lam_a = np.sqrt(A / (B + eps)) L_a = -np.sqrt(A*B) - r*c / (s + eps) lam_b = np.sqrt(q / delta) L_b = -np.sqrt(q * delta) if lam_mid > 0: if c < 0: if lam_a > lam_mid: lam_a = lam_mid L_a = L_mid if lam_b < lam_mid: lam_b = lam_mid L_b = L_mid else: if lam_a < lam_mid: lam_a = lam_mid L_a = L_mid if lam_b > lam_mid: lam_b = lam_mid L_b = L_mid if L_a >= L_b: lam = lam_a else: lam = lam_b else: if c < 0: lam = lam_b else: lam = lam_a nu = max(0, lam * c - r) / (s + eps) nextS = S + np.sqrt(delta * s) def record_zeros(): logger.record_tabular("BacktrackIters", 0) logger.record_tabular("LossRejects", 0) logger.record_tabular("QuadRejects", 0) logger.record_tabular("LinRejects", 0) if optim_case > 0: flat_descent_step = (1. / (lam + eps) ) * ( v + nu * w ) else: flat_descent_step = np.sqrt(delta / (s + eps)) * w prev_param = np.copy(self._target.get_param_values(trainable=True)) prev_lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_reject_threshold = self._max_lin_constraint_val if precomputed_threshold is not None: lin_reject_threshold = precomputed_threshold if diff_threshold: lin_reject_threshold += prev_lin_constraint_val def check_nan(): loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) if np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val): if np.isnan(loss): logger.log("Violated because loss is NaN") if np.isnan(quad_constraint_val): logger.log("Violated because quad_constraint %s is NaN" % self._constraint_name_1) if np.isnan(lin_constraint_val): logger.log("Violated because lin_constraint %s is NaN" % self._constraint_name_2) self._target.set_param_values(prev_param, trainable=True) def line_search(check_loss=True, check_quad=True, check_lin=True): loss_rejects = 0 quad_rejects = 0 lin_rejects = 0 n_iter = 0 for n_iter, ratio in enumerate(self._backtrack_ratio ** np.arange(self._max_backtracks)): cur_step = ratio * flat_descent_step cur_param = prev_param - cur_step self._target.set_param_values(cur_param, trainable=True) loss, quad_constraint_val, lin_constraint_val = sliced_fun( self._opt_fun["f_loss_constraint"], self._num_slices)(inputs, extra_inputs) loss_flag = loss < loss_before quad_flag = quad_constraint_val <= self._max_quad_constraint_val lin_flag = lin_constraint_val <= lin_reject_threshold if check_loss and not(loss_flag): loss_rejects += 1 if check_quad and not(quad_flag): quad_rejects += 1 if check_lin and not(lin_flag): lin_rejects += 1 if (loss_flag or not(check_loss)) and (quad_flag or not(check_quad)) and (lin_flag or not(check_lin)): break return loss, quad_constraint_val, lin_constraint_val, n_iter def wrap_up(): if optim_case < 4: lin_constraint_val = sliced_fun( self._opt_fun["f_lin_constraint"], self._num_slices)(inputs, extra_inputs) lin_constraint_delta = lin_constraint_val - prev_lin_constraint_val logger.record_tabular("LinConstraintDelta",lin_constraint_delta) cur_param = self._target.get_param_values() next_linear_S = S + flat_b.dot(cur_param - prev_param) next_surrogate_S = S + lin_constraint_delta lin_surrogate_acc = 100.*(next_linear_S - next_surrogate_S) / next_surrogate_S lin_pred_err = (self._last_lin_pred_S - S) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - S) #/ (S + eps) self._last_lin_pred_S = next_linear_S self._last_surr_pred_S = next_surrogate_S else: lin_pred_err = (self._last_lin_pred_S - 0) #/ (S + eps) surr_pred_err = (self._last_surr_pred_S - 0) #/ (S + eps) self._last_lin_pred_S = 0 self._last_surr_pred_S = 0 if stop_flag==True: record_zeros() wrap_up() return if optim_case == 1 and not(self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log("feasible recovery mode: constrained natural gradient step. performing linesearch on constraints.") line_search(False,True,True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log("feasible recovery mode: constrained natural gradient step. no linesearch performed.") check_nan() record_zeros() wrap_up() return elif optim_case == 0 and not(self.revert_to_last_safe_point): if self._linesearch_infeasible_recovery: logger.log("infeasible recovery mode: natural safety step. performing linesearch on constraints.") line_search(False,True,True) else: self._target.set_param_values(prev_param - flat_descent_step, trainable=True) logger.log("infeasible recovery mode: natural safety gradient step. no linesearch performed.") check_nan() record_zeros() wrap_up() return elif (optim_case == 0 or optim_case == 1) and self.revert_to_last_safe_point: if self.last_safe_point: self._target.set_param_values(self.last_safe_point, trainable=True) logger.log("infeasible recovery mode: reverted to last safe point!") else: logger.log("alert! infeasible recovery mode failed: no last safe point to revert to.") record_zeros() wrap_up() return loss, quad_constraint_val, lin_constraint_val, n_iter = line_search() if (np.isnan(loss) or np.isnan(quad_constraint_val) or np.isnan(lin_constraint_val) or loss >= loss_before or quad_constraint_val >= self._max_quad_constraint_val or lin_constraint_val > lin_reject_threshold) and not self._accept_violation: self._target.set_param_values(prev_param, trainable=True) wrap_up()
def f_opt_wrapper(flat_params): self._target.set_param_values(flat_params, trainable=True) return sliced_fun(f_opt, self._n_slices)(inputs)
def optimize(self, inputs, extra_inputs=None, subsample_grouped_inputs=None): if len(inputs) == 0: raise NotImplementedError f_loss = self._opt_fun["f_loss"] f_grad = self._opt_fun["f_grad"] f_grad_tilde = self._opt_fun["f_grad_tilde"] inputs = tuple(inputs) if extra_inputs is None: extra_inputs = tuple() else: extra_inputs = tuple(extra_inputs) param = np.copy(self._target.get_param_values(trainable=True)) logger.log( "Start SVRG CG subsample optimization: #parameters: %d, #inputs: %d, #subsample_inputs: %d" % (len(param), len( inputs[0]), self._subsample_factor * len(inputs[0]))) subsamples = BatchDataset(inputs, int(self._subsample_factor * len(inputs[0])), extra_inputs=extra_inputs) dataset = BatchDataset(inputs, self._batch_size, extra_inputs=extra_inputs) for epoch in range(self._max_epochs): if self._verbose: logger.log("Epoch %d" % (epoch)) progbar = pyprind.ProgBar(len(inputs[0])) # g_u = 1/n \sum_{b} \partial{loss(w_tidle, b)} {w_tidle} grad_sum = np.zeros_like(param) g_mean_tilde = sliced_fun(f_grad_tilde, self._num_slices)(inputs, extra_inputs) logger.record_tabular('g_mean_tilde', LA.norm(g_mean_tilde)) print("-------------mini-batch-------------------") num_batch = 0 while num_batch < self._max_batch: batch = dataset.random_batch() # todo, pick mini-batch with weighted prob. if self._use_SGD: g = f_grad(*(batch)) else: g = f_grad(*(batch)) - \ f_grad_tilde(*(batch)) + g_mean_tilde grad_sum += g subsample_inputs = subsamples.random_batch() pdb.set_trace() Hx = self._hvp_approach.build_eval(subsample_inputs) self.conjugate_grad(g, Hx, inputs, extra_inputs) num_batch += 1 print("max batch achieved {:}".format(num_batch)) grad_sum /= 1.0 * num_batch if self._verbose: progbar.update(batch[0].shape[0]) logger.record_tabular('gdist', LA.norm(grad_sum - g_mean_tilde)) cur_w = np.copy(self._target.get_param_values(trainable=True)) w_tilde = self._target_tilde.get_param_values(trainable=True) self._target_tilde.set_param_values(cur_w, trainable=True) logger.record_tabular('wnorm', LA.norm(cur_w)) logger.record_tabular('w_dist', LA.norm(cur_w - w_tilde) / LA.norm(cur_w)) if self._verbose: if progbar.active: progbar.stop() if abs(LA.norm(cur_w - w_tilde) / LA.norm(cur_w)) < self._tolerance: break
def loss(self, inputs, extra_inputs=None): if extra_inputs is None: extra_inputs = list() # return self._opt_fun["f_loss"](*(list(inputs) + list(extra_inputs))) return sliced_fun(self._opt_fun["f_loss"], self._n_slices)(inputs, extra_inputs)
def loss(self, inputs, extra_inputs=None): shareds, barriers = self._par_objs shareds.loss[self.rank] = self.avg_fac * ext.sliced_fun( self._opt_fun["f_loss"], self._num_slices)(inputs, extra_inputs) barriers.loss.wait() return sum(shareds.loss)