def augment_random(image, label):

    #happens first
    if np.random.randint(0, 5) == 3:
        image, label = elasticDeform2D(image,
                                       label,
                                       alpha=720,
                                       sigma=24,
                                       is_random=True)

    #Img [x,x,slices*modalities]
    num_sm = image.shape[2]
    num_classes = label.shape[2]
    img_and_mask = np.zeros(
        (image.shape[0], image.shape[1], num_sm + num_classes))

    #image = np.rollaxis(image, 2, 0)
    #label = np.rollaxis(label, 2, 0)

    img_and_mask[:, :, :num_sm] = image
    img_and_mask[:, :, num_sm:] = label

    if np.random.randint(0, 20) == 7:
        img_and_mask = random_rotation(img_and_mask,
                                       rg=180,
                                       row_axis=0,
                                       col_axis=1,
                                       channel_axis=2,
                                       fill_mode='constant',
                                       cval=0.)

    if np.random.randint(0, 20) == 7:
        img_and_mask = random_shift(img_and_mask,
                                    wrg=0.2,
                                    hrg=0.2,
                                    row_axis=0,
                                    col_axis=1,
                                    channel_axis=2,
                                    fill_mode='constant',
                                    cval=0.)

    if np.random.randint(0, 20) == 7:
        img_and_mask = random_shear(img_and_mask,
                                    intensity=16,
                                    row_axis=0,
                                    col_axis=1,
                                    channel_axis=2,
                                    fill_mode='constant',
                                    cval=0.)

    if np.random.randint(0, 20) == 7:
        img_and_mask = random_zoom(img_and_mask,
                                   zoom_range=(0.8, 0.8),
                                   row_axis=0,
                                   col_axis=1,
                                   channel_axis=2,
                                   fill_mode='constant',
                                   cval=0.)

    if np.random.randint(0, 20) == 7:
        img_and_mask = flip_axis(img_and_mask, axis=1)

    if np.random.randint(0, 20) == 7:
        img_and_mask = flip_axis(img_and_mask, axis=0)

    image = img_and_mask[:, :, :num_sm]
    label = img_and_mask[:, :, num_sm:]

    #image = np.rollaxis(image, 0, 3)
    #label = np.rollaxis(label, 0, 3)

    default_one_hot_label = [0] * num_classes
    default_one_hot_label[0] = 1
    label[np.where(np.sum(label, axis=-1) == 0)] = default_one_hot_label

    label = oneHot2LabelMax(label)
    label = np.eye(num_classes)[label]

    #Label [x,x,classes]

    # img = [x,y,slices*modalities]
    # [x,y,classes]

    # img_and_mask = [slices*modalities + classes, x, y]

    return image, label
