def learn(self, num_episodes=300, batch_size=None, print_progess_frequency=10, min_replay_samples=50, repeat_train=16, imshow=True): # 智能體學習的主方法 if batch_size is not None: self.batch_size = batch_size # 學習一開始將steps_done清零,逐步降低隨機決策比率 self.steps_done = 0 train_cnt = 0 success_cnt = 0 keep_success_cnt = 0 start_train_episode = 0 start_train = False # 收集初始資料 while start_train == False: # 重置環境 self.env.reset() # 獎賞清零 total_rewards = 0 state = self.get_observation() for t in count(): # 基於目前狀態產生行動 action = self.select_action(state, model_only=False) # 基於行動產生獎賞以及判斷是否結束(此時已經更新至下一個時間點) reward, done = self.get_rewards(action) # 累積獎賞 total_rewards += reward # 任務完成強制終止(以300為基礎) conplete = (not done and t + 1 >= 300) if imshow: # 更新視覺化螢幕 self.env.render() # 取得下一時間點觀察值 next_state = None if done and not conplete else self.get_observation( ) # 將四元組儲存於記憶中 # 如果要減少「好案例」的儲存比例請移除註解 self.memory.push(state, action, next_state, reward) if len(self.memory) % 100 == 0: print("Replay Samples:{0}".format(len(self.memory))) if len(self.memory) == min_replay_samples: print('Start Train!!', flush=True) # 需要記憶中的案例數大於批次數才開始訓練 start_train = (len(self.memory) >= min_replay_samples) break # 切換至下一狀態 state = copy.deepcopy(next_state) if done or conplete: break # 開始訓練模式 self.training_context['steps'] = 0 self.steps_done = 0 for i_episode in range(num_episodes): for i in range(repeat_train): # 經驗回放獲得訓練用批次數據 self.output_fn = self.experience_replay # 訓練模型 self.train_model(None, None, current_epoch=i_episode, current_batch=i, total_epoch=num_episodes, total_batch=repeat_train, is_collect_data=True if t >= 0 else False, is_print_batch_progress=False, is_print_epoch_progress=False, log_gradients=False, log_weights=False, accumulate_grads=False) # 定期更新target_net權值 if i_episode % self.target_update == 0: self.target_net.load_state_dict(self.policy_net.state_dict(), strict=True) self.save_model(save_path=self.training_context['save_path']) # 重置環境 self.env.reset() # 獎賞清零 total_rewards = 0 state = self.get_observation() tmp_memory = [] for t in count(): # 透過優化器進行一步優化 # 基於目前狀態產生行動 action = self.select_action(state, model_only=True) # 基於行動產生獎賞以及判斷是否結束(此時已經更新至下一個時間點) reward, done = self.get_rewards(action) # 累積獎賞 total_rewards += reward # 任務完成強制終止(以300為基礎) conplete = (not done and t + 1 >= 300) if imshow: # 更新視覺化螢幕 self.env.render() # 取得下一時間點觀察值 next_state = None if done else self.get_observation() # 將四元組儲存於記憶中 tmp_memory.append((state, action, next_state, reward)) # 切換至下一狀態 state = next_state if done or conplete: if t >= 200: success_cnt += 1 else: success_cnt = 0 # 判斷是否連續可達300分,如果是則停止學習 if t + 1 >= 300: keep_success_cnt += 1 else: keep_success_cnt = 0 if keep_success_cnt >= 2: self.training_context['stop_update'] = 1 else: self.training_context['stop_update'] = 0 # 紀錄累積獎賞 self.epoch_metric_history.collect('total_rewards', i_episode, float(total_rewards)) self.epoch_metric_history.collect('original_rewards', i_episode, float(t)) # 紀錄完成比率(以200為基礎) self.epoch_metric_history.collect( 'task_complete', i_episode, 1.0 if t + 1 >= 200 else 0.0) # 定期列印學習進度 if i_episode > 0 and i_episode % print_progess_frequency == 0: self.print_epoch_progress(print_progess_frequency) # 定期繪製損失函數以及評估函數對時間的趨勢圖 if i_episode > 0 and i_episode % ( 5 * print_progess_frequency) == 0: print( 'negative_reward_ratio:', less( self.training_context['train_data'] ['reward_batch'], 0).mean().item()) print( 'predict_rewards:', self.training_context['train_data'] ['predict_rewards'].copy()[:5, 0]) print( 'target_rewards:', self.training_context['train_data'] ['target_rewards'].copy()[:5, 0]) print( 'reward_batch:', self.training_context['train_data'] ['reward_batch'].copy()[:5]) loss_metric_curve(self.epoch_loss_history, self.epoch_metric_history, legend=['dqn'], calculate_base='epoch', imshow=imshow) if success_cnt == 50: self.save_model( save_path=self.training_context['save_path']) print('50 episodes success, training finish! ') return True break # print([item[3] for item in tmp_memory]) sample_idx = [] indexs = list(range(len(tmp_memory))) if len(tmp_memory) > 10: # 只保留失敗前的3筆以及隨機抽樣sqrt(len(tmp_memory))+5筆 sample_idx.extend(indexs[-1 * min(3, len(tmp_memory)):]) sample_idx.extend( random_choice(indexs[:-3], int(sqrt(len(tmp_memory))))) sample_idx = list(set(sample_idx)) for k in range(len(tmp_memory)): state, action, next_state, reward = tmp_memory[k] if k in sample_idx or (k + 3 < len(tmp_memory) and tmp_memory[k + 1][3] < 1) or reward < 1: self.memory.push(state, action, next_state, reward) print('Complete') self.env.render() self.env.close() plt.ioff() plt.show()
def learn(self, num_episodes=300, batch_size=None, print_progess_frequency=10, imshow=True): """The main method for the agent learn Returns: object: """ if batch_size is not None: self.batch_size = batch_size self.steps_done = 0 for i_episode in range(num_episodes): # reset enviorment self.env.reset() # clear rewards total_rewards = 0 state = self.get_observation() # 需要記憶中的案例數大於批次數才開始訓練 start_train = (len(self.memory) > self.batch_size) for t in count(): # 基於目前狀態產生行動 action = self.select_action(state) # 基於行動產生獎賞以及判斷是否結束(此時已經更新至下一個時間點) reward, done = self.get_rewards(action) # 累積獎賞 total_rewards += reward # 任務完成強制終止(以300為基礎) conplete = (not done and t + 1 >= 300) if imshow: # 更新視覺化螢幕 self.env.render() # get next state next_state = self.get_observation() # 將四元組儲存於記憶中,建議要減少「好案例」的儲存比例 if reward < 1 or (reward == 1 and i_episode < 20) or ( reward == 1 and i_episode >= 20 and t < 100 and random.random() < 0.1 and i_episode >= 20 and t >= 100 and random.random() < 0.2): self.memory.push(state, action, next_state, reward) # switch next t state = deepcopy(next_state) if start_train: # get batch data from experimental replay trainData = self.experience_replay(self.batch_size) # switch model to training mode self.policy_net.train() self.train_model( trainData, None, current_epoch=i_episode, current_batch=t, total_epoch=num_episodes, total_batch=t + 1 if done or conplete else t + 2, is_collect_data=True if done or conplete else False, is_print_batch_progress=False, is_print_epoch_progress=False, log_gradients=False, log_weights=False, accumulate_grads=False) if done or conplete: if start_train: # self.epoch_metric_history.collect('episode_durations',i_episode,float(t)) # 紀錄累積獎賞 self.epoch_metric_history.collect( 'total_rewards', i_episode, float(total_rewards)) # 紀錄完成比率(以200為基礎) self.epoch_metric_history.collect( 'task_complete', i_episode, 1.0 if t + 1 >= 200 else 0.0) # 定期列印學習進度 if i_episode % print_progess_frequency == 0: self.print_epoch_progress(print_progess_frequency) # 定期繪製損失函數以及評估函數對時間的趨勢圖 if i_episode > 0 and (i_episode + 1) % ( 5 * print_progess_frequency) == 0: print('epsilon:', self.epsilon) print( 'predict_rewards:', self.training_context['train_data'] ['predict_rewards'][:5]) print( 'target_rewards:', self.training_context['train_data'] ['target_rewards'][:5]) print( 'reward_batch:', self.training_context['train_data'] ['reward_batch'][:5]) loss_metric_curve(self.epoch_loss_history, self.epoch_metric_history, legend=['dqn'], calculate_base='epoch', imshow=imshow) break # 定期更新target_net權值 if start_train and i_episode % self.target_update == 0: self.target_net.load_state_dict(self.policy_net.state_dict(), strict=True) self.save_model(save_path=self.training_context['save_path']) print('Complete') self.env.render() self.env.close() plt.ioff() plt.show()
def plot_spectrum(result, correct = True, interactive = False): plt.close('all') plt.ioff() if interactive: plt.ion() hdu = fits.open(result['ORIGINALFILE']) galaxy = gaussian_filter(hdu[1].data, 1) thumbnail = hdu['THUMBNAIL'].data twoD = hdu['2D'].data header = hdu[0].header header1 = hdu[1].header hdu.close() lamRange = header1['CRVAL1'] + np.array([0., header1['CD1_1'] * (header1['NAXIS1'] - 1)]) if correct: zp = 1. + (result['VREL'] / 299792.458) else: zp = 1. wavelength = np.linspace(lamRange[0],lamRange[1], header1['NAXIS1']) / zp ymin, ymax = np.min(galaxy), np.max(galaxy) ylim = [ymin, ymax] + np.array([-0.02, 0.1])*(ymax-ymin) ylim[0] = 0. xmin, xmax = np.min(wavelength), np.max(wavelength) ### Define multipanel size and properties fig = plt.figure(figsize=(8,6)) gs = gridspec.GridSpec(200,130,bottom=0.10,left=0.10,right=0.95) ### Plot the object in the sky ax_obj = fig.add_subplot(gs[0:70,105:130]) ax_obj.imshow(thumbnail, cmap = 'gray', interpolation = 'nearest') ax_obj.set_xticks([]) ax_obj.set_yticks([]) ### Plot the 2D spectrum ax_2d = fig.add_subplot(gs[0:11,0:100]) ix_start = header['START_{}'.format(int(result['DETECT']))] ix_end = header['END_{}'.format(int(result['DETECT']))] ax_2d.imshow(twoD, cmap='spectral', aspect = "auto", origin = 'lower', extent=[xmin, xmax, 0, 1], vmin = -0.2, vmax=0.2) ax_2d.set_xticks([]) ax_2d.set_yticks([]) ### Add spectra subpanels ax_spectrum = fig.add_subplot(gs[11:85,0:100]) ax_blue = fig.add_subplot(gs[110:200,0:50]) ax_red = fig.add_subplot(gs[110:200,51:100]) ### Plot some atomic lines line_wave = [4861., 5175., 5892., 6562.8, 8498., 8542., 8662.] # ['Hbeta', 'Mgb', 'NaD', 'Halpha', 'CaT', 'CaT', 'CaT'] for i in range(len(line_wave)): x = [line_wave[i], line_wave[i]] y = [ylim[0], ylim[1]] ax_spectrum.plot(x, y, c= 'gray', linewidth=1.0) ax_blue.plot(x, y, c= 'gray', linewidth=1.0) ax_red.plot(x, y, c= 'gray', linewidth=1.0) ### Plot the spectrum ax_spectrum.plot(wavelength, galaxy, 'k', linewidth=1.3) ax_spectrum.set_ylim(ylim) ax_spectrum.set_xlim([xmin,xmax]) ax_spectrum.set_ylabel(r'Arbitrary Flux') ax_spectrum.set_xlabel(r'Restframe Wavelength [ $\AA$ ]') ### Plot blue part of the spectrum x1, x2 = 300, 750 ax_blue.plot(wavelength[x1:x2], galaxy[x1:x2], 'k', linewidth=1.3) ax_blue.set_xlim(wavelength[x1],wavelength[x2]) ax_blue.set_ylim(galaxy[x1:x2].min(), galaxy[x1:x2].max()) ax_blue.set_yticks([]) ### Plot red part of the spectrum x1, x2 = 1400, 1500 ax_red.plot(wavelength[x1:x2], galaxy[x1:x2], 'k', linewidth=1.3) ax_red.set_xlim(wavelength[x1],wavelength[x2]) ax_red.set_ylim(galaxy[x1:x2].min(), galaxy[x1:x2].max()) ax_red.set_yticks([]) ### Plot text #if interactive: textplot = fig.add_subplot(gs[80:200,105:130]) kwarg = {'va' : 'center', 'ha' : 'left', 'size' : 'medium'} textplot.text(0.1, 1.0,r'ID = {} \, {}'.format(result.ID, int(result.DETECT)),**kwarg) textplot.text(0.1, 0.9,r'$v =$ {}'.format(int(result.VREL)), **kwarg) textplot.text(0.1, 0.8,r'$\delta \, v = $ {}'.format(int(result.VERR)), **kwarg) textplot.text(0.1, 0.7,r'SN1 = {0:.2f}'.format(result.SN1), **kwarg) textplot.text(0.1, 0.6,r'SN2 = {0:.2f}'.format(result.SN2), **kwarg) textplot.text(0.1, 0.5,r'TDR = {0:.2f}'.format(result.TDR), **kwarg) textplot.text(0.1, 0.4,r'SG = {}'.format(result.SG), **kwarg) textplot.axis('off') return fig