def l_method(ax, X, method):
    l_method = agglomerative_l_method(X, method=method)
    suggest_n = len(l_method.cluster_centers_)
    cmap = get_cmap(suggest_n + 1)
    for label in range(suggest_n):
        XX = list(map(lambda xy: xy[0],
                      filter(lambda xy: xy[1] == label,
                             zip(X, l_method.labels_))))
        plot(ax, XX, c=cmap(label), edgecolors='none')
    def plot(ax, sort_fn, name=''):

        data = {
            'acc_kmeans_3':
                list(map(lambda x: x[0] / x[1], result['evaluation_kmeans_3'])),
            'badness_naive':
                list(map(lambda x: x['md'], result['badness_naive'])),
            'badness_hierarchical_voronoid_filling': result['badness_hierarchical_voronoid_filling'],
            'voronoid_sigmoid': result['voronoid_sigmoid'],
            'names': result['name'],
        }

        # sigmoids are same for all rows
        sigmoids = data['voronoid_sigmoid'][0]

        # transpose the dictionary
        keys = data.keys()
        seq = []
        for values in zip(*data.values()):
            seq.append(dict(zip(keys, values)))

        # sort by sort_fn
        seq.sort(key=sort_fn)

        # transpose it back!
        def dict_to_list(d, order):
            l = [d[o] for o in order]
            return l
        sorted_data_wo_keys = list(map(lambda d: dict_to_list(d, keys),
                                       seq))
        sorted_data = dict(zip(keys, zip(*sorted_data_wo_keys)))

        # Example data
        cnt = len(sorted_data['names'])
        x = range(cnt)

        col = sorted_data
        ax.plot(x, col['acc_kmeans_3'], 'k--', color="black", label='acc c*3')
        ax.plot(x, col['badness_naive'], 'k', color='grey', label='naive')

        hvf = result['badness_hierarchical_voronoid_filling']
        hvf_by_sigmoid = [[] for i in range(len(sigmoids))]
        for each in hvf:
            for i, s in enumerate(each):
                hvf_by_sigmoid[i].append(s)

        cmap = get_cmap(len(sigmoids) + 1)

        print('dataset:', dataset)
        for i, (sigmoid, each_hvf) in enumerate(zip(sigmoids, hvf_by_sigmoid)):
            ax.plot(x, each_hvf, 'k', color=cmap(i), label=sigmoid)
            penalty = decreasing_penalty(each_hvf)
            if i == 2:
                # the best now
                print('penalty:', penalty)
            avgs[i] += penalty


        # remove y axis
        ax.yaxis.set_major_formatter(plt.NullFormatter())

        # scale y to [0,1]
        ax.set_ylim([0, 1])

        plt.sca(ax)
        title = dataset.replace('_with_test', '')
        plt.title(title)

        # increase space between rows
        plt.subplots_adjust(hspace=.5)

        # rename the xticks
        col_names = []
        for col_name in col['names']:
            if col_name.startswith('some'):
                col_name = col_name.replace('some-', '')
                col_name, *_ = col_name.split('-prob')
            elif col_name.startswith('prob'):
                col_name = col_name.replace('prob-', '')
            col_names.append(col_name)
        plt.xticks(range(cnt), col_names, rotation=90)
Beispiel #3
0
	for i in range(output.shape[0]):
		for j in range(output.shape[1]):
			holder[i, j] = p2n[output[i, j]]

	return holder

def performance(actual, predicted):
	acc = accuracy_score(actual.ravel(), predicted.ravel())
	prec = precision_score(actual.ravel(), predicted.ravel(), average='weighted', labels=np.unique(actual)[1:])
	rec = recall_score(actual.ravel(), predicted.ravel(), average='weighted', labels=np.unique(actual)[1:])
	return np.array([acc, prec, rec])

if __name__ == "__main__":

	slide_cam_labelled_path = 'tiff/slide_cam_2_extra_slices/OASIS-TRT-20-'
	n2id, cmap = get_cmap()
	id2n = {v: k for k, v in n2id.items()}
	print(len(n2id))
	test_sub_IDs = range(1, 21)

	fs_perf = np.zeros((len(test_sub_IDs), 3))
	sc_perf = np.zeros((len(test_sub_IDs), 3))

	for j, sub_ID in enumerate(test_sub_IDs):
		unlabelled, man_labelled, FS_labelled = subject_volumes(sub_ID, '')
		uld = unlabelled.get_data()
		mld = man_labelled.get_data()

		print(len(np.unique(mld)))

# 		fsld = FS_labelled.get_data()