def save_best_results(network_name, checkpoint_names, set_of_thresholds, set_of_mrs, set_of_homogeneity_scores,
                      set_of_completeness_scores, speaker_numbers):
    if len(set_of_mrs) == 1:
        write_result_pickle(network_name + "_best", checkpoint_names, set_of_thresholds, set_of_mrs,
                            set_of_homogeneity_scores, set_of_completeness_scores, speaker_numbers)
    else:

        # Find best result (min MR)
        min_mrs = []
        for mrs in set_of_mrs:
            min_mrs.append(np.min(mrs))

        min_mr_over_all = min(min_mrs)

        best_checkpoint_name = []
        set_of_best_mrs = []
        set_of_best_homogeneity_scores = []
        set_of_best_completeness_scores = []
        set_of_best_thresholds = []
        best_speaker_numbers = []
        for index, min_mr in enumerate(min_mrs):
            if min_mr == min_mr_over_all:
                best_checkpoint_name.append(checkpoint_names[index])
                set_of_best_mrs.append(set_of_mrs[index])
                set_of_best_homogeneity_scores.append(set_of_homogeneity_scores[index])
                set_of_best_completeness_scores.append(set_of_completeness_scores[index])
                set_of_best_thresholds.append(set_of_thresholds[index])
                best_speaker_numbers.append(speaker_numbers[index])

        write_result_pickle(network_name + "_best", best_checkpoint_name, set_of_best_thresholds, set_of_best_mrs,
                            set_of_best_homogeneity_scores, set_of_best_completeness_scores, best_speaker_numbers)
def plot_curves(plot_file_name, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings):
    """
    Plots all specified curves and saves the plot into a file.
    :param plot_file_name: String value of save file name
    :param curve_names: Set of names used in legend to describe this curve
    :param mrs: 2D Matrix, each row describes one dataset of misclassification rates for a curve
    :param homogeneity_scores: 2D Matrix, each row describes one dataset of homogeneity scores for a curve
    :param completeness_scores: 2D Matrix, each row describes one dataset of completeness scores for a curve
    :param number_of_embeddings: set of integers, each integer describes how many embeddings is in this curve
    """
    logger = get_logger('analysis', logging.INFO)
    logger.info('Plot results')
    logger.info(plot_file_name)
    min_mrs = []
    for mr in mrs:
        min_mrs.append(np.min(mr))

    min_mrs, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings = \
        (list(t) for t in
         zip(*sorted(zip(min_mrs, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings))))

    # How many lines to plot
    number_of_lines = len(curve_names)

    # Get various colors needed to plot
    color_map = plt.get_cmap('gist_rainbow')
    colors = [color_map(i) for i in np.linspace(0, 1, number_of_lines)]

    # Define number of figures
    fig1 = plt.figure(1)
    fig1.set_size_inches(16, 8)

    # Define Plots
    mr_plot = plt.subplot2grid((2, 3), (0, 0), colspan=2)
    mr_plot.set_ylabel('MR')
    mr_plot.set_xlabel('number of clusters')
    plt.ylim([-0.02, 1.02])

    completeness_scores_plot = add_cluster_subplot(fig1, 234, 'completeness_scores')
    homogeneity_scores_plot = add_cluster_subplot(fig1, 235, 'homogeneity_scores')

    # Define curves and their values
    curves = [[mr_plot, mrs],
              [homogeneity_scores_plot, homogeneity_scores],
              [completeness_scores_plot, completeness_scores]]

    # Plot all curves
    for index in range(number_of_lines):
        label = curve_names[index] + '\n min MR: ' + str(min_mrs[index])
        color = colors[index]
        number_of_clusters = np.arange(number_of_embeddings[index], 0, -1)

        for plot, value in curves:
            plot.plot(number_of_clusters, value[index], color=color, label=label)

    # Add legend and save the plot
    fig1.legend()
    # fig1.show()
    fig1.savefig(get_result_png(plot_file_name))
    fig1.savefig(get_result_png(plot_file_name + '.svg'), format='svg')
Ejemplo n.º 3
0
def _save_best_results(network_name, checkpoint_names, metric_sets,
                       speaker_numbers):
    if len(metric_sets[0]) == 1:
        _write_result_pickle(network_name + "_best", checkpoint_names,
                             metric_sets, speaker_numbers)
    else:
        # Find best result (according to the first metric in metrics)
        if (metric_min_values[0] == 1):
            best_results = []
            for results in metric_sets[0]:
                best_results.append(np.min(results))
            best_result_over_all = min(best_results)
        else:
            best_results = []
            for results in metric_sets[0]:
                best_results.append(np.max(results))
            best_result_over_all = max(best_results)

        best_checkpoint_name = []
        set_of_best_metrics = [[] for _ in metric_sets]
        best_speaker_numbers = []

        for index, best_result in enumerate(best_results):
            if best_result == best_result_over_all:
                best_checkpoint_name.append(checkpoint_names[index])
                for m, metric_set in enumerate(metric_sets):
                    set_of_best_metrics[m].append(metric_set[index])
                best_speaker_numbers.append(speaker_numbers[index])

        _write_result_pickle(network_name + "_best", best_checkpoint_name,
                             set_of_best_metrics, best_speaker_numbers)
