Esempio n. 1
0
    def learn(self,
              num_episodes=300,
              batch_size=None,
              print_progess_frequency=10,
              min_replay_samples=50,
              repeat_train=16,
              imshow=True):
        # 智能體學習的主方法
        if batch_size is not None:
            self.batch_size = batch_size

        # 學習一開始將steps_done清零,逐步降低隨機決策比率
        self.steps_done = 0
        train_cnt = 0
        success_cnt = 0
        keep_success_cnt = 0
        start_train_episode = 0
        start_train = False
        # 收集初始資料
        while start_train == False:
            # 重置環境
            self.env.reset()
            # 獎賞清零
            total_rewards = 0
            state = self.get_observation()

            for t in count():
                # 基於目前狀態產生行動
                action = self.select_action(state, model_only=False)
                # 基於行動產生獎賞以及判斷是否結束(此時已經更新至下一個時間點)
                reward, done = self.get_rewards(action)

                # 累積獎賞
                total_rewards += reward

                # 任務完成強制終止(以300為基礎)
                conplete = (not done and t + 1 >= 300)

                if imshow:
                    # 更新視覺化螢幕
                    self.env.render()
                # 取得下一時間點觀察值
                next_state = None if done and not conplete else self.get_observation(
                )

                # 將四元組儲存於記憶中
                # 如果要減少「好案例」的儲存比例請移除註解
                self.memory.push(state, action, next_state, reward)
                if len(self.memory) % 100 == 0:
                    print("Replay Samples:{0}".format(len(self.memory)))
                if len(self.memory) == min_replay_samples:
                    print('Start Train!!', flush=True)
                    # 需要記憶中的案例數大於批次數才開始訓練
                    start_train = (len(self.memory) >= min_replay_samples)
                    break

                # 切換至下一狀態
                state = copy.deepcopy(next_state)

                if done or conplete:
                    break

        # 開始訓練模式
        self.training_context['steps'] = 0
        self.steps_done = 0
        for i_episode in range(num_episodes):
            for i in range(repeat_train):
                # 經驗回放獲得訓練用批次數據
                self.output_fn = self.experience_replay

                # 訓練模型
                self.train_model(None,
                                 None,
                                 current_epoch=i_episode,
                                 current_batch=i,
                                 total_epoch=num_episodes,
                                 total_batch=repeat_train,
                                 is_collect_data=True if t >= 0 else False,
                                 is_print_batch_progress=False,
                                 is_print_epoch_progress=False,
                                 log_gradients=False,
                                 log_weights=False,
                                 accumulate_grads=False)

            # 定期更新target_net權值
            if i_episode % self.target_update == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict(),
                                                strict=True)
                self.save_model(save_path=self.training_context['save_path'])

            # 重置環境
            self.env.reset()
            # 獎賞清零
            total_rewards = 0
            state = self.get_observation()
            tmp_memory = []

            for t in count():
                # 透過優化器進行一步優化

                # 基於目前狀態產生行動
                action = self.select_action(state, model_only=True)
                # 基於行動產生獎賞以及判斷是否結束(此時已經更新至下一個時間點)
                reward, done = self.get_rewards(action)
                # 累積獎賞
                total_rewards += reward

                # 任務完成強制終止(以300為基礎)
                conplete = (not done and t + 1 >= 300)

                if imshow:
                    # 更新視覺化螢幕
                    self.env.render()
                # 取得下一時間點觀察值
                next_state = None if done else self.get_observation()

                # 將四元組儲存於記憶中
                tmp_memory.append((state, action, next_state, reward))

                # 切換至下一狀態
                state = next_state

                if done or conplete:
                    if t >= 200:
                        success_cnt += 1
                    else:
                        success_cnt = 0

                    # 判斷是否連續可達300分,如果是則停止學習

                    if t + 1 >= 300:
                        keep_success_cnt += 1
                    else:
                        keep_success_cnt = 0
                    if keep_success_cnt >= 2:
                        self.training_context['stop_update'] = 1
                    else:
                        self.training_context['stop_update'] = 0

                    # 紀錄累積獎賞
                    self.epoch_metric_history.collect('total_rewards',
                                                      i_episode,
                                                      float(total_rewards))
                    self.epoch_metric_history.collect('original_rewards',
                                                      i_episode, float(t))
                    # 紀錄完成比率(以200為基礎)
                    self.epoch_metric_history.collect(
                        'task_complete', i_episode,
                        1.0 if t + 1 >= 200 else 0.0)
                    # 定期列印學習進度
                    if i_episode > 0 and i_episode % print_progess_frequency == 0:
                        self.print_epoch_progress(print_progess_frequency)
                    # 定期繪製損失函數以及評估函數對時間的趨勢圖
                    if i_episode > 0 and i_episode % (
                            5 * print_progess_frequency) == 0:
                        print(
                            'negative_reward_ratio:',
                            less(
                                self.training_context['train_data']
                                ['reward_batch'], 0).mean().item())
                        print(
                            'predict_rewards:',
                            self.training_context['train_data']
                            ['predict_rewards'].copy()[:5, 0])
                        print(
                            'target_rewards:',
                            self.training_context['train_data']
                            ['target_rewards'].copy()[:5, 0])
                        print(
                            'reward_batch:',
                            self.training_context['train_data']
                            ['reward_batch'].copy()[:5])
                        loss_metric_curve(self.epoch_loss_history,
                                          self.epoch_metric_history,
                                          legend=['dqn'],
                                          calculate_base='epoch',
                                          imshow=imshow)

                    if success_cnt == 50:
                        self.save_model(
                            save_path=self.training_context['save_path'])
                        print('50 episodes success, training finish! ')
                        return True

                    break
            # print([item[3] for item in tmp_memory])
            sample_idx = []
            indexs = list(range(len(tmp_memory)))
            if len(tmp_memory) > 10:
                # 只保留失敗前的3筆以及隨機抽樣sqrt(len(tmp_memory))+5筆
                sample_idx.extend(indexs[-1 * min(3, len(tmp_memory)):])
                sample_idx.extend(
                    random_choice(indexs[:-3], int(sqrt(len(tmp_memory)))))

            sample_idx = list(set(sample_idx))
            for k in range(len(tmp_memory)):
                state, action, next_state, reward = tmp_memory[k]
                if k in sample_idx or (k + 3 < len(tmp_memory) and
                                       tmp_memory[k + 1][3] < 1) or reward < 1:
                    self.memory.push(state, action, next_state, reward)

        print('Complete')
        self.env.render()
        self.env.close()
        plt.ioff()
        plt.show()
