def _get_tiled_data(self): """ If tile directory already exists, check which channels have been processed and only tile new channels. :return dataframe tiled_meta: Metadata with previously tiled channels :return list of lists tile_indices: Nbr tiles x 4 indices with row start + stop and column start + stop indices """ if self.tiles_exist: tiled_meta = aux_utils.read_meta(self.tile_dir) # Find untiled channels tiled_channels = np.unique(tiled_meta['channel_idx']) new_channels = list(set(self.channel_ids) - set(tiled_channels)) if len(new_channels) == 0: print('All channels in config have already been tiled') return self.channel_ids = new_channels tile_indices = self._get_tile_indices( tiled_meta=tiled_meta, time_idx=self.time_ids[0], channel_idx=tiled_channels[0], pos_idx=self.pos_ids[0], slice_idx=self.slice_ids[0]) else: tiled_meta = self._get_dataframe() tile_indices = None return tiled_meta, tile_indices
def __init__(self, input_dir, output_dir, channel_ids, slice_ids, block_size=32): """ Flatfield images are estimated once per channel for 2D data :param str input_dir: Directory with 2D image frames from dataset :param str output_dir: Base output directory :param int/list channel_ids: channel ids for flat field_correction :param int/list slice_ids: Z slice indices for flatfield correction :param int block_size: Size of blocks image will be divided into """ self.input_dir = input_dir self.output_dir = output_dir # Create flat_field_dir as a subdirectory of output_dir self.flat_field_dir = os.path.join(self.output_dir, 'flat_field_images') os.makedirs(self.flat_field_dir, exist_ok=True) self.slice_ids = slice_ids self.frames_metadata = aux_utils.read_meta(self.input_dir) metadata_ids, _ = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, channel_ids=channel_ids, slice_ids=slice_ids, ) self.channels_ids = metadata_ids['channel_ids'] self.slice_ids = metadata_ids['slice_ids'] if block_size is None: block_size = 32 self.block_size = block_size
def test_read_meta(): with TempDirectory() as tempdir: meta_fname = 'test_meta.csv' meta_df.to_csv(os.path.join(tempdir.path, meta_fname)) test_meta = aux_utils.read_meta(tempdir.path, meta_fname) # Only testing file name as writing df changes dtypes nose.tools.assert_true(test_meta['file_name'].equals( meta_df['file_name']))
def __init__(self, input_dir, output_dir, scale_factor, channel_ids=-1, time_ids=-1, slice_ids=-1, pos_ids=-1, int2str_len=3, num_workers=4, flat_field_dir=None): """ :param str input_dir: Directory with image frames :param str output_dir: Base output directory :param float/list scale_factor: Scale factor for resizing frames. :param int/list channel_ids: Channel indices to resize (default -1 includes all slices) :param int/list time_ids: timepoints to use :param int/list slice_ids: Index of slice (z) indices to use :param int/list pos_ids: Position (FOV) indices to use :param int int2str_len: Length of str when converting ints :param int num_workers: number of workers for multiprocessing :param str flat_field_dir: dir with flat field images """ self.input_dir = input_dir self.output_dir = output_dir if isinstance(scale_factor, list): scale_factor = np.array(scale_factor) assert np.all(scale_factor > 0), \ "Scale factor should be positive float, not {}".format(scale_factor) self.scale_factor = scale_factor self.frames_metadata = aux_utils.read_meta(self.input_dir) metadata_ids, _ = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, ) self.time_ids = metadata_ids['time_ids'] self.channel_ids = metadata_ids['channel_ids'] self.slice_ids = metadata_ids['slice_ids'] self.pos_ids = metadata_ids['pos_ids'] # Create resize_dir as a subdirectory of output_dir self.resize_dir = os.path.join( self.output_dir, 'resized_images', ) os.makedirs(self.resize_dir, exist_ok=True) self.int2str_len = int2str_len self.num_workers = num_workers self.flat_field_dir = flat_field_dir
def test_pre_process_resize2d(self): cur_config = self.pp_config cur_config['resize'] = { 'scale_factor': 2, 'resize_3d': False, } cur_config['make_weight_map'] = False out_config, runtime = pp.pre_process(cur_config, self.base_config) self.assertIsInstance(runtime, np.float) self.assertEqual( out_config['resize']['resize_dir'], os.path.join(self.output_dir, 'resized_images') ) resize_dir = out_config['resize']['resize_dir'] # Check that all images have been resized resize_meta = aux_utils.read_meta(resize_dir) # 3 resized channels expected_rows = 3 * len(self.slice_ids) * len(self.pos_ids) self.assertEqual(resize_meta.shape[0], expected_rows) # Load an image and make sure it's twice as big im = cv2.imread( os.path.join(resize_dir, 'im_c003_z002_t000_p010.png'), cv2.IMREAD_ANYDEPTH, ) self.assertTupleEqual(im.shape, (60, 40)) self.assertTrue(im.dtype, np.uint8) # There should now be 2*2 the amount of tiles, same shape tile_dir = out_config['tile']['tile_dir'] tile_meta = aux_utils.read_meta(tile_dir) # 4 processed channels (0, 1, 3, 4), 24 tiles per image expected_rows = 4 * 24 * len(self.slice_ids) * len(self.pos_ids) self.assertEqual(tile_meta.shape[0], expected_rows) # Load a tile and assert shape im = np.load(os.path.join( tile_dir, 'im_c001_z000_t000_p007_r40-50_c20-30_sl0-1.npy', )) self.assertTupleEqual(im.shape, (1, 10, 10)) self.assertTrue(im.dtype == np.float64)
def _get_input_fnames(self, time_idx, channel_idx, slice_idx, pos_idx, mask_dir=None): """Get input_fnames :param int time_idx: Time index :param int channel_idx: Channel index :param int slice_idx: Slice (z) index :param int pos_idx: Position (FOV) index :param str mask_dir: Directory containing masks :return: list of input fnames """ if mask_dir is None: depth = self.channel_depth[channel_idx] else: depth = self.mask_depth margin = 0 if depth == 1 else depth // 2 im_fnames = [] for z in range(slice_idx - margin, slice_idx + margin + 1): if mask_dir is not None: mask_meta = aux_utils.read_meta(mask_dir) meta_idx = aux_utils.get_meta_idx( mask_meta, time_idx, channel_idx, z, pos_idx, ) file_path = os.path.join( mask_dir, mask_meta.loc[meta_idx, 'file_name'], ) else: meta_idx = aux_utils.get_meta_idx( self.frames_metadata, time_idx, channel_idx, z, pos_idx, ) file_path = os.path.join( self.input_dir, self.frames_metadata.loc[meta_idx, 'file_name'], ) # check if file_path exists im_fnames.append(file_path) return im_fnames
def test_validate_mask_meta_no_channel(self): mask_channel = preprocess_utils.validate_mask_meta( mask_dir=self.mask_dir, input_dir=self.input_dir, csv_name=self.csv_name, ) self.assertEqual(mask_channel, 4) out_meta = aux_utils.read_meta(self.mask_dir) for i, row in out_meta.iterrows(): self.assertEqual(row.slice_idx, self.slice_idx) self.assertEqual(row.time_idx, self.time_idx) self.assertEqual(row.channel_idx, 4) self.assertEqual(row.pos_idx, i) self.assertEqual(row.file_name, "mask_{}.png".format(i + 1))
def _get_split_ids(self, data_split='test'): """ Get the indices for data_split :param str data_split: in [train, val, test] :return list inference_ids: Indices for inference given data split :return str split_col: Dataframe column name, which was split in training """ split_col = self.config['dataset']['split_by_column'] try: split_fname = os.path.join(self.model_dir, 'split_samples.json') split_samples = aux_utils.read_json(split_fname) inference_ids = split_samples[data_split] except FileNotFoundError as e: print("No split_samples file. " "Will predict all images in dir.") frames_meta = aux_utils.read_meta(self.image_dir) inference_ids = np.unique(frames_meta[split_col]).tolist() return split_col, inference_ids
def validate_mask_meta(mask_dir, input_dir, csv_name=None, mask_channel=None): """ If user provides existing masks, the mask directory should also contain a csv file (not named frames_meta.csv which is reserved for output) with two column names: mask_name and file_name. Each row should describe the mask name and the corresponding file name. Each file_name should exist in input_dir and belong to the same channel. This function checks that all file names exist in input_dir and writes a frames_meta csv containing mask names with indices corresponding to the matched file_name. It also assigns a mask channel number for future preprocessing steps like tiling. :param str mask_dir: Mask directory :param str input_dir: Input image directory, to match masks with images :param int/None mask_channel: Channel idx assigned to masks :return int mask_channel: New channel index for masks for writing tiles :raises IOError: If no csv file is present in mask_dir :raises IOError: If more than one csv file exists in mask_dir and no csv_name is provided to resolve ambiguity :raises AssertionError: If csv doesn't consist of two columns named 'mask_name' and 'file_name' :raises IndexError: If unable to match file_name in mask_dir csv with file_name in input_dir for any given mask row """ input_meta = aux_utils.read_meta(input_dir) if mask_channel is None: mask_channel = int( input_meta['channel_idx'].max() + 1 ) # Make sure there is a csv file file if csv_name is not None: csv_name = glob.glob(os.path.join(mask_dir, csv_name)) if len(csv_name) == 1: # Use the one existing csv name csv_name = csv_name[0] else: csv_name = None # No csv name given, search for it if csv_name is None: csv_name = glob.glob(os.path.join(mask_dir, '*.csv')) if len(csv_name) == 0: raise IOError("No csv file present in mask dir") else: # See if frames_meta is already present, if so, move on has_meta = next((s for s in csv_name if 'frames_meta.csv' in s), None) if isinstance(has_meta, str): # Return existing mask channel from frames_meta frames_meta = pd.read_csv( os.path.join(mask_dir, 'frames_meta.csv'), ) mask_channel = np.unique(frames_meta['channel_idx']) if isinstance(mask_channel, list): assert len(mask_channel) == 1,\ "Found more than one mask channel: {}".format(mask_channel) mask_channel = mask_channel[0] if type(mask_channel).__module__ == 'numpy': mask_channel = mask_channel.item() return mask_channel elif len(csv_name) == 1: # Use the one existing csv name csv_name = csv_name[0] else: # More than one csv file in dir raise IOError("More than one csv file present in mask dir", "use csv_name to specify which one to use") # Read csv with masks and corresponding input file names mask_meta = aux_utils.read_meta(input_dir=mask_dir, meta_fname=csv_name) assert len(set(mask_meta).difference({'file_name', 'mask_name'})) == 0,\ "mask csv should have columns mask_name and file_name " +\ "(corresponding to the file_name in input_dir)" # Check that file_name for each mask_name matches files in input_dir file_names = input_meta['file_name'] # Create dataframe that will store all indices for masks out_meta = aux_utils.make_dataframe(nbr_rows=mask_meta.shape[0]) for i, row in mask_meta.iterrows(): try: file_loc = file_names[file_names == row.file_name].index[0] except IndexError as e: msg = "Can't find image file name match for {}, error {}".format( row.file_name, e) raise IndexError(msg) # Fill dataframe with row indices from matched image in input dir out_meta.iloc[i] = input_meta.iloc[file_loc] # Write back the mask name out_meta.iloc[i]['file_name'] = row.mask_name assert len(out_meta.channel_idx.unique()) == 1,\ "Masks should match one input channel only" assert mask_channel not in set(input_meta.channel_idx.unique()),\ "Mask channel {} already exists in image dir".format(mask_channel) # Replace channel_idx new mask channel idx out_meta['channel_idx'] = mask_channel # Write mask metadata with indices that match input images meta_filename = os.path.join(mask_dir, 'frames_meta.csv') out_meta.to_csv(meta_filename, sep=",") return mask_channel
def __init__(self, input_dir, output_dir, channel_ids, flat_field_dir=None, time_ids=-1, slice_ids=-1, pos_ids=-1, int2str_len=3, uniform_struct=True, num_workers=4, mask_type='otsu', mask_channel=None, mask_ext='.npy', normalize_im=False): """ :param str input_dir: Directory with image frames :param str output_dir: Base output directory :param list[int] channel_ids: Channel indices to be masked (typically just one) :param str flat_field_dir: Directory with flatfield images if flatfield correction is applied :param int/list channel_ids: generate mask from the sum of these (flurophore) channel indices :param list/int time_ids: timepoints to consider :param int slice_ids: Index of which focal plane (z) acquisition to use (default -1 includes all slices) :param int pos_ids: Position (FOV) indices to use :param int int2str_len: Length of str when converting ints :param bool uniform_struct: bool indicator for same structure across pos and time points :param int num_workers: number of workers for multiprocessing :param str mask_type: method to use for generating mask. Needed for mapping to the masking function :param int mask_channel: channel number assigned to to be generated masks. If resizing images on a subset of channels, frames_meta is from resize dir, which could lead to wrong mask channel being assigned. :param str mask_ext: '.npy' or 'png'. Save the mask as uint8 PNG or NPY files :param bool normalize_im: indicator to normalize image based on z-score or not """ self.input_dir = input_dir self.output_dir = output_dir self.flat_field_dir = flat_field_dir self.num_workers = num_workers self.normalize_im = normalize_im self.frames_metadata = aux_utils.read_meta(self.input_dir) # Create a unique mask channel number so masks can be treated # as a new channel if mask_channel is None: self.mask_channel = int(self.frames_metadata['channel_idx'].max() + 1) else: self.mask_channel = int(mask_channel) metadata_ids, nested_id_dict = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, uniform_structure=uniform_struct, ) self.time_ids = metadata_ids['time_ids'] self.channel_ids = metadata_ids['channel_ids'] self.slice_ids = metadata_ids['slice_ids'] self.pos_ids = metadata_ids['pos_ids'] # Create mask_dir as a subdirectory of output_dir self.mask_dir = os.path.join( self.output_dir, 'mask_channels_' + '-'.join(map(str, self.channel_ids)), ) os.makedirs(self.mask_dir, exist_ok=True) self.int2str_len = int2str_len self.uniform_struct = uniform_struct self.nested_id_dict = nested_id_dict assert mask_type in ['otsu', 'unimodal', 'borders_weight_loss_map'], \ 'Masking method invalid, Otsu, borders_weight_loss_map, " +\ "and unimodal are currently supported' self.mask_type = mask_type self.mask_ext = mask_ext
def tile_mask_stack(self, mask_dir, mask_channel, min_fraction, mask_depth=1): """ Tiles images in the specified channels assuming there are masks already created in mask_dir. Only tiles above a certain fraction of foreground in mask tile will be saved and added to metadata. Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx', 'slice_idx', 'file_name'] for all the tiles :param str mask_dir: Directory containing masks :param int mask_channel: Channel number assigned to mask :param float min_fraction: Minimum fraction of foreground in tiled masks :param int mask_depth: Depth for mask channel """ # mask depth has to match input or ouput channel depth assert mask_depth <= max(self.channel_depth.values()) self.mask_depth = mask_depth # tile and save masks # if mask channel is already tiled if self.tiles_exist and mask_channel in self.channel_ids: mask_meta_df = pd.read_csv( os.path.join(self.tile_dir, 'frames_meta.csv')) else: # TODO: different masks across timepoints (but MaskProcessor # generates mask for tp=0 only) mask_fn_args = [] for slice_idx in self.slice_ids: for time_idx in self.time_ids: for pos_idx in self.pos_ids: # Evaluate mask, then channels.The masks will influence # tiling indices, so it's not allowed to add masks to # existing tiled data sets (indices will be retrieved # from existing meta) cur_args = self.get_crop_tile_args( channel_idx=mask_channel, time_idx=time_idx, slice_idx=slice_idx, pos_idx=pos_idx, task_type='tile', mask_dir=mask_dir, min_fraction=min_fraction, normalize_im=False, ) mask_fn_args.append(cur_args) # tile_image uses min_fraction assuming input_image is a bool mask_meta_df_list = mp_utils.mp_tile_save( mask_fn_args, workers=self.num_workers, ) mask_meta_df = pd.concat(mask_meta_df_list, ignore_index=True) # Finally, save all the metadata mask_meta_df = mask_meta_df.sort_values(by=['file_name']) mask_meta_df.to_csv( os.path.join(self.tile_dir, 'frames_meta.csv'), sep=',', ) # remove mask_channel from self.channel_ids if included _ = [ self.channel_ids.pop(idx) for idx, val in enumerate(self.channel_ids) if val == mask_channel ] _ = [ self.normalize_channels.pop(idx) for idx, val in enumerate(self.channel_ids) if val == mask_channel ] fn_args = [] for slice_idx in self.slice_ids: for time_idx in self.time_ids: for pos_idx in np.unique(self.frames_metadata["pos_idx"]): # Loop through all channels and tile from indices cur_tile_indices = self._get_tile_indices( tiled_meta=mask_meta_df, time_idx=time_idx, channel_idx=mask_channel, pos_idx=pos_idx, slice_idx=slice_idx) if np.any(cur_tile_indices): for i, channel_idx in enumerate(self.channel_ids): cur_args = self.get_crop_tile_args( channel_idx, time_idx, slice_idx, pos_idx, task_type='crop', tile_indices=cur_tile_indices, normalize_im=self.normalize_channels[i], ) fn_args.append(cur_args) tiled_meta_df_list = mp_utils.mp_crop_save( fn_args, workers=self.num_workers, ) tiled_metadata = pd.concat(tiled_meta_df_list, ignore_index=True) # If there's been tiling done already, add to existing metadata prev_tiled_metadata = aux_utils.read_meta(self.tile_dir) tiled_metadata = pd.concat( [ prev_tiled_metadata.reset_index(drop=True), tiled_metadata.reset_index(drop=True) ], axis=0, ignore_index=True, ) # Finally, save all the metadata tiled_metadata = tiled_metadata.sort_values(by=['file_name']) tiled_metadata.to_csv( os.path.join(self.tile_dir, "frames_meta.csv"), sep=',', )
def test_pre_process(self): out_config, runtime = pp.pre_process(self.pp_config, self.base_config) self.assertIsInstance(runtime, np.float) self.assertEqual( self.base_config['input_dir'], self.image_dir, ) self.assertEqual( self.base_config['channel_ids'], self.pp_config['channel_ids'], ) self.assertEqual( out_config['flat_field']['flat_field_dir'], os.path.join(self.output_dir, 'flat_field_images') ) self.assertEqual( out_config['masks']['mask_dir'], os.path.join(self.output_dir, 'mask_channels_3') ) self.assertEqual( out_config['tile']['tile_dir'], os.path.join(self.output_dir, 'tiles_10-10_step_10-10'), ) # Make sure new mask channel assignment is correct self.assertEqual(out_config['masks']['mask_channel'], 4) # Check that masks are generated mask_dir = out_config['masks']['mask_dir'] mask_meta = aux_utils.read_meta(mask_dir) mask_names = os.listdir(mask_dir) mask_names.pop(mask_names.index('frames_meta.csv')) # Validate that all masks are there self.assertEqual( len(mask_names), len(self.slice_ids) * len(self.pos_ids), ) for p in self.pos_ids: for z in self.slice_ids: im_name = aux_utils.get_im_name( channel_idx=out_config['masks']['mask_channel'], slice_idx=z, time_idx=self.time_idx, pos_idx=p, ) im = cv2.imread( os.path.join(mask_dir, im_name), cv2.IMREAD_ANYDEPTH, ) self.assertTupleEqual(im.shape, (30, 20)) self.assertTrue(im.dtype == 'uint8') self.assertTrue(im_name in mask_names) self.assertTrue(im_name in mask_meta['file_name'].tolist()) # Check flatfield images ff_dir = out_config['flat_field']['flat_field_dir'] ff_names = os.listdir(ff_dir) self.assertEqual(len(ff_names), 3) for processed_channel in [0, 1, 3]: expected_name = "flat-field_channel-{}.npy".format(processed_channel) self.assertTrue(expected_name in ff_names) im = np.load(os.path.join(ff_dir, expected_name)) self.assertTrue(im.dtype == np.float64) self.assertTupleEqual(im.shape, (30, 20)) # Check tiles tile_dir = out_config['tile']['tile_dir'] tile_meta = aux_utils.read_meta(tile_dir) # 4 processed channels (0, 1, 3, 4), 6 tiles per image expected_rows = 4 * 6 * len(self.slice_ids) * len(self.pos_ids) self.assertEqual(tile_meta.shape[0], expected_rows) # Check indices self.assertListEqual( tile_meta.channel_idx.unique().tolist(), [0, 1, 3, 4], ) self.assertListEqual( tile_meta.pos_idx.unique().tolist(), self.pos_ids, ) self.assertListEqual( tile_meta.slice_idx.unique().tolist(), self.slice_ids, ) self.assertListEqual( tile_meta.time_idx.unique().tolist(), [self.time_idx], ) self.assertListEqual( list(tile_meta), ['channel_idx', 'col_start', 'file_name', 'pos_idx', 'row_start', 'slice_idx', 'time_idx'] ) self.assertListEqual( tile_meta.row_start.unique().tolist(), [0, 10, 20], ) self.assertListEqual( tile_meta.col_start.unique().tolist(), [0, 10], ) # Read one tile and check format # r = row start/end idx, c = column start/end, sl = slice start/end # sl0-1 signifies depth of 1 im = np.load(os.path.join( tile_dir, 'im_c001_z000_t000_p007_r10-20_c10-20_sl0-1.npy', )) self.assertTupleEqual(im.shape, (1, 10, 10)) self.assertTrue(im.dtype == np.float64)
def __init__(self, input_dir, output_dir, tile_size=[256, 256], step_size=[64, 64], depths=1, time_ids=-1, channel_ids=-1, normalize_channels=-1, slice_ids=-1, pos_ids=-1, hist_clip_limits=None, flat_field_dir=None, image_format='zyx', num_workers=4, int2str_len=3, tile_3d=False): """ Tiles images. If tile_dir already exist, it will check which channels are already tiled, get indices from them and tile from indices only on the channels not already present. :param str input_dir: Directory with frames to be tiled :param str output_dir: Base output directory :param list tile_size: size of the blocks to be cropped from the image :param list step_size: size of the window shift. In case of no overlap, the step size is tile_size. If overlap, step_size < tile_size :param int/list depths: The z depth for generating stack training data Default 1 assumes 2D data for all channels to be tiled. For cases where input and target shapes are not the same (e.g. stack to 2D) you should specify depths for each channel in tile.channels. :param list/int time_ids: Tile given timepoint indices :param list/int channel_ids: Tile images in the given channel indices default=-1, tile all channels. :param list/int normalize_channels: list of booleans matching channel_ids indicating if channel should be normalized or not. :param int slice_ids: Index of which focal plane acquisition to use (for 2D). default=-1 for the whole z-stack :param int pos_ids: Position (FOV) indices to use :param list hist_clip_limits: lower and upper percentiles used for histogram clipping. :param str flat_field_dir: Flatfield directory. None if no flatfield correction :param str image_format: zyx (preferred) or xyz :param int num_workers: number of workers for multiprocessing :param int int2str_len: number of characters for each idx to be used in file names :param bool tile_3d: Whether tiling is 3D or 2D """ self.input_dir = input_dir self.output_dir = output_dir self.normalize_channels = normalize_channels self.depths = depths self.tile_size = tile_size self.step_size = step_size self.hist_clip_limits = hist_clip_limits self.image_format = image_format assert self.image_format in {'zyx', 'xyz'}, \ 'Data format must be zyx or xyz' self.num_workers = num_workers self.int2str_len = int2str_len self.tile_3d = tile_3d self.str_tile_step = 'tiles_{}_step_{}'.format( '-'.join([str(val) for val in tile_size]), '-'.join([str(val) for val in step_size]), ) self.tile_dir = os.path.join( output_dir, self.str_tile_step, ) # If tile dir already exist, only tile channels not already present self.tiles_exist = False # If tile dir already exist, things could get messy because we don't # have any checks in place for how to add to existing tiles try: os.makedirs(self.tile_dir, exist_ok=False) # make dir for saving indiv meta per image, could be used for # tracking job success / fail os.makedirs(os.path.join(self.tile_dir, 'meta_dir'), exist_ok=False) except FileExistsError as e: print("Tile dir exists. Only add untiled channels.") self.tiles_exist = True # make dir for saving individual meta per image, could be used for # tracking job success / fail os.makedirs(os.path.join(self.tile_dir, 'meta_dir'), exist_ok=True) self.flat_field_dir = flat_field_dir self.frames_metadata = aux_utils.read_meta(self.input_dir) # Get metadata indices metadata_ids, _ = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, uniform_structure=True) self.channel_ids = metadata_ids['channel_ids'] self.normalize_channels = normalize_channels self.time_ids = metadata_ids['time_ids'] self.slice_ids = metadata_ids['slice_ids'] self.pos_ids = metadata_ids['pos_ids'] self.normalize_channels = normalize_channels # Determine which channels should be normalized in tiling if self.normalize_channels == -1: self.normalize_channels = [True] * len(self.channel_ids) else: assert len(self.normalize_channels) == len(self.channel_ids),\ "Channel ids {} and normalization list {} mismatch".format( self.channel_ids, self.normalize_channels, ) # If more than one depth is specified, length must match channel ids if isinstance(self.depths, list): assert len(self.depths) == len(self.channel_ids),\ "depths ({}) and channels ({}) length mismatch".format( self.depths, self.channel_ids, ) # Get max of all specified depths max_depth = max(self.depths) # Convert channels + depths to dict for lookup self.channel_depth = dict(zip(self.channel_ids, self.depths)) else: # If depth is scalar, make depth the same for all channels max_depth = self.depths self.channel_depth = dict( zip(self.channel_ids, [self.depths] * len(self.channel_ids)), ) # Adjust slice margins self.slice_ids = aux_utils.adjust_slice_margins( slice_ids=self.slice_ids, depth=max_depth, )
def tile_mask_stack(self, mask_dir, mask_channel, min_fraction, mask_depth=1): """ Tiles images in the specified channels assuming there are masks already created in mask_dir. Only tiles above a certain fraction of foreground in mask tile will be saved and added to metadata. Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx', 'slice_idx', 'file_name'] for all the tiles :param str mask_dir: Directory containing masks :param int mask_channel: Channel number assigned to mask :param float min_fraction: Min fraction of foreground in tiled masks :param int mask_depth: Depth for mask channel """ # mask depth has to match input or output channel depth assert mask_depth <= max(self.channel_depth.values()) self.mask_depth = mask_depth # Mask meta is stored in mask dir. If channel_ids= -1,frames_meta will # not contain any rows for mask channel. Assuming structure is same # across channels. Get time, pos and slice indices for mask channel mask_meta_df = aux_utils.read_meta(mask_dir) # TODO: different masks across timepoints (but MaskProcessor generates # mask for tp=0 only) _, mask_nested_id_dict = aux_utils.validate_metadata_indices( frames_metadata=mask_meta_df, time_ids=self.time_ids, channel_ids=mask_channel, slice_ids=self.slice_ids, pos_ids=self.pos_ids, uniform_structure=False ) # get t, z, p indices for mask_channel mask_ch_ids = {} for tp_idx, tp_dict in mask_nested_id_dict.items(): for ch_idx, ch_dict in tp_dict.items(): if ch_idx == mask_channel: ch0_dict = {mask_channel: ch_dict} mask_ch_ids[tp_idx] = ch0_dict # tile mask channel and use the tile indices to tile the rest meta_df = self.tile_first_channel( channel0_ids=mask_ch_ids, channel0_depth=mask_depth, cur_mask_dir=mask_dir, min_fraction=min_fraction, is_mask=True, ) # tile the rest self.tile_remaining_channels( nested_id_dict=self.nested_id_dict, tiled_ch_id=mask_channel, cur_meta_df=meta_df, )
def __init__(self, image_dir, dataset_config, network_config, split_col_ids, image_format='zyx', mask_dir=None, flat_field_dir=None): """Init :param str image_dir: dir containing images AND NOT TILES! :param dict dataset_config: dict with dataset related params :param dict network_config: dict with network related params :param tuple split_col_ids: How to split up the dataset for inference: for frames_meta: (str split column name, list split row indices) :param str image_format: xyz or zyx format :param str/None mask_dir: If inference targets are masks stored in a different directory than the image dir. Assumes the directory contains a frames_meta.csv containing mask channels (which will be target channels in the inference config) z, t, p indices matching the ones in image_dir :param str flat_field_dir: Directory with flat field images """ self.image_dir = image_dir self.target_dir = image_dir self.frames_meta = aux_utils.read_meta(self.image_dir) self.flat_field_dir = flat_field_dir if mask_dir is not None: self.target_dir = mask_dir # Append mask meta to frames meta mask_meta = aux_utils.read_meta(mask_dir) self.frames_meta = self.frames_meta.append( mask_meta, ignore_index=True, ) # Use only indices selected for inference (split_col, split_ids) = split_col_ids meta_ids = self.frames_meta[split_col].isin(split_ids) self.frames_meta = self.frames_meta[meta_ids] assert image_format in {'xyz', 'zyx'}, \ "Image format should be xyz or zyx, not {}".format(image_format) self.image_format = image_format # Check if model task (regression or segmentation) is specified self.model_task = 'regression' if 'model_task' in dataset_config: self.model_task = dataset_config['model_task'] assert self.model_task in {'regression', 'segmentation'}, \ "Model task must be either 'segmentation' or 'regression'" # the raw input images have to be normalized (z-score typically) self.normalize = True if self.model_task == 'regression' else False self.input_channels = dataset_config['input_channels'] self.target_channels = dataset_config['target_channels'] # get a subset of frames meta for only one channel to easily # extract indices (pos, time, slice) to iterate over df_idx = (self.frames_meta['channel_idx'] == self.target_channels[0]) self.iteration_meta = self.frames_meta.copy() self.iteration_meta = self.iteration_meta[df_idx] self.depth = 1 self.target_depth = 1 # adjust slice margins if stacktostack or stackto2d network_cls = network_config['class'] if network_cls in ['UNetStackTo2D', 'UNetStackToStack']: self.depth = network_config['depth'] self.adjust_slice_indices() # if Unet2D 4D tensor, remove the singleton dimension, else 5D self.squeeze = False if network_cls == 'UNet2D': self.squeeze = True self.im_3d = False if network_cls == 'UNet3D': self.im_3d = True self.data_format = 'channels_first' if 'data_format' in network_config: self.data_format = network_config['data_format'] # check if sorted values look right self.iteration_meta = self.iteration_meta.sort_values( ['pos_idx', 'slice_idx'], ascending=[True, True], ) self.iteration_meta = self.iteration_meta.reset_index(drop=True) self.num_samples = len(self.iteration_meta)
def __init__(self, train_config, inference_config, gpu_id=-1, gpu_mem_frac=None): """Init :param dict train_config: Training config dict with params related to dataset, trainer and network :param dict inference_config: Read yaml file with following parameters: str model_dir: Path to model directory str/None model_fname: File name of weights in model dir (.hdf5). If left out, latest weights file will be selected. str image_dir: dir containing input images AND NOT TILES! str data_split: Which data (train/test/val) to run inference on. (default = test) dict images: str image_format: 'zyx' or 'xyz' str/None flat_field_dir: flatfield directory str im_ext: For writing predictions e.g. '.png' or '.npy' or '.tiff' FOR 3D IMAGES USE NPY AS PNG AND TIFF ARE CURRENTLY NOT SUPPORTED. list crop_shape: center crop the image to a specified shape before tiling for inference dict metrics: list metrics_list: list of metrics to estimate. available metrics: [ssim, corr, r2, mse, mae}] list metrics_orientations: xy, xyz, xz or yz (see evaluation_metrics.py for description of orientations) dict masks: dict with keys str mask_dir: Mask directory containing a frames_meta.csv containing mask channels (which will be target channels in the inference config) z, t, p indices matching the ones in image_dir. Mask dirs are often generated or have frames_meta added to them during preprocessing. str mask_type: 'target' for segmentation, 'metrics' for weighted int mask_channel: mask channel as in training dict inference_3d: dict with params for 3D inference with keys: num_slices, inf_shape, tile_shape, num_overlap, overlap_operation. int num_slices: in case of 3D, the full volume will not fit in GPU memory, specify the number of slices to use and this will depend on the network depth, for ex 8 for a network of depth 4. list inf_shape: inference on a center sub volume. list tile_shape: shape of tile for tiling along xyz. int/list num_overlap: int for tile_z, list for tile_xyz str overlap_operation: e.g. 'mean' :param int gpu_id: GPU number to use. -1 for debugging (no GPU) :param float/None gpu_mem_frac: Memory fractions to use corresponding to gpu_ids """ # Use model_dir from inference config if present, otherwise use train if 'model_dir' in inference_config: model_dir = inference_config['model_dir'] else: model_dir = train_config['trainer']['model_dir'] if 'model_fname' in inference_config: model_fname = inference_config['model_fname'] else: # If model filename not listed, grab latest one fnames = [ f for f in os.listdir(inference_config['model_dir']) if f.endswith('.hdf5') ] assert len(fnames) > 0, 'No weight files found in model dir' fnames = natsort.natsorted(fnames) model_fname = fnames[-1] self.config = train_config self.model_dir = model_dir self.image_dir = inference_config['image_dir'] # Set default for data split, determine column name and indices data_split = 'test' if 'data_split' in inference_config: data_split = inference_config['data_split'] assert data_split in ['train', 'val', 'test'], \ 'data_split not in [train, val, test]' split_col_ids = self._get_split_ids(data_split) self.data_format = self.config['network']['data_format'] assert self.data_format in {'channels_first', 'channels_last'}, \ "Data format should be channels_first/last" flat_field_dir = None images_dict = inference_config['images'] if 'flat_field_dir' in images_dict: flat_field_dir = images_dict['flat_field_dir'] # Set defaults self.image_format = 'zyx' if 'image_format' in images_dict: self.image_format = images_dict['image_format'] self.image_ext = '.png' if 'image_ext' in images_dict: self.image_ext = images_dict['image_ext'] # Create image subdirectory to write predicted images self.pred_dir = os.path.join(self.model_dir, 'predictions') os.makedirs(self.pred_dir, exist_ok=True) # Handle masks as either targets or for masked metrics self.masks_dict = None self.mask_metrics = False self.mask_dir = None self.mask_meta = None target_dir = None if 'masks' in inference_config: self.masks_dict = inference_config['masks'] if self.masks_dict is not None: assert 'mask_channel' in self.masks_dict, 'mask_channel is needed' assert 'mask_dir' in self.masks_dict, 'mask_dir is needed' self.mask_dir = self.masks_dict['mask_dir'] self.mask_meta = aux_utils.read_meta(self.mask_dir) assert 'mask_type' in self.masks_dict, \ 'mask_type (target/metrics) is needed' if self.masks_dict['mask_type'] == 'metrics': # Compute weighted metrics self.mask_metrics = True else: target_dir = self.mask_dir # Create dataset instance self.dataset_inst = InferenceDataSet( image_dir=self.image_dir, dataset_config=self.config['dataset'], network_config=self.config['network'], split_col_ids=split_col_ids, image_format=images_dict['image_format'], mask_dir=target_dir, flat_field_dir=flat_field_dir, ) # create an instance of MetricsEstimator self.iteration_meta = self.dataset_inst.get_iteration_meta() # Handle metrics config settings self.metrics_inst = None self.metrics_dict = None if 'metrics' in inference_config: self.metrics_dict = inference_config['metrics'] if self.metrics_dict is not None: assert 'metrics' in self.metrics_dict,\ 'Must specify with metrics to use' self.metrics_inst = MetricsEstimator( metrics_list=self.metrics_dict['metrics'], masked_metrics=self.mask_metrics, ) self.metrics_orientations = ['xy'] available_orientations = ['xy', 'xyz', 'xz', 'yz'] if 'metrics_orientations' in self.metrics_dict: self.metrics_orientations = \ self.metrics_dict['metrics_orientations'] assert set(self.metrics_orientations).\ issubset(available_orientations),\ 'orientation not in [xy, xyz, xz, yz]' self.df_xy = pd.DataFrame() self.df_xyz = pd.DataFrame() self.df_xz = pd.DataFrame() self.df_yz = pd.DataFrame() # Handle 3D volume inference settings self.num_overlap = 0 self.stitch_inst = None self.tile_option = None self.z_dim = 2 self.crop_shape = None if 'crop_shape' in images_dict: self.crop_shape = images_dict['crop_shape'] if 'inference_3d' in inference_config: self.params_3d = inference_config['inference_3d'] self._assign_3d_inference() # Make image ext npy default for 3D self.image_ext = '.npy' # Set session if not debug if gpu_id >= 0: self.sess = set_keras_session( gpu_ids=gpu_id, gpu_mem_frac=gpu_mem_frac, ) # create model and load weights self.model = inference.load_model( network_config=self.config['network'], model_fname=os.path.join(self.model_dir, model_fname), predict=True, )
def test_pre_process_weight_maps(self): cur_config = self.pp_config # Use preexisiting masks with more than one class, otherwise # weight map generation doesn't work cur_config['masks'] = { 'mask_dir': self.input_mask_dir, 'mask_channel': self.input_mask_channel, } cur_config['make_weight_map'] = True out_config, runtime = pp.pre_process(cur_config, self.base_config) # Check weights dir self.assertEqual( out_config['weights']['weights_dir'], os.path.join(self.output_dir, 'mask_channels_111') ) weights_meta = aux_utils.read_meta(out_config['weights']['weights_dir']) # Check indices self.assertListEqual( weights_meta.channel_idx.unique().tolist(), [112], ) self.assertListEqual( weights_meta.pos_idx.unique().tolist(), self.pos_ids, ) self.assertListEqual( weights_meta.slice_idx.unique().tolist(), self.slice_ids, ) self.assertListEqual( weights_meta.time_idx.unique().tolist(), [self.time_idx], ) # Load one weights file and check contents im = np.load(os.path.join( out_config['weights']['weights_dir'], 'im_c112_z002_t000_p007.npy', )) self.assertTupleEqual(im.shape, (30, 20)) self.assertTrue(im.dtype == np.float64) # Check tiles tile_dir = out_config['tile']['tile_dir'] tile_meta = aux_utils.read_meta(tile_dir) # 5 processed channels (0, 1, 3, 111, 112), 6 tiles per image expected_rows = 5 * 6 * len(self.slice_ids) * len(self.pos_ids) self.assertEqual(tile_meta.shape[0], expected_rows) # Check indices self.assertListEqual( tile_meta.channel_idx.unique().tolist(), [0, 1, 3, 111, 112], ) self.assertListEqual( tile_meta.pos_idx.unique().tolist(), self.pos_ids, ) self.assertListEqual( tile_meta.slice_idx.unique().tolist(), self.slice_ids, ) self.assertListEqual( tile_meta.time_idx.unique().tolist(), [self.time_idx], ) # Load one tile im = np.load(os.path.join( tile_dir, 'im_c111_z002_t000_p008_r0-10_c10-20_sl0-1.npy', )) self.assertTupleEqual(im.shape, (1, 10, 10)) self.assertTrue(im.dtype == bool)