コード例 #1
0
def visualize_embeddings(node, split, threshold, iter_no, phase=None):
    with tr.no_grad():
        if split == 'train':
            data = dl_set[node.id].data[split]
        elif split == 'test':
            data = x_seed

        Z = node.post_gmm_encode(data)

        labels = node.gmm_predict_test(Z, threshold).tolist()

        pca_z = PCA(n_components=2)

        z_transformed = pca_z.fit_transform(Z)

        color = ['r', 'b', 'g']
        colors = [color[int(x)] for x in labels]

        b = 20
        fig = plt.figure(figsize=(6.5, 6.5))

        ax = fig.add_subplot(111)
        ax.set_xlim(-b, b)
        ax.set_ylim(-b, b)

        ax.scatter(z_transformed[:, 0], z_transformed[:, 1], s=0.5, c=colors)

        node.trainer.writer[split].add_figure(
            node.name + '_' + phase + '_plots', fig, iter_no)

        path = Paths.get_result_path(node.name + '_' + split +
                                     '_embedding_plots/' + phase +
                                     '_plot_%03d' % (iter_no))
        fig.savefig(path)
        plt.close(fig)
コード例 #2
0
ファイル: gan_trainer.py プロジェクト: val-iisc/GANTree
def generate_and_save_image(plots_data,
                            iter_no,
                            image_label,
                            scatter_size=0.5,
                            log=False):
    if log:
        print('------------------------------------------------------------')
        print('%s: step %i: started generation' % (exp_name, iter_no))

    figure = viz_utils.get_figure(plots_data, scatter_size)
    if log:
        print('%s: step %i: got figure' % (exp_name, iter_no))

    figure_name = '%s-%05d.png' % (image_label, iter_no)
    figure_path = Paths.get_result_path(figure_name)
    figure.savefig(figure_path)
    plt.close(figure)

    img = np.array(im.imread(figure_path), dtype=np.uint8)
    img = img[:, :, :-1]
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    if log:
        print('%s: step %i: visualization saved' % (exp_name, iter_no))
        print('------------------------------------------------------------')
    return img, iter_no
コード例 #3
0
def plot_mean_axis_distribution(node, split, iter_no, phase):

    mean0 = node.kmeans.means[0]
    mean1 = node.kmeans.means[1]

    direction = (mean1 - mean0) / np.linalg.norm(mean1 - mean0)

    if split == 'train':
        data = dl_set[node.id].data['train']
    elif split == 'test':
        data = x_seed

    Z = node.post_gmm_encode(data)

    projection = np.zeros(Z.shape)

    for j in range(Z.shape[0]):
        projection[j] = mean0 + direction * np.dot(Z[j] - mean0, direction)

    for i in range(projection.shape[1]):
        plot_data_tensorboard = projection[:, i] 
        plot_data = [projection[:, i], mean0[i], mean1[i]]
        plt.hist(plot_data, color = ['g', 'r', 'b'])
        # plt.hist(plot_data_tensorboard, bins = 'auto', color = ['g'])

        fig_mean_axis_histogram = plt.gcf()
        node.trainer.writer[split].add_histogram(node.name + '_' + phase + '_mean_axis_' + str(i), plot_data_tensorboard, iter_no)
        # node.trainer.writer[split].add_image(node.name + '_mean_axis_' + str(i), fig_mean_axis_histogram, iter_no)
        path_mean_axis_hist = Paths.get_result_path(node.name + '_' + split +  '_mean_axis_histogram/' + phase + '%03d_%01d' % (iter_no, i))
        fig_mean_axis_histogram.savefig(path_mean_axis_hist)
        plt.close(fig_mean_axis_histogram)
コード例 #4
0
def z_histogram_plot(node, split, iter_no, phase):
    with tr.no_grad():
        if split == 'train':
            data = dl_set[node.id].data[split]
        elif split == 'test':
            data = x_seed

        Z = node.post_gmm_encode(data)

        for i in range(Z.shape[1]):
            plot_data = Z[:, i]
            plt.hist(plot_data)

            fig_histogram = plt.gcf()
            node.trainer.writer[split].add_histogram(node.name + '_' + phase + '_embedding_' + str(i), plot_data, iter_no)
            path_embedding_hist = Paths.get_result_path(node.name + '_' + split +  '_embedding_histogram/' + phase + 'embedding_%03d_%01d' % (iter_no, i))
            fig_histogram.savefig(path_embedding_hist)
            plt.close(fig_histogram)
