def test_epoch(self, epoch, save_network=True, batches=100): self.model.eval() mses = [] losses = [] for batch_idx in range(batches): data = self.get_batch(train=False) z = data["z"] z_proj = data['z_proj'] z_proj_hat = self.model(z) mse = self.mse_loss(z_proj_hat, z_proj) loss = mse mses.append(mse.data[0]) losses.append(loss.data[0]) logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/MSE", np.mean(mses)) logger.record_tabular("test/loss", np.mean(losses)) logger.dump_tabular() if save_network: logger.save_itr_params(epoch, self.model, prefix='reproj', save_anyway=True)
def train(self, start_epoch=0): self.pretrain() if start_epoch == 0: params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) self.training_mode(False) self._n_env_steps_total = start_epoch * self.num_env_steps_per_epoch gt.reset() gt.set_def_unique(False) if self.collection_mode == 'online': self.train_online(start_epoch=start_epoch) elif self.collection_mode == 'online-parallel': try: self.train_parallel(start_epoch=start_epoch) except: import traceback traceback.print_exc() self.parallel_env.shutdown() elif self.collection_mode == 'batch': self.train_batch(start_epoch=start_epoch) elif self.collection_mode == 'offline': self.train_offline(start_epoch=start_epoch) else: raise TypeError("Invalid collection_mode: {}".format( self.collection_mode )) self.cleanup()
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def test_epoch(self, epoch, save_vae=True, **kwargs): self.model.eval() losses = [] kles = [] zs = [] recon_logging_dict = { 'MSE': [], 'WSE': [], } for k in self.extra_recon_logging: recon_logging_dict[k] = [] beta = self.beta_schedule.get_value(epoch) for batch_idx in range(100): data = self.get_batch(train=False) obs = data['obs'] next_obs = data['next_obs'] actions = data['actions'] recon_batch, mu, logvar = self.model(next_obs) mse = self.logprob(recon_batch, next_obs) wse = self.logprob(recon_batch, next_obs, unorm_weights=self.recon_weights) for k, idx in self.extra_recon_logging.items(): recon_loss = self.logprob(recon_batch, next_obs, idx=idx) recon_logging_dict[k].append(recon_loss.data[0]) kle = self.kl_divergence(mu, logvar) if self.recon_loss_type == 'mse': loss = mse + beta * kle elif self.recon_loss_type == 'wse': loss = wse + beta * kle z_data = ptu.get_numpy(mu.cpu()) for i in range(len(z_data)): zs.append(z_data[i, :]) losses.append(loss.data[0]) recon_logging_dict['WSE'].append(wse.data[0]) recon_logging_dict['MSE'].append(mse.data[0]) kles.append(kle.data[0]) zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) for k in recon_logging_dict: logger.record_tabular("/".join(["test", k]), np.mean(recon_logging_dict[k])) logger.record_tabular("test/KL", np.mean(kles)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("beta", beta) process = psutil.Process(os.getpid()) logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000)) num_active_dims = 0 for std in self.model.dist_std: if std > 0.15: num_active_dims += 1 logger.record_tabular("num_active_dims", num_active_dims) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True) # slow...
def train_offline(self, start_epoch=0): self.training_mode(False) params = self.get_epoch_snapshot(-1) logger.save_itr_params(-1, params) for epoch in range(start_epoch, self.num_epochs): self._start_epoch(epoch) self._try_to_train() self._try_to_offline_eval(epoch) self._end_epoch()
def _try_to_eval(self, epoch, eval_paths=None): logger.save_extra_data(self.get_extra_data_to_save(epoch)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) if self._can_evaluate(): self.evaluate(epoch, eval_paths=eval_paths) # params = self.get_epoch_snapshot(epoch) # logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration." ) self._old_table_keys = table_keys logger.record_tabular( "Number of train steps total", self._n_train_steps_total, ) logger.record_tabular( "Number of env steps total", self._n_env_steps_total, ) logger.record_tabular( "Number of rollouts total", self._n_rollouts_total, ) if self.collection_mode != 'online-parallel': times_itrs = gt.get_times().stamps.itrs train_time = times_itrs['train'][-1] sample_time = times_itrs['sample'][-1] if 'eval' in times_itrs: eval_time = times_itrs['eval'][-1] if epoch > 0 else -1 else: eval_time = -1 epoch_time = train_time + sample_time + eval_time total_time = gt.get_times().total logger.record_tabular('Train Time (s)', train_time) logger.record_tabular('(Previous) Eval Time (s)', eval_time) logger.record_tabular('Sample Time (s)', sample_time) logger.record_tabular('Epoch Time (s)', epoch_time) logger.record_tabular('Total Train Time (s)', total_time) else: logger.record_tabular('Epoch Time (s)', time.time() - self._epoch_start_time) logger.record_tabular("Epoch", epoch) logger.dump_tabular(with_prefix=False, with_timestamp=False) else: logger.log("Skipping eval for now.")
def train(self): timer.return_global_times = True for _ in range(self.num_epochs): self._begin_epoch() # logger.save_itr_params(self.epoch, self._get_snapshot()) # timer.stamp('saving') log_dict, _ = self._train() logger.record_dict(log_dict) logger.dump_tabular(with_prefix=True, with_timestamp=False) logger.save_itr_params(self.epoch, self._get_snapshot()) self._end_epoch()
def _try_to_offline_eval(self, epoch): start_time = time.time() logger.save_extra_data(self.get_extra_data_to_save(epoch)) self.offline_evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) table_keys = logger.get_table_key_set() if self._old_table_keys is not None: assert table_keys == self._old_table_keys, ( "Table keys cannot change from iteration to iteration.") self._old_table_keys = table_keys logger.dump_tabular(with_prefix=False, with_timestamp=False) logger.log("Eval Time: {0}".format(time.time() - start_time))
def train(self): self.fix_data_set() logger.log("Done creating dataset.") num_batches_total = 0 for epoch in range(self.num_epochs): for _ in range(self.num_batches_per_epoch): self.qf.train(True) self._do_training() num_batches_total += 1 logger.push_prefix('Iteration #%d | ' % epoch) self.qf.train(False) self.evaluate(epoch) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.log("Done evaluating") logger.pop_prefix()
def create_policy(variant): bottom_snapshot = joblib.load(variant['bottom_path']) column_snapshot = joblib.load(variant['column_path']) policy = variant['combiner_class']( policy1=bottom_snapshot['naf_policy'], policy2=column_snapshot['naf_policy'], ) env = bottom_snapshot['env'] logger.save_itr_params(0, dict( policy=policy, env=env, )) path = rollout( env, policy, max_path_length=variant['max_path_length'], animated=variant['render'], ) env.log_diagnostics([path]) logger.dump_tabular()
def train(self): for epoch in range(self.num_epochs): logger.push_prefix('Iteration #%d | ' % epoch) start_time = time.time() for _ in range(self.num_steps_per_epoch): batch = self.get_batch() train_dict = self.get_train_dict(batch) self.policy_optimizer.zero_grad() policy_loss = train_dict['Policy Loss'] policy_loss.backward() self.policy_optimizer.step() logger.log("Train time: {}".format(time.time() - start_time)) start_time = time.time() self.evaluate(epoch) logger.log("Eval time: {}".format(time.time() - start_time)) params = self.get_epoch_snapshot(epoch) logger.save_itr_params(epoch, params) logger.pop_prefix()
def train_rfeatures_model(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule # from railrl.torch.vae.conv_vae import ( # ConvVAE, ConvResnetVAE # ) import railrl.torch.vae.conv_vae as conv_vae # from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_model import TimestepPredictionModel from railrl.launchers.experiments.ashvin.rfeatures.rfeatures_trainer import TimePredictionTrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch output_classes = variant["output_classes"] representation_size = variant["representation_size"] batch_size = variant["batch_size"] variant['dataset_kwargs']["output_classes"] = output_classes train_dataset, test_dataset, info = get_data(variant['dataset_kwargs']) num_train_workers = variant.get("num_train_workers", 0) # 0 uses main process (good for pdb) train_dataset_loader = InfiniteBatchLoader( train_dataset, batch_size=batch_size, num_workers=num_train_workers, ) test_dataset_loader = InfiniteBatchLoader( test_dataset, batch_size=batch_size, ) logger.save_extra_data(info) logger.get_snapshot_dir() if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['model_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['model_kwargs']['architecture'] = architecture model_class = variant.get('model_class', TimestepPredictionModel) model = model_class( representation_size, decoder_output_activation=decoder_activation, output_classes=output_classes, **variant['model_kwargs'], ) # model = torch.nn.DataParallel(model) model.to(ptu.device) variant['trainer_kwargs']['batch_size'] = batch_size trainer_class = variant.get('trainer_class', TimePredictionTrainer) trainer = trainer_class( model, **variant['trainer_kwargs'], ) save_period = variant['save_period'] trainer.dump_trajectory_rewards( "initial", dict(train=train_dataset.dataset, test=test_dataset.dataset)) dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset_loader, batches=10) trainer.test_epoch(epoch, test_dataset_loader, batches=1) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_trajectory_rewards( epoch, dict(train=train_dataset.dataset, test=test_dataset.dataset), should_save_imgs) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def get_n_train_vae(latent_dim, env, vae_train_epochs, num_image_examples, vae_kwargs, vae_trainer_kwargs, vae_architecture, vae_save_period=10, vae_test_p=.9, decoder_activation='sigmoid', vae_class='VAE', **kwargs): env.goal_sampling_mode = 'test' image_examples = unnormalize_image( env.sample_goals(num_image_examples)['desired_goal']) n = int(num_image_examples * vae_test_p) train_dataset = ImageObservationDataset(image_examples[:n, :]) test_dataset = ImageObservationDataset(image_examples[n:, :]) if decoder_activation == 'sigmoid': decoder_activation = torch.nn.Sigmoid() vae_class = vae_class.lower() if vae_class == 'VAE'.lower(): vae_class = ConvVAE elif vae_class == 'SpatialVAE'.lower(): vae_class = SpatialAutoEncoder else: raise RuntimeError("Invalid VAE Class: {}".format(vae_class)) vae = vae_class(latent_dim, architecture=vae_architecture, decoder_output_activation=decoder_activation, **vae_kwargs) trainer = ConvVAETrainer(vae, **vae_trainer_kwargs) logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) for epoch in range(vae_train_epochs): should_save_imgs = (epoch % vae_save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, vae) logger.save_extra_data(vae, 'vae.pkl', mode='pickle') logger.remove_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('progress.csv', relative_to_snapshot_dir=True) return vae
def test_epoch(self, epoch, save_reconstruction=True, save_interpolation=True, save_vae=True): self.model.eval() vae_losses = [] iwae_losses = [] losses = [] des = [] kles = [] linear_losses = [] zs = [] beta = float(self.beta_schedule.get_value(epoch)) for batch_idx in range(10): data = self.get_batch(train=False) obs = data['obs'] obs = obs.detach() next_obs = data['next_obs'] next_obs = next_obs.detach() actions = data['actions'] actions = actions.detach() x_recon, z_mu, z_logvar, z = self.model(next_obs, n_imp=25) x_recon = x_recon.detach() z_mu = z_mu.detach() z = z.detach() batch_size = x_recon.shape[0] k = x_recon.shape[1] x = next_obs.view((batch_size, 1, -1)).repeat(torch.Size([1, k, 1])) x = x.detach() vae_loss, de, kle = self.compute_vae_loss(x_recon, x, z_mu, z_logvar, z, beta) vae_loss, de, kle = vae_loss.detach(), de.detach(), kle.detach() iwae_loss = self.compute_iwae_loss(x_recon, x, z_mu, z_logvar, z, beta) iwae_loss = iwae_loss.detach() loss = vae_loss if self.use_linear_dynamics: linear_dynamics_loss = self.state_linearity_loss( obs, next_obs, actions) linear_dynamics_loss = linear_dynamics_loss.detach() loss += self.linearity_weight * linear_dynamics_loss linear_losses.append(float( linear_dynamics_loss.data[0])) #here too z_data = ptu.get_numpy(z_mu[:, 0].cpu()) for i in range(len(z_data)): zs.append(z_data[i, :].copy()) vae_losses.append(float(vae_loss.data[0])) iwae_losses.append(float(iwae_loss.data[0])) losses.append(float(loss.data[0])) des.append(float(de.data[0])) kles.append(float(kle.data[0])) if batch_idx == 0 and save_reconstruction: n = min(data['next_obs'].size(0), 16) comparison = torch.cat([ data['next_obs'][:n].narrow(start=0, length=self.imlength, dimension=1).contiguous().view( -1, self.input_channels, self.imsize, self.imsize), x_recon[:, 0].contiguous().view( self.batch_size, self.input_channels, self.imsize, self.imsize, )[:n] ]) save_dir = osp.join(logger.get_snapshot_dir(), 'r_%d.png' % epoch) save_image(comparison.data.cpu(), save_dir, nrow=n) del comparison if batch_idx == 0 and save_interpolation: n = min(data['next_obs'].size(0), 10) z1 = z_mu[:n, 0] z2 = z_mu[n:2 * n, 0] num_steps = 8 z_interp = [] for i in np.linspace(0.0, 1.0, num_steps): z_interp.append(float(i) * z1 + float(1 - i) * z2) z_interp = torch.cat(z_interp) imgs = self.model.decode(z_interp) imgs = imgs.view((num_steps, n, 3, self.imsize, self.imsize)) imgs = imgs.permute([1, 0, 2, 3, 4]) imgs = imgs.contiguous().view( (n * num_steps, 3, self.imsize, self.imsize)) save_dir = osp.join(logger.get_snapshot_dir(), 'i_%d.png' % epoch) save_image( imgs.data.cpu(), save_dir, nrow=num_steps, ) del imgs del z_interp del obs, next_obs, actions, x_recon, z_mu, z_logvar, \ z, x, vae_loss, de, kle, loss zs = np.array(zs) self.model.dist_mu = zs.mean(axis=0) self.model.dist_std = zs.std(axis=0) del zs logger.record_tabular("test/decoder_loss", np.mean(des)) logger.record_tabular("test/KL", np.mean(kles)) if self.use_linear_dynamics: logger.record_tabular("test/linear_loss", np.mean(linear_losses)) logger.record_tabular("test/loss", np.mean(losses)) logger.record_tabular("test/vae_loss", np.mean(vae_losses)) logger.record_tabular("test/iwae_loss", np.mean(iwae_losses)) logger.record_tabular( "test/iwae_vae_diff", np.mean(np.array(iwae_losses) - np.array(vae_losses))) logger.record_tabular("beta", beta) process = psutil.Process(os.getpid()) logger.record_tabular("RAM Usage (Mb)", int(process.memory_info().rss / 1000000)) num_active_dims = 0 num_active_dims2 = 0 for std in self.model.dist_std: if std > 0.15: num_active_dims += 1 if std > 0.05: num_active_dims2 += 1 logger.record_tabular("num_active_dims", num_active_dims) logger.record_tabular("num_active_dims2", num_active_dims2) logger.dump_tabular() if save_vae: logger.save_itr_params(epoch, self.model, prefix='vae', save_anyway=True) # slow...
def train( dataset_generator, n_start_samples, projection=project_samples_square_np, n_samples_to_add_per_epoch=1000, n_epochs=100, z_dim=1, hidden_size=32, save_period=10, append_all_data=True, full_variant=None, dynamics_noise=0, decoder_output_var='learned', num_bins=5, skew_config=None, use_perfect_samples=False, use_perfect_density=False, vae_reset_period=0, vae_kwargs=None, use_dataset_generator_first_epoch=True, **kwargs ): """ Sanitize Inputs """ assert skew_config is not None if not (use_perfect_density and use_perfect_samples): assert vae_kwargs is not None if vae_kwargs is None: vae_kwargs = {} report = HTMLReport( logger.get_snapshot_dir() + '/report.html', images_per_row=10, ) dynamics = Dynamics(projection, dynamics_noise) if full_variant: report.add_header("Variant") report.add_text( json.dumps( ppp.dict_to_safe_json( full_variant, sort=True), indent=2, ) ) vae, decoder, decoder_opt, encoder, encoder_opt = get_vae( decoder_output_var, hidden_size, z_dim, vae_kwargs, ) vae.to(ptu.device) epochs = [] losses = [] kls = [] log_probs = [] hist_heatmap_imgs = [] vae_heatmap_imgs = [] sample_imgs = [] entropies = [] tvs_to_uniform = [] entropy_gains_from_reweighting = [] p_theta = Histogram(num_bins) p_new = Histogram(num_bins) orig_train_data = dataset_generator(n_start_samples) train_data = orig_train_data start = time.time() for epoch in progressbar(range(n_epochs)): p_theta = Histogram(num_bins) if epoch == 0 and use_dataset_generator_first_epoch: vae_samples = dataset_generator(n_samples_to_add_per_epoch) else: if use_perfect_samples and epoch != 0: # Ideally the VAE = p_new, but in practice, it won't be... vae_samples = p_new.sample(n_samples_to_add_per_epoch) else: vae_samples = vae.sample(n_samples_to_add_per_epoch) projected_samples = dynamics(vae_samples) if append_all_data: train_data = np.vstack((train_data, projected_samples)) else: train_data = np.vstack((orig_train_data, projected_samples)) p_theta.fit(train_data) if use_perfect_density: prob = p_theta.compute_density(train_data) else: prob = vae.compute_density(train_data) all_weights = prob_to_weight(prob, skew_config) p_new.fit(train_data, weights=all_weights) if epoch == 0 or (epoch + 1) % save_period == 0: epochs.append(epoch) report.add_text("Epoch {}".format(epoch)) hist_heatmap_img = visualize_histogram(p_theta, skew_config, report) vae_heatmap_img = visualize_vae( vae, skew_config, report, resolution=num_bins, ) sample_img = visualize_vae_samples( epoch, train_data, vae, report, dynamics, ) visualize_samples( p_theta.sample(n_samples_to_add_per_epoch), report, title="P Theta/RB Samples", ) visualize_samples( p_new.sample(n_samples_to_add_per_epoch), report, title="P Adjusted Samples", ) hist_heatmap_imgs.append(hist_heatmap_img) vae_heatmap_imgs.append(vae_heatmap_img) sample_imgs.append(sample_img) report.save() Image.fromarray(hist_heatmap_img).save( logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch) ) Image.fromarray(vae_heatmap_img).save( logger.get_snapshot_dir() + '/hist_heatmap{}.png'.format(epoch) ) Image.fromarray(sample_img).save( logger.get_snapshot_dir() + '/samples{}.png'.format(epoch) ) """ train VAE to look like p_new """ if sum(all_weights) == 0: all_weights[:] = 1 if vae_reset_period > 0 and epoch % vae_reset_period == 0: vae, decoder, decoder_opt, encoder, encoder_opt = get_vae( decoder_output_var, hidden_size, z_dim, vae_kwargs, ) vae.to(ptu.device) vae.fit(train_data, weights=all_weights) epoch_stats = vae.get_epoch_stats() losses.append(np.mean(epoch_stats['losses'])) kls.append(np.mean(epoch_stats['kls'])) log_probs.append(np.mean(epoch_stats['log_probs'])) entropies.append(p_theta.entropy()) tvs_to_uniform.append(p_theta.tv_to_uniform()) entropy_gain = p_new.entropy() - p_theta.entropy() entropy_gains_from_reweighting.append(entropy_gain) for k in sorted(epoch_stats.keys()): logger.record_tabular(k, epoch_stats[k]) logger.record_tabular("Epoch", epoch) logger.record_tabular('Entropy ', p_theta.entropy()) logger.record_tabular('KL from uniform', p_theta.kl_from_uniform()) logger.record_tabular('TV to uniform', p_theta.tv_to_uniform()) logger.record_tabular('Entropy gain from reweight', entropy_gain) logger.record_tabular('Total Time (s)', time.time() - start) logger.dump_tabular() logger.save_itr_params(epoch, { 'vae': vae, 'train_data': train_data, 'vae_samples': vae_samples, 'dynamics': dynamics, }) report.add_header("Training Curves") plot_curves( [ ("Training Loss", losses), ("KL", kls), ("Log Probs", log_probs), ("Entropy Gain from Reweighting", entropy_gains_from_reweighting), ], report, ) plot_curves( [ ("Entropy", entropies), ("TV to Uniform", tvs_to_uniform), ], report, ) report.add_text("Max entropy: {}".format(p_theta.max_entropy())) report.save() for filename, imgs in [ ("hist_heatmaps", hist_heatmap_imgs), ("vae_heatmaps", vae_heatmap_imgs), ("samples", sample_imgs), ]: video = np.stack(imgs) vwrite( logger.get_snapshot_dir() + '/{}.mp4'.format(filename), video, ) local_gif_file_path = '{}.gif'.format(filename) gif_file_path = '{}/{}'.format( logger.get_snapshot_dir(), local_gif_file_path ) gif(gif_file_path, video) report.add_image(local_gif_file_path, txt=filename, is_url=True) report.save()
def experiment(variant): num_rollouts = variant['num_rollouts'] path = variant['qf_path'] data = joblib.load(path) goal_conditioned_model = data['qf'] env = data['env'] argmax_qf_policy = data['policy'] extra_data_path = Path(path).parent / 'extra_data.pkl' extra_data = joblib.load(str(extra_data_path)) replay_buffer = extra_data['replay_buffer'] """ Train amortized policy """ # goal_chooser = Mlp( # output_size=env.goal_dim, # input_size=int(env.observation_space.flat_dim), # hidden_sizes=[100, 100], # ) # goal_chooser = ReacherGoalChooser( # hidden_sizes=[64, 64], # ) goal_chooser = UniversalGoalChooser(input_goal_dim=7, output_goal_dim=env.goal_dim, obs_dim=int( env.observation_space.flat_dim), **variant['goal_chooser_params']) tau = variant['tau'] if ptu.gpu_enabled(): goal_chooser.to(ptu.device) goal_conditioned_model.to(ptu.device) argmax_qf_policy.to(ptu.device) train_amortized_goal_chooser(goal_chooser, goal_conditioned_model, argmax_qf_policy, tau, replay_buffer, **variant['train_params']) policy = AmortizedPolicy(argmax_qf_policy, goal_chooser) goal = np.array(variant['goal']) logger.save_itr_params( 0, dict( env=env, policy=policy, goal_chooser=goal_chooser, goal=goal, )) """ Eval policy. """ paths = [] # env.set_goal(goal) for _ in range(num_rollouts): # path = rollout( # env, # policy, # **variant['rollout_params'] # ) # goal_expanded = np.expand_dims(goal, axis=0) # path['goal_states'] = goal_expanded.repeat(len(path['observations']), 0) goal = env.sample_goal_for_rollout() path = multitask_rollout(env, policy, goal, **variant['rollout_params']) paths.append(path) env.log_diagnostics(paths) logger.dump_tabular(with_timestamp=False)
def pretrain_policy_with_bc(self): logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('pretrain_policy.csv', relative_to_snapshot_dir=True) for i in range(self.bc_num_pretrain_steps): train_batch = self.get_batch_from_buffer(self.demo_train_buffer) train_o = train_batch["observations"] train_u = train_batch["actions"] if self.goal_conditioned: train_g = train_batch["resampled_goals"] train_o = torch.cat((train_o, train_g), dim=1) train_pred_u = self.policy(train_o) train_error = (train_pred_u - train_u)**2 train_bc_loss = train_error.mean() policy_loss = self.bc_weight * train_bc_loss.mean() self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() test_batch = self.get_batch_from_buffer(self.demo_test_buffer) test_o = test_batch["observations"] test_u = test_batch["actions"] if self.goal_conditioned: test_g = test_batch["resampled_goals"] test_o = torch.cat((test_o, test_g), dim=1) test_pred_u = self.policy(test_o) test_error = (test_pred_u - test_u)**2 test_bc_loss = test_error.mean() train_loss_mean = np.mean(ptu.get_numpy(train_bc_loss)) test_loss_mean = np.mean(ptu.get_numpy(test_bc_loss)) stats = { "Train BC Loss": train_loss_mean, "Test BC Loss": test_loss_mean, "policy_loss": ptu.get_numpy(policy_loss), "batch": i, } logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) if i % 1000 == 0: logger.save_itr_params( i, { "evaluation/policy": self.policy, "evaluation/env": self.env.wrapped_env, }) logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, )
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant.get("representation_size", variant.get("latent_sizes", None)) use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant[ 'algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model