コード例 #1
0
def main(input_file_or_name, data_directory = "data",
    log_directory = "log", results_directory = "results",
    temporary_log_directory = None,
    map_features = False, feature_selection = [], example_filter = [],
    preprocessing_methods = [], noisy_preprocessing_methods = [],
    split_data_set = True,
    splitting_method = "default", splitting_fraction = 0.9,
    model_type = "VAE", latent_size = 50, hidden_sizes = [500],
    number_of_importance_samples = [5],
    number_of_monte_carlo_samples = [10],
    inference_architecture = "MLP",
    latent_distribution = "gaussian",
    number_of_classes = None,
    parameterise_latent_posterior = False,
    generative_architecture = "MLP",
    reconstruction_distribution = "poisson",
    number_of_reconstruction_classes = 0,
    prior_probabilities_method = "uniform",
    number_of_warm_up_epochs = 0,
    kl_weight = 1,
    proportion_of_free_KL_nats = 0.0,
    batch_normalisation = True,
    dropout_keep_probabilities = [],
    count_sum = True,
    number_of_epochs = 200, plotting_interval_during_training = None, 
    batch_size = 100, learning_rate = 1e-4,
    run_id = None, new_run = False,
    prediction_method = None, prediction_training_set_name = "training",
    prediction_decomposition_method = None,
    prediction_decomposition_dimensionality = None,
    decomposition_methods = ["PCA"], highlight_feature_indices = [],
    reset_training = False, skip_modelling = False,
    model_versions = ["all"],
    analyse = True, evaluation_set_name = "test", analyse_data = False,
    analyses = ["default"], analysis_level = "normal", fast_analysis = False,
    export_options = []):
    
    # Setup
    
    model_versions = parseModelVersions(model_versions)
    
    ## Analyses
    
    if fast_analysis:
        analyse = True
        analyses = ["simple"]
        analysis_level = "limited"
    
    ## Distributions
    
    reconstruction_distribution = parseDistribution(
        reconstruction_distribution)
    latent_distribution = parseDistribution(latent_distribution)
    
    ## Model configuration validation
    
    if not skip_modelling:
        
        if run_id:
            run_id = checkRunID(run_id)
        
        model_valid, model_errors = validateModelParameters(
            model_type, latent_distribution,
            reconstruction_distribution, number_of_reconstruction_classes,
            parameterise_latent_posterior
        )
        
        if not model_valid:
            print("Model configuration is invalid:")
            for model_error in model_errors:
                print("    ", model_error)
            print()
            if analyse_data:
                print("Skipping modelling.")
                print("")
                skip_modelling = True
            else:
                print("Modelling cancelled.")
                return
    
    ## Binarisation
    
    binarise_values = False
    
    if reconstruction_distribution == "bernoulli":
        if noisy_preprocessing_methods:
            if noisy_preprocessing_methods[-1] != "binarise":
                noisy_preprocessing_methods.append("binarise")
                print("Appended binarisation method to noisy preprocessing,",
                    "because of the Bernoulli distribution.\n")
        else:
            binarise_values = True
    
    ## Data sets
    
    if not split_data_set or analyse_data or evaluation_set_name == "full" \
        or prediction_training_set_name == "full":
            full_data_set_needed = True
    else:
        full_data_set_needed = False
    
    # Data
    
    print(title("Data"))
    
    data_set = data.DataSet(
        input_file_or_name,
        directory = data_directory,
        map_features = map_features,
        feature_selection = feature_selection,
        example_filter = example_filter,
        preprocessing_methods = preprocessing_methods,
        binarise_values = binarise_values,
        noisy_preprocessing_methods = noisy_preprocessing_methods
    )
    
    if full_data_set_needed:
        data_set.load()
    
    if split_data_set:
        training_set, validation_set, test_set = data_set.split(
            splitting_method, splitting_fraction)
        all_data_sets = [data_set, training_set, validation_set, test_set]
    else:
        splitting_method = None
        training_set = data_set
        validation_set = None
        test_set = data_set
        all_data_sets = [data_set]
        evaluation_set_name = "full"
        prediction_training_set_name = "full"
    
    ## Setup of log and results directories
    
    log_directory = data.directory(log_directory, data_set,
        splitting_method, splitting_fraction)
    data_results_directory = data.directory(results_directory, data_set,
        splitting_method, splitting_fraction, preprocessing = False)
    results_directory = data.directory(results_directory, data_set,
        splitting_method, splitting_fraction)
    
    if temporary_log_directory:
        main_temporary_log_directory = temporary_log_directory
        temporary_log_directory = data.directory(temporary_log_directory,
            data_set, splitting_method, splitting_fraction)
    
    ## Data analysis
    
    if analyse and analyse_data:
        print(subtitle("Analysing data"))
        analysis.analyseData(
            data_sets = all_data_sets,
            decomposition_methods = decomposition_methods,
            highlight_feature_indices = highlight_feature_indices,
            analyses = analyses,
            analysis_level = analysis_level,
            export_options = export_options,
            results_directory = data_results_directory
        )
        print()
    
    ## Full data set clean up
    
    if not full_data_set_needed:
        data_set.clear()
    
    # Modelling
    
    if skip_modelling:
        print("Modelling skipped.")
        return
    
    print(title("Modelling"))
    
    # Set the number of features for the model
    feature_size = training_set.number_of_features
    
    # Parse numbers of samples
    number_of_monte_carlo_samples = parseSampleLists(
        number_of_monte_carlo_samples)
    number_of_importance_samples = parseSampleLists(
        number_of_importance_samples)
    
    # Use analytical KL term for single-Gaussian-VAE
    if "VAE" in model_type:
        if latent_distribution == "gaussian":
            analytical_kl_term = True
        else:
            analytical_kl_term = False
    
    # Change latent distribution to Gaussian mixture if not already set
    if model_type == "GMVAE" and latent_distribution != "gaussian mixture":
        latent_distribution = "gaussian mixture"
        print("The latent distribution was changed to",
            "a Gaussian-mixture model, because of the model chosen.\n")
    
    # Set the number of classes if not already set
    if not number_of_classes:
        if training_set.has_labels:
            number_of_classes = training_set.number_of_classes \
                - training_set.number_of_excluded_classes
        elif "mixture" in latent_distribution:
            raise ValueError(
                "For a mixture model and a data set without labels, "
                "the number of classes has to be set."
            )
        else:
            number_of_classes = 1
    
    print(subtitle("Model setup"))
    
    if model_type == "VAE":
        model = VariationalAutoencoder(
            feature_size = feature_size,
            latent_size = latent_size,
            hidden_sizes = hidden_sizes,
            number_of_monte_carlo_samples =number_of_monte_carlo_samples,
            number_of_importance_samples = number_of_importance_samples,
            analytical_kl_term = analytical_kl_term,
            inference_architecture = inference_architecture,
            latent_distribution = latent_distribution,
            number_of_latent_clusters = number_of_classes,
            parameterise_latent_posterior = parameterise_latent_posterior,
            generative_architecture = generative_architecture,
            reconstruction_distribution = reconstruction_distribution,
            number_of_reconstruction_classes = number_of_reconstruction_classes,
            batch_normalisation = batch_normalisation,
            dropout_keep_probabilities = dropout_keep_probabilities,
            count_sum = count_sum,
            number_of_warm_up_epochs = number_of_warm_up_epochs,
            kl_weight = kl_weight,
            log_directory = log_directory,
            results_directory = results_directory
        )

    elif model_type == "GMVAE":
        
        if prior_probabilities_method == "uniform":
            prior_probabilities = None
        elif prior_probabilities_method == "infer":
            prior_probabilities = training_set.class_probabilities
        elif prior_probabilities_method == "literature":
            prior_probabilities = training_set.literature_probabilities
        else:
            prior_probabilities = None
        
        if not prior_probabilities:
            prior_probabilities_method = "uniform"
            prior_probabilities_values = None
        else:
            prior_probabilities_values = list(prior_probabilities.values())
        
        prior_probabilities = {
            "method": prior_probabilities_method,
            "values": prior_probabilities_values
        }
        
        model = GaussianMixtureVariationalAutoencoder(
            feature_size = feature_size,
            latent_size = latent_size,
            hidden_sizes = hidden_sizes,
            number_of_monte_carlo_samples = number_of_monte_carlo_samples,
            number_of_importance_samples = number_of_importance_samples, 
            analytical_kl_term = analytical_kl_term,
            prior_probabilities = prior_probabilities,
            number_of_latent_clusters = number_of_classes,
            proportion_of_free_KL_nats = proportion_of_free_KL_nats,
            reconstruction_distribution = reconstruction_distribution,
            number_of_reconstruction_classes = number_of_reconstruction_classes,
            batch_normalisation = batch_normalisation,
            dropout_keep_probabilities = dropout_keep_probabilities,
            count_sum = count_sum,
            number_of_warm_up_epochs = number_of_warm_up_epochs,
            kl_weight = kl_weight,
            log_directory = log_directory,
            results_directory = results_directory
        )
    
    else:
        raise ValueError("Model type not found: `{}`.".format(model_type))
    
    print(model.description)
    print()
    
    print(model.parameters)
    print()
    
    ## Training
    
    print(subtitle("Model training"))
    
    status, run_id = model.train(
        training_set,
        validation_set,
        number_of_epochs = number_of_epochs,
        batch_size = batch_size,
        learning_rate = learning_rate,
        plotting_interval = plotting_interval_during_training,
        run_id = run_id,
        new_run = new_run,
        reset_training = reset_training,
        temporary_log_directory = temporary_log_directory
    )
    
    # Remove temporary directories created and emptied during training
    if temporary_log_directory and os.path.exists(main_temporary_log_directory):
        removeEmptyDirectories(main_temporary_log_directory)
    
    if not status["completed"]:
        print(status["message"])
        return
    
    status_filename = "status"
    if "epochs trained" in status:
        status_filename += "-" + status["epochs trained"]
    status_path = os.path.join(
        model.logDirectory(run_id = run_id),
        status_filename + ".log"
    )
    with open(status_path, "w") as status_file:
        for status_field, status_value in status.items():
            if status_value:
                status_file.write(
                    status_field + ": " + str(status_value) + "\n"
                )
    
    print()
    
    # Evaluation, prediction, and analysis
    
    ## Setup
    
    if analyse:
        if prediction_method:
            predict_labels_using_model = False
        elif "GM" in model.type:
            predict_labels_using_model = True
            prediction_method = "model"
        else:
            predict_labels_using_model = False
    else:
        predict_labels_using_model = False
    
    evaluation_title_parts = ["evaluation"]
    
    if analyse:
        if prediction_method:
            evaluation_title_parts.append("prediction")
        evaluation_title_parts.append("analysis")
    
    evaluation_title = enumerateListOfStrings(evaluation_title_parts)
    
    print(title(evaluation_title.capitalize()))
    
    ### Set selection
    
    for data_subset in all_data_sets:
        
        clear_subset = True
        
        if data_subset.kind == evaluation_set_name:
            evaluation_set = data_subset
            clear_subset = False
            
        if prediction_method \
            and data_subset.kind == prediction_training_set_name:
                prediction_training_set = data_subset
                clear_subset = False
        
        if clear_subset:
            data_subset.clear()
    
    ### Evaluation set
    
    evaluation_subset_indices = analysis.evaluationSubsetIndices(
        evaluation_set)
    
    print("Evaluation set: {} set.".format(evaluation_set.kind))
    
    ### Prediction method
    
    if prediction_method:
        
        prediction_method = properString(
            prediction_method,
            PREDICTION_METHOD_NAMES
        )
        
        prediction_method_specifications = PREDICTION_METHOD_SPECIFICATIONS\
            .get(prediction_method, {})
        prediction_method_inference = prediction_method_specifications.get(
            "inference", None)
        prediction_method_fixed_number_of_clusters \
            = prediction_method_specifications.get(
                "fixed number of clusters", None)
        prediction_method_cluster_kind = prediction_method_specifications.get(
            "cluster kind", None)
        
        if prediction_method_fixed_number_of_clusters:
            number_of_clusters = number_of_classes
        else:
            number_of_clusters = None
        
        if prediction_method_inference \
            and prediction_method_inference == "transductive":
            
            prediction_training_set = None
            prediction_training_set_name = None
        
        else:
            prediction_training_set_name = prediction_training_set.kind
        
        prediction_details = {
            "method": prediction_method,
            "number_of_classes": number_of_clusters,
            "training_set_name": prediction_training_set_name,
            "decomposition_method": prediction_decomposition_method,
            "decomposition_dimensionality":
                prediction_decomposition_dimensionality
        }
        
        print("Prediction method: {}.".format(prediction_method))
        
        if number_of_clusters:
            print("Number of clusters: {}.".format(number_of_clusters))
        
        if prediction_training_set:
            print("Prediction training set: {} set.".format(
                prediction_training_set.kind))
        
        prediction_id_parts = []
        
        if prediction_decomposition_method:
            
            prediction_decomposition_method = properString(
                prediction_decomposition_method,
                DECOMPOSITION_METHOD_NAMES
            )
            
            if not prediction_decomposition_dimensionality:
                prediction_decomposition_dimensionality \
                    = DEFAULT_DECOMPOSITION_DIMENSIONALITY
            
            prediction_id_parts += [
                prediction_decomposition_method,
                prediction_decomposition_dimensionality
            ]
            
            prediction_details.update({
                "decomposition_method": prediction_decomposition_method,
                "decomposition_dimensionality":
                    prediction_decomposition_dimensionality
            })
            
            print("Decomposition method before prediction: {}-d {}.".format(
                prediction_decomposition_dimensionality,
                prediction_decomposition_method
            ))
        
        prediction_id_parts.append(prediction_method)
        
        if number_of_clusters:
            prediction_id_parts.append(number_of_clusters)
        
        if prediction_training_set \
            and prediction_training_set.kind != "training":
                prediction_id_parts.append(prediction_training_set.kind)
        
        prediction_id = "_".join(map(
            lambda s: normaliseString(str(s)).replace("_", ""),
            prediction_id_parts
        ))
        prediction_details["id"] = prediction_id
    
    else:
        prediction_details = {}
    
    ### Model parameter sets
    
    model_parameter_set_names = []
    
    if "end_of_training" in model_versions:
        model_parameter_set_names.append("end of training")
    
    if "best_model" in model_versions \
        and betterModelExists(model, run_id = run_id):
            model_parameter_set_names.append("best model")
    
    if "early_stopping" in model_versions \
        and modelStoppedEarly(model, run_id = run_id):
            model_parameter_set_names.append("early stopping")
    
    print("Model parameter sets: {}.".format(enumerateListOfStrings(
        model_parameter_set_names)))
    
    print()
    
    ## Model analysis
    
    if analyse:
        
        print(subtitle("Model analysis"))
        analysis.analyseModel(
            model = model,
            run_id = run_id,
            analyses = analyses,
            analysis_level = analysis_level,
            export_options = export_options,
            results_directory = results_directory
        )
    
    ## Results evaluation, prediction, and analysis
    
    for model_parameter_set_name in model_parameter_set_names:
        
        if model_parameter_set_name == "best model":
            use_best_model = True
        else:
            use_best_model = False
        
        if model_parameter_set_name == "early stopping":
            use_early_stopping_model = True
        else:
            use_early_stopping_model = False
        
        model_parameter_set_name = model_parameter_set_name.capitalize()
        print(subtitle(model_parameter_set_name))
        
        # Evaluation
        
        model_parameter_set_name = model_parameter_set_name.replace(" ", "-")
        
        print(heading("{} evaluation".format(model_parameter_set_name)))
        
        if "VAE" in model.type:
            transformed_evaluation_set, reconstructed_evaluation_set,\
                latent_evaluation_sets = model.evaluate(
                    evaluation_set = evaluation_set,
                    evaluation_subset_indices = evaluation_subset_indices,
                    batch_size = batch_size,
                    predict_labels = predict_labels_using_model,
                    run_id = run_id,
                    use_best_model = use_best_model,
                    use_early_stopping_model = use_early_stopping_model
                )
        else:
            transformed_evaluation_set, reconstructed_evaluation_set = \
                model.evaluate(
                    evaluation_set = evaluation_set,
                    evaluation_subset_indices = evaluation_subset_indices,
                    batch_size = batch_size,
                    run_id = run_id,
                    use_best_model = use_best_model,
                    use_early_stopping_model = use_early_stopping_model
                )
            latent_evaluation_sets = None
        
        print()
        
        # Prediction
        
        if analyse and "VAE" in model.type and prediction_method \
            and not transformed_evaluation_set.has_predictions:
            
            print(heading("{} prediction".format(model_parameter_set_name)))
            
            latent_prediction_evaluation_set = latent_evaluation_sets["z"]
            
            if prediction_method_inference \
                and prediction_method_inference == "inductive":
                
                latent_prediction_training_sets = model.evaluate(
                    evaluation_set = prediction_training_set,
                    batch_size = batch_size,
                    run_id = run_id,
                    use_best_model = use_best_model,
                    use_early_stopping_model = use_early_stopping_model,
                    output_versions = "latent",
                    log_results = False
                )
                latent_prediction_training_set \
                    = latent_prediction_training_sets["z"]
                
                print()
            
            else:
                latent_prediction_training_set = None
            
            if prediction_decomposition_method:
                
                if latent_prediction_training_set:
                    latent_prediction_training_set, \
                        latent_prediction_evaluation_set \
                        = data.decomposeDataSubsets(
                            latent_prediction_training_set,
                            latent_prediction_evaluation_set,
                            method = prediction_decomposition_method,
                            number_of_components = 
                                prediction_decomposition_dimensionality,
                            random = True
                        )
                else:
                    latent_prediction_evaluation_set \
                        = data.decomposeDataSubsets(
                            latent_prediction_evaluation_set,
                            method = prediction_decomposition_method,
                            number_of_components = 
                                prediction_decomposition_dimensionality,
                            random = True
                        )
                
                print()
            
            cluster_ids, predicted_labels, predicted_superset_labels \
                = predict(
                    latent_prediction_training_set,
                    latent_prediction_evaluation_set,
                    prediction_method,
                    number_of_clusters
                )
            
            transformed_evaluation_set.updatePredictions(
                predicted_cluster_ids = cluster_ids,
                predicted_labels = predicted_labels,
                predicted_superset_labels = predicted_superset_labels
            )
            reconstructed_evaluation_set.updatePredictions(
                predicted_cluster_ids = cluster_ids,
                predicted_labels = predicted_labels,
                predicted_superset_labels = predicted_superset_labels
            )
            
            for variable in latent_evaluation_sets:
                latent_evaluation_sets[variable].updatePredictions(
                    predicted_cluster_ids = cluster_ids,
                    predicted_labels = predicted_labels,
                    predicted_superset_labels = predicted_superset_labels
            )
            
            print()
        
        # Analysis
        
        if analyse:
            
            print(heading("{} results analysis".format(model_parameter_set_name)))
            
            analysis.analyseResults(
                evaluation_set = transformed_evaluation_set,
                reconstructed_evaluation_set = reconstructed_evaluation_set,
                latent_evaluation_sets = latent_evaluation_sets,
                model = model,
                run_id = run_id,
                decomposition_methods = decomposition_methods,
                evaluation_subset_indices = evaluation_subset_indices,
                highlight_feature_indices = highlight_feature_indices,
                prediction_details = prediction_details,
                best_model = use_best_model,
                early_stopping = use_early_stopping_model,
                analyses = analyses, analysis_level = analysis_level,
                export_options = export_options,
                results_directory = results_directory
            )
        
        # Clean up
        
        if transformed_evaluation_set.version == "original":
            transformed_evaluation_set.resetPredictions()