Esempio n. 2
0
    def learn(self,
              num_episodes=300,
              batch_size=None,
              print_progess_frequency=10,
              imshow=True):
        """The main method for the agent learn

        Returns:
            object:
        """
        if batch_size is not None:
            self.batch_size = batch_size

        self.steps_done = 0
        for i_episode in range(num_episodes):
            # reset enviorment
            self.env.reset()
            # clear rewards
            total_rewards = 0
            state = self.get_observation()

            # 需要記憶中的案例數大於批次數才開始訓練
            start_train = (len(self.memory) > self.batch_size)
            for t in count():
                # 基於目前狀態產生行動
                action = self.select_action(state)
                # 基於行動產生獎賞以及判斷是否結束(此時已經更新至下一個時間點)
                reward, done = self.get_rewards(action)
                # 累積獎賞
                total_rewards += reward

                # 任務完成強制終止(以300為基礎)
                conplete = (not done and t + 1 >= 300)

                if imshow:
                    # 更新視覺化螢幕
                    self.env.render()
                # get next state
                next_state = self.get_observation()

                # 將四元組儲存於記憶中,建議要減少「好案例」的儲存比例
                if reward < 1 or (reward == 1 and i_episode < 20) or (
                        reward == 1 and i_episode >= 20 and t < 100
                        and random.random() < 0.1 and i_episode >= 20
                        and t >= 100 and random.random() < 0.2):
                    self.memory.push(state, action, next_state, reward)

                # switch next t
                state = deepcopy(next_state)

                if start_train:
                    # get batch data from experimental replay
                    trainData = self.experience_replay(self.batch_size)
                    # switch model to training mode
                    self.policy_net.train()
                    self.train_model(
                        trainData,
                        None,
                        current_epoch=i_episode,
                        current_batch=t,
                        total_epoch=num_episodes,
                        total_batch=t + 1 if done or conplete else t + 2,
                        is_collect_data=True if done or conplete else False,
                        is_print_batch_progress=False,
                        is_print_epoch_progress=False,
                        log_gradients=False,
                        log_weights=False,
                        accumulate_grads=False)

                if done or conplete:
                    if start_train:

                        # self.epoch_metric_history.collect('episode_durations',i_episode,float(t))
                        # 紀錄累積獎賞
                        self.epoch_metric_history.collect(
                            'total_rewards', i_episode, float(total_rewards))
                        # 紀錄完成比率(以200為基礎)
                        self.epoch_metric_history.collect(
                            'task_complete', i_episode,
                            1.0 if t + 1 >= 200 else 0.0)
                        # 定期列印學習進度
                        if i_episode % print_progess_frequency == 0:
                            self.print_epoch_progress(print_progess_frequency)
                        # 定期繪製損失函數以及評估函數對時間的趨勢圖
                        if i_episode > 0 and (i_episode + 1) % (
                                5 * print_progess_frequency) == 0:
                            print('epsilon:', self.epsilon)
                            print(
                                'predict_rewards:',
                                self.training_context['train_data']
                                ['predict_rewards'][:5])
                            print(
                                'target_rewards:',
                                self.training_context['train_data']
                                ['target_rewards'][:5])
                            print(
                                'reward_batch:',
                                self.training_context['train_data']
                                ['reward_batch'][:5])
                            loss_metric_curve(self.epoch_loss_history,
                                              self.epoch_metric_history,
                                              legend=['dqn'],
                                              calculate_base='epoch',
                                              imshow=imshow)

                    break

            # 定期更新target_net權值
            if start_train and i_episode % self.target_update == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict(),
                                                strict=True)
                self.save_model(save_path=self.training_context['save_path'])

        print('Complete')
        self.env.render()
        self.env.close()
        plt.ioff()
        plt.show()
