def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs'] ) uniform_dataset=generate_uniform_dataset_reacher( **variant['generate_uniform_dataset_kwargs'] ) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform(m, uniform_dataset) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch) if epoch % variant['train_weight_update_period'] == 0: t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) gpu_id = variant.get("gpu_id", None) if gpu_id is not None: ptu.set_device(gpu_id) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def dump_latent_plots(vae_env, epoch): from railrl.core import logger import os.path as osp from torchvision.utils import save_image if getattr(vae_env, "get_states_sweep", None) is None: return nx, ny = (vae_env.vis_granularity, vae_env.vis_granularity) states_sweep = vae_env.get_states_sweep(nx, ny) sweep_latents_mu, sweep_latents_logvar = vae_env.encode_states(states_sweep, clip_std=False) sweep_latents_std = np.exp(0.5*sweep_latents_logvar) sweep_latents_sample = vae_env.reparameterize(sweep_latents_mu, sweep_latents_logvar, noisy=True) images_mu_sc, images_std_sc, images_sample_sc = [], [], [] imsize = 84 for i in range(sweep_latents_mu.shape[1]): image_mu_sc = vae_env.transform_image(vae_env.get_image_plt( sweep_latents_mu[:,i].reshape((nx, ny)), vmin=-2.0, vmax=2.0, draw_state=False, imsize=imsize)) images_mu_sc.append(image_mu_sc) image_std_sc = vae_env.transform_image(vae_env.get_image_plt( sweep_latents_std[:,i].reshape((nx, ny)), vmin=0.0, vmax=2.0, draw_state=False, imsize=imsize)) images_std_sc.append(image_std_sc) image_sample_sc = vae_env.transform_image(vae_env.get_image_plt( sweep_latents_sample[:,i].reshape((nx, ny)), vmin=-3.0, vmax=3.0, draw_state=False, imsize=imsize)) images_sample_sc.append(image_sample_sc) images = images_mu_sc + images_std_sc + images_sample_sc images = np.array(images) if vae_env.representation_size > 16: nrow = 16 else: nrow = vae_env.representation_size if epoch is not None: save_dir = osp.join(logger.get_snapshot_dir(), 'z_%d.png' % epoch) else: save_dir = osp.join(logger.get_snapshot_dir(), 'z.png') save_image( ptu.FloatTensor( ptu.from_numpy( images.reshape( (vae_env.representation_size*3, -1, imsize, imsize) ))), save_dir, nrow=nrow, )
def save_paths(algo, epoch): expl_paths = algo.expl_data_collector.get_epoch_paths() filename = osp.join(logger.get_snapshot_dir(), 'video_{epoch}_vae.p'.format(epoch=epoch)) pickle.dump(expl_paths, open(filename, "wb")) print("saved", filename) eval_paths = algo.eval_data_collector.get_epoch_paths() filename = osp.join(logger.get_snapshot_dir(), 'video_{epoch}_env.p'.format(epoch=epoch)) pickle.dump(eval_paths, open(filename, "wb")) print("saved", filename)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) uniform_dataset = load_local_or_remote_file( variant['uniform_dataset_path']).item() uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform( m, uniform_dataset, variant['algo_kwargs']['priority_function_kwargs']) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions( dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] # train_data, test_data, info = generate_vae_dataset( # **variant['get_data_kwargs'] # ) num_divisions = 5 images = np.zeros((num_divisions * 10000, 21168)) for i in range(num_divisions): imgs = np.load( '/home/murtaza/vae_data/sawyer_torque_control_images100000_' + str(i + 1) + '.npy') images[i * 10000:(i + 1) * 10000] = imgs print(i) mid = int(num_divisions * 10000 * .9) train_data, test_data = images[:mid], images[mid:] info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: kwargs = variant['beta_schedule_kwargs'] kwargs['y_values'][2] = variant['beta'] kwargs['x_values'][1] = variant['flat_x'] kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def dump_reconstructions(vae_env, epoch, n_recon=16): from railrl.core import logger import os.path as osp from torchvision.utils import save_image if vae_env.use_vae_dataset and vae_env.vae_dataset_path is not None: from multiworld.core.image_env import normalize_image from railrl.misc.asset_loader import local_path_from_s3_or_local_path filename = local_path_from_s3_or_local_path(vae_env.vae_dataset_path) dataset = np.load(filename).item() sampled_idx = np.random.choice(dataset['next_obs'].shape[0], n_recon) if vae_env.vae_input_key_prefix == 'state': states = dataset['next_obs'][sampled_idx] imgs = ptu.np_to_var( vae_env.wrapped_env.states_to_images(states) ) recon_samples, _, _ = vae_env.vae(ptu.np_to_var(states)) recon_imgs = ptu.np_to_var( vae_env.wrapped_env.states_to_images(ptu.get_numpy(recon_samples)) ) else: imgs = ptu.np_to_var( normalize_image(dataset['next_obs'][sampled_idx]) ) recon_imgs, _, _, _ = vae_env.vae(imgs) del dataset else: return comparison = torch.cat([ imgs.narrow(start=0, length=vae_env.wrapped_env.image_length, dimension=1).contiguous().view( -1, vae_env.wrapped_env.channels, vae_env.wrapped_env.imsize, vae_env.wrapped_env.imsize ), recon_imgs.contiguous().view( n_recon, vae_env.wrapped_env.channels, vae_env.wrapped_env.imsize, vae_env.wrapped_env.imsize )[:n_recon] ]) if epoch is not None: save_dir = osp.join(logger.get_snapshot_dir(), 'r_%d.png' % epoch) else: save_dir = osp.join(logger.get_snapshot_dir(), 'r.png') save_image(comparison.data.cpu(), save_dir, nrow=n_recon)
def experiment(variant): num_feat_points=variant['feat_points'] from railrl.core import logger beta = variant["beta"] print('collecting data') train_data, test_data, info = get_data(**variant['get_data_kwargs']) print('finish collecting data') logger.save_extra_data(info) logger.get_snapshot_dir() m = SpatialAutoEncoder(2 * num_feat_points, num_feat_points, input_channels=3) # m = ConvVAE(2*num_feat_points, input_channels=3) t = ConvVAETrainer(train_data, test_data, m, lr=variant['lr'], beta=beta) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def __init__(self, variant): self.logdir = logger.get_snapshot_dir() self.dump_buffer_kwargs = variant.get("dump_buffer_kwargs", dict()) self.save_period = self.dump_buffer_kwargs.pop('dump_buffer_period', 50) self.buffer_dir = osp.join(self.logdir, 'buffers') if not osp.exists(self.buffer_dir): os.makedirs(self.buffer_dir)
def run_task(variant): from railrl.core import logger print(variant) logger.log("Hello from script") logger.log("variant: " + str(variant)) logger.record_tabular("value", 1) logger.dump_tabular() logger.log("snapshot_dir:", logger.get_snapshot_dir())
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] data = joblib.load(variant['file']) obs = data['obs'] size = int(data['size']) dataset = obs[:size, :] n = int(size * .9) train_data = dataset[:n, :] test_data = dataset[n:, :] logger.get_snapshot_dir() print('SIZE: ', size) uniform_dataset = generate_uniform_dataset_door( **variant['generate_uniform_dataset_kwargs'] ) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform(uniform_dataset) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def pretrain_policy_with_bc(self): logger.remove_tabular_output( 'progress.csv', relative_to_snapshot_dir=True ) logger.add_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True ) if self.do_pretrain_rollouts: total_ret = self.do_rollouts() print("INITIAL RETURN", total_ret/20) prev_time = time.time() for i in range(self.bc_num_pretrain_steps): train_policy_loss, train_logp_loss, train_mse_loss, train_log_std = self.run_bc_batch(self.demo_train_buffer, self.policy) train_policy_loss = train_policy_loss * self.bc_weight self.policy_optimizer.zero_grad() train_policy_loss.backward() self.policy_optimizer.step() test_policy_loss, test_logp_loss, test_mse_loss, test_log_std = self.run_bc_batch(self.demo_test_buffer, self.policy) test_policy_loss = test_policy_loss * self.bc_weight if self.do_pretrain_rollouts and i % self.pretraining_env_logging_period == 0: total_ret = self.do_rollouts() print("Return at step {} : {}".format(i, total_ret/20)) if i % self.pretraining_logging_period==0: stats = { "pretrain_bc/batch": i, "pretrain_bc/Train Logprob Loss": ptu.get_numpy(train_logp_loss), "pretrain_bc/Test Logprob Loss": ptu.get_numpy(test_logp_loss), "pretrain_bc/Train MSE": ptu.get_numpy(train_mse_loss), "pretrain_bc/Test MSE": ptu.get_numpy(test_mse_loss), "pretrain_bc/train_policy_loss": ptu.get_numpy(train_policy_loss), "pretrain_bc/test_policy_loss": ptu.get_numpy(test_policy_loss), "pretrain_bc/epoch_time":time.time()-prev_time, } if self.do_pretrain_rollouts: stats["pretrain_bc/avg_return"] = total_ret / 20 logger.record_dict(stats) logger.dump_tabular(with_prefix=True, with_timestamp=False) pickle.dump(self.policy, open(logger.get_snapshot_dir() + '/bc.pkl', "wb")) prev_time = time.time() logger.remove_tabular_output( 'pretrain_policy.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) if self.post_bc_pretrain_hyperparams: self.set_algorithm_weights(**self.post_bc_pretrain_hyperparams)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] #this has both states and images so can't use generate vae dataset X = np.load( '/home/murtaza/vae_data/sawyer_torque_control_ou_imgs_zoomed_out10000.npy' ) Y = np.load( '/home/murtaza/vae_data/sawyer_torque_control_ou_states_zoomed_out10000.npy' ) Y = np.concatenate((Y[:, :7], Y[:, 14:]), axis=1) X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.1) info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, state_sim_debug=True, state_size=Y.shape[1], **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer((X_train, Y_train), (X_test, Y_test), m, beta=beta, beta_schedule=beta_schedule, state_sim_debug=True, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def plot_encoder_function(variant, encoder, tag=""): import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from railrl.core import logger logdir = logger.get_snapshot_dir() def plot_encoder(algo, epoch, is_x=False): save_period = variant.get('save_video_period', 50) if epoch % save_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'encoder_{}_{}_{epoch}_env.gif'.format(tag, "x" if is_x else "y", epoch=epoch)) vary = np.arange(-4, 4, .1) static = np.zeros(len(vary)) points_x = np.c_[vary.reshape(-1, 1), static.reshape(-1, 1)] points_y = np.c_[static.reshape(-1, 1), vary.reshape(-1, 1)] encoded_points_x = ptu.get_numpy( encoder.forward(ptu.from_numpy(points_x))) encoded_points_y = ptu.get_numpy( encoder.forward(ptu.from_numpy(points_y))) plt.clf() fig = plt.figure() plt.xlim( min(min(encoded_points_x[:, 0]), min(encoded_points_y[:, 0])), max(max(encoded_points_x[:, 0]), max(encoded_points_y[:, 0]))) plt.ylim( min(min(encoded_points_x[:, 1]), min(encoded_points_y[:, 1])), max(max(encoded_points_x[:, 1]), max(encoded_points_y[:, 1]))) colors = ["red", "blue"] lines = [ plt.plot([], [], 'o', color=colors[i], alpha=0.4)[0] for i in range(2) ] def animate(i): lines[0].set_data(encoded_points_x[:i + 1, 0], encoded_points_x[:i + 1, 1]) lines[1].set_data(encoded_points_y[:i + 1, 0], encoded_points_y[:i + 1, 1]) return lines ani = FuncAnimation(fig, animate, frames=len(vary), interval=40) ani.save(filename, writer='imagemagick', fps=60) # def plot_encoder_x_and_y(algo, epoch): # plot_encoder(algo, epoch, is_x=True) # plot_encoder(algo, epoch, is_x=False) return plot_encoder
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu ptu.set_gpu_mode(True) info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() net = CNN(**variant['cnn_kwargs']) net.cuda() num_divisions = variant['num_divisions'] images = np.zeros((num_divisions * 10000, 21168)) states = np.zeros((num_divisions * 10000, 7)) for i in range(num_divisions): imgs = np.load( '/home/murtaza/vae_data/sawyer_torque_control_images100000_' + str(i + 1) + '.npy') state = np.load( '/home/murtaza/vae_data/sawyer_torque_control_states100000_' + str(i + 1) + '.npy')[:, :7] % (2 * np.pi) images[i * 10000:(i + 1) * 10000] = imgs states[i * 10000:(i + 1) * 10000] = state print(i) if variant['normalize']: std = np.std(states, axis=0) mu = np.mean(states, axis=0) states = np.divide((states - mu), std) print(mu, std) mid = int(num_divisions * 10000 * .9) train_images, test_images = images[:mid], images[mid:] train_labels, test_labels = states[:mid], states[mid:] algo = SupervisedAlgorithm(train_images, test_images, train_labels, test_labels, net, batch_size=variant['batch_size'], lr=variant['lr'], weight_decay=variant['weight_decay']) for epoch in range(variant['num_epochs']): algo.train_epoch(epoch) algo.test_epoch(epoch)
def dump_samples(vae_env, epoch, n_samples=64): from railrl.core import logger from torchvision.utils import save_image import os.path as osp vae_env.vae.eval() sample = ptu.Variable(torch.randn(n_samples, vae_env.representation_size)) sample = vae_env.vae.decode(sample).cpu() if vae_env.vae_input_key_prefix == 'state': sample = ptu.np_to_var(vae_env.wrapped_env.states_to_images(ptu.get_numpy(sample))) if sample is None: return if epoch is not None: save_dir = osp.join(logger.get_snapshot_dir(), 's_%d.png' % epoch) else: save_dir = osp.join(logger.get_snapshot_dir(), 's.png') save_image( sample.data.view(n_samples, -1, vae_env.wrapped_env.imsize, vae_env.wrapped_env.imsize), save_dir, nrow=int(np.sqrt(n_samples)) )
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights()
def get_video_save_func(rollout_function, env, policy, variant): from multiworld.core.image_env import ImageEnv from railrl.core import logger from railrl.envs.vae_wrappers import temporary_mode from railrl.visualization.video import dump_video logdir = logger.get_snapshot_dir() save_period = variant.get('save_video_period', 50) do_state_exp = variant.get("do_state_exp", False) dump_video_kwargs = variant.get("dump_video_kwargs", dict()) dump_video_kwargs['horizon'] = variant['max_path_length'] if do_state_exp: imsize = variant.get('imsize') dump_video_kwargs['imsize'] = imsize image_env = ImageEnv( env, imsize, init_camera=variant.get('init_camera', None), transpose=True, normalize=True, ) def save_video(algo, epoch): if epoch % save_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch)) dump_video(image_env, policy, filename, rollout_function, **dump_video_kwargs) else: image_env = env dump_video_kwargs['imsize'] = env.imsize def save_video(algo, epoch): if epoch % save_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'video_{epoch}_env.mp4'.format(epoch=epoch)) temporary_mode(image_env, mode='video_env', func=dump_video, args=(image_env, policy, filename, rollout_function), kwargs=dump_video_kwargs) filename = osp.join( logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch)) temporary_mode(image_env, mode='video_vae', func=dump_video, args=(image_env, policy, filename, rollout_function), kwargs=dump_video_kwargs) return save_video
def visualization_post_processing(save_vis, save_video, epoch): vis_list = vis_variant.get('vis_list', []) if save_vis: if vae_env.vae_input_key_prefix == 'state': vae_env.dump_reconstructions(epoch, n_recon=vis_variant.get( 'n_recon', 16)) vae_env.dump_samples(epoch, n_samples=vis_variant.get('n_samples', 64)) if 'latent_representation' in vis_list: vae_env.dump_latent_plots(epoch) if any(elem in vis_list for elem in [ 'latent_histogram', 'latent_histogram_mu', 'latent_histogram_2d', 'latent_histogram_mu_2d' ]): vae_env.compute_latent_histogram() if not save_video and ('latent_histogram' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=True, use_true_prior=True) if not save_video and ('latent_histogram_mu' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=False, use_true_prior=True) if save_video and save_vis: from railrl.envs.vae_wrappers import temporary_mode from railrl.misc.video_gen import dump_video from railrl.core import logger vae_env.compute_goal_encodings() logdir = logger.get_snapshot_dir() filename = osp.join(logdir, 'video_{epoch}.mp4'.format(epoch=epoch)) variant['dump_video_kwargs']['epoch'] = epoch temporary_mode(vae_env, mode='video_env', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) if not vis_variant.get('save_video_env_only', True): filename = osp.join( logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch)) temporary_mode(vae_env, mode='video_vae', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs'])
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = get_data(**variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def train_reprojection_network(variant): from railrl.torch.vae.reprojection_network import ( ReprojectionNetwork, ReprojectionNetworkTrainer, ) from railrl.core import logger import railrl.torch.pytorch_util as ptu logger.get_snapshot_dir() vae = variant['vae'] generate_reprojection_network_dataset_kwargs = variant.get( "generate_reprojection_network_dataset_kwargs", {}) generate_reprojection_network_dataset_kwargs['vae'] = vae train_data, test_data = generate_reprojection_network_dataset( generate_reprojection_network_dataset_kwargs) reprojection_network_kwargs = variant.get("reprojection_network_kwargs", {}) m = ReprojectionNetwork(vae, **reprojection_network_kwargs) if ptu.gpu_enabled(): m.cuda() algo_kwargs = variant.get("algo_kwargs", {}) t = ReprojectionNetworkTrainer(train_data, test_data, m, **algo_kwargs) num_epochs = variant.get('num_epochs', 5000) for epoch in range(num_epochs): should_save_network = (epoch % 250 == 0 or epoch == num_epochs - 1) t.train_epoch(epoch) t.test_epoch( epoch, save_network=should_save_network, ) logger.save_extra_data(m, 'reproj_network.pkl', mode='pickle') return m
def dump_latent_histogram(vae_env, epoch, noisy=False, reproj=False, use_true_prior=None, draw_dots=False): from railrl.core import logger import os.path as osp from torchvision.utils import save_image images = vae_env.get_image_latent_histogram( noisy=noisy, reproj=reproj, draw_dots=draw_dots, use_true_prior=use_true_prior ) if noisy: prefix = 'h' elif reproj: prefix = 'h_r' else: prefix = 'h_mu' if epoch is None: save_dir = osp.join(logger.get_snapshot_dir(), prefix + '.png') else: save_dir = osp.join(logger.get_snapshot_dir(), prefix + '_%d.png' % epoch) save_image( ptu.FloatTensor(ptu.from_numpy(images)), save_dir, nrow=int(np.sqrt(images.shape[0])), )
def plot_buffer_function(save_period, buffer_key): import matplotlib.pyplot as plt from railrl.core import logger logdir = logger.get_snapshot_dir() def plot_buffer(algo, epoch): replay_buffer = algo.replay_buffer if epoch % save_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, '{}_buffer_{epoch}_env.png'.format(buffer_key, epoch=epoch)) goals = replay_buffer._next_obs[buffer_key][:replay_buffer._size] plt.clf() plt.scatter(goals[:, 0], goals[:, 1], alpha=0.2) plt.savefig(filename) return plot_buffer
def __init__(self, env, variant, expl_path_collector=None, eval_path_collector=None): self.env = env self.logdir = logger.get_snapshot_dir() self.dump_video_kwargs = variant.get("dump_video_kwargs", dict()) if 'imsize' not in self.dump_video_kwargs: self.dump_video_kwargs['imsize'] = env.imsize self.dump_video_kwargs.setdefault("rows", 2) self.dump_video_kwargs.setdefault("columns", 5) self.dump_video_kwargs.setdefault("unnormalize", True) self.save_period = self.dump_video_kwargs.pop('save_video_period', 50) self.exploration_goal_image_key = self.dump_video_kwargs.pop( "exploration_goal_image_key", "decoded_goal_image") self.evaluation_goal_image_key = self.dump_video_kwargs.pop( "evaluation_goal_image_key", "image_desired_goal") self.expl_path_collector = expl_path_collector self.eval_path_collector = eval_path_collector self.variant = variant
def get_save_video_function(rollout_function, env, policy, save_video_period=10, imsize=48, tag="", video_image_env_kwargs=None, **dump_video_kwargs): logdir = logger.get_snapshot_dir() if not isinstance(env, ImageEnv) and not isinstance(env, VAEWrappedEnv): if video_image_env_kwargs is None: video_image_env_kwargs = {} image_env = ImageEnv(env, imsize, transpose=True, normalize=True, **video_image_env_kwargs) else: image_env = env assert image_env.imsize == imsize, "Imsize must match env imsize" def save_video(algo, epoch): if epoch % save_video_period == 0 or epoch == algo.num_epochs: filename = osp.join( logdir, 'video_{}_{epoch}_env.mp4'.format(tag, epoch=epoch), ) dump_video(image_env, policy, filename, rollout_function, imsize=imsize, **dump_video_kwargs) return save_video