def train(self): ''' meta-training loop ''' self.pretrain() params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for( range(self.num_iterations), save_itrs=True, ): self._start_epoch(it_) self.training_mode(True) if it_ == 0: print('collecting initial pool of data for train and eval') # temp for evaluating for idx in self.train_tasks: self.task_idx = idx self.env.reset_task(idx) self.collect_data(self.num_initial_steps, 1, np.inf) # Sample data from train tasks. for i in range(self.num_tasks_sample): idx = np.random.randint(len(self.train_tasks)) self.task_idx = idx self.env.reset_task(idx) self.enc_replay_buffer.task_buffers[idx].clear() # collect some trajectories with z ~ prior if self.num_steps_prior > 0: self.collect_data(self.num_steps_prior, 1, np.inf) # collect some trajectories with z ~ posterior if self.num_steps_posterior > 0: self.collect_data(self.num_steps_posterior, 1, self.update_post_train) # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior if self.num_extra_rl_steps_posterior > 0: self.collect_data(self.num_extra_rl_steps_posterior, 1, self.update_post_train, add_to_enc_buffer=False) # Sample train tasks and compute gradient updates on parameters. for train_step in range(self.num_train_steps_per_itr): indices = np.random.choice(self.train_tasks, self.meta_batch) self._do_training(indices) self._n_train_steps_total += 1 gt.stamp('train') self.training_mode(False) # eval self._try_to_eval(it_) gt.stamp('eval') self._end_epoch()
def run(self): if self.progress_csv_file_name != 'progress.csv': logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output( self.progress_csv_file_name, relative_to_snapshot_dir=True, ) timer.return_global_times = True for _ in range(self.num_iters): self._begin_epoch() timer.start_timer('saving') logger.save_itr_params(self.epoch, self._get_snapshot()) timer.stop_timer('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) self._end_epoch() logger.save_itr_params(self.epoch, self._get_snapshot()) if self.progress_csv_file_name != 'progress.csv': logger.remove_tabular_output( self.progress_csv_file_name, relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def _end_epoch(self, epoch): print('in _end_epoch, epoch is: {}'.format(epoch)) snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) # trainer_obj = self.trainer # ckpt_path='ckpt.pkl' # logger.save_ckpt(epoch, trainer_obj, ckpt_path) # gt.stamp('saving') if epoch % 1 == 0: self.save_snapshot_2(epoch) expl_paths = self.expl_data_collector.get_epoch_paths() d = eval_util.get_generic_path_information(expl_paths) # print(d.keys()) metric_val = d['Rewards Mean'] cur_best_metric_val = self.get_cur_best_metric_val() if epoch != 0: self.save_snapshot_2_best_only( metric_val=metric_val, cur_best_metric_val=cur_best_metric_val, min_or_max='max') self._log_stats(epoch) self.expl_data_collector.end_epoch(epoch) self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def _try_to_eval(self, epoch): if self._can_evaluate(): # save if it's time to save if epoch % self.freq_saving == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) self.evaluate(epoch) logger.record_tabular( "Number of train calls total", self._n_train_steps_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def _try_to_eval(self, epoch=0): if epoch % self.save_extra_data_interval == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch), epoch) if self._can_evaluate(): self.evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def test_epoch( self, epoch, save_reconstruction=True, save_vae=True, from_rl=False, ): self.model.eval() losses = [] log_probs = [] kles = [] zs = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(10): next_obs = self.get_batch(train=False) reconstructions, obs_distribution_params, latent_distribution_params = self.model(next_obs) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) loss = -1 * log_prob + beta * kle encoder_mean = latent_distribution_params[0] z_data = ptu.get_numpy(encoder_mean.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.item()) log_probs.append(log_prob.item()) kles.append(kle.item()) if batch_idx == 0 and save_reconstruction: n = min(next_obs.size(0), 8) comparison = torch.cat([ next_obs[:n].narrow(start=0, length=self.imlength, dim=1) .contiguous().view( -1, self.input_channels, self.imsize, self.imsize ).transpose(2, 3), reconstructions.view( self.batch_size, self.input_channels, self.imsize, self.imsize, )[:n].transpose(2, 3) ]) save_dir = osp.join(logger.get_snapshot_dir(), 'r%d.png' % epoch) save_image(comparison.data.cpu(), save_dir, nrow=n) zs = np.array(zs) self.eval_statistics['epoch'] = epoch self.eval_statistics['test/log prob'] = np.mean(log_probs) self.eval_statistics['test/KL'] = np.mean(kles) self.eval_statistics['test/loss'] = np.mean(losses) self.eval_statistics['beta'] = beta if not from_rl: for k, v in self.eval_statistics.items(): logger.record_tabular(k, v) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model)
def train_vae(variant, return_data=False): from rlkit.misc.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def _try_to_eval(self, epoch): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): if self.environment_farming: # Create new new eval_sampler each evaluation time in order to avoid relesed environment problem env_for_eval_sampler = self.farmer.force_acq_env() print(env_for_eval_sampler) self.eval_sampler = InPlacePathSampler( env=env_for_eval_sampler, policy=self.eval_policy, max_samples=self.num_steps_per_eval + self.max_path_length, max_path_length=self.max_path_length, ) self.evaluate(epoch) # Adding env back to free_env list self.farmer.add_free_env(env_for_eval_sampler) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def train(self, start_epoch=0): self.pretrain() if start_epoch == 0: params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) self.training_mode(False) self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch gt.reset() gt.set_def_unique(False) self.train_online(start_epoch=start_epoch)
def _end_epoch(self, epoch): snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) gt.stamp('saving') self._log_stats(epoch) self.expl_data_collector.end_epoch(epoch) self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch)
def _end_epoch(self, epoch): snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) gt.stamp('saving') self._log_stats(epoch) self.eval_data_collector.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def _try_to_eval(self, epoch): if epoch % self.freq_saving == 0: logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch) if epoch % self.freq_saving == 0: params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() # if self._old_table_keys is not None: # print('$$$$$$$$$$$$$$$') # print(table_keys) # print('\n'*4) # print(self._old_table_keys) # print('$$$$$$$$$$$$$$$') # print(set(table_keys) - set(self._old_table_keys)) # print(set(self._old_table_keys) - set(table_keys)) # assert table_keys == self._old_table_keys, ( # "Table keys cannot change from iteration to iteration." # ) # self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] eval_time = times_itrs['eval'][-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def train(self): timer.return_global_times = True for _ in range(self.num_epochs): self._begin_epoch() timer.start_timer('saving') logger.save_itr_params(self.epoch, self._get_snapshot()) timer.stop_timer('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) self._end_epoch() logger.save_itr_params(self.epoch, self._get_snapshot())
def train(self): ''' meta-training loop ''' self.pretrain() gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for( range(self.num_iterations), save_itrs=True, ): self._start_epoch(it_) self.training_mode(True) # Sample train tasks and compute gradient updates on parameters. batch_idxes = np.random.randint(0, len(self.train_goals), size=self.meta_batch_size) train_batch_obj_id = self.replay_buffers.sample_training_data( batch_idxes, self.use_same_context) for _ in range(self.num_train_steps_per_itr): train_raw_batch = ray.get(train_batch_obj_id) gt.stamp('sample_training_data', unique=False) batch_idxes = np.random.randint(0, len(self.train_goals), size=self.meta_batch_size) # In this way, we can start the data sampling job for the # next training while doing training for the current loop. train_batch_obj_id = self.replay_buffers.sample_training_data( batch_idxes, self.use_same_context) gt.stamp('set_up_sampling', unique=False) train_data = self.construct_training_batch(train_raw_batch) gt.stamp('construct_training_batch', unique=False) self._do_training(train_data) self._n_train_steps_total += 1 gt.stamp('train') self.training_mode(False) # eval self._try_to_eval(it_) gt.stamp('eval') self._end_epoch() if it_ == self.num_iterations: logger.save_itr_params(it_, self.agent.get_snapshot())
def _end_epoch(self, epoch, num_epochs_per_eval=0): snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) gt.stamp('saving') self._log_stats(epoch, num_epochs_per_eval) self.expl_data_collector.end_epoch(epoch) # self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def _end_epoch(self, epoch, solved=False): snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) gt.stamp('saving') self._log_stats(epoch, solved=solved) self.eval_data_collector.end_epoch(epoch) if not solved: self.expl_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def _end_epoch(self, epoch): #print ("core/rl_algorithm, _end_epoch(): ", "epoch: ", epoch) snapshot = self._get_snapshot() #print ("core/rl_algorithm, _end_epoch(): ", "snapshot: ", snapshot) logger.save_itr_params(epoch, snapshot) gt.stamp('saving') self._log_stats(epoch) self.expl_data_collector.end_epoch(epoch) self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def _end_epoch(self, epoch): snapshot = self._get_snapshot() # only save params for the first gpu if ptu.dist_rank == 0: logger.save_itr_params(epoch, snapshot) gt.stamp("saving") self._log_stats(epoch) self.expl_data_collector.end_epoch(epoch) self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def train(self): self.fix_data_set() logger.log("Done creating dataset.") num_batches_total = 0 for epoch in range(self.num_epochs): for _ in range(self.num_batches_per_epoch): self.qf.train(True) self._do_training() num_batches_total += 1 logger.push_prefix('Iteration #%d | ' % epoch) self.qf.train(False) self.evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.log("Done evaluating") logger.pop_prefix()
def _end_epoch(self, epoch): snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) gt.stamp('saving') self._log_stats(epoch) if self.collect_actions and epoch % self.collect_actions_every == 0: self._log_actions(epoch) self.expl_data_collector.end_epoch(epoch) self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def train(self, start_epoch=0): self.pretrain() if start_epoch == 0: params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) self.training_mode(False) self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch gt.reset() gt.set_def_unique(False) if self.collection_mode == 'online': self.train_online(start_epoch=start_epoch) elif self.collection_mode == 'batch': self.train_batch(start_epoch=start_epoch) else: raise TypeError("Invalid collection_mode: {}".format( self.collection_mode))
def _try_to_eval(self, epoch, eval_all=False, eval_train_offline=True, animated=False): logger.save_extra_data(self.get_extra_data_to_save(epoch)) if self._can_evaluate(): self.evaluate(epoch, eval_all, eval_train_offline, animated) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) times_itrs = gt.get_times().stamps.itrs train_time = times_itrs.get('train', [0])[-1] sample_time = times_itrs.get('sample', [0])[-1] eval_time = times_itrs.get('eval', [0])[-1] if epoch > 0 else 0 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def _end_epoch(self, epoch): if not self.trainer.discrete: snapshot = self._get_snapshot() logger.save_itr_params(epoch, snapshot) # if snapshot['evaluation/Average Returns'] >= self.best_rewrad: # self.best_rewrad = snapshot['evaluation/Average Returns'] gt.stamp('saving') self._log_stats(epoch) self.expl_data_collector.end_epoch(epoch) self.eval_data_collector.end_epoch(epoch) self.replay_buffer.end_epoch(epoch) self.trainer.end_epoch(epoch) for post_epoch_func in self.post_epoch_funcs: post_epoch_func(self, epoch)
def create_policy(variant): bottom_snapshot = joblib.load(variant['bottom_path']) column_snapshot = joblib.load(variant['column_path']) policy = variant['combiner_class']( policy1=bottom_snapshot['naf_policy'], policy2=column_snapshot['naf_policy'], ) env = bottom_snapshot['env'] logger.save_itr_params(0, dict( policy=policy, env=env, )) path = rollout( env, policy, max_path_length=variant['max_path_length'], animated=variant['render'], ) env.log_diagnostics([path]) logger.dump_tabular()
def train(self): for epoch in range(self.num_epochs): logger.push_prefix('Iteration #%d | ' % epoch) start_time = time.time() for _ in range(self.num_steps_per_epoch): batch = self.get_batch() train_dict = self.get_train_dict(batch) self.policy_optimizer.zero_grad() policy_loss = train_dict['Policy Loss'] policy_loss.backward() self.policy_optimizer.step() logger.log("Train time: {}".format(time.time() - start_time)) start_time = time.time() self.evaluate(epoch) logger.log("Eval time: {}".format(time.time() - start_time)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.pop_prefix()
def train(dataset_generator, n_start_samples, projection=project_samples_square_np, n_samples_to_add_per_epoch=1000, n_epochs=100, z_dim=1, hidden_size=32, save_period=10, append_all_data=True, full_variant=None, dynamics_noise=0, decoder_output_var='learned', num_bins=5, skew_config=None, use_perfect_samples=False, use_perfect_density=False, vae_reset_period=0, vae_kwargs=None, use_dataset_generator_first_epoch=True, **kwargs): """ Sanitize Inputs """ assert skew_config is not None if not (use_perfect_density and use_perfect_samples): assert vae_kwargs is not None if vae_kwargs is None: vae_kwargs = {} report = HTMLReport( logger.get_snapshot_dir() + '/report.html', images_per_row=10, ) dynamics = Dynamics(projection, dynamics_noise) if full_variant: report.add_header("Variant") report.add_text( json.dumps( ppp.dict_to_safe_json(full_variant, sort=True), indent=2, )) vae, decoder, decoder_opt, encoder, encoder_opt = get_vae( decoder_output_var, hidden_size, z_dim, vae_kwargs, ) vae.to(ptu.device) epochs = [] losses = [] kls = [] log_probs = [] hist_heatmap_imgs = [] vae_heatmap_imgs = [] sample_imgs = [] entropies = [] tvs_to_uniform = [] entropy_gains_from_reweighting = [] p_theta = Histogram(num_bins) p_new = Histogram(num_bins) orig_train_data = dataset_generator(n_start_samples) train_data = orig_train_data start = time.time() for epoch in progressbar(range(n_epochs)): p_theta = Histogram(num_bins) if epoch == 0 and use_dataset_generator_first_epoch: vae_samples = dataset_generator(n_samples_to_add_per_epoch) else: if use_perfect_samples and epoch != 0: # Ideally the VAE = p_new, but in practice, it won't be... vae_samples = p_new.sample(n_samples_to_add_per_epoch) else: vae_samples = vae.sample(n_samples_to_add_per_epoch) projected_samples = dynamics(vae_samples) if append_all_data: train_data = np.vstack((train_data, projected_samples)) else: train_data = np.vstack((orig_train_data, projected_samples)) p_theta.fit(train_data) if use_perfect_density: prob = p_theta.compute_density(train_data) else: prob = vae.compute_density(train_data) all_weights = prob_to_weight(prob, skew_config) p_new.fit(train_data, weights=all_weights) if epoch == 0 or (epoch + 1) % save_period == 0: epochs.append(epoch) report.add_text("Epoch {}".format(epoch)) hist_heatmap_img = visualize_histogram(p_theta, skew_config, report) vae_heatmap_img = visualize_vae( vae, skew_config, report, resolution=num_bins, ) sample_img = visualize_vae_samples( epoch, train_data, vae, report, dynamics, ) visualize_samples( p_theta.sample(n_samples_to_add_per_epoch), report, title="P Theta/RB Samples", ) visualize_samples( p_new.sample(n_samples_to_add_per_epoch), report, title="P Adjusted Samples", ) hist_heatmap_imgs.append(hist_heatmap_img) vae_heatmap_imgs.append(vae_heatmap_img) sample_imgs.append(sample_img) report.save() Image.fromarray( hist_heatmap_img).save(logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)) Image.fromarray( vae_heatmap_img).save(logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch)) Image.fromarray(sample_img).save(logger.get_snapshot_dir() + '/samples{}.png'.format(epoch)) """ train VAE to look like p_new """ if sum(all_weights) == 0: all_weights[:] = 1 if vae_reset_period > 0 and epoch % vae_reset_period == 0: vae, decoder, decoder_opt, encoder, encoder_opt = get_vae( decoder_output_var, hidden_size, z_dim, vae_kwargs, ) vae.to(ptu.device) vae.fit(train_data, weights=all_weights) epoch_stats = vae.get_epoch_stats() losses.append(np.mean(epoch_stats['losses'])) kls.append(np.mean(epoch_stats['kls'])) log_probs.append(np.mean(epoch_stats['log_probs'])) entropies.append(p_theta.entropy()) tvs_to_uniform.append(p_theta.tv_to_uniform()) entropy_gain = p_new.entropy() - p_theta.entropy() entropy_gains_from_reweighting.append(entropy_gain) for k in sorted(epoch_stats.keys()): logger.record_tabular(k, epoch_stats[k]) logger.record_tabular("Epoch", epoch) logger.record_tabular('Entropy ', p_theta.entropy()) logger.record_tabular('KL from uniform', p_theta.kl_from_uniform()) logger.record_tabular('TV to uniform', p_theta.tv_to_uniform()) logger.record_tabular('Entropy gain from reweight', entropy_gain) logger.record_tabular('Total Time (s)', time.time() - start) logger.dump_tabular() logger.save_itr_params( epoch, { 'vae': vae, 'train_data': train_data, 'vae_samples': vae_samples, 'dynamics': dynamics, }) report.add_header("Training Curves") plot_curves( [ ("Training Loss", losses), ("KL", kls), ("Log Probs", log_probs), ("Entropy Gain from Reweighting", entropy_gains_from_reweighting), ], report, ) plot_curves( [ ("Entropy", entropies), ("TV to Uniform", tvs_to_uniform), ], report, ) report.add_text("Max entropy: {}".format(p_theta.max_entropy())) report.save() for filename, imgs in [ ("hist_heatmaps", hist_heatmap_imgs), ("vae_heatmaps", vae_heatmap_imgs), ("samples", sample_imgs), ]: video = np.stack(imgs) vwrite( logger.get_snapshot_dir() + '/{}.mp4'.format(filename), video, ) local_gif_file_path = '{}.gif'.format(filename) gif_file_path = '{}/{}'.format(logger.get_snapshot_dir(), local_gif_file_path) gif(gif_file_path, video) report.add_image(local_gif_file_path, txt=filename, is_url=True) report.save()
def train(self): ''' meta-training loop ''' self.pretrain() params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() self.train_obs = self._start_new_rollout() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for( range(self.num_iterations), save_itrs=True, ): self._start_epoch(it_) self.training_mode(True) if it_ == 0: print('collecting initial pool of data for train and eval') # temp for evaluating for idx in self.train_tasks: print('train task', idx) self.task_idx = idx self.env.reset_task(idx) self.collect_data_sampling_from_prior( num_samples=self.max_path_length * 10, resample_z_every_n=self.max_path_length, eval_task=False) """ for idx in self.eval_tasks: self.task_idx = idx self.env.reset_task(idx) # TODO: make number of initial trajectories a parameter self.collect_data_sampling_from_prior(num_samples=self.max_path_length * 20, resample_z_every_n=self.max_path_length, eval_task=True) """ # Sample data from train tasks. for i in range(self.num_tasks_sample): idx = np.random.randint(len(self.train_tasks)) self.task_idx = idx self.env.reset_task(idx) # TODO: there may be more permutations of sampling/adding to encoding buffer we may wish to try if self.train_embedding_source == 'initial_pool': # embeddings are computed using only the initial pool of data # sample data from posterior to train RL algorithm self.collect_data_from_task_posterior( idx=idx, num_samples=self.num_steps_per_task, add_to_enc_buffer=False) elif self.train_embedding_source == 'posterior_only': self.collect_data_from_task_posterior( idx=idx, num_samples=self.num_steps_per_task, eval_task=False, add_to_enc_buffer=True) elif self.train_embedding_source == 'online_exploration_trajectories': # embeddings are computed using only data collected using the prior # sample data from posterior to train RL algorithm self.enc_replay_buffer.task_buffers[idx].clear() # resamples using current policy, conditioned on prior self.collect_data_sampling_from_prior( num_samples=self.num_steps_per_task, resample_z_every_n=self.max_path_length, add_to_enc_buffer=True) self.env.reset_task(idx) self.collect_data_from_task_posterior( idx=idx, num_samples=self.num_steps_per_task, add_to_enc_buffer=False, viz=True) elif self.train_embedding_source == 'online_on_policy_trajectories': # sample from prior, then sample more from the posterior # embeddings computed from both prior and posterior data self.enc_replay_buffer.task_buffers[idx].clear() self.collect_data_online( idx=idx, num_samples=self.num_steps_per_task, add_to_enc_buffer=True) else: raise Exception( "Invalid option for computing train embedding {}". format(self.train_embedding_source)) # Sample train tasks and compute gradient updates on parameters. for train_step in range(self.num_train_steps_per_itr): indices = np.random.choice(self.train_tasks, self.meta_batch) self._do_training(indices, train_step) self._n_train_steps_total += 1 gt.stamp('train') #self.training_mode(False) # eval self._try_to_eval(it_) gt.stamp('eval') self._end_epoch()
def train_pixelcnn( vqvae=None, vqvae_path=None, num_epochs=100, batch_size=32, n_layers=15, dataset_path=None, save=True, save_period=10, cached_dataset_path=False, trainer_kwargs=None, model_kwargs=None, data_filter_fn=lambda x: x, debug=False, data_size=float('inf'), num_train_batches_per_epoch=None, num_test_batches_per_epoch=None, train_img_loader=None, test_img_loader=None, ): trainer_kwargs = {} if trainer_kwargs is None else trainer_kwargs model_kwargs = {} if model_kwargs is None else model_kwargs # Load VQVAE + Define Args if vqvae is None: vqvae = load_local_or_remote_file(vqvae_path) vqvae.to(ptu.device) vqvae.eval() root_len = vqvae.root_len num_embeddings = vqvae.num_embeddings embedding_dim = vqvae.embedding_dim cond_size = vqvae.num_embeddings imsize = vqvae.imsize discrete_size = root_len * root_len representation_size = embedding_dim * discrete_size input_channels = vqvae.input_channels imlength = imsize * imsize * input_channels log_dir = logger.get_snapshot_dir() # Define data loading info new_path = osp.join(log_dir, 'pixelcnn_data.npy') def prep_sample_data(cached_path): data = load_local_or_remote_file(cached_path).item() train_data = data['train'] test_data = data['test'] return train_data, test_data def encode_dataset(path, object_list): data = load_local_or_remote_file(path) data = data.item() data = data_filter_fn(data) all_data = [] n = min(data["observations"].shape[0], data_size) for i in tqdm(range(n)): obs = ptu.from_numpy(data["observations"][i] / 255.0) latent = vqvae.encode(obs, cont=False) all_data.append(latent) encodings = ptu.get_numpy(torch.stack(all_data, dim=0)) return encodings if train_img_loader: _, test_loader, test_batch_loader = create_conditional_data_loader( test_img_loader, 80, vqvae, "test2") # 80 _, train_loader, train_batch_loader = create_conditional_data_loader( train_img_loader, 2000, vqvae, "train2") # 2000 else: if cached_dataset_path: train_data, test_data = prep_sample_data(cached_dataset_path) else: train_data = encode_dataset(dataset_path['train'], None) # object_list) test_data = encode_dataset(dataset_path['test'], None) dataset = {'train': train_data, 'test': test_data} np.save(new_path, dataset) _, _, train_loader, test_loader, _ = \ rlkit.torch.vae.pixelcnn_utils.load_data_and_data_loaders(new_path, 'COND_LATENT_BLOCK', batch_size) #train_dataset = InfiniteBatchLoader(train_loader) #test_dataset = InfiniteBatchLoader(test_loader) print("Finished loading data") model = GatedPixelCNN(num_embeddings, root_len**2, n_classes=representation_size, **model_kwargs).to(ptu.device) trainer = PixelCNNTrainer( model, vqvae, batch_size=batch_size, **trainer_kwargs, ) print("Starting training") BEST_LOSS = 999 for epoch in range(num_epochs): should_save = (epoch % save_period == 0) and (epoch > 0) trainer.train_epoch(epoch, train_loader, num_train_batches_per_epoch) trainer.test_epoch(epoch, test_loader, num_test_batches_per_epoch) test_data = test_batch_loader.random_batch(bz)["x"] train_data = train_batch_loader.random_batch(bz)["x"] trainer.dump_samples(epoch, test_data, test=True) trainer.dump_samples(epoch, train_data, test=False) if should_save: logger.save_itr_params(epoch, model) stats = trainer.get_diagnostics() cur_loss = stats["test/loss"] if cur_loss < BEST_LOSS: BEST_LOSS = cur_loss vqvae.set_pixel_cnn(model) logger.save_extra_data(vqvae, 'best_vqvae', mode='torch') else: return vqvae for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) return vqvae
def train(self): ''' meta-training loop ''' self.pretrain() params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) gt.reset() gt.set_def_unique(False) self._current_path_builder = PathBuilder() # at each iteration, we first collect data from tasks, perform meta-updates, then try to evaluate for it_ in gt.timed_for( range(self.num_iterations), save_itrs=True, ): self._start_epoch(it_) self.training_mode(True) print("\nIteration:{}".format(it_+1)) if it_ == 0:# 算法第一步,初始化每个任务的buffer print('\nCollecting initial pool of data for train and eval') # temp for evaluating for idx in self.train_tasks:#在训练开始之前,为每个任务采集2000条transition self.task_idx = idx#更改当前任务idx self.env.reset_task(idx)#重置任务 self.collect_data(self.num_initial_steps, 1, np.inf)#采集num_initial_steps条轨迹c并利用q(z|c)更新self.z # print("task id:", self.task_idx, " env:", self.replay_buffer.env) # print("buffer ", self.task_idx, ":", self.replay_buffer.task_buffers[self.task_idx].__dict__.items()) # Sample data from train tasks. print("\nFinishing collecting initial pool of data") print("\nSampling data from train tasks for Meta-training") for i in range(self.num_tasks_sample):#对于所有的train_tasks,随机从中取5个,然后为每个任务的buffer采集num_steps_prior + num_extra_rl_steps_posterior条transition print("\nSample data , round{}".format(i+1))#为每个任务的enc_buffer采集num_steps_prior条transition idx = np.random.randint(len(self.train_tasks))#train_tasks里面随便选一个task self.task_idx = idx self.env.reset_task(idx)#task重置 self.enc_replay_buffer.task_buffers[idx].clear()#清除对应的enc_bufffer # collect some trajectories with z ~ prior if self.num_steps_prior > 0: print("\ncollect some trajectories with z ~ prior") self.collect_data(self.num_steps_prior, 1, np.inf)#利用z的先验采集num_steps_prior条transition # collect some trajectories with z ~ posterior if self.num_steps_posterior > 0: print("\ncollect some trajectories with z ~ posterior") self.collect_data(self.num_steps_posterior, 1, self.update_post_train)#利用后验的z收集轨迹 # even if encoder is trained only on samples from the prior, the policy needs to learn to handle z ~ posterior if self.num_extra_rl_steps_posterior > 0: print("\ncollect some trajectories for policy update only") self.collect_data(self.num_extra_rl_steps_posterior, 1, self.update_post_train, add_to_enc_buffer=False)#利用后验的z收集num_extra_rl_steps_posterior条轨迹,仅用于策略 print("\nFinishing sample data from train tasks") # Sample train tasks and compute gradient updates on parameters. print("\nStrating Meta-training , Episode {}".format(it_)) for train_step in range(self.num_train_steps_per_itr):#每轮迭代计算num_train_steps_per_itr次梯度 500x2000=1000000 indices = np.random.choice(self.train_tasks, self.meta_batch)#train_tasks中随机取meta_batch个task , sample RL batch b~B if ((train_step + 1) % 500 == 0): print("\nTraining step {}".format(train_step + 1)) print("Indices: {}".format(indices)) print("alpha:{}".format(self.alpha)) self._do_training(indices)#梯度下降 self._n_train_steps_total += 1 gt.stamp('train') self.training_mode(False) # eval self._try_to_eval(it_) gt.stamp('eval') self._end_epoch()
def test_epoch( self, epoch, sample_batch=None, key=None, save_reconstruction=True, save_vae=True, from_rl=False, save_prefix='r', only_train_vae=False, ): self.model.eval() losses = [] log_probs = [] triplet_losses = [] matching_losses = [] vae_matching_losses = [] kles = [] lstm_kles = [] ae_losses = [] contrastive_losses = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(10): # print(batch_idx) if sample_batch is not None: data = sample_batch(self.batch_size, key=key) next_obs = data['next_obs'] else: next_obs = self.get_batch(epoch=epoch) reconstructions, obs_distribution_params, vae_latent_distribution_params, lstm_latent_encodings = self.model( next_obs) latent_encodings = lstm_latent_encodings vae_mu = vae_latent_distribution_params[0] # this is lstm inputs latent_distribution_params = vae_latent_distribution_params triplet_loss = ptu.zeros(1) for tri_idx, triplet_type in enumerate(self.triplet_loss_type): if triplet_type == 1 and not only_train_vae: triplet_loss += self.triplet_loss_coef[ tri_idx] * self.triplet_loss(latent_encodings) elif triplet_type == 2 and not only_train_vae: triplet_loss += self.triplet_loss_coef[ tri_idx] * self.triplet_loss_2(next_obs) elif triplet_type == 3 and not only_train_vae: triplet_loss += self.triplet_loss_coef[ tri_idx] * self.triplet_loss_3(next_obs) if self.matching_loss_coef > 0 and not only_train_vae: matching_loss = self.matching_loss(next_obs) else: matching_loss = ptu.zeros(1) if self.vae_matching_loss_coef > 0: matching_loss_vae = self.matching_loss_vae(next_obs) else: matching_loss_vae = ptu.zeros(1) if self.contrastive_loss_coef > 0 and not only_train_vae: contrastive_loss = self.contrastive_loss(next_obs) else: contrastive_loss = ptu.zeros(1) log_prob = self.model.logprob(next_obs, obs_distribution_params) kle = self.model.kl_divergence(latent_distribution_params) lstm_kle = ptu.zeros(1) ae_loss = F.mse_loss( latent_encodings.view((-1, self.model.representation_size)), vae_mu.detach()) ae_losses.append(ae_loss.item()) loss = -self.recon_loss_coef * log_prob + beta * kle + \ self.matching_loss_coef * matching_loss + self.ae_loss_coef * ae_loss + triplet_loss + \ self.vae_matching_loss_coef * matching_loss_vae + self.contrastive_loss_coef * contrastive_loss losses.append(loss.item()) log_probs.append(log_prob.item()) triplet_losses.append(triplet_loss.item()) matching_losses.append(matching_loss.item()) vae_matching_losses.append(matching_loss_vae.item()) kles.append(kle.item()) lstm_kles.append(lstm_kle.item()) contrastive_losses.append(contrastive_loss.item()) if batch_idx == 0 and save_reconstruction: seq_len, batch_size, feature_size = next_obs.shape show_obs = next_obs[0][:8] reconstructions = reconstructions.view( (seq_len, batch_size, feature_size))[0][:8] comparison = torch.cat([ show_obs.narrow(start=0, length=self.imlength, dim=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize).transpose(2, 3), reconstructions.view( -1, self.input_channels, self.imsize, self.imsize, ).transpose(2, 3) ]) save_dir = osp.join(logger.get_snapshot_dir(), '{}{}.png'.format(save_prefix, epoch)) save_image(comparison.data.cpu(), save_dir, nrow=8) self.eval_statistics['epoch'] = epoch self.eval_statistics['test/log prob'] = np.mean(log_probs) self.eval_statistics['test/triplet loss'] = np.mean(triplet_losses) self.eval_statistics['test/vae matching loss'] = np.mean( vae_matching_losses) self.eval_statistics['test/matching loss'] = np.mean(matching_losses) self.eval_statistics['test/KL'] = np.mean(kles) self.eval_statistics['test/lstm KL'] = np.mean(lstm_kles) self.eval_statistics['test/loss'] = np.mean(losses) self.eval_statistics['test/contrastive loss'] = np.mean( contrastive_losses) self.eval_statistics['beta'] = beta self.eval_statistics['test/ae loss'] = np.mean(ae_losses) if not from_rl: for k, v in self.eval_statistics.items(): logger.record_tabular(k, v) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model) torch.cuda.empty_cache()