コード例 #1
0
    def test_read_imstack(self):
        """Test read_imstack"""

        fnames = self.frames_meta['file_name'][:3]
        fnames = [os.path.join(self.temp_path, fname) for fname in fnames]
        # non-boolean
        im_stack = image_utils.read_imstack(fnames)
        exp_stack = normalize.zscore(self.sph[:, :, :3])
        np.testing.assert_equal(im_stack.shape, (32, 32, 3))
        np.testing.assert_array_equal(exp_stack[:, :, :3],
                                      im_stack)

        # read a 3D image
        im_stack = image_utils.read_imstack([self.sph_fname])
        np.testing.assert_equal(im_stack.shape, (32, 32, 8))

        # read multiple 3D images
        im_stack = image_utils.read_imstack((self.sph_fname, self.sph_fname))
        np.testing.assert_equal(im_stack.shape, (32, 32, 8, 2))
コード例 #2
0
def compute_metrics(model_dir,
                    image_dir,
                    metrics_list,
                    orientations_list,
                    test_data=True):
    """
    Compute specified metrics for given orientations for predictions, which
    are assumed to be stored in model_dir/predictions. Targets are stored in
    image_dir.
    Writes metrics csv files for each orientation in model_dir/predictions.

    :param str model_dir: Assumed to contain config, split_samples.json and
        subdirectory predictions/
    :param str image_dir: Directory containing target images with frames_meta.csv
    :param list metrics_list: See inference/evaluation_metrics.py for options
    :param list orientations_list: Any subset of {xy, xz, yz, xyz}
        (see evaluation_metrics)
    :param bool test_data: Uses test indices in split_samples.json,
    otherwise all indices
    """
    # Load config file
    config_name = os.path.join(model_dir, 'config.yml')
    with open(config_name, 'r') as f:
        config = yaml.safe_load(f)
    # Load frames metadata and determine indices
    frames_meta = pd.read_csv(os.path.join(image_dir, 'frames_meta.csv'))

    if isinstance(metrics_list, str):
        metrics_list = [metrics_list]
    metrics_inst = metrics.MetricsEstimator(metrics_list=metrics_list)

    split_idx_name = config['dataset']['split_by_column']
    if test_data:
        idx_fname = os.path.join(model_dir, 'split_samples.json')
        try:
            split_samples = aux_utils.read_json(idx_fname)
            test_ids = split_samples['test']
        except FileNotFoundError as e:
            print("No split_samples file. Will predict all images in dir.")
    else:
        test_ids = np.unique(frames_meta[split_idx_name])

    # Find other indices to iterate over than split index name
    # E.g. if split is position, we also need to iterate over time and slice
    test_meta = pd.read_csv(os.path.join(model_dir, 'test_metadata.csv'))
    metadata_ids = {split_idx_name: test_ids}
    iter_ids = ['slice_idx', 'pos_idx', 'time_idx']

    for id in iter_ids:
        if id != split_idx_name:
            metadata_ids[id] = np.unique(test_meta[id])

    # Create image subdirectory to write predicted images
    pred_dir = os.path.join(model_dir, 'predictions')

    target_channel = config['dataset']['target_channels'][0]

    # If network depth is > 3 determine depth margins for +-z
    depth = 1
    if 'depth' in config['network']:
        depth = config['network']['depth']

    # Get channel name and extension for predictions
    pred_fnames = [f for f in os.listdir(pred_dir) if f.startswith('im_')]
    meta_row = aux_utils.parse_idx_from_name(pred_fnames[0])
    pred_channel = meta_row['channel_idx']
    _, ext = os.path.splitext(pred_fnames[0])

    if isinstance(orientations_list, str):
        orientations_list = [orientations_list]
    available_orientations = {'xy', 'xz', 'yz', 'xyz'}
    assert set(orientations_list).issubset(available_orientations), \
        "Orientations must be subset of {}".format(available_orientations)

    fn_mapping = {
        'xy': metrics_inst.estimate_xy_metrics,
        'xz': metrics_inst.estimate_xz_metrics,
        'yz': metrics_inst.estimate_yz_metrics,
        'xyz': metrics_inst.estimate_xyz_metrics,
    }
    metrics_mapping = {
        'xy': metrics_inst.get_metrics_xy,
        'xz': metrics_inst.get_metrics_xz,
        'yz': metrics_inst.get_metrics_yz,
        'xyz': metrics_inst.get_metrics_xyz,
    }
    df_mapping = {
        'xy': pd.DataFrame(),
        'xz': pd.DataFrame(),
        'yz': pd.DataFrame(),
        'xyz': pd.DataFrame(),
    }

    # Iterate over all indices for test data
    for time_idx in metadata_ids['time_idx']:
        for pos_idx in metadata_ids['pos_idx']:
            target_fnames = []
            pred_fnames = []
            for slice_idx in metadata_ids['slice_idx']:
                im_idx = aux_utils.get_meta_idx(
                    frames_metadata=frames_meta,
                    time_idx=time_idx,
                    channel_idx=target_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                )
                target_fname = os.path.join(
                    image_dir,
                    frames_meta.loc[im_idx, 'file_name'],
                )
                target_fnames.append(target_fname)
                pred_fname = aux_utils.get_im_name(
                    time_idx=time_idx,
                    channel_idx=pred_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                    ext=ext,
                )
                pred_fname = os.path.join(pred_dir, pred_fname)
                pred_fnames.append(pred_fname)

            target_stack = image_utils.read_imstack(
                input_fnames=tuple(target_fnames),
            )
            pred_stack = image_utils.read_imstack(
                input_fnames=tuple(pred_fnames),
                normalize_im=False,
            )

            if depth == 1:
                # Remove singular z dimension for 2D image
                target_stack = np.squeeze(target_stack)
                pred_stack = np.squeeze(pred_stack)
            if target_stack.dtype == np.float64:
                target_stack = target_stack.astype(np.float32)
            pred_name = "t{}_p{}".format(time_idx, pos_idx)
            for orientation in orientations_list:
                metric_fn = fn_mapping[orientation]
                metric_fn(
                    target=target_stack,
                    prediction=pred_stack,
                    pred_name=pred_name,
                )
                df_mapping[orientation] = df_mapping[orientation].append(
                    metrics_mapping[orientation](),
                    ignore_index=True,
                )

    # Save non-empty dataframes
    for orientation in orientations_list:
        metrics_df = df_mapping[orientation]
        df_name = 'metrics_{}.csv'.format(orientation)
        metrics_name = os.path.join(pred_dir, df_name)
        metrics_df.to_csv(metrics_name, sep=",", index=False)
