Example #1
0
def convert_data_to_numpy(root_path,
                          img_name,
                          no_masks=False,
                          overwrite=False):
    fname = img_name[:-4]
    numpy_path = join(root_path, 'np_files')
    img_path = join(root_path, 'imgs')
    mask_path = join(root_path, 'masks')
    fig_path = join(root_path, 'figs')
    try:
        makedirs(numpy_path)
    except:
        pass
    try:
        makedirs(fig_path)
    except:
        pass

    if not overwrite:
        try:
            with np.load(join(numpy_path, fname + '.npz')) as data:
                return data['img'], data['mask']
        except:
            pass

    try:
        img = np.array(Image.open(join(img_path, img_name)))
        # Conver image to 3 dimensions
        img = convert_img_data(img, 3)

        if not no_masks:
            # Replace SimpleITK to PILLOW for 2D image support on Raspberry Pi
            mask = np.array(Image.open(join(mask_path, img_name)))  # (x,y,4)

            mask = convert_mask_data(mask)

        if not no_masks:
            np.savez_compressed(join(numpy_path, fname + '.npz'),
                                img=img,
                                mask=mask)
        else:
            np.savez_compressed(join(numpy_path, fname + '.npz'), img=img)

        if not no_masks:
            return img, mask
        else:
            return img

    except Exception as e:
        print('\n' + '-' * 100)
        print('Unable to load img or masks for {}'.format(fname))
        print(e)
        print('Skipping file')
        print('-' * 100 + '\n')

        return np.zeros(1), np.zeros(1)
Example #2
0
def generate_test_image(test_img,
                        net_input_shape,
                        batchSize=1,
                        numSlices=1,
                        subSampAmt=0,
                        stride=1,
                        downSampAmt=1):
    '''
    test_img: numpy.array of image data, (height, width, channels)
    
    '''
    # Create placeholders for testing
    logging.info('\nload_2D_data.generate_test_image')
    # Convert image to 4 dimensions
    test_img = convert_img_data(test_img, 4)

    yield (test_img)
