Example #1
0
    def _get_tiled_data(self):
        """
        If tile directory already exists, check which channels have been
        processed and only tile new channels.

        :return dataframe tiled_meta: Metadata with previously tiled channels
        :return list of lists tile_indices: Nbr tiles x 4 indices with row
        start + stop and column start + stop indices
        """
        if self.tiles_exist:
            tiled_meta = aux_utils.read_meta(self.tile_dir)
            # Find untiled channels
            tiled_channels = np.unique(tiled_meta['channel_idx'])
            new_channels = list(set(self.channel_ids) - set(tiled_channels))
            if len(new_channels) == 0:
                print('All channels in config have already been tiled')
                return
            self.channel_ids = new_channels
            tile_indices = self._get_tile_indices(
                tiled_meta=tiled_meta,
                time_idx=self.time_ids[0],
                channel_idx=tiled_channels[0],
                pos_idx=self.pos_ids[0],
                slice_idx=self.slice_ids[0])
        else:
            tiled_meta = self._get_dataframe()
            tile_indices = None
        return tiled_meta, tile_indices
Example #2
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 channel_ids,
                 slice_ids,
                 block_size=32):
        """
        Flatfield images are estimated once per channel for 2D data

        :param str input_dir: Directory with 2D image frames from dataset
        :param str output_dir: Base output directory
        :param int/list channel_ids: channel ids for flat field_correction
        :param int/list slice_ids: Z slice indices for flatfield correction
        :param int block_size: Size of blocks image will be divided into
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        # Create flat_field_dir as a subdirectory of output_dir
        self.flat_field_dir = os.path.join(self.output_dir,
                                           'flat_field_images')
        os.makedirs(self.flat_field_dir, exist_ok=True)
        self.slice_ids = slice_ids
        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        metadata_ids, _ = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
        )
        self.channels_ids = metadata_ids['channel_ids']
        self.slice_ids = metadata_ids['slice_ids']
        if block_size is None:
            block_size = 32
        self.block_size = block_size
Example #3
0
def test_read_meta():
    with TempDirectory() as tempdir:
        meta_fname = 'test_meta.csv'
        meta_df.to_csv(os.path.join(tempdir.path, meta_fname))
        test_meta = aux_utils.read_meta(tempdir.path, meta_fname)
        # Only testing file name as writing df changes dtypes
        nose.tools.assert_true(test_meta['file_name'].equals(
            meta_df['file_name']))
Example #4
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 scale_factor,
                 channel_ids=-1,
                 time_ids=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 int2str_len=3,
                 num_workers=4,
                 flat_field_dir=None):
        """
        :param str input_dir: Directory with image frames
        :param str output_dir: Base output directory
        :param float/list scale_factor: Scale factor for resizing frames.
        :param int/list channel_ids: Channel indices to resize
            (default -1 includes all slices)
        :param int/list time_ids: timepoints to use
        :param int/list slice_ids: Index of slice (z) indices to use
        :param int/list pos_ids: Position (FOV) indices to use
        :param int int2str_len: Length of str when converting ints
        :param int num_workers: number of workers for multiprocessing
        :param str flat_field_dir: dir with flat field images
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        if isinstance(scale_factor, list):
            scale_factor = np.array(scale_factor)
        assert np.all(scale_factor > 0), \
            "Scale factor should be positive float, not {}".format(scale_factor)
        self.scale_factor = scale_factor

        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        metadata_ids, _ = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
        )
        self.time_ids = metadata_ids['time_ids']
        self.channel_ids = metadata_ids['channel_ids']
        self.slice_ids = metadata_ids['slice_ids']
        self.pos_ids = metadata_ids['pos_ids']

        # Create resize_dir as a subdirectory of output_dir
        self.resize_dir = os.path.join(
            self.output_dir,
            'resized_images',
        )
        os.makedirs(self.resize_dir, exist_ok=True)

        self.int2str_len = int2str_len
        self.num_workers = num_workers
        self.flat_field_dir = flat_field_dir
    def test_pre_process_resize2d(self):
        cur_config = self.pp_config
        cur_config['resize'] = {
            'scale_factor': 2,
            'resize_3d': False,
        }
        cur_config['make_weight_map'] = False
        out_config, runtime = pp.pre_process(cur_config, self.base_config)

        self.assertIsInstance(runtime, np.float)
        self.assertEqual(
            out_config['resize']['resize_dir'],
            os.path.join(self.output_dir, 'resized_images')
        )
        resize_dir = out_config['resize']['resize_dir']
        # Check that all images have been resized
        resize_meta = aux_utils.read_meta(resize_dir)
        # 3 resized channels
        expected_rows = 3 * len(self.slice_ids) * len(self.pos_ids)
        self.assertEqual(resize_meta.shape[0], expected_rows)
        # Load an image and make sure it's twice as big
        im = cv2.imread(
            os.path.join(resize_dir, 'im_c003_z002_t000_p010.png'),
            cv2.IMREAD_ANYDEPTH,
        )
        self.assertTupleEqual(im.shape, (60, 40))
        self.assertTrue(im.dtype, np.uint8)
        # There should now be 2*2 the amount of tiles, same shape
        tile_dir = out_config['tile']['tile_dir']
        tile_meta = aux_utils.read_meta(tile_dir)
        # 4 processed channels (0, 1, 3, 4), 24 tiles per image
        expected_rows = 4 * 24 * len(self.slice_ids) * len(self.pos_ids)
        self.assertEqual(tile_meta.shape[0], expected_rows)
        # Load a tile and assert shape
        im = np.load(os.path.join(
            tile_dir,
            'im_c001_z000_t000_p007_r40-50_c20-30_sl0-1.npy',
        ))
        self.assertTupleEqual(im.shape, (1, 10, 10))
        self.assertTrue(im.dtype == np.float64)
