def train_vae(variant, other_variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, other_variant, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(m, 'vae.pkl', mode='pickle') # torch.save(m, other_variant['vae_pkl_path']+'/online_vae.pkl') # easy way:load momdel for via bonus if return_data: return m, train_data, test_data return m
def train_vae(variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule, ConstantSchedule from rlkit.torch.vae.conv_vae import ConvVAE # from rlkit.torch.vae.conv_vae import ( # ConvVAE, # ConvDynamicsVAE, # SpatialAutoEncoder, # AutoEncoder, # ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch import gym beta = variant["beta"] representation_size = variant.get( "representation_size", variant.get("latent_sizes", variant.get("embedding_dim", None))) use_linear_dynamics = variant.get('use_linear_dynamics', False) variant['algo_kwargs']['num_epochs'] = variant['num_epochs'] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) variant['generate_vae_dataset_kwargs'][ 'use_linear_dynamics'] = use_linear_dynamics variant['generate_vae_dataset_kwargs']['batch_size'] = variant[ 'algo_kwargs']['batch_size'] train_dataset, test_dataset, info = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if use_linear_dynamics: action_dim = train_dataset.data['actions'].shape[2] logger.save_extra_data(info) logger.get_snapshot_dir() if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if 'context_schedule' in variant: schedule = variant['context_schedule'] if type(schedule) is dict: context_schedule = PiecewiseLinearSchedule(**schedule) else: context_schedule = ConstantSchedule(schedule) variant['algo_kwargs']['context_schedule'] = context_schedule if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') if variant['algo_kwargs'].get('is_auto_encoder', False): model = AutoEncoder(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('use_spatial_auto_encoder', False): model = SpatialAutoEncoder( representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) elif variant.get('only_kwargs', False): vae_class = variant.get('vae_class', ConvVAE) model = vae_class(**variant['vae_kwargs']) else: vae_class = variant.get('vae_class', ConvVAE) if use_linear_dynamics: model = vae_class(representation_size, decoder_output_activation=decoder_activation, action_dim=action_dim, **variant['vae_kwargs']) else: model = vae_class(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) model.to(ptu.device) vae_trainer_class = variant.get('vae_trainer_class', ConvVAETrainer) trainer = vae_trainer_class(model, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) save_period = variant['save_period'] dump_skew_debug_plots = variant.get('dump_skew_debug_plots', False) for epoch in range(variant['num_epochs']): should_save_imgs = (epoch % save_period == 0) trainer.train_epoch(epoch, train_dataset) trainer.test_epoch(epoch, test_dataset) if should_save_imgs: trainer.dump_reconstructions(epoch) trainer.dump_samples(epoch) if dump_skew_debug_plots: trainer.dump_best_reconstruction(epoch) trainer.dump_worst_reconstruction(epoch) trainer.dump_sampling_histogram(epoch) stats = trainer.get_diagnostics() for k, v in stats.items(): logger.record_tabular(k, v) logger.dump_tabular() trainer.end_epoch(epoch) if epoch % 50 == 0: logger.save_itr_params(epoch, model) logger.save_extra_data(model, 'model', mode='pickle') if return_data: return model, train_dataset, test_dataset return model
def train_vae(variant, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ConvVAE import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch beta = variant["beta"] representation_size = variant["representation_size"] generate_vae_dataset_fctn = variant.get("generate_vae_data_fctn", generate_vae_dataset) train_data, test_data, info = generate_vae_dataset_fctn( variant["generate_vae_dataset_kwargs"]) logger.save_extra_data(info) logger.get_snapshot_dir() if "beta_schedule_kwargs" in variant: beta_schedule = PiecewiseLinearSchedule( **variant["beta_schedule_kwargs"]) else: beta_schedule = None if variant.get("decoder_activation", None) == "sigmoid": decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant["vae_kwargs"].get("architecture", None) if not architecture and variant.get("imsize") == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get("imsize") == 48: architecture = conv_vae.imsize48_default_architecture variant["vae_kwargs"]["architecture"] = architecture variant["vae_kwargs"]["imsize"] = variant.get("imsize") m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant["vae_kwargs"]) m.to(ptu.device) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant["algo_kwargs"]) save_period = variant["save_period"] dump_skew_debug_plots = variant.get("dump_skew_debug_plots", False) for epoch in range(variant["num_epochs"]): should_save_imgs = epoch % save_period == 0 t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(m, "vae.pkl", mode="pickle") if return_data: return m, train_data, test_data return m
def train_vae(cfgs, return_data=False): from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae from rlkit.torch.vae.vae_trainer import ConvVAETrainer from rlkit.core import logger import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch train_data, test_data, info = generate_vae_dataset(cfgs) logger.save_extra_data(info) logger.get_snapshot_dir() # FIXME default gaussian if cfgs.VAE.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = cfgs.VAE.get('architecture', None) if not architecture and cfgs.ENV.get('img_size') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and cfgs.ENV.get('img_size') == 48: architecture = conv_vae.imsize48_default_architecture vae_model = ConvVAE( representation_size=cfgs.VAE.representation_size, architecture=architecture, decoder_output_activation=decoder_activation, input_channels=cfgs.VAE.input_channels, decoder_distribution=cfgs.VAE.decoder_distribution, imsize=cfgs.VAE.img_size, ) vae_model.to(ptu.device) # FIXME the function of beta_schedule? if 'beta_schedule_kwargs' in cfgs.VAE_TRAINER: beta_schedule = PiecewiseLinearSchedule( **cfgs.VAE_TRAINER.beta_schedule_kwargs) else: beta_schedule = None t = ConvVAETrainer(train_data, test_data, vae_model, lr=cfgs.VAE_TRAINER.lr, beta=cfgs.VAE_TRAINER.beta, beta_schedule=beta_schedule) save_period = cfgs.VAE_TRAINER.save_period for epoch in range(cfgs.VAE_TRAINER.num_epochs): should_save_imgs = (epoch % save_period == 0) t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, # save_vae=False, ) if should_save_imgs: t.dump_samples(epoch) t.update_train_weights() logger.save_extra_data(vae_model, 'vae.pkl', mode='pickle') if return_data: return vae_model, train_data, test_data return vae_model
def train_vae( variant, return_data=False, skewfit_variant=None): # acutally train both the vae and the lstm from rlkit.util.ml_util import PiecewiseLinearSchedule from rlkit.torch.vae.conv_vae import ( ConvVAE, ) import rlkit.torch.vae.conv_vae as conv_vae import ROLL.LSTM_model as LSTM_model from ROLL.LSTM_model import ConvLSTM2 from ROLL.LSTM_trainer import ConvLSTMTrainer from rlkit.torch.vae.vae_trainer import ConvVAETrainer import rlkit.torch.pytorch_util as ptu from rlkit.pythonplusplus import identity import torch seg_pretrain = variant['seg_pretrain'] ori_pretrain = variant['ori_pretrain'] generate_vae_dataset_fctn = variant.get('generate_vae_data_fctn', generate_vae_dataset) generate_lstm_dataset_fctn = variant.get('generate_lstm_data_fctn') assert generate_lstm_dataset_fctn is not None, "Must provide a custom generate lstm pretraining dataset function!" train_data_lstm, test_data_lstm, info_lstm = generate_lstm_dataset_fctn( variant['generate_lstm_dataset_kwargs'], segmented=True, segmentation_method=skewfit_variant['segmentation_method']) train_data_ori, test_data_ori, info_ori = generate_vae_dataset_fctn( variant['generate_vae_dataset_kwargs']) if 'beta_schedule_kwargs' in variant: beta_schedule = PiecewiseLinearSchedule( **variant['beta_schedule_kwargs']) else: beta_schedule = None if variant.get('decoder_activation', None) == 'sigmoid': decoder_activation = torch.nn.Sigmoid() else: decoder_activation = identity architecture = variant['vae_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = conv_vae.imsize84_default_architecture elif not architecture and variant.get('imsize') == 48: architecture = conv_vae.imsize48_default_architecture variant['vae_kwargs']['architecture'] = architecture variant['vae_kwargs']['imsize'] = variant.get('imsize') architecture = variant['lstm_kwargs'].get('architecture', None) if not architecture and variant.get('imsize') == 84: architecture = None # TODO LSTM: wrap a 84 lstm architecutre elif not architecture and variant.get('imsize') == 48: architecture = LSTM_model.imsize48_default_architecture variant['lstm_kwargs']['architecture'] = architecture variant['lstm_kwargs']['imsize'] = variant.get('imsize') train_datas = [ train_data_lstm, train_data_ori, ] test_datas = [ test_data_lstm, test_data_ori, ] names = [ 'lstm_seg_pretrain', 'vae_ori_pretrain', ] vaes = [] env_id = variant['generate_lstm_dataset_kwargs'].get('env_id') assert env_id is not None lstm_pretrain_vae_only = variant.get('lstm_pretrain_vae_only', False) for idx in range(2): train_data, test_data, name = train_datas[idx], test_datas[idx], names[ idx] logger.add_tabular_output('{}_progress.csv'.format(name), relative_to_snapshot_dir=True) if idx == 1: # train the original vae representation_size = variant.get( "vae_representation_size", variant.get('representation_size')) beta = variant.get('vae_beta', variant.get('beta')) m = ConvVAE(representation_size, decoder_output_activation=decoder_activation, **variant['vae_kwargs']) t = ConvVAETrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) else: # train the segmentation lstm lstm_version = variant.get('lstm_version', 2) if lstm_version == 2: lstm_class = ConvLSTM2 representation_size = variant.get( "lstm_representation_size", variant.get('representation_size')) beta = variant.get('lstm_beta', variant.get('beta')) m = lstm_class(representation_size, decoder_output_activation=decoder_activation, **variant['lstm_kwargs']) t = ConvLSTMTrainer(train_data, test_data, m, beta=beta, beta_schedule=beta_schedule, **variant['algo_kwargs']) m.to(ptu.device) vaes.append(m) print("test data len: ", len(test_data)) print("train data len: ", len(train_data)) save_period = variant['save_period'] pjhome = os.environ['PJHOME'] if env_id == 'SawyerPushHurdle-v0' and osp.exists( osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format( 'SawyerPushHurdle-v0', 'seg-color', '500'))): data_file_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5.npy'.format(env_id, 'seg-color', 500)) puck_pos_path = osp.join( pjhome, 'data/local/pre-train-lstm', '{}-{}-{}-0.3-0.5-puck-pos.npy'.format(env_id, 'seg-color', 500)) all_data = np.load(data_file_path) puck_pos = np.load(puck_pos_path) all_data = normalize_image(all_data.copy()) obj_states = puck_pos else: all_data = np.concatenate([train_data_lstm, test_data_lstm], axis=0) all_data = normalize_image(all_data.copy()) obj_states = info_lstm['obj_state'] obj = 'door' if 'Door' in env_id else 'puck' num_epochs = variant['num_lstm_epochs'] if idx == 0 else variant[ 'num_vae_epochs'] if (idx == 0 and seg_pretrain) or (idx == 1 and ori_pretrain): for epoch in range(num_epochs): should_save_imgs = (epoch % save_period == 0) if idx == 0: # only LSTM trainer has 'only_train_vae' argument t.train_epoch(epoch, only_train_vae=lstm_pretrain_vae_only) t.test_epoch(epoch, save_reconstruction=should_save_imgs, save_prefix='r_' + name, only_train_vae=lstm_pretrain_vae_only) else: t.train_epoch(epoch) t.test_epoch( epoch, save_reconstruction=should_save_imgs, save_prefix='r_' + name, ) if should_save_imgs: t.dump_samples(epoch, save_prefix='s_' + name) if idx == 0: compare_latent_distance( m, all_data, obj_states, obj_name=obj, save_dir=logger.get_snapshot_dir(), save_name='lstm_latent_distance_{}.png'.format( epoch)) test_lstm_traj( env_id, m, save_path=logger.get_snapshot_dir(), save_name='lstm_test_traj_{}.png'.format(epoch)) test_masked_traj_lstm( env_id, m, save_dir=logger.get_snapshot_dir(), save_name='masked_test_{}.png'.format(epoch)) t.update_train_weights() logger.save_extra_data(m, '{}.pkl'.format(name), mode='pickle') logger.remove_tabular_output('{}_progress.csv'.format(name), relative_to_snapshot_dir=True) if idx == 0 and variant.get("only_train_lstm", False): exit() if return_data: return vaes, train_datas, test_datas return m