コード例 #5
0
def plot_cluster_graphs(node, split, threshold, iter_no, phase):
    no_of_classes = H.no_of_classes

    with tr.no_grad():
        if split == 'train':
            data = dl_set[node.id].data[split]
            labels = dl_set[node.id].labels[split]
        elif split == 'test':
            data = x_seed
            labels = l_seed

        Z = node.post_gmm_encode(data)

        if split == 'train':
            p = node.kmeans.pred
        elif split == 'test':
            p = node.gmm_predict_test(Z, threshold)
        """ plot the count of unassigned vs assigned labels
            purple -- unassigned
            green -- assigned """

        unassigned_labels = [0 for i in range(no_of_classes)]
        assigned_labels = [0 for i in range(no_of_classes)]

        for i in range(len(p)):
            if p[i] == 2:
                unassigned_labels[labels[i]] += 1
            else:
                assigned_labels[labels[i]] += 1

        barWidth = 0.3
        r1 = np.arange(len(unassigned_labels))
        r2 = [x + barWidth for x in r1]

        plt.bar(r1,
                unassigned_labels,
                width=barWidth,
                color='purple',
                edgecolor='black',
                capsize=7)
        plt.bar(r2,
                assigned_labels,
                width=barWidth,
                color='green',
                edgecolor='black',
                capsize=7)
        plt.xticks([r + barWidth for r in range(len(unassigned_labels))],
                   [str(i) for i in range(no_of_classes)])
        plt.ylabel('count')

        fig_assigned = plt.gcf()
        node.trainer.writer[split].add_figure(
            node.name + '_' + phase + '_assigned_labels_count', fig_assigned,
            iter_no)
        path_assign = Paths.get_result_path(node.name + '_' + split +
                                            '_assigned/' + phase +
                                            'assigned_%03d' % (iter_no))
        fig_assigned.savefig(path_assign)
        plt.close(fig_assigned)
        """ plot the percentage of assigned labels in cluster 0 and cluster 1
            red -- cluster 0
            blue -- cluster 1 """

        l_seed_ch0 = labels[np.where(p == 0)]
        l_seed_ch1 = labels[np.where(p == 1)]

        count_ch0 = [0 for i in range(no_of_classes)]
        count_ch1 = [0 for i in range(no_of_classes)]
        prob_ch0 = [0 for i in range(no_of_classes)]
        prob_ch1 = [0 for i in range(no_of_classes)]

        for i in l_seed_ch0:
            count_ch0[i] += 1

        for i in l_seed_ch1:
            count_ch1[i] += 1

        for i in range(no_of_classes):
            if (count_ch0[i] + count_ch1[i]) != 0:
                prob_ch0[i] = count_ch0[i] * 1.0 / (count_ch0[i] +
                                                    count_ch1[i])
                prob_ch1[i] = count_ch1[i] * 1.0 / (count_ch0[i] +
                                                    count_ch1[i])
            else:
                prob_ch0[i] = 0
                prob_ch1[i] = 0

        plt.bar(r1,
                prob_ch0,
                width=barWidth,
                color='red',
                edgecolor='black',
                capsize=7)
        plt.bar(r2,
                prob_ch1,
                width=barWidth,
                color='blue',
                edgecolor='black',
                capsize=7)
        plt.xticks([r + barWidth for r in range(len(prob_ch0))],
                   [str(i) for i in range(no_of_classes)])
        plt.ylabel('percentage')

        fig_confidence = plt.gcf()
        node.trainer.writer[split].add_figure(
            node.name + '_' + phase + '_confidence', fig_confidence, iter_no)
        path_confidence = Paths.get_result_path(node.name + '_' + split +
                                                '_confidence/' + phase +
                                                'confidence_%03d' % (iter_no))
        fig_confidence.savefig(path_confidence)
        plt.close(fig_confidence)
        """ get count of points that exceed the threshold of phase 1 part 2 """

        aboveThresholdLabels_ch0 = [0 for i in range(no_of_classes)]
        aboveThresholdLabels_ch1 = [0 for i in range(no_of_classes)]

        for i in range(len(p)):
            if p[i] == 0:
                if (distance.mahalanobis(Z[i], node.kmeans.means[0],
                                         node.kmeans.covs[0])) > threshold:
                    aboveThresholdLabels_ch0[labels[i]] += 1
            elif p[i] == 1:
                if (distance.mahalanobis(Z[i], node.kmeans.means[1],
                                         node.kmeans.covs[1])) > threshold:
                    aboveThresholdLabels_ch1[labels[i]] += 1

        plt.bar(r1,
                aboveThresholdLabels_ch0,
                width=barWidth,
                color='red',
                edgecolor='black',
                capsize=7)
        plt.bar(r2,
                aboveThresholdLabels_ch1,
                width=barWidth,
                color='blue',
                edgecolor='black',
                capsize=7)
        plt.xticks(
            [r + barWidth for r in range(len(aboveThresholdLabels_ch0))],
            [str(i) for i in range(no_of_classes)])
        plt.ylabel('count')

        fig_above_threshold = plt.gcf()
        node.trainer.writer[split].add_figure(
            node.name + '_' + phase + '_above_threshold', fig_above_threshold,
            iter_no)
        path_above_threshold = Paths.get_result_path(node.name + '_' + split +
                                                     '_above_threshold/' +
                                                     phase + '%03d' %
                                                     (iter_no))
        fig_above_threshold.savefig(path_above_threshold)
        plt.close(fig_above_threshold)