Example #6
0
    def _get_input_fnames(self,
                          time_idx,
                          channel_idx,
                          slice_idx,
                          pos_idx,
                          mask_dir=None):
        """Get input_fnames

        :param int time_idx: Time index
        :param int channel_idx: Channel index
        :param int slice_idx: Slice (z) index
        :param int pos_idx: Position (FOV) index
        :param str mask_dir: Directory containing masks
        :return: list of input fnames
        """
        if mask_dir is None:
            depth = self.channel_depth[channel_idx]
        else:
            depth = self.mask_depth
        margin = 0 if depth == 1 else depth // 2
        im_fnames = []
        for z in range(slice_idx - margin, slice_idx + margin + 1):
            if mask_dir is not None:
                mask_meta = aux_utils.read_meta(mask_dir)
                meta_idx = aux_utils.get_meta_idx(
                    mask_meta,
                    time_idx,
                    channel_idx,
                    z,
                    pos_idx,
                )
                file_path = os.path.join(
                    mask_dir,
                    mask_meta.loc[meta_idx, 'file_name'],
                )
            else:
                meta_idx = aux_utils.get_meta_idx(
                    self.frames_metadata,
                    time_idx,
                    channel_idx,
                    z,
                    pos_idx,
                )
                file_path = os.path.join(
                    self.input_dir,
                    self.frames_metadata.loc[meta_idx, 'file_name'],
                )
            # check if file_path exists
            im_fnames.append(file_path)
        return im_fnames
    def test_validate_mask_meta_no_channel(self):
        mask_channel = preprocess_utils.validate_mask_meta(
            mask_dir=self.mask_dir,
            input_dir=self.input_dir,
            csv_name=self.csv_name,
        )
        self.assertEqual(mask_channel, 4)

        out_meta = aux_utils.read_meta(self.mask_dir)
        for i, row in out_meta.iterrows():
            self.assertEqual(row.slice_idx, self.slice_idx)
            self.assertEqual(row.time_idx, self.time_idx)
            self.assertEqual(row.channel_idx, 4)
            self.assertEqual(row.pos_idx, i)
            self.assertEqual(row.file_name, "mask_{}.png".format(i + 1))
Example #8
0
    def _get_split_ids(self, data_split='test'):
        """
        Get the indices for data_split

        :param str data_split: in [train, val, test]
        :return list inference_ids: Indices for inference given data split
        :return str split_col: Dataframe column name, which was split in training
        """
        split_col = self.config['dataset']['split_by_column']
        try:
            split_fname = os.path.join(self.model_dir, 'split_samples.json')
            split_samples = aux_utils.read_json(split_fname)
            inference_ids = split_samples[data_split]
        except FileNotFoundError as e:
            print("No split_samples file. " "Will predict all images in dir.")
            frames_meta = aux_utils.read_meta(self.image_dir)
            inference_ids = np.unique(frames_meta[split_col]).tolist()
        return split_col, inference_ids
Example #9
0
def validate_mask_meta(mask_dir,
                       input_dir,
                       csv_name=None,
                       mask_channel=None):
    """
    If user provides existing masks, the mask directory should also contain
    a csv file (not named frames_meta.csv which is reserved for output) with
    two column names: mask_name and file_name. Each row should describe the
    mask name and the corresponding file name. Each file_name should exist in
    input_dir and belong to the same channel.
    This function checks that all file names exist in input_dir and writes a
    frames_meta csv containing mask names with indices corresponding to the
    matched file_name. It also assigns a mask channel number for future
    preprocessing steps like tiling.

    :param str mask_dir: Mask directory
    :param str input_dir: Input image directory, to match masks with images
    :param int/None mask_channel: Channel idx assigned to masks
    :return int mask_channel: New channel index for masks for writing tiles
    :raises IOError: If no csv file is present in mask_dir
    :raises IOError: If more than one csv file exists in mask_dir
        and no csv_name is provided to resolve ambiguity
    :raises AssertionError: If csv doesn't consist of two columns named
        'mask_name' and 'file_name'
    :raises IndexError: If unable to match file_name in mask_dir csv with
        file_name in input_dir for any given mask row
    """
    input_meta = aux_utils.read_meta(input_dir)
    if mask_channel is None:
        mask_channel = int(
            input_meta['channel_idx'].max() + 1
        )
    # Make sure there is a csv file file
    if csv_name is not None:
        csv_name = glob.glob(os.path.join(mask_dir, csv_name))
        if len(csv_name) == 1:
            # Use the one existing csv name
            csv_name = csv_name[0]
        else:
            csv_name = None
    # No csv name given, search for it
    if csv_name is None:
        csv_name = glob.glob(os.path.join(mask_dir, '*.csv'))
        if len(csv_name) == 0:
            raise IOError("No csv file present in mask dir")
        else:
            # See if frames_meta is already present, if so, move on
            has_meta = next((s for s in csv_name if 'frames_meta.csv' in s), None)
            if isinstance(has_meta, str):
                # Return existing mask channel from frames_meta
                frames_meta = pd.read_csv(
                    os.path.join(mask_dir, 'frames_meta.csv'),
                )
                mask_channel = np.unique(frames_meta['channel_idx'])
                if isinstance(mask_channel, list):
                    assert len(mask_channel) == 1,\
                        "Found more than one mask channel: {}".format(mask_channel)
                    mask_channel = mask_channel[0]
                if type(mask_channel).__module__ == 'numpy':
                    mask_channel = mask_channel.item()
                return mask_channel
            elif len(csv_name) == 1:
                # Use the one existing csv name
                csv_name = csv_name[0]
            else:
                # More than one csv file in dir
                raise IOError("More than one csv file present in mask dir",
                              "use csv_name to specify which one to use")

    # Read csv with masks and corresponding input file names
    mask_meta = aux_utils.read_meta(input_dir=mask_dir, meta_fname=csv_name)

    assert len(set(mask_meta).difference({'file_name', 'mask_name'})) == 0,\
        "mask csv should have columns mask_name and file_name " +\
        "(corresponding to the file_name in input_dir)"
    # Check that file_name for each mask_name matches files in input_dir
    file_names = input_meta['file_name']
    # Create dataframe that will store all indices for masks
    out_meta = aux_utils.make_dataframe(nbr_rows=mask_meta.shape[0])
    for i, row in mask_meta.iterrows():
        try:
            file_loc = file_names[file_names == row.file_name].index[0]
        except IndexError as e:
            msg = "Can't find image file name match for {}, error {}".format(
                row.file_name, e)
            raise IndexError(msg)
        # Fill dataframe with row indices from matched image in input dir
        out_meta.iloc[i] = input_meta.iloc[file_loc]
        # Write back the mask name
        out_meta.iloc[i]['file_name'] = row.mask_name

    assert len(out_meta.channel_idx.unique()) == 1,\
        "Masks should match one input channel only"
    assert mask_channel not in set(input_meta.channel_idx.unique()),\
        "Mask channel {} already exists in image dir".format(mask_channel)

    # Replace channel_idx new mask channel idx
    out_meta['channel_idx'] = mask_channel

    # Write mask metadata with indices that match input images
    meta_filename = os.path.join(mask_dir, 'frames_meta.csv')
    out_meta.to_csv(meta_filename, sep=",")

    return mask_channel
