Esempio n. 1
0
def build_iscat_fluo_training(iscat_filepaths, fluo_filepaths):
    """Creates iscat training data and target in data/iscat_seg/[REF_FRAMES / MASKS] for the iSCAT cell segmentation task with fluorent images used as training
    
        ARGS:
            iscat_filepaths (list(str)): filepaths of all the iSCAT images to input as returned by utilities.load_data_paths()
            fluo_filepaths (list(str)): filepaths of all the fluo images to input as returned by utilities.load_data_paths()
    """

    print("Building iSCAT/fluo dataset...")

    OUT_PATH = DATA_PATH + 'iscat_fluo/'
    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    fluo_tirf_imgs = (tifffile.imread(fluo_tirf_img_path)
                      for fluo_tirf_img_path in fluo_filepaths)
    fluo_iscat_imgs = (tifffile.imread(fluo_iscat_img_path)
                       for fluo_iscat_img_path in iscat_filepaths)

    min_size, max_size = 1, 13

    for i, (tirf, iscat) in enumerate(zip(fluo_tirf_imgs, fluo_iscat_imgs)):
        tirf = processing.coregister(tirf, 1.38)
        tirf = ndimage.median_filter(tirf, 5)
        tirf = tirf > 300
        tirf = ndimage.binary_closing(
            tirf, processing.structural_element('circle',
                                                (1, 10, 10))).astype('float32')

        iscat = processing.image_correction(iscat)
        iscat = processing.enhance_contrast(iscat,
                                            'stretching',
                                            percentile=(1, 99))
        iscat = processing.fft_filtering(iscat, min_size, max_size, True)
        iscat = processing.enhance_contrast(iscat,
                                            'stretching',
                                            percentile=(3, 97))

        # Discards the first image
        for j in range(1, tirf.shape[0]):
            try:
                if (tirf != 0).any():
                    imageio.imsave(
                        OUT_PATH + 'REF_FRAMES/' +
                        "iscat_{}_{}.png".format(i + 1, j + 1),
                        rescale(iscat[int(j * iscat.shape[0] /
                                          tirf.shape[0])]))
                    imageio.imsave(
                        OUT_PATH +
                        'MASKS/' + "fluo_{}_{}.png".format(i + 1, j + 1),
                        rescale(tirf[j]))

            except IndexError:
                # In case the dimensions don't match -> goes to next slice
                print("IndexError")
                continue
Esempio n. 2
0
def import_test_data():
    """Imports the test data in one of the 4 given folders as a generator yielding the stacks one at a time"""

    DATA_PATH = ['/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2019/novembre/20/fliC-_PaQa_Gasket_0/',
                 '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2020/janvier/16/fliC-_PaQa_solid_4/',
                 '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2018/décembre/04/PilH-FliC-_Agar_Microchannel_0_Flo/',
                 '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2020/janvier/17/pilH-fliC-_PaQa_solid_0/']

    DATA_PATH = DATA_PATH[-1]

    tirf_paths = glob.glob(DATA_PATH +'cam1/event[0-9]_tirf/*.tif')
    iscat_paths = glob.glob(DATA_PATH +'cam1/event[0-9]/*PreNbin*.tif')

    fps_iscat, fps_tirf = 100, 50

    for i, (iscat_path, tirf_path) in enumerate(zip(iscat_paths, tirf_paths)):

        iscat = tifffile.imread(iscat_path)[::int(fps_iscat /fps_tirf)]
        tirf = tifffile.imread(tirf_path)

        iscat = processing.image_correction(iscat)
        iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(1, 99))
        iscat = processing.fft_filtering(iscat, 1, 13, True)
        iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(3, 97))
        iscat = iscat.astype('float32')

        tirf = processing.coregister(tirf, 1.38)

        # For fluo tirf images
        # tirf = ndimage.median_filter(tirf, 5)
        # tirf = (tirf > .9 *skimage.filters.threshold_triangle(tirf.ravel())).astype('uint8')
        # tirf = ndimage.binary_opening(tirf, processing.structural_element('circle', (1,10,10))).astype('uint8')
        # tirf = predict_cell_detection(tirf)

        tirf = tirf.astype('float32')
        
        iscat -= iscat.min()
        iscat /= iscat.max()
        
        tirf -= tirf.min()
        tirf /= tirf.max()

        yield iscat, tirf
