def calc_dice_scores(pred, gt, num_classes):
    scores = np.zeros(num_classes - 1)
    for class_idx in range(1, num_classes):
        pred_class, gt_class = create_class_mask(pred, gt, class_idx)
        coeff = compute_dice_coefficient(gt_class, pred_class)
        other_dice = dc(pred_class, gt_class)
        #print(other_dice)
        print("Dice value: {}".format(coeff))
        scores[class_idx-1] = coeff
    return scores
def test_acregnet_model(config):
    tf.reset_default_graph()
    sess = tf.Session()

    out_im_dir = config['result_dir'] + '/images'
    create_dir(out_im_dir)

    # Load test data
    test_ims, _ = DataHandler.load_images(config['test_ims_file'])
    test_lbs, _ = DataHandler.load_labels(config['test_lbs_file'])
    print('Loading test data...done')

    config['batch_size'] = test_ims.shape[0] * 2
    config['image_size'] = [256, 256]

    # Load trained model
    acregnet = ACRegNet(sess, config, 'ACRegNet', is_train=False)
    print('Building AC-RegNet model...done')
    acregnet.restore(config['ckpt_dir'])
    print('Loading trained AC-RegNet model...done')

    data = random_pairs(test_ims, test_lbs, size=config['batch_size'])
    mov_ims, fix_ims, mov_lbs, fix_lbs = data

    print('Testing...')
    # Get deformation fields
    flow = acregnet.deploy(out_im_dir, mov_ims, fix_ims, True)[1]

    # Warp label maps
    sw = SimpleWarper(sess, config['batch_size'], config['image_size'],
                      config['n_labels'])
    warp_lbs = sw.warp_label(mov_lbs, flow)

    # Compute metrics
    metrics = {'dc': [], 'hd': [], 'assd': []}
    spacing = config['std_res'] / config['image_size'][0]
    for warp_lb, fix_lb in zip(warp_lbs, fix_lbs):
        metrics['dc'].append(dc(warp_lb, fix_lb))
        metrics['hd'].append(hd(warp_lb, fix_lb, pixel_spacing=spacing))
        metrics['assd'].append(assd(warp_lb, fix_lb, pixel_spacing=spacing))

    # Save results to disk
    with open(config['result_dir'] + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Report metric values
    print('Metrics:')
    for name, values in metrics.items():
        print('- {}: mean {:.3f}, std {:.3f}'.format(name, np.mean(values),
                                                     np.std(values)))

    print('Testing done')
def test_aenet_model(config):
    tf.reset_default_graph()
    sess = tf.Session()

    out_lb_dir = config['result_dir'] + '/labels'
    create_dir(out_lb_dir)

    # Load test data
    test_lbs, _ = DataHandler.load_labels(config['test_lbs_file'])
    print('Loading test data...done')

    config['batch_size'] = test_lbs.shape[0]

    # Load trained model
    aenet = AENet(sess, config, 'AENet', is_train=False)
    print('Building AE-Net model...done')
    aenet.restore(config['ckpt_dir'])
    print('Loading trained AE-Net model...done')

    print('Testing...')
    # Get reconstructed label maps
    rec_lbs = aenet.deploy(out_lb_dir, test_lbs)

    # Compute metrics
    metrics = {'dc': [], 'hd': [], 'assd': []}
    spacing = config['std_res'] / config['image_size'][0]
    for rec_lb, fix_lb in zip(rec_lbs, test_lbs):
        metrics['dc'].append(dc(rec_lb, fix_lb))
        metrics['hd'].append(hd(rec_lb, fix_lb, pixel_spacing=spacing))
        metrics['assd'].append(assd(rec_lb, fix_lb, pixel_spacing=spacing))

    # Save results to disk
    with open(config['result_dir'] + '/metrics.pkl', 'wb') as f:
        pickle.dump(metrics, f)

    # Report metric values
    print('Metrics:')
    for name, values in metrics.items():
        print('- {}: mean {:.3f}, std {:.3f}'.format(
            name, np.mean(values), np.std(values)))

    print('Testing done')
def test(args, test_list, model_list, net_input_shape):
    if args.weights_path == '':
        weights_path = join(args.check_dir,
                            args.output_name + '_model_' + args.time + '.hdf5')
    else:
        weights_path = join(args.data_root_dir, args.weights_path)

    output_dir = join(args.data_root_dir, 'results', args.net,
                      'split_' + str(args.split_num))
    raw_out_dir = join(output_dir, 'raw_output')
    fin_out_dir = join(output_dir, 'final_output')
    fig_out_dir = join(output_dir, 'qual_figs')
    try:
        makedirs(raw_out_dir)
    except:
        pass
    try:
        makedirs(fin_out_dir)
    except:
        pass
    try:
        makedirs(fig_out_dir)
    except:
        pass

    if len(model_list) > 1:
        eval_model = model_list[1]
    else:
        eval_model = model_list[0]
    try:
        eval_model.load_weights(weights_path)
    except:
        print('Unable to find weights path. Testing with random weights.')
    print_summary(model=eval_model, positions=[.38, .65, .75, 1.])

    # Set up placeholders
    outfile = ''
    if args.compute_dice:
        dice_arr = np.zeros((len(test_list)))
        outfile += 'dice_'
    if args.compute_jaccard:
        jacc_arr = np.zeros((len(test_list)))
        outfile += 'jacc_'
    if args.compute_assd:
        assd_arr = np.zeros((len(test_list)))
        outfile += 'assd_'

    # Testing the network
    print('Testing... This will take some time...')

    with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'),
              'wb') as csvfile:
        writer = csv.writer(csvfile,
                            delimiter=',',
                            quotechar='|',
                            quoting=csv.QUOTE_MINIMAL)

        row = ['Scan Name']
        if args.compute_dice:
            row.append('Dice Coefficient')
        if args.compute_jaccard:
            row.append('Jaccard Index')
        if args.compute_assd:
            row.append('Average Symmetric Surface Distance')

        writer.writerow(row)

        for i, img in enumerate(tqdm(test_list)):
            sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0]))
            img_data = sitk.GetArrayFromImage(sitk_img)
            num_slices = img_data.shape[0]

            output_array = eval_model.predict_generator(
                generate_test_batches(args.data_root_dir, [img],
                                      net_input_shape,
                                      batchSize=args.batch_size,
                                      numSlices=args.slices,
                                      subSampAmt=0,
                                      stride=1),
                steps=num_slices,
                max_queue_size=1,
                workers=1,
                use_multiprocessing=False,
                verbose=1)

            if args.net.find('caps') != -1:
                output = output_array[0][:, :, :, 0]
                #recon = output_array[1][:,:,:,0]
            else:
                output = output_array[:, :, :, 0]

            output_img = sitk.GetImageFromArray(output)
            print('Segmenting Output')
            output_bin = threshold_mask(output, args.thresh_level)
            output_mask = sitk.GetImageFromArray(output_bin)

            output_img.CopyInformation(sitk_img)
            output_mask.CopyInformation(sitk_img)

            print('Saving Output')
            sitk.WriteImage(
                output_img,
                join(raw_out_dir, img[0][:-4] + '_raw_output' + img[0][-4:]))
            sitk.WriteImage(
                output_mask,
                join(fin_out_dir, img[0][:-4] + '_final_output' + img[0][-4:]))

            # Load gt mask
            sitk_mask = sitk.ReadImage(
                join(args.data_root_dir, 'masks', img[0]))
            gt_data = sitk.GetArrayFromImage(sitk_mask)

            # Plot Qual Figure
            print('Creating Qualitative Figure for Quick Reference')
            f, ax = plt.subplots(1, 3, figsize=(15, 5))

            ax[0].imshow(img_data[img_data.shape[0] // 3, :, :],
                         alpha=1,
                         cmap='gray')
            ax[0].imshow(output_bin[img_data.shape[0] // 3, :, :],
                         alpha=0.5,
                         cmap='Blues')
            ax[0].imshow(gt_data[img_data.shape[0] // 3, :, :],
                         alpha=0.2,
                         cmap='Reds')
            ax[0].set_title('Slice {}/{}'.format(img_data.shape[0] // 3,
                                                 img_data.shape[0]))
            ax[0].axis('off')

            ax[1].imshow(img_data[img_data.shape[0] // 2, :, :],
                         alpha=1,
                         cmap='gray')
            ax[1].imshow(output_bin[img_data.shape[0] // 2, :, :],
                         alpha=0.5,
                         cmap='Blues')
            ax[1].imshow(gt_data[img_data.shape[0] // 2, :, :],
                         alpha=0.2,
                         cmap='Reds')
            ax[1].set_title('Slice {}/{}'.format(img_data.shape[0] // 2,
                                                 img_data.shape[0]))
            ax[1].axis('off')

            ax[2].imshow(img_data[img_data.shape[0] // 2 +
                                  img_data.shape[0] // 4, :, :],
                         alpha=1,
                         cmap='gray')
            ax[2].imshow(output_bin[img_data.shape[0] // 2 +
                                    img_data.shape[0] // 4, :, :],
                         alpha=0.5,
                         cmap='Blues')
            ax[2].imshow(gt_data[img_data.shape[0] // 2 +
                                 img_data.shape[0] // 4, :, :],
                         alpha=0.2,
                         cmap='Reds')
            ax[2].set_title('Slice {}/{}'.format(
                img_data.shape[0] // 2 + img_data.shape[0] // 4,
                img_data.shape[0]))
            ax[2].axis('off')

            fig = plt.gcf()
            fig.suptitle(img[0][:-4])

            plt.savefig(join(fig_out_dir, img[0][:-4] + '_qual_fig' + '.png'),
                        format='png',
                        bbox_inches='tight')
            plt.close('all')

            row = [img[0][:-4]]
            if args.compute_dice:
                print('Computing Dice')
                dice_arr[i] = dc(output_bin, gt_data)
                print('\tDice: {}'.format(dice_arr[i]))
                row.append(dice_arr[i])
            if args.compute_jaccard:
                print('Computing Jaccard')
                jacc_arr[i] = jc(output_bin, gt_data)
                print('\tJaccard: {}'.format(jacc_arr[i]))
                row.append(jacc_arr[i])
            if args.compute_assd:
                print('Computing ASSD')
                assd_arr[i] = assd(output_bin,
                                   gt_data,
                                   voxelspacing=sitk_img.GetSpacing(),
                                   connectivity=1)
                print('\tASSD: {}'.format(assd_arr[i]))
                row.append(assd_arr[i])

            writer.writerow(row)

        row = ['Average Scores']
        if args.compute_dice:
            row.append(np.mean(dice_arr))
        if args.compute_jaccard:
            row.append(np.mean(jacc_arr))
        if args.compute_assd:
            row.append(np.mean(assd_arr))
        writer.writerow(row)

    print('Done.')
def main():

    args = sys.argv[1:]
    if len(args) == 0:
        label = "Myocardium"
        pathPrediction = "prediction"
        pathGT = "GT"
    elif len(args) == 3:
        label = args[0]
        pathPrediction = args[1]
        pathGT = args[2]
    elif len(args) == 1:
        label = args[0]
        pathPrediction = "prediction"
        pathGT = "GT"
    else:
        print("Parameter error. Please check How To Use in the Readme.")
        sys.exit(0)

    dice = []
    HD = []
    volumeDifference = []
    volumeDifferenceRate = []
    volumePrediction = []

    for filePrediction in sorted(os.listdir(pathPrediction)):
        #  load prediction mask as a nifiti, you can use nib.load as well for nifti
        prediction = sitk.ReadImage(
            os.path.join(pathPrediction, filePrediction, 'Contours',
                         filePrediction + '.nii.gz'), sitk.sitkInt16)
        #  the prediction mask array should be encoded as the categorical value from 0 to 4.
        #  the correspendnence of the Value-Class should be the same as the GT mask
        predArray = sitk.GetArrayFromImage(
            prediction)  # convert into numpy array

        # load GT mask.
        # You should modify the GT file name if its name is different to the prediction file
        GT = sitk.ReadImage(
            os.path.join(pathGT, filePrediction, 'Contours',
                         filePrediction + '.nii.gz'), sitk.sitkInt16)
        GTArray = sitk.GetArrayFromImage(GT)
        spacing = GT.GetSpacing()
        #  get the one hot GT mask of the indexed class
        #class index in GT contour nifti{"background":0 ,"cavity":1, "normal_myocardium":2, "infarction":3, "NoReflow":4}
        if label == "Myocardium":  #The Myocardium includes both the normal myocardium and scar tissue
            aGTArray = (GTArray == 2) + (GTArray == 3) + (GTArray == 4)
            aPredArray = (predArray == 2) + (predArray == 3) + (predArray == 4)
            #aPredArray[1:3] = np.zeros_like(aPredArray[1:3])

        elif label == "Infarction":
            aGTArray = (GTArray == 3) + (GTArray == 4)
            aPredArray = (predArray == 3) + (predArray == 4)
            #aPredArray[3:4] = np.zeros_like(aPredArray[3:4])

        elif label == "NoReflow":
            aGTArray = GTArray == 4
            aPredArray = predArray == 4
            #aPredArray[1:7] = np.zeros_like(aPredArray[1:7])
        else:
            raise NameError('Unknown class name')
        ###*****************metrics calculation*****************
        ###*****************commun metrics******************
        dice.append(metrics.dc(aPredArray, aGTArray))
        aVolumePred = metrics.volume(aPredArray, spacing)
        aVolumeGT = metrics.volume(aGTArray, spacing)
        volumePrediction.append(round(aVolumePred, 2))
        volumeDifference.append(round(abs(aVolumePred - aVolumeGT), 2))
        ###****************particular metric for myocardium***********
        if label == "Myocardium":
            HD.append(metrics.hd(aPredArray, aGTArray, spacing))

        ###****************particular metric for scar tissues***********
        else:
            aVolumeMyo = metrics.volume(
                (GTArray == 2) + (GTArray == 3) + (GTArray == 4), spacing)
            volumeDifferenceRate.append(
                abs(aVolumePred - aVolumeGT) / aVolumeMyo)

    avgDice = float(sum(dice)) / len(dice)
    print("Average Dice index: ", "{:.2%}".format(avgDice))
    np.savetxt('csv/Dice.csv', dice, delimiter=',', fmt='%f')
    avgVD = float(sum(volumeDifference)) / len(volumeDifference)
    print("Average volume difference: ", round(avgVD, 2),
          "mm\N{SUPERSCRIPT THREE}")
    np.savetxt('csv/volumeDif.csv', volumeDifference, delimiter=',', fmt='%f')
    if label == "Myocardium":
        avgHd = float(sum(HD)) / len(HD)
        print("Average Hausdorff distance: ", round(avgHd, 2), "mm")
        np.savetxt('csv/HD.csv', HD, delimiter=',', fmt='%f')
    else:
        avgVDR = float(sum(volumeDifferenceRate)) / len(volumeDifferenceRate)
        print(
            "Average volume difference ratio according to volume of myocardium: ",
            "{:.2%}".format(avgVDR))
        np.savetxt('csv/volumeDifRatio.csv',
                   volumeDifferenceRate,
                   delimiter=',',
                   fmt='%f')
Beispiel #6
0
def test(args, test_list, model_list, net_input_shape):
    if args.weights_path == '':
        weights_path = join(args.check_dir,
                            args.output_name + '_model_' + args.time + '.hdf5')
    else:
        weights_path = join(args.data_root_dir, args.weights_path)

    output_dir = join(args.data_root_dir, 'results', args.net,
                      'split_' + str(args.split_num))
    raw_out_dir = join(output_dir, 'raw_output')
    fin_out_dir = join(output_dir, 'final_output')
    fig_out_dir = join(output_dir, 'qual_figs')
    try:
        makedirs(raw_out_dir)
    except FileExistsError:
        pass
    try:
        makedirs(fin_out_dir)
    except FileExistsError:
        pass
    try:
        makedirs(fig_out_dir)
    except FileExistsError:
        pass

    if len(model_list) > 1:
        eval_model = model_list[1]
    else:
        eval_model = model_list[0]
    try:
        eval_model.load_weights(weights_path)
    except FileNotFoundError:
        print('Unable to find weights path. Testing with random weights.')
    print_summary(model=eval_model, positions=[.38, .65, .75, 1.])

    # Set up placeholders
    outfile = ''
    if args.compute_dice:
        dice_arr = np.zeros((len(test_list)))
        outfile += 'dice_'
    if args.compute_jaccard:
        jacc_arr = np.zeros((len(test_list)))
        outfile += 'jacc_'
    if args.compute_assd:
        assd_arr = np.zeros((len(test_list)))
        outfile += 'assd_'

    # Testing the network
    print('Testing... This will take some time...')

    with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'),
              'w') as csvfile:
        writer = csv.writer(csvfile,
                            delimiter=',',
                            quotechar='|',
                            quoting=csv.QUOTE_MINIMAL)

        row = ['Scan Name']
        if args.compute_dice:
            row.append('Dice Coefficient')
        if args.compute_jaccard:
            row.append('Jaccard Index')
        if args.compute_assd:
            row.append('Average Symmetric Surface Distance')

        writer.writerow(row)

        for i, img in enumerate(tqdm(test_list)):
            output_array = eval_model.predict_generator(
                generate_test_batches(args.data_root_dir, [img],
                                      net_input_shape,
                                      batchSize=args.batch_size,
                                      numSlices=args.slices,
                                      subSampAmt=0,
                                      stride=1),
                steps=1,
                max_queue_size=0,
                workers=1,
                use_multiprocessing=False,
                verbose=1)

            if args.net.find('caps') != -1:
                output = output_array[0][:, :, :, 0]
                # recon = output_array[1][:,:,:,0]
            else:
                output = output_array[:, :, :, 0]
            print('output')
            print(output.shape)
            # output_img = sitk.GetImageFromArray(output)
            print('Segmenting Output')
            output_bin = threshold_mask(output, args.thresh_level)
            output_bin = output_bin[0, :, :]
            # (raw_output, threshold)
            # output_mask = sitk.GetImageFromArray(output_bin)
            path_to_np = join(args.data_root_dir, 'np_files',
                              img[0][:-3] + 'npz')
            sitk_mask = np.load(path_to_np)
            print('mask')
            gt_data = sitk_mask['mask']
            gt_data = gt_data[:, :, 0]
            intn_data = sitk_mask['img']
            intn_data = intn_data[:, :, 0]
            print(gt_data.shape)

            print('Saving Output')
            indiv_fig_dir = join(fig_out_dir, args.save_prefix)
            try:
                makedirs(indiv_fig_dir)
            except FileExistsError:
                pass

            # Generarte image
            f, ax = plt.subplots(1, 3, figsize=(15, 5))
            ax[0].imshow(intn_data, alpha=1, cmap='gray')
            ax[0].imshow(output_bin, alpha=0.2, cmap='Reds')
            ax[0].set_title('Predict Mask')
            ax[1].imshow(intn_data, alpha=1, cmap='gray')
            ax[1].imshow(gt_data, alpha=0.2, cmap='Blues')
            ax[1].set_title('True Mask')
            ax[2].imshow(output_bin, alpha=0.3, cmap='Reds')
            ax[2].imshow(gt_data, alpha=0.3, cmap='Blues')
            ax[2].set_title('Comparison')
            fig = plt.gcf()
            fig.suptitle(img[0][:-4])
            plt.savefig(join(indiv_fig_dir,
                             img[0][:-4] + '_qual_fig' + '.png'),
                        format='png',
                        bbox_inches='tight')
            plt.close('all')

            row = [img[0][:-4]]
            if args.compute_dice:
                print('Computing Dice')
                dice_arr[i] = dc(output_bin, gt_data)
                print('\tDice: {}'.format(dice_arr[i]))
                row.append(dice_arr[i])
            if args.compute_jaccard:
                print('Computing Jaccard')
                jacc_arr[i] = jc(output_bin, gt_data)
                print('\tJaccard: {}'.format(jacc_arr[i]))
                row.append(jacc_arr[i])

            writer.writerow(row)

        row = ['Average Scores']
        if args.compute_dice:
            row.append(np.mean(dice_arr))
        if args.compute_jaccard:
            row.append(np.mean(jacc_arr))
        if args.compute_assd:
            row.append(np.mean(assd_arr))
        writer.writerow(row)

    print('Done.')
def test(args, test_list, model_list, net_input_shape):
    if args.weights_path == '':
        weights_path = join(args.check_dir, args.output_name + '_validation_best_model_' + args.time + '.hdf5')
    else:
        weights_path = join(args.data_root_dir, args.weights_path)

    if args.dataset == 'brats':
        RESOLUTION = 240
    elif args.dataset == 'heart':
        RESOLUTION = 320
    else:
        RESOLUTION = 512

    output_dir = join(args.data_root_dir, 'results', args.net, 'split_' + str(args.split_num))
    raw_out_dir = join(output_dir, 'raw_output')
    fin_out_dir = join(output_dir, 'final_output')
    fig_out_dir = join(output_dir, 'qual_figs')
    try:
        makedirs(raw_out_dir)
    except:
        pass
    try:
        makedirs(fin_out_dir)
    except:
        pass
    try:
        makedirs(fig_out_dir)
    except:
        pass

    if len(model_list) > 1:
        eval_model = model_list[1]
    else:
        eval_model = model_list[0]
    try:
        eval_model.load_weights(weights_path)
    except Exception as e:
        print(e)
        assert False, 'Unable to find weights path. Testing with random weights.'
    print_summary(model=eval_model, positions=[.38, .65, .75, 1.])

    # Set up placeholders
    outfile = ''
    if args.compute_dice:
        dice_arr = np.zeros((len(test_list)))
        outfile += 'dice_'
    if args.compute_jaccard:
        jacc_arr = np.zeros((len(test_list)))
        outfile += 'jacc_'
    if args.compute_assd:
        assd_arr = np.zeros((len(test_list)))
        outfile += 'assd_'
    surf_arr = np.zeros((len(test_list)), dtype=str)
    dice2_arr = np.zeros((len(test_list), args.out_classes-1))

    # Testing the network
    print('Testing... This will take some time...')

    with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'), 'wb') as csvfile:
        writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        dice_results_csv = open(join(output_dir, args.save_prefix + outfile + 'dice_scores.csv'), 'wb')
        dice_writer = csv.writer(dice_results_csv, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        row = ["Scan Name"]
        for i in range(1,args.out_classes):
            row.append("Dice_{}".format(i))
        dice_writer.writerow(row)

        row = ['Scan Name']
        if args.compute_dice:
            row.append('Dice Coefficient')
        if args.compute_jaccard:
            row.append('Jaccard Index')
        if args.compute_assd:
            row.append('Average Symmetric Surface Distance')

        writer.writerow(row)

        for i, img in enumerate(tqdm(test_list)):
            sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0]))
            img_data = sitk.GetArrayFromImage(sitk_img)
            num_slices = img_data.shape[0]

            if args.dataset == 'brats':
                num_slices = img_data.shape[1]#brats
                img_data = np.rollaxis(img_data,0,4)

            print(args.dataset)

            output_array = eval_model.predict_generator(generate_test_batches(args.data_root_dir, [img],
                                                                              net_input_shape,
                                                                              batchSize=1,
                                                                              numSlices=args.slices,
                                                                              subSampAmt=0,
                                                                              stride=1, dataset = args.dataset, num_output_classes=args.out_classes),
                                                        steps=num_slices, max_queue_size=1, workers=1,
                                                        use_multiprocessing=False, verbose=1)

            print('out' + str(output_array[0].shape))
            if args.net.find('caps') != -1:
                if args.dataset == 'brats':
                    output = output_array[0][:,8:-8,8:-8]
                    recon = output_array[1][:,8:-8,8:-8]
                else:
                    output = output_array[0][:,:,:]
                    recon = output_array[1][:,:,:]
            else:
                if args.dataset == 'brats':
                    output = output_array[:,8:-8,8:-8,:]
                else:
                    output = output_array[:,:,:,:]


            if args.out_classes == 1:
                output_raw = output.reshape(-1,RESOLUTION,RESOLUTION,1) #binary
            else:
                output_raw = output
                output = oneHot2LabelMax(output)

            out = output.astype(np.int64)

            if args.out_classes == 1:
                outputOnehot = out.reshape(-1,RESOLUTION,RESOLUTION,1) #binary
            else:
                outputOnehot = np.eye(args.out_classes)[out].astype(np.uint8)


            output_img = sitk.GetImageFromArray(output)

            print('Segmenting Output')

            output_bin = threshold_mask(output, args.thresh_level)

            output_mask = sitk.GetImageFromArray(output_bin)

            slice_img = sitk.Image(RESOLUTION,RESOLUTION,num_slices, sitk.sitkUInt8)

            output_img.CopyInformation(slice_img)
            output_mask.CopyInformation(slice_img)


            #output_img.CopyInformation(sitk_img)
            #output_mask.CopyInformation(sitk_img)

            print('Saving Output')
            if args.dataset != 'luna':
                sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-7] + '_raw_output' + img[0][-7:]))
                sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-7] + '_final_output' + img[0][-7:]))

                # Load gt mask
                sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0]))
                gt_data = sitk.GetArrayFromImage(sitk_mask)
                label = gt_data.astype(np.int64)

                if args.out_classes == 1:
                    gtOnehot = label.reshape(-1,RESOLUTION,RESOLUTION,1) #binary
                    gt_label = label
                else:
                    gtOnehot = np.eye(args.out_classes)[label].astype(np.uint8)
                    gt_label = label

                if args.net.find('caps') != -1:
                    create_recon_image(args, recon, img_data, gtOnehot, i=i)
                
                create_activation_image(args, output_raw, gtOnehot, slice_num=output_raw.shape[0] // 2, index=i)
                # Plot Qual Figure
                print('Creating Qualitative Figure for Quick Reference')
                f, ax = plt.subplots(2, 3, figsize=(10, 5))

                colors = ['Greys', 'Greens', 'Reds', 'Blues']
                fileTypeLength = 7

                print(img_data.shape)
                print(outputOnehot.shape)
                if args.dataset == 'brats':
                    #img_data = img_data[3] #caps?
                    img_data = img_data[:,:,:,3] #isensee

                # Prediction plots
                ax[0,0].imshow(img_data[num_slices // 3, :, :], alpha=1, cmap='gray')
                for class_num in range(1, outputOnehot.shape[3]):
                    mask = outputOnehot[num_slices // 3, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[0,0].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[0,0].set_title('Slice {}/{}'.format(num_slices // 3, num_slices))
                ax[0,0].axis('off')

                ax[0,1].imshow(img_data[num_slices // 2, :, :], alpha=1, cmap='gray')
                for class_num in range(1, outputOnehot.shape[3]):
                    mask = outputOnehot[num_slices // 2, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[0,1].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[0,1].set_title('Slice {}/{}'.format(num_slices // 2, num_slices))
                ax[0,1].axis('off')

                ax[0,2].imshow(img_data[num_slices // 2 + num_slices // 4, :, :], alpha=1, cmap='gray')
                for class_num in range(1, outputOnehot.shape[3]):
                    mask = outputOnehot[num_slices // 2 + num_slices // 4, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[0,2].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[0,2].set_title(
                    'Slice {}/{}'.format(num_slices // 2 + num_slices // 4, num_slices))
                ax[0,2].axis('off')

                # Ground truth plots
                ax[1,0].imshow(img_data[num_slices // 3, :, :], alpha=1, cmap='gray')
                ax[1,0].set_title('Slice {}/{}'.format(num_slices // 3, num_slices))
                for class_num in range(1, gtOnehot.shape[3]):
                    mask = gtOnehot[num_slices // 3, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[1,0].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[1,0].axis('off')

                ax[1,1].imshow(img_data[num_slices // 2, :, :], alpha=1, cmap='gray')
                ax[1,1].set_title('Slice {}/{}'.format(num_slices // 2, num_slices))
                for class_num in range(1, gtOnehot.shape[3]):
                    mask = gtOnehot[num_slices // 2, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[1,1].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[1,1].axis('off')

                ax[1,2].imshow(img_data[num_slices // 2 + num_slices // 4, :, :], alpha=1, cmap='gray')
                ax[1,2].set_title(
                    'Slice {}/{}'.format(num_slices // 2 + num_slices // 4, num_slices))
                for class_num in range(1, gtOnehot.shape[3]):
                    mask = gtOnehot[num_slices // 2 + num_slices // 4, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[1,2].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[1,2].axis('off')

                fig = plt.gcf()
                fig.suptitle(img[0][:-fileTypeLength])

                plt.savefig(join(fig_out_dir, img[0][:-fileTypeLength] + '_qual_fig' + '.png'),
                            format='png', bbox_inches='tight')
                plt.close('all')
            else:
                sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-4] + '_raw_output' + img[0][-4:]))
                sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-4] + '_final_output' + img[0][-4:]))

                # Load gt mask
                sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0]))
                gt_data = sitk.GetArrayFromImage(sitk_mask)

                f, ax = plt.subplots(1, 3, figsize=(15, 5))

                ax[0].imshow(img_data[img_data.shape[0] // 3, :, :], alpha=1, cmap='gray')
                ax[0].imshow(output_bin[img_data.shape[0] // 3, :, :], alpha=0.5, cmap='Reds')
                #ax[0].imshow(gt_data[img_data.shape[0] // 3, :, :], alpha=0.2, cmap='Reds')
                ax[0].set_title('Slice {}/{}'.format(img_data.shape[0] // 3, img_data.shape[0]))
                ax[0].axis('off')

                ax[1].imshow(img_data[img_data.shape[0] // 2, :, :], alpha=1, cmap='gray')
                ax[1].imshow(output_bin[img_data.shape[0] // 2, :, :], alpha=0.5, cmap='Reds')
                #ax[1].imshow(gt_data[img_data.shape[0] // 2, :, :], alpha=0.2, cmap='Reds')
                ax[1].set_title('Slice {}/{}'.format(img_data.shape[0] // 2, img_data.shape[0]))
                ax[1].axis('off')

                ax[2].imshow(img_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=1, cmap='gray')
                ax[2].imshow(output_bin[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.5,
                             cmap='Reds')
                #ax[2].imshow(gt_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.2,
                #             cmap='Reds')
                ax[2].set_title(
                    'Slice {}/{}'.format(img_data.shape[0] // 2 + img_data.shape[0] // 4, img_data.shape[0]))
                ax[2].axis('off')

                fig = plt.gcf()
                fig.suptitle(img[0][:-4])

                plt.savefig(join(fig_out_dir, img[0][:-4] + '_qual_fig' + '.png'),
                            format='png', bbox_inches='tight')

                
            output_label = oneHot2LabelMax(outputOnehot)

            row = [img[0][:-4]]
            dice_row = [img[0][:-4]]
            if args.compute_dice:
                print('Computing Dice')
                dice_arr[i] = dc(outputOnehot, gtOnehot)
                print('\tDice: {}'.format(dice_arr[i]))
                row.append(dice_arr[i])
            if args.compute_jaccard:
                print('Computing Jaccard')
                jacc_arr[i] = jaccard(outputOnehot, gtOnehot)
                print('\tJaccard: {}'.format(jacc_arr[i]))
                row.append(jacc_arr[i])
            if args.compute_assd:
                print('Computing ASSD')
                assd_arr[i] = assd(outputOnehot, gtOnehot, voxelspacing=sitk_img.GetSpacing(), connectivity=1)
                print('\tASSD: {}'.format(assd_arr[i]))
                row.append(assd_arr[i])
            try:
                spacing = np.array(sitk_img.GetSpacing())
                if args.dataset == 'brats':
                   spacing = spacing[1:]
                surf = compute_surface_distances(label, out, spacing)
                surf_arr[i] = str(surf)
                assd_score = calc_assd_scores(output_label, gt_label, args.out_classes, spacing)
                print(assd_score)
                print('\tSurface distance ' + str(surf_arr[i]))
            except:
                print("surf failed")
                pass
            #dice2_arr[i] = compute_dice_coefficient(gtOnehot, outputOnehot)
            dice2_arr[i] = calc_dice_scores(output_label, gt_label, args.out_classes)
            for score in dice2_arr[i]:
                dice_row.append(score)
            dice_writer.writerow(dice_row)
            
            print('\tMSD Dice: {}'.format(dice2_arr[i]))
            
            writer.writerow(row)

            
        dice_row = ['Average Scores']
        avgs = np.mean(dice2_arr, axis=0)
        for avg in avgs:
            dice_row.append(avg)
        dice_writer.writerow(dice_row)
            
        
        row = ['Average Scores']
        if args.compute_dice:
            row.append(np.mean(dice_arr))
        if args.compute_jaccard:
            row.append(np.mean(jacc_arr))
        if args.compute_assd:
            row.append(np.mean(assd_arr))
        row.append(surf_arr)
        row.append(np.mean(dice2_arr))
        
        writer.writerow(row)
        dice_results_csv.close()
      

    print('Done.')