def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) gpu_id = variant.get("gpu_id", None) if gpu_id is not None: ptu.set_device(gpu_id) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_dataset( ) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: kwargs = variant['beta_schedule_kwargs'] kwargs['y_values'][2] = variant['beta'] kwargs['x_values'][1] = variant['flat_x'] kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None output_scale=1 if variant['algo_kwargs']['is_auto_encoder']: m = AutoEncoder(representation_size, train_data.shape[1], output_scale=output_scale, **variant['vae_kwargs'] ) else: m = VAE(representation_size, train_data.shape[1], output_scale=output_scale, **variant['vae_kwargs'] ) t = VAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger beta = variant["beta"] use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] else: action_dim = 0 model = get_vae(variant, action_dim) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) uniform_dataset = load_local_or_remote_file( variant['uniform_dataset_path']).item() uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform( m, uniform_dataset, variant['algo_kwargs']['priority_function_kwargs']) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions( dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] # train_data, test_data, info = generate_vae_dataset( # **variant['get_data_kwargs'] # ) num_divisions = 5 images = np.zeros((num_divisions * 10000, 21168)) for i in range(num_divisions): imgs = np.load( '/home/murtaza/vae_data/sawyer_torque_control_images100000_' + str(i + 1) + '.npy') images[i * 10000:(i + 1) * 10000] = imgs print(i) mid = int(num_divisions * 10000 * .9) train_data, test_data = images[:mid], images[mid:] info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: kwargs = variant['beta_schedule_kwargs'] kwargs['y_values'][2] = variant['beta'] kwargs['x_values'][1] = variant['flat_x'] kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] data = joblib.load(variant['file']) obs = data['obs'] size = int(data['size']) dataset = obs[:size, :] n = int(size * .9) train_data = dataset[:n, :] test_data = dataset[n:, :] logger.get_snapshot_dir() print('SIZE: ', size) uniform_dataset = generate_uniform_dataset_door( **variant['generate_uniform_dataset_kwargs'] ) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform(uniform_dataset) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] #this has both states and images so can't use generate vae dataset X = np.load( '/home/murtaza/vae_data/sawyer_torque_control_ou_imgs_zoomed_out10000.npy' ) Y = np.load( '/home/murtaza/vae_data/sawyer_torque_control_ou_states_zoomed_out10000.npy' ) Y = np.concatenate((Y[:, :7], Y[:, 14:]), axis=1) X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.1) info = dict() logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, state_sim_debug=True, state_size=Y.shape[1], **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer((X_train, Y_train), (X_test, Y_test), m, beta=beta, beta_schedule=beta_schedule, state_sim_debug=True, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def experiment(variant): if variant["use_gpu"]: gpu_id = variant["gpu_id"] ptu.set_gpu_mode(True) ptu.set_device(gpu_id) beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data = get_data(10000) m = ConvVAE(representation_size, input_channels=3) t = ConvVAETrainer(train_data, test_data, m, beta_schedule=PiecewiseLinearSchedule([0, 400, 800], [0.5, 0.5, beta])) for epoch in range(1001): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = get_data(**variant['get_data_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) m = ConvVAE(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) for epoch in range(variant['num_epochs']): t.train_epoch(epoch) t.test_epoch(epoch) t.dump_samples(epoch)
def experiment(variant): lmbda = variant['lmbda'] gamma = variant['gamma'] mu = variant['mu'] beta = PiecewiseLinearSchedule([0, 2500, 3500], [0, 0, variant['beta']]) representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['generate_vae_dataset_kwargs']) m = ACAI(representation_size, input_channels=3) t = ACAITrainer(train_data, test_data, m, beta_schedule=beta, gamma=gamma, mu=mu, lmbda=lmbda) for epoch in range(6001): t.train_epoch(epoch) t.test_epoch(epoch) if epoch % variant['save_period'] == 0: t.dump_samples(epoch)
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = ConvVAE(representation_size, input_channels=3, **variant['conv_vae_kwargs']) if ptu.gpu_enabled(): m.cuda() t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset_from_demos( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): m = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): raise NotImplementedError( 'This is currently broken, please update SpatialAutoEncoder then remove this line' ) m = SpatialAutoEncoder(representation_size, int(representation_size / 2)) else: vae_class = variant.get('vae_class', ConvVAE) m = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) if dump_skew_debug_plots: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') if return_data: return m, train_data, test_data return m
def train_vae(variant): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ConvVAE from railrl.torch.vae.conv_vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from multiworld.core.image_env import ImageEnv from railrl.envs.vae_wrappers import VAEWrappedEnv from railrl.misc.asset_loader import local_path_from_s3_or_local_path logger.remove_tabular_output('progress.csv', relative_to_snapshot_dir=True) logger.add_tabular_output('vae_progress.csv', relative_to_snapshot_dir=True) env_id = variant['generate_vae_dataset_kwargs'].get('env_id', None) if env_id is not None: import gym env = gym.make(env_id) else: env_class = variant['generate_vae_dataset_kwargs']['env_class'] env_kwargs = variant['generate_vae_dataset_kwargs']['env_kwargs'] env = env_class(**env_kwargs) representation_size = variant["representation_size"] beta = variant["beta"] if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None # obtain training and testing data dataset_path = variant['generate_vae_dataset_kwargs'].get( 'dataset_path', None) test_p = variant['generate_vae_dataset_kwargs'].get('test_p', 0.9) filename = local_path_from_s3_or_local_path(dataset_path) dataset = np.load(filename, allow_pickle=True).item() N = dataset['obs'].shape[0] n = int(N * test_p) train_data = {} test_data = {} for k in dataset.keys(): train_data[k] = dataset[k][:n, :] test_data[k] = dataset[k][n:, :] # setup vae variant['vae_kwargs']['action_dim'] = train_data['actions'].shape[1] if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae import VAE input_size = train_data['obs'].shape[1] variant['vae_kwargs']['input_size'] = input_size m = VAE(representation_size, **variant['vae_kwargs']) elif variant.get('vae_type', None) == "VAE2": from railrl.torch.vae.conv_vae2 import ConvVAE2 variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE2(representation_size, **variant['vae_kwargs']) else: variant['vae_kwargs']['imsize'] = variant['imsize'] m = ConvVAE(representation_size, **variant['vae_kwargs']) if ptu.gpu_enabled(): m.cuda() # setup vae trainer if variant.get('vae_type', None) == "VAE-state": from railrl.torch.vae.vae_trainer import VAETrainer t = VAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) # visualization vis_variant = variant.get('vis_kwargs', {}) save_video = vis_variant.get('save_video', False) if isinstance(env, ImageEnv): image_env = env else: image_env = ImageEnv( env, variant['generate_vae_dataset_kwargs'].get('imsize'), init_camera=variant['generate_vae_dataset_kwargs'].get( 'init_camera'), transpose=True, normalize=True, ) render = variant.get('render', False) reward_params = variant.get("reward_params", dict()) vae_env = VAEWrappedEnv(image_env, m, imsize=image_env.imsize, decode_goals=render, render_goals=render, render_rollouts=render, reward_params=reward_params, **variant.get('vae_wrapped_env_kwargs', {})) vae_env.reset() vae_env.add_mode("video_env", 'video_env') vae_env.add_mode("video_vae", 'video_vae') if save_video: import railrl.samplers.rollout_functions as rf from railrl.policies.simple import RandomPolicy random_policy = RandomPolicy(vae_env.action_space) rollout_function = rf.create_rollout_function( rf.multitask_rollout, max_path_length=100, observation_key='latent_observation', desired_goal_key='latent_desired_goal', vis_list=vis_variant.get('vis_list', []), dont_terminate=True, ) dump_video_kwargs = variant.get("dump_video_kwargs", dict()) dump_video_kwargs['imsize'] = vae_env.imsize dump_video_kwargs['vis_list'] = [ 'image_observation', 'reconstr_image_observation', 'image_latent_histogram_2d', 'image_latent_histogram_mu_2d', 'image_plt', 'image_rew', 'image_rew_euclidean', 'image_rew_mahalanobis', 'image_rew_logp', 'image_rew_kl', 'image_rew_kl_rev', ] def visualization_post_processing(save_vis, save_video, epoch): vis_list = vis_variant.get('vis_list', []) if save_vis: if vae_env.vae_input_key_prefix == 'state': vae_env.dump_reconstructions(epoch, n_recon=vis_variant.get( 'n_recon', 16)) vae_env.dump_samples(epoch, n_samples=vis_variant.get('n_samples', 64)) if 'latent_representation' in vis_list: vae_env.dump_latent_plots(epoch) if any(elem in vis_list for elem in [ 'latent_histogram', 'latent_histogram_mu', 'latent_histogram_2d', 'latent_histogram_mu_2d' ]): vae_env.compute_latent_histogram() if not save_video and ('latent_histogram' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=True, use_true_prior=True) if not save_video and ('latent_histogram_mu' in vis_list): vae_env.dump_latent_histogram(epoch=epoch, noisy=False, use_true_prior=True) if save_video and save_vis: from railrl.envs.vae_wrappers import temporary_mode from railrl.misc.video_gen import dump_video from railrl.core import logger vae_env.compute_goal_encodings() logdir = logger.get_snapshot_dir() filename = osp.join(logdir, 'video_{epoch}.mp4'.format(epoch=epoch)) variant['dump_video_kwargs']['epoch'] = epoch temporary_mode(vae_env, mode='video_env', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) if not vis_variant.get('save_video_env_only', True): filename = osp.join( logdir, 'video_{epoch}_vae.mp4'.format(epoch=epoch)) temporary_mode(vae_env, mode='video_vae', func=dump_video, args=(vae_env, random_policy, filename, rollout_function), kwargs=variant['dump_video_kwargs']) # train vae for epoch in range(variant['num_epochs']): #for epoch in range(2000): save_vis = (epoch % vis_variant['save_period'] == 0 or epoch == variant['num_epochs'] - 1) save_vae = (epoch % variant['snapshot_gap'] == 0 or epoch == variant['num_epochs'] - 1) t.train_epoch(epoch) '''if epoch % 500 == 0 or epoch == variant['num_epochs']-1: t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 200 == 0 or epoch == variant['num_epochs']-1: visualization_post_processing(save_video, save_video, epoch)''' t.test_epoch( epoch, save_reconstruction=save_vis, save_interpolation=save_vis, save_vae=save_vae, ) if epoch % 300 == 0 or epoch == variant['num_epochs'] - 1: visualization_post_processing(save_vis, save_video, epoch) logger.save_extra_data(m, 'vae.pkl', mode='pickle') logger.remove_tabular_output( 'vae_progress.csv', relative_to_snapshot_dir=True, ) logger.add_tabular_output( 'progress.csv', relative_to_snapshot_dir=True, ) print("finished --------------------!!!!!!!!!!!!!!!") return m
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant.get("representation_size", variant.get("latent_sizes", None)) use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant[ 'algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, exploration_schedule_kwargs=None, priority_function_kwargs=None, exploration_counter_kwargs=None, relabeling_goal_sampling_mode='vae_prior', decode_vae_goals=False, **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) # assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self.decode_vae_goals = decode_vae_goals if exploration_schedule_kwargs is None: self.explr_reward_scale_schedule = \ ConstantSchedule(self.exploration_rewards_scale) else: self.explr_reward_scale_schedule = \ PiecewiseLinearSchedule(**exploration_schedule_kwargs) self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1), dtype=np.float32) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1), dtype=np.float32) self._vae_sample_probs = None self.use_dynamics_model = ( self.exploration_rewards_type == 'forward_model_error' ) if self.use_dynamics_model: self.initialize_dynamics_model() type_to_function = { 'reconstruction_error': self.reconstruction_mse, 'bce': self.binary_cross_entropy, 'latent_distance': self.latent_novelty, 'latent_distance_true_prior': self.latent_novelty_true_prior, 'forward_model_error': self.forward_model_error, 'gaussian_inv_prob': self.gaussian_inv_prob, 'bernoulli_inv_prob': self.bernoulli_inv_prob, 'vae_prob': self.vae_prob, 'hash_count': self.hash_count_reward, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs if self.exploration_rewards_type == 'hash_count': if exploration_counter_kwargs is None: exploration_counter_kwargs = dict() self.exploration_counter = CountExploration(env=self.env, **exploration_counter_kwargs) self.epoch = 0