def eval(self, set_name=None):
        seg_conn_3d = {}
        seg_ws_3d  = {}
        arand_conn_eval = {}
        arand_ws_eval = {}
        # voi_eval = {}
        eval_sets = [set_name] if set_name else self.exp_cfg.dataset.subset

        print('subset data = {}'.format(eval_sets))
        for d_set in eval_sets:
            #d_set = 'Set_A'
            print('predicting {} ...'.format(d_set))
            self.exp_cfg.dataset.set_current_subDataset(d_set)
            seg_lbs = self.exp_cfg.dataset.get_label()
            preds  = self.predict(d_set)
            # seg_conn_3d[d_set], seg_ws_3d[d_set] = self.seg_predictor.predict()
            # x_size = seg_conn_3d[d_set].shape[1]
            # y_size = seg_conn_3d[d_set].shape[2]
            # seg_lbs = seg_lbs[:, :x_size, :y_size]
            arand_conn_eval[d_set] = adapted_rand(seg_conn_3d[d_set], seg_lbs)
            # arand_ws_eval[d_set] = adapted_rand(seg_ws_3d[d_set], seg_lbs)
            # voi_conn_eval[d_set] = voi(seg_conn_3d[d_set], seg_lbs)
            # voi_ws_eval[d_set] = voi(seg_ws_3d[d_set], seg_lbs)
            print('arand,voi conn for {} = {},{}'.format(d_set, arand_conn_eval[d_set],voi_conn_eval[d_set]))
            print('arand. voi ws for {} = {},{}'.format(d_set, arand_ws_eval[d_set],voi_ws_eval[d_set]))

        return arand_conn_eval, voi_conn_eval
Exemple #2
0
def segment_all_and_makeSubmission(hd5_file,data_name_for_seg ='distance',\
                                  im_dir=None,\
                                  task_set='test',
                                  save_seg_to_img=False, \
                                  rep_bad_slice=True,
                                  evaluation =True):
    seg_dict = {}
    for set_n in set_names:
        raw_im = get_rawim_or_labels(task_set=task_set,
                                     subset_name=set_n,
                                     data_set='image')
        raw_im = replace_bad_slice(raw_im,set_n) \
                              if rep_bad_slice \
                              else raw_im
        seg3D = segment(hd5_file, set_n, raw_im, data_name_for_seg,
                        rep_bad_slice)
        if evaluation and task_set == 'valid':
            gt_label = get_rawim_or_labels(task_set=task_set,
                                           subset_name=set_n,
                                           data_set='label')
            gt_label = gt_label[0:len(seg3D)]
            arand = adapted_rand(seg3D, gt_label)
            split, merge = voi(seg3D, gt_label)
            print('arand {} ,(split,merge) =({},{})'.format(
                arand, split, merge))
        seg_dict[set_n] = seg3D

        if save_seg_to_img:
            assert im_dir
            im_save_dir = im_dir + '_' + set_n
            if not os.path.exists(im_save_dir):
                os.mkdir(im_save_dir)
            save_seg3D_to_image(seg3D, im_save_dir, data_name_for_seg)
def evaluation(subset_name):
    seg=read_predict_segmentation('valid', subset_name)
    _, lbs=read_raw_image_data('valid',subset_name)
    #pdb.set_trace()
    np.squeeze(seg)
    arand =adapted_rand(seg,lbs)
    split,merge = voi(seg,lbs)

    print('{} --- arand : (split,merge) {}:({},{})'.format(subset_name, arand,split,merge))


if __name__ == '__main__':
	hf5 = h5py.File('tempdata/seg_final_plus_distance.h5','r')
	#hf5 = h5py.File('tempdata/seg_fina_distance_only.h5','r')
	#hf5 =h5py.File('tempdata/seg_mu1_distance.h5', 'r')
	dataset = 'Set_A'
	#dataset = 'A'
	d_orig,d_combine,tg = get_data(hf5,dataset)
	
	t    = tg[100:,:,:]
	thresholds = np.linspace(16,35,15)
	arands = []
	print ('test {}'.format(dataset))
	for th in thresholds:
		#d_seg= watershed_seg2(d_orig[100:,:,:], d_combine[100:,:,:], threshold = th)
		d_seg= watershed_seg(d_combine[100:,:,:], threshold = th)
		#d_seg= watershed_seg(d_orig[100:,:,:], threshold = th)
		arand = adapted_rand(d_seg.astype(np.int), t)
		split, merge = voi(d_seg.astype(np.int), t)
		arands.append(arand)
		print('arand, split, merge = {:.3f}, {:.3f}, {:.3f} for threshold = {:.3f}'.format(arand,split,merge,th))
		#print('arand ={}  for threshold= {}'.format(arand,th))
	plt.plot(arands)
	plt.title('Set_' + dataset)
	plt.show()



Exemple #5
0
def evaluate(seg_gt,seg_pred):
    arand = adapted_rand(seg_pred.astype(np.int), seg_gt)
    split, merge = voi(seg_pred.astype(np.int), seg_gt)
    print('rand , voi (Merg, Split) ={:.3f}, ({:.3f},{:.3f})'.format(arand,merge,split))