Example #10
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 channel_ids,
                 flat_field_dir=None,
                 time_ids=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 int2str_len=3,
                 uniform_struct=True,
                 num_workers=4,
                 mask_type='otsu',
                 mask_channel=None,
                 mask_ext='.npy',
                 normalize_im=False):
        """
        :param str input_dir: Directory with image frames
        :param str output_dir: Base output directory
        :param list[int] channel_ids: Channel indices to be masked (typically
            just one)
        :param str flat_field_dir: Directory with flatfield images if
            flatfield correction is applied
        :param int/list channel_ids: generate mask from the sum of these
            (flurophore) channel indices
        :param list/int time_ids: timepoints to consider
        :param int slice_ids: Index of which focal plane (z)
            acquisition to use (default -1 includes all slices)
        :param int pos_ids: Position (FOV) indices to use
        :param int int2str_len: Length of str when converting ints
        :param bool uniform_struct: bool indicator for same structure across
            pos and time points
        :param int num_workers: number of workers for multiprocessing
        :param str mask_type: method to use for generating mask. Needed for
            mapping to the masking function
        :param int mask_channel: channel number assigned to to be generated masks.
            If resizing images on a subset of channels, frames_meta is from resize
            dir, which could lead to wrong mask channel being assigned.
        :param str mask_ext: '.npy' or 'png'. Save the mask as uint8 PNG or
            NPY files
        :param bool normalize_im: indicator to normalize image based on z-score or not
        """

        self.input_dir = input_dir
        self.output_dir = output_dir
        self.flat_field_dir = flat_field_dir
        self.num_workers = num_workers
        self.normalize_im = normalize_im

        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        # Create a unique mask channel number so masks can be treated
        # as a new channel
        if mask_channel is None:
            self.mask_channel = int(self.frames_metadata['channel_idx'].max() +
                                    1)
        else:
            self.mask_channel = int(mask_channel)

        metadata_ids, nested_id_dict = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
            uniform_structure=uniform_struct,
        )
        self.time_ids = metadata_ids['time_ids']
        self.channel_ids = metadata_ids['channel_ids']
        self.slice_ids = metadata_ids['slice_ids']
        self.pos_ids = metadata_ids['pos_ids']
        # Create mask_dir as a subdirectory of output_dir
        self.mask_dir = os.path.join(
            self.output_dir,
            'mask_channels_' + '-'.join(map(str, self.channel_ids)),
        )
        os.makedirs(self.mask_dir, exist_ok=True)

        self.int2str_len = int2str_len
        self.uniform_struct = uniform_struct
        self.nested_id_dict = nested_id_dict

        assert mask_type in ['otsu', 'unimodal', 'borders_weight_loss_map'], \
            'Masking method invalid, Otsu, borders_weight_loss_map, " +\
            "and unimodal are currently supported'

        self.mask_type = mask_type
        self.mask_ext = mask_ext