def generate_train_batches(root_path,
                           train_list,
                           net_input_shape,
                           net,
                           batchSize=1,
                           numSlices=1,
                           subSampAmt=-1,
                           stride=1,
                           downSampAmt=1,
                           shuff=1,
                           aug_data=1,
                           dataset='brats',
                           num_output_classes=2):
    # Create placeholders for training
    # (img_shape[1], img_shape[2], args.slices)
    print('train ' + str(dataset))
    modalities = net_input_shape[2] // numSlices
    input_slices = numSlices
    img_batch = np.zeros((np.concatenate(((batchSize, ), net_input_shape))),
                         dtype=np.float32)
    mask_shape = [net_input_shape[0], net_input_shape[1], num_output_classes]
    #print(mask_shape)
    mask_batch = np.zeros((np.concatenate(((batchSize, ), mask_shape))),
                          dtype=np.float32)

    if dataset == 'brats':
        np_converter = convert_brats_data_to_numpy
        frame_pixels_0 = 8
        frame_pixels_1 = -8
        empty_mask = np.array(
            [one_hot_max, 1 - one_hot_max, 1 - one_hot_max, 1 - one_hot_max])
        raw_x_shape = 240
        raw_y_shape = 240
    elif dataset in ['heart', 'spleen', 'colon', 'hepatic']:
        np_converter = get_np_converter(dataset)
        frame_pixels_0 = 0
        frame_pixels_1 = net_input_shape[0]
        if num_output_classes == 2:
            empty_mask = np.array([one_hot_max, 1 - one_hot_max])
        else:
            empty_mask = np.array([1 - one_hot_max])
        raw_x_shape = net_input_shape[0]
        raw_y_shape = net_input_shape[1]
    else:
        assert False, 'Dataset not recognized'

    while True:
        if shuff:
            shuffle(train_list)
        count = 0
        is_binary_classification = num_output_classes == 1
        for i, scan_name in enumerate(train_list):
            try:
                scan_name = scan_name[0]
                path_to_np = join(root_path, 'np_files',
                                  basename(scan_name)[:-6] + 'npz')
                #print('\npath_to_np=%s'%(path_to_np))
                with np.load(path_to_np) as data:
                    train_img = data['img']
                    train_mask = data['mask']
            except:
                #print('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-7]))
                train_img, train_mask = np_converter(
                    root_path, scan_name, num_classes=num_output_classes)
                if np.array_equal(train_img, np.zeros(1)):
                    continue
                else:
                    print('\nFinished making npz file.')
            #print("Train mask shape {}".format(train_mask.shape))

            if numSlices == 1:
                sideSlices = 0
            else:
                if numSlices % 2 != 0:
                    numSlices -= 1
                sideSlices = numSlices / 2

            z_shape = train_img.shape[2]
            indicies = np.arange(0, z_shape, stride)

            if shuff:
                shuffle(indicies)
            for j in indicies:

                if (is_binary_classification and np.sum(train_mask[:, :, j]) <
                        1) or (not is_binary_classification
                               and np.sum(train_mask[:, :, j, 1:]) < 1):
                    #print('hola')
                    continue
                if img_batch.ndim == 4:
                    img_batch[count] = 0
                    next_img = train_img[:, :,
                                         max(j - sideSlices, 0
                                             ):min(j + sideSlices +
                                                   1, z_shape)].reshape(
                                                       raw_x_shape,
                                                       raw_y_shape, -1)
                    insertion_index = -modalities
                    img_index = 0
                    for k in range(j - sideSlices, j + sideSlices + 1):
                        insertion_index += modalities
                        if (k < 0): continue
                        if (k >= z_shape): break
                        img_batch[count, frame_pixels_0:frame_pixels_1,
                                  frame_pixels_0:frame_pixels_1,
                                  insertion_index:insertion_index +
                                  modalities] = next_img[:, :,
                                                         img_index:img_index +
                                                         modalities]
                        img_index += modalities
                    mask_batch[count] = empty_mask
                    mask_batch[
                        count, frame_pixels_0:frame_pixels_1,
                        frame_pixels_0:frame_pixels_1, :] = train_mask[:, :, j]
                else:
                    print(
                        '\nError this function currently only supports 2D and 3D data.'
                    )
                    exit(0)

                if aug_data:
                    img_batch[count], mask_batch[count] = augment_random(
                        img_batch[count], mask_batch[count])
                count += 1
                if count % batchSize == 0:
                    count = 0
                    if debug:
                        if img_batch.ndim == 4:
                            plt.imshow(np.squeeze(img_batch[0, :, :, 0]),
                                       cmap='gray')
                            plt.savefig(
                                join(root_path, 'logs',
                                     'ex{}_train_slice1.png'.format(j)),
                                format='png',
                                bbox_inches='tight')
                            plt.close()
                            plt.imshow(np.squeeze(img_batch[0, :, :, 4]),
                                       cmap='gray')
                            plt.savefig(
                                join(root_path, 'logs',
                                     'ex{}_train_slice2.png'.format(j)),
                                format='png',
                                bbox_inches='tight')
                            plt.close()
                            plt.imshow(np.squeeze(img_batch[0, :, :, 8]),
                                       cmap='gray')
                            plt.savefig(join(
                                root_path, 'logs',
                                'ex{}_train_slice3_main.png'.format(j)),
                                        format='png',
                                        bbox_inches='tight')
                            plt.close()
                            plt.imshow(np.squeeze(mask_batch[0, :, :, 0]),
                                       alpha=0.15)
                            plt.savefig(join(root_path, 'logs',
                                             'ex{}_train_label.png'.format(j)),
                                        format='png',
                                        bbox_inches='tight')
                            plt.close()
                            plt.imshow(np.squeeze(img_batch[0, :, :, 12]),
                                       cmap='gray')
                            plt.savefig(
                                join(root_path, 'logs',
                                     'ex{}_train_slice4.png'.format(j)),
                                format='png',
                                bbox_inches='tight')
                            plt.close()
                            plt.imshow(np.squeeze(img_batch[0, :, :, 16]),
                                       cmap='gray')
                            plt.savefig(
                                join(root_path, 'logs',
                                     'ex{}_train_slice5.png'.format(j)),
                                format='png',
                                bbox_inches='tight')
                            plt.close()
                        '''elif img_batch.ndim == 5:
                            plt.imshow(np.squeeze(img_batch[0, :, :, 0, 0]), cmap='gray')
                            plt.imshow(np.squeeze(mask_batch[0, :, :, 0, 0]), alpha=0.15)
                        plt.savefig(join(root_path, 'logs', 'ex_train.png'), format='png', bbox_inches='tight')
                        plt.close()'''
                    if net.find(
                            'caps'
                    ) != -1:  # if the network is capsule/segcaps structure
                        mid_slice = input_slices // 2
                        start_index = mid_slice * modalities
                        img_batch_mid_slice = img_batch[:, :, :, start_index:
                                                        start_index +
                                                        modalities]

                        mask_batch_masked = oneHot2LabelMax(mask_batch)
                        mask_batch_masked[
                            mask_batch_masked >
                            0.5] = 1.0  # Setting all other classes than background to mask
                        mask_batch_masked = np.expand_dims(mask_batch_masked,
                                                           axis=-1)
                        mask_batch_masked_expand = np.repeat(mask_batch_masked,
                                                             modalities,
                                                             axis=-1)

                        masked_img = mask_batch_masked_expand * img_batch_mid_slice
                        '''plt.imshow(np.squeeze(img_batch[0, :, :, 0]), cmap='gray')
                        plt.savefig(join(root_path, 'logs', '{}_img.png'.format(j)), format='png', bbox_inches='tight')
                        plt.close()
                        plt.imshow(np.squeeze(mask_batch_masked[0, :, :, 0]), cmap='gray')
                        plt.savefig(join(root_path, 'logs', '{}_mask_masked.png'.format(j)), format='png', bbox_inches='tight')
                        plt.close()
                        plt.imshow(np.squeeze(mask_batch[0, :, :, 0]), cmap='gray')
                        plt.savefig(join(root_path, 'logs', '{}_mask.png'.format(j)), format='png', bbox_inches='tight')
                        plt.close()
                        plt.imshow(np.squeeze(masked_img[0, :, :, 0]), cmap='gray')
                        plt.savefig(join(root_path, 'logs', '{}_masked_img.png'.format(j)), format='png', bbox_inches='tight')
                        plt.close()'''
                        yield ([img_batch,
                                mask_batch_masked], [mask_batch, masked_img])
                    else:
                        yield (img_batch, mask_batch)
        if count != 0:
            #if aug_data:
            #    img_batch[:count,...], mask_batch[:count,...] = augmentImages(img_batch[:count,...],
            #                                                                  mask_batch[:count,...])
            if net.find('caps'
                        ) != -1:  #TODO: This is not correct for several slices
                mid_slice = input_slices // 2
                start_index = mid_slice * modalities
                img_batch_mid_slice = img_batch[:, :, :,
                                                start_index:start_index +
                                                modalities]

                mask_batch_masked = oneHot2LabelMax(mask_batch)
                mask_batch_masked[
                    mask_batch_masked >
                    0.5] = 1.0  # Setting all other classes than background to mask
                mask_batch_masked = np.expand_dims(mask_batch_masked, axis=-1)
                mask_batch_masked_expand = np.repeat(mask_batch_masked,
                                                     modalities,
                                                     axis=-1)
                yield ([
                    img_batch[:count, ...], 1 - mask_batch_masked[:count, ...]
                ], [
                    mask_batch[:count,
                               ...], mask_batch_masked_expand[:count, ...] *
                    img_batch_mid_slice[:count, ...]
                ])
            else:
                yield (img_batch[:count, ...], mask_batch[:count, ...])
