def tile_and_save(input_fnames, flat_field_fname, hist_clip_limits, time_idx, channel_idx, pos_idx, slice_idx, tile_size, step_size, min_fraction, image_format, save_dir, int2str_len=3, is_mask=False, normalize_im=False): """Crop image into tiles at given indices and save :param tuple input_fnames: tuple of input fnames with full path :param str flat_field_fname: fname of flat field image :param tuple hist_clip_limits: limits for histogram clipping :param int time_idx: time point of input image :param int channel_idx: channel idx of input image :param int slice_idx: slice idx of input image :param int pos_idx: sample idx of input image :param list tile_size: size of tile along row, col (& slices) :param list step_size: step size along row, col (& slices) :param float min_fraction: min foreground volume fraction for keep tile :param str image_format: zyx / xyz :param str save_dir: output dir to save tiles :param int int2str_len: len of indices for creating file names :param bool is_mask: Indicates if files are masks :param bool normalize_im: Indicates if normalizing using z score is needed :return: pd.DataFrame from a list of dicts with metadata """ try: input_image = image_utils.read_imstack( input_fnames=input_fnames, flat_field_fname=flat_field_fname, hist_clip_limits=hist_clip_limits, is_mask=is_mask, normalize_im=normalize_im ) save_dict = {'time_idx': time_idx, 'channel_idx': channel_idx, 'pos_idx': pos_idx, 'slice_idx': slice_idx, 'save_dir': save_dir, 'image_format': image_format, 'int2str_len': int2str_len} tile_meta_df = tile_utils.tile_image( input_image=input_image, tile_size=tile_size, step_size=step_size, min_fraction=min_fraction, save_dict=save_dict, ) except Exception as e: err_msg = 'error in t_{}, c_{}, pos_{}, sl_{}'.format( time_idx, channel_idx, pos_idx, slice_idx ) err_msg = err_msg + str(e) # TODO(Anitha) write to log instead print(err_msg) raise e return tile_meta_df
def tile_stack(self): """ Tiles images in the specified channels. https://research.wmz.ninja/articles/2018/03/ on-sharing-large-arrays-when-using-pythons-multiprocessing.html Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx','slice_idx', 'file_name'] for all the tiles """ # Get or create tiled metadata and tile indices prev_tiled_metadata, tile_indices = self._get_tiled_data() tiled_meta0 = None fn_args = [] for channel_idx in self.channel_ids: # Find channel index position in channel_ids list list_idx = self.channel_ids.index(channel_idx) # Perform flatfield correction if flatfield dir is specified flat_field_im = self._get_flat_field(channel_idx=channel_idx) for slice_idx in self.slice_ids: for time_idx in self.time_ids: for pos_idx in self.pos_ids: if tile_indices is None: # tile and save first image # get meta data and tile_indices im = image_utils.preprocess_imstack( frames_metadata=self.frames_metadata, input_dir=self.input_dir, depth=self.channel_depth[channel_idx], time_idx=time_idx, channel_idx=channel_idx, slice_idx=slice_idx, pos_idx=pos_idx, flat_field_im=flat_field_im, hist_clip_limits=self.hist_clip_limits, normalize_im=self.normalize_channels[list_idx], ) save_dict = { 'time_idx': time_idx, 'channel_idx': channel_idx, 'pos_idx': pos_idx, 'slice_idx': slice_idx, 'save_dir': self.tile_dir, 'image_format': self.image_format, 'int2str_len': self.int2str_len } tiled_meta0, tile_indices = \ tile_utils.tile_image( input_image=im, tile_size=self.tile_size, step_size=self.step_size, return_index=True, save_dict=save_dict, ) else: cur_args = self.get_crop_tile_args( channel_idx, time_idx, slice_idx, pos_idx, task_type='crop', tile_indices=tile_indices, normalize_im=self.normalize_channels[list_idx], ) fn_args.append(cur_args) tiled_meta_df_list = mp_utils.mp_crop_save( fn_args, workers=self.num_workers, ) if tiled_meta0 is not None: tiled_meta_df_list.append(tiled_meta0) tiled_metadata = pd.concat(tiled_meta_df_list, ignore_index=True) if self.tiles_exist: tiled_metadata.reset_index(drop=True, inplace=True) prev_tiled_metadata.reset_index(drop=True, inplace=True) tiled_metadata = pd.concat( [prev_tiled_metadata, tiled_metadata], ignore_index=True, ) # Finally, save all the metadata tiled_metadata = tiled_metadata.sort_values(by=['file_name']) tiled_metadata.to_csv( os.path.join(self.tile_dir, "frames_meta.csv"), sep=",", )
def tile_mask_stack(self, input_mask_dir, tile_index_fname=None, tile_size=None, step_size=None, isotropic=False): """ Tiles a stack of masks :param str/list input_mask_dir: input_mask_dir with full path :param str tile_index_fname: fname with full path for the pickle file which contains a dict with fname as keys and crop indices as values. Needed when tiling using a volume fraction constraint (i.e. check for minimum foreground in tile) :param list/tuple tile_size: as named :param list/tuple step_size: as named :param bool isotropic: indicator for making the tiles have isotropic shape (only for 3D) """ if tile_index_fname: msg = 'tile index file does not exist' assert (os.path.exists(tile_index_fname) and os.path.isfile(tile_index_fname)), msg with open(tile_index_fname, 'rb') as f: crop_indices_dict = pickle.load(f) else: msg = 'tile_size and step_size are needed' assert tile_size is not None and step_size is not None, msg msg = 'tile and step sizes should have same length' assert len(tile_size) == len(step_size), msg if not isinstance(input_mask_dir, list): input_mask_dir = [input_mask_dir] for ch_idx, cur_dir in enumerate(input_mask_dir): # Split dir name and remove last / if present sep_strs = cur_dir.split(os.sep) if len(sep_strs[-1]) == 0: sep_strs.pop(-1) cur_tp = int(sep_strs[-2].split('_')[-1]) cur_ch = int(sep_strs[-1].split('_')[-1]) # read all mask npy files mask_fnames = glob.glob(os.path.join(cur_dir, '*.npy')) # Sort file names, the assumption is that the csv is sorted mask_fnames = natsort.natsorted(mask_fnames) cropped_meta = [] output_dir = os.path.join( self.output_dir, 'timepoint_{}'.format(cur_tp), 'channel_{}'.format(self.output_channel_id[ch_idx])) os.makedirs(output_dir, exist_ok=True) for cur_mask_fname in mask_fnames: _, fname = os.path.split(cur_mask_fname) sample_num = int(fname.split('_')[1][1:]) cur_mask = np.load(cur_mask_fname) if tile_index_fname: cropped_image_data = tile_utils.crop_at_indices( input_image=cur_mask, crop_indices=crop_indices_dict[fname], isotropic=isotropic) else: cropped_image_data = tile_utils.tile_image( input_image=cur_mask, tile_size=tile_size, step_size=step_size, isotropic=isotropic) # save the stack for id_img_tuple in cropped_image_data: rcsl_idx = id_img_tuple[0] img_fname = 'n{}_{}.npy'.format(sample_num, rcsl_idx) cropped_img = id_img_tuple[1] cropped_img_fname = os.path.join(output_dir, img_fname) np.save( cropped_img_fname, cropped_img, allow_pickle=True, fix_imports=True, ) cropped_meta.append( (cur_tp, cur_ch, sample_num, self.focal_plane_idx, cropped_img_fname)) aux_utils.save_tile_meta( cropped_meta, cur_channel=cur_ch, tiled_dir=self.output_dir, )
def predict_3d(self, iteration_rows): """ Run prediction in 3D on images with 3D shape. :param list iteration_rows: Inference meta rows :return np.array pred_stack: Prediction :return np.array target_stack: Target :return np.array/list mask_stack: Mask for metrics """ crop_indices = None assert len(iteration_rows) == 1, \ 'more than one matching row found for position ' \ '{}'.format(iteration_rows.pos_idx) cur_input, cur_target = \ self.dataset_inst.__getitem__(iteration_rows[0]) # If crop shape is defined in images dict if self.crop_shape is not None: cur_input = image_utils.center_crop_to_shape( cur_input, self.crop_shape, ) cur_target = image_utils.center_crop_to_shape( cur_target, self.crop_shape, ) inf_shape = None if self.tile_option == 'infer_on_center': inf_shape = self.params_3d['inf_shape'] center_block = image_utils.center_crop_to_shape( cur_input, inf_shape) cur_target = image_utils.center_crop_to_shape( cur_target, inf_shape) pred_image = inference.predict_large_image( model=self.model, input_image=center_block, ) elif self.tile_option == 'tile_z': pred_block_list, start_end_idx = \ self._predict_sub_block_z(cur_input) pred_image = self.stitch_inst.stitch_predictions( np.squeeze(cur_input).shape, pred_block_list, start_end_idx) elif self.tile_option == 'tile_xyz': step_size = (np.array(self.params_3d['tile_shape']) - np.array(self.num_overlap)) if crop_indices is None: # TODO tile_image works for 2D/3D imgs, modify for multichannel _, crop_indices = tile_utils.tile_image( input_image=np.squeeze(cur_input), tile_size=self.params_3d['tile_shape'], step_size=step_size, return_index=True) pred_block_list = self._predict_sub_block_xyz( cur_input, crop_indices, ) pred_image = self.stitch_inst.stitch_predictions( np.squeeze(cur_input).shape, pred_block_list, crop_indices, ) pred_image = np.squeeze(pred_image).astype(np.float32) target_image = np.squeeze(cur_target).astype(np.float32) # save prediction cur_row = self.iteration_meta.iloc[iteration_rows[0]] self.save_pred_image( predicted_image=pred_image, time_idx=cur_row['time_idx'], target_channel_idx=cur_row['channel_idx'], pos_idx=cur_row['pos_idx'], slice_idx=cur_row['slice_idx'], ) # 3D uses zyx, estimate metrics expects xyz if self.image_format == 'zyx': pred_image = np.transpose(pred_image, [1, 2, 0]) target_image = np.transpose(target_image, [1, 2, 0]) # get mask mask_image = None if self.masks_dict is not None: mask_image = self.get_mask(cur_row, transpose=True) if inf_shape is not None: mask_image = image_utils.center_crop_to_shape( mask_image, inf_shape, ) if self.image_format == 'zyx': mask_image = np.transpose(mask_image, [1, 2, 0]) return pred_image, target_image, mask_image
def predict_on_full_image(self, image_meta, test_samples, focal_plane_idx=None, depth=None, per_tile_overlap=1 / 8, flat_field_correct=False, base_image_dir=None, place_operation='mean'): """Tile and run inference on tiles and assemble the full image :param pd.DataFrame image_meta: Df with individual image info, timepoint', 'channel_num', 'sample_num', 'slice_num', 'fname', 'size_x_microns', 'size_y_microns', 'size_z_microns' :param list test_samples: list of sample numbers to be used in the test set :param int focal_plane_idx: focal plane to be used :param int depth: if 3D - num of slices used for tiling :param float per_tile_overlap: percent overlap between successive tiles :param bool flat_field_correct: indicator for applying flat field correction :param str base_image_dir: base directory where images are stored :param str place_operation: in ['mean', 'max']. mean for regression tasks, max for segmentation tasks """ assert place_operation in ['mean', 'max'], \ 'only mean and max are allowed: %s' % place_operation if 'timepoints' not in self.config['dataset']: timepoint_ids = -1 else: timepoint_ids = self.config['dataset']['timepoints'] ip_channel_ids = self.config['dataset']['input_channels'] op_channel_ids = self.config['dataset']['target_channels'] tp_channel_ids = aux_utils.validate_metadata_indices( image_meta, time_ids=timepoint_ids) tp_idx = tp_channel_ids['timepoints'] tile_size = [ self.config['network']['height'], self.config['network']['width'] ] if depth is not None: assert 'depth' in self.config['network'] tile_size.insert(0, depth) step_size = (1 - per_tile_overlap) * np.array(tile_size) step_size = step_size.astype('int') step_size[step_size < 1] = 1 overlap_size = tile_size - step_size batch_size = self.config['trainer']['batch_size'] if flat_field_correct: assert base_image_dir is not None ff_dir = os.path.join(base_image_dir, 'flat_field_images') else: ff_dir = None for tp in tp_idx: # get the meta for all images in tp_dir and channel_dir row_idx_ip0 = aux_utils.get_row_idx(image_meta, tp, ip_channel_ids[0], slice_idx=focal_plane_idx) ip0_meta = image_meta[row_idx_ip0] # get rows corr. to test_samples from this DF test_row_ip0 = ip0_meta.loc[ip0_meta['sample_num'].isin( test_samples)] test_ip0_fnames = test_row_ip0['fname'].tolist() test_image_fnames = ([ fname.split(os.sep)[-1] for fname in test_ip0_fnames ]) tp_dir = str(os.sep).join(test_ip0_fnames[0].split(os.sep)[:-2]) test_image = np.load(test_ip0_fnames[0]) _, crop_indices = tile_utils.tile_image(test_image, tile_size, step_size, return_index=True) pred_dir = os.path.join(self.config['trainer']['model_dir'], 'predicted_images', 'tp_{}'.format(tp)) for fname in test_image_fnames: target_image = self._read_one(tp_dir, op_channel_ids, fname, ff_dir) input_image = self._read_one(tp_dir, ip_channel_ids, fname, ff_dir) pred_tiles = self._pred_image(input_image, crop_indices, batch_size) pred_image = self._stitch_image(pred_tiles, crop_indices, input_image.shape, batch_size, tile_size, overlap_size, place_operation) pred_fname = '{}.npy'.format(fname.split('.')[0]) for idx, op_ch in enumerate(op_channel_ids): op_dir = os.path.join(pred_dir, 'channel_{}'.format(op_ch)) if not os.path.exists(op_dir): os.makedirs(op_dir) np.save(os.path.join(op_dir, pred_fname), pred_image[idx]) save_predicted_images([input_image], [target_image], [pred_image], os.path.join(op_dir, 'collage'), output_fname=fname.split('.')[0])
def test_tile_image(self): """Test tile_image""" input_image = self.sph[:, :, 3:6] tile_size = [16, 16] step_size = [8, 8] # returns at tuple of (img_id, tile) tiled_image_list = tile_utils.tile_image( input_image, tile_size=tile_size, step_size=step_size, ) nose.tools.assert_equal(len(tiled_image_list), 9) c = 0 for row in range(0, 17, 8): for col in range(0, 17, 8): id_str = 'r{}-{}_c{}-{}_sl{}-{}'.format( row, row + tile_size[0], col, col + tile_size[1], 0, 3) nose.tools.assert_equal(id_str, tiled_image_list[c][0]) tile = input_image[row:row + tile_size[0], col:col + tile_size[1], ...] numpy.testing.assert_array_equal(tile, tiled_image_list[c][1]) c += 1 # returns tuple_list, cropping_index _, tile_index = tile_utils.tile_image( input_image, tile_size=tile_size, step_size=step_size, return_index=True, ) exp_tile_index = [(0, 16, 0, 16), (0, 16, 8, 24), (0, 16, 16, 32), (8, 24, 0, 16), (8, 24, 8, 24), (8, 24, 16, 32), (16, 32, 0, 16), (16, 32, 8, 24), (16, 32, 16, 32)] numpy.testing.assert_equal(exp_tile_index, tile_index) # save tiles in place and return meta_df tile_dir = os.path.join(self.temp_path, 'tile_dir') os.makedirs(tile_dir, exist_ok=True) meta_dir = os.path.join(tile_dir, 'meta_dir') os.makedirs(meta_dir, exist_ok=True) save_dict = { 'time_idx': self.time_idx, 'channel_idx': self.channel_idx, 'slice_idx': 4, 'pos_idx': self.pos_idx, 'image_format': 'zyx', 'int2str_len': 3, 'save_dir': tile_dir } tile_meta_df = tile_utils.tile_image( input_image, tile_size=tile_size, step_size=step_size, save_dict=save_dict, ) tile_meta = [] for row in range(0, 17, 8): for col in range(0, 17, 8): id_str = 'r{}-{}_c{}-{}_sl{}-{}'.format( row, row + tile_size[0], col, col + tile_size[1], 0, 3) cur_fname = aux_utils.get_im_name( time_idx=self.time_idx, channel_idx=self.channel_idx, slice_idx=4, pos_idx=self.pos_idx, int2str_len=3, extra_field=id_str, ext='.npy', ) cur_path = os.path.join(tile_dir, cur_fname) nose.tools.assert_equal(os.path.exists(cur_path), True) cur_meta = { 'channel_idx': self.channel_idx, 'slice_idx': 4, 'time_idx': self.time_idx, 'file_name': cur_fname, 'pos_idx': self.pos_idx, 'row_start': row, 'col_start': col } tile_meta.append(cur_meta) exp_tile_meta_df = pd.DataFrame.from_dict(tile_meta) exp_tile_meta_df = exp_tile_meta_df.sort_values(by=['file_name']) pd.testing.assert_frame_equal(tile_meta_df, exp_tile_meta_df) # use mask and min_fraction to select tiles to retain input_image_bool = input_image > 128 _, tile_index = tile_utils.tile_image( input_image_bool, tile_size=tile_size, step_size=step_size, min_fraction=0.3, return_index=True, ) exp_tile_index = [(0, 16, 8, 24), (8, 24, 0, 16), (8, 24, 8, 24), (8, 24, 16, 32), (16, 32, 8, 24)] numpy.testing.assert_array_equal(tile_index, exp_tile_index) # tile_3d input_image = self.sph tile_size = [16, 16, 6] step_size = [8, 8, 4] # returns at tuple of (img_id, tile) tiled_image_list = tile_utils.tile_image( input_image, tile_size=tile_size, step_size=step_size, ) nose.tools.assert_equal(len(tiled_image_list), 18) c = 0 for row in range(0, 17, 8): for col in range(0, 17, 8): for sl in range(0, 8, 6): if sl == 0: sl_start_end = [0, 6] else: sl_start_end = [2, 8] id_str = 'r{}-{}_c{}-{}_sl{}-{}'.format( row, row + tile_size[0], col, col + tile_size[1], sl_start_end[0], sl_start_end[1]) nose.tools.assert_equal(id_str, tiled_image_list[c][0]) tile = input_image[row:row + tile_size[0], col:col + tile_size[1], sl_start_end[0]:sl_start_end[1]] numpy.testing.assert_array_equal(tile, tiled_image_list[c][1]) c += 1