Esempio n. 3
0
def build_iscat_pili_training(iscat_filepaths, pili_coords_filepaths):
    """Creates the training set and target for the pili training task
        ARGS:
                iscat_filepaths (list(str)): filepaths of all the iSCAT images to input as returned by utilitiues.load_data_paths()
                pili_coords_filepaths (list(str)): filepaths of all the pili cordinates as computed by the included imageJ plugin
    """

    OUT_PATH = DATA_PATH + 'iscat_pili/'
    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    ref_iscat_imgs = (tifffile.imread(filepath)
                      for filepath in iscat_filepaths)
    pili_masks = get_pili_masks(pili_coords_filepaths)

    # Preprocessing
    min_size, max_size = 1, 13
    iscat_stacks = (processing.image_correction(iscat_stack)
                    for iscat_stack in ref_iscat_imgs)
    iscat_stacks = (processing.enhance_contrast(iscat_stack,
                                                'stretching',
                                                percentile=(1, 99))
                    for iscat_stack in iscat_stacks)
    iscat_stacks = (processing.fft_filtering(iscat_stack, min_size, max_size,
                                             True)
                    for iscat_stack in iscat_stacks)
    iscat_stacks = (processing.enhance_contrast(iscat_stack,
                                                'stretching',
                                                percentile=(3, 97))
                    for iscat_stack in iscat_stacks)

    for i, (iscat, pili) in enumerate(zip(iscat_stacks, pili_masks)):
        for j in range(pili.shape[0]):
            if (pili[j] != 0).any() and not (j <= 1
                                             or j >= pili[j].shape[0] - 2):
                imageio.imsave(
                    OUT_PATH + 'REF_FRAMES/' +
                    "ref_iscat_{}_{}.png".format(i + 1, j + 1),
                    rescale(iscat[j]))
                imageio.imsave(
                    OUT_PATH + 'MASKS/' +
                    "pili_{}_{}.png".format(i + 1, j + 1), pili[j])
Esempio n. 4
0
def build_hand_segmentation_iscat_testval():
    """Creates dataset of manually segmented iSCAT images for validation and testing"""

    OUT_PATH_CELL = DATA_PATH + 'hand_seg_iscat_cell/'
    os.makedirs(os.path.join(OUT_PATH_CELL, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH_CELL, 'MASKS/'), exist_ok=True)

    OUT_PATH_PILI = DATA_PATH + 'hand_seg_iscat_pili/'
    os.makedirs(os.path.join(OUT_PATH_PILI, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH_PILI, 'MASKS/'), exist_ok=True)

    ROOT_TEST_PATH = "data/hand-segmentation/"
    iscat_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'iSCAT/*.tif'))
    cell_seg_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'cell_seg/*.txt'))
    pili_seg_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'pili_seg/*.txt'))

    iscat_files.sort()
    cell_seg_files.sort()
    pili_seg_files.sort()

    for i, (iscat, cell_seg, pili_seg) in enumerate(
            zip(iscat_files, cell_seg_files, pili_seg_files)):

        # Loading tirf and iSCAT images
        iscat_stack = tifffile.imread(iscat)

        # iSCAT preprocessing
        iscat_stack = processing.image_correction(iscat_stack)
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(1, 99))
        iscat_stack = processing.fft_filtering(iscat_stack, 1, 13, True)
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(3, 97))

        # Loading ground truth masks
        mask_cell = utilities.cell_mask_from_segmentation(cell_seg)
        mask_pili = utilities.pili_mask_from_segmentation(pili_seg)

        for j in range(0, iscat_stack.shape[0], 8):
            print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1),
                  end=' ' * 5)
            tifffile.imsave(
                os.path.join(OUT_PATH_CELL, 'REF_FRAMES/',
                             "stack_{}_{}.png".format(i + 1, j + 1)),
                rescale(iscat_stack[j]))
            tifffile.imsave(
                os.path.join(OUT_PATH_CELL, 'MASKS/',
                             "mask_{}_{}.png".format(i + 1, j + 1)),
                mask_cell[j // 2].astype('uint8'))

        print('')
        for j in range(iscat_stack.shape[0]):
            if not (mask_pili != 0).any(): continue

            print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1),
                  end=' ' * 5)
            tifffile.imsave(
                os.path.join(OUT_PATH_PILI, 'REF_FRAMES/',
                             "stack_{}_{}.png".format(i + 1, j + 1)),
                rescale(iscat_stack[j]))
            tifffile.imsave(
                os.path.join(OUT_PATH_PILI, 'MASKS/',
                             "mask_{}_{}.png".format(i + 1, j + 1)),
                mask_pili[j].astype('uint8'))

        print('')