def plot_curves(plot_file_name, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings):
    """
    Plots all specified curves and saves the plot into a file.
    """
    logger = get_logger('analysis', logging.INFO)
    logger.info('Plot results')

    # How many lines to plot
    number_of_lines = len(curve_names)

    # Get various colors needed to plot
    color_map = plt.get_cmap('gist_rainbow')
    colors = [color_map(i) for i in np.linspace(0, 1, number_of_lines)]

    # Define number of figures
    fig1 = plt.figure(1)
    fig1.set_size_inches(32, 24)

    # Define Plots
    mr_plot = plt.subplot2grid((2, 2), (0, 0), colspan=2)
    mr_plot.set_title('MR')
    mr_plot.set_xlabel('number of clusters')
    mr_plot.axis([0, 80, 0, 1])

    completeness_scores_plot = add_cluster_subplot(fig1, 223, 'completeness_scores')
    homogeneity_scores_plot = add_cluster_subplot(fig1, 224, 'homogeneity_scores')

    # Define curves and their values
    curves = [[mr_plot, mrs],
              [homogeneity_scores_plot, homogeneity_scores],
              [completeness_scores_plot, completeness_scores]]

    # Plot all curves
    for index in range(number_of_lines):
        label = curve_names[index]
        color = colors[index]
        number_of_clusters = np.arange(number_of_embeddings[index], 1, -1)

        for plot, value in curves:
            plot.plot(number_of_clusters, value[index], color=color, label=label)

        min_mr = np.min(mrs[index])
        mr_plot.annotate(str(min_mr), xy=(0, min_mr))

    # Add legend and save the plot
    fig1.legend()
    # fig1.show()
    fig1.savefig(get_result_png(plot_file_name))
Ejemplo n.º 5
0
def plot_curves(plot_file_name, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings,loaded_dict,loaded_dict2,loaded_dict3):
    """
    Plots all specified curves and saves the plot into a file.
    :param plot_file_name: String value of save file name
    :param curve_names: Set of names used in legend to describe this curve
    :param mrs: 2D Matrix, each row describes one dataset of misclassification rates for a curve
    :param homogeneity_scores: 2D Matrix, each row describes one dataset of homogeneity scores for a curve
    :param completeness_scores: 2D Matrix, each row describes one dataset of completeness scores for a curve
    :param number_of_embeddings: set of integers, each integer describes how many embeddings is in this curve
    """
    logger = get_logger('analysis', logging.INFO)
    logger.info('Plot results')
    logger.info(plot_file_name)
    min_mrs = []
    for mr in mrs:
        min_mrs.append(np.min(mr))


    num_clusters=[]
    hierach_MR = []
    kmeans_mr = []
    ds_mr = []

    for x in loaded_dict:
        num_clusters.append(x)
        hierach_MR.append(loaded_dict[x])
        # print(str(loaded_dict[x]) + " appended to " + str(x))
    print("\n")
    print("Hierachial MR")
    print(hierach_MR)

    for x in loaded_dict2:
        kmeans_mr.append(loaded_dict2[x])
        # print(str(loaded_dict[x]) + " appended to " + str(x))
    print("\n")
    print("Kmeans MR")
    print(kmeans_mr)

    for x in loaded_dict3:
        ds_mr.append(loaded_dict3[x])
        # print(str(loaded_dict[x]) + " appended to " + str(x))
    print("\n")
    print("DS MR")
    print(ds_mr)


    ks = list(loaded_dict)
    ks = list(map(int, ks))
    print("\n")
    print("Cluster Count")
    print(ks)
    print("Minimum Cluster : " + str(min(ks)))
    print("Maximum CLuster : " + str(max(ks))+ "\n")

    maxc = max(ks)
    minc = min(ks)

    # x = zip(*sorted(zip(min_mrs, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings)))
    #
    #
    # print("\n")
    # print(tuple(x))
    # print("\n")
    #
    # min_mrs, curve_names, mrs, homogeneity_scores, completeness_scores, number_of_embeddings = \
    #     (list(t) for t in x)


    # How many lines to plot
    number_of_lines = len(curve_names)

    # Get various colors needed to plot
    color_map = plt.get_cmap('gist_rainbow')
    colors = [color_map(i) for i in np.linspace(0, 1, number_of_lines)]

    # Define number of figures
    fig1 = plt.figure(1)
    fig1.set_size_inches(16, 8)

    # Define Plots
    mr_plot = plt.subplot2grid((2, 3), (0, 0), colspan=2)
    mr_plot.set_ylabel('MR')
    mr_plot.set_xlabel('Number of clusters')
    plt.grid(True)
    plt.axis([minc,maxc, -0.02,1.2])


    # print(mrs)
    #
    #
    # for i in range(1,len(mrs),2):
    #     print(mrs[i])
    #     print(mrs[i][0])
    #     kmeans_mr.append(mrs[i][0])
    #     print(str(mrs[i][0]) + " appended "

    completeness_scores_plot = add_cluster_subplot(fig1, 234, 'completeness_scores')
    homogeneity_scores_plot = add_cluster_subplot(fig1, 235, 'homogeneity_scores')

    value = [hierach_MR , kmeans_mr, ds_mr]
    # Define curves and their values
    curves = [[mr_plot, value]]

    algorithm = ["Agglomerative_Hierachial_Clustering",
                 "K_Means_Clustering",
                 "DominantSets_Clustering"]


    # Plot all curves
    for index in range(3):
        label = algorithm[index]
        color = colors[index]
        # number_of_clusters = np.arange(number_of_embeddings[index], 0, -1)


        for plot, value in curves:
            print(value[index])
            plot.plot(ks,value[index], color=color, label=label)


    # Add legend and save the plot
    fig1.legend()
    # fig1.show()
    fig1.savefig(get_result_png(plot_file_name))
    print("Plot File saved in " + get_result_png(plot_file_name) )
    fig1.savefig(get_result_png(plot_file_name + '.svg'), format='svg')
