def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs'] ) uniform_dataset=generate_uniform_dataset_reacher( **variant['generate_uniform_dataset_kwargs'] ) logger.save_extra_data(info) logger.get_snapshot_dir() beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform(m, uniform_dataset) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch) if epoch % variant['train_weight_update_period'] == 0: t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) uniform_dataset = load_local_or_remote_file( variant['uniform_dataset_path']).item() uniform_dataset = unormalize_image(uniform_dataset['image_desired_goal']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform( m, uniform_dataset, variant['algo_kwargs']['priority_function_kwargs']) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions( dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] data = joblib.load(variant['file']) obs = data['obs'] size = int(data['size']) dataset = obs[:size, :] n = int(size * .9) train_data = dataset[:n, :] test_data = dataset[n:, :] logger.get_snapshot_dir() print('SIZE: ', size) uniform_dataset = generate_uniform_dataset_door( **variant['generate_uniform_dataset_kwargs'] ) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, decoder_output_activation=nn.Sigmoid(), **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.log_loss_under_uniform(uniform_dataset) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.dump_uniform_imgs_and_reconstructions(dataset=uniform_dataset, epoch=epoch) t.update_train_weights()
def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = variant['generate_vae_dataset_fn']( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: # kwargs = variant['beta_schedule_kwargs'] # kwargs['y_values'][2] = variant['beta'] # kwargs['x_values'][1] = variant['flat_x'] # kwargs['x_values'][2] = variant['ramp_x'] + variant['flat_x'] variant['beta_schedule_kwargs']['y_values'][-1] = variant['beta'] beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None m = variant['vae'](representation_size, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch) if variant['dump_skew_debug_plots']: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights()
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset_from_demos( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): m = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): raise NotImplementedError( 'This is currently broken, please update SpatialAutoEncoder then remove this line' ) m = SpatialAutoEncoder(representation_size, int(representation_size / 2)) else: vae_class = variant.get('vae_class', ConvVAE) m = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) if dump_skew_debug_plots: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') if return_data: return m, train_data, test_data return m