コード例 #1
0
    def _train(self, env, policy, pool):
        """Perform RL training.

        Args:
            env (`rllab.Env`): Environment used for training
            policy (`Policy`): Policy used for training
            pool (`PoolBase`): Sample pool to add samples to
        """
        self._init_training()
        self.sampler.initialize(env, policy, pool)

        evaluation_env = deep_clone(env) if self._eval_n_episodes else None
        # TODO: use Ezpickle to deep_clone???
        # evaluation_env = env

        with tf_utils.get_default_session().as_default():
            gt.rename_root('RLAlgorithm')
            gt.reset()
            gt.set_def_unique(False)

            for epoch in gt.timed_for(
                    range(self._n_epochs + 1), save_itrs=True):
                logger.push_prefix('Epoch #%d | ' % epoch)

                for t in range(self._epoch_length):
                    self.sampler.sample()
                    if not self.sampler.batch_ready():
                        continue
                    gt.stamp('sample')

                    for i in range(self._n_train_repeat):
                        self._do_training(
                            iteration=t + epoch * self._epoch_length,
                            batch=self.sampler.random_batch())
                    gt.stamp('train')

                self._evaluate(policy, evaluation_env)
                gt.stamp('eval')

                params = self.get_snapshot(epoch)
                logger.save_itr_params(epoch, params)

                time_itrs = gt.get_times().stamps.itrs
                time_eval = time_itrs['eval'][-1]
                time_total = gt.get_times().total
                time_train = time_itrs.get('train', [0])[-1]
                time_sample = time_itrs.get('sample', [0])[-1]

                logger.record_tabular('time-train', time_train)
                logger.record_tabular('time-eval', time_eval)
                logger.record_tabular('time-sample', time_sample)
                logger.record_tabular('time-total', time_total)
                logger.record_tabular('epoch', epoch)

                self.sampler.log_diagnostics()

                logger.dump_tabular(with_prefix=False)
                logger.pop_prefix()

            self.sampler.terminate()
コード例 #2
0
ファイル: hex_env.py プロジェクト: jinparksj/hex
 def log_diagnostics(self, paths):
     progs = [
         path["observations"][-1][-3] - path["observations"][0][-3]
         for path in paths
     ]
     logger.record_tabular('AverageForwardProgress', np.mean(progs))
     logger.record_tabular('MaxForwardProgress', np.max(progs))
     logger.record_tabular('MinForwardProgress', np.min(progs))
     logger.record_tabular('StdForwardProgress', np.std(progs))
コード例 #3
0
    def log_diagnostics(self, batch):
        """Record diagnostic information.

        Records the mean and standard deviation of Q-function and the
        squared Bellman residual of the  s (mean squared Bellman error)
        for a sample batch.

        Also call the `draw` method of the plotter, if plotter is defined.
        """

        feeds = self._get_feed_dict(batch)
        qf, bellman_residual = self._sess.run(
            [self._q_values, self._bellman_residual], feeds)

        logger.record_tabular('qf-avg', np.mean(qf))
        logger.record_tabular('qf-std', np.std(qf))
        logger.record_tabular('mean-sq-bellman-error', bellman_residual)

        self.policy.log_diagnostics(batch)
        if self.plotter:
            self.plotter.draw()
コード例 #4
0
    def _evaluate(self, policy, evaluation_env):
        """Perform evaluation for the current policy."""

        if self._eval_n_episodes < 1:
            return

        # TODO: max_path_length should be a property of environment.
        paths = rollouts(evaluation_env, policy, self.sampler._max_path_length,
                         self._eval_n_episodes)

        total_returns = [path['rewards'].sum() for path in paths]
        episode_lengths = [len(p['rewards']) for p in paths]

        logger.record_tabular('return-average', np.mean(total_returns))
        logger.record_tabular('return-min', np.min(total_returns))
        logger.record_tabular('return-max', np.max(total_returns))
        logger.record_tabular('return-std', np.std(total_returns))
        logger.record_tabular('episode-length-avg', np.mean(episode_lengths))
        logger.record_tabular('episode-length-min', np.min(episode_lengths))
        logger.record_tabular('episode-length-max', np.max(episode_lengths))
        logger.record_tabular('episode-length-std', np.std(episode_lengths))

        # TODO: figure out how to pass log_diagnostics through
        evaluation_env.log_diagnostics(paths)
        if self._eval_render:
            evaluation_env.render(paths)

        if self.sampler.batch_ready():
            batch = self.sampler.random_batch()
            self.log_diagnostics(batch)
コード例 #5
0
ファイル: sampler.py プロジェクト: jinparksj/hex
 def log_diagnostics(self):
     logger.record_tabular('pool-size', self.pool.size)
コード例 #6
0
ファイル: sampler.py プロジェクト: jinparksj/hex
 def log_diagnostics(self):
     super(SimpleSampler, self).log_diagnostics()
     logger.record_tabular('max-path-return', self._max_path_return)
     logger.record_tabular('last-path-return', self._last_path_return)
     logger.record_tabular('episodes', self._n_episodes)
     logger.record_tabular('total-samples', self._total_samples)