def collect_paths(self, idx, epoch, eval_task=False): self.task_idx = idx dprint('Task:', idx) self.env.reset_task(idx) # if eval_task: # num_evals = self.num_evals # else: num_evals = 1 paths = [] for _ in range(num_evals): paths += self.obtain_eval_paths(idx, eval_task=eval_task, deterministic=True) # goal = self.env._goal # for path in paths: # path['goal'] = goal # goal # save the paths for visualization, only useful for point mass if self.dump_eval_paths: split = 'test' if eval_task else 'train' logger.save_extra_data( paths, path='eval_trajectories/{}-task{}-epoch{}'.format( split, idx, epoch)) return paths
def _try_to_eval(self, epoch=0): if epoch % self.save_extra_data_interval == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch), epoch) if self._can_evaluate(): self.evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def train_vae_and_update_variant(variant): from rlkit.core import logger skewfit_variant = variant["skewfit_variant"] train_vae_variant = variant["train_vae_variant"] if skewfit_variant.get("vae_path", None) is None: logger.remove_tabular_output("progress.csv", relative_to_snapshot_dir=True) logger.add_tabular_output("vae_progress.csv", relative_to_snapshot_dir=True) vae, vae_train_data, vae_test_data = train_vae(train_vae_variant, return_data=True) if skewfit_variant.get("save_vae_data", False): skewfit_variant["vae_train_data"] = vae_train_data skewfit_variant["vae_test_data"] = vae_test_data logger.save_extra_data(vae, "vae.pkl", mode="pickle") logger.remove_tabular_output("vae_progress.csv", relative_to_snapshot_dir=True) logger.add_tabular_output("progress.csv", relative_to_snapshot_dir=True) skewfit_variant["vae_path"] = vae # just pass the VAE directly else: if skewfit_variant.get("save_vae_data", False): vae_train_data, vae_test_data, info = generate_vae_dataset( train_vae_variant["generate_vae_dataset_kwargs"]) skewfit_variant["vae_train_data"] = vae_train_data skewfit_variant["vae_test_data"] = vae_test_data
def _try_to_eval(self, epoch): if self._can_evaluate(): # save if it's time to save if epoch % self.freq_saving == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) self.evaluate(epoch) logger.record_tabular( "Number of train calls total", self._n_train_steps_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def collect_paths(self, idx, epoch, run): self.task_idx = idx self.env.reset_task(idx) self.agent.clear_z() paths = [] num_transitions = 0 num_trajs = 0 while num_transitions < self.num_steps_per_eval: path, num = self.sampler.obtain_samples( deterministic=self.eval_deterministic, max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, accum_context=True) paths += path num_transitions += num num_trajs += 1 if num_trajs >= self.num_exp_traj_eval: self.agent.infer_posterior(self.agent.context) if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards # save the paths for visualization, only useful for point mass if self.dump_eval_paths: logger.save_extra_data( paths, path='eval_trajectories/task{}-epoch{}-run{}'.format( idx, epoch, run)) return paths
def train_vae_and_update_variant(variant): from rlkit.core import logger skewfit_variant = variant['skewfit_variant'] train_vae_variant = variant['train_vae_variant'] if skewfit_variant.get('vae_path', None) is None: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) vae, vae_train_data, vae_test_data = train_vae( train_vae_variant, variant['other_variant'], return_data=True) if skewfit_variant.get('save_vae_data', False): skewfit_variant['vae_train_data'] = vae_train_data skewfit_variant['vae_test_data'] = vae_test_data logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) skewfit_variant['vae_path'] = vae # just pass the VAE directly else: if skewfit_variant.get('save_vae_data', False): vae_train_data, vae_test_data, info = generate_vae_dataset( train_vae_variant['generate_vae_dataset_kwargs']) skewfit_variant['vae_train_data'] = vae_train_data skewfit_variant['vae_test_data'] = vae_test_data
def experiment(variant): from rlkit.core import logger import rlkit.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def train_vae(variant, return_data=False): from rlkit.misc.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def evaluate(self, epoch): statistics = OrderedDict() statistics.update(self.eval_statistics) self.eval_statistics = None # statistics.update(eval_util.get_generic_path_information( # self._exploration_paths, stat_prefix="Exploration", # )) for mode in ['meta_train', 'meta_test']: logger.log("Collecting samples for evaluation") test_paths = self.obtain_eval_samples(epoch, mode=mode) statistics.update( eval_util.get_generic_path_information( test_paths, stat_prefix="Test " + mode, )) # print(statistics.keys()) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(test_paths) if hasattr(self.env, "log_statistics"): log_stats = self.env.log_statistics(test_paths) new_log_stats = OrderedDict( (k + ' ' + mode, v) for k, v in log_stats.items()) statistics.update(new_log_stats) average_returns = rlkit.core.eval_util.get_average_returns( test_paths) statistics['AverageReturn ' + mode] = average_returns if self.render_eval_paths: self.env.render_paths(test_paths) # meta_test_this_epoch = statistics['Percent_Solved meta_test'] # meta_test_this_epoch = statistics['Percent_Solved meta_test'] meta_test_this_epoch = statistics['AverageReturn meta_test'] if meta_test_this_epoch >= self.best_meta_test: # make sure you set save_algorithm to true then call save_extra_data prev_save_alg = self.save_algorithm self.save_algorithm = True if self.save_best: if epoch > self.save_best_after_epoch: temp = self.replay_buffer self.replay_buffer = None logger.save_extra_data(self.get_extra_data_to_save(epoch), 'best_meta_test.pkl') self.replay_buffer = temp self.best_meta_test = meta_test_this_epoch print('\n\nSAVED ALG AT EPOCH %d\n\n' % epoch) self.save_algorithm = prev_save_alg for key, value in statistics.items(): logger.record_tabular(key, value) if self.plotter: self.plotter.draw()
def _try_to_eval(self, epoch): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): if self.environment_farming: # Create new new eval_sampler each evaluation time in order to avoid relesed environment problem env_for_eval_sampler = self.farmer.force_acq_env() print(env_for_eval_sampler) self.eval_sampler = InPlacePathSampler( env=env_for_eval_sampler, policy=self.eval_policy, max_samples=self.num_steps_per_eval + self.max_path_length, max_path_length=self.max_path_length, ) self.evaluate(epoch) # Adding env back to free_env list self.farmer.add_free_env(env_for_eval_sampler) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def _start_epoch(self, epoch): self._epoch_start_time = time.time() self._exploration_paths = [] self._do_train_time = 0 logger.push_prefix('Iteration #%d | ' % epoch) if epoch in self.save_extra_manual_beginning_epoch_list: logger.save_extra_data( self.get_extra_data_to_save(epoch), file_name='extra_snapshot_beginning_itr{}'.format(epoch), mode='cloudpickle', )
def _try_to_eval(self, epoch): if epoch % self.freq_saving == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch) if epoch % self.freq_saving == 0: params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() # if self._old_table_keys is not None: # print('$$$$$$$$$$$$$$$') # print(table_keys) # print('\n'*4) # print(self._old_table_keys) # print('$$$$$$$$$$$$$$$') # print(set(table_keys) - set(self._old_table_keys)) # print(set(self._old_table_keys) - set(table_keys)) # assert table_keys == self._old_table_keys, ( # "Table keys cannot change from iteration to iteration." # ) # self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def collect_paths(self, idx, epoch, run): self.agent.clear_z() paths = [] num_transitions = 0 num_trajs = 0 init_context = None infer_posterior_at_start = False while num_transitions < self.num_steps_per_eval: # We follow the PEARL protocol and never update the posterior or resample z within an episode during evaluation. if idx in self.fake_task_idx_to_z: initialized_z_reward = self.fake_task_idx_to_z[idx] else: initialized_z_reward = None loop_paths, num = self.sampler.obtain_samples( deterministic=self.eval_deterministic, max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, accum_context=True, initial_context=init_context, task_idx=idx, resample_latent_period=self. exploration_resample_latent_period, # PEARL had this=0. update_posterior_period=0, # following PEARL protocol infer_posterior_at_start=infer_posterior_at_start, initialized_z_reward=initialized_z_reward, use_predicted_reward=initialized_z_reward is not None, ) paths += loop_paths num_transitions += num num_trajs += 1 # accumulated contexts across rollouts init_context = paths[-1]['context'] # TODO clean hack if num_trajs >= self.num_exp_traj_eval: infer_posterior_at_start = True if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards goal = self.env._goal for path in paths: path['goal'] = goal # goal # save the paths for visualization, only useful for point mass if self.dump_eval_paths and epoch >= 0: logger.save_extra_data( paths, file_name='eval_trajectories/task{}-epoch{}-run{}'.format( idx, epoch, run)) return paths
def evaluate(self, epoch): """ Evaluate the policy, e.g. save/print progress. :param epoch: :return: """ statistics = OrderedDict() try: statistics.update(self.eval_statistics) self.eval_statistics = None except: print('No Stats to Eval') logger.log("Collecting samples for evaluation") test_paths = self.eval_sampler.obtain_samples() statistics.update( eval_util.get_generic_path_information( test_paths, stat_prefix="Test", )) statistics.update( eval_util.get_generic_path_information( self._exploration_paths, stat_prefix="Exploration", )) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(test_paths) if hasattr(self.env, "log_statistics"): statistics.update(self.env.log_statistics(test_paths)) if epoch % self.freq_log_visuals == 0: if hasattr(self.env, "log_visuals"): self.env.log_visuals(test_paths, epoch, logger.get_snapshot_dir()) average_returns = eval_util.get_average_returns(test_paths) statistics['AverageReturn'] = average_returns for key, value in statistics.items(): logger.record_tabular(key, value) best_statistic = statistics[self.best_key] if best_statistic > self.best_statistic_so_far: self.best_statistic_so_far = best_statistic if self.save_best and epoch >= self.save_best_starting_from_epoch: data_to_save = {'epoch': epoch, 'statistics': statistics} data_to_save.update(self.get_epoch_snapshot(epoch)) logger.save_extra_data(data_to_save, 'best.pkl') print('\n\nSAVED BEST\n\n')
def _try_to_eval(self, epoch, eval_all=False, eval_train_offline=True, animated=False): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch, eval_all, eval_train_offline, animated) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs.get('train', [0])[-1] sample_time = times_itrs.get('sample', [0])[-1] eval_time = times_itrs.get('eval', [0])[-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def collect_paths(self, idx, epoch, run): self.task_idx = idx self.env.reset_task(idx) self.agent.clear_z() paths = [] all_zs = [] num_transitions = 0 num_trajs = 0 while num_transitions < self.num_steps_per_eval: path, num = self.sampler.obtain_samples( deterministic=self.eval_deterministic, max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, accum_context=True) paths += path num_transitions += num num_trajs += 1 if num_trajs >= self.num_exp_traj_eval: self.agent.infer_posterior(self.agent.context) all_zs.append({ 'z_mean': self.agent.z_means.detach().cpu().numpy(), 'z_vars': self.agent.z_vars.detach().cpu().numpy(), 'z_sample': self.agent.z.detach().cpu().numpy() }) if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards #import ipdb ; ipdb.set_trace() if self.dump_eval_paths: logger.save_extra_data({ 'paths': paths, 'zs': all_zs }, _dir_annotation='inference/task_' + str(idx)) return paths
def collect_paths(self, idx, epoch, run, animated=False): # print ('enter collect path') self.task_idx = idx self.env.reset_task(idx) self.agent.clear_z() paths = [] num_transitions = 0 num_trajs = 0 while num_transitions < self.num_steps_per_eval: path, num = self.sampler.obtain_samples(deterministic=self.eval_deterministic, \ max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, \ accum_context=True and self.glob, animated=animated, glob=self.glob) paths += path num_transitions += num num_trajs += 1 if num_trajs >= self.num_exp_traj_eval and self.glob: self.agent.infer_posterior(self.agent.context) if self.sparse_rewards: for p in paths: sparse_rewards = np.stack( e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards goal = self.env._goal for path in paths: path['goal'] = goal # goal if animated: for i in range(len(paths)): video_writer = imageio.get_writer(os.path.join( logger.get_snapshot_dir(), 'task{}-epoch{}-run{}.mp4'.format(idx, epoch, i)), fps=20) for j in paths[i]['frames']: video_writer.append_data(j) video_writer.close() # save the paths for visualization, only useful for point mass if self.dump_eval_paths: logger.save_extra_data( paths, path='eval_trajectories/task{}-epoch{}-run{}'.format( idx, epoch, run)) return paths
def train_vae(variant): #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5' #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5' # train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5' # test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5' train_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10kActions.h5' test_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10kActions.h5' train_feats, train_actions = load_dataset(train_path, train=True) test_feats, test_actions = load_dataset(test_path, train=False) K = variant['vae_kwargs']['K'] rep_size = variant['vae_kwargs']['representation_size'] logger.get_snapshot_dir() variant['vae_kwargs']['architecture'] = iodine.imsize64_large_iodine_architecture variant['vae_kwargs']['decoder_class'] = BroadcastCNN refinement_net = RefinementNetwork(**iodine.imsize64_large_iodine_architecture['refine_args'], hidden_activation=nn.ELU()) physics_net = PhysicsNetwork(K, rep_size, train_actions.shape[-1]) m = IodineVAE( **variant['vae_kwargs'], refinement_net=refinement_net, dynamic=True, physics_net=physics_net, ) m.to(ptu.device) t = IodineTrainer(train_feats, test_feats, m, variant['train_seedsteps'], variant['test_seedsteps'], train_actions=train_actions, test_actions=test_actions, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch, batches=train_feats.shape[0]//variant['algo_kwargs']['batch_size']) t.test_epoch(epoch, save_vae=True, train=False, record_stats=True, batches=1, save_reconstruction=should_save_imgs) t.test_epoch(epoch, save_vae=False, train=True, record_stats=False, batches=1, save_reconstruction=should_save_imgs) logger.save_extra_data(m, 'vae.pkl', mode='pickle')
def train_vae(variant, return_data=False): beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, ) if should_save_imgs: t.dump_samples(epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle') if return_data: return m, train_data, test_data return m
def collect_paths(self, idx, epoch, run, wideeval=False): self.task_idx = idx if wideeval==False: self.env.reset_task(idx) else: self.env_eval.reset_task(idx) self.agent.clear_z() paths = [] num_transitions = 0 num_trajs = 0 test_suc = 0 while num_transitions < self.num_steps_per_eval: if wideeval == False: path, num, info = self.sampler.obtain_samples(deterministic=self.eval_deterministic, max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, accum_context=True) else: path, num, info = self.sampler_eval.obtain_samples(deterministic=self.eval_deterministic, max_samples=self.num_steps_per_eval - num_transitions, max_trajs=1, accum_context=True) paths += path num_transitions += num num_trajs += 1 test_suc += info['n_success_num'] if num_trajs >= self.num_exp_traj_eval: self.agent.infer_posterior(self.agent.context) suc_rate = test_suc / num_trajs if self.sparse_rewards: for p in paths: sparse_rewards = np.stack(e['sparse_reward'] for e in p['env_infos']).reshape(-1, 1) p['rewards'] = sparse_rewards goal = self.env._goal for path in paths: path['goal'] = goal # goal # save the paths for visualization, only useful for point mass if self.dump_eval_paths: logger.save_extra_data(paths, path='eval_trajectories/task{}-epoch{}-run{}'.format(idx, epoch, run)) return paths, suc_rate
def train_vae(variant): #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train_10000.hdf5' #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5' train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorTwoBallSmall.h5' test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorTwoBallSmall.h5' train_data = load_dataset(train_path, train=True) test_data = load_dataset(test_path, train=False) train_data = train_data.reshape((train_data.shape[0], -1)) test_data = test_data.reshape((test_data.shape[0], -1)) #logger.save_extra_data(info) logger.get_snapshot_dir() variant['vae_kwargs'][ 'architecture'] = monet.imsize64_monet_architecture #monet.imsize84_monet_architecture variant['vae_kwargs']['decoder_output_activation'] = identity variant['vae_kwargs']['decoder_class'] = BroadcastCNN attention_net = UNet(in_channels=4, n_classes=1, up_mode='upsample', depth=3, padding=True) m = MonetVAE(**variant['vae_kwargs'], attention_net=attention_net) m.to(ptu.device) t = MonetTrainer(train_data, test_data, m, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, ) if should_save_imgs: t.dump_samples(epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle')
def experiment(variant): from rlkit.core import logger import rlkit.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = get_data(**variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def train_vae(variant): train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5' test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5' #train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5' #test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5' train_data = load_dataset(train_path, train=True) test_data = load_dataset(test_path, train=False) train_data = train_data.reshape((train_data.shape[0], -1))[:500] #train_data = train_data.reshape((train_data.shape[0], -1))[0] #train_data = np.reshape(train_data[:2], (2, -1)).repeat(100, 0) test_data = test_data.reshape((test_data.shape[0], -1))[:10] #logger.save_extra_data(info) logger.get_snapshot_dir() variant['vae_kwargs']['architecture'] = iodine.imsize84_iodine_architecture variant['vae_kwargs']['decoder_class'] = BroadcastCNN refinement_net = RefinementNetwork( **iodine.imsize84_iodine_architecture['refine_args'], hidden_activation=nn.ELU()) m = IodineVAE(**variant['vae_kwargs'], refinement_net=refinement_net) m.to(ptu.device) t = IodineTrainer(train_data, test_data, m, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch, batches=train_data.shape[0] // variant['algo_kwargs']['batch_size']) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_vae=False) if should_save_imgs: t.dump_samples(epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle')
def grill_her_full_experiment(variant, mode='td3'): train_vae_variant = variant['train_vae_variant'] grill_variant = variant['grill_variant'] env_class = variant['env_class'] env_kwargs = variant['env_kwargs'] init_camera = variant['init_camera'] train_vae_variant['generate_vae_dataset_kwargs']['env_class'] = env_class train_vae_variant['generate_vae_dataset_kwargs']['env_kwargs'] = env_kwargs train_vae_variant['generate_vae_dataset_kwargs'][ 'init_camera'] = init_camera grill_variant['env_class'] = env_class grill_variant['env_kwargs'] = env_kwargs grill_variant['init_camera'] = init_camera if 'vae_paths' not in grill_variant: logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) vae = train_vae(train_vae_variant) rdim = train_vae_variant['representation_size'] vae_file = logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) grill_variant['vae_paths'] = { str(rdim): vae_file, } grill_variant['rdim'] = str(rdim) if mode == 'td3': grill_her_td3_experiment(variant['grill_variant']) elif mode == 'twin-sac': grill_her_twin_sac_experiment(variant['grill_variant']) elif mode == 'sac': grill_her_sac_experiment(variant['grill_variant'])
def evaluate(self, epoch): if self.eval_statistics is None: self.eval_statistics = OrderedDict() ### sample trajectories from prior for debugging / visualization if self.dump_eval_paths: # 100 arbitrarily chosen for visualizations of point_robot trajectories # just want stochasticity of z, not the policy self.agent.clear_z() prior_paths, _ = self.sampler.obtain_samples( deterministic=self.eval_deterministic, max_samples=self.max_path_length * 20, accum_context=False, resample=1, testing=True) logger.save_extra_data( prior_paths, path='eval_trajectories/prior-epoch{}'.format(epoch)) ### train tasks # eval on a subset of train tasks for speed indices = np.random.choice(self.train_tasks, len(self.eval_tasks)) eval_util.dprint('evaluating on {} train tasks'.format(len(indices))) ### eval train tasks with posterior sampled from the training replay buffer train_returns = [] for idx in indices: self.task_idx = idx self.env.reset_task(idx) paths = [] for _ in range(self.num_steps_per_eval // self.max_path_length): context = self.sample_context(idx) self.agent.infer_posterior(context) p, _ = self.sampler.obtain_samples( deterministic=self.eval_deterministic, max_samples=self.max_path_length, accum_context=False, max_trajs=1, resample=np.inf, testing=True) paths += p if self.sparse_rewards: for p in paths: sparse_rewards = np.stack(e['sparse_reward'] for e in p['env_infos']).reshape( -1, 1) p['rewards'] = sparse_rewards train_returns.append(eval_util.get_average_returns(paths)) train_returns = np.mean(train_returns) ### eval train tasks with on-policy data to match eval of test tasks train_final_returns, train_online_returns = self._do_eval( indices, epoch) eval_util.dprint('train online returns') eval_util.dprint(train_online_returns) ### test tasks eval_util.dprint('evaluating on {} test tasks'.format( len(self.eval_tasks))) test_final_returns, test_online_returns = self._do_eval( self.eval_tasks, epoch) eval_util.dprint('test online returns') eval_util.dprint(test_online_returns) # save the final posterior self.agent.log_diagnostics(self.eval_statistics) if hasattr(self.env, "log_diagnostics"): self.env.log_diagnostics(paths, prefix=None) avg_train_return = np.mean(train_final_returns) avg_test_return = np.mean(test_final_returns) avg_train_online_return = np.mean(np.stack(train_online_returns), axis=0) avg_test_online_return = np.mean(np.stack(test_online_returns), axis=0) self.eval_statistics[ 'AverageTrainReturn_all_train_tasks'] = train_returns self.eval_statistics[ 'AverageReturn_all_train_tasks'] = avg_train_return self.eval_statistics['AverageReturn_all_test_tasks'] = avg_test_return logger.save_extra_data(avg_train_online_return, path='online-train-epoch{}'.format(epoch)) logger.save_extra_data(avg_test_online_return, path='online-test-epoch{}'.format(epoch)) for key, value in self.eval_statistics.items(): logger.record_tabular(key, value) self.eval_statistics = None if self.render_eval_paths: self.env.render_paths(paths) if self.plotter: self.plotter.draw()
def train_vae(variant): #train_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_train.hdf5' #test_path = '/home/jcoreyes/objects/rlkit/examples/monet/clevr_test.hdf5' # train_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5' # test_path = '/home/jcoreyes/objects/RailResearch/DataGeneration/ColorBigTwoBallSmall.h5' train_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10k.h5' test_path = '/home/jcoreyes/objects/RailResearch/BlocksGeneration/rendered/fiveBlock10k.h5' train_data = load_dataset(train_path, train=True) test_data = load_dataset(test_path, train=False) n_frames = 2 imsize = train_data.shape[-1] T = variant['vae_kwargs']['T'] K = variant['vae_kwargs']['K'] rep_size = variant['vae_kwargs']['representation_size'] # t_sample = np.array([0, 0, 0, 0, 0, 10, 15, 20, 25, 30]) #t_sample = np.array([0, 34, 34, 34, 34]) t_sample = np.array([0, 0, 0, 0, 1]) train_data = train_data.reshape( (n_frames, -1, 3, imsize, imsize)).swapaxes(0, 1)[:8000, t_sample] test_data = test_data.reshape( (n_frames, -1, 3, imsize, imsize)).swapaxes(0, 1)[:50, t_sample] #logger.save_extra_data(info) logger.get_snapshot_dir() variant['vae_kwargs'][ 'architecture'] = iodine.imsize64_large_iodine_architecture variant['vae_kwargs']['decoder_class'] = BroadcastCNN refinement_net = RefinementNetwork( **iodine.imsize64_large_iodine_architecture['refine_args'], hidden_activation=nn.ELU()) physics_net = None if variant['physics']: physics_net = PhysicsNetwork(K, rep_size) m = IodineVAE( **variant['vae_kwargs'], refinement_net=refinement_net, dynamic=True, physics_net=physics_net, ) m.to(ptu.device) t = IodineTrainer(train_data, test_data, m, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch, batches=train_data.shape[0] // variant['algo_kwargs']['batch_size']) t.test_epoch(epoch, save_vae=True, train=False, record_stats=True, batches=1, save_reconstruction=should_save_imgs) t.test_epoch(epoch, save_vae=False, train=True, record_stats=False, batches=1, save_reconstruction=should_save_imgs) logger.save_extra_data(m, 'vae.pkl', mode='pickle')
def train_vae(variant, other_variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, other_variant, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus if return_data: return m, train_data, test_data return m
def evaluate(self, epoch): """ Evaluate the policy, e.g. save/print progress. :param epoch: :return: """ statistics = OrderedDict() try: statistics.update(self.eval_statistics) self.eval_statistics = None except: print('No Stats to Eval') logger.log("Collecting random samples for evaluation") eval_steps = self.num_steps_per_eval test_paths = self.eval_sampler.obtain_samples(eval_steps) obs = torch.Tensor( np.squeeze(np.vstack([path["observations"] for path in test_paths]))) acts = torch.Tensor( np.squeeze(np.vstack([path["actions"] for path in test_paths]))) if len(acts.shape) < 2: acts = torch.unsqueeze(acts, 1) random_input = torch.cat([obs, acts], dim=1).to(ptu.device) exp_batch = self.get_batch(eval_steps, keys=['observations', 'actions'], use_expert_buffer=True) # exp_batch = {'observations':torch.Tensor([[0.],[1.],[2.],[3.],[4.],[5.],[6.],[7.],[8.],[9.],[10.]]), 'actions':torch.Tensor([[0.5]]*11)} obs = exp_batch['observations'] acts = exp_batch['actions'] exp_input = torch.cat([obs, acts], dim=1).to(ptu.device) statistics['random_avg_energy'] = self.ebm(random_input).mean().item() statistics['expert_avg_energy'] = self.get_energy( exp_input).mean().item() statistics['expert*20_avg_energy'] = self.get_energy(exp_input * 20).mean().item() statistics["random_expert_diff"] = statistics[ "random_avg_energy"] - statistics["expert_avg_energy"] for key, value in statistics.items(): logger.record_tabular(key, value) best_statistic = statistics[self.best_key] if best_statistic > self.best_statistic_so_far: self.best_statistic_so_far = best_statistic self.best_epoch = epoch self.best_random_avg_energy = statistics['random_avg_energy'] self.best_expert_avg_energy = statistics['expert_avg_energy'] logger.record_tabular("Best Model Epoch", self.best_epoch) logger.record_tabular("Best Random Energy", self.best_random_avg_energy) logger.record_tabular("Best Expert Energy", self.best_expert_avg_energy) if self.save_best and epoch >= self.save_best_starting_from_epoch: data_to_save = {'epoch': epoch, 'statistics': statistics} data_to_save.update(self.get_epoch_snapshot(epoch)) logger.save_extra_data(data_to_save, 'best.pkl') print('\n\nSAVED BEST\n\n') logger.record_tabular("Best Model Epoch", self.best_epoch) logger.record_tabular("Best Random Energy", self.best_random_avg_energy) logger.record_tabular("Best Expert Energy", self.best_expert_avg_energy)
def train(self): ''' meta-training loop ''' params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for(range(self.num_iterations), save_itrs=True): self._start_epoch(it_) self.training_mode(True) if it_ == 0: print('collecting initial pool of data for train and eval') # temp for evaluating for idx in self.train_tasks: self.task_idx = idx self.env.reset_task(idx) self.collect_data(self.num_initial_steps, 1, np.inf, buffer=self.train_buffer) # Sample data from train tasks. for i in range(self.num_tasks_sample): idx = np.random.choice(self.train_tasks, 1)[0] self.task_idx = idx self.env.reset_task(idx) self.enc_replay_buffer.task_buffers[idx].clear() # collect some trajectories with z ~ prior if self.num_steps_prior > 0: self.collect_data(self.num_steps_prior, 1, np.inf, buffer=self.train_buffer) # collect some trajectories with z ~ posterior if self.num_steps_posterior > 0: self.collect_data(self.num_steps_posterior, 1, self.update_post_train, buffer=self.train_buffer) # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior if self.num_extra_rl_steps_posterior > 0: self.collect_data(self.num_extra_rl_steps_posterior, 1, self.update_post_train, buffer=self.train_buffer, add_to_enc_buffer=False) indices_lst = [] z_means_lst = [] z_vars_lst = [] # Sample train tasks and compute gradient updates on parameters. for train_step in range(self.num_train_steps_per_itr): indices = np.random.choice(self.train_tasks, self.meta_batch, replace=self.mb_replace) z_means, z_vars = self._do_training(indices, zloss=True) indices_lst.append(indices) z_means_lst.append(z_means) z_vars_lst.append(z_vars) self._n_train_steps_total += 1 indices = np.concatenate(indices_lst) z_means = np.concatenate(z_means_lst) z_vars = np.concatenate(z_vars_lst) data_dict = self.data_dict(indices, z_means, z_vars) logger.save_itr_data(it_, **data_dict) gt.stamp('train') self.training_mode(False) # eval params = self.get_epoch_snapshot(it_) logger.save_itr_params(it_, params) if self.allow_eval: logger.save_extra_data(self.get_extra_data_to_save(it_)) self._try_to_eval(it_) gt.stamp('eval') self._end_epoch()
def train_pixelcnn( vqvae=None, vqvae_path=None, num_epochs=100, batch_size=32, n_layers=15, dataset_path=None, save=True, save_period=10, cached_dataset_path=False, trainer_kwargs=None, model_kwargs=None, data_filter_fn=lambda x: x, debug=False, data_size=float('inf'), num_train_batches_per_epoch=None, num_test_batches_per_epoch=None, train_img_loader=None, test_img_loader=None, ): trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs model_kwargs = {} if model_kwargs is None else model_kwargs # Load VQVAE + Define Args if vqvae is None: vqvae = load_local_or_remote_file(vqvae_path) vqvae.to(ptu.device) vqvae.eval() root_len = vqvae.root_len num_embeddings = vqvae.num_embeddings embedding_dim = vqvae.embedding_dim cond_size = vqvae.num_embeddings imsize = vqvae.imsize discrete_size = root_len * root_len representation_size = embedding_dim * discrete_size input_channels = vqvae.input_channels imlength = imsize * imsize * input_channels log_dir = logger.get_snapshot_dir() # Define data loading info new_path = osp.join(log_dir, 'pixelcnn_data.npy') def prep_sample_data(cached_path): data = load_local_or_remote_file(cached_path).item() train_data = data['train'] test_data = data['test'] return train_data, test_data def encode_dataset(path, object_list): data = load_local_or_remote_file(path) data = data.item() data = data_filter_fn(data) all_data = [] n = min(data["observations"].shape[0], data_size) for i in tqdm(range(n)): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) encodings = ptu.get_numpy(torch.stack(all_data, dim=0)) return encodings if train_img_loader: _, test_loader, test_batch_loader = create_conditional_data_loader( test_img_loader, 80, vqvae, "test2") # 80 _, train_loader, train_batch_loader = create_conditional_data_loader( train_img_loader, 2000, vqvae, "train2") # 2000 else: if cached_dataset_path: train_data, test_data = prep_sample_data(cached_dataset_path) else: train_data = encode_dataset(dataset_path['train'], None) # object_list) test_data = encode_dataset(dataset_path['test'], None) dataset = {'train': train_data, 'test': test_data} np.save(new_path, dataset) _, _, train_loader, test_loader, _ = \ rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size) #train_dataset = InfiniteBatchLoader(train_loader) #test_dataset = InfiniteBatchLoader(test_loader) print("Finished loading data") model = GatedPixelCNN(num_embeddings, root_len**2, n_classes=representation_size, **model_kwargs).to(ptu.device) trainer = PixelCNNTrainer( model, vqvae, batch_size=batch_size, **trainer_kwargs, ) print("Starting training") BEST_LOSS = 999 for epoch in range(num_epochs): should_save = (epoch % save_period == 0) and (epoch > 0) trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch) trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch) test_data = test_batch_loader.random_batch(bz)["x"] train_data = train_batch_loader.random_batch(bz)["x"] trainer.dump_samples(epoch, test_data, test=True) trainer.dump_samples(epoch, train_data, test=False) if should_save: logger.save_itr_params(epoch, model) stats = trainer.get_diagnostics() cur_loss = stats["test/loss"] if cur_loss < BEST_LOSS: BEST_LOSS = cur_loss vqvae.set_pixel_cnn(model) logger.save_extra_data(vqvae, 'best_vqvae', mode='torch') else: return vqvae for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) return vqvae