def train(self):
     """
     Performs startup, evaluates the initial agent, then loops by
     alternating between ``sampler.obtain_samples()`` and
     ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
     specified log interval.
     """
     n_itr = self.startup()
     with logger.prefix(f"itr #0 "):
         eval_traj_infos, eval_time = self.evaluate_agent(0)
         self.log_diagnostics(0, eval_traj_infos, eval_time)
     if self.pretrain != 'None':
         status_file_read = open(self.log_dir + '/last_itr.txt', 'r')
         starting_itr = int(status_file_read.read().split('\n')[-2])
     else:
         starting_itr = 0
     with open(self.log_dir + '/last_itr.txt',
               'a') as status_file:  # for restart purposes
         if self.pretrain != 'None':
             starting_itr = int(status_file.read().split('\n')[-2])
         else:
             starting_itr = 0
         for itr in range(starting_itr, n_itr):
             logger.set_iteration(itr)
             with logger.prefix(f"itr #{itr} "):
                 self.agent.sample_mode(itr)
                 samples, traj_infos = self.sampler.obtain_samples(itr)
                 self.agent.train_mode(itr)
                 opt_info = self.algo.optimize_agent(itr, samples)
                 self.store_diagnostics(itr, traj_infos, opt_info)
                 if (itr + 1) % self.log_interval_itrs == 0:
                     eval_traj_infos, eval_time = self.evaluate_agent(itr)
                     self.log_diagnostics(itr, eval_traj_infos, eval_time)
     self.shutdown()
示例#2
0
 def train(self):
     """
     Performs startup, evaluates the initial agent, then loops by
     alternating between ``sampler.obtain_samples()`` and
     ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
     specified log interval.
     """
     n_itr = self.startup()
     with logger.prefix(f"itr #0 "):
         eval_traj_infos, eval_time = self.evaluate_agent(0)
         self.log_diagnostics(0, eval_traj_infos, eval_time)
     for itr in range(n_itr):
         logger.set_iteration(itr)
         with logger.prefix(f"itr #{itr} "):
             if self.transfer and self.transfer_iter == itr:
                 self.sampler.transfer(
                     self.transfer_arg)  # Transfer if doing
                 eval_traj_infos, eval_time = self.evaluate_agent(
                     itr)  # Eval
                 self.log_diagnostics(itr, eval_traj_infos, eval_time)
             self.agent.sample_mode(itr)
             samples, traj_infos = self.sampler.obtain_samples(itr)
             self.agent.train_mode(itr)
             opt_info = self.algo.optimize_agent(itr, samples)
             self.store_diagnostics(itr, traj_infos, opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 eval_traj_infos, eval_time = self.evaluate_agent(itr)
                 self.log_diagnostics(itr, eval_traj_infos, eval_time)
     self.shutdown()
示例#3
0
    def train(self):
        """
        Performs startup, evaluates the initial agent, then loops by
        alternating between ``sampler.obtain_samples()`` and
        ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
        specified log interval.
        """
        n_itr = self.startup()
        with logger.prefix(f"itr #0 "):
            eval_traj_infos, eval_time = self.evaluate_agent(0)
            self.log_diagnostics(0, eval_traj_infos, eval_time)
        for itr in range(n_itr):
            logger.set_iteration(itr)
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)
                samples, traj_infos = self.sampler.obtain_samples(itr)
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, samples)
                self.store_diagnostics(itr, traj_infos, opt_info)
                if self.prioritized_level_replay:
                    for traj_info, value_error in zip(traj_infos,
                                                      opt_info.valueLoss):
                        self.level_replay.update_seed_score(
                            traj_info.seed, value_error)
                    seeds = [
                        self.level_replay.sample()
                        for _ in range(self.sampler.n_worker)
                    ]
                    self.sampler.set_env_seeds(seeds)

                if (itr + 1) % self.log_interval_itrs == 0:
                    eval_traj_infos, eval_time = self.evaluate_agent(itr)
                    self.log_diagnostics(itr, eval_traj_infos, eval_time)
        self.shutdown()