コード例 #3
0
def create_save_mask(input_fnames,
                     flat_field_fname,
                     str_elem_radius,
                     mask_dir,
                     mask_channel_idx,
                     time_idx,
                     pos_idx,
                     slice_idx,
                     int2str_len,
                     mask_type,
                     mask_ext,
                     normalize_im=False):

    """
    Create and save mask.
    When >1 channel are used to generate the mask, mask of each channel is
    generated then added together.

    :param tuple input_fnames: tuple of input fnames with full path
    :param str flat_field_fname: fname of flat field image
    :param int str_elem_radius: size of structuring element used for binary
     opening. str_elem: disk or ball
    :param str mask_dir: dir to save masks
    :param int mask_channel_idx: channel number of mask
    :param int time_idx: time points to use for generating mask
    :param int pos_idx: generate masks for given position / sample ids
    :param int slice_idx: generate masks for given slice ids
    :param int int2str_len: Length of str when converting ints
    :param str mask_type: thresholding type used for masking or str to map to
     masking function
    :param str mask_ext: '.npy' or '.png'. Save the mask as uint8 PNG or
     NPY files for otsu, unimodal masks, recommended to save as npy
     float64 for borders_weight_loss_map masks to avoid loss due to scaling it
     to uint8.
    :param bool normalize_im: indicator to normalize image based on z-score or not
    :return dict cur_meta for each mask
    """
    im_stack = image_utils.read_imstack(
        input_fnames,
        flat_field_fname,
        normalize_im=normalize_im,
    )
    masks = []
    for idx in range(im_stack.shape[-1]):
        im = im_stack[..., idx]
        if mask_type == 'otsu':
            mask = mask_utils.create_otsu_mask(im.astype('float32'), str_elem_radius)
        elif mask_type == 'unimodal':
            mask = mask_utils.create_unimodal_mask(im.astype('float32'), str_elem_radius)
        elif mask_type == 'borders_weight_loss_map':
            mask = mask_utils.get_unet_border_weight_map(im)
        masks += [mask]
    # Border weight map mask is a float mask not binary like otsu or unimodal,
    # so keep it as is (assumes only one image in stack)
    if mask_type == 'borders_weight_loss_map':
        mask = masks[0]
    else:
        masks = np.stack(masks, axis=-1)
        mask = np.any(masks, axis=-1)

    # Create mask name for given slice, time and position
    file_name = aux_utils.get_im_name(
        time_idx=time_idx,
        channel_idx=mask_channel_idx,
        slice_idx=slice_idx,
        pos_idx=pos_idx,
        int2str_len=int2str_len,
        ext=mask_ext,
    )
    if mask_ext == '.npy':
        # Save mask for given channels, mask is 2D
        np.save(os.path.join(mask_dir, file_name),
                mask,
                allow_pickle=True,
                fix_imports=True)
    elif mask_ext == '.png':
        # Covert mask to uint8
        # Border weight map mask is a float mask not binary like otsu or unimodal,
        # so keep it as is
        if mask_type == 'borders_weight_loss_map':
            assert im_stack.shape[-1] == 1
            # Note: Border weight map mask should only be generated from one binary image
        else:
            mask = mask.astype(np.uint8) * np.iinfo(np.uint8).max
        cv2.imwrite(os.path.join(mask_dir, file_name), mask)
    else:
        raise ValueError("mask_ext can be '.npy' or '.png', not {}".format(mask_ext))
    cur_meta = {'channel_idx': mask_channel_idx,
                'slice_idx': slice_idx,
                'time_idx': time_idx,
                'pos_idx': pos_idx,
                'file_name': file_name}
    return cur_meta
