Beispiel #1
0
def main(input_file_or_name, data_directory = "data",
    log_directory = "log", results_directory = "results",
    temporary_log_directory = None,
    map_features = False, feature_selection = [], example_filter = [],
    preprocessing_methods = [], noisy_preprocessing_methods = [],
    split_data_set = True,
    splitting_method = "default", splitting_fraction = 0.9,
    model_type = "VAE", latent_size = 50, hidden_sizes = [500],
    number_of_importance_samples = [5],
    number_of_monte_carlo_samples = [10],
    inference_architecture = "MLP",
    latent_distribution = "gaussian",
    number_of_classes = None,
    parameterise_latent_posterior = False,
    generative_architecture = "MLP",
    reconstruction_distribution = "poisson",
    number_of_reconstruction_classes = 0,
    prior_probabilities_method = "uniform",
    number_of_warm_up_epochs = 0,
    kl_weight = 1,
    proportion_of_free_KL_nats = 0.0,
    batch_normalisation = True,
    dropout_keep_probabilities = [],
    count_sum = True,
    number_of_epochs = 200, plotting_interval_during_training = None, 
    batch_size = 100, learning_rate = 1e-4,
    run_id = None, new_run = False,
    prediction_method = None, prediction_training_set_name = "training",
    prediction_decomposition_method = None,
    prediction_decomposition_dimensionality = None,
    decomposition_methods = ["PCA"], highlight_feature_indices = [],
    reset_training = False, skip_modelling = False,
    model_versions = ["all"],
    analyse = True, evaluation_set_name = "test", analyse_data = False,
    analyses = ["default"], analysis_level = "normal", fast_analysis = False,
    export_options = []):
    
    # Setup
    
    model_versions = parseModelVersions(model_versions)
    
    ## Analyses
    
    if fast_analysis:
        analyse = True
        analyses = ["simple"]
        analysis_level = "limited"
    
    ## Distributions
    
    reconstruction_distribution = parseDistribution(
        reconstruction_distribution)
    latent_distribution = parseDistribution(latent_distribution)
    
    ## Model configuration validation
    
    if not skip_modelling:
        
        if run_id:
            run_id = checkRunID(run_id)
        
        model_valid, model_errors = validateModelParameters(
            model_type, latent_distribution,
            reconstruction_distribution, number_of_reconstruction_classes,
            parameterise_latent_posterior
        )
        
        if not model_valid:
            print("Model configuration is invalid:")
            for model_error in model_errors:
                print("    ", model_error)
            print()
            if analyse_data:
                print("Skipping modelling.")
                print("")
                skip_modelling = True
            else:
                print("Modelling cancelled.")
                return
    
    ## Binarisation
    
    binarise_values = False
    
    if reconstruction_distribution == "bernoulli":
        if noisy_preprocessing_methods:
            if noisy_preprocessing_methods[-1] != "binarise":
                noisy_preprocessing_methods.append("binarise")
                print("Appended binarisation method to noisy preprocessing,",
                    "because of the Bernoulli distribution.\n")
        else:
            binarise_values = True
    
    ## Data sets
    
    if not split_data_set or analyse_data or evaluation_set_name == "full" \
        or prediction_training_set_name == "full":
            full_data_set_needed = True
    else:
        full_data_set_needed = False
    
    # Data
    
    print(title("Data"))
    
    data_set = data.DataSet(
        input_file_or_name,
        directory = data_directory,
        map_features = map_features,
        feature_selection = feature_selection,
        example_filter = example_filter,
        preprocessing_methods = preprocessing_methods,
        binarise_values = binarise_values,
        noisy_preprocessing_methods = noisy_preprocessing_methods
    )
    
    if full_data_set_needed:
        data_set.load()
    
    if split_data_set:
        training_set, validation_set, test_set = data_set.split(
            splitting_method, splitting_fraction)
        all_data_sets = [data_set, training_set, validation_set, test_set]
    else:
        splitting_method = None
        training_set = data_set
        validation_set = None
        test_set = data_set
        all_data_sets = [data_set]
        evaluation_set_name = "full"
        prediction_training_set_name = "full"
    
    ## Setup of log and results directories
    
    log_directory = data.directory(log_directory, data_set,
        splitting_method, splitting_fraction)
    data_results_directory = data.directory(results_directory, data_set,
        splitting_method, splitting_fraction, preprocessing = False)
    results_directory = data.directory(results_directory, data_set,
        splitting_method, splitting_fraction)
    
    if temporary_log_directory:
        main_temporary_log_directory = temporary_log_directory
        temporary_log_directory = data.directory(temporary_log_directory,
            data_set, splitting_method, splitting_fraction)
    
    ## Data analysis
    
    if analyse and analyse_data:
        print(subtitle("Analysing data"))
        analysis.analyseData(
            data_sets = all_data_sets,
            decomposition_methods = decomposition_methods,
            highlight_feature_indices = highlight_feature_indices,
            analyses = analyses,
            analysis_level = analysis_level,
            export_options = export_options,
            results_directory = data_results_directory
        )
        print()
    
    ## Full data set clean up
    
    if not full_data_set_needed:
        data_set.clear()
    
    # Modelling
    
    if skip_modelling:
        print("Modelling skipped.")
        return
    
    print(title("Modelling"))
    
    # Set the number of features for the model
    feature_size = training_set.number_of_features
    
    # Parse numbers of samples
    number_of_monte_carlo_samples = parseSampleLists(
        number_of_monte_carlo_samples)
    number_of_importance_samples = parseSampleLists(
        number_of_importance_samples)
    
    # Use analytical KL term for single-Gaussian-VAE
    if "VAE" in model_type:
        if latent_distribution == "gaussian":
            analytical_kl_term = True
        else:
            analytical_kl_term = False
    
    # Change latent distribution to Gaussian mixture if not already set
    if model_type == "GMVAE" and latent_distribution != "gaussian mixture":
        latent_distribution = "gaussian mixture"
        print("The latent distribution was changed to",
            "a Gaussian-mixture model, because of the model chosen.\n")
    
    # Set the number of classes if not already set
    if not number_of_classes:
        if training_set.has_labels:
            number_of_classes = training_set.number_of_classes \
                - training_set.number_of_excluded_classes
        elif "mixture" in latent_distribution:
            raise ValueError(
                "For a mixture model and a data set without labels, "
                "the number of classes has to be set."
            )
        else:
            number_of_classes = 1
    
    print(subtitle("Model setup"))
    
    if model_type == "VAE":
        model = VariationalAutoencoder(
            feature_size = feature_size,
            latent_size = latent_size,
            hidden_sizes = hidden_sizes,
            number_of_monte_carlo_samples =number_of_monte_carlo_samples,
            number_of_importance_samples = number_of_importance_samples,
            analytical_kl_term = analytical_kl_term,
            inference_architecture = inference_architecture,
            latent_distribution = latent_distribution,
            number_of_latent_clusters = number_of_classes,
            parameterise_latent_posterior = parameterise_latent_posterior,
            generative_architecture = generative_architecture,
            reconstruction_distribution = reconstruction_distribution,
            number_of_reconstruction_classes = number_of_reconstruction_classes,
            batch_normalisation = batch_normalisation,
            dropout_keep_probabilities = dropout_keep_probabilities,
            count_sum = count_sum,
            number_of_warm_up_epochs = number_of_warm_up_epochs,
            kl_weight = kl_weight,
            log_directory = log_directory,
            results_directory = results_directory
        )

    elif model_type == "GMVAE":
        
        if prior_probabilities_method == "uniform":
            prior_probabilities = None
        elif prior_probabilities_method == "infer":
            prior_probabilities = training_set.class_probabilities
        elif prior_probabilities_method == "literature":
            prior_probabilities = training_set.literature_probabilities
        else:
            prior_probabilities = None
        
        if not prior_probabilities:
            prior_probabilities_method = "uniform"
            prior_probabilities_values = None
        else:
            prior_probabilities_values = list(prior_probabilities.values())
        
        prior_probabilities = {
            "method": prior_probabilities_method,
            "values": prior_probabilities_values
        }
        
        model = GaussianMixtureVariationalAutoencoder(
            feature_size = feature_size,
            latent_size = latent_size,
            hidden_sizes = hidden_sizes,
            number_of_monte_carlo_samples = number_of_monte_carlo_samples,
            number_of_importance_samples = number_of_importance_samples, 
            analytical_kl_term = analytical_kl_term,
            prior_probabilities = prior_probabilities,
            number_of_latent_clusters = number_of_classes,
            proportion_of_free_KL_nats = proportion_of_free_KL_nats,
            reconstruction_distribution = reconstruction_distribution,
            number_of_reconstruction_classes = number_of_reconstruction_classes,
            batch_normalisation = batch_normalisation,
            dropout_keep_probabilities = dropout_keep_probabilities,
            count_sum = count_sum,
            number_of_warm_up_epochs = number_of_warm_up_epochs,
            kl_weight = kl_weight,
            log_directory = log_directory,
            results_directory = results_directory
        )
    
    else:
        raise ValueError("Model type not found: `{}`.".format(model_type))
    
    print(model.description)
    print()
    
    print(model.parameters)
    print()
    
    ## Training
    
    print(subtitle("Model training"))
    
    status, run_id = model.train(
        training_set,
        validation_set,
        number_of_epochs = number_of_epochs,
        batch_size = batch_size,
        learning_rate = learning_rate,
        plotting_interval = plotting_interval_during_training,
        run_id = run_id,
        new_run = new_run,
        reset_training = reset_training,
        temporary_log_directory = temporary_log_directory
    )
    
    # Remove temporary directories created and emptied during training
    if temporary_log_directory and os.path.exists(main_temporary_log_directory):
        removeEmptyDirectories(main_temporary_log_directory)
    
    if not status["completed"]:
        print(status["message"])
        return
    
    status_filename = "status"
    if "epochs trained" in status:
        status_filename += "-" + status["epochs trained"]
    status_path = os.path.join(
        model.logDirectory(run_id = run_id),
        status_filename + ".log"
    )
    with open(status_path, "w") as status_file:
        for status_field, status_value in status.items():
            if status_value:
                status_file.write(
                    status_field + ": " + str(status_value) + "\n"
                )
    
    print()
    
    # Evaluation, prediction, and analysis
    
    ## Setup
    
    if analyse:
        if prediction_method:
            predict_labels_using_model = False
        elif "GM" in model.type:
            predict_labels_using_model = True
            prediction_method = "model"
        else:
            predict_labels_using_model = False
    else:
        predict_labels_using_model = False
    
    evaluation_title_parts = ["evaluation"]
    
    if analyse:
        if prediction_method:
            evaluation_title_parts.append("prediction")
        evaluation_title_parts.append("analysis")
    
    evaluation_title = enumerateListOfStrings(evaluation_title_parts)
    
    print(title(evaluation_title.capitalize()))
    
    ### Set selection
    
    for data_subset in all_data_sets:
        
        clear_subset = True
        
        if data_subset.kind == evaluation_set_name:
            evaluation_set = data_subset
            clear_subset = False
            
        if prediction_method \
            and data_subset.kind == prediction_training_set_name:
                prediction_training_set = data_subset
                clear_subset = False
        
        if clear_subset:
            data_subset.clear()
    
    ### Evaluation set
    
    evaluation_subset_indices = analysis.evaluationSubsetIndices(
        evaluation_set)
    
    print("Evaluation set: {} set.".format(evaluation_set.kind))
    
    ### Prediction method
    
    if prediction_method:
        
        prediction_method = properString(
            prediction_method,
            PREDICTION_METHOD_NAMES
        )
        
        prediction_method_specifications = PREDICTION_METHOD_SPECIFICATIONS\
            .get(prediction_method, {})
        prediction_method_inference = prediction_method_specifications.get(
            "inference", None)
        prediction_method_fixed_number_of_clusters \
            = prediction_method_specifications.get(
                "fixed number of clusters", None)
        prediction_method_cluster_kind = prediction_method_specifications.get(
            "cluster kind", None)
        
        if prediction_method_fixed_number_of_clusters:
            number_of_clusters = number_of_classes
        else:
            number_of_clusters = None
        
        if prediction_method_inference \
            and prediction_method_inference == "transductive":
            
            prediction_training_set = None
            prediction_training_set_name = None
        
        else:
            prediction_training_set_name = prediction_training_set.kind
        
        prediction_details = {
            "method": prediction_method,
            "number_of_classes": number_of_clusters,
            "training_set_name": prediction_training_set_name,
            "decomposition_method": prediction_decomposition_method,
            "decomposition_dimensionality":
                prediction_decomposition_dimensionality
        }
        
        print("Prediction method: {}.".format(prediction_method))
        
        if number_of_clusters:
            print("Number of clusters: {}.".format(number_of_clusters))
        
        if prediction_training_set:
            print("Prediction training set: {} set.".format(
                prediction_training_set.kind))
        
        prediction_id_parts = []
        
        if prediction_decomposition_method:
            
            prediction_decomposition_method = properString(
                prediction_decomposition_method,
                DECOMPOSITION_METHOD_NAMES
            )
            
            if not prediction_decomposition_dimensionality:
                prediction_decomposition_dimensionality \
                    = DEFAULT_DECOMPOSITION_DIMENSIONALITY
            
            prediction_id_parts += [
                prediction_decomposition_method,
                prediction_decomposition_dimensionality
            ]
            
            prediction_details.update({
                "decomposition_method": prediction_decomposition_method,
                "decomposition_dimensionality":
                    prediction_decomposition_dimensionality
            })
            
            print("Decomposition method before prediction: {}-d {}.".format(
                prediction_decomposition_dimensionality,
                prediction_decomposition_method
            ))
        
        prediction_id_parts.append(prediction_method)
        
        if number_of_clusters:
            prediction_id_parts.append(number_of_clusters)
        
        if prediction_training_set \
            and prediction_training_set.kind != "training":
                prediction_id_parts.append(prediction_training_set.kind)
        
        prediction_id = "_".join(map(
            lambda s: normaliseString(str(s)).replace("_", ""),
            prediction_id_parts
        ))
        prediction_details["id"] = prediction_id
    
    else:
        prediction_details = {}
    
    ### Model parameter sets
    
    model_parameter_set_names = []
    
    if "end_of_training" in model_versions:
        model_parameter_set_names.append("end of training")
    
    if "best_model" in model_versions \
        and betterModelExists(model, run_id = run_id):
            model_parameter_set_names.append("best model")
    
    if "early_stopping" in model_versions \
        and modelStoppedEarly(model, run_id = run_id):
            model_parameter_set_names.append("early stopping")
    
    print("Model parameter sets: {}.".format(enumerateListOfStrings(
        model_parameter_set_names)))
    
    print()
    
    ## Model analysis
    
    if analyse:
        
        print(subtitle("Model analysis"))
        analysis.analyseModel(
            model = model,
            run_id = run_id,
            analyses = analyses,
            analysis_level = analysis_level,
            export_options = export_options,
            results_directory = results_directory
        )
    
    ## Results evaluation, prediction, and analysis
    
    for model_parameter_set_name in model_parameter_set_names:
        
        if model_parameter_set_name == "best model":
            use_best_model = True
        else:
            use_best_model = False
        
        if model_parameter_set_name == "early stopping":
            use_early_stopping_model = True
        else:
            use_early_stopping_model = False
        
        model_parameter_set_name = model_parameter_set_name.capitalize()
        print(subtitle(model_parameter_set_name))
        
        # Evaluation
        
        model_parameter_set_name = model_parameter_set_name.replace(" ", "-")
        
        print(heading("{} evaluation".format(model_parameter_set_name)))
        
        if "VAE" in model.type:
            transformed_evaluation_set, reconstructed_evaluation_set,\
                latent_evaluation_sets = model.evaluate(
                    evaluation_set = evaluation_set,
                    evaluation_subset_indices = evaluation_subset_indices,
                    batch_size = batch_size,
                    predict_labels = predict_labels_using_model,
                    run_id = run_id,
                    use_best_model = use_best_model,
                    use_early_stopping_model = use_early_stopping_model
                )
        else:
            transformed_evaluation_set, reconstructed_evaluation_set = \
                model.evaluate(
                    evaluation_set = evaluation_set,
                    evaluation_subset_indices = evaluation_subset_indices,
                    batch_size = batch_size,
                    run_id = run_id,
                    use_best_model = use_best_model,
                    use_early_stopping_model = use_early_stopping_model
                )
            latent_evaluation_sets = None
        
        print()
        
        # Prediction
        
        if analyse and "VAE" in model.type and prediction_method \
            and not transformed_evaluation_set.has_predictions:
            
            print(heading("{} prediction".format(model_parameter_set_name)))
            
            latent_prediction_evaluation_set = latent_evaluation_sets["z"]
            
            if prediction_method_inference \
                and prediction_method_inference == "inductive":
                
                latent_prediction_training_sets = model.evaluate(
                    evaluation_set = prediction_training_set,
                    batch_size = batch_size,
                    run_id = run_id,
                    use_best_model = use_best_model,
                    use_early_stopping_model = use_early_stopping_model,
                    output_versions = "latent",
                    log_results = False
                )
                latent_prediction_training_set \
                    = latent_prediction_training_sets["z"]
                
                print()
            
            else:
                latent_prediction_training_set = None
            
            if prediction_decomposition_method:
                
                if latent_prediction_training_set:
                    latent_prediction_training_set, \
                        latent_prediction_evaluation_set \
                        = data.decomposeDataSubsets(
                            latent_prediction_training_set,
                            latent_prediction_evaluation_set,
                            method = prediction_decomposition_method,
                            number_of_components = 
                                prediction_decomposition_dimensionality,
                            random = True
                        )
                else:
                    latent_prediction_evaluation_set \
                        = data.decomposeDataSubsets(
                            latent_prediction_evaluation_set,
                            method = prediction_decomposition_method,
                            number_of_components = 
                                prediction_decomposition_dimensionality,
                            random = True
                        )
                
                print()
            
            cluster_ids, predicted_labels, predicted_superset_labels \
                = predict(
                    latent_prediction_training_set,
                    latent_prediction_evaluation_set,
                    prediction_method,
                    number_of_clusters
                )
            
            transformed_evaluation_set.updatePredictions(
                predicted_cluster_ids = cluster_ids,
                predicted_labels = predicted_labels,
                predicted_superset_labels = predicted_superset_labels
            )
            reconstructed_evaluation_set.updatePredictions(
                predicted_cluster_ids = cluster_ids,
                predicted_labels = predicted_labels,
                predicted_superset_labels = predicted_superset_labels
            )
            
            for variable in latent_evaluation_sets:
                latent_evaluation_sets[variable].updatePredictions(
                    predicted_cluster_ids = cluster_ids,
                    predicted_labels = predicted_labels,
                    predicted_superset_labels = predicted_superset_labels
            )
            
            print()
        
        # Analysis
        
        if analyse:
            
            print(heading("{} results analysis".format(model_parameter_set_name)))
            
            analysis.analyseResults(
                evaluation_set = transformed_evaluation_set,
                reconstructed_evaluation_set = reconstructed_evaluation_set,
                latent_evaluation_sets = latent_evaluation_sets,
                model = model,
                run_id = run_id,
                decomposition_methods = decomposition_methods,
                evaluation_subset_indices = evaluation_subset_indices,
                highlight_feature_indices = highlight_feature_indices,
                prediction_details = prediction_details,
                best_model = use_best_model,
                early_stopping = use_early_stopping_model,
                analyses = analyses, analysis_level = analysis_level,
                export_options = export_options,
                results_directory = results_directory
            )
        
        # Clean up
        
        if transformed_evaluation_set.version == "original":
            transformed_evaluation_set.resetPredictions()