def test_slice_connector_on_GTdata(net_model, image, gt_label):
    z_slice=len(gt_label)
    # gt_slice_diff_lable=gt_label.copy()
    # # re-label each slice
    # #d=label(gt_slice_diff_lable[1])

    # gt_slice_diff_lable=np.stack([label(gslice)+1000*idx for idx, gslice in enumerate(gt_slice_diff_lable)],axis =0)
    # #pdb.set_trace()
    #for slice_idx in range(z_slice-1): #we assume that first axis is z-axis

    #gt_slice_diff_lable = build_2Dslice_ids_from_3Dseg(gt_label)


    gt_slice_diff_lable = build_2Dslice_ids_from_3Dseg(gt_label)
    adj_matrix, slice_ids_and_counts, (flat_obj_ids_list, flat_obj_ids_counts) \
                        = build_adjecency_objMatrix(gt_slice_diff_lable)
    
    ids_idx_dict = {sid:idx for idx,sid in enumerate(flat_obj_ids_list)}

    test_z_len=10
    for slice_idx in range(0,test_z_len): #we assume that first axis is z-axis
        gt_slice = gt_slice_diff_lable[slice_idx]
        #pdb.set_trace()
        #lb_ids,counts = slice_ids_and_counts[slice_idx]
        
        #lb_ids,counts = np.unique(gt_slice,return_counts=True)
        # sort_idx = np.argsort(counts)[::-1] # we want the descent order
        # lb_ids   = lb_ids[np.argsort(sort_idx)]
        # counts   = counts[np.argsort(sort_idx)]
        im_input = image[slice_idx:slice_idx+2].copy() -127.0
        #pdb.set_trace()
        gt_connected=set()
        #for idx,(s_id,count) in enumerate(zip(lb_ids,counts)):
        for idx,(s_id,count) in enumerate(slice_ids_and_counts[slice_idx]):
            mask = (gt_slice == s_id).astype(np.int)
            probs,ids=compute_next_masked_slice_connect_probs(net_model,
                                                              mask,
                                                              s_id,
                                                              im_input, 
                                                              gt_slice_diff_lable[slice_idx+1])

            

            for prob,next_id in zip(probs.data.cpu().numpy()[:,1],ids):
                #if prob >0.7:
                adj_matrix[ids_idx_dict[s_id],ids_idx_dict[next_id]] =prob
                adj_matrix[ids_idx_dict[next_id],ids_idx_dict[s_id]] =prob
                    # if next_id not in gt_connected:
                    #     fill_mask =  (gt_slice_diff_lable[slice_idx+1] ==next_id)
                    #     gt_slice_diff_lable[slice_idx+1,:,:][fill_mask] = s_id
                    #     gt_connected.add(next_id)
                    # else:
                    #     fill_mask =  (gt_slice_diff_lable==s_id)

            #print('slice = {}, id ={}, probs ={}'.format(slice_idx,s_id,probs))
            #s_ids_in_next_slice = np.unique(gt_label[slice_idx +1][mask])
        print('sclie {} of {}'.format(slice_idx,z_slice-1) )


    connected_3D_seglabel=adj_matrix_to_3D_segLabel(adj_matrix,gt_slice_diff_lable,flat_obj_ids_list)


    #pdb.set_trace()
    arand = adapted_rand(connected_3D_seglabel.astype(np.int)[0:test_z_len], gt_label[0:test_z_len])
    split, merge = voi(connected_3D_seglabel.astype(np.int)[0:test_z_len], gt_label[0:test_z_len])

    print('arand : {} (split, merge) : ({},{})'.format(arand,split,merge))
        net_model = build_network(model_file)
        adj_matrix_prob,gt_slice_diff_lable,flat_obj_ids_list= compute_adj_matrix(net_model, data, seg_label,use_true_label=False)
        #test_slice_connector_on_GTdata(net_model, data, seg_label)

    

    '''connecting'''
    
    #connected_3D_seglabel=adj_matrix_to_3D_segLabel((adj_matrix_prob>prob_threshold).astype(np.int),gt_slice_diff_lable,flat_obj_ids_list)
    prob_threshold =0.8
    connected_3D_seglabel=adj_matrix_to_3D_segLabel(adj_matrix_prob,
                                                   gt_slice_diff_lable,
                                                   flat_obj_ids_list,
                                                   threshold=prob_threshold)

    arand = adapted_rand(connected_3D_seglabel.astype(np.int)[0:test_z_len], seg_label[0:test_z_len])
    split, merge = voi(connected_3D_seglabel.astype(np.int)[0:test_z_len], seg_label[0:test_z_len])

    print('prob_threshodl : {} for {},  arand : {} (split, merge) : ({},{})'.format(prob_threshold, set_name, arand,split,merge))

    # gt_slice_diff_lable = build_2Dslice_ids_from_3Dseg(seg_label)
    # adj_matrix, slice_ids_and_count, (obj_ids_list,obj_ids_counts) \
    #                     = build_adjecency_objMatrix(gt_slice_diff_lable)
    # ids_idx_dict = {sid:idx for idx,sid in enumerate(obj_ids_list)}



    # thresholds = np.linspace(16,35,15)
    # arands = []
    # print ('test {}'.format(dataset))
    # for th in thresholds:
print('dt = {}'.format(distance.shape))
#gz,gx,gy =  np.gradient(distance,1,1,1, edge_order =1)
gx, gy = np.gradient(distance[0], 1, 1, edge_order=1)

# Distance to the background for pixels of the skeleton
#dist_on_skel = distance * skel12

#dist_on_skel = binary_erosion(skel)
#
#
#
#
#seg_labels[seg_labels ==28]=79
#seg_labels[seg_labels ==93]=79

arand = adapted_rand(seg_labels.astype(np.int), gt_seg)
split, merge = voi(seg_labels.astype(np.int), gt_seg)

print('rand , voi Merg, Split ={:.3f}, ({:.3f},{:.3f})'.format(
    arand, merge, split))
slice_idx = 0

fig, axs = plt.subplots(2, 2, figsize=(8, 8))
axs[0, 0].imshow(seg_labels[slice_idx],
                 cmap=plt.cm.spectral,
                 interpolation='nearest')
axs[0, 0].title.set_text('seg_waterseg')
axs[0, 0].axis('off')

axs[0, 1].imshow(gt_seg[slice_idx],
                 cmap=plt.cm.spectral,