コード例 #4
0
def crop_at_indices_save(input_fnames,
                         flat_field_fname,
                         hist_clip_limits,
                         time_idx,
                         channel_idx,
                         pos_idx,
                         slice_idx,
                         crop_indices,
                         image_format,
                         save_dir,
                         int2str_len=3,
                         is_mask=False,
                         tile_3d=False,
                         normalize_im=False):
    """Crop image into tiles at given indices and save

    :param tuple input_fnames: tuple of input fnames with full path
    :param str flat_field_fname: fname of flat field image
    :param tuple hist_clip_limits: limits for histogram clipping
    :param int time_idx: time point of input image
    :param int channel_idx: channel idx of input image
    :param int slice_idx: slice idx of input image
    :param int pos_idx: sample idx of input image
    :param tuple crop_indices: tuple of indices for cropping
    :param str image_format: zyx or xyz
    :param str save_dir: output dir to save tiles
    :param int int2str_len: len of indices for creating file names
    :param bool is_mask: Indicates if files are masks
    :param bool tile_3d: indicator for tiling in 3D
    :param bool normalize_im: Indicates if normalizing using z score is needed
    :return: pd.DataFrame from a list of dicts with metadata
    """

    try:
        input_image = image_utils.read_imstack(
            input_fnames=input_fnames,
            flat_field_fname=flat_field_fname,
            hist_clip_limits=hist_clip_limits,
            is_mask=is_mask,
            normalize_im=normalize_im
        )
        save_dict = {'time_idx': time_idx,
                     'channel_idx': channel_idx,
                     'pos_idx': pos_idx,
                     'slice_idx': slice_idx,
                     'save_dir': save_dir,
                     'image_format': image_format,
                     'int2str_len': int2str_len}

        tile_meta_df = tile_utils.crop_at_indices(
            input_image=input_image,
            crop_indices=crop_indices,
            save_dict=save_dict,
            tile_3d=tile_3d,
        )
    except Exception as e:
        err_msg = 'error in t_{}, c_{}, pos_{}, sl_{}'.format(
            time_idx, channel_idx, pos_idx, slice_idx
        )
        err_msg = err_msg + str(e)
        # TODO(Anitha) write to log instead
        print(err_msg)
        raise e

    return tile_meta_df
コード例 #5
0
def tile_and_save(input_fnames,
                  flat_field_fname,
                  hist_clip_limits,
                  time_idx,
                  channel_idx,
                  pos_idx,
                  slice_idx,
                  tile_size,
                  step_size,
                  min_fraction,
                  image_format,
                  save_dir,
                  int2str_len=3,
                  is_mask=False,
                  normalize_im=False):
    """Crop image into tiles at given indices and save

    :param tuple input_fnames: tuple of input fnames with full path
    :param str flat_field_fname: fname of flat field image
    :param tuple hist_clip_limits: limits for histogram clipping
    :param int time_idx: time point of input image
    :param int channel_idx: channel idx of input image
    :param int slice_idx: slice idx of input image
    :param int pos_idx: sample idx of input image
    :param list tile_size: size of tile along row, col (& slices)
    :param list step_size: step size along row, col (& slices)
    :param float min_fraction: min foreground volume fraction for keep tile
    :param str image_format: zyx / xyz
    :param str save_dir: output dir to save tiles
    :param int int2str_len: len of indices for creating file names
    :param bool is_mask: Indicates if files are masks
    :param bool normalize_im: Indicates if normalizing using z score is needed
    :return: pd.DataFrame from a list of dicts with metadata
    """
    try:
        input_image = image_utils.read_imstack(
            input_fnames=input_fnames,
            flat_field_fname=flat_field_fname,
            hist_clip_limits=hist_clip_limits,
            is_mask=is_mask,
            normalize_im=normalize_im
        )
        save_dict = {'time_idx': time_idx,
                     'channel_idx': channel_idx,
                     'pos_idx': pos_idx,
                     'slice_idx': slice_idx,
                     'save_dir': save_dir,
                     'image_format': image_format,
                     'int2str_len': int2str_len}

        tile_meta_df = tile_utils.tile_image(
            input_image=input_image,
            tile_size=tile_size,
            step_size=step_size,
            min_fraction=min_fraction,
            save_dict=save_dict,
        )
    except Exception as e:
        err_msg = 'error in t_{}, c_{}, pos_{}, sl_{}'.format(
            time_idx, channel_idx, pos_idx, slice_idx
        )
        err_msg = err_msg + str(e)
        # TODO(Anitha) write to log instead
        print(err_msg)
        raise e
    return tile_meta_df