def __init__(self, input_dir, output_dir, channel_ids, slice_ids, block_size=32): """ Flatfield images are estimated once per channel for 2D data :param str input_dir: Directory with 2D image frames from dataset :param str output_dir: Base output directory :param int/list channel_ids: channel ids for flat field_correction :param int/list slice_ids: Z slice indices for flatfield correction :param int block_size: Size of blocks image will be divided into """ self.input_dir = input_dir self.output_dir = output_dir # Create flat_field_dir as a subdirectory of output_dir self.flat_field_dir = os.path.join(self.output_dir, 'flat_field_images') os.makedirs(self.flat_field_dir, exist_ok=True) self.slice_ids = slice_ids self.frames_metadata = aux_utils.read_meta(self.input_dir) metadata_ids, _ = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, channel_ids=channel_ids, slice_ids=slice_ids, ) self.channels_ids = metadata_ids['channel_ids'] self.slice_ids = metadata_ids['slice_ids'] if block_size is None: block_size = 32 self.block_size = block_size
def __init__(self, input_dir, output_dir, tile_size=[256, 256], step_size=[64, 64], depths=1, time_ids=-1, channel_ids=-1, normalize_channels=-1, slice_ids=-1, pos_ids=-1, hist_clip_limits=None, flat_field_dir=None, image_format='zyx', num_workers=4, int2str_len=3, tile_3d=False): """Init Assuming same structure across channels and same number of samples across channels. The dataset could have varying number of time points and / or varying number of slices / size for each sample / position Please ref to init of ImageTilerUniform. """ super().__init__(input_dir=input_dir, output_dir=output_dir, tile_size=tile_size, step_size=step_size, depths=depths, time_ids=time_ids, channel_ids=channel_ids, normalize_channels=normalize_channels, slice_ids=slice_ids, pos_ids=pos_ids, hist_clip_limits=hist_clip_limits, flat_field_dir=flat_field_dir, image_format=image_format, num_workers=num_workers, int2str_len=int2str_len, tile_3d=tile_3d) # Get metadata indices metadata_ids, nested_id_dict = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, uniform_structure=False ) self.nested_id_dict = nested_id_dict # self.tile_dir is already created in super(). Check if frames_meta # exists in self.tile_dir meta_path = os.path.join(self.tile_dir, 'frames_meta.csv') assert not os.path.exists(meta_path), 'Tile dir exists. ' \ 'cannot add to existing dir'
def __init__(self, input_dir, output_dir, scale_factor, channel_ids=-1, time_ids=-1, slice_ids=-1, pos_ids=-1, int2str_len=3, num_workers=4, flat_field_dir=None): """ :param str input_dir: Directory with image frames :param str output_dir: Base output directory :param float/list scale_factor: Scale factor for resizing frames. :param int/list channel_ids: Channel indices to resize (default -1 includes all slices) :param int/list time_ids: timepoints to use :param int/list slice_ids: Index of slice (z) indices to use :param int/list pos_ids: Position (FOV) indices to use :param int int2str_len: Length of str when converting ints :param int num_workers: number of workers for multiprocessing :param str flat_field_dir: dir with flat field images """ self.input_dir = input_dir self.output_dir = output_dir if isinstance(scale_factor, list): scale_factor = np.array(scale_factor) assert np.all(scale_factor > 0), \ "Scale factor should be positive float, not {}".format(scale_factor) self.scale_factor = scale_factor self.frames_metadata = aux_utils.read_meta(self.input_dir) metadata_ids, _ = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, ) self.time_ids = metadata_ids['time_ids'] self.channel_ids = metadata_ids['channel_ids'] self.slice_ids = metadata_ids['slice_ids'] self.pos_ids = metadata_ids['pos_ids'] # Create resize_dir as a subdirectory of output_dir self.resize_dir = os.path.join( self.output_dir, 'resized_images', ) os.makedirs(self.resize_dir, exist_ok=True) self.int2str_len = int2str_len self.num_workers = num_workers self.flat_field_dir = flat_field_dir
def validate_metadata_indices(): metadata_ids, tp_dict = aux_utils.validate_metadata_indices( meta_df, time_ids=-1, channel_ids=-1, slice_ids=-1, pos_ids=-1, ) nose.tools.assert_list_equal(metadata_ids['channel_ids'].tolist(), [5]) nose.tools.assert_list_equal(metadata_ids['slice_ids'].tolist(), [0, 1, 2]) nose.tools.assert_list_equal(metadata_ids['time_ids'].tolist(), [6]) nose.tools.assert_list_equal(metadata_ids['pos_ids'].tolist(), [0, 1, 2, 3]) nose.tools.assert_is_none(tp_dict)
def __init__(self, input_dir, output_dir, channel_ids, flat_field_dir=None, time_ids=-1, slice_ids=-1, pos_ids=-1, int2str_len=3, uniform_struct=True, num_workers=4, mask_type='otsu', mask_channel=None, mask_ext='.npy', normalize_im=False): """ :param str input_dir: Directory with image frames :param str output_dir: Base output directory :param list[int] channel_ids: Channel indices to be masked (typically just one) :param str flat_field_dir: Directory with flatfield images if flatfield correction is applied :param int/list channel_ids: generate mask from the sum of these (flurophore) channel indices :param list/int time_ids: timepoints to consider :param int slice_ids: Index of which focal plane (z) acquisition to use (default -1 includes all slices) :param int pos_ids: Position (FOV) indices to use :param int int2str_len: Length of str when converting ints :param bool uniform_struct: bool indicator for same structure across pos and time points :param int num_workers: number of workers for multiprocessing :param str mask_type: method to use for generating mask. Needed for mapping to the masking function :param int mask_channel: channel number assigned to to be generated masks. If resizing images on a subset of channels, frames_meta is from resize dir, which could lead to wrong mask channel being assigned. :param str mask_ext: '.npy' or 'png'. Save the mask as uint8 PNG or NPY files :param bool normalize_im: indicator to normalize image based on z-score or not """ self.input_dir = input_dir self.output_dir = output_dir self.flat_field_dir = flat_field_dir self.num_workers = num_workers self.normalize_im = normalize_im self.frames_metadata = aux_utils.read_meta(self.input_dir) # Create a unique mask channel number so masks can be treated # as a new channel if mask_channel is None: self.mask_channel = int(self.frames_metadata['channel_idx'].max() + 1) else: self.mask_channel = int(mask_channel) metadata_ids, nested_id_dict = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, uniform_structure=uniform_struct, ) self.time_ids = metadata_ids['time_ids'] self.channel_ids = metadata_ids['channel_ids'] self.slice_ids = metadata_ids['slice_ids'] self.pos_ids = metadata_ids['pos_ids'] # Create mask_dir as a subdirectory of output_dir self.mask_dir = os.path.join( self.output_dir, 'mask_channels_' + '-'.join(map(str, self.channel_ids)), ) os.makedirs(self.mask_dir, exist_ok=True) self.int2str_len = int2str_len self.uniform_struct = uniform_struct self.nested_id_dict = nested_id_dict assert mask_type in ['otsu', 'unimodal', 'borders_weight_loss_map'], \ 'Masking method invalid, Otsu, borders_weight_loss_map, " +\ "and unimodal are currently supported' self.mask_type = mask_type self.mask_ext = mask_ext
def __init__(self, input_dir, output_dir, tile_size=[256, 256], step_size=[64, 64], depths=1, time_ids=-1, channel_ids=-1, normalize_channels=-1, slice_ids=-1, pos_ids=-1, hist_clip_limits=None, flat_field_dir=None, image_format='zyx', num_workers=4, int2str_len=3, tile_3d=False): """ Tiles images. If tile_dir already exist, it will check which channels are already tiled, get indices from them and tile from indices only on the channels not already present. :param str input_dir: Directory with frames to be tiled :param str output_dir: Base output directory :param list tile_size: size of the blocks to be cropped from the image :param list step_size: size of the window shift. In case of no overlap, the step size is tile_size. If overlap, step_size < tile_size :param int/list depths: The z depth for generating stack training data Default 1 assumes 2D data for all channels to be tiled. For cases where input and target shapes are not the same (e.g. stack to 2D) you should specify depths for each channel in tile.channels. :param list/int time_ids: Tile given timepoint indices :param list/int channel_ids: Tile images in the given channel indices default=-1, tile all channels. :param list/int normalize_channels: list of booleans matching channel_ids indicating if channel should be normalized or not. :param int slice_ids: Index of which focal plane acquisition to use (for 2D). default=-1 for the whole z-stack :param int pos_ids: Position (FOV) indices to use :param list hist_clip_limits: lower and upper percentiles used for histogram clipping. :param str flat_field_dir: Flatfield directory. None if no flatfield correction :param str image_format: zyx (preferred) or xyz :param int num_workers: number of workers for multiprocessing :param int int2str_len: number of characters for each idx to be used in file names :param bool tile_3d: Whether tiling is 3D or 2D """ self.input_dir = input_dir self.output_dir = output_dir self.normalize_channels = normalize_channels self.depths = depths self.tile_size = tile_size self.step_size = step_size self.hist_clip_limits = hist_clip_limits self.image_format = image_format assert self.image_format in {'zyx', 'xyz'}, \ 'Data format must be zyx or xyz' self.num_workers = num_workers self.int2str_len = int2str_len self.tile_3d = tile_3d self.str_tile_step = 'tiles_{}_step_{}'.format( '-'.join([str(val) for val in tile_size]), '-'.join([str(val) for val in step_size]), ) self.tile_dir = os.path.join( output_dir, self.str_tile_step, ) # If tile dir already exist, only tile channels not already present self.tiles_exist = False # If tile dir already exist, things could get messy because we don't # have any checks in place for how to add to existing tiles try: os.makedirs(self.tile_dir, exist_ok=False) # make dir for saving indiv meta per image, could be used for # tracking job success / fail os.makedirs(os.path.join(self.tile_dir, 'meta_dir'), exist_ok=False) except FileExistsError as e: print("Tile dir exists. Only add untiled channels.") self.tiles_exist = True # make dir for saving individual meta per image, could be used for # tracking job success / fail os.makedirs(os.path.join(self.tile_dir, 'meta_dir'), exist_ok=True) self.flat_field_dir = flat_field_dir self.frames_metadata = aux_utils.read_meta(self.input_dir) # Get metadata indices metadata_ids, _ = aux_utils.validate_metadata_indices( frames_metadata=self.frames_metadata, time_ids=time_ids, channel_ids=channel_ids, slice_ids=slice_ids, pos_ids=pos_ids, uniform_structure=True) self.channel_ids = metadata_ids['channel_ids'] self.normalize_channels = normalize_channels self.time_ids = metadata_ids['time_ids'] self.slice_ids = metadata_ids['slice_ids'] self.pos_ids = metadata_ids['pos_ids'] self.normalize_channels = normalize_channels # Determine which channels should be normalized in tiling if self.normalize_channels == -1: self.normalize_channels = [True] * len(self.channel_ids) else: assert len(self.normalize_channels) == len(self.channel_ids),\ "Channel ids {} and normalization list {} mismatch".format( self.channel_ids, self.normalize_channels, ) # If more than one depth is specified, length must match channel ids if isinstance(self.depths, list): assert len(self.depths) == len(self.channel_ids),\ "depths ({}) and channels ({}) length mismatch".format( self.depths, self.channel_ids, ) # Get max of all specified depths max_depth = max(self.depths) # Convert channels + depths to dict for lookup self.channel_depth = dict(zip(self.channel_ids, self.depths)) else: # If depth is scalar, make depth the same for all channels max_depth = self.depths self.channel_depth = dict( zip(self.channel_ids, [self.depths] * len(self.channel_ids)), ) # Adjust slice margins self.slice_ids = aux_utils.adjust_slice_margins( slice_ids=self.slice_ids, depth=max_depth, )
def tile_mask_stack(self, mask_dir, mask_channel, min_fraction, mask_depth=1): """ Tiles images in the specified channels assuming there are masks already created in mask_dir. Only tiles above a certain fraction of foreground in mask tile will be saved and added to metadata. Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx', 'slice_idx', 'file_name'] for all the tiles :param str mask_dir: Directory containing masks :param int mask_channel: Channel number assigned to mask :param float min_fraction: Min fraction of foreground in tiled masks :param int mask_depth: Depth for mask channel """ # mask depth has to match input or output channel depth assert mask_depth <= max(self.channel_depth.values()) self.mask_depth = mask_depth # Mask meta is stored in mask dir. If channel_ids= -1,frames_meta will # not contain any rows for mask channel. Assuming structure is same # across channels. Get time, pos and slice indices for mask channel mask_meta_df = aux_utils.read_meta(mask_dir) # TODO: different masks across timepoints (but MaskProcessor generates # mask for tp=0 only) _, mask_nested_id_dict = aux_utils.validate_metadata_indices( frames_metadata=mask_meta_df, time_ids=self.time_ids, channel_ids=mask_channel, slice_ids=self.slice_ids, pos_ids=self.pos_ids, uniform_structure=False ) # get t, z, p indices for mask_channel mask_ch_ids = {} for tp_idx, tp_dict in mask_nested_id_dict.items(): for ch_idx, ch_dict in tp_dict.items(): if ch_idx == mask_channel: ch0_dict = {mask_channel: ch_dict} mask_ch_ids[tp_idx] = ch0_dict # tile mask channel and use the tile indices to tile the rest meta_df = self.tile_first_channel( channel0_ids=mask_ch_ids, channel0_depth=mask_depth, cur_mask_dir=mask_dir, min_fraction=min_fraction, is_mask=True, ) # tile the rest self.tile_remaining_channels( nested_id_dict=self.nested_id_dict, tiled_ch_id=mask_channel, cur_meta_df=meta_df, )
def __init__(self, input_dir, input_channel_id, output_dir, output_channel_id, timepoint_id=0, correct_flat_field=True, focal_plane_idx=0, plot_masks=False): """Init :param str input_dir: base input dir at the level of individual sample images (or the level above timepoint dirs) :param int/list/tuple input_channel_id: channel_ids for which masks have to be generated :param str output_dir: output dir with full path. It is the base dir for tiled images :param int/list/tuple output_channel_id: channel_ids to be assigned to the created masks. Must match the len(input_channel_id), i.e. mask(input_channel_id[0])->output_channel_id[0] :param int/list/tuple timepoint_id: timepoints to consider :param bool correct_flat_field: indicator to apply flat field correction :param str meta_fname: fname that contains the meta info at the sample image level. If None, read from the default dir structure :param int focal_plane_idx: focal plane acquisition to use :param bool plot_masks: Plot input, masks and overlays """ assert os.path.exists(input_dir), 'input_dir does not exist' assert os.path.exists(output_dir), 'output_dir does not exist' self.input_dir = input_dir self.output_dir = output_dir self.correct_flat_field = correct_flat_field meta_fname = glob.glob(os.path.join(input_dir, "*info.csv")) assert len(meta_fname) == 1,\ "Can't find info.csv file in {}".format(input_dir) study_metadata = pd.read_csv(meta_fname[0]) self.study_metadata = study_metadata self.plot_masks = plot_masks avail_tp_channels = aux_utils.validate_metadata_indices( study_metadata, timepoint_ids=timepoint_id, channel_ids=input_channel_id) msg = 'timepoint_id is not available' assert timepoint_id in avail_tp_channels['timepoints'], msg if isinstance(timepoint_id, int): timepoint_id = [timepoint_id] self.timepoint_id = timepoint_id # Convert channel to int if there's only one value present if isinstance(input_channel_id, (list, tuple)): if len(input_channel_id) == 1: input_channel_id = input_channel_id[0] msg = 'input_channel_id is not available' assert input_channel_id in avail_tp_channels['channels'], msg msg = 'output_channel_id is already present' assert output_channel_id not in avail_tp_channels['channels'], msg if isinstance(input_channel_id, (list, tuple)): msg = 'input and output channel ids are not of same length' assert len(input_channel_id) == len(output_channel_id), msg else: input_channel_id = [input_channel_id] output_channel_id = [output_channel_id] self.input_channel_id = input_channel_id self.output_channel_id = output_channel_id self.focal_plane_idx = focal_plane_idx
def predict_on_full_image(self, image_meta, test_samples, focal_plane_idx=None, depth=None, per_tile_overlap=1 / 8, flat_field_correct=False, base_image_dir=None, place_operation='mean'): """Tile and run inference on tiles and assemble the full image :param pd.DataFrame image_meta: Df with individual image info, timepoint', 'channel_num', 'sample_num', 'slice_num', 'fname', 'size_x_microns', 'size_y_microns', 'size_z_microns' :param list test_samples: list of sample numbers to be used in the test set :param int focal_plane_idx: focal plane to be used :param int depth: if 3D - num of slices used for tiling :param float per_tile_overlap: percent overlap between successive tiles :param bool flat_field_correct: indicator for applying flat field correction :param str base_image_dir: base directory where images are stored :param str place_operation: in ['mean', 'max']. mean for regression tasks, max for segmentation tasks """ assert place_operation in ['mean', 'max'], \ 'only mean and max are allowed: %s' % place_operation if 'timepoints' not in self.config['dataset']: timepoint_ids = -1 else: timepoint_ids = self.config['dataset']['timepoints'] ip_channel_ids = self.config['dataset']['input_channels'] op_channel_ids = self.config['dataset']['target_channels'] tp_channel_ids = aux_utils.validate_metadata_indices( image_meta, time_ids=timepoint_ids) tp_idx = tp_channel_ids['timepoints'] tile_size = [ self.config['network']['height'], self.config['network']['width'] ] if depth is not None: assert 'depth' in self.config['network'] tile_size.insert(0, depth) step_size = (1 - per_tile_overlap) * np.array(tile_size) step_size = step_size.astype('int') step_size[step_size < 1] = 1 overlap_size = tile_size - step_size batch_size = self.config['trainer']['batch_size'] if flat_field_correct: assert base_image_dir is not None ff_dir = os.path.join(base_image_dir, 'flat_field_images') else: ff_dir = None for tp in tp_idx: # get the meta for all images in tp_dir and channel_dir row_idx_ip0 = aux_utils.get_row_idx(image_meta, tp, ip_channel_ids[0], slice_idx=focal_plane_idx) ip0_meta = image_meta[row_idx_ip0] # get rows corr. to test_samples from this DF test_row_ip0 = ip0_meta.loc[ip0_meta['sample_num'].isin( test_samples)] test_ip0_fnames = test_row_ip0['fname'].tolist() test_image_fnames = ([ fname.split(os.sep)[-1] for fname in test_ip0_fnames ]) tp_dir = str(os.sep).join(test_ip0_fnames[0].split(os.sep)[:-2]) test_image = np.load(test_ip0_fnames[0]) _, crop_indices = tile_utils.tile_image(test_image, tile_size, step_size, return_index=True) pred_dir = os.path.join(self.config['trainer']['model_dir'], 'predicted_images', 'tp_{}'.format(tp)) for fname in test_image_fnames: target_image = self._read_one(tp_dir, op_channel_ids, fname, ff_dir) input_image = self._read_one(tp_dir, ip_channel_ids, fname, ff_dir) pred_tiles = self._pred_image(input_image, crop_indices, batch_size) pred_image = self._stitch_image(pred_tiles, crop_indices, input_image.shape, batch_size, tile_size, overlap_size, place_operation) pred_fname = '{}.npy'.format(fname.split('.')[0]) for idx, op_ch in enumerate(op_channel_ids): op_dir = os.path.join(pred_dir, 'channel_{}'.format(op_ch)) if not os.path.exists(op_dir): os.makedirs(op_dir) np.save(os.path.join(op_dir, pred_fname), pred_image[idx]) save_predicted_images([input_image], [target_image], [pred_image], os.path.join(op_dir, 'collage'), output_fname=fname.split('.')[0])