Example #3
0
def test(args, test_list, model_list, net_input_shape):
    if args.weights_path == '':
        weights_path = join(args.check_dir,
                            args.output_name + '_model_' + args.time + '.hdf5')
    else:
        weights_path = join(args.data_root_dir, args.weights_path)

    output_dir = join(args.data_root_dir, 'results', args.net,
                      'split_' + str(args.split_num))
    raw_out_dir = join(output_dir, 'raw_output')
    fin_out_dir = join(output_dir, 'final_output')
    fig_out_dir = join(output_dir, 'qual_figs')
    try:
        makedirs(raw_out_dir)
    except:
        pass
    try:
        makedirs(fin_out_dir)
    except:
        pass
    try:
        makedirs(fig_out_dir)
    except:
        pass

    if len(model_list) > 1:
        eval_model = model_list[1]
    else:
        eval_model = model_list[0]
    try:
        logging.info('\nWeights_path=%s' % (weights_path))
        eval_model.load_weights(weights_path)
    except:
        logging.warning(
            '\nUnable to find weights path. Testing with random weights.')
    print_summary(model=eval_model, positions=[.38, .65, .75, 1.])

    # Set up placeholders
    outfile = ''
    if args.compute_dice:
        dice_arr = np.zeros((len(test_list)))
        outfile += 'dice_'
    if args.compute_jaccard:
        jacc_arr = np.zeros((len(test_list)))
        outfile += 'jacc_'
    if args.compute_assd:
        assd_arr = np.zeros((len(test_list)))
        outfile += 'assd_'

    # Testing the network
    logging.info('\nTesting... This will take some time...')

    with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'),
              'w') as csvfile:
        writer = csv.writer(csvfile,
                            delimiter=',',
                            quotechar='|',
                            quoting=csv.QUOTE_MINIMAL)

        row = ['Scan Name']
        if args.compute_dice:
            row.append('Dice Coefficient')
        if args.compute_jaccard:
            row.append('Jaccard Index')
        if args.compute_assd:
            row.append('Average Symmetric Surface Distance')

        writer.writerow(row)

        for i, img in enumerate((test_list)):
            sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0]))
            img_data = sitk.GetArrayFromImage(
                sitk_img)  # 3d:(slices, 512, 512), 2d:(512, 512, channels=4)

            # Change RGB to single slice of grayscale image for MS COCO 17 dataset.
            if args.dataset == 'mscoco17':
                img_data = convert_img_data(img_data, 3)

            num_slices = 1
            logging.info('\ntest.test: eval_model.predict_generator')
            _, _, generate_test_batches = get_generator(args.dataset)
            output_array = eval_model.predict_generator(
                generate_test_batches(args.data_root_dir, [img],
                                      net_input_shape,
                                      batchSize=args.batch_size,
                                      numSlices=args.slices,
                                      subSampAmt=0,
                                      stride=1),
                steps=num_slices,
                max_queue_size=1,
                workers=4,
                use_multiprocessing=args.use_multiprocessing,
                verbose=1)
            logging.info('\ntest.test: output_array=%s' % (output_array))
            if args.net.find('caps') != -1:
                # A list with two images [mask, recon], get mask image.#3d:
                # output_array=[mask(Slices, x=512, y=512, 1), recon(slices, x=512, y=512, 1)]
                output = output_array[0][:, :, :,
                                         0]  # output = (slices, 512, 512)
                #recon = output_array[1][:,:,:,0]
            else:
                output = output_array[:, :, :, 0]

            #output_image = RTTI size:[512, 512, 119]
            output_img = sitk.GetImageFromArray(output)
            print('Segmenting Output')
            # output_bin (119, 512, 512)
            output_bin = threshold_mask(output, args.thresh_level)
            # output_mask = RIIT (512, 512, 119)
            output_mask = sitk.GetImageFromArray(output_bin)
            if args.dataset == 'luna16':
                output_img.CopyInformation(sitk_img)
                output_mask.CopyInformation(sitk_img)

                print('Saving Output')
                sitk.WriteImage(
                    output_img,
                    join(raw_out_dir,
                         img[0][:-4] + '_raw_output' + img[0][-4:]))
                sitk.WriteImage(
                    output_mask,
                    join(fin_out_dir,
                         img[0][:-4] + '_final_output' + img[0][-4:]))
            else:  # MS COCO 17
                plt.imshow(output[0, :, :], cmap='gray')
                plt.imsave(
                    join(raw_out_dir,
                         img[0][:-4] + '_raw_output' + img[0][-4:]),
                    output[0, :, :])
                plt.imshow(output_bin[0, :, :], cmap='gray')
                plt.imsave(
                    join(fin_out_dir,
                         img[0][:-4] + '_final_output' + img[0][-4:]),
                    output_bin[0, :, :])

            # Load gt mask
            # sitk_mask: 3d RTTI(512, 512, slices)
            sitk_mask = sitk.ReadImage(
                join(args.data_root_dir, 'masks', img[0]))
            # gt_data: 3d=(slices, 512, 512), Ground Truth data
            gt_data = sitk.GetArrayFromImage(sitk_mask)

            # Change RGB to single slice of grayscale image for MS COCO 17 dataset.
            if args.dataset == 'mscoco17':
                gt_data = convert_mask_data(gt_data)
                # Reshape numpy from 2 to 3 dimensions (slices, heigh, width)
                gt_data = gt_data.reshape(
                    [1, gt_data.shape[0], gt_data.shape[1]])

            # Plot Qual Figure
            print('Creating Qualitative Figure for Quick Reference')
            f, ax = plt.subplots(1, 3, figsize=(15, 5))

            if args.dataset == 'mscoco17':
                pass
            else:  # 3D data
                ax[0].imshow(img_data[img_data.shape[0] // 3, :, :],
                             alpha=1,
                             cmap='gray')
                ax[0].imshow(output_bin[img_data.shape[0] // 3, :, :],
                             alpha=0.5,
                             cmap='Blues')
                ax[0].imshow(gt_data[img_data.shape[0] // 3, :, :],
                             alpha=0.2,
                             cmap='Reds')
                ax[0].set_title('Slice {}/{}'.format(img_data.shape[0] // 3,
                                                     img_data.shape[0]))
                ax[0].axis('off')

                ax[1].imshow(img_data[img_data.shape[0] // 2, :, :],
                             alpha=1,
                             cmap='gray')
                ax[1].imshow(output_bin[img_data.shape[0] // 2, :, :],
                             alpha=0.5,
                             cmap='Blues')
                ax[1].imshow(gt_data[img_data.shape[0] // 2, :, :],
                             alpha=0.2,
                             cmap='Reds')
                ax[1].set_title('Slice {}/{}'.format(img_data.shape[0] // 2,
                                                     img_data.shape[0]))
                ax[1].axis('off')

                ax[2].imshow(img_data[img_data.shape[0] // 2 +
                                      img_data.shape[0] // 4, :, :],
                             alpha=1,
                             cmap='gray')
                ax[2].imshow(output_bin[img_data.shape[0] // 2 +
                                        img_data.shape[0] // 4, :, :],
                             alpha=0.5,
                             cmap='Blues')
                ax[2].imshow(gt_data[img_data.shape[0] // 2 +
                                     img_data.shape[0] // 4, :, :],
                             alpha=0.2,
                             cmap='Reds')
                ax[2].set_title('Slice {}/{}'.format(
                    img_data.shape[0] // 2 + img_data.shape[0] // 4,
                    img_data.shape[0]))
                ax[2].axis('off')

                fig = plt.gcf()
                fig.suptitle(img[0][:-4])

                plt.savefig(join(fig_out_dir,
                                 img[0][:-4] + '_qual_fig' + '.png'),
                            format='png',
                            bbox_inches='tight')
                plt.close('all')

            # Compute metrics
            row = [img[0][:-4]]
            if args.compute_dice:
                logging.info('\nComputing Dice')
                dice_arr[i] = dc(output_bin, gt_data)
                logging.info('\tDice: {}'.format(dice_arr[i]))
                row.append(dice_arr[i])
            if args.compute_jaccard:
                logging.info('\nComputing Jaccard')
                jacc_arr[i] = jc(output_bin, gt_data)
                logging.info('\tJaccard: {}'.format(jacc_arr[i]))
                row.append(jacc_arr[i])
            if args.compute_assd:
                logging.info('\nComputing ASSD')
                assd_arr[i] = assd(output_bin,
                                   gt_data,
                                   voxelspacing=sitk_img.GetSpacing(),
                                   connectivity=1)
                logging.info('\tASSD: {}'.format(assd_arr[i]))
                row.append(assd_arr[i])

            writer.writerow(row)

        row = ['Average Scores']
        if args.compute_dice:
            row.append(np.mean(dice_arr))
        if args.compute_jaccard:
            row.append(np.mean(jacc_arr))
        if args.compute_assd:
            row.append(np.mean(assd_arr))
        writer.writerow(row)

    print('Done.')
def convert_data_to_numpy(root_path,
                          img_name,
                          no_masks=False,
                          overwrite=False):
    print("a1")
    fname = img_name[:-4]
    numpy_path = join(root_path, 'np_files')
    img_path = join(root_path, 'imgs')
    mask_path = join(root_path, 'masks')
    fig_path = join(root_path, 'figs')
    print(fname)
    print(numpy_path)
    print(img_path)
    print(fig_path)

    try:
        makedirs(numpy_path)
    except:
        pass
    try:
        makedirs(fig_path)
    except:
        pass

    if not overwrite:
        try:
            print("")
            with np.load(join(numpy_path, fname + '.npz')) as data:
                return data['img'], data['mask']
        except:
            pass
    # print("a2")
    try:
        img = np.array(Image.open(join(img_path, img_name)))
        # Conver image to 3 dimensions
        # print("img before")
        # print(img)
        # print(img.shape)
        # print("img after")
        #img = img.reshape(img.shape[])
        img = convert_img_data(img, 3)
        # print("img after convert")

        # print("AT 1111111111111")
        if not no_masks:
            # Replace SimpleITK to PILLOW for 2D image support on Raspberry Pi
            mask = np.array(Image.open(join(mask_path, img_name)))  # (x,y,4)
            # print("Before")

            # print(mask)
            mask = convert_mask_data(mask)
            # print("From mask " , mask.shape)
            # print('Mask Dtype',mask.dtype)
            # exit(9)
            mask = custom_background(mask)
            # print(mask)
            # print("After")
        if not no_masks:
            np.savez_compressed(join(numpy_path, fname + '.npz'),
                                img=img,
                                mask=mask)
        else:
            np.savez_compressed(join(numpy_path, fname + '.npz'), img=img)
        # print("AT 2222222222222222222")
        if not no_masks:
            return img, mask
        else:
            return img

    except Exception as e:
        print('\n' + '-' * 100)
        print('Unable to load img or masks for {}'.format(fname))
        print(e)
        print('Skipping file')
        print('-' * 100 + '\n')

        return np.zeros(1), np.zeros(1)