Beispiel #2
0
if __name__ == "__main__":
    from itertools import repeat
    from torch.autograd import Variable

    dset = SpriteDataset(transform=lambda x: x.reshape(-1), download=True)
    unlabelled = DataLoader(dset,
                            batch_size=16,
                            shuffle=True,
                            sampler=SubsetRandomSampler(
                                np.arange(len(dset) // 3)))

    models = []

    from models import VariationalAutoencoder
    model = VariationalAutoencoder([64**2, 10, [1200, 1200]])
    model.decoder = nn.Sequential(
        nn.Linear(10, 1200),
        nn.Tanh(),
        nn.Linear(1200, 1200),
        nn.Tanh(),
        nn.Linear(1200, 1200),
        nn.Tanh(),
        nn.Linear(10, 64**2),
        nn.Sigmoid(),
    )

    if cuda: model = model.cuda()

    beta = repeat(4.0)
    optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-2)
Beispiel #3
0
    ToTensor(),
    Lambda(tmp_lambda)
])

# Download binarized MNIST data
train_data = MNIST('./', train=True, download=True, transform=data_transform)

# split into training and validation sets
train_set, val_set = torch.utils.data.random_split(train_data, [50000, 10000])

# Setup data loaders
kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    **kwargs
)
val_loader = DataLoader(
    val_set,
    batch_size=config['batch_size'],
    shuffle=True,
    **kwargs
)