def generate_val_batches(root_path,
                         val_list,
                         net_input_shape,
                         net,
                         batchSize=1,
                         numSlices=1,
                         subSampAmt=-1,
                         stride=1,
                         downSampAmt=1,
                         shuff=1,
                         dataset='brats',
                         num_output_classes=2):
    # Create placeholders for validation

    modalities = net_input_shape[2] // numSlices
    input_slices = numSlices
    img_batch = np.zeros((np.concatenate(((batchSize, ), net_input_shape))),
                         dtype=np.float32)
    mask_shape = [net_input_shape[0], net_input_shape[1], num_output_classes]
    mask_batch = np.zeros((np.concatenate(((batchSize, ), mask_shape))),
                          dtype=np.float32)

    if dataset == 'brats':
        np_converter = convert_brats_data_to_numpy
        frame_pixels_0 = 8
        frame_pixels_1 = -8
        empty_mask = np.array(
            [one_hot_max, 1 - one_hot_max, 1 - one_hot_max, 1 - one_hot_max])
        raw_x_shape = 240
        raw_y_shape = 240
    elif dataset in ['heart', 'spleen', 'colon', 'hepatic']:
        np_converter = get_np_converter(dataset)
        frame_pixels_0 = 0
        frame_pixels_1 = net_input_shape[0]
        if num_output_classes == 2:
            empty_mask = np.array([one_hot_max, 1 - one_hot_max])
        else:
            empty_mask = np.array([1 - one_hot_max])
        raw_x_shape = net_input_shape[0]
        raw_y_shape = net_input_shape[1]
    else:
        assert False, 'Dataset not recognized'

    while True:
        if shuff:
            shuffle(val_list)
        count = 0
        for i, scan_name in enumerate(val_list):
            try:
                scan_name = scan_name[0]
                path_to_np = join(root_path, 'np_files',
                                  basename(scan_name)[:-6] + 'npz')
                with np.load(path_to_np) as data:
                    val_img = data['img']
                    val_mask = data['mask']
            except:
                print(
                    '\nPre-made numpy array not found for {}.\nCreating now...'
                    .format(scan_name[:-7]))
                val_img, val_mask = np_converter(
                    root_path, scan_name, num_classes=num_output_classes)
                if np.array_equal(val_img, np.zeros(1)):
                    continue
                else:
                    print('\nFinished making npz file.')

            if numSlices == 1:
                sideSlices = 0
            else:
                if numSlices % 2 != 0:
                    numSlices -= 1
                sideSlices = numSlices / 2

            z_shape = val_img.shape[2]
            indicies = np.arange(0, z_shape, stride)

            if shuff:
                shuffle(indicies)

            for j in indicies:
                #if not np.any(val_mask[:, :,  j:j+numSlices]):
                #    continue
                if img_batch.ndim == 4:
                    img_batch[count] = 0
                    next_img = val_img[:, :,
                                       max(j - sideSlices, 0
                                           ):min(j + sideSlices +
                                                 1, z_shape)].reshape(
                                                     raw_x_shape, raw_y_shape,
                                                     -1)
                    insertion_index = -modalities
                    img_index = 0
                    for k in range(j - sideSlices, j + sideSlices + 1):
                        insertion_index += modalities
                        if (k < 0): continue
                        if (k >= z_shape): break
                        img_batch[count, frame_pixels_0:frame_pixels_1,
                                  frame_pixels_0:frame_pixels_1,
                                  insertion_index:insertion_index +
                                  modalities] = next_img[:, :,
                                                         img_index:img_index +
                                                         modalities]
                        img_index += modalities

                    mask_batch[count] = empty_mask
                    mask_batch[
                        count, frame_pixels_0:frame_pixels_1,
                        frame_pixels_0:frame_pixels_1, :] = val_mask[:, :, j]
                else:
                    print(
                        '\nError this function currently only supports 2D and 3D data.'
                    )
                    exit(0)

                count += 1
                if count % batchSize == 0:
                    count = 0
                    if net.find(
                            'caps'
                    ) != -1:  # if the network is capsule/segcaps structure
                        mid_slice = input_slices // 2
                        start_index = mid_slice * modalities
                        img_batch_mid_slice = img_batch[:, :, :, start_index:
                                                        start_index +
                                                        modalities]

                        mask_batch_masked = oneHot2LabelMax(mask_batch)
                        mask_batch_masked[
                            mask_batch_masked >
                            0.5] = 1.0  # Setting all other classes than background to mask
                        mask_batch_masked = np.expand_dims(mask_batch_masked,
                                                           axis=-1)
                        mask_batch_masked_expand = np.repeat(mask_batch_masked,
                                                             modalities,
                                                             axis=-1)

                        masked_img = mask_batch_masked_expand * img_batch_mid_slice
                        yield ([img_batch, 1 - mask_batch_masked],
                               [mask_batch, masked_img])
                    else:
                        yield (img_batch, mask_batch)

        if count != 0:
            #if aug_data:
            #    img_batch[:count,...], mask_batch[:count,...] = augmentImages(img_batch[:count,...],
            #                                                                  mask_batch[:count,...])
            if net.find('caps'
                        ) != -1:  #TODO: This is not correct for several slices
                yield ([img_batch[:count, ...], mask_batch[:count, ...]], [
                    mask_batch[:count, ...],
                    mask_batch[:count, ...] * img_batch[:count, ...]
                ])
            else:
                yield (img_batch[:count, ...], mask_batch[:count, ...])