Example #11
0
    def tile_mask_stack(self,
                        mask_dir,
                        mask_channel,
                        min_fraction,
                        mask_depth=1):
        """
        Tiles images in the specified channels assuming there are masks
        already created in mask_dir. Only tiles above a certain fraction
        of foreground in mask tile will be saved and added to metadata.

        Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx',
        'slice_idx', 'file_name'] for all the tiles

        :param str mask_dir: Directory containing masks
        :param int mask_channel: Channel number assigned to mask
        :param float min_fraction: Minimum fraction of foreground in tiled masks
        :param int mask_depth: Depth for mask channel
        """

        # mask depth has to match input or ouput channel depth
        assert mask_depth <= max(self.channel_depth.values())
        self.mask_depth = mask_depth

        # tile and save masks
        # if mask channel is already tiled
        if self.tiles_exist and mask_channel in self.channel_ids:
            mask_meta_df = pd.read_csv(
                os.path.join(self.tile_dir, 'frames_meta.csv'))
        else:
            # TODO: different masks across timepoints (but MaskProcessor
            # generates mask for tp=0 only)
            mask_fn_args = []
            for slice_idx in self.slice_ids:
                for time_idx in self.time_ids:
                    for pos_idx in self.pos_ids:
                        # Evaluate mask, then channels.The masks will influence
                        # tiling indices, so it's not allowed to add masks to
                        # existing tiled data sets (indices will be retrieved
                        # from existing meta)
                        cur_args = self.get_crop_tile_args(
                            channel_idx=mask_channel,
                            time_idx=time_idx,
                            slice_idx=slice_idx,
                            pos_idx=pos_idx,
                            task_type='tile',
                            mask_dir=mask_dir,
                            min_fraction=min_fraction,
                            normalize_im=False,
                        )
                        mask_fn_args.append(cur_args)

            # tile_image uses min_fraction assuming input_image is a bool
            mask_meta_df_list = mp_utils.mp_tile_save(
                mask_fn_args,
                workers=self.num_workers,
            )
            mask_meta_df = pd.concat(mask_meta_df_list, ignore_index=True)
            # Finally, save all the metadata
            mask_meta_df = mask_meta_df.sort_values(by=['file_name'])
            mask_meta_df.to_csv(
                os.path.join(self.tile_dir, 'frames_meta.csv'),
                sep=',',
            )
        # remove mask_channel from self.channel_ids if included
        _ = [
            self.channel_ids.pop(idx)
            for idx, val in enumerate(self.channel_ids) if val == mask_channel
        ]
        _ = [
            self.normalize_channels.pop(idx)
            for idx, val in enumerate(self.channel_ids) if val == mask_channel
        ]

        fn_args = []
        for slice_idx in self.slice_ids:
            for time_idx in self.time_ids:
                for pos_idx in np.unique(self.frames_metadata["pos_idx"]):
                    # Loop through all channels and tile from indices
                    cur_tile_indices = self._get_tile_indices(
                        tiled_meta=mask_meta_df,
                        time_idx=time_idx,
                        channel_idx=mask_channel,
                        pos_idx=pos_idx,
                        slice_idx=slice_idx)
                    if np.any(cur_tile_indices):
                        for i, channel_idx in enumerate(self.channel_ids):
                            cur_args = self.get_crop_tile_args(
                                channel_idx,
                                time_idx,
                                slice_idx,
                                pos_idx,
                                task_type='crop',
                                tile_indices=cur_tile_indices,
                                normalize_im=self.normalize_channels[i],
                            )
                            fn_args.append(cur_args)
        tiled_meta_df_list = mp_utils.mp_crop_save(
            fn_args,
            workers=self.num_workers,
        )
        tiled_metadata = pd.concat(tiled_meta_df_list, ignore_index=True)
        # If there's been tiling done already, add to existing metadata
        prev_tiled_metadata = aux_utils.read_meta(self.tile_dir)
        tiled_metadata = pd.concat(
            [
                prev_tiled_metadata.reset_index(drop=True),
                tiled_metadata.reset_index(drop=True)
            ],
            axis=0,
            ignore_index=True,
        )
        # Finally, save all the metadata
        tiled_metadata = tiled_metadata.sort_values(by=['file_name'])
        tiled_metadata.to_csv(
            os.path.join(self.tile_dir, "frames_meta.csv"),
            sep=',',
        )
 def test_pre_process(self):
     out_config, runtime = pp.pre_process(self.pp_config, self.base_config)
     self.assertIsInstance(runtime, np.float)
     self.assertEqual(
         self.base_config['input_dir'],
         self.image_dir,
     )
     self.assertEqual(
         self.base_config['channel_ids'],
         self.pp_config['channel_ids'],
     )
     self.assertEqual(
         out_config['flat_field']['flat_field_dir'],
         os.path.join(self.output_dir, 'flat_field_images')
     )
     self.assertEqual(
         out_config['masks']['mask_dir'],
         os.path.join(self.output_dir, 'mask_channels_3')
     )
     self.assertEqual(
         out_config['tile']['tile_dir'],
         os.path.join(self.output_dir, 'tiles_10-10_step_10-10'),
     )
     # Make sure new mask channel assignment is correct
     self.assertEqual(out_config['masks']['mask_channel'], 4)
     # Check that masks are generated
     mask_dir = out_config['masks']['mask_dir']
     mask_meta = aux_utils.read_meta(mask_dir)
     mask_names = os.listdir(mask_dir)
     mask_names.pop(mask_names.index('frames_meta.csv'))
     # Validate that all masks are there
     self.assertEqual(
         len(mask_names),
         len(self.slice_ids) * len(self.pos_ids),
     )
     for p in self.pos_ids:
         for z in self.slice_ids:
             im_name = aux_utils.get_im_name(
                 channel_idx=out_config['masks']['mask_channel'],
                 slice_idx=z,
                 time_idx=self.time_idx,
                 pos_idx=p,
             )
             im = cv2.imread(
                 os.path.join(mask_dir, im_name),
                 cv2.IMREAD_ANYDEPTH,
             )
             self.assertTupleEqual(im.shape, (30, 20))
             self.assertTrue(im.dtype == 'uint8')
             self.assertTrue(im_name in mask_names)
             self.assertTrue(im_name in mask_meta['file_name'].tolist())
     # Check flatfield images
     ff_dir = out_config['flat_field']['flat_field_dir']
     ff_names = os.listdir(ff_dir)
     self.assertEqual(len(ff_names), 3)
     for processed_channel in [0, 1, 3]:
         expected_name = "flat-field_channel-{}.npy".format(processed_channel)
         self.assertTrue(expected_name in ff_names)
         im = np.load(os.path.join(ff_dir, expected_name))
         self.assertTrue(im.dtype == np.float64)
         self.assertTupleEqual(im.shape, (30, 20))
     # Check tiles
     tile_dir = out_config['tile']['tile_dir']
     tile_meta = aux_utils.read_meta(tile_dir)
     # 4 processed channels (0, 1, 3, 4), 6 tiles per image
     expected_rows = 4 * 6 * len(self.slice_ids) * len(self.pos_ids)
     self.assertEqual(tile_meta.shape[0], expected_rows)
     # Check indices
     self.assertListEqual(
         tile_meta.channel_idx.unique().tolist(),
         [0, 1, 3, 4],
     )
     self.assertListEqual(
         tile_meta.pos_idx.unique().tolist(),
         self.pos_ids,
     )
     self.assertListEqual(
         tile_meta.slice_idx.unique().tolist(),
         self.slice_ids,
     )
     self.assertListEqual(
         tile_meta.time_idx.unique().tolist(),
         [self.time_idx],
     )
     self.assertListEqual(
         list(tile_meta),
         ['channel_idx',
          'col_start',
          'file_name',
          'pos_idx',
          'row_start',
          'slice_idx',
          'time_idx']
     )
     self.assertListEqual(
         tile_meta.row_start.unique().tolist(),
         [0, 10, 20],
     )
     self.assertListEqual(
         tile_meta.col_start.unique().tolist(),
         [0, 10],
     )
     # Read one tile and check format
     # r = row start/end idx, c = column start/end, sl = slice start/end
     # sl0-1 signifies depth of 1
     im = np.load(os.path.join(
         tile_dir,
         'im_c001_z000_t000_p007_r10-20_c10-20_sl0-1.npy',
     ))
     self.assertTupleEqual(im.shape, (1, 10, 10))
     self.assertTrue(im.dtype == np.float64)