示例#4
0
 def train(self):
     n_itr = self.startup()
     with logger.prefix(f"itr #0 "):
         eval_traj_infos, eval_time = self.evaluate_agent(0)
         self.log_diagnostics(0, eval_traf_infos, eval_time)
     for itr in range(n_itr):
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(itr)
             samples, traj_infos = self.sampler.obtain_samples(itr)
             self.agent.train_mode(itr)
示例#5
0
    def train(self):
        """
        Performs startup, evaluates the initial agent, then loops by
        alternating between ``sampler.obtain_samples()`` and
        ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
        specified log interval.
        """
        n_itr = self.startup()
        with logger.prefix(f"itr #0 "):
            player_eval_traj_infos, observer_eval_traj_infos, eval_time = self.evaluate_agent(
                0)
            self.log_diagnostics(0, player_eval_traj_infos,
                                 observer_eval_traj_infos, eval_time)
        for itr in range(n_itr):
            logger.set_iteration(itr)
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)
                player_samples, player_traj_infos, observer_samples, observer_traj_infos = self.sampler.obtain_samples(
                    itr)
                self.agent.train_mode(itr)
                player_opt_info = ()
                observer_opt_info = ()
                if self.alt_train:
                    if self.agent.train_mask[0] and (itr % 2 == 0):
                        player_opt_info = self.player_algo.optimize_agent(
                            itr // 2, player_samples)
                    elif self.agent.train_mask[1]:
                        observer_opt_info = self.observer_algo.optimize_agent(
                            itr // 2, observer_samples)
                else:
                    if self.agent.train_mask[0]:
                        player_opt_info = self.player_algo.optimize_agent(
                            itr, player_samples)
                    if self.agent.train_mask[1]:
                        observer_opt_info = self.observer_algo.optimize_agent(
                            itr, observer_samples)
                self.store_diagnostics(itr, player_traj_infos,
                                       observer_traj_infos, player_opt_info,
                                       observer_opt_info)
                if (itr + 1) % self.log_interval_itrs == 0:
                    player_eval_traj_infos, observer_eval_traj_infos, eval_time = self.evaluate_agent(
                        itr)
                    if self.wandb_log:
                        self.wandb_logging(
                            itr,
                            player_traj_infos=player_eval_traj_infos,
                            observer_traj_infos=observer_eval_traj_infos)
                    self.log_diagnostics(itr, player_eval_traj_infos,
                                         observer_eval_traj_infos, eval_time)

        self.shutdown()
示例#6
0
 def optimize(self):
     throttle_itr, delta_throttle_itr = self.startup()
     throttle_time = 0.
     itr = 0
     if self._log_itr0:
         while self.ctrl.sample_itr.value < 1:
             time.sleep(THROTTLE_WAIT)
         while self.traj_infos_queue.qsize():
             traj_infos = self.traj_infos.queue.get()
         self.store_diagnostics(0, 0, traj_infos, ())
         self.log_diagnostics(0, 0, 0)
     log_counter = 0
     while True:
         if self.ctrl.quit.value:
             break
         with logger.prefix(f"opt_itr #{itr} "):
             while self.ctrl.sample_itr.value < throttle_itr:
                 time.sleep(THROTTLE_WAIT)
                 throttle_time += THROTTLE_WAIT
             if self.ctrl.opt_throttle is not None:
                 self.ctrl.opt_throttle.wait()
             throttle_itr += delta_throttle_itr
             opt_info = self.algo.optimize_agent(itr)
             self.agent.send_shared_memory()
             traj_infos = list()
             sample_itr = self.ctrl.sample_itr.value  # Check before queue.
             while self.traj_infos_queue.qsize():
                 traj_infos = self.traj_infos.queue.get()
             self.store_diagnostics(itr, sample_itr, traj_infos, opt_info)
             if (sample_itr // self.log_interval_itrs > log_counter):
                 self.log_diagnostics(itr, sample_itr, throttle_time)
                 log_counter += 1
                 throttle_time = 0.
         itr += 1
     self.shutdown()
示例#7
0
 def train(self):
     n_itr = self.startup()
     with logger.prefix(f"itr #0 "):
         eval_traj_infos, eval_time = self.evaluate_agent(0)
         self.log_diagnostics(0, eval_traj_infos, eval_time)
     for itr in range(n_itr):
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(itr)
             samples, traj_infos = self.sampler.obtain_samples(itr)
             self.agent.train_mode(itr)
             opt_info = self.algo.optimize_agent(itr, samples)
             self.store_diagnostics(itr, traj_infos, opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 eval_traj_infos, eval_time = self.evaluate_agent(itr)
                 self.log_diagnostics(itr, eval_traj_infos, eval_time)
     self.shutdown()
 def train(self):
     """
     Performs startup, then loops by alternating between
     ``sampler.obtain_samples()`` and ``algo.optimize_agent()``, logging
     diagnostics at the specified interval.
     """
     n_itr = self.startup()
     if self.pretrain != 'None':
         status_file_read = open(self.log_dir + '/last_itr.txt', 'r')
         starting_itr = int(status_file_read.read().split('\n')[-2])
     else:
         starting_itr = 0
     with open(self.log_dir + '/last_itr.txt',
               'a') as status_file:  # for restart purposes
         for itr in range(starting_itr, n_itr):
             logger.set_iteration(itr)
             with logger.prefix(f"itr #{itr} "):
                 self.agent.sample_mode(
                     itr)  # Might not be this agent sampling.
                 samples, traj_infos = self.sampler.obtain_samples(itr)
                 self.agent.train_mode(itr)
                 opt_info, layer_info = self.algo.optimize_agent(
                     itr, samples)
                 self.store_diagnostics(itr, traj_infos, opt_info)
                 if (itr + 1) % self.log_interval_itrs == 0:
                     status_file.write(str(itr) + '\n')
                     self.log_diagnostics(itr)
                     self.log_weights(layer_info)
     self.shutdown()
示例#9
0
 def train(self):
     """
     Performs startup, then loops by alternating between
     ``sampler.obtain_samples()`` and ``algo.optimize_agent()``, logging
     diagnostics at the specified interval.
     """
     n_itr = self.startup()
     for itr in range(n_itr):
         logger.set_iteration(itr)
         with logger.prefix(f"itr #{itr} "):
             if self.transfer and self.transfer_iter == itr:
                 self.sampler.transfer(
                     self.transfer_arg)  # Transfer if doing
                 self._traj_infos.clear()  # Clear trajectory information
                 self._transfer_start(itr, opt_info)
             self.agent.sample_mode(
                 itr)  # Might not be this agent sampling.
             samples, traj_infos = self.sampler.obtain_samples(itr)
             self.agent.train_mode(itr)
             opt_info = self.algo.optimize_agent(itr, samples)
             self.store_diagnostics(itr, traj_infos, opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 self.log_diagnostics(itr)
             if self.n_episodes is not None and self._cum_completed_trajs >= self.n_episodes:
                 break
     self.shutdown()
示例#10
0
文件: mtgail.py 项目: qxcv/mtil
    def train(self, cb_startup=None):
        # copied from MinibatchRl.train() & extended to support GAIL update
        n_itr = self.startup()
        if cb_startup:
            # post-startup callback (cb)
            cb_startup(self)
        for itr in range(n_itr):
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(
                    itr)  # Might not be this agent sampling.
                samples, traj_infos = self.sampler.obtain_samples(itr)
                # label traj_infos with env IDs (this is specific to
                # magical/my multi-task thing)
                traj_infos = _label_traj_infos(traj_infos, self.variant_groups)
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, samples)

                # run GAIL & combine its output with RL algorithm output
                gail_info = self.gail_optim.optim_disc(itr, n_itr, samples)
                if self.joint_info_cls is None:
                    self.joint_info_cls = namedtuple(
                        'joint_info_cls', gail_info._fields + opt_info._fields)
                opt_info = self.joint_info_cls(**gail_info._asdict(),
                                               **opt_info._asdict())

                self.store_diagnostics(itr, traj_infos, opt_info)
                if (itr + 1) % self.log_interval_itrs == 0:
                    self.log_diagnostics(itr)
        self.shutdown()
示例#11
0
    def train(self):
        """
        Performs startup, evaluates the initial agent, then loops by
        alternating between ``sampler.obtain_samples()`` and
        ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
        specified log interval.
        """
        n_itr = self.startup()
        with logger.prefix(f"itr #0 "):
            eval_traj_infos, eval_time = self.evaluate_agent(0)
            self.log_diagnostics(0, eval_traj_infos, eval_time)
        for itr in range(n_itr):
            logger.set_iteration(self.get_cum_steps(itr))
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)
                samples, traj_infos = self.sampler.obtain_samples(itr)
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, samples)
                self.store_diagnostics(itr, traj_infos, opt_info)
                if (itr + 1) % self.log_interval_itrs == 0:
                    eval_traj_infos, eval_time = self.evaluate_agent(itr)
                    p_error = self.sampler.env_kwargs['error_rate']
                    for traj_info in eval_traj_infos:
                        traj_info['p_error'] = p_error
                    self.log_diagnostics(itr, eval_traj_infos, eval_time)
                    avg_lifetime = np.nanmean(
                        np.array([x['lifetime'] for x in eval_traj_infos]))
                    if avg_lifetime > (1 / p_error):
                        if p_error < 0.010:
                            p_error = 0.011
                            self.sampler.env_kwargs['error_rate'] = p_error
                            self.sampler.eval_env_kwargs[
                                'error_rate'] = p_error
                            print(f'new p error is {p_error}', flush=True)
                            self.shutdown()
                            self.startup()
                        else:
                            print(
                                f'didnt change p_error - currently at {p_error}'
                            )

                        # for env in self.sampler.collector.envs + self.sampler.eval_collector.envs:
                        #     env.p_phys = new_p_error
                        #     env.p_meas = new_p_error
        print(f'training end due to n_itr', flush=True)
        self.shutdown()
示例#12
0
    def train(self):
        """
        Run the optimizer in a loop.  Check whether enough new samples have
        been generated, and throttle down if necessary at each iteration.  Log
        at an interval in the number of sampler iterations, not optimizer
        iterations.
        """
        throttle_itr, delta_throttle_itr = self.startup()
        throttle_time = 0.
        sampler_itr = itr = 0
        if self._eval:
            while self.ctrl.sampler_itr.value < 1:  # Sampler does eval first.
                time.sleep(THROTTLE_WAIT)
            traj_infos = drain_queue(self.traj_infos_queue, n_sentinel=1)
            self.store_diagnostics(0, 0, traj_infos, ())
            self.log_diagnostics(0, 0, 0)

        log_counter = 0
        while True:  # Run until sampler hits n_steps and sets ctrl.quit=True.
            logger.set_iteration(itr)
            with logger.prefix(f"opt_itr #{itr} "):
                while self.ctrl.sampler_itr.value < throttle_itr:
                    if self.ctrl.quit.value:
                        break
                    time.sleep(THROTTLE_WAIT)
                    throttle_time += THROTTLE_WAIT
                if self.ctrl.quit.value:
                    break
                if self.ctrl.opt_throttle is not None:
                    self.ctrl.opt_throttle.wait()
                throttle_itr += delta_throttle_itr
                opt_info = self.algo.optimize_agent(
                    itr, sampler_itr=self.ctrl.sampler_itr.value)
                self.agent.send_shared_memory()  # To sampler.
                sampler_itr = self.ctrl.sampler_itr.value
                traj_infos = (list() if self._eval else drain_queue(
                    self.traj_infos_queue))
                self.store_diagnostics(itr, sampler_itr, traj_infos, opt_info)
                if (sampler_itr // self.log_interval_itrs > log_counter):
                    if self._eval:
                        with self.ctrl.sampler_itr.get_lock():
                            traj_infos = drain_queue(self.traj_infos_queue,
                                                     n_sentinel=1)
                        self.store_diagnostics(itr, sampler_itr, traj_infos,
                                               ())
                    self.log_diagnostics(itr, sampler_itr, throttle_time)
                    log_counter += 1
                    throttle_time = 0.
            itr += 1
        # Final log:
        sampler_itr = self.ctrl.sampler_itr.value
        traj_infos = drain_queue(self.traj_infos_queue)
        if traj_infos or not self._eval:
            self.store_diagnostics(itr, sampler_itr, traj_infos, ())
            self.log_diagnostics(itr, sampler_itr, throttle_time)
        self.shutdown()
示例#13
0
 def train(self):
     n_itr = self.startup()
     for itr in range(n_itr):
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(itr)  # Might not be this agent sampling.
             samples, traj_infos = self.sampler.obtain_samples(itr)
             self.agent.train_mode(itr)
             opt_info = self.algo.optimize_agent(itr, samples)
             self.store_diagnostics(itr, traj_infos, opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 self.log_diagnostics(itr)
     self.shutdown()
示例#14
0
 def train(self):
     n_itr = self.startup()
     with logger.prefix(f"itr #0 "):
         eval_traj_infos, eval_time = self.evaluate_agent(0)
         self.log_diagnostics(0, eval_traj_infos, eval_time)
     for itr in range(n_itr):
         # print("{}/{}".format(itr, n_itr))
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(
                 itr)  # Might not be this agent sampling.
             # st = time.time()
             samples, traj_infos = self.sampler.obtain_samples(itr)
             # st2 = time.time()
             # print('sample time:', st2 - st)
             self.agent.train_mode(itr)
             opt_info = self.algo.optimize_agent(itr, samples)
             # print('train time:', time.time() - st2)
             self.store_diagnostics(itr, traj_infos, opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 eval_traj_infos, eval_time = self.evaluate_agent(itr)
                 self.log_diagnostics(itr, eval_traj_infos, eval_time)
     self.shutdown()
示例#15
0
 def train(self):
     n_itr = self.startup(
     )  # 调用startup()会导致调用父类的__init__()方法,从而会把外面的algo,agent,sampler传进去
     with logger.prefix(f"itr #0 "):
         eval_traj_infos, eval_time = self.evaluate_agent(
             0)  # 开始训练模型之前先evaluate一次
         self.log_diagnostics(0, eval_traj_infos, eval_time)  # 记录诊断信息(写日志)
     for itr in range(n_itr):  # 重复训练N轮
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(itr)  # 设置成采样模式
             samples, traj_infos = self.sampler.obtain_samples(
                 itr)  # 采样一批数据
             self.agent.train_mode(itr)  # 把神经网络module设置成训练模式,传进入的迭代次数其实没用
             opt_info = self.algo.optimize_agent(
                 itr, samples)  # 训练模型,反向传播之类的工作就是在这里面做的
             self.store_diagnostics(itr, traj_infos,
                                    opt_info)  # 更新内存中的一些统计数据
             if (itr + 1) % self.log_interval_itrs == 0:  # 每迭代到记录一次日志的步数
                 eval_traj_infos, eval_time = self.evaluate_agent(
                     itr)  # 评估模型
                 self.log_diagnostics(itr, eval_traj_infos,
                                      eval_time)  # 记录诊断信息(写日志)
     self.shutdown()  # 完成后的清理工作
示例#16
0
 def train(self):
     self.startup()
     self.algo.train()
     for itr in range(self.n_updates):
         logger.set_iteration(itr)
         with logger.prefix(f"itr #{itr} "):
             opt_info = self.algo.optimize(itr)  # perform one update
             self.store_diagnostics(itr, opt_info)
             if (itr + 1) % self.log_interval_updates == 0:
                 self.algo.eval()
                 val_info = self.algo.validation(itr)
                 self.log_diagnostics(itr, val_info)
                 self.algo.train()
     self.shutdown()
示例#17
0
 def train(self):
     throttle_itr, delta_throttle_itr = self.startup()
     throttle_time = 0.
     sampler_itr = itr = 0
     if self._eval:
         while self.ctrl.sampler_itr.value < 1:  # Sampler does eval first.
             time.sleep(THROTTLE_WAIT)
         traj_infos = drain_queue(self.traj_infos_queue, n_sentinel=1)
         self.store_diagnostics(0, 0, traj_infos, ())
         self.log_diagnostics(0, 0, 0)
     log_counter = 0
     while True:  # Run until sampler hits n_steps and sets ctrl.quit=True.
         with logger.prefix(f"opt_itr #{itr} "):
             while self.ctrl.sampler_itr.value < throttle_itr:
                 if self.ctrl.quit.value:
                     break
                 time.sleep(THROTTLE_WAIT)
                 throttle_time += THROTTLE_WAIT
             if self.ctrl.quit.value:
                 break
             if self.ctrl.opt_throttle is not None:
                 self.ctrl.opt_throttle.wait()
             throttle_itr += delta_throttle_itr
             opt_info = self.algo.optimize_agent(itr,
                 sampler_itr=self.ctrl.sampler_itr.value)
             self.agent.send_shared_memory()  # To sampler.
             sampler_itr = self.ctrl.sampler_itr.value
             traj_infos = (list() if self._eval else
                 drain_queue(self.traj_infos_queue))
             self.store_diagnostics(itr, sampler_itr, traj_infos, opt_info)
             if (sampler_itr // self.log_interval_itrs > log_counter):
                 if self._eval:
                     with self.ctrl.sampler_itr.get_lock():
                         traj_infos = drain_queue(self.traj_infos_queue, n_sentinel=1)
                     self.store_diagnostics(itr, sampler_itr, traj_infos, ())
                 self.log_diagnostics(itr, sampler_itr, throttle_time)
                 log_counter += 1
                 throttle_time = 0.
         itr += 1
     # Final log:
     sampler_itr = self.ctrl.sampler_itr.value
     traj_infos = drain_queue(self.traj_infos_queue)
     if traj_infos or not self._eval:
         self.store_diagnostics(itr, sampler_itr, traj_infos, ())
         self.log_diagnostics(itr, sampler_itr, throttle_time)
     self.shutdown()
示例#18
0
 def train(self):
     """
     Performs startup, then loops by alternating between
     ``sampler.obtain_samples()`` and ``algo.optimize_agent()``, logging
     diagnostics at the specified interval.
     """
     n_itr = self.startup()
     for itr in range(n_itr):
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(itr)  # Might not be this agent sampling.
             samples, traj_infos = self.sampler.obtain_samples(itr)
             self.agent.train_mode(itr)
             opt_info = self.algo.optimize_agent(itr, samples)
             self.store_diagnostics(itr, traj_infos, opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 self.log_diagnostics(itr)
     self.shutdown()
示例#19
0
 def train(self):
     """
     Performs startup, then loops by alternating between
     ``sampler.obtain_samples()`` and ``algo.optimize_agent()``, logging
     diagnostics at the specified interval.
     """
     n_itr = self.startup()
     for itr in range(n_itr):
         logger.set_iteration(itr)
         with logger.prefix(f"itr #{itr} "):
             self.agent.sample_mode(
                 itr)  # Might not be this agent sampling.
             player_samples, player_traj_infos, observer_samples, observer_traj_infos = self.sampler.obtain_samples(
                 itr)
             self.agent.train_mode(itr)
             player_opt_info = ()
             observer_opt_info = ()
             if self.alt_train:
                 if self.agent.train_mask[0] and (itr % 2 == 0):
                     player_opt_info = self.player_algo.optimize_agent(
                         itr // 2, player_samples)
                 elif self.agent.train_mask[1]:
                     observer_opt_info = self.observer_algo.optimize_agent(
                         itr // 2, observer_samples)
             else:
                 if self.agent.train_mask[0]:
                     player_opt_info = self.player_algo.optimize_agent(
                         itr, player_samples)
                 if self.agent.train_mask[1]:
                     observer_opt_info = self.observer_algo.optimize_agent(
                         itr, observer_samples)
             self.store_diagnostics(itr, player_traj_infos,
                                    observer_traj_infos, player_opt_info,
                                    observer_opt_info)
             if (itr + 1) % self.log_interval_itrs == 0:
                 if self.wandb_log:
                     self.wandb_logging(itr)
                 self.log_diagnostics(itr)
     self.shutdown()
示例#20
0
    def train(self, return_buffer=False, check_running=True):
        """
		Performs startup, evaluates the initial agent, then loops by
		alternating between ``sampler.obtain_samples()`` and
		``algo.optimize_agent()``.  Pauses to evaluate the agent at the
		specified log interval.
		"""
        best_itr = 0
        # initial_pi_loss = None
        # initial_q_loss = None
        # min_save_itr = self.min_save_args['min_save_itr']
        # min_save_pi_loss_ratio = self.min_save_args['min_save_pi_loss_ratio']
        # min_save_q_loss_ratio = self.min_save_args['min_save_q_loss_ratio']
        eval_reward_avg_all = []

        n_itr = self.startup()

        # Evaluate first - initialize initial and best eval reward (before running random exploration) - load policy first, and then reset
        with logger.prefix(f"itr eval"):
            self.agent.load_state_dict()
            eval_traj_infos, eval_time = self.evaluate_agent(itr=0)
            ini_eval_reward_avg = self.get_eval_reward(eval_traj_infos)
            # self.log_diagnostics(0, eval_traj_infos, eval_time)
            self.agent.reset_model()
            print(f'Initial eval reward: {ini_eval_reward_avg}')
            logger.log(f'Initial eval reward: {ini_eval_reward_avg}')

            # Initialize best reward
            best_eval_reward_avg = ini_eval_reward_avg
            # best_eval_reward_avg = -1000	# dummy

        for itr in range(n_itr):
            logger.set_iteration(itr)

            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)
                samples, traj_infos = self.sampler.obtain_samples(itr)
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, samples)

                save_cur = False
                # Find if in min_itr_learn (random exploration using bad policy)
                if len(opt_info.piLoss) == 0:
                    min_itr_learn = True
                else:
                    min_itr_learn = False
                self.store_diagnostics(itr, traj_infos, opt_info)

                # It is possible that save_cur never satisfied in all itrs, then do not update policy for this retrain
                if (itr + 1) % self.log_interval_itrs == 0:
                    eval_traj_infos, eval_time = self.evaluate_agent(itr)
                    eval_reward_avg = self.get_eval_reward(eval_traj_infos)

                    # Do not save at initial itrs
                    if not min_itr_learn:
                        eval_reward_avg_all += [eval_reward_avg]
                        eval_reward_window = eval_reward_avg_all[
                            -self.running_window_size:]

                        # Get running average
                        if len(eval_reward_avg_all
                               ) >= self.running_window_size:
                            running_avg = np.mean(eval_reward_window)
                        else:
                            running_avg = -1000  # dummy

                        # Get running std
                        s0 = sum(1 for a in eval_reward_window)
                        s1 = sum(a for a in eval_reward_window)
                        s2 = sum(a * a for a in eval_reward_window)
                        running_std = np.sqrt(
                            (s0 * s2 - s1 * s1) / (s0 * (s0 - 1)))

                        # Determine if saving current snapshot
                        if check_running and (
                                running_avg - ini_eval_reward_avg
                        ) > 0 and eval_reward_avg > best_eval_reward_avg and running_std < self.running_std_thres:
                            best_eval_reward_avg = eval_reward_avg
                            best_itr = itr
                            save_cur = True

                        elif not check_running and eval_reward_avg > best_eval_reward_avg:
                            best_eval_reward_avg = eval_reward_avg
                            best_itr = itr
                            save_cur = True

                    self.log_diagnostics(itr, eval_traj_infos, eval_time,
                                         save_cur)
                    if (itr + 1) % 10 == 0:
                        logger.log(f'Average eval reward: {eval_reward_avg}')
                        print(
                            f'Average eval reward at itr {itr}: {eval_reward_avg}'
                        )
        self.shutdown()

        if return_buffer:
            return best_itr, self.algo.replay_buffer_dict()
        else:
            return best_itr
示例#21
0
    def train(self):
        """
        Performs startup, then loops by alternating between
        ``sampler.obtain_samples()`` and ``algo.optimize_agent()``, logging
        diagnostics at the specified interval.
        """
        n_itr = self.startup()
        for itr in range(n_itr):
            logger.set_iteration(itr)
            with logger.prefix(f"itr #{itr} "):
                if itr % 200 == 0:
                    # try to log distribution gradient norm of agent
                    # gradients = []
                    policy_gradients = []
                    value_gradients = []
                    all_value_diffs = []
                    all_ratios = []
                    all_rewards = []
                    all_unnorm_rewards = []

                    num_iters = 100
                    all_returns = []
                    for i in range(num_iters):
                        samples, traj_infos = self.sampler.obtain_samples(itr)
                        returns = [ti.Return for ti in traj_infos]
                        all_returns.extend(returns)
                        # mb_grads, value_diffs, ratios, norm_rewards, unnorm_rewards = self.algo.compute_minibatch_gradients(samples)
                        # gradients.extend(mb_grads)
                        p_grads, v_grads, value_diffs, ratios, norm_rewards, unnorm_rewards = self.algo.compute_minibatch_gradients(
                            samples)
                        policy_gradients.extend(p_grads)
                        value_gradients.extend(v_grads)

                        all_value_diffs.extend(value_diffs)
                        all_rewards.extend(norm_rewards.numpy())
                        all_unnorm_rewards.extend(unnorm_rewards.numpy())
                        # all_ratios.extend(ratios)
                        # print('ratios', all_ratios)
                        if i % 10 == 0:
                            print('grad', i)
                    # average all gradients
                    # mean_gradient = np.mean(np.array(gradients), axis=0)
                    mean_p_grad = np.mean(np.array(policy_gradients), axis=0)
                    mean_v_grad = np.mean(np.array(value_gradients), axis=0)

                    # compute gradient noises
                    p_cosines = []
                    p_grad_noise_norms = []
                    v_cosines = []
                    v_grad_noise_norms = []
                    for i in range(num_iters):
                        p_noise = np.linalg.norm(policy_gradients[i] -
                                                 mean_p_grad)
                        v_noise = np.linalg.norm(value_gradients[i] -
                                                 mean_v_grad)
                        p_grad_noise_norms.append(p_noise)
                        v_grad_noise_norms.append(v_noise)
                        p_cos = -1 * (
                            cosine(policy_gradients[i], mean_p_grad) - 1)
                        v_cos = -1 * (cosine(value_gradients[i], mean_v_grad) -
                                      1)
                        p_cosines.append(p_cos)
                        v_cosines.append(v_cos)
                    print(p_grad_noise_norms)
                    print(v_grad_noise_norms)
                    np.save(
                        '/home/vincent/repos/rlpyt/log/policy_gradnoisenorms' +
                        str(itr) + '_' + str(self.seed), p_grad_noise_norms)
                    np.save(
                        '/home/vincent/repos/rlpyt/log/value_gradnoisenorms' +
                        str(itr) + '_' + str(self.seed), v_grad_noise_norms)

                    np.save(
                        '/home/vincent/repos/rlpyt/log/policy_cosines' +
                        str(itr) + '_' + str(self.seed), p_cosines)
                    np.save(
                        '/home/vincent/repos/rlpyt/log/value_cosines' +
                        str(itr) + '_' + str(self.seed), v_cosines)

                    np.save(
                        '/home/vincent/repos/rlpyt/log/value_diffs' +
                        str(itr) + '_' + str(self.seed), value_diffs)
                    np.save(
                        '/home/vincent/repos/rlpyt/log/norm_rewards' +
                        str(itr) + '_' + str(self.seed), all_rewards)
                    np.save(
                        '/home/vincent/repos/rlpyt/log/unnorm_rewards' +
                        str(itr) + '_' + str(self.seed), all_unnorm_rewards)
                    np.save(
                        '/home/vincent/repos/rlpyt/log/returns' + str(itr) +
                        '_' + str(self.seed), all_returns)

                self.agent.sample_mode(
                    itr)  # Might not be this agent sampling.
                samples, traj_infos = self.sampler.obtain_samples(itr)
                self.agent.train_mode(itr)
                opt_info, ratios, rews = self.algo.optimize_agent(itr, samples)
                if itr % 100 == 0:
                    np.save(
                        '/home/vincent/repos/rlpyt/log/adv_ratios' + str(itr) +
                        '_' + str(self.seed), ratios)
                    # np.save('/home/vincent/repos/rlpyt/log/rewards' + str(itr) + '_' + str(self.seed),
                    #         rews.numpy())
                self.store_diagnostics(itr, traj_infos, opt_info)
                if (itr + 1) % self.log_interval_itrs == 0:
                    self.log_diagnostics(itr)
        self.shutdown()