def test(args, test_list, model_list, net_input_shape):
    if args.weights_path == '':
        weights_path = join(args.check_dir, args.output_name + '_validation_best_model_' + args.time + '.hdf5')
    else:
        weights_path = join(args.data_root_dir, args.weights_path)

    if args.dataset == 'brats':
        RESOLUTION = 240
    elif args.dataset == 'heart':
        RESOLUTION = 320
    else:
        RESOLUTION = 512

    output_dir = join(args.data_root_dir, 'results', args.net, 'split_' + str(args.split_num))
    raw_out_dir = join(output_dir, 'raw_output')
    fin_out_dir = join(output_dir, 'final_output')
    fig_out_dir = join(output_dir, 'qual_figs')
    try:
        makedirs(raw_out_dir)
    except:
        pass
    try:
        makedirs(fin_out_dir)
    except:
        pass
    try:
        makedirs(fig_out_dir)
    except:
        pass

    if len(model_list) > 1:
        eval_model = model_list[1]
    else:
        eval_model = model_list[0]
    try:
        eval_model.load_weights(weights_path)
    except Exception as e:
        print(e)
        assert False, 'Unable to find weights path. Testing with random weights.'
    print_summary(model=eval_model, positions=[.38, .65, .75, 1.])

    # Set up placeholders
    outfile = ''
    if args.compute_dice:
        dice_arr = np.zeros((len(test_list)))
        outfile += 'dice_'
    if args.compute_jaccard:
        jacc_arr = np.zeros((len(test_list)))
        outfile += 'jacc_'
    if args.compute_assd:
        assd_arr = np.zeros((len(test_list)))
        outfile += 'assd_'
    surf_arr = np.zeros((len(test_list)), dtype=str)
    dice2_arr = np.zeros((len(test_list), args.out_classes-1))

    # Testing the network
    print('Testing... This will take some time...')

    with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'), 'wb') as csvfile:
        writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        dice_results_csv = open(join(output_dir, args.save_prefix + outfile + 'dice_scores.csv'), 'wb')
        dice_writer = csv.writer(dice_results_csv, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        row = ["Scan Name"]
        for i in range(1,args.out_classes):
            row.append("Dice_{}".format(i))
        dice_writer.writerow(row)

        row = ['Scan Name']
        if args.compute_dice:
            row.append('Dice Coefficient')
        if args.compute_jaccard:
            row.append('Jaccard Index')
        if args.compute_assd:
            row.append('Average Symmetric Surface Distance')

        writer.writerow(row)

        for i, img in enumerate(tqdm(test_list)):
            sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0]))
            img_data = sitk.GetArrayFromImage(sitk_img)
            num_slices = img_data.shape[0]

            if args.dataset == 'brats':
                num_slices = img_data.shape[1]#brats
                img_data = np.rollaxis(img_data,0,4)

            print(args.dataset)

            output_array = eval_model.predict_generator(generate_test_batches(args.data_root_dir, [img],
                                                                              net_input_shape,
                                                                              batchSize=1,
                                                                              numSlices=args.slices,
                                                                              subSampAmt=0,
                                                                              stride=1, dataset = args.dataset, num_output_classes=args.out_classes),
                                                        steps=num_slices, max_queue_size=1, workers=1,
                                                        use_multiprocessing=False, verbose=1)

            print('out' + str(output_array[0].shape))
            if args.net.find('caps') != -1:
                if args.dataset == 'brats':
                    output = output_array[0][:,8:-8,8:-8]
                    recon = output_array[1][:,8:-8,8:-8]
                else:
                    output = output_array[0][:,:,:]
                    recon = output_array[1][:,:,:]
            else:
                if args.dataset == 'brats':
                    output = output_array[:,8:-8,8:-8,:]
                else:
                    output = output_array[:,:,:,:]


            if args.out_classes == 1:
                output_raw = output.reshape(-1,RESOLUTION,RESOLUTION,1) #binary
            else:
                output_raw = output
                output = oneHot2LabelMax(output)

            out = output.astype(np.int64)

            if args.out_classes == 1:
                outputOnehot = out.reshape(-1,RESOLUTION,RESOLUTION,1) #binary
            else:
                outputOnehot = np.eye(args.out_classes)[out].astype(np.uint8)


            output_img = sitk.GetImageFromArray(output)

            print('Segmenting Output')

            output_bin = threshold_mask(output, args.thresh_level)

            output_mask = sitk.GetImageFromArray(output_bin)

            slice_img = sitk.Image(RESOLUTION,RESOLUTION,num_slices, sitk.sitkUInt8)

            output_img.CopyInformation(slice_img)
            output_mask.CopyInformation(slice_img)


            #output_img.CopyInformation(sitk_img)
            #output_mask.CopyInformation(sitk_img)

            print('Saving Output')
            if args.dataset != 'luna':
                sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-7] + '_raw_output' + img[0][-7:]))
                sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-7] + '_final_output' + img[0][-7:]))

                # Load gt mask
                sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0]))
                gt_data = sitk.GetArrayFromImage(sitk_mask)
                label = gt_data.astype(np.int64)

                if args.out_classes == 1:
                    gtOnehot = label.reshape(-1,RESOLUTION,RESOLUTION,1) #binary
                    gt_label = label
                else:
                    gtOnehot = np.eye(args.out_classes)[label].astype(np.uint8)
                    gt_label = label

                if args.net.find('caps') != -1:
                    create_recon_image(args, recon, img_data, gtOnehot, i=i)
                
                create_activation_image(args, output_raw, gtOnehot, slice_num=output_raw.shape[0] // 2, index=i)
                # Plot Qual Figure
                print('Creating Qualitative Figure for Quick Reference')
                f, ax = plt.subplots(2, 3, figsize=(10, 5))

                colors = ['Greys', 'Greens', 'Reds', 'Blues']
                fileTypeLength = 7

                print(img_data.shape)
                print(outputOnehot.shape)
                if args.dataset == 'brats':
                    #img_data = img_data[3] #caps?
                    img_data = img_data[:,:,:,3] #isensee

                # Prediction plots
                ax[0,0].imshow(img_data[num_slices // 3, :, :], alpha=1, cmap='gray')
                for class_num in range(1, outputOnehot.shape[3]):
                    mask = outputOnehot[num_slices // 3, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[0,0].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[0,0].set_title('Slice {}/{}'.format(num_slices // 3, num_slices))
                ax[0,0].axis('off')

                ax[0,1].imshow(img_data[num_slices // 2, :, :], alpha=1, cmap='gray')
                for class_num in range(1, outputOnehot.shape[3]):
                    mask = outputOnehot[num_slices // 2, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[0,1].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[0,1].set_title('Slice {}/{}'.format(num_slices // 2, num_slices))
                ax[0,1].axis('off')

                ax[0,2].imshow(img_data[num_slices // 2 + num_slices // 4, :, :], alpha=1, cmap='gray')
                for class_num in range(1, outputOnehot.shape[3]):
                    mask = outputOnehot[num_slices // 2 + num_slices // 4, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[0,2].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[0,2].set_title(
                    'Slice {}/{}'.format(num_slices // 2 + num_slices // 4, num_slices))
                ax[0,2].axis('off')

                # Ground truth plots
                ax[1,0].imshow(img_data[num_slices // 3, :, :], alpha=1, cmap='gray')
                ax[1,0].set_title('Slice {}/{}'.format(num_slices // 3, num_slices))
                for class_num in range(1, gtOnehot.shape[3]):
                    mask = gtOnehot[num_slices // 3, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[1,0].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[1,0].axis('off')

                ax[1,1].imshow(img_data[num_slices // 2, :, :], alpha=1, cmap='gray')
                ax[1,1].set_title('Slice {}/{}'.format(num_slices // 2, num_slices))
                for class_num in range(1, gtOnehot.shape[3]):
                    mask = gtOnehot[num_slices // 2, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[1,1].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[1,1].axis('off')

                ax[1,2].imshow(img_data[num_slices // 2 + num_slices // 4, :, :], alpha=1, cmap='gray')
                ax[1,2].set_title(
                    'Slice {}/{}'.format(num_slices // 2 + num_slices // 4, num_slices))
                for class_num in range(1, gtOnehot.shape[3]):
                    mask = gtOnehot[num_slices // 2 + num_slices // 4, :, :, class_num]
                    mask = np.ma.masked_where(mask == 0, mask)
                    ax[1,2].imshow(mask, alpha=0.7, cmap=colors[class_num], vmin = 0, vmax = 1)
                ax[1,2].axis('off')

                fig = plt.gcf()
                fig.suptitle(img[0][:-fileTypeLength])

                plt.savefig(join(fig_out_dir, img[0][:-fileTypeLength] + '_qual_fig' + '.png'),
                            format='png', bbox_inches='tight')
                plt.close('all')
            else:
                sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-4] + '_raw_output' + img[0][-4:]))
                sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-4] + '_final_output' + img[0][-4:]))

                # Load gt mask
                sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0]))
                gt_data = sitk.GetArrayFromImage(sitk_mask)

                f, ax = plt.subplots(1, 3, figsize=(15, 5))

                ax[0].imshow(img_data[img_data.shape[0] // 3, :, :], alpha=1, cmap='gray')
                ax[0].imshow(output_bin[img_data.shape[0] // 3, :, :], alpha=0.5, cmap='Reds')
                #ax[0].imshow(gt_data[img_data.shape[0] // 3, :, :], alpha=0.2, cmap='Reds')
                ax[0].set_title('Slice {}/{}'.format(img_data.shape[0] // 3, img_data.shape[0]))
                ax[0].axis('off')

                ax[1].imshow(img_data[img_data.shape[0] // 2, :, :], alpha=1, cmap='gray')
                ax[1].imshow(output_bin[img_data.shape[0] // 2, :, :], alpha=0.5, cmap='Reds')
                #ax[1].imshow(gt_data[img_data.shape[0] // 2, :, :], alpha=0.2, cmap='Reds')
                ax[1].set_title('Slice {}/{}'.format(img_data.shape[0] // 2, img_data.shape[0]))
                ax[1].axis('off')

                ax[2].imshow(img_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=1, cmap='gray')
                ax[2].imshow(output_bin[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.5,
                             cmap='Reds')
                #ax[2].imshow(gt_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.2,
                #             cmap='Reds')
                ax[2].set_title(
                    'Slice {}/{}'.format(img_data.shape[0] // 2 + img_data.shape[0] // 4, img_data.shape[0]))
                ax[2].axis('off')

                fig = plt.gcf()
                fig.suptitle(img[0][:-4])

                plt.savefig(join(fig_out_dir, img[0][:-4] + '_qual_fig' + '.png'),
                            format='png', bbox_inches='tight')

                
            output_label = oneHot2LabelMax(outputOnehot)

            row = [img[0][:-4]]
            dice_row = [img[0][:-4]]
            if args.compute_dice:
                print('Computing Dice')
                dice_arr[i] = dc(outputOnehot, gtOnehot)
                print('\tDice: {}'.format(dice_arr[i]))
                row.append(dice_arr[i])
            if args.compute_jaccard:
                print('Computing Jaccard')
                jacc_arr[i] = jaccard(outputOnehot, gtOnehot)
                print('\tJaccard: {}'.format(jacc_arr[i]))
                row.append(jacc_arr[i])
            if args.compute_assd:
                print('Computing ASSD')
                assd_arr[i] = assd(outputOnehot, gtOnehot, voxelspacing=sitk_img.GetSpacing(), connectivity=1)
                print('\tASSD: {}'.format(assd_arr[i]))
                row.append(assd_arr[i])
            try:
                spacing = np.array(sitk_img.GetSpacing())
                if args.dataset == 'brats':
                   spacing = spacing[1:]
                surf = compute_surface_distances(label, out, spacing)
                surf_arr[i] = str(surf)
                assd_score = calc_assd_scores(output_label, gt_label, args.out_classes, spacing)
                print(assd_score)
                print('\tSurface distance ' + str(surf_arr[i]))
            except:
                print("surf failed")
                pass
            #dice2_arr[i] = compute_dice_coefficient(gtOnehot, outputOnehot)
            dice2_arr[i] = calc_dice_scores(output_label, gt_label, args.out_classes)
            for score in dice2_arr[i]:
                dice_row.append(score)
            dice_writer.writerow(dice_row)
            
            print('\tMSD Dice: {}'.format(dice2_arr[i]))
            
            writer.writerow(row)

            
        dice_row = ['Average Scores']
        avgs = np.mean(dice2_arr, axis=0)
        for avg in avgs:
            dice_row.append(avg)
        dice_writer.writerow(dice_row)
            
        
        row = ['Average Scores']
        if args.compute_dice:
            row.append(np.mean(dice_arr))
        if args.compute_jaccard:
            row.append(np.mean(jacc_arr))
        if args.compute_assd:
            row.append(np.mean(assd_arr))
        row.append(surf_arr)
        row.append(np.mean(dice2_arr))
        
        writer.writerow(row)
        dice_results_csv.close()
      

    print('Done.')