def build_iscat_fluo_training(iscat_filepaths, fluo_filepaths): """Creates iscat training data and target in data/iscat_seg/[REF_FRAMES / MASKS] for the iSCAT cell segmentation task with fluorent images used as training ARGS: iscat_filepaths (list(str)): filepaths of all the iSCAT images to input as returned by utilities.load_data_paths() fluo_filepaths (list(str)): filepaths of all the fluo images to input as returned by utilities.load_data_paths() """ print("Building iSCAT/fluo dataset...") OUT_PATH = DATA_PATH + 'iscat_fluo/' os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True) os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True) fluo_tirf_imgs = (tifffile.imread(fluo_tirf_img_path) for fluo_tirf_img_path in fluo_filepaths) fluo_iscat_imgs = (tifffile.imread(fluo_iscat_img_path) for fluo_iscat_img_path in iscat_filepaths) min_size, max_size = 1, 13 for i, (tirf, iscat) in enumerate(zip(fluo_tirf_imgs, fluo_iscat_imgs)): tirf = processing.coregister(tirf, 1.38) tirf = ndimage.median_filter(tirf, 5) tirf = tirf > 300 tirf = ndimage.binary_closing( tirf, processing.structural_element('circle', (1, 10, 10))).astype('float32') iscat = processing.image_correction(iscat) iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(1, 99)) iscat = processing.fft_filtering(iscat, min_size, max_size, True) iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(3, 97)) # Discards the first image for j in range(1, tirf.shape[0]): try: if (tirf != 0).any(): imageio.imsave( OUT_PATH + 'REF_FRAMES/' + "iscat_{}_{}.png".format(i + 1, j + 1), rescale(iscat[int(j * iscat.shape[0] / tirf.shape[0])])) imageio.imsave( OUT_PATH + 'MASKS/' + "fluo_{}_{}.png".format(i + 1, j + 1), rescale(tirf[j])) except IndexError: # In case the dimensions don't match -> goes to next slice print("IndexError") continue
def build_hand_segmentation_tirf_testval(): """Creates dataset of manually segmented tirf images for validation and testing""" OUT_PATH = DATA_PATH + 'hand_seg_tirf/' os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True) os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True) ROOT_TEST_PATH = "data/hand-segmentation/" tirf_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'tirf/*.tif')) cell_seg_files = glob.glob(os.path.join(ROOT_TEST_PATH, 'cell_seg/*.txt')) tirf_files.sort() cell_seg_files.sort() for i, (tirf, cell_seg) in enumerate(zip(tirf_files, cell_seg_files)): tirf_stack = tifffile.imread(tirf) mask_cell = utilities.cell_mask_from_segmentation(cell_seg) tirf_stack = processing.coregister(tirf_stack, 1.38, np.zeros( (3, )), 0.0) if tirf_stack.shape[1:] != (512, 512) else tirf_stack mask_cell = processing.coregister(mask_cell, 1.38, np.zeros( (3, )), 0.0) if mask_cell.shape[1:] != (512, 512) else mask_cell for j in range(0, tirf_stack.shape[0], 4): print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1), end=' ' * 5) tifffile.imsave( os.path.join(OUT_PATH, 'REF_FRAMES/', "stack_{}_{}.png".format(i + 1, j + 1)), rescale(tirf_stack[j])) tifffile.imsave( os.path.join(OUT_PATH, 'MASKS/', "mask_{}_{}.png".format(i + 1, j + 1)), mask_cell[j].astype('uint8')) print('')
def build_brigth_field_training(filepaths, sampling=4): """Creates bright field training data and target in data/bf_seg/[REF_FRAMES / MASKS] for the bright field cell segmentation task ARGS: filepath (list(str)): filepaths of all the images to input as returned by utilitiues.load_data_paths() sampling (int): sampling interval of the saved images (lower storage footprint) """ OUT_PATH = DATA_PATH + 'bf_seg/' bf_stacks = (utilities.load_imgs(path) for path in filepaths) os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True) os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True) for i, bf_stack in enumerate(bf_stacks): print(bf_stack.shape[1:]) # bf_stack = bf_stack[::sampling] # Change scale from (384, 384) to (512,512) if bf_stack.shape[1:] != (512, 512): bf_stack = processing.coregister(bf_stack, 1.38, np.zeros((3, )), 0.0) print(bf_stack.shape) mask = processing.bright_field_segmentation(bf_stack) for j in range(0, bf_stack.shape[0], 2): if bf_stack[j].shape == mask[j].shape: if mask[j].max() == 0: continue print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1), end=' ' * 5) tifffile.imsave( os.path.join(OUT_PATH, 'REF_FRAMES/', "stack_{}_{}.png".format(i + 1, j + 1)), rescale(bf_stack[j])) tifffile.imsave( os.path.join(OUT_PATH, 'MASKS/', "mask_{}_{}.png".format(i + 1, j + 1)), mask[j] * 255) else: print("Error, shape: {}, {}".format(bf_stack[j].shape, mask[j].shape)) break print('')
def import_test_data(): """Imports the test data in one of the 4 given folders as a generator yielding the stacks one at a time""" DATA_PATH = ['/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2019/novembre/20/fliC-_PaQa_Gasket_0/', '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2020/janvier/16/fliC-_PaQa_solid_4/', '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2018/décembre/04/PilH-FliC-_Agar_Microchannel_0_Flo/', '/mnt/plabNAS/Lorenzo/iSCAT/iSCAT Data/Lorenzo/2020/janvier/17/pilH-fliC-_PaQa_solid_0/'] DATA_PATH = DATA_PATH[-1] tirf_paths = glob.glob(DATA_PATH +'cam1/event[0-9]_tirf/*.tif') iscat_paths = glob.glob(DATA_PATH +'cam1/event[0-9]/*PreNbin*.tif') fps_iscat, fps_tirf = 100, 50 for i, (iscat_path, tirf_path) in enumerate(zip(iscat_paths, tirf_paths)): iscat = tifffile.imread(iscat_path)[::int(fps_iscat /fps_tirf)] tirf = tifffile.imread(tirf_path) iscat = processing.image_correction(iscat) iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(1, 99)) iscat = processing.fft_filtering(iscat, 1, 13, True) iscat = processing.enhance_contrast(iscat, 'stretching', percentile=(3, 97)) iscat = iscat.astype('float32') tirf = processing.coregister(tirf, 1.38) # For fluo tirf images # tirf = ndimage.median_filter(tirf, 5) # tirf = (tirf > .9 *skimage.filters.threshold_triangle(tirf.ravel())).astype('uint8') # tirf = ndimage.binary_opening(tirf, processing.structural_element('circle', (1,10,10))).astype('uint8') # tirf = predict_cell_detection(tirf) tirf = tirf.astype('float32') iscat -= iscat.min() iscat /= iscat.max() tirf -= tirf.min() tirf /= tirf.max() yield iscat, tirf
def build_iscat_training(bf_filepaths, iscat_filepaths, sampling=4): """Creates iscat training data and target in data/iscat_seg/[REF_FRAMES / MASKS] for the iSCAT cell segmentation task ARGS: bf_filepaths (list(str)): filepaths of all the bright field images to input as returned by utilitiues.load_data_paths() iscat_filepaths (list(str)): filepaths of all the iscat images to input as returned by utilitiues.load_data_paths() sampling (int): sampling interval of the saved images (lower storage footprint) """ OUT_PATH = DATA_PATH + 'iscat_seg/' os.makedirs(os.path.join(OUT_PATH, 'REF_FRAMES/'), exist_ok=True) os.makedirs(os.path.join(OUT_PATH, 'MASKS/'), exist_ok=True) # Range of non filtered elements [px] min_size, max_size = 1, 13 iscat_stacks = (utilities.load_imgs(path) for path in iscat_filepaths) bf_stacks = (utilities.load_imgs(path) for path in bf_filepaths) # Returns the metadata of the exwperiments such as frame rate metadatas = get_experiments_metadata(iscat_filepaths) if torch.cuda.is_available(): device = torch.cuda.current_device() torch.cuda.set_device(device) print("Running on: {:s}".format(torch.cuda.get_device_name(device))) cuda = torch.device('cuda') else: # Doesn't run on CPU only machines comment if no GPU print("No CUDA device found") sys.exit(1) unet = UNetCell(1, 1, device=cuda, bilinear_upsampling=False) unet.load_state_dict(torch.load('outputs/saved_models/bf_unet.pth')) for i, (bf_stack, iscat_stack, metadata) in enumerate(zip(bf_stacks, iscat_stacks, metadatas)): if i < 45: continue bf_stack = bf_stack.astype('float32') print(bf_stack.shape) if bf_stack.shape[1:] != iscat_stack.shape[1:]: bf_stack = processing.coregister(bf_stack, 1.38) print(bf_stack.shape) normalize(bf_stack) # Samples iscat image to correct for the difference in framefate iscat_stack = iscat_stack[::sampling * int(metadata['iscat_fps'] / metadata['tirf_fps'])] torch_stack = torch.from_numpy(bf_stack).unsqueeze(1).cuda() mask = unet.predict_stack( torch_stack).detach().squeeze().cpu().numpy() > 0.05 mask = morphology.grey_erosion(mask * 255, structure=processing.structural_element( 'circle', (3, 5, 5))) mask = morphology.grey_closing(mask, structure=processing.structural_element( 'circle', (3, 7, 7))) mask = (mask > 50).astype('uint8') # Median filtering and normalization iscat_stack = processing.image_correction(iscat_stack) # Contrast enhancement iscat_stack = processing.enhance_contrast(iscat_stack, 'stretching', percentile=(1, 99)) # Fourier filtering of image iscat_stack = processing.fft_filtering(iscat_stack, min_size, max_size, True) iscat_stack = processing.enhance_contrast(iscat_stack, 'stretching', percentile=(3, 97)) for j in range(0, min(iscat_stack.shape[0], mask.shape[0]), sampling): if iscat_stack[j].shape == mask[j].shape: # Doesn't save images without detected cells if mask[j].max() == 0: continue print("\rSaving to stack_{}_{}.png".format(i + 1, j + 1), end=' ' * 5) tifffile.imsave( os.path.join(OUT_PATH, 'REF_FRAMES/', "stack_{}_{}.png".format(i + 1, j + 1)), rescale(iscat_stack[j])) tifffile.imsave( os.path.join(OUT_PATH, 'MASKS/', "mask_{}_{}.png".format(i + 1, j + 1)), mask[j] * 255) else: print("Error, shape: {}, {}".format(iscat_stack[j].shape, mask[j].shape)) break print('')