Ejemplo n.º 6
0
def _plot_curves(plot_file_name, curve_names, metric_sets,
                 number_of_embeddings):
    """
    Plots all specified curves and saves the plot into a file.
    :param plot_file_name: String value of save file name
    :param curve_names: Set of names used in legend to describe this curve
    :param metric_sets: A list of 2D matrices, each row of a metrics 2D matrix describes one dataset for a curve
    :param number_of_embeddings: set of integers, each integer describes how many embeddings is in this curve
    """
    logger = get_logger('analysis', logging.INFO)
    logger.info('Plot results')

    config = load_config(None, join(get_common(), 'config.cfg'))
    plot_width = config.getint('common', 'plot_width')
    fig_width = config.getint('common', 'fig_width')
    fig_height = config.getint('common', 'fig_height')
    #Slice results to only 1-80 clusters
    for i in range(0, len(metric_sets)):
        for j in range(0, len(metric_sets[i])):
            metric_sets[i][j] = metric_sets[i][j][-plot_width:]
            print(len(metric_sets[i][j]))

    best_results = [[] for _ in metric_names]
    for m, min_value in enumerate(metric_min_values):
        for results in metric_sets[m]:
            if (metric_min_values[m] == 0):
                best_results[m].append(np.max(results))
            else:
                best_results[m].append(np.min(results))
    '''
    This code is used to sort the lines by min mr. Because we now use mutliple metrics and dont sort by a single
    metric, this code is not used anymore, but we keep it for now.
    min_mrs, curve_names, mrs, acps, aris, homogeneity_scores, completeness_scores, number_of_embeddings = \
        (list(t) for t in
         zip(*sorted(zip(min_mrs, curve_names, mrs, acps, aris, homogeneity_scores, completeness_scores, number_of_embeddings))))
    '''

    # How many lines to plot
    number_of_lines = len(curve_names)

    # Get various colors needed to plot
    color_map = plt.get_cmap('gist_rainbow')
    colors = [color_map(i) for i in np.linspace(0, 1, number_of_lines)]

    #Set fontsize for all plots
    plt.rcParams.update({'font.size': 12})

    # Define number of figures
    fig1 = plt.figure(figsize=(fig_width, fig_height))

    # Define Plots
    plot_grid = (3, 2)

    plots = [None] * len(metric_names)

    plots[0] = _add_cluster_subplot(plot_grid, (0, 0), metric_names[0], 1)
    plots[1] = _add_cluster_subplot(plot_grid, (0, 1), metric_names[1], 1)
    plots[2] = _add_cluster_subplot(plot_grid, (1, 0), metric_names[2], 1)
    plots[3] = _add_cluster_subplot(plot_grid, (1, 1), metric_names[3], 1)

    #Set the horizontal space between subplots
    plt.subplots_adjust(hspace=0.3)

    # Define curves and their values
    curves = [[] for _ in metric_names]

    for m, metric_set in enumerate(metric_sets):
        curves[m] = [plots[m], metric_set]

    # Plot all curves
    for index in range(number_of_lines):
        label = curve_names[index]
        for m, metric_name in enumerate(metric_names):
            label = label + '\n {} {}: {}'.format(
                'Max' if metric_min_values[m] == 0 else 'Min', metric_name,
                str(best_results[m][index]))
        color = colors[index]
        number_of_clusters = np.arange(plot_width, 0, -1)

        for plot, value in curves:
            plot.plot(number_of_clusters,
                      value[index],
                      color=color,
                      label=label)

    # Add legend and save the plot
    fig1.legend(loc='upper center', bbox_to_anchor=(0.5, 0.33), ncol=4)
    #fig1.show()
    fig1.savefig(get_result_png(plot_file_name))
    fig1.savefig(get_result_png(plot_file_name + '.svg'), format='svg')