# Instantiate model
model = VariationalAutoencoder(config, x_dim=784).to(device)

# Train model
train_vae(model, config, train_loader, val_loader, 'generative-project')
Beispiel #4
0
def train_vae(train_exs: List[SentimentExample], continue_training):
    matrix_len = 5020
    emb_dim = 300
    weights_matrix = torch.zeros(matrix_len, emb_dim)
    word_embeddings = read_word_embeddings("data/glove.6B.300d-relativized.txt")

    for i in range(len(word_embeddings.word_indexer.objs_to_ints)):
        word = word_embeddings.word_indexer.get_object(i)
        weights_matrix[i,:] = torch.from_numpy(word_embeddings.get_embedding(word)).float()

    max_sequence_length = 50
    num_epochs = 4
    lr = 1e-3
    
    if continue_training:
        model = load_model()
    else:
        model = VariationalAutoencoder(weights_matrix, max_sequence_length)
    
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    for epoch in range(num_epochs):
        print("Epoch", epoch+1)
        model.train()

        ex_indices = list(range(len(train_exs)))
        random.shuffle(ex_indices)

        losses = []

        for ex in train_exs:
            words = ex.words[:max_sequence_length]
            input = torch.LongTensor()
            target = torch.LongTensor()

            input_words = map(lambda x: x if x != " " else "PAD", ["PAD"] + words[:-1])
            indexes = []

            for word in input_words:
                idx = word_embeddings.word_indexer.index_of(word)
                if idx == -1:
                    idx = 1
                indexes.append(idx)

            if max_sequence_length - len(words) > 0:
                for i in range(max_sequence_length - len(words)):
                    indexes.append(word_embeddings.word_indexer.index_of("PAD"))

            indexes = torch.LongTensor(indexes)[None]
            input = torch.cat((input, indexes), dim=0)

            target_words = map(lambda x: x if x != " " else "PAD", words)
            indexes = []

            for word in target_words:
                idx = word_embeddings.word_indexer.index_of(word)
                if idx == -1:
                    idx = 1
                indexes.append(idx)

            if max_sequence_length - len(words) > 0:
                for i in range(max_sequence_length - len(words)):
                    indexes.append(word_embeddings.word_indexer.index_of("PAD"))

            indexes = torch.LongTensor(indexes)
            target = torch.cat((target, indexes), dim=0)

            output = model(input)

            l = loss(output, target)
            losses.append(l.item())
            optimizer.zero_grad()
            l.backward()
            optimizer.step()

        print(np.mean(losses))
        model.eval()

    return model
