def augment_random(image, label): #happens first if np.random.randint(0, 5) == 3: image, label = elasticDeform2D(image, label, alpha=720, sigma=24, is_random=True) #Img [x,x,slices*modalities] num_sm = image.shape[2] num_classes = label.shape[2] img_and_mask = np.zeros( (image.shape[0], image.shape[1], num_sm + num_classes)) #image = np.rollaxis(image, 2, 0) #label = np.rollaxis(label, 2, 0) img_and_mask[:, :, :num_sm] = image img_and_mask[:, :, num_sm:] = label if np.random.randint(0, 20) == 7: img_and_mask = random_rotation(img_and_mask, rg=180, row_axis=0, col_axis=1, channel_axis=2, fill_mode='constant', cval=0.) if np.random.randint(0, 20) == 7: img_and_mask = random_shift(img_and_mask, wrg=0.2, hrg=0.2, row_axis=0, col_axis=1, channel_axis=2, fill_mode='constant', cval=0.) if np.random.randint(0, 20) == 7: img_and_mask = random_shear(img_and_mask, intensity=16, row_axis=0, col_axis=1, channel_axis=2, fill_mode='constant', cval=0.) if np.random.randint(0, 20) == 7: img_and_mask = random_zoom(img_and_mask, zoom_range=(0.8, 0.8), row_axis=0, col_axis=1, channel_axis=2, fill_mode='constant', cval=0.) if np.random.randint(0, 20) == 7: img_and_mask = flip_axis(img_and_mask, axis=1) if np.random.randint(0, 20) == 7: img_and_mask = flip_axis(img_and_mask, axis=0) image = img_and_mask[:, :, :num_sm] label = img_and_mask[:, :, num_sm:] #image = np.rollaxis(image, 0, 3) #label = np.rollaxis(label, 0, 3) default_one_hot_label = [0] * num_classes default_one_hot_label[0] = 1 label[np.where(np.sum(label, axis=-1) == 0)] = default_one_hot_label label = oneHot2LabelMax(label) label = np.eye(num_classes)[label] #Label [x,x,classes] # img = [x,y,slices*modalities] # [x,y,classes] # img_and_mask = [slices*modalities + classes, x, y] return image, label
def generate_train_batches(root_path, train_list, net_input_shape, net, batchSize=1, numSlices=1, subSampAmt=-1, stride=1, downSampAmt=1, shuff=1, aug_data=1, dataset='brats', num_output_classes=2): # Create placeholders for training # (img_shape[1], img_shape[2], args.slices) print('train ' + str(dataset)) modalities = net_input_shape[2] // numSlices input_slices = numSlices img_batch = np.zeros((np.concatenate(((batchSize, ), net_input_shape))), dtype=np.float32) mask_shape = [net_input_shape[0], net_input_shape[1], num_output_classes] #print(mask_shape) mask_batch = np.zeros((np.concatenate(((batchSize, ), mask_shape))), dtype=np.float32) if dataset == 'brats': np_converter = convert_brats_data_to_numpy frame_pixels_0 = 8 frame_pixels_1 = -8 empty_mask = np.array( [one_hot_max, 1 - one_hot_max, 1 - one_hot_max, 1 - one_hot_max]) raw_x_shape = 240 raw_y_shape = 240 elif dataset in ['heart', 'spleen', 'colon', 'hepatic']: np_converter = get_np_converter(dataset) frame_pixels_0 = 0 frame_pixels_1 = net_input_shape[0] if num_output_classes == 2: empty_mask = np.array([one_hot_max, 1 - one_hot_max]) else: empty_mask = np.array([1 - one_hot_max]) raw_x_shape = net_input_shape[0] raw_y_shape = net_input_shape[1] else: assert False, 'Dataset not recognized' while True: if shuff: shuffle(train_list) count = 0 is_binary_classification = num_output_classes == 1 for i, scan_name in enumerate(train_list): try: scan_name = scan_name[0] path_to_np = join(root_path, 'np_files', basename(scan_name)[:-6] + 'npz') #print('\npath_to_np=%s'%(path_to_np)) with np.load(path_to_np) as data: train_img = data['img'] train_mask = data['mask'] except: #print('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-7])) train_img, train_mask = np_converter( root_path, scan_name, num_classes=num_output_classes) if np.array_equal(train_img, np.zeros(1)): continue else: print('\nFinished making npz file.') #print("Train mask shape {}".format(train_mask.shape)) if numSlices == 1: sideSlices = 0 else: if numSlices % 2 != 0: numSlices -= 1 sideSlices = numSlices / 2 z_shape = train_img.shape[2] indicies = np.arange(0, z_shape, stride) if shuff: shuffle(indicies) for j in indicies: if (is_binary_classification and np.sum(train_mask[:, :, j]) < 1) or (not is_binary_classification and np.sum(train_mask[:, :, j, 1:]) < 1): #print('hola') continue if img_batch.ndim == 4: img_batch[count] = 0 next_img = train_img[:, :, max(j - sideSlices, 0 ):min(j + sideSlices + 1, z_shape)].reshape( raw_x_shape, raw_y_shape, -1) insertion_index = -modalities img_index = 0 for k in range(j - sideSlices, j + sideSlices + 1): insertion_index += modalities if (k < 0): continue if (k >= z_shape): break img_batch[count, frame_pixels_0:frame_pixels_1, frame_pixels_0:frame_pixels_1, insertion_index:insertion_index + modalities] = next_img[:, :, img_index:img_index + modalities] img_index += modalities mask_batch[count] = empty_mask mask_batch[ count, frame_pixels_0:frame_pixels_1, frame_pixels_0:frame_pixels_1, :] = train_mask[:, :, j] else: print( '\nError this function currently only supports 2D and 3D data.' ) exit(0) if aug_data: img_batch[count], mask_batch[count] = augment_random( img_batch[count], mask_batch[count]) count += 1 if count % batchSize == 0: count = 0 if debug: if img_batch.ndim == 4: plt.imshow(np.squeeze(img_batch[0, :, :, 0]), cmap='gray') plt.savefig( join(root_path, 'logs', 'ex{}_train_slice1.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(img_batch[0, :, :, 4]), cmap='gray') plt.savefig( join(root_path, 'logs', 'ex{}_train_slice2.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(img_batch[0, :, :, 8]), cmap='gray') plt.savefig(join( root_path, 'logs', 'ex{}_train_slice3_main.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(mask_batch[0, :, :, 0]), alpha=0.15) plt.savefig(join(root_path, 'logs', 'ex{}_train_label.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(img_batch[0, :, :, 12]), cmap='gray') plt.savefig( join(root_path, 'logs', 'ex{}_train_slice4.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(img_batch[0, :, :, 16]), cmap='gray') plt.savefig( join(root_path, 'logs', 'ex{}_train_slice5.png'.format(j)), format='png', bbox_inches='tight') plt.close() '''elif img_batch.ndim == 5: plt.imshow(np.squeeze(img_batch[0, :, :, 0, 0]), cmap='gray') plt.imshow(np.squeeze(mask_batch[0, :, :, 0, 0]), alpha=0.15) plt.savefig(join(root_path, 'logs', 'ex_train.png'), format='png', bbox_inches='tight') plt.close()''' if net.find( 'caps' ) != -1: # if the network is capsule/segcaps structure mid_slice = input_slices // 2 start_index = mid_slice * modalities img_batch_mid_slice = img_batch[:, :, :, start_index: start_index + modalities] mask_batch_masked = oneHot2LabelMax(mask_batch) mask_batch_masked[ mask_batch_masked > 0.5] = 1.0 # Setting all other classes than background to mask mask_batch_masked = np.expand_dims(mask_batch_masked, axis=-1) mask_batch_masked_expand = np.repeat(mask_batch_masked, modalities, axis=-1) masked_img = mask_batch_masked_expand * img_batch_mid_slice '''plt.imshow(np.squeeze(img_batch[0, :, :, 0]), cmap='gray') plt.savefig(join(root_path, 'logs', '{}_img.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(mask_batch_masked[0, :, :, 0]), cmap='gray') plt.savefig(join(root_path, 'logs', '{}_mask_masked.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(mask_batch[0, :, :, 0]), cmap='gray') plt.savefig(join(root_path, 'logs', '{}_mask.png'.format(j)), format='png', bbox_inches='tight') plt.close() plt.imshow(np.squeeze(masked_img[0, :, :, 0]), cmap='gray') plt.savefig(join(root_path, 'logs', '{}_masked_img.png'.format(j)), format='png', bbox_inches='tight') plt.close()''' yield ([img_batch, mask_batch_masked], [mask_batch, masked_img]) else: yield (img_batch, mask_batch) if count != 0: #if aug_data: # img_batch[:count,...], mask_batch[:count,...] = augmentImages(img_batch[:count,...], # mask_batch[:count,...]) if net.find('caps' ) != -1: #TODO: This is not correct for several slices mid_slice = input_slices // 2 start_index = mid_slice * modalities img_batch_mid_slice = img_batch[:, :, :, start_index:start_index + modalities] mask_batch_masked = oneHot2LabelMax(mask_batch) mask_batch_masked[ mask_batch_masked > 0.5] = 1.0 # Setting all other classes than background to mask mask_batch_masked = np.expand_dims(mask_batch_masked, axis=-1) mask_batch_masked_expand = np.repeat(mask_batch_masked, modalities, axis=-1) yield ([ img_batch[:count, ...], 1 - mask_batch_masked[:count, ...] ], [ mask_batch[:count, ...], mask_batch_masked_expand[:count, ...] * img_batch_mid_slice[:count, ...] ]) else: yield (img_batch[:count, ...], mask_batch[:count, ...])
def generate_val_batches(root_path, val_list, net_input_shape, net, batchSize=1, numSlices=1, subSampAmt=-1, stride=1, downSampAmt=1, shuff=1, dataset='brats', num_output_classes=2): # Create placeholders for validation modalities = net_input_shape[2] // numSlices input_slices = numSlices img_batch = np.zeros((np.concatenate(((batchSize, ), net_input_shape))), dtype=np.float32) mask_shape = [net_input_shape[0], net_input_shape[1], num_output_classes] mask_batch = np.zeros((np.concatenate(((batchSize, ), mask_shape))), dtype=np.float32) if dataset == 'brats': np_converter = convert_brats_data_to_numpy frame_pixels_0 = 8 frame_pixels_1 = -8 empty_mask = np.array( [one_hot_max, 1 - one_hot_max, 1 - one_hot_max, 1 - one_hot_max]) raw_x_shape = 240 raw_y_shape = 240 elif dataset in ['heart', 'spleen', 'colon', 'hepatic']: np_converter = get_np_converter(dataset) frame_pixels_0 = 0 frame_pixels_1 = net_input_shape[0] if num_output_classes == 2: empty_mask = np.array([one_hot_max, 1 - one_hot_max]) else: empty_mask = np.array([1 - one_hot_max]) raw_x_shape = net_input_shape[0] raw_y_shape = net_input_shape[1] else: assert False, 'Dataset not recognized' while True: if shuff: shuffle(val_list) count = 0 for i, scan_name in enumerate(val_list): try: scan_name = scan_name[0] path_to_np = join(root_path, 'np_files', basename(scan_name)[:-6] + 'npz') with np.load(path_to_np) as data: val_img = data['img'] val_mask = data['mask'] except: print( '\nPre-made numpy array not found for {}.\nCreating now...' .format(scan_name[:-7])) val_img, val_mask = np_converter( root_path, scan_name, num_classes=num_output_classes) if np.array_equal(val_img, np.zeros(1)): continue else: print('\nFinished making npz file.') if numSlices == 1: sideSlices = 0 else: if numSlices % 2 != 0: numSlices -= 1 sideSlices = numSlices / 2 z_shape = val_img.shape[2] indicies = np.arange(0, z_shape, stride) if shuff: shuffle(indicies) for j in indicies: #if not np.any(val_mask[:, :, j:j+numSlices]): # continue if img_batch.ndim == 4: img_batch[count] = 0 next_img = val_img[:, :, max(j - sideSlices, 0 ):min(j + sideSlices + 1, z_shape)].reshape( raw_x_shape, raw_y_shape, -1) insertion_index = -modalities img_index = 0 for k in range(j - sideSlices, j + sideSlices + 1): insertion_index += modalities if (k < 0): continue if (k >= z_shape): break img_batch[count, frame_pixels_0:frame_pixels_1, frame_pixels_0:frame_pixels_1, insertion_index:insertion_index + modalities] = next_img[:, :, img_index:img_index + modalities] img_index += modalities mask_batch[count] = empty_mask mask_batch[ count, frame_pixels_0:frame_pixels_1, frame_pixels_0:frame_pixels_1, :] = val_mask[:, :, j] else: print( '\nError this function currently only supports 2D and 3D data.' ) exit(0) count += 1 if count % batchSize == 0: count = 0 if net.find( 'caps' ) != -1: # if the network is capsule/segcaps structure mid_slice = input_slices // 2 start_index = mid_slice * modalities img_batch_mid_slice = img_batch[:, :, :, start_index: start_index + modalities] mask_batch_masked = oneHot2LabelMax(mask_batch) mask_batch_masked[ mask_batch_masked > 0.5] = 1.0 # Setting all other classes than background to mask mask_batch_masked = np.expand_dims(mask_batch_masked, axis=-1) mask_batch_masked_expand = np.repeat(mask_batch_masked, modalities, axis=-1) masked_img = mask_batch_masked_expand * img_batch_mid_slice yield ([img_batch, 1 - mask_batch_masked], [mask_batch, masked_img]) else: yield (img_batch, mask_batch) if count != 0: #if aug_data: # img_batch[:count,...], mask_batch[:count,...] = augmentImages(img_batch[:count,...], # mask_batch[:count,...]) if net.find('caps' ) != -1: #TODO: This is not correct for several slices yield ([img_batch[:count, ...], mask_batch[:count, ...]], [ mask_batch[:count, ...], mask_batch[:count, ...] * img_batch[:count, ...] ]) else: yield (img_batch[:count, ...], mask_batch[:count, ...])
def test(args, test_list, model_list, net_input_shape): if args.weights_path == '': weights_path = join(args.check_dir, args.output_name + '_validation_best_model_' + args.time + '.hdf5') else: weights_path = join(args.data_root_dir, args.weights_path) if args.dataset == 'brats': RESOLUTION = 240 elif args.dataset == 'heart': RESOLUTION = 320 else: RESOLUTION = 512 output_dir = join(args.data_root_dir, 'results', args.net, 'split_' + str(args.split_num)) raw_out_dir = join(output_dir, 'raw_output') fin_out_dir = join(output_dir, 'final_output') fig_out_dir = join(output_dir, 'qual_figs') try: makedirs(raw_out_dir) except: pass try: makedirs(fin_out_dir) except: pass try: makedirs(fig_out_dir) except: pass if len(model_list) > 1: eval_model = model_list[1] else: eval_model = model_list[0] try: eval_model.load_weights(weights_path) except Exception as e: print(e) assert False, 'Unable to find weights path. Testing with random weights.' print_summary(model=eval_model, positions=[.38, .65, .75, 1.]) # Set up placeholders outfile = '' if args.compute_dice: dice_arr = np.zeros((len(test_list))) outfile += 'dice_' if args.compute_jaccard: jacc_arr = np.zeros((len(test_list))) outfile += 'jacc_' if args.compute_assd: assd_arr = np.zeros((len(test_list))) outfile += 'assd_' surf_arr = np.zeros((len(test_list)), dtype=str) dice2_arr = np.zeros((len(test_list), args.out_classes-1)) # Testing the network print('Testing... This will take some time...') with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'), 'wb') as csvfile: writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) dice_results_csv = open(join(output_dir, args.save_prefix + outfile + 'dice_scores.csv'), 'wb') dice_writer = csv.writer(dice_results_csv, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) row = ["Scan Name"] for i in range(1,args.out_classes): row.append("Dice_{}".format(i)) dice_writer.writerow(row) row = ['Scan Name'] if args.compute_dice: row.append('Dice Coefficient') if args.compute_jaccard: row.append('Jaccard Index') if args.compute_assd: row.append('Average Symmetric Surface Distance') writer.writerow(row) for i, img in enumerate(tqdm(test_list)): sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0])) img_data = sitk.GetArrayFromImage(sitk_img) num_slices = img_data.shape[0] if args.dataset == 'brats': num_slices = img_data.shape[1]#brats img_data = np.rollaxis(img_data,0,4) print(args.dataset) output_array = eval_model.predict_generator(generate_test_batches(args.data_root_dir, [img], net_input_shape, batchSize=1, numSlices=args.slices, subSampAmt=0, stride=1, dataset = args.dataset, num_output_classes=args.out_classes), steps=num_slices, max_queue_size=1, workers=1, use_multiprocessing=False, verbose=1) print('out' + str(output_array[0].shape)) if args.net.find('caps') != -1: if args.dataset == 'brats': output = output_array[0][:,8:-8,8:-8] recon = output_array[1][:,8:-8,8:-8] else: output = output_array[0][:,:,:] recon = output_array[1][:,:,:] else: if args.dataset == 'brats': output = output_array[:,8:-8,8:-8,:] else: output = output_array[:,:,:,:] if args.out_classes == 1: output_raw = output.reshape(-1,RESOLUTION,RESOLUTION,1) #binary else: output_raw = output output = oneHot2LabelMax(output) out = output.astype(np.int64) if args.out_classes == 1: outputOnehot = out.reshape(-1,RESOLUTION,RESOLUTION,1) #binary else: outputOnehot = np.eye(args.out_classes)[out].astype(np.uint8) output_img = sitk.GetImageFromArray(output) print('Segmenting Output') output_bin = threshold_mask(output, args.thresh_level) output_mask = sitk.GetImageFromArray(output_bin) slice_img = sitk.Image(RESOLUTION,RESOLUTION,num_slices, sitk.sitkUInt8) output_img.CopyInformation(slice_img) output_mask.CopyInformation(slice_img) #output_img.CopyInformation(sitk_img) #output_mask.CopyInformation(sitk_img) print('Saving Output') if args.dataset != 'luna': sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-7] + '_raw_output' + img[0][-7:])) sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-7] + '_final_output' + img[0][-7:])) # Load gt mask sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0])) gt_data = sitk.GetArrayFromImage(sitk_mask) label = gt_data.astype(np.int64) if args.out_classes == 1: gtOnehot = label.reshape(-1,RESOLUTION,RESOLUTION,1) #binary gt_label = label else: gtOnehot = np.eye(args.out_classes)[label].astype(np.uint8) gt_label = label if args.net.find('caps') != -1: create_recon_image(args, recon, img_data, gtOnehot, i=i) create_activation_image(args, output_raw, gtOnehot, slice_num=output_raw.shape[0] // 2, index=i) # Plot Qual Figure print('Creating Qualitative Figure for Quick Reference') f, ax = plt.subplots(2, 3, figsize=(10, 5)) colors = ['Greys', 'Greens', 'Reds', 'Blues'] fileTypeLength = 7 print(img_data.shape) print(outputOnehot.shape) if args.dataset == 'brats': #img_data = img_data[3] #caps? img_data = img_data[:,:,:,3] #isensee # Prediction plots ax[0,0].imshow(img_data[num_slices // 3, :, :], alpha=1, cmap='gray') for class_num in range(1, outputOnehot.shape[3]): mask = outputOnehot[num_slices // 3, :, :, class_num] mask = np.ma.masked_where(mask == 0, mask) ax[0,0].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1) ax[0,0].set_title('Slice {}/{}'.format(num_slices // 3, num_slices)) ax[0,0].axis('off') ax[0,1].imshow(img_data[num_slices // 2, :, :], alpha=1, cmap='gray') for class_num in range(1, outputOnehot.shape[3]): mask = outputOnehot[num_slices // 2, :, :, class_num] mask = np.ma.masked_where(mask == 0, mask) ax[0,1].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1) ax[0,1].set_title('Slice {}/{}'.format(num_slices // 2, num_slices)) ax[0,1].axis('off') ax[0,2].imshow(img_data[num_slices // 2 + num_slices // 4, :, :], alpha=1, cmap='gray') for class_num in range(1, outputOnehot.shape[3]): mask = outputOnehot[num_slices // 2 + num_slices // 4, :, :, class_num] mask = np.ma.masked_where(mask == 0, mask) ax[0,2].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1) ax[0,2].set_title( 'Slice {}/{}'.format(num_slices // 2 + num_slices // 4, num_slices)) ax[0,2].axis('off') # Ground truth plots ax[1,0].imshow(img_data[num_slices // 3, :, :], alpha=1, cmap='gray') ax[1,0].set_title('Slice {}/{}'.format(num_slices // 3, num_slices)) for class_num in range(1, gtOnehot.shape[3]): mask = gtOnehot[num_slices // 3, :, :, class_num] mask = np.ma.masked_where(mask == 0, mask) ax[1,0].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1) ax[1,0].axis('off') ax[1,1].imshow(img_data[num_slices // 2, :, :], alpha=1, cmap='gray') ax[1,1].set_title('Slice {}/{}'.format(num_slices // 2, num_slices)) for class_num in range(1, gtOnehot.shape[3]): mask = gtOnehot[num_slices // 2, :, :, class_num] mask = np.ma.masked_where(mask == 0, mask) ax[1,1].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1) ax[1,1].axis('off') ax[1,2].imshow(img_data[num_slices // 2 + num_slices // 4, :, :], alpha=1, cmap='gray') ax[1,2].set_title( 'Slice {}/{}'.format(num_slices // 2 + num_slices // 4, num_slices)) for class_num in range(1, gtOnehot.shape[3]): mask = gtOnehot[num_slices // 2 + num_slices // 4, :, :, class_num] mask = np.ma.masked_where(mask == 0, mask) ax[1,2].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1) ax[1,2].axis('off') fig = plt.gcf() fig.suptitle(img[0][:-fileTypeLength]) plt.savefig(join(fig_out_dir, img[0][:-fileTypeLength] + '_qual_fig' + '.png'), format='png', bbox_inches='tight') plt.close('all') else: sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-4] + '_raw_output' + img[0][-4:])) sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-4] + '_final_output' + img[0][-4:])) # Load gt mask sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0])) gt_data = sitk.GetArrayFromImage(sitk_mask) f, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].imshow(img_data[img_data.shape[0] // 3, :, :], alpha=1, cmap='gray') ax[0].imshow(output_bin[img_data.shape[0] // 3, :, :], alpha=0.5, cmap='Reds') #ax[0].imshow(gt_data[img_data.shape[0] // 3, :, :], alpha=0.2, cmap='Reds') ax[0].set_title('Slice {}/{}'.format(img_data.shape[0] // 3, img_data.shape[0])) ax[0].axis('off') ax[1].imshow(img_data[img_data.shape[0] // 2, :, :], alpha=1, cmap='gray') ax[1].imshow(output_bin[img_data.shape[0] // 2, :, :], alpha=0.5, cmap='Reds') #ax[1].imshow(gt_data[img_data.shape[0] // 2, :, :], alpha=0.2, cmap='Reds') ax[1].set_title('Slice {}/{}'.format(img_data.shape[0] // 2, img_data.shape[0])) ax[1].axis('off') ax[2].imshow(img_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=1, cmap='gray') ax[2].imshow(output_bin[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.5, cmap='Reds') #ax[2].imshow(gt_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.2, # cmap='Reds') ax[2].set_title( 'Slice {}/{}'.format(img_data.shape[0] // 2 + img_data.shape[0] // 4, img_data.shape[0])) ax[2].axis('off') fig = plt.gcf() fig.suptitle(img[0][:-4]) plt.savefig(join(fig_out_dir, img[0][:-4] + '_qual_fig' + '.png'), format='png', bbox_inches='tight') output_label = oneHot2LabelMax(outputOnehot) row = [img[0][:-4]] dice_row = [img[0][:-4]] if args.compute_dice: print('Computing Dice') dice_arr[i] = dc(outputOnehot, gtOnehot) print('\tDice: {}'.format(dice_arr[i])) row.append(dice_arr[i]) if args.compute_jaccard: print('Computing Jaccard') jacc_arr[i] = jaccard(outputOnehot, gtOnehot) print('\tJaccard: {}'.format(jacc_arr[i])) row.append(jacc_arr[i]) if args.compute_assd: print('Computing ASSD') assd_arr[i] = assd(outputOnehot, gtOnehot, voxelspacing=sitk_img.GetSpacing(), connectivity=1) print('\tASSD: {}'.format(assd_arr[i])) row.append(assd_arr[i]) try: spacing = np.array(sitk_img.GetSpacing()) if args.dataset == 'brats': spacing = spacing[1:] surf = compute_surface_distances(label, out, spacing) surf_arr[i] = str(surf) assd_score = calc_assd_scores(output_label, gt_label, args.out_classes, spacing) print(assd_score) print('\tSurface distance ' + str(surf_arr[i])) except: print("surf failed") pass #dice2_arr[i] = compute_dice_coefficient(gtOnehot, outputOnehot) dice2_arr[i] = calc_dice_scores(output_label, gt_label, args.out_classes) for score in dice2_arr[i]: dice_row.append(score) dice_writer.writerow(dice_row) print('\tMSD Dice: {}'.format(dice2_arr[i])) writer.writerow(row) dice_row = ['Average Scores'] avgs = np.mean(dice2_arr, axis=0) for avg in avgs: dice_row.append(avg) dice_writer.writerow(dice_row) row = ['Average Scores'] if args.compute_dice: row.append(np.mean(dice_arr)) if args.compute_jaccard: row.append(np.mean(jacc_arr)) if args.compute_assd: row.append(np.mean(assd_arr)) row.append(surf_arr) row.append(np.mean(dice2_arr)) writer.writerow(row) dice_results_csv.close() print('Done.')