def experiment2(directory_name, dataset='MNIST', direction='vae_to_iwae'): '''The experiment that trains a vae initialized at an iwae or vice versa''' dataset = datasets.load_dataset_from_name(dataset) if direction == 'vae_to_iwae': previous_args = dict(layers=1, model='vae', k=1, dataset='MNIST', exp='train') new_model_type = 'iwae' new_k = 50 elif direction == 'iwae_to_vae': previous_args = dict(layers=1, model='iwae', k=50, dataset='MNIST', exp='train') new_model_type = 'vae' new_k = 1 previous_directory_name = directory_to_store(**previous_args) loaded_checkpoint, model, optimizer, srng = load_checkpoint(previous_directory_name, 8) optimizer.learning_rate = 1e-4 model = train.train(model=model, dataset=dataset, optimizer=optimizer, minibatch_size=20, n_epochs=3**7, srng=srng, num_samples=new_k, model_type=new_model_type) save_checkpoint(directory_name, 0, model, optimizer, srng) post_experiment(directory_name, dataset, model)
def training_experiment(directory_name, latent_units, hidden_units_q, hidden_units_p, k, model_type, dataset, checkpoint=-1): '''The experiment that trains a model with given parameters''' def checkpoint0(dataset): data_dimension = dataset.get_data_dim() model = iwae.random_iwae(latent_units=[data_dimension] + latent_units, hidden_units_q=hidden_units_q, hidden_units_p=hidden_units_p, dataset=dataset ) srng = utils.srng() optimizer = optimizers.Adam(model=model, learning_rate=1e-3) return model, optimizer, srng def checkpoint1to8(i, dataset, model, optimizer, srng): optimizer.learning_rate = 1e-4*round(10.**(1-(i-1)/7.), 1) model = train.train(model=model, dataset=dataset, optimizer=optimizer, minibatch_size=20, n_epochs=3**(i-1), srng=srng, num_samples=k, model_type=model_type) return model, optimizer, srng dataset = datasets.load_dataset_from_name(dataset) loaded_checkpoint = -1 if checkpoint >= 0: loaded_checkpoint, model, optimizer, srng = load_checkpoint(directory_name, checkpoint) if loaded_checkpoint == -1: print "Unable to load checkpoint {} from {}, starting the experiment from the beginning".format(checkpoint, directory_name) if loaded_checkpoint < 0: model, optimizer, srng = checkpoint0(dataset) save_checkpoint(directory_name, 0, model, optimizer, srng) loaded_checkpoint = 0 for i in range(loaded_checkpoint+1, 9): model, optimizer, srng = checkpoint1to8(i, dataset, model, optimizer, srng) save_checkpoint(directory_name, i, model, optimizer, srng) loaded_checkpoint = 8 post_experiment(directory_name, dataset, model)