Beispiel #5
0
def train_vae(args):

    # set up path dir and save path
    EXPERIMENT_DIR = os.path.join(SAVE_DIR, args.name)
    if not os.path.exists(EXPERIMENT_DIR):
        os.mkdir(EXPERIMENT_DIR)
    PARAM_PATH = os.path.join(EXPERIMENT_DIR, "epoch{}.pth.tar")
    LOGS_PATH = os.path.join(EXPERIMENT_DIR, "logs.json")
    IMG_DIR = os.path.join(EXPERIMENT_DIR, "img")
    if not os.path.exists(IMG_DIR):
        os.mkdir(IMG_DIR)
    RECONSTRUCTIONS_PATH = os.path.join(IMG_DIR, "reconstructions_epoch{}.png")
    INTERPOLATION_PATH = os.path.join(IMG_DIR, "interpolation_epoch{}.png")

    # data loader
    n_epochs = args.num_epochs
    train_data_loader, test_data_loader = setup_datasets(
        DATA_DIR, image_loader, args.batch_size)

    # Model
    ae = VariationalAutoencoder(args.z_dim).to(args.device)
    if args.device == "cuda":
        # support multiple gpu
        #ae = torch.nn.DataParallel(ae, args.gpu_ids)
        pass

    # Loss function
    loss_fn = DFC_VAE_Loss([1., 1., 1., 0., 0.], args.device)
    if "pix" in args.name:
        loss_fn = VAE_Loss()
    optimizer = torch.optim.Adam(ae.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3)

    # Logs
    writer = SummaryWriter()
    min_loss = float('inf')

    # Improvement counter
    not_improved = 0.

    for epoch in range(n_epochs):
        torch.cuda.empty_cache()
        ae.train()
        for step, (images, y) in enumerate(train_data_loader):

            global_step = step + epoch * len(train_data_loader)

            # train
            optimizer.zero_grad()

            # Compute VAE outputs and pass into loss function.
            out, mu, logvar = ae(images.to(args.device))
            loss_comp = loss_fn(out, images.to(args.device), mu, logvar)

            # Propagate
            if "pix" in args.name:
                loss_comp.backward()
                loss = loss_comp
            else:
                loss = sum(loss_comp.values())
                loss.backward()
            optimizer.step()

            # Save loss
            writer.add_scalar('train/loss', loss.item(), global_step)
            if "pix" not in args.name:
                writer.add_scalars('train',
                                   {k: v.item()
                                    for k, v in loss_comp.items()},
                                   global_step)

            if step % 10 == 0:
                to_print = "Epoch [{}/{}]; Step [{}/{}]; Train Loss: {:.7f}".format(epoch+1, \
                                    n_epochs, step, len(train_data_loader), loss.item())
                sys.stdout.write(to_print + '\n')
                sys.stdout.flush()

        # test
        print("Evaluating...")

        # TEST
        torch.cuda.empty_cache()
        with torch.no_grad():
            ae.eval()
            losses = []
            for step, (images, y) in enumerate(test_data_loader):
                out, mu, logvar = ae(images.to(args.device))
                loss_i = loss_fn(out, images.to(args.device), mu, logvar)
                if "pix" in args.name:
                    losses.append(loss_i.item())
                else:
                    loss_i = sum(loss_i.values())
                    losses.append(loss_i.item())

        test_loss = np.array(losses).mean()
        scheduler.step(test_loss)
        writer.add_scalar('test/loss', test_loss,
                          epoch * len(train_data_loader))
        to_print = "Epoch [{}/{}]; Test Loss: {:.7f}".format(
            epoch + 1, n_epochs, test_loss)
        sys.stdout.write(to_print + '\n')
        sys.stdout.flush()
        writer.export_scalars_to_json(LOGS_PATH)

        # END OF EPOCH; SAVE MODEL IF LOSS DECREASED AND PLOT RECONSTRUCTIONS
        params = ae.parameters()
        if epoch > 0 and test_loss < min_loss:
            try:
                torch.save(ae.state_dict(), PARAM_PATH.format(epoch + 1))
                min_loss = test_loss
                not_improved = 0.
                plot_interpolation(test_data_loader, ae, args.device,
                                   INTERPOLATION_PATH, epoch)
                plot_reconstructions(test_data_loader, ae, args.device,
                                     RECONSTRUCTIONS_PATH, epoch)
                print(
                    "Saved model, plotted reconstruction and interpolation and reset improvement counter"
                )
            except:
                print(
                    'Error occurred while saving after epoch {}'.format(epoch +
                                                                        1))
        else:
            not_improved += 1

        print("Training has not improved on test set for {} epochs".format(
            not_improved))

    writer.close()