Esempio n. 3
0
def plot_spectrum(result, correct = True, interactive = False):
    
    plt.close('all')
    plt.ioff()

    if interactive:
      plt.ion()

    hdu = fits.open(result['ORIGINALFILE'])
    galaxy = gaussian_filter(hdu[1].data, 1)
    thumbnail = hdu['THUMBNAIL'].data
    twoD = hdu['2D'].data
    header = hdu[0].header
    header1 = hdu[1].header
    hdu.close()

    lamRange = header1['CRVAL1']  + np.array([0., header1['CD1_1'] * (header1['NAXIS1'] - 1)]) 
    
    if correct:
      zp = 1. + (result['VREL'] / 299792.458)
    else:
      zp = 1.

    wavelength = np.linspace(lamRange[0],lamRange[1], header1['NAXIS1']) / zp
    ymin, ymax = np.min(galaxy), np.max(galaxy)
    ylim = [ymin, ymax] + np.array([-0.02, 0.1])*(ymax-ymin)
    ylim[0] = 0.

    xmin, xmax = np.min(wavelength), np.max(wavelength)

    ### Define multipanel size and properties
    fig = plt.figure(figsize=(8,6))
    gs = gridspec.GridSpec(200,130,bottom=0.10,left=0.10,right=0.95)

    ### Plot the object in the sky
    ax_obj = fig.add_subplot(gs[0:70,105:130])
    ax_obj.imshow(thumbnail, cmap = 'gray', interpolation = 'nearest')
    ax_obj.set_xticks([]) 
    ax_obj.set_yticks([]) 

    ### Plot the 2D spectrum
    ax_2d = fig.add_subplot(gs[0:11,0:100])
    ix_start = header['START_{}'.format(int(result['DETECT']))]
    ix_end = header['END_{}'.format(int(result['DETECT']))]
    ax_2d.imshow(twoD, cmap='spectral',
                aspect = "auto", origin = 'lower', extent=[xmin, xmax, 0, 1], 
                vmin = -0.2, vmax=0.2) 
    ax_2d.set_xticks([]) 
    ax_2d.set_yticks([]) 
    
    ### Add spectra subpanels
    ax_spectrum = fig.add_subplot(gs[11:85,0:100])
    ax_blue = fig.add_subplot(gs[110:200,0:50])
    ax_red = fig.add_subplot(gs[110:200,51:100])
    
    ### Plot some atomic lines  
    line_wave = [4861., 5175., 5892., 6562.8, 8498., 8542., 8662.] 
    #           ['Hbeta', 'Mgb', 'NaD', 'Halpha', 'CaT', 'CaT', 'CaT']
    for i in range(len(line_wave)):
        x = [line_wave[i], line_wave[i]]
        y = [ylim[0], ylim[1]]
        ax_spectrum.plot(x, y, c= 'gray', linewidth=1.0)
        ax_blue.plot(x, y, c= 'gray', linewidth=1.0)
        ax_red.plot(x, y, c= 'gray', linewidth=1.0)

    ### Plot the spectrum 
    ax_spectrum.plot(wavelength, galaxy, 'k', linewidth=1.3)
    ax_spectrum.set_ylim(ylim)
    ax_spectrum.set_xlim([xmin,xmax])
    ax_spectrum.set_ylabel(r'Arbitrary Flux')
    ax_spectrum.set_xlabel(r'Restframe Wavelength [ $\AA$ ]')
    
    ### Plot blue part of the spectrum
    x1, x2 = 300, 750 
    ax_blue.plot(wavelength[x1:x2], galaxy[x1:x2], 'k', linewidth=1.3)
    ax_blue.set_xlim(wavelength[x1],wavelength[x2])
    ax_blue.set_ylim(galaxy[x1:x2].min(), galaxy[x1:x2].max())
    ax_blue.set_yticks([]) 
    
    ### Plot red part of the spectrum
    x1, x2 = 1400, 1500
    ax_red.plot(wavelength[x1:x2], galaxy[x1:x2], 'k', linewidth=1.3)
    ax_red.set_xlim(wavelength[x1],wavelength[x2])
    ax_red.set_ylim(galaxy[x1:x2].min(), galaxy[x1:x2].max())
    ax_red.set_yticks([]) 

    ### Plot text
    #if interactive:
    textplot = fig.add_subplot(gs[80:200,105:130])
    kwarg = {'va' : 'center', 'ha' : 'left', 'size' : 'medium'}
    textplot.text(0.1, 1.0,r'ID = {} \, {}'.format(result.ID, int(result.DETECT)),**kwarg)
    textplot.text(0.1, 0.9,r'$v =$ {}'.format(int(result.VREL)), **kwarg)
    textplot.text(0.1, 0.8,r'$\delta \, v = $ {}'.format(int(result.VERR)), **kwarg)
    textplot.text(0.1, 0.7,r'SN1 = {0:.2f}'.format(result.SN1), **kwarg)
    textplot.text(0.1, 0.6,r'SN2 = {0:.2f}'.format(result.SN2), **kwarg)
    textplot.text(0.1, 0.5,r'TDR = {0:.2f}'.format(result.TDR), **kwarg)
    textplot.text(0.1, 0.4,r'SG = {}'.format(result.SG), **kwarg)
    textplot.axis('off')

    return fig