def experiment(variant): from railrl.core import logger import railrl.torch.pytorch_util as ptu beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset( **variant['get_data_kwargs'] ) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule(**variant['beta_schedule_kwargs']) else: beta_schedule = None m = AutoEncoder(representation_size, input_channels=3) if ptu.gpu_enabled(): m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs) if should_save_imgs: t.dump_samples(epoch)
def get_vae(variant, action_dim): from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch representation_size = variant["representation_size"] use_linear_dynamics = variant.get('use_linear_dynamics', False) if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) return model
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] train_data, test_data, info = generate_vae_dataset_from_demos( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): m = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): raise NotImplementedError( 'This is currently broken, please update SpatialAutoEncoder then remove this line' ) m = SpatialAutoEncoder(representation_size, int(representation_size / 2)) else: vae_class = variant.get('vae_class', ConvVAE) m = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_scatterplot=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) if dump_skew_debug_plots: t.dump_best_reconstruction(epoch) t.dump_worst_reconstruction(epoch) t.dump_sampling_histogram(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') if return_data: return m, train_data, test_data return m
def train_vae(variant, return_data=False): from railrl.misc.ml_util import PiecewiseLinearSchedule, ConstantSchedule from railrl.torch.vae.conv_vae import ( ConvVAE, SpatialAutoEncoder, AutoEncoder, ) import railrl.torch.vae.conv_vae as conv_vae from railrl.torch.vae.vae_trainer import ConvVAETrainer from railrl.core import logger import railrl.torch.pytorch_util as ptu from railrl.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant.get("representation_size", variant.get("latent_sizes", None)) use_linear_dynamics = variant.get('use_linear_dynamics', False) generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant[ 'algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'vae.pkl', mode='pickle') if return_data: return model, train_dataset, test_dataset return model