Esempio n. 5
0
def build_iscat_training(bf_filepaths, iscat_filepaths, sampling=4):
    """Creates iscat training data and target in data/iscat_seg/[REF_FRAMES / MASKS] for the iSCAT cell segmentation task
    
        ARGS:
            bf_filepaths (list(str)): filepaths of all the bright field images to input as returned by utilitiues.load_data_paths()            
            iscat_filepaths (list(str)): filepaths of all the iscat images to input as returned by utilitiues.load_data_paths()
            sampling (int): sampling interval of the saved images (lower storage footprint)
    """

    OUT_PATH = DATA_PATH + 'iscat_seg/'
    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    # Range of non filtered elements [px]
    min_size, max_size = 1, 13

    iscat_stacks = (utilities.load_imgs(path) for path in iscat_filepaths)
    bf_stacks = (utilities.load_imgs(path) for path in bf_filepaths)

    # Returns the metadata of the exwperiments such as frame rate
    metadatas = get_experiments_metadata(iscat_filepaths)

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        torch.cuda.set_device(device)
        print("Running on: {:s}".format(torch.cuda.get_device_name(device)))
        cuda = torch.device('cuda')
    else:
        # Doesn't run on CPU only machines comment if no GPU
        print("No CUDA device found")
        sys.exit(1)

    unet = UNetCell(1, 1, device=cuda, bilinear_upsampling=False)
    unet.load_state_dict(torch.load('outputs/saved_models/bf_unet.pth'))

    for i, (bf_stack, iscat_stack,
            metadata) in enumerate(zip(bf_stacks, iscat_stacks, metadatas)):
        if i < 45: continue

        bf_stack = bf_stack.astype('float32')
        print(bf_stack.shape)
        if bf_stack.shape[1:] != iscat_stack.shape[1:]:
            bf_stack = processing.coregister(bf_stack, 1.38)
            print(bf_stack.shape)

        normalize(bf_stack)

        # Samples iscat image to correct for the difference in framefate
        iscat_stack = iscat_stack[::sampling * int(metadata['iscat_fps'] /
                                                   metadata['tirf_fps'])]

        torch_stack = torch.from_numpy(bf_stack).unsqueeze(1).cuda()
        mask = unet.predict_stack(
            torch_stack).detach().squeeze().cpu().numpy() > 0.05
        mask = morphology.grey_erosion(mask * 255,
                                       structure=processing.structural_element(
                                           'circle', (3, 5, 5)))
        mask = morphology.grey_closing(mask,
                                       structure=processing.structural_element(
                                           'circle', (3, 7, 7)))
        mask = (mask > 50).astype('uint8')

        # Median filtering and normalization
        iscat_stack = processing.image_correction(iscat_stack)

        # Contrast enhancement
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(1, 99))

        # Fourier filtering of image
        iscat_stack = processing.fft_filtering(iscat_stack, min_size, max_size,
                                               True)
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(3, 97))

        for j in range(0, min(iscat_stack.shape[0], mask.shape[0]), sampling):
            if iscat_stack[j].shape == mask[j].shape:
                # Doesn't save images without detected cells
                if mask[j].max() == 0: continue

                print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1),
                      end=' ' * 5)
                tifffile.imsave(
                    os.path.join(OUT_PATH, 'REF_FRAMES/',
                                 "stack_{}_{}.png".format(i + 1, j + 1)),
                    rescale(iscat_stack[j]))
                tifffile.imsave(
                    os.path.join(OUT_PATH, 'MASKS/',
                                 "mask_{}_{}.png".format(i + 1, j + 1)),
                    mask[j] * 255)
            else:
                print("Error, shape: {}, {}".format(iscat_stack[j].shape,
                                                    mask[j].shape))
                break

        print('')
