def train(self, start_epoch=0): self.pretrain() if start_epoch == 0: params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) self.training_mode(False) self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch gt.reset() gt.set_def_unique(False) if self.collection_mode == 'online': self.train_online(start_epoch=start_epoch) elif self.collection_mode == 'batch': self.train_batch(start_epoch=start_epoch) else: raise TypeError("Invalid collection_mode: {}".format( self.collection_mode))
def _try_to_eval(self, epoch, eval_paths=None): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch, eval_paths=eval_paths) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] try: eval_time = times_itrs['eval'][-1] except: eval_time = 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def test_epoch( self, epoch, save_reconstruction=True, save_vae=True, from_rl=False, ): self.model.eval() losses = [] log_probs = [] kles = [] zs = [] for batch_idx in range(10): next_obs = self.get_batch(train=False) reconstructions, obs_distribution_params, latent_distribution_params = self.model( next_obs) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) loss = self.beta * kle - log_prob encoder_mean = latent_distribution_params[0] z_data = ptu.get_numpy(encoder_mean.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.item()) log_probs.append(log_prob.item()) kles.append(kle.item()) if batch_idx == 0 and save_reconstruction: n = min(next_obs.size(0), 8) comparison = torch.cat([ next_obs[:n].narrow(start=0, length=self.imlength, dim=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize), reconstructions.view( self.batch_size, self.input_channels, self.imsize, self.imsize, )[:n] ]) save_dir = osp.join(logger.get_snapshot_dir(), 'r%d.png' % epoch) save_image(comparison.data.cpu(), save_dir, nrow=n) zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) if from_rl: self.vae_logger_stats_for_rl['Test VAE Epoch'] = epoch self.vae_logger_stats_for_rl['Test VAE Log Prob'] = np.mean( log_probs) self.vae_logger_stats_for_rl['Test VAE KL'] = np.mean(kles) self.vae_logger_stats_for_rl['Test VAE loss'] = np.mean(losses) self.vae_logger_stats_for_rl['VAE Beta'] = self.beta else: for key, value in self.debug_statistics().items(): logger.record_tabular(key, value) logger.record_tabular("test/Log Prob", np.mean(log_probs)) logger.record_tabular("test/KL", np.mean(kles)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("beta", self.beta) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model) # slow...