def end_of_exp(self): print(self.name + ' last mspbe is %.5f' % (float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()))) if (self.rho + self.rho_omega + self.rho_ac) > 0: self.result = self.mspbe_history.cpu().numpy() else: self.result = np.log10( mspbe.calc_mspbe_torch( self, self.rho).cpu().numpy()) if self.grid_search else np.log10( self.mspbe_history.cpu().numpy()) self.theta = self.theta.cpu().numpy() self.omega = self.omega.cpu().numpy() self.delete_attrs()
def init_alg(self): if os.path.exists(os.path.join( self.saving_dir_path, 'init_theta.npy')) and os.path.exists( os.path.join(self.saving_dir_path, 'init_omega.npy')): self.theta = torch.as_tensor(np.load( os.path.join(self.saving_dir_path, 'init_theta.npy')), dtype=torch.float64, device=self.device) self.omega = torch.as_tensor(np.load( os.path.join(self.saving_dir_path, 'init_omega.npy')), dtype=torch.float64, device=self.device) else: self.theta = torch.zeros([self.nFeatures], dtype=torch.float64, device=self.device) self.omega = torch.zeros([self.nFeatures], dtype=torch.float64, device=self.device) self.check_pt = self.num_data if self.num_checks == 0 else int( self.num_epoch / self.num_checks) if self.rho_multiplier > 0: #self.rho = torch.mul(mspbe.calc_eig_max_AtCinvA(self), self.rho_multiplier) #self.rho = torch.tensor(0.01, dtype=torch.float32, device=self.device) self.rho = self.rho_multiplier print(self.rho) if self.record_before_one_pass: self.record_points_before_one_pass = [0] self.mspbe_history = torch.unsqueeze( mspbe.calc_mspbe_torch(self, self.rho), 0) self.one_over_num_data = torch.tensor(1 / self.num_data, device=self.device)
def _run(self): svrg.load_mdp_data(self) svrg.init_alg(self) theta_update_counter = 0 omega_update_counter = 0 print('before entering loop ' + str(datetime.datetime.now())) for i in range(self.num_epoch): theta_tilde = self.theta omega_tilde = self.omega theta_tilde_grad = mspbe.mspbe_grad_theta(self.theta, self.omega, self.A, rho=0) omega_tilde_grad = mspbe.mspbe_grad_omega(self.theta, self.omega, self.A, self.b, self.C, self.rho_omega) k = 0 for batch_A_t, batch_b_t, batch_C_t, batch_t_m in self.data_generator: for j in range(self.batch_size): A_t, b_t, C_t, t_m = svrg.get_stoc_data(self, batch_A_t, batch_b_t, batch_C_t, batch_t_m, j) theta_grad = mspbe.mspbe_stoc_grad_theta_torch(self.omega,A_t) + theta_tilde_grad - mspbe.mspbe_stoc_grad_theta_torch(omega_tilde, A_t) omega_grad = mspbe.mspbe_stoc_grad_omega_torch(self.theta, self.omega, A_t, b_t,C_t) + omega_tilde_grad - mspbe.mspbe_stoc_grad_omega_torch(theta_tilde, omega_tilde, A_t, b_t, C_t) if torch.gt(torch.dot(theta_grad, theta_grad), torch.dot(omega_grad, omega_grad)): self.theta.sub_(torch.mul(theta_grad, self.sigma_theta)) theta_update_counter += 1 else: self.omega.sub_(torch.mul(omega_grad, self.sigma_omega)) omega_update_counter += 1 k += 1 if k == self.inner_loop_epoch: self.mspbe_history.append(float(mspbe.calc_mspbe_torch(self, self.rho))) print('before checking') if i % self.check_pt == 0: self.check_values_torch(self.mspbe_history[i]) print('finish epoch ' + str(i)) break print('after loop ' + str(datetime.datetime.now())) svrg.end_of_exp(self) return {'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega, 'name': self.name, 'msg': self.msg}
def _run(self): svrg.load_mdp_data(self) svrg.init_alg(self) if self.terminate_if_less_than_epsilon==False: progress_bar = progressbar.ProgressBar(max_value=self.num_epoch*2) while self.check_termination_cond(): theta_tilde = self.theta.clone() omega_tilde = self.omega.clone() theta_tilde_grad = mspbe.mspbe_grad_theta(self.theta, self.omega, self.A, rho=self.rho) omega_tilde_grad = mspbe.mspbe_grad_omega(self.theta, self.omega, self.A, self.b, self.C, self.rho_omega) self.num_grad_eval += self.num_data if self.record_per_dataset_pass: self.check_complete_data_pass() for batch_A_t, batch_b_t, batch_C_t, batch_t_m in self.data_generator: batch_size = batch_t_m.shape[0] for j in range(batch_size): A_t, b_t, C_t, t_m = svrg.get_stoc_data(self, batch_A_t, batch_b_t, batch_C_t, batch_t_m, j) theta_grad = mspbe.mspbe_grad_theta(self.theta, self.omega, A_t, rho=self.rho) + theta_tilde_grad - mspbe.mspbe_grad_theta(theta_tilde, omega_tilde, A_t, rho=self.rho) omega_grad = mspbe.mspbe_grad_omega(self.theta, self.omega, A_t, b_t, C_t, self.rho_omega) + omega_tilde_grad - mspbe.mspbe_grad_omega(theta_tilde,omega_tilde,A_t,b_t,C_t, self.rho_omega) self.theta.sub_(torch.mul(theta_grad, self.sigma_theta)) self.omega.sub_(torch.mul(omega_grad, self.sigma_omega)) self.num_grad_eval += batch_size if self.record_per_dataset_pass: self.check_complete_data_pass() #Temporary mspbe_at_epoch = float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()) print('svrg mspbe = ' + "{0:.3e}".format(mspbe_at_epoch)) self.end_of_epoch() if self.terminate_if_less_than_epsilon == False: progress_bar.update(self.num_pass) if self.record_per_dataset_pass else progress_bar.update(self.cur_epoch) svrg.end_of_exp(self) return {'theta':self.theta, 'omega':self.omega, 'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega, 'inner_loop_multiplier':self.inner_loop_multiplier, 'name': self.name, 'record_per_dataset_pass':self.record_per_dataset_pass, 'record_per_epoch':self.record_per_epoch, 'comp_cost':self.num_pass, 'rho': self.rho, 'rho_ac': self.rho_ac}
def check_complete_data_pass(self): if self.num_grad_eval >= self.num_data: mspbe_val = mspbe.calc_mspbe_torch(self, self.rho) if self.num_pass % self.check_pt == 0: self.check_values_torch(mspbe_val) self.mspbe_history = torch.cat( (self.mspbe_history, torch.unsqueeze(mspbe_val, 0))) self.num_pass += 1 if self.record_before_one_pass: self.record_points_before_one_pass.append(self.num_grad_eval) self.num_grad_eval = self.num_grad_eval - self.num_data
def _run(self): svrg.load_mdp_data(self) svrg.init_alg(self) full_dataset = mdp_dataset(self) scsg_batch_size = int(self.num_data * self.scsg_batch_size_ratio) geom_dist_p = 1/(scsg_batch_size+1) #rho = 1e-2*mspbe.calc_L_rho(self) if self.terminate_if_less_than_epsilon==False: progress_bar = progressbar.ProgressBar(max_value=self.num_epoch+50) while self.check_termination_cond(): theta_tilde = self.theta.clone() omega_tilde = self.omega.clone() theta_tilde_grad, omega_tilde_grad = self.get_grad_theta_omega_from_batch_abc(self.theta, self.omega, full_dataset, torch.randperm(self.num_data)[:scsg_batch_size], scsg_batch_size, self.rho) torch.cuda.empty_cache() self.num_grad_eval += scsg_batch_size if self.record_per_dataset_pass: self.check_complete_data_pass() if self.use_geometric_dist: inner_loop_epoch = np.random.geometric(geom_dist_p) else: inner_loop_epoch = int(self.num_data * self.scsg_batch_size_ratio) sampler = data.RandomSampler(torch.arange(self.num_data), replacement=True, num_samples=inner_loop_epoch) data_generator = data.DataLoader(full_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers, drop_last=False) for batch_A_t, batch_b_t, batch_C_t, batch_t_m in data_generator: batch_size = batch_t_m.shape[0] for j in range(batch_size): A_t, b_t, C_t, t_m = svrg.get_stoc_data(self, batch_A_t, batch_b_t, batch_C_t, batch_t_m, j) theta_grad = mspbe.mspbe_grad_theta(self.theta, self.omega, A_t, rho=self.rho) + theta_tilde_grad - mspbe.mspbe_grad_theta(theta_tilde, omega_tilde, A_t, rho=self.rho) omega_grad = mspbe.mspbe_grad_omega(self.theta, self.omega, A_t, b_t, C_t, self.rho_omega) + omega_tilde_grad - mspbe.mspbe_grad_omega(theta_tilde,omega_tilde,A_t,b_t,C_t, self.rho_omega) self.theta.sub_(torch.mul(theta_grad, self.sigma_theta)) self.omega.sub_(torch.mul(omega_grad, self.sigma_omega)) self.num_grad_eval += inner_loop_epoch if self.record_per_dataset_pass: self.check_complete_data_pass() if self.record_before_one_pass: self.record_value_before_one_pass() # Temporary mspbe_at_epoch = float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()) print('scsg ratio = '+ str(self.scsg_batch_size_ratio) + ' sigma_theta =' + str(self.sigma_theta) + ' sigma_omega = ' + str(self.sigma_omega) + ' scsg mspbe = %.5f' % (mspbe_at_epoch)) self.end_of_epoch() if self.terminate_if_less_than_epsilon==False: progress_bar.update(self.num_pass) if self.record_per_dataset_pass else progress_bar.update(self.cur_epoch) svrg.end_of_exp(self) #Temporary if self.record_before_one_pass: return {'record_points_before_one_pass':self.record_points_before_one_pass, 'use_geom_dist':self.use_geometric_dist, 'theta':self.theta, 'omega':self.omega, 'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega,'name': self.name, 'scsg_batch_size_ratio':self.scsg_batch_size_ratio, 'record_per_dataset_pass':self.record_per_dataset_pass, 'record_per_epoch':self.record_per_epoch, 'comp_cost':self.num_pass, 'rho': self.rho, 'rho_ac': self.rho_ac} else: return {'use_geom_dist': self.use_geometric_dist, 'theta': self.theta, 'omega': self.omega, 'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega, 'name': self.name, 'scsg_batch_size_ratio': self.scsg_batch_size_ratio, 'record_per_dataset_pass': self.record_per_dataset_pass, 'record_per_epoch': self.record_per_epoch, 'comp_cost': self.num_pass, 'rho': self.rho, 'rho_ac': self.rho_ac}
def _run(self): svrg.load_mdp_data(self) svrg.init_alg(self) outer_loop_batch_size = int(self.num_data * self.batch_svrg_init_ratio) full_dataset = mdp_dataset(self) if self.terminate_if_less_than_epsilon == False: progress_bar = progressbar.ProgressBar(max_value=self.num_epoch*2) while self.check_termination_cond(): theta_tilde = self.theta.clone() omega_tilde = self.omega.clone() if outer_loop_batch_size>=self.num_data: theta_tilde_grad = mspbe.mspbe_grad_theta(self.theta, self.omega, self.A, rho=self.rho) omega_tilde_grad = mspbe.mspbe_grad_omega(self.theta, self.omega, self.A, self.b, self.C, self.rho_omega) self.num_grad_eval += self.num_data else: theta_tilde_grad, omega_tilde_grad = self.get_grad_theta_omega_from_batch_abc(self.theta, self.omega, full_dataset, torch.randperm(self.num_data)[:outer_loop_batch_size], outer_loop_batch_size, rho=self.rho) torch.cuda.empty_cache() self.num_grad_eval += outer_loop_batch_size if self.record_per_dataset_pass: self.check_complete_data_pass() for batch_A_t, batch_b_t, batch_C_t, batch_t_m in self.data_generator: batch_size = batch_t_m.shape[0] for j in range(batch_size): A_t, b_t, C_t, t_m = svrg.get_stoc_data(self, batch_A_t, batch_b_t, batch_C_t, batch_t_m, j) theta_grad = mspbe.mspbe_grad_theta(self.theta, self.omega, A_t, rho=self.rho) + theta_tilde_grad - mspbe.mspbe_grad_theta(theta_tilde, omega_tilde, A_t, rho=self.rho) omega_grad = mspbe.mspbe_grad_omega(self.theta, self.omega, A_t, b_t, C_t, self.rho_omega) + omega_tilde_grad - mspbe.mspbe_grad_omega(theta_tilde,omega_tilde,A_t,b_t,C_t, self.rho_omega) self.theta.sub_(torch.mul(theta_grad, self.sigma_theta)) self.omega.sub_(torch.mul(omega_grad, self.sigma_omega)) self.num_grad_eval += batch_size if self.record_per_dataset_pass: self.check_complete_data_pass() if self.record_before_one_pass: self.record_value_before_one_pass() # Temporary mspbe_at_epoch = float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()) print('batch svrg mspbe = ' + "{0:.3e}".format(mspbe_at_epoch)) self.end_of_epoch() outer_loop_batch_size = int(outer_loop_batch_size * self.batch_svrg_increment_ratio) if self.terminate_if_less_than_epsilon == False: progress_bar.update(self.num_pass) if self.record_per_dataset_pass else progress_bar.update(self.cur_epoch) svrg.end_of_exp(self) if self.record_before_one_pass: return {'record_points_before_one_pass':self.record_points_before_one_pass, 'theta':self.theta, 'omega':self.omega, 'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega,'name': self.name, 'record_per_dataset_pass':self.record_per_dataset_pass, 'record_per_epoch':self.record_per_epoch, 'comp_cost':self.num_pass, 'batch_svrg_init_ratio':self.batch_svrg_init_ratio, 'batch_svrg_increment_ratio':self.batch_svrg_increment_ratio, 'rho': self.rho} else: return {'theta':self.theta, 'omega':self.omega, 'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega,'name': self.name, 'record_per_dataset_pass':self.record_per_dataset_pass, 'record_per_epoch':self.record_per_epoch, 'comp_cost':self.num_pass, 'batch_svrg_init_ratio':self.batch_svrg_init_ratio, 'batch_svrg_increment_ratio':self.batch_svrg_increment_ratio, 'rho': self.rho}
def check_termination_cond(self): if self.terminate_if_less_than_epsilon: if (self.record_per_epoch and self.cur_epoch % 100 == 0) or (self.record_per_dataset_pass and self.num_pass % 100 == 0): mspbe_val = float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()) np.save(os.path.join(self.saving_dir_path, 'theta.npy'), self.theta.cpu().numpy()) np.save(os.path.join(self.saving_dir_path, 'omega.npy'), self.omega.cpu().numpy()) if mspbe_val < self.policy_eval_epsilon: print(self.name + ' terminate in ' + str(self.cur_epoch) + ' epochs.') return False else: print('epoch ' + str(self.cur_epoch) + '. mspbe = %.5f' % (mspbe_val)) if self.record_per_epoch else print('epoch ' + str(self.num_pass) + '. mspbe = %.5f' % (mspbe_val)) return True else: return True elif self.record_per_dataset_pass: return self.num_pass < self.num_epoch elif self.record_per_epoch: return self.cur_epoch < self.num_epoch else: raise ValueError('invalid option')
def _run(self): svrg.load_mdp_data(self) svrg.init_alg(self) g_t_theta, g_t_omega, B_theta, B_omega = self.create_grad_pool() print('finish generating grad pool' + str(datetime.datetime.now())) self.num_grad_eval += self.num_data if self.record_per_dataset_pass: self.check_complete_data_pass() if self.terminate_if_less_than_epsilon==False: progress_bar = progressbar.ProgressBar(max_value=self.num_epoch+50) while self.check_termination_cond(): for batch_A_t, batch_b_t, batch_C_t, batch_t_m in self.data_generator: batch_size = batch_t_m.shape[0] for j in range(batch_size): A_t, b_t, C_t, t_m = svrg.get_stoc_data(self, batch_A_t, batch_b_t, batch_C_t, batch_t_m, j) h_tm_theta = mspbe.mspbe_grad_theta(self.theta, self.omega, A_t, self.rho) h_tm_omega = mspbe.mspbe_grad_omega(self.theta, self.omega, A_t, b_t, C_t, self.rho_omega) g_tm_theta = g_t_theta[:, t_m] g_tm_omega = g_t_omega[:, t_m] theta_grad = self.sigma_theta * (B_theta + h_tm_theta - g_tm_theta) omega_grad = self.sigma_omega * (B_omega + h_tm_omega - g_tm_omega) self.theta.sub_(torch.mul(theta_grad, self.sigma_theta)) self.omega.sub_(torch.mul(omega_grad, self.sigma_omega)) B_theta = B_theta + self.one_over_num_data * (h_tm_theta - g_tm_theta) B_omega = B_omega + self.one_over_num_data * (h_tm_omega - g_tm_omega) g_t_theta[:, t_m] = h_tm_theta g_t_omega[:, t_m] = h_tm_omega self.num_grad_eval += batch_size self.check_complete_data_pass() if self.terminate_if_less_than_epsilon == False: progress_bar.update(self.num_pass) if self.record_per_dataset_pass else progress_bar.update(self.cur_epoch) # Temporary mspbe_at_epoch = float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()) print('saga sigma_theta =' + str(self.sigma_theta) + ' sigma_omega = ' + str(self.sigma_omega) + ' saga mspbe = %.5f' % (mspbe_at_epoch)) svrg.end_of_exp(self) return {'theta': self.theta, 'omega': self.omega, 'result': self.result, 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega, 'name': self.name, 'record_per_dataset_pass': self.record_per_dataset_pass, 'record_per_epoch': self.record_per_epoch, 'comp_cost': self.num_pass, 'rho': self.rho, 'rho_ac': self.rho_ac}
def handle_epoch_result(self, i, batch_j): self.batch_result[batch_j] = mspbe.calc_mspbe_torch(self, self.rho) if i % self.check_pt == 0: self.check_values_torch( float(mspbe.calc_mspbe_torch(self, self.rho)))
def record_value_before_one_pass(self): mspbe_val = mspbe.calc_mspbe_torch(self, self.rho) self.mspbe_history = torch.cat( (self.mspbe_history, torch.unsqueeze(mspbe_val, 0))) self.record_points_before_one_pass.append(self.num_grad_eval)
def end_of_epoch(self): if self.record_per_epoch: mspbe_val = mspbe.calc_mspbe_torch(self, self.rho) if self.cur_epoch % self.check_pt == 0: self.check_values_torch(mspbe_val) self.mspbe_history = torch.cat((self.mspbe_history, torch.unsqueeze(mspbe_val, 0))) self.cur_epoch += 1