コード例 #2
0
ファイル: cross_analysis.py プロジェクト: chulaihunde/scVAE
def main(log_directory=None,
         results_directory=None,
         data_set_included_strings=[],
         data_set_excluded_strings=[],
         model_included_strings=[],
         model_excluded_strings=[],
         prediction_included_strings=[],
         prediction_excluded_strings=[],
         epoch_cut_off=inf,
         export_options=[],
         log_summary=False):

    if log_directory:
        log_directory = os.path.normpath(log_directory) + os.sep

    if results_directory:

        results_directory = os.path.normpath(results_directory) + os.sep
        cross_analysis_directory = os.path.join(results_directory +
                                                "cross_analysis")

        explanation_string_parts = []

        def appendExplanationForSearchStrings(search_strings, inclusive, kind):
            if search_strings:
                explanation_string_parts.append("{} {} with: {}.".format(
                    "Including" if inclusive else "Excluding", kind,
                    ", ".join(search_strings)))

        appendExplanationForSearchStrings(data_set_included_strings,
                                          inclusive=True,
                                          kind="data sets")
        appendExplanationForSearchStrings(data_set_excluded_strings,
                                          inclusive=False,
                                          kind="data sets")
        appendExplanationForSearchStrings(model_included_strings,
                                          inclusive=True,
                                          kind="models")
        appendExplanationForSearchStrings(model_excluded_strings,
                                          inclusive=False,
                                          kind="models")
        appendExplanationForSearchStrings(prediction_included_strings,
                                          inclusive=True,
                                          kind="prediction methods")
        appendExplanationForSearchStrings(prediction_excluded_strings,
                                          inclusive=False,
                                          kind="prediction methods")

        explanation_string = "\n".join(explanation_string_parts)

        print(explanation_string)

        print()

        if log_summary:

            log_filename_parts = []

            def appendSearchStrings(search_strings, symbol):
                if search_strings:
                    log_filename_parts.append("{}_{}".format(
                        symbol, "_".join(search_strings)))

            appendSearchStrings(data_set_included_strings, "d")
            appendSearchStrings(data_set_excluded_strings, "D")
            appendSearchStrings(model_included_strings, "m")
            appendSearchStrings(model_excluded_strings, "M")
            appendSearchStrings(prediction_included_strings, "p")
            appendSearchStrings(prediction_excluded_strings, "P")

            if not log_filename_parts:
                log_filename_parts.append("all")

            log_filename = "-".join(log_filename_parts) + log_extension
            log_path = os.path.join(cross_analysis_directory, log_filename)

            log_string_parts = [explanation_string + "\n"]

        test_metrics_set = testMetricsInResultsDirectory(
            results_directory, data_set_included_strings,
            data_set_excluded_strings, model_included_strings,
            model_excluded_strings)

        model_IDs = modelID()

        for data_set_name, models in test_metrics_set.items():

            data_set_title = titleFromDataSetName(data_set_name)

            print(title(data_set_title))

            if log_summary:
                log_string_parts.append(title(data_set_title, plain=True))

            comparisons = {}
            correlation_sets = {}

            for model_name, test_metrics in models.items():

                model_title = titleFromModelName(model_name)

                metrics_string_parts = []

                # ID

                model_ID = next(model_IDs)
                metrics_string_parts.append("ID: {}".format(model_ID))

                # Time

                timestamp = test_metrics["timestamp"]
                metrics_string_parts.append("Timestamp: {}".format(
                    formatTime(timestamp)))

                # Epochs

                E = test_metrics["number of epochs trained"]
                metrics_string_parts.append("Epochs trained: {}".format(E))

                metrics_string_parts.append("")

                # Evaluation

                evaluation = test_metrics["evaluation"]

                losses = [
                    "log_likelihood", "lower_bound", "reconstruction_error",
                    "kl_divergence", "kl_divergence_z", "kl_divergence_z1",
                    "kl_divergence_z2", "kl_divergence_y"
                ]

                for loss in losses:
                    if loss in evaluation:
                        metrics_string_parts.append("{}: {:-.6g}".format(
                            loss, evaluation[loss][-1]))

                if "lower_bound" in evaluation:
                    model_lower_bound = evaluation["lower_bound"][-1]
                else:
                    model_lower_bound = None

                # Accuracies

                accuracies = ["accuracy", "superset_accuracy"]

                for accuracy in accuracies:
                    if accuracy in test_metrics and test_metrics[accuracy]:
                        metrics_string_parts.append("{}: {:6.2f} %".format(
                            accuracy, 100 * test_metrics[accuracy][-1]))

                metrics_string_parts.append("")

                # Statistics

                if isinstance(test_metrics["statistics"], list):
                    statistics_sets = test_metrics["statistics"]
                else:
                    statistics_sets = None

                reconstructed_statistics = None

                if statistics_sets:
                    for statistics_set in statistics_sets:
                        if "reconstructed" in statistics_set["name"]:
                            reconstructed_statistics = statistics_set

                if reconstructed_statistics:
                    metrics_string_parts.append(
                        formatStatistics(reconstructed_statistics))

                metrics_string_parts.append("")

                # Predictions

                model_ARIs = []

                if "predictions" in test_metrics:

                    for prediction in test_metrics["predictions"].values():

                        ARIs = {}

                        method = prediction["prediction method"]
                        number_of_classes = prediction["number of classes"]

                        if not method:
                            method = "model"

                        prediction_string = "{} ({} classes)".format(
                            method, number_of_classes)

                        for key, value in prediction.items():
                            key_match = matchString(
                                "; ".join([prediction_string,
                                           key]), prediction_included_strings,
                                prediction_excluded_strings)
                            if not key_match:
                                continue
                            if key.startswith("ARI") and value is not None:
                                ARIs[key] = value

                        if ARIs:

                            metrics_string_parts.append(prediction_string +
                                                        ":")

                            for ARI_name, ARI_value in ARIs.items():
                                metrics_string_parts.append(
                                    "    {}: {:.6g}".format(
                                        ARI_name, ARI_value))

                                if "clusters" in ARI_name and ARI_value > 0:
                                    correlation_set_name = "; ".join(
                                        [prediction_string, ARI_name])
                                    if correlation_set_name not in correlation_sets:
                                        correlation_sets[
                                            correlation_set_name] = {
                                                "ELBO": [],
                                                "ARI": []
                                            }
                                    correlation_sets[correlation_set_name]["ELBO"]\
                                        .append(model_lower_bound)
                                    correlation_sets[correlation_set_name]["ARI"]\
                                        .append(ARI_value)

                            metrics_string_parts.append("")

                        model_ARIs.extend(
                            [v for k, v in ARIs.items() if "clusters" in k])

                comparisons[model_title] = {
                    "ID": model_ID,
                    "ELBO": model_lower_bound,
                    "ARI": model_ARIs
                }

                metrics_string = "\n".join(metrics_string_parts)

                print(subtitle(model_title))
                print(metrics_string)

                if log_summary:
                    log_string_parts.append(subtitle(model_title, plain=True))
                    log_string_parts.append(metrics_string)

            if len(comparisons) <= 1:
                continue

            # Correlations

            if correlation_sets:

                correlation_string_parts = []
                correlation_table = {}

                for set_name in correlation_sets:
                    if len(correlation_sets[set_name]["ELBO"]) < 2:
                        continue
                    correlation_coefficient, _ = pearsonr(
                        correlation_sets[set_name]["ELBO"],
                        correlation_sets[set_name]["ARI"])
                    correlation_table[set_name] = {
                        "r": correlation_coefficient
                    }

                if correlation_table:
                    correlation_table = pandas.DataFrame(correlation_table).T
                    correlation_string_parts.append(str(correlation_table))

                correlation_string_parts.append("")
                correlation_string_parts.append("Plotting correlations.")
                figure, figure_name = plotCorrelations(
                    correlation_sets,
                    x_key="ELBO",
                    y_key="ARI",
                    x_label=r"$\mathcal{L}$",
                    y_label=r"$R_{\mathrm{adj}}$",
                    name=data_set_name.replace(os.sep, "-"))
                saveFigure(figure, figure_name, export_options,
                           cross_analysis_directory)

                correlation_string = "\n".join(correlation_string_parts)

                print(subtitle("ELBO--ARI correlations"))
                print(correlation_string + "\n")

                if log_summary:
                    log_string_parts.append(
                        subtitle("ELBO--ARI correlations", plain=True))
                    log_string_parts.append(correlation_string + "\n")

            # Comparison

            model_spec_names = [
                "ID", "type", "distribution", "sizes", "other", "epochs"
            ]

            model_spec_short_names = {
                "ID": "#",
                "type": "T",
                "distribution": "LD",
                "sizes": "S",
                "other": "O",
                "epochs": "E"
            }

            model_metric_names = ["ELBO", "ARI"]

            model_field_names = model_spec_names + model_metric_names

            for model_title in comparisons:
                model_title_parts = model_title.split("; ")
                comparisons[model_title].update({
                    "type":
                    model_title_parts.pop(0),
                    "distribution":
                    model_title_parts.pop(0),
                    "sizes":
                    model_title_parts.pop(0),
                    "epochs":
                    model_title_parts.pop(-1).replace(" epochs", ""),
                    "other":
                    "; ".join(model_title_parts)
                })

            sorted_comparison_items = sorted(
                comparisons.items(),
                key=lambda key_value_pair: key_value_pair[-1]["ELBO"],
                reverse=True)

            network_architecture_ELBOs = {}
            network_architecture_epochs = {}

            for model_title, model_fields in comparisons.items():
                if model_fields["type"] == "VAE(G)" \
                    and model_fields["distribution"] == "NB" \
                    and model_fields["other"] == "BN":

                    epochs = model_fields["epochs"]
                    architecture = model_fields["sizes"]
                    ELBO = model_fields["ELBO"]

                    if int(epochs.split()[0]) > epoch_cut_off:
                        continue

                    h, l = architecture.rsplit("×", maxsplit=1)

                    if l not in network_architecture_ELBOs:
                        network_architecture_ELBOs[l] = {}
                        network_architecture_epochs[l] = {}

                    if h not in network_architecture_ELBOs[l]:
                        network_architecture_ELBOs[l][h] = ELBO
                        network_architecture_epochs[l][h] = epochs
                    else:
                        best_model_version = bestModelVersion(
                            network_architecture_epochs[l][h], epochs)
                        if epochs == best_model_version:
                            network_architecture_ELBOs[l][h] = ELBO
                            network_architecture_epochs[l][h] = epochs

            if network_architecture_ELBOs:
                network_architecture_ELBOs = pandas.DataFrame(
                    network_architecture_ELBOs)
                network_architecture_ELBOs = network_architecture_ELBOs\
                    .reindex(
                        columns = sorted(
                            network_architecture_ELBOs.columns,
                            key = lambda s: int(s)
                        )
                    )
                network_architecture_ELBOs = network_architecture_ELBOs\
                    .reindex(
                        index = sorted(
                            network_architecture_ELBOs.index,
                            key = lambda s: prod(map(int, s.split("×")))
                        )
                    )

                if network_architecture_ELBOs.size > 1:
                    figure, figure_name = plotELBOHeatMap(
                        network_architecture_ELBOs,
                        x_label="Latent dimension",
                        y_label="Number of hidden units",
                        z_symbol="\mathcal{L}",
                        name=data_set_name.replace(os.sep, "-"))
                    saveFigure(figure, figure_name, export_options,
                               cross_analysis_directory)
                    print()

            for model_title, model_fields in comparisons.items():
                for field_name, field_value in model_fields.items():

                    if isinstance(field_value, str):
                        continue

                    elif not field_value:
                        string = ""

                    elif isinstance(field_value, float):
                        string = "{:-.6g}".format(field_value)

                    elif isinstance(field_value, int):
                        string = "{:d}".format(field_value)

                    elif isinstance(field_value, list):

                        minimum = min(field_value)
                        maximum = max(field_value)

                        if minimum == maximum:
                            string = "{:.6g}".format(maximum)
                        else:
                            string = "{:5.3f}–{:5.3f}".format(minimum, maximum)

                    else:
                        raise TypeError(
                            "Type `{}` not supported in comparison table.".
                            format(type(field_value)))

                    comparisons[model_title][field_name] = string

            comparison_table_rows = []
            table_column_spacing = "  "

            comparison_table_column_widths = {}

            for field_name in model_field_names:
                comparison_table_column_widths[field_name] = max([
                    len(metrics[field_name])
                    for metrics in comparisons.values()
                ])

            comparison_table_heading_parts = []

            for field_name in model_field_names:

                field_width = comparison_table_column_widths[field_name]

                if field_width == 0:
                    continue

                if field_name in model_spec_names:
                    if len(field_name) > field_width:
                        field_name = model_spec_short_names[field_name]
                    elif field_name == field_name.lower():
                        field_name = field_name.capitalize()

                comparison_table_heading_parts.append("{:{}}".format(
                    field_name, field_width))

            comparison_table_heading = table_column_spacing.join(
                comparison_table_heading_parts)
            comparison_table_toprule = "-" * len(comparison_table_heading)

            comparison_table_rows.append(comparison_table_heading)
            comparison_table_rows.append(comparison_table_toprule)

            for model_title, model_fields in sorted_comparison_items:

                sorted_model_field_items = sorted(
                    model_fields.items(),
                    key=lambda key_value_pair: model_field_names.index(
                        key_value_pair[0]))

                comparison_table_row_parts = [
                    "{:{}}".format(field_value,
                                   comparison_table_column_widths[field_name])
                    for field_name, field_value in sorted_model_field_items
                    if comparison_table_column_widths[field_name] > 0
                ]

                comparison_table_rows.append(
                    table_column_spacing.join(comparison_table_row_parts))

            comparison_table = "\n".join(comparison_table_rows)

            print(subtitle("Comparison"))
            print(comparison_table + "\n")

            if log_summary:
                log_string_parts.append(subtitle("Comparison", plain=True))
                log_string_parts.append(comparison_table + "\n")

        if log_summary:

            log_string = "\n".join(log_string_parts)

            with open(log_path, "w") as log_file:
                log_file.write(log_string)