def main(input_file_or_name, data_directory = "data", log_directory = "log", results_directory = "results", temporary_log_directory = None, map_features = False, feature_selection = [], example_filter = [], preprocessing_methods = [], noisy_preprocessing_methods = [], split_data_set = True, splitting_method = "default", splitting_fraction = 0.9, model_type = "VAE", latent_size = 50, hidden_sizes = [500], number_of_importance_samples = [5], number_of_monte_carlo_samples = [10], inference_architecture = "MLP", latent_distribution = "gaussian", number_of_classes = None, parameterise_latent_posterior = False, generative_architecture = "MLP", reconstruction_distribution = "poisson", number_of_reconstruction_classes = 0, prior_probabilities_method = "uniform", number_of_warm_up_epochs = 0, kl_weight = 1, proportion_of_free_KL_nats = 0.0, batch_normalisation = True, dropout_keep_probabilities = [], count_sum = True, number_of_epochs = 200, plotting_interval_during_training = None, batch_size = 100, learning_rate = 1e-4, run_id = None, new_run = False, prediction_method = None, prediction_training_set_name = "training", prediction_decomposition_method = None, prediction_decomposition_dimensionality = None, decomposition_methods = ["PCA"], highlight_feature_indices = [], reset_training = False, skip_modelling = False, model_versions = ["all"], analyse = True, evaluation_set_name = "test", analyse_data = False, analyses = ["default"], analysis_level = "normal", fast_analysis = False, export_options = []): # Setup model_versions = parseModelVersions(model_versions) ## Analyses if fast_analysis: analyse = True analyses = ["simple"] analysis_level = "limited" ## Distributions reconstruction_distribution = parseDistribution( reconstruction_distribution) latent_distribution = parseDistribution(latent_distribution) ## Model configuration validation if not skip_modelling: if run_id: run_id = checkRunID(run_id) model_valid, model_errors = validateModelParameters( model_type, latent_distribution, reconstruction_distribution, number_of_reconstruction_classes, parameterise_latent_posterior ) if not model_valid: print("Model configuration is invalid:") for model_error in model_errors: print(" ", model_error) print() if analyse_data: print("Skipping modelling.") print("") skip_modelling = True else: print("Modelling cancelled.") return ## Binarisation binarise_values = False if reconstruction_distribution == "bernoulli": if noisy_preprocessing_methods: if noisy_preprocessing_methods[-1] != "binarise": noisy_preprocessing_methods.append("binarise") print("Appended binarisation method to noisy preprocessing,", "because of the Bernoulli distribution.\n") else: binarise_values = True ## Data sets if not split_data_set or analyse_data or evaluation_set_name == "full" \ or prediction_training_set_name == "full": full_data_set_needed = True else: full_data_set_needed = False # Data print(title("Data")) data_set = data.DataSet( input_file_or_name, directory = data_directory, map_features = map_features, feature_selection = feature_selection, example_filter = example_filter, preprocessing_methods = preprocessing_methods, binarise_values = binarise_values, noisy_preprocessing_methods = noisy_preprocessing_methods ) if full_data_set_needed: data_set.load() if split_data_set: training_set, validation_set, test_set = data_set.split( splitting_method, splitting_fraction) all_data_sets = [data_set, training_set, validation_set, test_set] else: splitting_method = None training_set = data_set validation_set = None test_set = data_set all_data_sets = [data_set] evaluation_set_name = "full" prediction_training_set_name = "full" ## Setup of log and results directories log_directory = data.directory(log_directory, data_set, splitting_method, splitting_fraction) data_results_directory = data.directory(results_directory, data_set, splitting_method, splitting_fraction, preprocessing = False) results_directory = data.directory(results_directory, data_set, splitting_method, splitting_fraction) if temporary_log_directory: main_temporary_log_directory = temporary_log_directory temporary_log_directory = data.directory(temporary_log_directory, data_set, splitting_method, splitting_fraction) ## Data analysis if analyse and analyse_data: print(subtitle("Analysing data")) analysis.analyseData( data_sets = all_data_sets, decomposition_methods = decomposition_methods, highlight_feature_indices = highlight_feature_indices, analyses = analyses, analysis_level = analysis_level, export_options = export_options, results_directory = data_results_directory ) print() ## Full data set clean up if not full_data_set_needed: data_set.clear() # Modelling if skip_modelling: print("Modelling skipped.") return print(title("Modelling")) # Set the number of features for the model feature_size = training_set.number_of_features # Parse numbers of samples number_of_monte_carlo_samples = parseSampleLists( number_of_monte_carlo_samples) number_of_importance_samples = parseSampleLists( number_of_importance_samples) # Use analytical KL term for single-Gaussian-VAE if "VAE" in model_type: if latent_distribution == "gaussian": analytical_kl_term = True else: analytical_kl_term = False # Change latent distribution to Gaussian mixture if not already set if model_type == "GMVAE" and latent_distribution != "gaussian mixture": latent_distribution = "gaussian mixture" print("The latent distribution was changed to", "a Gaussian-mixture model, because of the model chosen.\n") # Set the number of classes if not already set if not number_of_classes: if training_set.has_labels: number_of_classes = training_set.number_of_classes \ - training_set.number_of_excluded_classes elif "mixture" in latent_distribution: raise ValueError( "For a mixture model and a data set without labels, " "the number of classes has to be set." ) else: number_of_classes = 1 print(subtitle("Model setup")) if model_type == "VAE": model = VariationalAutoencoder( feature_size = feature_size, latent_size = latent_size, hidden_sizes = hidden_sizes, number_of_monte_carlo_samples =number_of_monte_carlo_samples, number_of_importance_samples = number_of_importance_samples, analytical_kl_term = analytical_kl_term, inference_architecture = inference_architecture, latent_distribution = latent_distribution, number_of_latent_clusters = number_of_classes, parameterise_latent_posterior = parameterise_latent_posterior, generative_architecture = generative_architecture, reconstruction_distribution = reconstruction_distribution, number_of_reconstruction_classes = number_of_reconstruction_classes, batch_normalisation = batch_normalisation, dropout_keep_probabilities = dropout_keep_probabilities, count_sum = count_sum, number_of_warm_up_epochs = number_of_warm_up_epochs, kl_weight = kl_weight, log_directory = log_directory, results_directory = results_directory ) elif model_type == "GMVAE": if prior_probabilities_method == "uniform": prior_probabilities = None elif prior_probabilities_method == "infer": prior_probabilities = training_set.class_probabilities elif prior_probabilities_method == "literature": prior_probabilities = training_set.literature_probabilities else: prior_probabilities = None if not prior_probabilities: prior_probabilities_method = "uniform" prior_probabilities_values = None else: prior_probabilities_values = list(prior_probabilities.values()) prior_probabilities = { "method": prior_probabilities_method, "values": prior_probabilities_values } model = GaussianMixtureVariationalAutoencoder( feature_size = feature_size, latent_size = latent_size, hidden_sizes = hidden_sizes, number_of_monte_carlo_samples = number_of_monte_carlo_samples, number_of_importance_samples = number_of_importance_samples, analytical_kl_term = analytical_kl_term, prior_probabilities = prior_probabilities, number_of_latent_clusters = number_of_classes, proportion_of_free_KL_nats = proportion_of_free_KL_nats, reconstruction_distribution = reconstruction_distribution, number_of_reconstruction_classes = number_of_reconstruction_classes, batch_normalisation = batch_normalisation, dropout_keep_probabilities = dropout_keep_probabilities, count_sum = count_sum, number_of_warm_up_epochs = number_of_warm_up_epochs, kl_weight = kl_weight, log_directory = log_directory, results_directory = results_directory ) else: raise ValueError("Model type not found: `{}`.".format(model_type)) print(model.description) print() print(model.parameters) print() ## Training print(subtitle("Model training")) status, run_id = model.train( training_set, validation_set, number_of_epochs = number_of_epochs, batch_size = batch_size, learning_rate = learning_rate, plotting_interval = plotting_interval_during_training, run_id = run_id, new_run = new_run, reset_training = reset_training, temporary_log_directory = temporary_log_directory ) # Remove temporary directories created and emptied during training if temporary_log_directory and os.path.exists(main_temporary_log_directory): removeEmptyDirectories(main_temporary_log_directory) if not status["completed"]: print(status["message"]) return status_filename = "status" if "epochs trained" in status: status_filename += "-" + status["epochs trained"] status_path = os.path.join( model.logDirectory(run_id = run_id), status_filename + ".log" ) with open(status_path, "w") as status_file: for status_field, status_value in status.items(): if status_value: status_file.write( status_field + ": " + str(status_value) + "\n" ) print() # Evaluation, prediction, and analysis ## Setup if analyse: if prediction_method: predict_labels_using_model = False elif "GM" in model.type: predict_labels_using_model = True prediction_method = "model" else: predict_labels_using_model = False else: predict_labels_using_model = False evaluation_title_parts = ["evaluation"] if analyse: if prediction_method: evaluation_title_parts.append("prediction") evaluation_title_parts.append("analysis") evaluation_title = enumerateListOfStrings(evaluation_title_parts) print(title(evaluation_title.capitalize())) ### Set selection for data_subset in all_data_sets: clear_subset = True if data_subset.kind == evaluation_set_name: evaluation_set = data_subset clear_subset = False if prediction_method \ and data_subset.kind == prediction_training_set_name: prediction_training_set = data_subset clear_subset = False if clear_subset: data_subset.clear() ### Evaluation set evaluation_subset_indices = analysis.evaluationSubsetIndices( evaluation_set) print("Evaluation set: {} set.".format(evaluation_set.kind)) ### Prediction method if prediction_method: prediction_method = properString( prediction_method, PREDICTION_METHOD_NAMES ) prediction_method_specifications = PREDICTION_METHOD_SPECIFICATIONS\ .get(prediction_method, {}) prediction_method_inference = prediction_method_specifications.get( "inference", None) prediction_method_fixed_number_of_clusters \ = prediction_method_specifications.get( "fixed number of clusters", None) prediction_method_cluster_kind = prediction_method_specifications.get( "cluster kind", None) if prediction_method_fixed_number_of_clusters: number_of_clusters = number_of_classes else: number_of_clusters = None if prediction_method_inference \ and prediction_method_inference == "transductive": prediction_training_set = None prediction_training_set_name = None else: prediction_training_set_name = prediction_training_set.kind prediction_details = { "method": prediction_method, "number_of_classes": number_of_clusters, "training_set_name": prediction_training_set_name, "decomposition_method": prediction_decomposition_method, "decomposition_dimensionality": prediction_decomposition_dimensionality } print("Prediction method: {}.".format(prediction_method)) if number_of_clusters: print("Number of clusters: {}.".format(number_of_clusters)) if prediction_training_set: print("Prediction training set: {} set.".format( prediction_training_set.kind)) prediction_id_parts = [] if prediction_decomposition_method: prediction_decomposition_method = properString( prediction_decomposition_method, DECOMPOSITION_METHOD_NAMES ) if not prediction_decomposition_dimensionality: prediction_decomposition_dimensionality \ = DEFAULT_DECOMPOSITION_DIMENSIONALITY prediction_id_parts += [ prediction_decomposition_method, prediction_decomposition_dimensionality ] prediction_details.update({ "decomposition_method": prediction_decomposition_method, "decomposition_dimensionality": prediction_decomposition_dimensionality }) print("Decomposition method before prediction: {}-d {}.".format( prediction_decomposition_dimensionality, prediction_decomposition_method )) prediction_id_parts.append(prediction_method) if number_of_clusters: prediction_id_parts.append(number_of_clusters) if prediction_training_set \ and prediction_training_set.kind != "training": prediction_id_parts.append(prediction_training_set.kind) prediction_id = "_".join(map( lambda s: normaliseString(str(s)).replace("_", ""), prediction_id_parts )) prediction_details["id"] = prediction_id else: prediction_details = {} ### Model parameter sets model_parameter_set_names = [] if "end_of_training" in model_versions: model_parameter_set_names.append("end of training") if "best_model" in model_versions \ and betterModelExists(model, run_id = run_id): model_parameter_set_names.append("best model") if "early_stopping" in model_versions \ and modelStoppedEarly(model, run_id = run_id): model_parameter_set_names.append("early stopping") print("Model parameter sets: {}.".format(enumerateListOfStrings( model_parameter_set_names))) print() ## Model analysis if analyse: print(subtitle("Model analysis")) analysis.analyseModel( model = model, run_id = run_id, analyses = analyses, analysis_level = analysis_level, export_options = export_options, results_directory = results_directory ) ## Results evaluation, prediction, and analysis for model_parameter_set_name in model_parameter_set_names: if model_parameter_set_name == "best model": use_best_model = True else: use_best_model = False if model_parameter_set_name == "early stopping": use_early_stopping_model = True else: use_early_stopping_model = False model_parameter_set_name = model_parameter_set_name.capitalize() print(subtitle(model_parameter_set_name)) # Evaluation model_parameter_set_name = model_parameter_set_name.replace(" ", "-") print(heading("{} evaluation".format(model_parameter_set_name))) if "VAE" in model.type: transformed_evaluation_set, reconstructed_evaluation_set,\ latent_evaluation_sets = model.evaluate( evaluation_set = evaluation_set, evaluation_subset_indices = evaluation_subset_indices, batch_size = batch_size, predict_labels = predict_labels_using_model, run_id = run_id, use_best_model = use_best_model, use_early_stopping_model = use_early_stopping_model ) else: transformed_evaluation_set, reconstructed_evaluation_set = \ model.evaluate( evaluation_set = evaluation_set, evaluation_subset_indices = evaluation_subset_indices, batch_size = batch_size, run_id = run_id, use_best_model = use_best_model, use_early_stopping_model = use_early_stopping_model ) latent_evaluation_sets = None print() # Prediction if analyse and "VAE" in model.type and prediction_method \ and not transformed_evaluation_set.has_predictions: print(heading("{} prediction".format(model_parameter_set_name))) latent_prediction_evaluation_set = latent_evaluation_sets["z"] if prediction_method_inference \ and prediction_method_inference == "inductive": latent_prediction_training_sets = model.evaluate( evaluation_set = prediction_training_set, batch_size = batch_size, run_id = run_id, use_best_model = use_best_model, use_early_stopping_model = use_early_stopping_model, output_versions = "latent", log_results = False ) latent_prediction_training_set \ = latent_prediction_training_sets["z"] print() else: latent_prediction_training_set = None if prediction_decomposition_method: if latent_prediction_training_set: latent_prediction_training_set, \ latent_prediction_evaluation_set \ = data.decomposeDataSubsets( latent_prediction_training_set, latent_prediction_evaluation_set, method = prediction_decomposition_method, number_of_components = prediction_decomposition_dimensionality, random = True ) else: latent_prediction_evaluation_set \ = data.decomposeDataSubsets( latent_prediction_evaluation_set, method = prediction_decomposition_method, number_of_components = prediction_decomposition_dimensionality, random = True ) print() cluster_ids, predicted_labels, predicted_superset_labels \ = predict( latent_prediction_training_set, latent_prediction_evaluation_set, prediction_method, number_of_clusters ) transformed_evaluation_set.updatePredictions( predicted_cluster_ids = cluster_ids, predicted_labels = predicted_labels, predicted_superset_labels = predicted_superset_labels ) reconstructed_evaluation_set.updatePredictions( predicted_cluster_ids = cluster_ids, predicted_labels = predicted_labels, predicted_superset_labels = predicted_superset_labels ) for variable in latent_evaluation_sets: latent_evaluation_sets[variable].updatePredictions( predicted_cluster_ids = cluster_ids, predicted_labels = predicted_labels, predicted_superset_labels = predicted_superset_labels ) print() # Analysis if analyse: print(heading("{} results analysis".format(model_parameter_set_name))) analysis.analyseResults( evaluation_set = transformed_evaluation_set, reconstructed_evaluation_set = reconstructed_evaluation_set, latent_evaluation_sets = latent_evaluation_sets, model = model, run_id = run_id, decomposition_methods = decomposition_methods, evaluation_subset_indices = evaluation_subset_indices, highlight_feature_indices = highlight_feature_indices, prediction_details = prediction_details, best_model = use_best_model, early_stopping = use_early_stopping_model, analyses = analyses, analysis_level = analysis_level, export_options = export_options, results_directory = results_directory ) # Clean up if transformed_evaluation_set.version == "original": transformed_evaluation_set.resetPredictions()
if __name__ == "__main__": from itertools import repeat from torch.autograd import Variable dset = SpriteDataset(transform=lambda x: x.reshape(-1), download=True) unlabelled = DataLoader(dset, batch_size=16, shuffle=True, sampler=SubsetRandomSampler( np.arange(len(dset) // 3))) models = [] from models import VariationalAutoencoder model = VariationalAutoencoder([64**2, 10, [1200, 1200]]) model.decoder = nn.Sequential( nn.Linear(10, 1200), nn.Tanh(), nn.Linear(1200, 1200), nn.Tanh(), nn.Linear(1200, 1200), nn.Tanh(), nn.Linear(10, 64**2), nn.Sigmoid(), ) if cuda: model = model.cuda() beta = repeat(4.0) optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2)
ToTensor(), Lambda(tmp_lambda) ]) # Download binarized MNIST data train_data = MNIST('./', train=True, download=True, transform=data_transform) # split into training and validation sets train_set, val_set = torch.utils.data.random_split(train_data, [50000, 10000]) # Setup data loaders kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {} train_loader = DataLoader( train_set, batch_size=config['batch_size'], shuffle=True, **kwargs ) val_loader = DataLoader( val_set, batch_size=config['batch_size'], shuffle=True, **kwargs ) # Instantiate model model = VariationalAutoencoder(config, x_dim=784).to(device) # Train model train_vae(model, config, train_loader, val_loader, 'generative-project')
def train_vae(train_exs: List[SentimentExample], continue_training): matrix_len = 5020 emb_dim = 300 weights_matrix = torch.zeros(matrix_len, emb_dim) word_embeddings = read_word_embeddings("data/glove.6B.300d-relativized.txt") for i in range(len(word_embeddings.word_indexer.objs_to_ints)): word = word_embeddings.word_indexer.get_object(i) weights_matrix[i,:] = torch.from_numpy(word_embeddings.get_embedding(word)).float() max_sequence_length = 50 num_epochs = 4 lr = 1e-3 if continue_training: model = load_model() else: model = VariationalAutoencoder(weights_matrix, max_sequence_length) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) for epoch in range(num_epochs): print("Epoch", epoch+1) model.train() ex_indices = list(range(len(train_exs))) random.shuffle(ex_indices) losses = [] for ex in train_exs: words = ex.words[:max_sequence_length] input = torch.LongTensor() target = torch.LongTensor() input_words = map(lambda x: x if x != " " else "PAD", ["PAD"] + words[:-1]) indexes = [] for word in input_words: idx = word_embeddings.word_indexer.index_of(word) if idx == -1: idx = 1 indexes.append(idx) if max_sequence_length - len(words) > 0: for i in range(max_sequence_length - len(words)): indexes.append(word_embeddings.word_indexer.index_of("PAD")) indexes = torch.LongTensor(indexes)[None] input = torch.cat((input, indexes), dim=0) target_words = map(lambda x: x if x != " " else "PAD", words) indexes = [] for word in target_words: idx = word_embeddings.word_indexer.index_of(word) if idx == -1: idx = 1 indexes.append(idx) if max_sequence_length - len(words) > 0: for i in range(max_sequence_length - len(words)): indexes.append(word_embeddings.word_indexer.index_of("PAD")) indexes = torch.LongTensor(indexes) target = torch.cat((target, indexes), dim=0) output = model(input) l = loss(output, target) losses.append(l.item()) optimizer.zero_grad() l.backward() optimizer.step() print(np.mean(losses)) model.eval() return model
def train_vae(args): # set up path dir and save path EXPERIMENT_DIR = os.path.join(SAVE_DIR, args.name) if not os.path.exists(EXPERIMENT_DIR): os.mkdir(EXPERIMENT_DIR) PARAM_PATH = os.path.join(EXPERIMENT_DIR, "epoch{}.pth.tar") LOGS_PATH = os.path.join(EXPERIMENT_DIR, "logs.json") IMG_DIR = os.path.join(EXPERIMENT_DIR, "img") if not os.path.exists(IMG_DIR): os.mkdir(IMG_DIR) RECONSTRUCTIONS_PATH = os.path.join(IMG_DIR, "reconstructions_epoch{}.png") INTERPOLATION_PATH = os.path.join(IMG_DIR, "interpolation_epoch{}.png") # data loader n_epochs = args.num_epochs train_data_loader, test_data_loader = setup_datasets( DATA_DIR, image_loader, args.batch_size) # Model ae = VariationalAutoencoder(args.z_dim).to(args.device) if args.device == "cuda": # support multiple gpu #ae = torch.nn.DataParallel(ae, args.gpu_ids) pass # Loss function loss_fn = DFC_VAE_Loss([1., 1., 1., 0., 0.], args.device) if "pix" in args.name: loss_fn = VAE_Loss() optimizer = torch.optim.Adam(ae.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3) # Logs writer = SummaryWriter() min_loss = float('inf') # Improvement counter not_improved = 0. for epoch in range(n_epochs): torch.cuda.empty_cache() ae.train() for step, (images, y) in enumerate(train_data_loader): global_step = step + epoch * len(train_data_loader) # train optimizer.zero_grad() # Compute VAE outputs and pass into loss function. out, mu, logvar = ae(images.to(args.device)) loss_comp = loss_fn(out, images.to(args.device), mu, logvar) # Propagate if "pix" in args.name: loss_comp.backward() loss = loss_comp else: loss = sum(loss_comp.values()) loss.backward() optimizer.step() # Save loss writer.add_scalar('train/loss', loss.item(), global_step) if "pix" not in args.name: writer.add_scalars('train', {k: v.item() for k, v in loss_comp.items()}, global_step) if step % 10 == 0: to_print = "Epoch [{}/{}]; Step [{}/{}]; Train Loss: {:.7f}".format(epoch+1, \ n_epochs, step, len(train_data_loader), loss.item()) sys.stdout.write(to_print + '\n') sys.stdout.flush() # test print("Evaluating...") # TEST torch.cuda.empty_cache() with torch.no_grad(): ae.eval() losses = [] for step, (images, y) in enumerate(test_data_loader): out, mu, logvar = ae(images.to(args.device)) loss_i = loss_fn(out, images.to(args.device), mu, logvar) if "pix" in args.name: losses.append(loss_i.item()) else: loss_i = sum(loss_i.values()) losses.append(loss_i.item()) test_loss = np.array(losses).mean() scheduler.step(test_loss) writer.add_scalar('test/loss', test_loss, epoch * len(train_data_loader)) to_print = "Epoch [{}/{}]; Test Loss: {:.7f}".format( epoch + 1, n_epochs, test_loss) sys.stdout.write(to_print + '\n') sys.stdout.flush() writer.export_scalars_to_json(LOGS_PATH) # END OF EPOCH; SAVE MODEL IF LOSS DECREASED AND PLOT RECONSTRUCTIONS params = ae.parameters() if epoch > 0 and test_loss < min_loss: try: torch.save(ae.state_dict(), PARAM_PATH.format(epoch + 1)) min_loss = test_loss not_improved = 0. plot_interpolation(test_data_loader, ae, args.device, INTERPOLATION_PATH, epoch) plot_reconstructions(test_data_loader, ae, args.device, RECONSTRUCTIONS_PATH, epoch) print( "Saved model, plotted reconstruction and interpolation and reset improvement counter" ) except: print( 'Error occurred while saving after epoch {}'.format(epoch + 1)) else: not_improved += 1 print("Training has not improved on test set for {} epochs".format( not_improved)) writer.close()