示例#1
0
def build_iscat_fluo_training(iscat_filepaths, fluo_filepaths):
    """Creates iscat training data and target in data/iscat_seg/[REF_FRAMES / MASKS] for the iSCAT cell segmentation task with fluorent images used as training
    
        ARGS:
            iscat_filepaths (list(str)): filepaths of all the iSCAT images to input as returned by utilities.load_data_paths()
            fluo_filepaths (list(str)): filepaths of all the fluo images to input as returned by utilities.load_data_paths()
    """

    print("Building iSCAT/fluo dataset...")

    OUT_PATH = DATA_PATH + 'iscat_fluo/'
    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    fluo_tirf_imgs = (tifffile.imread(fluo_tirf_img_path)
                      for fluo_tirf_img_path in fluo_filepaths)
    fluo_iscat_imgs = (tifffile.imread(fluo_iscat_img_path)
                       for fluo_iscat_img_path in iscat_filepaths)

    min_size, max_size = 1, 13

    for i, (tirf, iscat) in enumerate(zip(fluo_tirf_imgs, fluo_iscat_imgs)):
        tirf = processing.coregister(tirf, 1.38)
        tirf = ndimage.median_filter(tirf, 5)
        tirf = tirf > 300
        tirf = ndimage.binary_closing(
            tirf, processing.structural_element('circle',
                                                (1, 10, 10))).astype('float32')

        iscat = processing.image_correction(iscat)
        iscat = processing.enhance_contrast(iscat,
                                            'stretching',
                                            percentile=(1, 99))
        iscat = processing.fft_filtering(iscat, min_size, max_size, True)
        iscat = processing.enhance_contrast(iscat,
                                            'stretching',
                                            percentile=(3, 97))

        # Discards the first image
        for j in range(1, tirf.shape[0]):
            try:
                if (tirf != 0).any():
                    imageio.imsave(
                        OUT_PATH + 'REF_FRAMES/' +
                        "iscat_{}_{}.png".format(i + 1, j + 1),
                        rescale(iscat[int(j * iscat.shape[0] /
                                          tirf.shape[0])]))
                    imageio.imsave(
                        OUT_PATH +
                        'MASKS/' + "fluo_{}_{}.png".format(i + 1, j + 1),
                        rescale(tirf[j]))

            except IndexError:
                # In case the dimensions don't match -> goes to next slice
                print("IndexError")
                continue
示例#2
0
def build_hand_segmentation_tirf_testval():
    """Creates dataset of manually segmented tirf images for validation and testing"""

    OUT_PATH = DATA_PATH + 'hand_seg_tirf/'
    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    ROOT_TEST_PATH = "data/hand-segmentation/"
    tirf_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'tirf/*.tif'))
    cell_seg_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'cell_seg/*.txt'))

    tirf_files.sort()
    cell_seg_files.sort()

    for i, (tirf, cell_seg) in enumerate(zip(tirf_files, cell_seg_files)):

        tirf_stack = tifffile.imread(tirf)
        mask_cell = utilities.cell_mask_from_segmentation(cell_seg)

        tirf_stack = processing.coregister(tirf_stack, 1.38, np.zeros(
            (3, )), 0.0) if tirf_stack.shape[1:] != (512, 512) else tirf_stack
        mask_cell = processing.coregister(mask_cell, 1.38, np.zeros(
            (3, )), 0.0) if mask_cell.shape[1:] != (512, 512) else mask_cell

        for j in range(0, tirf_stack.shape[0], 4):
            print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1),
                  end=' ' * 5)
            tifffile.imsave(
                os.path.join(OUT_PATH, 'REF_FRAMES/',
                             "stack_{}_{}.png".format(i + 1, j + 1)),
                rescale(tirf_stack[j]))
            tifffile.imsave(
                os.path.join(OUT_PATH, 'MASKS/',
                             "mask_{}_{}.png".format(i + 1, j + 1)),
                mask_cell[j].astype('uint8'))

        print('')
示例#3
0
def build_brigth_field_training(filepaths, sampling=4):
    """Creates bright field training data and target in data/bf_seg/[REF_FRAMES / MASKS] for the bright field cell segmentation task
    
        ARGS:
            filepath (list(str)): filepaths of all the images to input as returned by utilitiues.load_data_paths()
            sampling (int): sampling interval of the saved images (lower storage footprint)        
    """

    OUT_PATH = DATA_PATH + 'bf_seg/'
    bf_stacks = (utilities.load_imgs(path) for path in filepaths)

    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    for i, bf_stack in enumerate(bf_stacks):

        print(bf_stack.shape[1:])

        # bf_stack = bf_stack[::sampling]

        # Change scale from (384, 384) to (512,512)
        if bf_stack.shape[1:] != (512, 512):
            bf_stack = processing.coregister(bf_stack, 1.38, np.zeros((3, )),
                                             0.0)

        print(bf_stack.shape)
        mask = processing.bright_field_segmentation(bf_stack)

        for j in range(0, bf_stack.shape[0], 2):
            if bf_stack[j].shape == mask[j].shape:
                if mask[j].max() == 0: continue

                print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1),
                      end=' ' * 5)
                tifffile.imsave(
                    os.path.join(OUT_PATH, 'REF_FRAMES/',
                                 "stack_{}_{}.png".format(i + 1, j + 1)),
                    rescale(bf_stack[j]))
                tifffile.imsave(
                    os.path.join(OUT_PATH, 'MASKS/',
                                 "mask_{}_{}.png".format(i + 1, j + 1)),
                    mask[j] * 255)
            else:
                print("Error, shape: {}, {}".format(bf_stack[j].shape,
                                                    mask[j].shape))
                break

        print('')