Esempio n. 6
0
def main():
    # Checks for an available graphics card
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        torch.cuda.set_device(device)
        print("Running on GPU {:d}: {:s}".format(
            device, torch.cuda.get_device_name(device)))
    else:
        print("No CUDA device found")
        device = 'cpu'

    # Loading paths to test images and groud truth
    ROOT_TEST_PATH = "data/hand-segmentation/"

    iscat_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'iSCAT/*.tif'))
    tirf_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'tirf/*.tif'))
    cell_seg_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'cell_seg/*.txt'))
    pili_seg_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'pili_seg/*.txt'))

    iscat_files.sort()
    tirf_files.sort()
    cell_seg_files.sort()
    pili_seg_files.sort()

    # Loading UNet models
    unet_tirf = UNetCell(1, 1, device=device, bilinear_upsampling=False)
    unet_tirf.load_state_dict(torch.load('saved_models/bf_unet.pth'))
    unet_tirf.eval()

    unet_iscat = UNetCell(1, 1, device=device)
    unet_iscat.load_state_dict(
        torch.load('saved_models/iscat_unet_augment_before_fluo.pth'))
    unet_iscat.eval()

    unet_pili = UNetPili(1, 1, device=device)
    unet_pili.load_state_dict(
        torch.load('saved_models/pili_unet_augment_16_channels_170.pth'))
    unet_pili.eval()

    # Iterating over the test files
    for i, (iscat, tirf, cell_seg, pili_seg) in enumerate(
            zip(iscat_files, tirf_files, cell_seg_files, pili_seg_files)):

        # Loading tirf and iSCAT images
        iscat_stack = tifffile.imread(iscat)
        tirf_stack = tifffile.imread(tirf)

        # iSCAT preprocessing
        iscat_stack = processing.image_correction(iscat_stack)
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(1, 99))
        iscat_stack = processing.fft_filtering(iscat_stack, 1, 13, True)
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(3, 97))

        # Loading ground truth masks
        mask_cell = cell_mask_from_segmentation(cell_seg).astype('bool')
        mask_pili = pili_mask_from_segmentation(pili_seg).astype('bool')

        # Predicting stacks
        with torch.no_grad():
            torch_tirf = torch.from_numpy(
                (tirf_stack /
                 tirf_stack.max()).astype('float32')).to(device=device)
            torch_iscat = torch.from_numpy(
                (iscat_stack /
                 iscat_stack.max()).astype('float32')).to(device=device)

            pred_cell_tirf = unet_tirf.predict_stack(
                torch_tirf.unsqueeze(1)).squeeze().cpu().numpy()
            pred_cell_iscat = unet_iscat.predict_stack(
                torch_iscat.unsqueeze(1)).squeeze().cpu().numpy()
            pred_pili_iscat = unet_pili.predict_stack(
                torch_iscat.unsqueeze(1)).squeeze().cpu().numpy()

        # Computing metrics of models
        print(f"Image {i+1} metrics:")
        print(
            "Cell_detect (tirf): accuracy={:.3f}, recall={:.3f}, precision={:.3e}, F1 score={:.3e}, IoU={:.3f}"
            .format(*compute_metrics(pred_cell_tirf >= .6, mask_cell)))
        print(
            "Cell_detect (iSCAT): accuracy={:.3f}, recall={:.3f}, precision={:.3e}, F1 score={:.3e}, IoU={:.3f}"
            .format(*compute_metrics(pred_cell_iscat[::2] >= .6, mask_cell)))
        print(
            "Pili_detect (iSCAT): accuracy={:.3f}, recall={:.3f}, precision={:.3e}, F1 score={:.3e}, IoU={:.3e}"
            .format(*compute_metrics(pred_cell_iscat >= .55, mask_pili)))

        # Saving prediction and ground truth
        out_tirf = np.stack([
            np.concatenate([tirf_stack / tirf_stack.max() * 255] * 2, axis=2)
        ] * 3,
                            axis=-1).astype('uint8')
        out_iscat = np.stack([
            np.concatenate([iscat_stack / iscat_stack.max() * 255] * 2, axis=2)
        ] * 3,
                             axis=-1).astype('uint8')

        out_tirf[..., :out_tirf.shape[2] // 2, 1][mask_cell != 0] = 255
        out_tirf[..., out_tirf.shape[2] // 2:, 1][pred_cell_tirf >= .6] = 255

        out_iscat[::2, :, :out_iscat.shape[2] // 2, 1][mask_cell != 0] = 255
        out_iscat[..., :out_iscat.shape[2] // 2, 0][mask_pili != 0] = 255

        out_iscat[..., :, out_iscat.shape[2] // 2:,
                  1][pred_cell_iscat >= .6] = 255
        out_iscat[..., out_iscat.shape[2] // 2:,
                  0][pred_pili_iscat >= .55] = 255

        # Ground truth on the left, net prediction on the right
        mimsave(f'outputs/tirf_truth_pred_{i+1}.gif',
                out_tirf.astype('uint8'),
                fps=20)
        mimsave(f'outputs/iscat_truth_pred_{i+1}.gif',
                out_iscat.astype('uint8'),
                fps=20)