def test_read_imstack(self): """Test read_imstack""" fnames = self.frames_meta['file_name'][:3] fnames = [os.path.join(self.temp_path, fname) for fname in fnames] # non-boolean im_stack = image_utils.read_imstack(fnames) exp_stack = normalize.zscore(self.sph[:, :, :3]) np.testing.assert_equal(im_stack.shape, (32, 32, 3)) np.testing.assert_array_equal(exp_stack[:, :, :3], im_stack) # read a 3D image im_stack = image_utils.read_imstack([self.sph_fname]) np.testing.assert_equal(im_stack.shape, (32, 32, 8)) # read multiple 3D images im_stack = image_utils.read_imstack((self.sph_fname, self.sph_fname)) np.testing.assert_equal(im_stack.shape, (32, 32, 8, 2))
def compute_metrics(model_dir, image_dir, metrics_list, orientations_list, test_data=True): """ Compute specified metrics for given orientations for predictions, which are assumed to be stored in model_dir/predictions. Targets are stored in image_dir. Writes metrics csv files for each orientation in model_dir/predictions. :param str model_dir: Assumed to contain config, split_samples.json and subdirectory predictions/ :param str image_dir: Directory containing target images with frames_meta.csv :param list metrics_list: See inference/evaluation_metrics.py for options :param list orientations_list: Any subset of {xy, xz, yz, xyz} (see evaluation_metrics) :param bool test_data: Uses test indices in split_samples.json, otherwise all indices """ # Load config file config_name = os.path.join(model_dir, 'config.yml') with open(config_name, 'r') as f: config = yaml.safe_load(f) # Load frames metadata and determine indices frames_meta = pd.read_csv(os.path.join(image_dir, 'frames_meta.csv')) if isinstance(metrics_list, str): metrics_list = [metrics_list] metrics_inst = metrics.MetricsEstimator(metrics_list=metrics_list) split_idx_name = config['dataset']['split_by_column'] if test_data: idx_fname = os.path.join(model_dir, 'split_samples.json') try: split_samples = aux_utils.read_json(idx_fname) test_ids = split_samples['test'] except FileNotFoundError as e: print("No split_samples file. Will predict all images in dir.") else: test_ids = np.unique(frames_meta[split_idx_name]) # Find other indices to iterate over than split index name # E.g. if split is position, we also need to iterate over time and slice test_meta = pd.read_csv(os.path.join(model_dir, 'test_metadata.csv')) metadata_ids = {split_idx_name: test_ids} iter_ids = ['slice_idx', 'pos_idx', 'time_idx'] for id in iter_ids: if id != split_idx_name: metadata_ids[id] = np.unique(test_meta[id]) # Create image subdirectory to write predicted images pred_dir = os.path.join(model_dir, 'predictions') target_channel = config['dataset']['target_channels'][0] # If network depth is > 3 determine depth margins for +-z depth = 1 if 'depth' in config['network']: depth = config['network']['depth'] # Get channel name and extension for predictions pred_fnames = [f for f in os.listdir(pred_dir) if f.startswith('im_')] meta_row = aux_utils.parse_idx_from_name(pred_fnames[0]) pred_channel = meta_row['channel_idx'] _, ext = os.path.splitext(pred_fnames[0]) if isinstance(orientations_list, str): orientations_list = [orientations_list] available_orientations = {'xy', 'xz', 'yz', 'xyz'} assert set(orientations_list).issubset(available_orientations), \ "Orientations must be subset of {}".format(available_orientations) fn_mapping = { 'xy': metrics_inst.estimate_xy_metrics, 'xz': metrics_inst.estimate_xz_metrics, 'yz': metrics_inst.estimate_yz_metrics, 'xyz': metrics_inst.estimate_xyz_metrics, } metrics_mapping = { 'xy': metrics_inst.get_metrics_xy, 'xz': metrics_inst.get_metrics_xz, 'yz': metrics_inst.get_metrics_yz, 'xyz': metrics_inst.get_metrics_xyz, } df_mapping = { 'xy': pd.DataFrame(), 'xz': pd.DataFrame(), 'yz': pd.DataFrame(), 'xyz': pd.DataFrame(), } # Iterate over all indices for test data for time_idx in metadata_ids['time_idx']: for pos_idx in metadata_ids['pos_idx']: target_fnames = [] pred_fnames = [] for slice_idx in metadata_ids['slice_idx']: im_idx = aux_utils.get_meta_idx( frames_metadata=frames_meta, time_idx=time_idx, channel_idx=target_channel, slice_idx=slice_idx, pos_idx=pos_idx, ) target_fname = os.path.join( image_dir, frames_meta.loc[im_idx, 'file_name'], ) target_fnames.append(target_fname) pred_fname = aux_utils.get_im_name( time_idx=time_idx, channel_idx=pred_channel, slice_idx=slice_idx, pos_idx=pos_idx, ext=ext, ) pred_fname = os.path.join(pred_dir, pred_fname) pred_fnames.append(pred_fname) target_stack = image_utils.read_imstack( input_fnames=tuple(target_fnames), ) pred_stack = image_utils.read_imstack( input_fnames=tuple(pred_fnames), normalize_im=False, ) if depth == 1: # Remove singular z dimension for 2D image target_stack = np.squeeze(target_stack) pred_stack = np.squeeze(pred_stack) if target_stack.dtype == np.float64: target_stack = target_stack.astype(np.float32) pred_name = "t{}_p{}".format(time_idx, pos_idx) for orientation in orientations_list: metric_fn = fn_mapping[orientation] metric_fn( target=target_stack, prediction=pred_stack, pred_name=pred_name, ) df_mapping[orientation] = df_mapping[orientation].append( metrics_mapping[orientation](), ignore_index=True, ) # Save non-empty dataframes for orientation in orientations_list: metrics_df = df_mapping[orientation] df_name = 'metrics_{}.csv'.format(orientation) metrics_name = os.path.join(pred_dir, df_name) metrics_df.to_csv(metrics_name, sep=",", index=False)
def create_save_mask(input_fnames, flat_field_fname, str_elem_radius, mask_dir, mask_channel_idx, time_idx, pos_idx, slice_idx, int2str_len, mask_type, mask_ext, normalize_im=False): """ Create and save mask. When >1 channel are used to generate the mask, mask of each channel is generated then added together. :param tuple input_fnames: tuple of input fnames with full path :param str flat_field_fname: fname of flat field image :param int str_elem_radius: size of structuring element used for binary opening. str_elem: disk or ball :param str mask_dir: dir to save masks :param int mask_channel_idx: channel number of mask :param int time_idx: time points to use for generating mask :param int pos_idx: generate masks for given position / sample ids :param int slice_idx: generate masks for given slice ids :param int int2str_len: Length of str when converting ints :param str mask_type: thresholding type used for masking or str to map to masking function :param str mask_ext: '.npy' or '.png'. Save the mask as uint8 PNG or NPY files for otsu, unimodal masks, recommended to save as npy float64 for borders_weight_loss_map masks to avoid loss due to scaling it to uint8. :param bool normalize_im: indicator to normalize image based on z-score or not :return dict cur_meta for each mask """ im_stack = image_utils.read_imstack( input_fnames, flat_field_fname, normalize_im=normalize_im, ) masks = [] for idx in range(im_stack.shape[-1]): im = im_stack[..., idx] if mask_type == 'otsu': mask = mask_utils.create_otsu_mask(im.astype('float32'), str_elem_radius) elif mask_type == 'unimodal': mask = mask_utils.create_unimodal_mask(im.astype('float32'), str_elem_radius) elif mask_type == 'borders_weight_loss_map': mask = mask_utils.get_unet_border_weight_map(im) masks += [mask] # Border weight map mask is a float mask not binary like otsu or unimodal, # so keep it as is (assumes only one image in stack) if mask_type == 'borders_weight_loss_map': mask = masks[0] else: masks = np.stack(masks, axis=-1) mask = np.any(masks, axis=-1) # Create mask name for given slice, time and position file_name = aux_utils.get_im_name( time_idx=time_idx, channel_idx=mask_channel_idx, slice_idx=slice_idx, pos_idx=pos_idx, int2str_len=int2str_len, ext=mask_ext, ) if mask_ext == '.npy': # Save mask for given channels, mask is 2D np.save(os.path.join(mask_dir, file_name), mask, allow_pickle=True, fix_imports=True) elif mask_ext == '.png': # Covert mask to uint8 # Border weight map mask is a float mask not binary like otsu or unimodal, # so keep it as is if mask_type == 'borders_weight_loss_map': assert im_stack.shape[-1] == 1 # Note: Border weight map mask should only be generated from one binary image else: mask = mask.astype(np.uint8) * np.iinfo(np.uint8).max cv2.imwrite(os.path.join(mask_dir, file_name), mask) else: raise ValueError("mask_ext can be '.npy' or '.png', not {}".format(mask_ext)) cur_meta = {'channel_idx': mask_channel_idx, 'slice_idx': slice_idx, 'time_idx': time_idx, 'pos_idx': pos_idx, 'file_name': file_name} return cur_meta
def crop_at_indices_save(input_fnames, flat_field_fname, hist_clip_limits, time_idx, channel_idx, pos_idx, slice_idx, crop_indices, image_format, save_dir, int2str_len=3, is_mask=False, tile_3d=False, normalize_im=False): """Crop image into tiles at given indices and save :param tuple input_fnames: tuple of input fnames with full path :param str flat_field_fname: fname of flat field image :param tuple hist_clip_limits: limits for histogram clipping :param int time_idx: time point of input image :param int channel_idx: channel idx of input image :param int slice_idx: slice idx of input image :param int pos_idx: sample idx of input image :param tuple crop_indices: tuple of indices for cropping :param str image_format: zyx or xyz :param str save_dir: output dir to save tiles :param int int2str_len: len of indices for creating file names :param bool is_mask: Indicates if files are masks :param bool tile_3d: indicator for tiling in 3D :param bool normalize_im: Indicates if normalizing using z score is needed :return: pd.DataFrame from a list of dicts with metadata """ try: input_image = image_utils.read_imstack( input_fnames=input_fnames, flat_field_fname=flat_field_fname, hist_clip_limits=hist_clip_limits, is_mask=is_mask, normalize_im=normalize_im ) save_dict = {'time_idx': time_idx, 'channel_idx': channel_idx, 'pos_idx': pos_idx, 'slice_idx': slice_idx, 'save_dir': save_dir, 'image_format': image_format, 'int2str_len': int2str_len} tile_meta_df = tile_utils.crop_at_indices( input_image=input_image, crop_indices=crop_indices, save_dict=save_dict, tile_3d=tile_3d, ) except Exception as e: err_msg = 'error in t_{}, c_{}, pos_{}, sl_{}'.format( time_idx, channel_idx, pos_idx, slice_idx ) err_msg = err_msg + str(e) # TODO(Anitha) write to log instead print(err_msg) raise e return tile_meta_df
def tile_and_save(input_fnames, flat_field_fname, hist_clip_limits, time_idx, channel_idx, pos_idx, slice_idx, tile_size, step_size, min_fraction, image_format, save_dir, int2str_len=3, is_mask=False, normalize_im=False): """Crop image into tiles at given indices and save :param tuple input_fnames: tuple of input fnames with full path :param str flat_field_fname: fname of flat field image :param tuple hist_clip_limits: limits for histogram clipping :param int time_idx: time point of input image :param int channel_idx: channel idx of input image :param int slice_idx: slice idx of input image :param int pos_idx: sample idx of input image :param list tile_size: size of tile along row, col (& slices) :param list step_size: step size along row, col (& slices) :param float min_fraction: min foreground volume fraction for keep tile :param str image_format: zyx / xyz :param str save_dir: output dir to save tiles :param int int2str_len: len of indices for creating file names :param bool is_mask: Indicates if files are masks :param bool normalize_im: Indicates if normalizing using z score is needed :return: pd.DataFrame from a list of dicts with metadata """ try: input_image = image_utils.read_imstack( input_fnames=input_fnames, flat_field_fname=flat_field_fname, hist_clip_limits=hist_clip_limits, is_mask=is_mask, normalize_im=normalize_im ) save_dict = {'time_idx': time_idx, 'channel_idx': channel_idx, 'pos_idx': pos_idx, 'slice_idx': slice_idx, 'save_dir': save_dir, 'image_format': image_format, 'int2str_len': int2str_len} tile_meta_df = tile_utils.tile_image( input_image=input_image, tile_size=tile_size, step_size=step_size, min_fraction=min_fraction, save_dict=save_dict, ) except Exception as e: err_msg = 'error in t_{}, c_{}, pos_{}, sl_{}'.format( time_idx, channel_idx, pos_idx, slice_idx ) err_msg = err_msg + str(e) # TODO(Anitha) write to log instead print(err_msg) raise e return tile_meta_df