示例#4
0
def import_test_data():
    """Imports the test data in one of the 4 given folders as a generator yielding the stacks one at a time"""

    DATA_PATH = ['/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2019/novembre/20/fliC-_PaQa_Gasket_0/',
                 '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2020/janvier/16/fliC-_PaQa_solid_4/',
                 '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2018/décembre/04/PilH-FliC-_Agar_Microchannel_0_Flo/',
                 '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2020/janvier/17/pilH-fliC-_PaQa_solid_0/']

    DATA_PATH = DATA_PATH[-1]

    tirf_paths = glob.glob(DATA_PATH +'cam1/event[0-9]_tirf/*.tif')
    iscat_paths = glob.glob(DATA_PATH +'cam1/event[0-9]/*PreNbin*.tif')

    fps_iscat, fps_tirf = 100, 50

    for i, (iscat_path, tirf_path) in enumerate(zip(iscat_paths, tirf_paths)):

        iscat = tifffile.imread(iscat_path)[::int(fps_iscat /fps_tirf)]
        tirf = tifffile.imread(tirf_path)

        iscat = processing.image_correction(iscat)
        iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(1, 99))
        iscat = processing.fft_filtering(iscat, 1, 13, True)
        iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(3, 97))
        iscat = iscat.astype('float32')

        tirf = processing.coregister(tirf, 1.38)

        # For fluo tirf images
        # tirf = ndimage.median_filter(tirf, 5)
        # tirf = (tirf > .9 *skimage.filters.threshold_triangle(tirf.ravel())).astype('uint8')
        # tirf = ndimage.binary_opening(tirf, processing.structural_element('circle', (1,10,10))).astype('uint8')
        # tirf = predict_cell_detection(tirf)

        tirf = tirf.astype('float32')
        
        iscat -= iscat.min()
        iscat /= iscat.max()
        
        tirf -= tirf.min()
        tirf /= tirf.max()

        yield iscat, tirf
示例#5
0
def build_iscat_training(bf_filepaths, iscat_filepaths, sampling=4):
    """Creates iscat training data and target in data/iscat_seg/[REF_FRAMES / MASKS] for the iSCAT cell segmentation task
    
        ARGS:
            bf_filepaths (list(str)): filepaths of all the bright field images to input as returned by utilitiues.load_data_paths()            
            iscat_filepaths (list(str)): filepaths of all the iscat images to input as returned by utilitiues.load_data_paths()
            sampling (int): sampling interval of the saved images (lower storage footprint)
    """

    OUT_PATH = DATA_PATH + 'iscat_seg/'
    os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True)
    os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True)

    # Range of non filtered elements [px]
    min_size, max_size = 1, 13

    iscat_stacks = (utilities.load_imgs(path) for path in iscat_filepaths)
    bf_stacks = (utilities.load_imgs(path) for path in bf_filepaths)

    # Returns the metadata of the exwperiments such as frame rate
    metadatas = get_experiments_metadata(iscat_filepaths)

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        torch.cuda.set_device(device)
        print("Running on: {:s}".format(torch.cuda.get_device_name(device)))
        cuda = torch.device('cuda')
    else:
        # Doesn't run on CPU only machines comment if no GPU
        print("No CUDA device found")
        sys.exit(1)

    unet = UNetCell(1, 1, device=cuda, bilinear_upsampling=False)
    unet.load_state_dict(torch.load('outputs/saved_models/bf_unet.pth'))

    for i, (bf_stack, iscat_stack,
            metadata) in enumerate(zip(bf_stacks, iscat_stacks, metadatas)):
        if i < 45: continue

        bf_stack = bf_stack.astype('float32')
        print(bf_stack.shape)
        if bf_stack.shape[1:] != iscat_stack.shape[1:]:
            bf_stack = processing.coregister(bf_stack, 1.38)
            print(bf_stack.shape)

        normalize(bf_stack)

        # Samples iscat image to correct for the difference in framefate
        iscat_stack = iscat_stack[::sampling * int(metadata['iscat_fps'] /
                                                   metadata['tirf_fps'])]

        torch_stack = torch.from_numpy(bf_stack).unsqueeze(1).cuda()
        mask = unet.predict_stack(
            torch_stack).detach().squeeze().cpu().numpy() > 0.05
        mask = morphology.grey_erosion(mask * 255,
                                       structure=processing.structural_element(
                                           'circle', (3, 5, 5)))
        mask = morphology.grey_closing(mask,
                                       structure=processing.structural_element(
                                           'circle', (3, 7, 7)))
        mask = (mask > 50).astype('uint8')

        # Median filtering and normalization
        iscat_stack = processing.image_correction(iscat_stack)

        # Contrast enhancement
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(1, 99))

        # Fourier filtering of image
        iscat_stack = processing.fft_filtering(iscat_stack, min_size, max_size,
                                               True)
        iscat_stack = processing.enhance_contrast(iscat_stack,
                                                  'stretching',
                                                  percentile=(3, 97))

        for j in range(0, min(iscat_stack.shape[0], mask.shape[0]), sampling):
            if iscat_stack[j].shape == mask[j].shape:
                # Doesn't save images without detected cells
                if mask[j].max() == 0: continue

                print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1),
                      end=' ' * 5)
                tifffile.imsave(
                    os.path.join(OUT_PATH, 'REF_FRAMES/',
                                 "stack_{}_{}.png".format(i + 1, j + 1)),
                    rescale(iscat_stack[j]))
                tifffile.imsave(
                    os.path.join(OUT_PATH, 'MASKS/',
                                 "mask_{}_{}.png".format(i + 1, j + 1)),
                    mask[j] * 255)
            else:
                print("Error, shape: {}, {}".format(iscat_stack[j].shape,
                                                    mask[j].shape))
                break

        print('')