コード例 #6
0
def get_labels_distribution(node, split):
    iter_no = 0
    no_of_classes = H.no_of_classes
    with tr.no_grad():

        if split == 'train':
            data = dl_set[node.id].data[split]
            labels = dl_set[node.id].labels[split]
        elif split == 'test':
            data = x_seed
            labels = l_seed

        Z = node.post_gmm_encode(data)

        pred = node.gmm_predict(Z)

        labels_ch0 = labels[np.where(pred == 0)]
        labels_ch1 = labels[np.where(pred == 1)]

        np.savez(node.name + '_' + split + '_child_labels',
                 labels_ch0=labels_ch0,
                 labels_ch1=labels_ch1)

        count_ch0 = [0 for i in range(no_of_classes)]
        count_ch1 = [0 for i in range(no_of_classes)]
        prob_ch0 = [0 for i in range(no_of_classes)]
        prob_ch1 = [0 for i in range(no_of_classes)]

        for i in labels_ch0:
            count_ch0[i] += 1

        for i in labels_ch1:
            count_ch1[i] += 1

        for i in range(no_of_classes):
            if (count_ch0[i] + count_ch1[i]) != 0:
                prob_ch0[i] = count_ch0[i] * 1.0 / (count_ch0[i] +
                                                    count_ch1[i])
                prob_ch1[i] = count_ch1[i] * 1.0 / (count_ch0[i] +
                                                    count_ch1[i])
            else:
                prob_ch0[i] = 0
                prob_ch1[i] = 0

        barWidth = 0.3
        r1 = np.arange(len(count_ch0))
        r2 = [x + barWidth for x in r1]

        plt.bar(r1,
                prob_ch0,
                width=barWidth,
                color='red',
                edgecolor='black',
                capsize=7)
        plt.bar(r2,
                prob_ch1,
                width=barWidth,
                color='blue',
                edgecolor='black',
                capsize=7)
        plt.xticks([r + barWidth for r in range(len(prob_ch0))],
                   [str(i) for i in range(no_of_classes)])
        plt.ylabel('percentage')

        fig_labels_prob = plt.gcf()
        node.trainer.writer[split].add_figure(node.name + '_labels_prob',
                                              fig_labels_prob, iter_no)
        path_labels_prob = Paths.get_result_path(
            node.name + '_' + split + '_labels_distribution/probability_%03d' %
            (iter_no))
        fig_labels_prob.savefig(path_labels_prob)
        plt.close(fig_labels_prob)

        plt.bar(r1,
                count_ch0,
                width=barWidth,
                color='red',
                edgecolor='black',
                capsize=7)
        plt.bar(r2,
                count_ch1,
                width=barWidth,
                color='blue',
                edgecolor='black',
                capsize=7)
        plt.xticks([r + barWidth for r in range(len(count_ch0))],
                   [str(i) for i in range(no_of_classes)])
        plt.ylabel('count')

        fig_labels_count = plt.gcf()
        node.trainer.writer[split].add_figure(
            node.name + '_labels_distribution', fig_labels_count, iter_no)
        path_labels_count = Paths.get_result_path(
            node.name + '_' + split + '_labels_distribution/count_%03d' %
            (iter_no))
        fig_labels_count.savefig(path_labels_count)
        plt.close(fig_labels_count)