def _train(self, env, policy, pool): """Perform RL training. Args: env (`rllab.Env`): Environment used for training policy (`Policy`): Policy used for training pool (`PoolBase`): Sample pool to add samples to """ self._init_training() self.sampler.initialize(env, policy, pool) evaluation_env = deep_clone(env) if self._eval_n_episodes else None # TODO: use Ezpickle to deep_clone??? # evaluation_env = env with tf_utils.get_default_session().as_default(): gt.rename_root('RLAlgorithm') gt.reset() gt.set_def_unique(False) for epoch in gt.timed_for( range(self._n_epochs + 1), save_itrs=True): logger.push_prefix('Epoch #%d | ' % epoch) for t in range(self._epoch_length): self.sampler.sample() if not self.sampler.batch_ready(): continue gt.stamp('sample') for i in range(self._n_train_repeat): self._do_training( iteration=t + epoch * self._epoch_length, batch=self.sampler.random_batch()) gt.stamp('train') self._evaluate(policy, evaluation_env) gt.stamp('eval') params = self.get_snapshot(epoch) logger.save_itr_params(epoch, params) time_itrs = gt.get_times().stamps.itrs time_eval = time_itrs['eval'][-1] time_total = gt.get_times().total time_train = time_itrs.get('train', [0])[-1] time_sample = time_itrs.get('sample', [0])[-1] logger.record_tabular('time-train', time_train) logger.record_tabular('time-eval', time_eval) logger.record_tabular('time-sample', time_sample) logger.record_tabular('time-total', time_total) logger.record_tabular('epoch', epoch) self.sampler.log_diagnostics() logger.dump_tabular(with_prefix=False) logger.pop_prefix() self.sampler.terminate()
def log_diagnostics(self, paths): progs = [ path["observations"][-1][-3] - path["observations"][0][-3] for path in paths ] logger.record_tabular('AverageForwardProgress', np.mean(progs)) logger.record_tabular('MaxForwardProgress', np.max(progs)) logger.record_tabular('MinForwardProgress', np.min(progs)) logger.record_tabular('StdForwardProgress', np.std(progs))
def log_diagnostics(self, batch): """Record diagnostic information. Records the mean and standard deviation of Q-function and the squared Bellman residual of the s (mean squared Bellman error) for a sample batch. Also call the `draw` method of the plotter, if plotter is defined. """ feeds = self._get_feed_dict(batch) qf, bellman_residual = self._sess.run( [self._q_values, self._bellman_residual], feeds) logger.record_tabular('qf-avg', np.mean(qf)) logger.record_tabular('qf-std', np.std(qf)) logger.record_tabular('mean-sq-bellman-error', bellman_residual) self.policy.log_diagnostics(batch) if self.plotter: self.plotter.draw()
def _evaluate(self, policy, evaluation_env): """Perform evaluation for the current policy.""" if self._eval_n_episodes < 1: return # TODO: max_path_length should be a property of environment. paths = rollouts(evaluation_env, policy, self.sampler._max_path_length, self._eval_n_episodes) total_returns = [path['rewards'].sum() for path in paths] episode_lengths = [len(p['rewards']) for p in paths] logger.record_tabular('return-average', np.mean(total_returns)) logger.record_tabular('return-min', np.min(total_returns)) logger.record_tabular('return-max', np.max(total_returns)) logger.record_tabular('return-std', np.std(total_returns)) logger.record_tabular('episode-length-avg', np.mean(episode_lengths)) logger.record_tabular('episode-length-min', np.min(episode_lengths)) logger.record_tabular('episode-length-max', np.max(episode_lengths)) logger.record_tabular('episode-length-std', np.std(episode_lengths)) # TODO: figure out how to pass log_diagnostics through evaluation_env.log_diagnostics(paths) if self._eval_render: evaluation_env.render(paths) if self.sampler.batch_ready(): batch = self.sampler.random_batch() self.log_diagnostics(batch)
def log_diagnostics(self): logger.record_tabular('pool-size', self.pool.size)
def log_diagnostics(self): super(SimpleSampler, self).log_diagnostics() logger.record_tabular('max-path-return', self._max_path_return) logger.record_tabular('last-path-return', self._last_path_return) logger.record_tabular('episodes', self._n_episodes) logger.record_tabular('total-samples', self._total_samples)