Example #13
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 tile_size=[256, 256],
                 step_size=[64, 64],
                 depths=1,
                 time_ids=-1,
                 channel_ids=-1,
                 normalize_channels=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 hist_clip_limits=None,
                 flat_field_dir=None,
                 image_format='zyx',
                 num_workers=4,
                 int2str_len=3,
                 tile_3d=False):
        """
        Tiles images.
        If tile_dir already exist, it will check which channels are already
        tiled, get indices from them and tile from indices only on the channels
        not already present.

        :param str input_dir: Directory with frames to be tiled
        :param str output_dir: Base output directory
        :param list tile_size: size of the blocks to be cropped
            from the image
        :param list step_size: size of the window shift. In case
            of no overlap, the step size is tile_size. If overlap, step_size <
            tile_size
        :param int/list depths: The z depth for generating stack training data
            Default 1 assumes 2D data for all channels to be tiled.
            For cases where input and target shapes are not the same (e.g. stack
            to 2D) you should specify depths for each channel in tile.channels.
        :param list/int time_ids: Tile given timepoint indices
        :param list/int channel_ids: Tile images in the given channel indices
            default=-1, tile all channels.
        :param list/int normalize_channels: list of booleans matching channel_ids
            indicating if channel should be normalized or not.
        :param int slice_ids: Index of which focal plane acquisition to
            use (for 2D). default=-1 for the whole z-stack
        :param int pos_ids: Position (FOV) indices to use
        :param list hist_clip_limits: lower and upper percentiles used for
            histogram clipping.
        :param str flat_field_dir: Flatfield directory. None if no flatfield
            correction
        :param str image_format: zyx (preferred) or xyz
        :param int num_workers: number of workers for multiprocessing
        :param int int2str_len: number of characters for each idx to be used
            in file names
        :param bool tile_3d: Whether tiling is 3D or 2D
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.normalize_channels = normalize_channels
        self.depths = depths
        self.tile_size = tile_size
        self.step_size = step_size
        self.hist_clip_limits = hist_clip_limits
        self.image_format = image_format
        assert self.image_format in {'zyx', 'xyz'}, \
            'Data format must be zyx or xyz'
        self.num_workers = num_workers
        self.int2str_len = int2str_len
        self.tile_3d = tile_3d

        self.str_tile_step = 'tiles_{}_step_{}'.format(
            '-'.join([str(val) for val in tile_size]),
            '-'.join([str(val) for val in step_size]),
        )
        self.tile_dir = os.path.join(
            output_dir,
            self.str_tile_step,
        )

        # If tile dir already exist, only tile channels not already present
        self.tiles_exist = False
        # If tile dir already exist, things could get messy because we don't
        # have any checks in place for how to add to existing tiles
        try:
            os.makedirs(self.tile_dir, exist_ok=False)
            # make dir for saving indiv meta per image, could be used for
            # tracking job success / fail
            os.makedirs(os.path.join(self.tile_dir, 'meta_dir'),
                        exist_ok=False)
        except FileExistsError as e:
            print("Tile dir exists. Only add untiled channels.")
            self.tiles_exist = True

        # make dir for saving individual meta per image, could be used for
        # tracking job success / fail
        os.makedirs(os.path.join(self.tile_dir, 'meta_dir'), exist_ok=True)

        self.flat_field_dir = flat_field_dir
        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        # Get metadata indices
        metadata_ids, _ = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
            uniform_structure=True)
        self.channel_ids = metadata_ids['channel_ids']
        self.normalize_channels = normalize_channels
        self.time_ids = metadata_ids['time_ids']
        self.slice_ids = metadata_ids['slice_ids']
        self.pos_ids = metadata_ids['pos_ids']
        self.normalize_channels = normalize_channels
        # Determine which channels should be normalized in tiling
        if self.normalize_channels == -1:
            self.normalize_channels = [True] * len(self.channel_ids)
        else:
            assert len(self.normalize_channels) == len(self.channel_ids),\
                "Channel ids {} and normalization list {} mismatch".format(
                    self.channel_ids,
                    self.normalize_channels,
                )
        # If more than one depth is specified, length must match channel ids
        if isinstance(self.depths, list):
            assert len(self.depths) == len(self.channel_ids),\
             "depths ({}) and channels ({}) length mismatch".format(
                self.depths, self.channel_ids,
            )
            # Get max of all specified depths
            max_depth = max(self.depths)
            # Convert channels + depths to dict for lookup
            self.channel_depth = dict(zip(self.channel_ids, self.depths))
        else:
            # If depth is scalar, make depth the same for all channels
            max_depth = self.depths
            self.channel_depth = dict(
                zip(self.channel_ids, [self.depths] * len(self.channel_ids)), )

        # Adjust slice margins
        self.slice_ids = aux_utils.adjust_slice_margins(
            slice_ids=self.slice_ids,
            depth=max_depth,
        )
Example #14
0
    def tile_mask_stack(self,
                        mask_dir,
                        mask_channel,
                        min_fraction,
                        mask_depth=1):
        """
        Tiles images in the specified channels assuming there are masks
        already created in mask_dir. Only tiles above a certain fraction
        of foreground in mask tile will be saved and added to metadata.


        Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx',
        'slice_idx', 'file_name'] for all the tiles

        :param str mask_dir: Directory containing masks
        :param int mask_channel: Channel number assigned to mask
        :param float min_fraction: Min fraction of foreground in tiled masks
        :param int mask_depth: Depth for mask channel
        """

        # mask depth has to match input or output channel depth
        assert mask_depth <= max(self.channel_depth.values())
        self.mask_depth = mask_depth

        # Mask meta is stored in mask dir. If channel_ids= -1,frames_meta will
        # not contain any rows for mask channel. Assuming structure is same
        # across channels. Get time, pos and slice indices for mask channel

        mask_meta_df = aux_utils.read_meta(mask_dir)
        # TODO: different masks across timepoints (but MaskProcessor generates
        # mask for tp=0 only)
        _, mask_nested_id_dict = aux_utils.validate_metadata_indices(
            frames_metadata=mask_meta_df,
            time_ids=self.time_ids,
            channel_ids=mask_channel,
            slice_ids=self.slice_ids,
            pos_ids=self.pos_ids,
            uniform_structure=False
        )

        # get t, z, p indices for mask_channel
        mask_ch_ids = {}
        for tp_idx, tp_dict in mask_nested_id_dict.items():
            for ch_idx, ch_dict in tp_dict.items():
                if ch_idx == mask_channel:
                    ch0_dict = {mask_channel: ch_dict}
                    mask_ch_ids[tp_idx] = ch0_dict

        # tile mask channel and use the tile indices to tile the rest
        meta_df = self.tile_first_channel(
            channel0_ids=mask_ch_ids,
            channel0_depth=mask_depth,
            cur_mask_dir=mask_dir,
            min_fraction=min_fraction,
            is_mask=True,
        )
        # tile the rest
        self.tile_remaining_channels(
            nested_id_dict=self.nested_id_dict,
            tiled_ch_id=mask_channel,
            cur_meta_df=meta_df,
        )
Example #15
0
    def __init__(self,
                 image_dir,
                 dataset_config,
                 network_config,
                 split_col_ids,
                 image_format='zyx',
                 mask_dir=None,
                 flat_field_dir=None):
        """Init

        :param str image_dir: dir containing images AND NOT TILES!
        :param dict dataset_config: dict with dataset related params
        :param dict network_config: dict with network related params
        :param tuple split_col_ids: How to split up the dataset for inference:
         for frames_meta: (str split column name, list split row indices)
        :param str image_format: xyz or zyx format
        :param str/None mask_dir: If inference targets are masks stored in a
         different directory than the image dir. Assumes the directory contains
         a frames_meta.csv containing mask channels (which will be target channels
          in the inference config) z, t, p indices matching the ones in image_dir
        :param str flat_field_dir: Directory with flat field images
        """
        self.image_dir = image_dir
        self.target_dir = image_dir
        self.frames_meta = aux_utils.read_meta(self.image_dir)
        self.flat_field_dir = flat_field_dir
        if mask_dir is not None:
            self.target_dir = mask_dir
            # Append mask meta to frames meta
            mask_meta = aux_utils.read_meta(mask_dir)
            self.frames_meta = self.frames_meta.append(
                mask_meta,
                ignore_index=True,
            )
        # Use only indices selected for inference
        (split_col, split_ids) = split_col_ids
        meta_ids = self.frames_meta[split_col].isin(split_ids)
        self.frames_meta = self.frames_meta[meta_ids]

        assert image_format in {'xyz', 'zyx'}, \
            "Image format should be xyz or zyx, not {}".format(image_format)
        self.image_format = image_format

        # Check if model task (regression or segmentation) is specified
        self.model_task = 'regression'
        if 'model_task' in dataset_config:
            self.model_task = dataset_config['model_task']
            assert self.model_task in {'regression', 'segmentation'}, \
                "Model task must be either 'segmentation' or 'regression'"
        # the raw input images have to be normalized (z-score typically)
        self.normalize = True if self.model_task == 'regression' else False

        self.input_channels = dataset_config['input_channels']
        self.target_channels = dataset_config['target_channels']
        # get a subset of frames meta for only one channel to easily
        # extract indices (pos, time, slice) to iterate over
        df_idx = (self.frames_meta['channel_idx'] == self.target_channels[0])
        self.iteration_meta = self.frames_meta.copy()
        self.iteration_meta = self.iteration_meta[df_idx]

        self.depth = 1
        self.target_depth = 1
        # adjust slice margins if stacktostack or stackto2d
        network_cls = network_config['class']
        if network_cls in ['UNetStackTo2D', 'UNetStackToStack']:
            self.depth = network_config['depth']
            self.adjust_slice_indices()

        # if Unet2D 4D tensor, remove the singleton dimension, else 5D
        self.squeeze = False
        if network_cls == 'UNet2D':
            self.squeeze = True

        self.im_3d = False
        if network_cls == 'UNet3D':
            self.im_3d = True

        self.data_format = 'channels_first'
        if 'data_format' in network_config:
            self.data_format = network_config['data_format']

        # check if sorted values look right
        self.iteration_meta = self.iteration_meta.sort_values(
            ['pos_idx', 'slice_idx'],
            ascending=[True, True],
        )
        self.iteration_meta = self.iteration_meta.reset_index(drop=True)
        self.num_samples = len(self.iteration_meta)
Example #16
0
    def __init__(self,
                 train_config,
                 inference_config,
                 gpu_id=-1,
                 gpu_mem_frac=None):
        """Init

        :param dict train_config: Training config dict with params related
            to dataset, trainer and network
        :param dict inference_config: Read yaml file with following parameters:
            str model_dir: Path to model directory
            str/None model_fname: File name of weights in model dir (.hdf5).
             If left out, latest weights file will be selected.
            str image_dir: dir containing input images AND NOT TILES!
            str data_split: Which data (train/test/val) to run inference on.
             (default = test)
            dict images:
             str image_format: 'zyx' or 'xyz'
             str/None flat_field_dir: flatfield directory
             str im_ext: For writing predictions e.g. '.png' or '.npy' or '.tiff'
             FOR 3D IMAGES USE NPY AS PNG AND TIFF ARE CURRENTLY NOT SUPPORTED.
             list crop_shape: center crop the image to a specified shape before
             tiling for inference
            dict metrics:
             list metrics_list: list of metrics to estimate. available
             metrics: [ssim, corr, r2, mse, mae}]
             list metrics_orientations: xy, xyz, xz or yz
              (see evaluation_metrics.py for description of orientations)
            dict masks: dict with keys
             str mask_dir: Mask directory containing a frames_meta.csv containing
             mask channels (which will be target channels in the inference config)
             z, t, p indices matching the ones in image_dir. Mask dirs are often
             generated or have frames_meta added to them during preprocessing.
             str mask_type: 'target' for segmentation, 'metrics' for weighted
             int mask_channel: mask channel as in training
            dict inference_3d: dict with params for 3D inference with keys:
             num_slices, inf_shape, tile_shape, num_overlap, overlap_operation.
             int num_slices: in case of 3D, the full volume will not fit in GPU
              memory, specify the number of slices to use and this will depend on
              the network depth, for ex 8 for a network of depth 4.
             list inf_shape: inference on a center sub volume.
             list tile_shape: shape of tile for tiling along xyz.
             int/list num_overlap: int for tile_z, list for tile_xyz
             str overlap_operation: e.g. 'mean'
        :param int gpu_id: GPU number to use. -1 for debugging (no GPU)
        :param float/None gpu_mem_frac: Memory fractions to use corresponding
            to gpu_ids
        """
        # Use model_dir from inference config if present, otherwise use train
        if 'model_dir' in inference_config:
            model_dir = inference_config['model_dir']
        else:
            model_dir = train_config['trainer']['model_dir']
        if 'model_fname' in inference_config:
            model_fname = inference_config['model_fname']
        else:
            # If model filename not listed, grab latest one
            fnames = [
                f for f in os.listdir(inference_config['model_dir'])
                if f.endswith('.hdf5')
            ]
            assert len(fnames) > 0, 'No weight files found in model dir'
            fnames = natsort.natsorted(fnames)
            model_fname = fnames[-1]
        self.config = train_config
        self.model_dir = model_dir
        self.image_dir = inference_config['image_dir']

        # Set default for data split, determine column name and indices
        data_split = 'test'
        if 'data_split' in inference_config:
            data_split = inference_config['data_split']
        assert data_split in ['train', 'val', 'test'], \
            'data_split not in [train, val, test]'
        split_col_ids = self._get_split_ids(data_split)

        self.data_format = self.config['network']['data_format']
        assert self.data_format in {'channels_first', 'channels_last'}, \
            "Data format should be channels_first/last"

        flat_field_dir = None
        images_dict = inference_config['images']
        if 'flat_field_dir' in images_dict:
            flat_field_dir = images_dict['flat_field_dir']

        # Set defaults
        self.image_format = 'zyx'
        if 'image_format' in images_dict:
            self.image_format = images_dict['image_format']
        self.image_ext = '.png'
        if 'image_ext' in images_dict:
            self.image_ext = images_dict['image_ext']

        # Create image subdirectory to write predicted images
        self.pred_dir = os.path.join(self.model_dir, 'predictions')
        os.makedirs(self.pred_dir, exist_ok=True)

        # Handle masks as either targets or for masked metrics
        self.masks_dict = None
        self.mask_metrics = False
        self.mask_dir = None
        self.mask_meta = None
        target_dir = None
        if 'masks' in inference_config:
            self.masks_dict = inference_config['masks']
        if self.masks_dict is not None:
            assert 'mask_channel' in self.masks_dict, 'mask_channel is needed'
            assert 'mask_dir' in self.masks_dict, 'mask_dir is needed'
            self.mask_dir = self.masks_dict['mask_dir']
            self.mask_meta = aux_utils.read_meta(self.mask_dir)
            assert 'mask_type' in self.masks_dict, \
                'mask_type (target/metrics) is needed'
            if self.masks_dict['mask_type'] == 'metrics':
                # Compute weighted metrics
                self.mask_metrics = True
            else:
                target_dir = self.mask_dir

        # Create dataset instance
        self.dataset_inst = InferenceDataSet(
            image_dir=self.image_dir,
            dataset_config=self.config['dataset'],
            network_config=self.config['network'],
            split_col_ids=split_col_ids,
            image_format=images_dict['image_format'],
            mask_dir=target_dir,
            flat_field_dir=flat_field_dir,
        )
        # create an instance of MetricsEstimator
        self.iteration_meta = self.dataset_inst.get_iteration_meta()

        # Handle metrics config settings
        self.metrics_inst = None
        self.metrics_dict = None
        if 'metrics' in inference_config:
            self.metrics_dict = inference_config['metrics']
        if self.metrics_dict is not None:
            assert 'metrics' in self.metrics_dict,\
                'Must specify with metrics to use'
            self.metrics_inst = MetricsEstimator(
                metrics_list=self.metrics_dict['metrics'],
                masked_metrics=self.mask_metrics,
            )
            self.metrics_orientations = ['xy']
            available_orientations = ['xy', 'xyz', 'xz', 'yz']
            if 'metrics_orientations' in self.metrics_dict:
                self.metrics_orientations = \
                    self.metrics_dict['metrics_orientations']
                assert set(self.metrics_orientations).\
                    issubset(available_orientations),\
                    'orientation not in [xy, xyz, xz, yz]'
            self.df_xy = pd.DataFrame()
            self.df_xyz = pd.DataFrame()
            self.df_xz = pd.DataFrame()
            self.df_yz = pd.DataFrame()

        # Handle 3D volume inference settings
        self.num_overlap = 0
        self.stitch_inst = None
        self.tile_option = None
        self.z_dim = 2
        self.crop_shape = None
        if 'crop_shape' in images_dict:
            self.crop_shape = images_dict['crop_shape']
        if 'inference_3d' in inference_config:
            self.params_3d = inference_config['inference_3d']
            self._assign_3d_inference()
            # Make image ext npy default for 3D
            self.image_ext = '.npy'

        # Set session if not debug
        if gpu_id >= 0:
            self.sess = set_keras_session(
                gpu_ids=gpu_id,
                gpu_mem_frac=gpu_mem_frac,
            )
        # create model and load weights
        self.model = inference.load_model(
            network_config=self.config['network'],
            model_fname=os.path.join(self.model_dir, model_fname),
            predict=True,
        )
    def test_pre_process_weight_maps(self):
        cur_config = self.pp_config
        # Use preexisiting masks with more than one class, otherwise
        # weight map generation doesn't work
        cur_config['masks'] = {
            'mask_dir': self.input_mask_dir,
            'mask_channel': self.input_mask_channel,
        }
        cur_config['make_weight_map'] = True
        out_config, runtime = pp.pre_process(cur_config, self.base_config)

        # Check weights dir
        self.assertEqual(
            out_config['weights']['weights_dir'],
            os.path.join(self.output_dir, 'mask_channels_111')
        )
        weights_meta = aux_utils.read_meta(out_config['weights']['weights_dir'])
        # Check indices
        self.assertListEqual(
            weights_meta.channel_idx.unique().tolist(),
            [112],
        )
        self.assertListEqual(
            weights_meta.pos_idx.unique().tolist(),
            self.pos_ids,
        )
        self.assertListEqual(
            weights_meta.slice_idx.unique().tolist(),
            self.slice_ids,
        )
        self.assertListEqual(
            weights_meta.time_idx.unique().tolist(),
            [self.time_idx],
        )
        # Load one weights file and check contents
        im = np.load(os.path.join(
            out_config['weights']['weights_dir'],
            'im_c112_z002_t000_p007.npy',
        ))
        self.assertTupleEqual(im.shape, (30, 20))
        self.assertTrue(im.dtype == np.float64)
        # Check tiles
        tile_dir = out_config['tile']['tile_dir']
        tile_meta = aux_utils.read_meta(tile_dir)
        # 5 processed channels (0, 1, 3, 111, 112), 6 tiles per image
        expected_rows = 5 * 6 * len(self.slice_ids) * len(self.pos_ids)
        self.assertEqual(tile_meta.shape[0], expected_rows)
        # Check indices
        self.assertListEqual(
            tile_meta.channel_idx.unique().tolist(),
            [0, 1, 3, 111, 112],
        )
        self.assertListEqual(
            tile_meta.pos_idx.unique().tolist(),
            self.pos_ids,
        )
        self.assertListEqual(
            tile_meta.slice_idx.unique().tolist(),
            self.slice_ids,
        )
        self.assertListEqual(
            tile_meta.time_idx.unique().tolist(),
            [self.time_idx],
        )
        # Load one tile
        im = np.load(os.path.join(
            tile_dir,
            'im_c111_z002_t000_p008_r0-10_c10-20_sl0-1.npy',
        ))
        self.assertTupleEqual(im.shape, (1, 10, 10))
        self.assertTrue(im.dtype == bool)