def setUp(self):
     """
     Set up a directory with some images to resample
     """
     self.tempdir = TempDirectory()
     self.temp_path = self.tempdir.path
     self.mask_dir = os.path.join(self.temp_path, 'mask_dir')
     self.tempdir.makedir('mask_dir')
     self.input_dir = os.path.join(self.temp_path, 'input_dir')
     self.tempdir.makedir('input_dir')
     self.mask_channel = 1
     self.slice_idx = 7
     self.time_idx = 8
     # Mask meta file
     self.csv_name = 'mask_image_matchup.csv'
     input_meta = aux_utils.make_dataframe()
     # Make input meta
     for c in range(4):
         for p in range(10):
             im_name = aux_utils.get_im_name(
                 channel_idx=c,
                 slice_idx=self.slice_idx,
                 time_idx=self.time_idx,
                 pos_idx=p,
             )
             input_meta = input_meta.append(
                 aux_utils.parse_idx_from_name(im_name),
                 ignore_index=True,
             )
     input_meta.to_csv(
         os.path.join(self.input_dir, 'frames_meta.csv'),
         sep=',',
     )
     # Make mask meta
     mask_meta = pd.DataFrame()
     for p in range(10):
         im_name = aux_utils.get_im_name(
             channel_idx=self.mask_channel,
             slice_idx=self.slice_idx,
             time_idx=self.time_idx,
             pos_idx=p,
         )
         # Indexing can be different
         mask_name = 'mask_{}.png'.format(p + 1)
         mask_meta = mask_meta.append(
             {'mask_name': mask_name, 'file_name': im_name},
             ignore_index=True,
         )
     mask_meta.to_csv(
         os.path.join(self.mask_dir, self.csv_name),
         sep=',',
     )
Exemple #2
0
    def setUp(self):
        """
        Set up a directory with some images to resample
        """
        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.output_dir = os.path.join(self.temp_path, 'out_dir')
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        self.frames_meta = aux_utils.make_dataframe()
        # Write images
        self.time_idx = 5
        self.slice_idx = 6
        self.pos_idx = 7
        self.im = 1500 * np.ones((30, 20), dtype=np.uint16)

        for c in range(4):
            for p in range(self.pos_idx, self.pos_idx + 2):
                im_name = aux_utils.get_im_name(
                    channel_idx=c,
                    slice_idx=self.slice_idx,
                    time_idx=self.time_idx,
                    pos_idx=p,
                )
                cv2.imwrite(os.path.join(self.temp_path, im_name),
                            self.im + c * 100)
                self.frames_meta = self.frames_meta.append(
                    aux_utils.parse_idx_from_name(im_name),
                    ignore_index=True,
                )
        # Write metadata
        self.frames_meta.to_csv(
            os.path.join(self.temp_path, self.meta_name),
            sep=',',
        )
Exemple #3
0
def write_tile(tile, save_dict, img_id):
    """
    Write tile function that can be called using threading.

    :param np.array tile: one tile
    :param dict save_dict: dict with keys: time_idx, channel_idx, slice_idx,
     pos_idx, image_format and save_dir for generation output fname
    :param str img_id: tile related indices as string
    :return str op_fname: filename used for saving the tile with entire path
    """

    file_name = aux_utils.get_im_name(
        time_idx=save_dict['time_idx'],
        channel_idx=save_dict['channel_idx'],
        slice_idx=save_dict['slice_idx'],
        pos_idx=save_dict['pos_idx'],
        int2str_len=save_dict['int2str_len'],
        extra_field=img_id,
        ext='.npy',
    )
    op_fname = os.path.join(save_dict['save_dir'], file_name)
    if save_dict['image_format'] == 'zyx' and len(tile.shape) > 2:
        tile = np.transpose(tile, (2, 0, 1))
    np.save(op_fname, tile, allow_pickle=True, fix_imports=True)
    return file_name
Exemple #4
0
    def setUp(self):
        """Set up a dictionary with images"""

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        meta_fname = 'frames_meta.csv'
        self.df_columns = ['channel_idx',
                           'slice_idx',
                           'time_idx',
                           'channel_name',
                           'file_name',
                           'pos_idx']
        self.frames_meta = pd.DataFrame(columns=self.df_columns)

        x = np.linspace(-4, 4, 32)
        y = x.copy()
        z = np.linspace(-3, 3, 8)
        xx, yy, zz = np.meshgrid(x, y, z)
        sph = (xx ** 2 + yy ** 2 + zz ** 2)
        sph = (sph <= 8) * (8 - sph)
        sph = (sph / sph.max()) * 255
        sph = sph.astype('uint8')
        self.sph = sph

        self.channel_idx = 1
        self.time_idx = 0
        self.pos_idx = 1
        self.int2str_len = 3

        for z in range(sph.shape[2]):
            im_name = aux_utils.get_im_name(
                channel_idx=1,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx,
            )
            cv2.imwrite(os.path.join(self.temp_path, im_name), sph[:, :, z])
            self.frames_meta = self.frames_meta.append(
                aux_utils.parse_idx_from_name(im_name, self.df_columns),
                ignore_index=True,
            )

        # Write metadata
        self.frames_meta.to_csv(os.path.join(self.temp_path, meta_fname), sep=',')
        # Write 3D sphere data
        self.sph_fname = os.path.join(
            self.temp_path,
            'im_c001_z000_t000_p001_3d.npy',
        )
        np.save(self.sph_fname, self.sph, allow_pickle=True, fix_imports=True)
        meta_3d = pd.DataFrame.from_dict([{
            'channel_idx': 1,
            'slice_idx': 0,
            'time_idx': 0,
            'channel_name': '3d_test',
            'file_name': 'im_c001_z000_t000_p001_3d.npy',
            'pos_idx': 1,
        }])
        self.meta_3d = meta_3d
    def test_tile_mask_stack(self):
        """Test tile_mask_stack"""

        # create a mask
        mask_dir = os.path.join(self.temp_path, 'mask_dir')
        os.makedirs(mask_dir, exist_ok=True)
        mask_images = np.zeros((15, 11, 5), dtype='bool')
        mask_images[4:12, 4:9, 2:4] = 1

        # timepoints for testing
        mask_meta = []
        for z in range(5):
            for t in range(3):
                cur_im = mask_images[:, :, z]
                im_name = aux_utils.get_im_name(
                    channel_idx=3,
                    slice_idx=z,
                    time_idx=t,
                    pos_idx=self.pos_idx1,
                    ext='.npy',
                )
                np.save(os.path.join(mask_dir, im_name), cur_im)
                cur_meta = {'channel_idx': 3,
                            'slice_idx': z,
                            'time_idx': t,
                            'pos_idx': self.pos_idx1,
                            'file_name': im_name}
                mask_meta.append(cur_meta)
        mask_meta_df = pd.DataFrame.from_dict(mask_meta)
        mask_meta_df.to_csv(os.path.join(mask_dir, 'frames_meta.csv'), sep=',')

        self.tile_inst.pos_ids = [7]

        self.tile_inst.normalize_channels = [None, None, None, False]

        self.tile_inst.tile_mask_stack(mask_dir,
                                       mask_channel=3,
                                       min_fraction=0.5,
                                       mask_depth=3)
        nose.tools.assert_equal(self.tile_inst.mask_depth, 3)

        frames_meta = pd.read_csv(os.path.join(self.tile_inst.tile_dir,
                                               'frames_meta.csv'),
                                  sep=',')
        # only 4 tiles have >= min_fraction. 4 tiles x 3 slices x 3 tps
        nose.tools.assert_equal(len(frames_meta), 36)
        nose.tools.assert_list_equal(
            frames_meta['row_start'].unique().tolist(),
            [4, 8])
        nose.tools.assert_equal(frames_meta['col_start'].unique().tolist(),
                                [4])
        nose.tools.assert_equal(frames_meta['slice_idx'].unique().tolist(),
                                [2, 3])
        self.assertSetEqual(set(frames_meta.channel_idx.tolist()), {1, 2, 3})
        self.assertSetEqual(set(frames_meta.time_idx.tolist()), {0, 1, 2})
        self.assertSetEqual(set(frames_meta.pos_idx.tolist()), {self.pos_idx1})
Exemple #6
0
def test_get_im_name():
    im_name = aux_utils.get_im_name(
        time_idx=1,
        channel_idx=2,
        slice_idx=3,
        pos_idx=4,
        extra_field='hej',
        int2str_len=1,
    )
    nose.tools.assert_equal(im_name, 'im_c2_z3_t1_p4_hej.png')
Exemple #7
0
    def run_prediction(self):
        """Run prediction for entire 2D image or a 3D stack"""

        pos_ids = self.iteration_meta['pos_idx'].unique()
        for idx, pos_idx in enumerate(pos_ids):
            print('Inference idx {}/{}'.format(idx, len(pos_ids)))
            iteration_rows = self.iteration_meta.index[
                self.iteration_meta['pos_idx'] == pos_idx, ].values
            if self.tile_option is None:
                # 2D, 2.5D
                pred_image, target_image, mask_image = self.predict_2d(
                    iteration_rows, )
            else:  # 3D
                pred_image, target_image, mask_image = self.predict_3d(
                    iteration_rows, )
            pred_fnames = []
            for row_idx in iteration_rows:
                cur_row = self.iteration_meta.iloc[row_idx]
                pred_fname = aux_utils.get_im_name(
                    time_idx=cur_row['time_idx'],
                    channel_idx=cur_row['channel_idx'],
                    slice_idx=cur_row['slice_idx'],
                    pos_idx=cur_row['pos_idx'],
                    ext='',
                )
                pred_fnames.append(pred_fname)
            if self.metrics_inst is not None:
                if not self.mask_metrics:
                    mask_image = None
                self.estimate_metrics(
                    target=target_image,
                    prediction=pred_image,
                    pred_fnames=pred_fnames,
                    mask=mask_image,
                )
                del pred_image, target_image

        # Save metrics csv files
        if self.metrics_inst is not None:
            metrics_mapping = {
                'xy': self.df_xy,
                'xz': self.df_xz,
                'yz': self.df_yz,
                'xyz': self.df_xyz,
            }
            for orientation in self.metrics_orientations:
                metrics_df = metrics_mapping[orientation]
                df_name = 'metrics_{}.csv'.format(orientation)
                metrics_df.to_csv(
                    os.path.join(self.pred_dir, df_name),
                    sep=',',
                    index=False,
                )
Exemple #8
0
    def test_create_save_mask_otsu(self):
        """test create_save_mask otsu"""
        self.write_mask_data()
        for sl_idx in range(8):
            input_fnames = [
                'im_c001_z00{}_t000_p001.png'.format(sl_idx),
                'im_c002_z00{}_t000_p001.png'.format(sl_idx)
            ]
            input_fnames = [
                os.path.join(self.temp_path, fname) for fname in input_fnames
            ]
            cur_meta = mp_utils.create_save_mask(tuple(input_fnames),
                                                 None,
                                                 str_elem_radius=1,
                                                 mask_dir=self.output_dir,
                                                 mask_channel_idx=3,
                                                 time_idx=self.time_ids,
                                                 pos_idx=self.pos_ids,
                                                 slice_idx=sl_idx,
                                                 int2str_len=3,
                                                 mask_type='otsu',
                                                 mask_ext='.png')
            fname = aux_utils.get_im_name(
                time_idx=self.time_ids,
                channel_idx=3,
                slice_idx=sl_idx,
                pos_idx=self.pos_ids,
            )
            exp_meta = {
                'channel_idx': 3,
                'slice_idx': sl_idx,
                'time_idx': 0,
                'pos_idx': 1,
                'file_name': fname
            }
            nose.tools.assert_dict_equal(cur_meta, exp_meta)

            op_fname = os.path.join(self.output_dir, fname)
            nose.tools.assert_equal(os.path.exists(op_fname), True)

            mask_image = image_utils.read_image(op_fname)
            if mask_image.dtype != bool:
                mask_image = mask_image > 0
            input_image = (self.sph_object[:, :,
                                           sl_idx], self.rect_object[:, :,
                                                                     sl_idx])
            mask_stack = np.stack([
                create_otsu_mask(input_image[0], str_elem_size=1),
                create_otsu_mask(input_image[1], str_elem_size=1)
            ])
            mask_exp = np.any(mask_stack, axis=0)
            numpy.testing.assert_array_equal(mask_image, mask_exp)
Exemple #9
0
    def test_create_save_mask_border_map(self):
        """test create_save_mask border weight map"""
        self.write_mask_data()
        for sl_idx in range(1):
            input_fnames = ['im_c001_z00{}_t000_p001.png'.format(sl_idx)]
            input_fnames = [
                os.path.join(self.temp_path, fname) for fname in input_fnames
            ]
            cur_meta = mp_utils.create_save_mask(
                tuple(input_fnames),
                None,
                str_elem_radius=1,
                mask_dir=self.output_dir,
                mask_channel_idx=2,
                time_idx=self.time_ids,
                pos_idx=self.pos_ids,
                slice_idx=sl_idx,
                int2str_len=3,
                mask_type='borders_weight_loss_map',
                mask_ext='.png')
            fname = aux_utils.get_im_name(
                time_idx=self.time_ids,
                channel_idx=2,
                slice_idx=sl_idx,
                pos_idx=self.pos_ids,
            )
            exp_meta = {
                'channel_idx': 2,
                'slice_idx': sl_idx,
                'time_idx': 0,
                'pos_idx': 1,
                'file_name': fname
            }
            nose.tools.assert_dict_equal(cur_meta, exp_meta)

            op_fname = os.path.join(self.output_dir, fname)
            nose.tools.assert_equal(os.path.exists(op_fname), True)
            weight_map = image_utils.read_image(op_fname)
            max_weight_map = np.max(weight_map)
            # weight map between 20, 16 and 44, 16 should be maximum
            # as there is more weight when two objects boundaries overlap
            y_coord = self.params[0][1]
            for x_coord in range(self.params[0][0] + self.radius,
                                 self.params[1][0] - self.radius):
                distance_near_intersection = weight_map[x_coord, y_coord]
                nose.tools.assert_equal(max_weight_map,
                                        distance_near_intersection)
Exemple #10
0
 def setUp(self):
     """
     Set up a dataframe for training table
     """
     # Start frames meta file
     self.meta_name = 'frames_meta.csv'
     self.frames_meta = aux_utils.make_dataframe()
     self.time_ids = [3, 4, 5]
     self.pos_ids = [7, 8, 10, 12, 15]
     self.channel_ids = [0, 1, 2, 3]
     self.slice_ids = [0, 1, 2, 3, 4, 5]
     # Tiles will typically be split into image subsections
     # but it doesn't matter for testing
     for c in self.channel_ids:
         for p in self.pos_ids:
             for z in self.slice_ids:
                 for t in self.time_ids:
                     im_name = aux_utils.get_im_name(
                         channel_idx=c,
                         slice_idx=z,
                         time_idx=t,
                         pos_idx=p,
                     )
                     self.frames_meta = self.frames_meta.append(
                         aux_utils.parse_idx_from_name(im_name),
                         ignore_index=True,
                     )
     self.tiles_meta = aux_utils.sort_meta_by_channel(self.frames_meta)
     self.input_channels = [0, 2]
     self.target_channels = [3]
     self.mask_channels = [1]
     self.split_ratio = {
         'train': 0.6,
         'val': 0.2,
         'test': 0.2,
     }
     # Instantiate class
     self.table_inst = training_table.BaseTrainingTable(
         df_metadata=self.tiles_meta,
         input_channels=self.input_channels,
         target_channels=self.target_channels,
         split_by_column='pos_idx',
         split_ratio=self.split_ratio,
         mask_channels=[1],
         random_seed=42,
     )
Exemple #11
0
 def test_adjust_slice_indices(self):
     # First create new frames meta with more slices
     temp_meta = aux_utils.make_dataframe()
     for s in range(10):
         im_name = aux_utils.get_im_name(
             time_idx=2,
             channel_idx=4,
             slice_idx=s,
             pos_idx=6,
         )
         temp_meta = temp_meta.append(
             aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
             ignore_index=True,
         )
     self.data_inst.iteration_meta = temp_meta
     self.data_inst.depth = 5
     # This should remove first and last two slices
     self.data_inst.adjust_slice_indices()
     # Original slice ids are 0-9 so after removing margins should be 2-7
     self.assertListEqual(
         self.data_inst.iteration_meta.slice_idx.unique().tolist(),
         [2, 3, 4, 5, 6, 7])
Exemple #12
0
    def save_pred_image(self, predicted_image, time_idx, target_channel_idx,
                        pos_idx, slice_idx):
        """
        Save predicted images with image extension given in init.

        :param np.array predicted_image: 2D / 3D predicted image
        :param int time_idx: time index
        :param int target_channel_idx: target / predicted channel index
        :param int pos_idx: FOV / position index
        :param int slice_idx: slice index
        """
        # Write prediction image
        im_name = aux_utils.get_im_name(
            time_idx=time_idx,
            channel_idx=target_channel_idx,
            slice_idx=slice_idx,
            pos_idx=pos_idx,
            ext=self.image_ext,
        )
        file_name = os.path.join(self.pred_dir, im_name)
        im_pred = predicted_image.astype(np.float32)
        if self.image_ext == '.png':
            # Convert to uint16 for now
            if im_pred.max() > im_pred.min():
                im_pred = np.iinfo(np.uint16).max * \
                          (im_pred - im_pred.min()) / \
                          (im_pred.max() - im_pred.min())
            else:
                im_pred = im_pred / im_pred.max() * np.iinfo(np.uint16).max
            im_pred = im_pred.astype(np.uint16)
            cv2.imwrite(file_name, np.squeeze(im_pred))
        elif self.image_ext == '.tif':
            cv2.imwrite(file_name, np.squeeze(im_pred))
        elif self.image_ext == '.npy':
            np.save(file_name, im_pred, allow_pickle=True)
        else:
            raise ValueError(
                'Unsupported file extension: {}'.format(self.image_ext), )
Exemple #13
0
    def test_get_tiled_data(self):
        """Test get_tiled_indices"""

        # no tiles_exist
        tile_meta, tile_indices = self.tile_inst._get_tiled_data()
        nose.tools.assert_equal(tile_indices, None)
        init_df = pd.DataFrame(columns=[
            'channel_idx', 'slice_idx', 'time_idx', 'file_name', 'pos_idx',
            'row_start', 'col_start'
        ])
        pd.testing.assert_frame_equal(tile_meta, init_df)
        # tile exists
        self.tile_inst.tile_stack()
        self.tile_inst.tiles_exist = True
        self.tile_inst.channel_ids = [1, 2]
        tile_meta, _ = self.tile_inst._get_tiled_data()

        exp_tile_meta = []
        for exp_idx in self.exp_tile_indices:
            for z in [16, 17, 18]:
                cur_img_id = 'r{}-{}_c{}-{}_sl{}-{}'.format(
                    exp_idx[0], exp_idx[1], exp_idx[2], exp_idx[3], 0, 3)
                pos1_fname = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.channel_idx,
                    slice_idx=z,
                    pos_idx=self.pos_idx1,
                    extra_field=cur_img_id,
                    ext='.npy',
                )
                pos1_meta = {
                    'channel_idx': self.channel_idx,
                    'slice_idx': z,
                    'time_idx': self.time_idx,
                    'file_name': pos1_fname,
                    'pos_idx': self.pos_idx1,
                    'row_start': exp_idx[0],
                    'col_start': exp_idx[2]
                }
                exp_tile_meta.append(pos1_meta)
                pos2_fname = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.channel_idx,
                    slice_idx=z,
                    pos_idx=self.pos_idx2,
                    extra_field=cur_img_id,
                    ext='.npy',
                )
                pos2_meta = {
                    'channel_idx': self.channel_idx,
                    'slice_idx': z,
                    'time_idx': self.time_idx,
                    'file_name': pos2_fname,
                    'pos_idx': self.pos_idx2,
                    'row_start': exp_idx[0],
                    'col_start': exp_idx[2]
                }
                exp_tile_meta.append(pos2_meta)
        exp_tile_meta_df = pd.DataFrame.from_dict(exp_tile_meta)
        exp_tile_meta_df = exp_tile_meta_df.sort_values(by=['file_name'])
        exp_tile_meta_df.reset_index(drop=True, inplace=True)
        tile_meta = tile_meta.sort_values(by=['file_name'])
        tile_meta.reset_index(drop=True, inplace=True)
        pd.testing.assert_frame_equal(tile_meta, exp_tile_meta_df)
Exemple #14
0
    def test_resize_volumes(self):
        """Test resizing volumes"""

        # set up a volume with 5 slices, 2 channels
        slice_ids = [0, 1, 2, 3, 4]
        channel_ids = [2, 3]
        frames_meta = aux_utils.make_dataframe()
        exp_meta_dict = []
        for c in channel_ids:
            for s in slice_ids:
                im_name = aux_utils.get_im_name(
                    channel_idx=c,
                    slice_idx=s,
                    time_idx=self.time_idx,
                    pos_idx=self.pos_idx,
                )
                cv2.imwrite(os.path.join(self.temp_path, im_name),
                            self.im + c * 100)
                frames_meta = frames_meta.append(
                    aux_utils.parse_idx_from_name(im_name),
                    ignore_index=True,
                )
            op_fname = 'im_c00{}_z000_t005_p007_3.3-0.8-1.0.npy'.format(c)
            exp_meta_dict.append({'time_idx': self.time_idx,
                                  'pos_idx': self.pos_idx,
                                  'channel_idx': c,
                                  'slice_idx': 0,
                                  'file_name': op_fname})
        # Write metadata
        frames_meta.to_csv(
            os.path.join(self.temp_path, self.meta_name),
            sep=',',
        )

        scale_factor = [3.3, 0.8, 1.0]
        resize_inst = resize_images.ImageResizer(
            input_dir=self.temp_path,
            output_dir=self.output_dir,
            scale_factor=scale_factor,
        )

        # save all slices in one volume
        resize_inst.resize_volumes()
        saved_meta = pd.read_csv(os.path.join(self.output_dir,
                                              'resized_images',
                                              'frames_meta.csv'))
        del saved_meta['Unnamed: 0']
        exp_meta_df = pd.DataFrame.from_dict(exp_meta_dict)
        pd.testing.assert_frame_equal(saved_meta, exp_meta_df)

        # num_slices_subvolume = 3, save vol chunks
        exp_meta_dict = []
        for c in channel_ids:
            for s in [0, 2]:
                op_fname = 'im_c00{}_z00{}_t005_p007_3.3-0.8-1.0.npy'.format(c,
                                                                             s)
                exp_meta_dict.append({'time_idx': self.time_idx,
                                      'pos_idx': self.pos_idx,
                                      'channel_idx': c,
                                      'slice_idx': s,
                                      'file_name': op_fname})

        resize_inst.resize_volumes(num_slices_subvolume=3)
        saved_meta = pd.read_csv(os.path.join(self.output_dir,
                                              'resized_images',
                                              'frames_meta.csv'))
        del saved_meta['Unnamed: 0']
        exp_meta_df = pd.DataFrame.from_dict(exp_meta_dict)
        pd.testing.assert_frame_equal(saved_meta, exp_meta_df)
Exemple #15
0
    def resize_volumes(self, num_slices_subvolume=-1):
        """Down or up sample volumes

        Overlap of one slice across subvolumes

        :param int num_slices_subvolume: num of 2D slices to include in each
         volume. if -1, include all slices
        """

        # assuming slice_ids will be continuous
        num_total_slices = len(self.slice_ids)
        if not isinstance(self.scale_factor, float):
            sc_str = '-'.join(self.scale_factor.astype('str'))
        else:
            sc_str = self.scale_factor

        mp_args = []
        resized_metadata_list = []
        if num_slices_subvolume == -1:
            num_slices_subvolume = len(self.slice_ids)
        num_blocks = np.floor(
            num_total_slices / (num_slices_subvolume - 1)
        ).astype('int')
        for time_idx in self.time_ids:
            for pos_idx in self.pos_ids:
                for channel_idx in self.channel_ids:
                    ff_path = None
                    if self.flat_field_dir is not None:
                        ff_path = os.path.join(
                            self.flat_field_dir,
                            'flat-field_channel-{}.npy'.format(channel_idx)
                        )
                    for block_idx in range(num_blocks):
                        idx = self.slice_ids[0] + \
                              block_idx * (num_slices_subvolume - 1)
                        start_idx = np.maximum(self.slice_ids[0], idx)
                        end_idx = start_idx + num_slices_subvolume
                        if end_idx > self.slice_ids[-1]:
                            end_idx = self.slice_ids[-1] + 1
                            start_idx = end_idx - num_slices_subvolume
                        op_fname = aux_utils.get_im_name(
                            time_idx,
                            channel_idx,
                            start_idx,
                            pos_idx,
                            extra_field=sc_str,
                            ext='.npy',
                        )
                        write_fpath = os.path.join(self.resize_dir, op_fname)
                        mp_args.append((time_idx,
                                        pos_idx,
                                        channel_idx,
                                        start_idx,
                                        end_idx,
                                        self.frames_metadata,
                                        write_fpath,
                                        self.scale_factor,
                                        self.input_dir,
                                        ff_path))
                        cur_metadata = {'time_idx': time_idx,
                                        'pos_idx': pos_idx,
                                        'channel_idx': channel_idx,
                                        'slice_idx': start_idx,
                                        'file_name': op_fname}
                        resized_metadata_list.append(cur_metadata)

        # Multiprocessing of kwargs
        mp_utils.mp_rescale_vol(mp_args, self.num_workers)
        resized_metadata_df = pd.DataFrame.from_dict(resized_metadata_list)
        resized_metadata_df.to_csv(
            os.path.join(self.resize_dir, 'frames_meta.csv'),
            sep=',',
        )

        if num_slices_subvolume == -1:
            slice_ids = self.slice_ids[0]
        else:
            slice_ids = self.slice_ids[0: -1: num_slices_subvolume - 1]

        return slice_ids
    def setUp(self):
        """
        Set up a directory with some images to resample
        """
        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.image_dir = self.temp_path
        self.output_dir = os.path.join(self.temp_path, 'out_dir')
        self.tempdir.makedir(self.output_dir)
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        self.frames_meta = aux_utils.make_dataframe()
        # Write images
        self.time_idx = 0
        self.pos_ids = [7, 8, 10]
        self.channel_ids = [0, 1, 2, 3]
        self.slice_ids = [0, 1, 2, 3, 4, 5]
        self.im = 1500 * np.ones((30, 20), dtype=np.uint16)
        self.im[10:20, 5:15] = 3000

        for c in self.channel_ids:
            for p in self.pos_ids:
                for z in self.slice_ids:
                    im_name = aux_utils.get_im_name(
                        channel_idx=c,
                        slice_idx=z,
                        time_idx=self.time_idx,
                        pos_idx=p,
                    )
                    cv2.imwrite(
                        os.path.join(self.image_dir, im_name),
                        self.im + c * 100,
                    )
                    self.frames_meta = self.frames_meta.append(
                        aux_utils.parse_idx_from_name(im_name),
                        ignore_index=True,
                    )
        # Write metadata
        self.frames_meta.to_csv(
            os.path.join(self.image_dir, self.meta_name),
            sep=',',
        )
        # Make input masks
        self.input_mask_channel = 111
        self.input_mask_dir = os.path.join(self.temp_path, 'input_mask_dir')
        self.tempdir.makedir(self.input_mask_dir)
        # Must have at least two foreground classes in mask for weight map to work
        mask = np.zeros((30, 20), dtype=np.uint16)
        mask[5:10, 5:15] = 1
        mask[20:25, 5:10] = 2
        mask_meta = aux_utils.make_dataframe()
        for p in self.pos_ids:
            for z in self.slice_ids:
                im_name = aux_utils.get_im_name(
                    channel_idx=self.input_mask_channel,
                    slice_idx=z,
                    time_idx=self.time_idx,
                    pos_idx=p,
                )
                cv2.imwrite(
                    os.path.join(self.input_mask_dir, im_name),
                    mask,
                )
                mask_meta = mask_meta.append(
                    aux_utils.parse_idx_from_name(im_name),
                    ignore_index=True,
                )
        mask_meta.to_csv(
            os.path.join(self.input_mask_dir, self.meta_name),
            sep=',',
        )
        # Create preprocessing config
        self.pp_config = {
            'output_dir': self.output_dir,
            'input_dir': self.image_dir,
            'channel_ids': [0, 1, 3],
            'num_workers': 4,
            'flat_field': {'estimate': True,
                           'block_size': 2,
                           'correct': True},
            'masks': {'channels': [3],
                      'str_elem_radius': 3,
                      'normalize_im': False},
            'tile': {'tile_size': [10, 10],
                     'step_size': [10, 10],
                     'depths': [1, 1, 1],
                     'mask_depth': 1,
                     'image_format': 'zyx',
                     'normalize_channels': [True, True, True]
                     },
        }
        # Create base config, generated party from pp_config in script
        self.base_config = {
            'input_dir': self.image_dir,
            'output_dir': self.output_dir,
            'slice_ids': -1,
            'time_ids': -1,
            'pos_ids': -1,
            'channel_ids': self.pp_config['channel_ids'],
            'uniform_struct': True,
            'int2strlen': 3,
            'num_workers': 4,
            'normalize_channels': [True, True, True]
        }
Exemple #17
0
def create_save_mask(input_fnames,
                     flat_field_fname,
                     str_elem_radius,
                     mask_dir,
                     mask_channel_idx,
                     time_idx,
                     pos_idx,
                     slice_idx,
                     int2str_len,
                     mask_type,
                     mask_ext,
                     normalize_im=False):

    """
    Create and save mask.
    When >1 channel are used to generate the mask, mask of each channel is
    generated then added together.

    :param tuple input_fnames: tuple of input fnames with full path
    :param str flat_field_fname: fname of flat field image
    :param int str_elem_radius: size of structuring element used for binary
     opening. str_elem: disk or ball
    :param str mask_dir: dir to save masks
    :param int mask_channel_idx: channel number of mask
    :param int time_idx: time points to use for generating mask
    :param int pos_idx: generate masks for given position / sample ids
    :param int slice_idx: generate masks for given slice ids
    :param int int2str_len: Length of str when converting ints
    :param str mask_type: thresholding type used for masking or str to map to
     masking function
    :param str mask_ext: '.npy' or '.png'. Save the mask as uint8 PNG or
     NPY files for otsu, unimodal masks, recommended to save as npy
     float64 for borders_weight_loss_map masks to avoid loss due to scaling it
     to uint8.
    :param bool normalize_im: indicator to normalize image based on z-score or not
    :return dict cur_meta for each mask
    """
    im_stack = image_utils.read_imstack(
        input_fnames,
        flat_field_fname,
        normalize_im=normalize_im,
    )
    masks = []
    for idx in range(im_stack.shape[-1]):
        im = im_stack[..., idx]
        if mask_type == 'otsu':
            mask = mask_utils.create_otsu_mask(im.astype('float32'), str_elem_radius)
        elif mask_type == 'unimodal':
            mask = mask_utils.create_unimodal_mask(im.astype('float32'), str_elem_radius)
        elif mask_type == 'borders_weight_loss_map':
            mask = mask_utils.get_unet_border_weight_map(im)
        masks += [mask]
    # Border weight map mask is a float mask not binary like otsu or unimodal,
    # so keep it as is (assumes only one image in stack)
    if mask_type == 'borders_weight_loss_map':
        mask = masks[0]
    else:
        masks = np.stack(masks, axis=-1)
        mask = np.any(masks, axis=-1)

    # Create mask name for given slice, time and position
    file_name = aux_utils.get_im_name(
        time_idx=time_idx,
        channel_idx=mask_channel_idx,
        slice_idx=slice_idx,
        pos_idx=pos_idx,
        int2str_len=int2str_len,
        ext=mask_ext,
    )
    if mask_ext == '.npy':
        # Save mask for given channels, mask is 2D
        np.save(os.path.join(mask_dir, file_name),
                mask,
                allow_pickle=True,
                fix_imports=True)
    elif mask_ext == '.png':
        # Covert mask to uint8
        # Border weight map mask is a float mask not binary like otsu or unimodal,
        # so keep it as is
        if mask_type == 'borders_weight_loss_map':
            assert im_stack.shape[-1] == 1
            # Note: Border weight map mask should only be generated from one binary image
        else:
            mask = mask.astype(np.uint8) * np.iinfo(np.uint8).max
        cv2.imwrite(os.path.join(mask_dir, file_name), mask)
    else:
        raise ValueError("mask_ext can be '.npy' or '.png', not {}".format(mask_ext))
    cur_meta = {'channel_idx': mask_channel_idx,
                'slice_idx': slice_idx,
                'time_idx': time_idx,
                'pos_idx': pos_idx,
                'file_name': file_name}
    return cur_meta
Exemple #18
0
    def setUp(self):
        """Set up a directory for mask generation, no flatfield"""

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.meta_fname = 'frames_meta.csv'
        frames_meta = aux_utils.make_dataframe()

        # create an image with bimodal hist
        x = np.linspace(-4, 4, 32)
        y = x.copy()
        z = np.linspace(-3, 3, 8)
        xx, yy, zz = np.meshgrid(x, y, z)
        sph = (xx**2 + yy**2 + zz**2)
        fg = (sph <= 8) * (8 - sph)
        fg[fg > 1e-8] = (fg[fg > 1e-8] / np.max(fg)) * 127 + 128
        fg = np.around(fg).astype('uint8')
        bg = np.around((sph > 8) * sph).astype('uint8')
        object1 = fg + bg

        # create an image with a rect
        rec = np.zeros(sph.shape)
        rec[3:30, 14:18, 3:6] = 120
        rec[14:18, 3:30, 3:6] = 120

        self.sph_object = object1
        self.rec_object = rec

        self.channel_ids = [1, 2]
        self.time_ids = 0
        self.pos_ids = 1
        self.int2str_len = 3

        for z in range(sph.shape[2]):
            im_name = aux_utils.get_im_name(
                time_idx=self.time_ids,
                channel_idx=1,
                slice_idx=z,
                pos_idx=self.pos_ids,
            )
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                sk_im_io.imsave(
                    os.path.join(self.temp_path, im_name),
                    object1[:, :, z].astype('uint8'),
                )
            frames_meta = frames_meta.append(aux_utils.parse_idx_from_name(
                im_name, aux_utils.DF_NAMES),
                                             ignore_index=True)
        for z in range(rec.shape[2]):
            im_name = aux_utils.get_im_name(
                time_idx=self.time_ids,
                channel_idx=2,
                slice_idx=z,
                pos_idx=self.pos_ids,
            )
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                sk_im_io.imsave(
                    os.path.join(self.temp_path, im_name),
                    rec[:, :, z].astype('uint8'),
                )
            frames_meta = frames_meta.append(aux_utils.parse_idx_from_name(
                im_name, aux_utils.DF_NAMES),
                                             ignore_index=True)
        # Write metadata
        frames_meta.to_csv(os.path.join(self.temp_path, self.meta_fname),
                           sep=',')

        self.output_dir = os.path.join(self.temp_path, 'mask_dir')
        self.mask_gen_inst = MaskProcessor(input_dir=self.temp_path,
                                           output_dir=self.output_dir,
                                           channel_ids=self.channel_ids)
Exemple #19
0
    def setUp(self, mock_model):
        """
        Set up a directory with 3D images
        """
        mock_model.return_value = 'dummy_model'

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.tempdir.makedir('image_dir')
        self.tempdir.makedir('mask_dir')
        self.tempdir.makedir('model_dir')
        self.image_dir = os.path.join(self.temp_path, 'image_dir')
        self.mask_dir = os.path.join(self.temp_path, 'mask_dir')
        self.model_dir = os.path.join(self.temp_path, 'model_dir')
        # Create a temp image dir
        self.im = np.zeros((10, 10, 8), dtype=np.uint8)
        self.frames_meta = aux_utils.make_dataframe()
        self.time_idx = 2
        self.slice_idx = 0
        for p in range(5):
            for c in range(3):
                im_name = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=c,
                    slice_idx=self.slice_idx,
                    pos_idx=p,
                    ext='.npy',
                )
                np.save(os.path.join(self.image_dir, im_name),
                        self.im + c * 10,
                        allow_pickle=True,
                        fix_imports=True)
                self.frames_meta = self.frames_meta.append(
                    aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                    ignore_index=True,
                )
        # Write frames meta to image dir too
        self.frames_meta.to_csv(os.path.join(self.image_dir,
                                             'frames_meta.csv'))
        # Save masks and mask meta
        self.mask_meta = aux_utils.make_dataframe()
        self.mask_channel = 50
        # Mask half the image
        mask = np.zeros_like(self.im)
        mask[:5, ...] = 1
        for p in range(5):
            im_name = aux_utils.get_im_name(
                time_idx=self.time_idx,
                channel_idx=self.mask_channel,
                slice_idx=self.slice_idx,
                pos_idx=p,
                ext='.npy',
            )
            np.save(os.path.join(self.mask_dir, im_name), mask)
            self.mask_meta = self.mask_meta.append(
                aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                ignore_index=True,
            )
        # Write frames meta to mask dir too
        self.mask_meta.to_csv(os.path.join(self.mask_dir, 'frames_meta.csv'))
        # Setup model dir
        split_samples = {
            "train": [0, 1],
            "val": [2],
            "test": [3, 4],
        }
        aux_utils.write_json(
            split_samples,
            os.path.join(self.model_dir, 'split_samples.json'),
        )
        # Make configs with fields necessary for 2.5D segmentation inference
        self.train_config = {
            'network': {
                'class': 'UNet3D',
                'data_format': 'channels_first',
                'num_filters_per_block': [8, 16],
                'depth': 5,
                'width': 5,
                'height': 5
            },
            'dataset': {
                'split_by_column': 'pos_idx',
                'input_channels': [1],
                'target_channels': [2],
                'model_task': 'regression',
            },
        }
        self.inference_config = {
            'model_dir': self.model_dir,
            'model_fname': 'dummy_weights.hdf5',
            'image_dir': self.image_dir,
            'data_split': 'test',
            'images': {
                'image_format': 'zyx',
                'image_ext': '.png',
            },
            'metrics': {
                'metrics': ['mse'],
                'metrics_orientations': ['xyz'],
            },
            'masks': {
                'mask_dir': self.mask_dir,
                'mask_type': 'metrics',
                'mask_channel': 50,
            },
            'inference_3d': {
                'tile_shape': [5, 5, 5],
                'num_overlap': [1, 1, 1],
                'overlap_operation': 'mean',
            },
        }
        # Instantiate class
        self.infer_inst = image_inference.ImagePredictor(
            train_config=self.train_config,
            inference_config=self.inference_config,
        )
Exemple #20
0
    def setUp(self, mock_model):
        """
        Set up a directory with images
        """
        mock_model.return_value = 'dummy_model'

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        self.tempdir.makedir('image_dir')
        self.tempdir.makedir('mask_dir')
        self.tempdir.makedir('model_dir')
        self.image_dir = os.path.join(self.temp_path, 'image_dir')
        self.mask_dir = os.path.join(self.temp_path, 'mask_dir')
        self.model_dir = os.path.join(self.temp_path, 'model_dir')
        # Create a temp image dir
        self.im = np.zeros((10, 16), dtype=np.uint8)
        self.frames_meta = aux_utils.make_dataframe()
        self.time_idx = 2
        for p in range(5):
            for c in range(3):
                for z in range(6):
                    im_name = aux_utils.get_im_name(
                        time_idx=self.time_idx,
                        channel_idx=c,
                        slice_idx=z,
                        pos_idx=p,
                    )
                    cv2.imwrite(os.path.join(self.image_dir, im_name),
                                self.im + c * 10)
                    self.frames_meta = self.frames_meta.append(
                        aux_utils.parse_idx_from_name(im_name,
                                                      aux_utils.DF_NAMES),
                        ignore_index=True,
                    )
        # Write frames meta to image dir too
        self.frames_meta.to_csv(os.path.join(self.image_dir,
                                             'frames_meta.csv'))
        # Save masks and mask meta
        self.mask_meta = aux_utils.make_dataframe()
        self.mask_channel = 50
        for p in range(5):
            for z in range(6):
                im_name = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.mask_channel,
                    slice_idx=z,
                    pos_idx=p,
                )
                cv2.imwrite(os.path.join(self.mask_dir, im_name), self.im + 1)
                self.mask_meta = self.mask_meta.append(
                    aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                    ignore_index=True,
                )
        # Write frames meta to mask dir too
        self.mask_meta.to_csv(os.path.join(self.mask_dir, 'frames_meta.csv'))
        # Setup model dir
        split_samples = {
            "train": [0, 1],
            "val": [2],
            "test": [3, 4],
        }
        aux_utils.write_json(
            split_samples,
            os.path.join(self.model_dir, 'split_samples.json'),
        )
        # Make configs with fields necessary for 2.5D segmentation inference
        self.train_config = {
            'network': {
                'class': 'UNetStackTo2D',
                'data_format': 'channels_first',
                'depth': 5,
                'width': 10,
                'height': 10
            },
            'dataset': {
                'split_by_column': 'pos_idx',
                'input_channels': [1],
                'target_channels': [self.mask_channel],
                'model_task': 'segmentation',
            },
        }
        self.inference_config = {
            'model_dir': self.model_dir,
            'model_fname': 'dummy_weights.hdf5',
            'image_dir': self.image_dir,
            'data_split': 'test',
            'images': {
                'image_format': 'zyx',
                'image_ext': '.png',
            },
            'metrics': {
                'metrics': ['dice'],
                'metrics_orientations': ['xy'],
            },
            'masks': {
                'mask_dir': self.mask_dir,
                'mask_type': 'target',
                'mask_channel': 50,
            }
        }
        # Instantiate class
        self.infer_inst = image_inference.ImagePredictor(
            train_config=self.train_config,
            inference_config=self.inference_config,
        )
    def setUp(self):
        """
        Set up a directory with some images to generate frames_meta.csv for
        """
        self.tempdir = TempDirectory()
        self.temp_dir = self.tempdir.path
        self.model_dir = os.path.join(self.temp_dir, 'model_dir')
        self.pred_dir = os.path.join(self.model_dir, 'predictions')
        self.image_dir = os.path.join(self.temp_dir, 'image_dir')
        self.tempdir.makedir(self.model_dir)
        self.tempdir.makedir(self.pred_dir)
        self.tempdir.makedir(self.image_dir)
        # Write images
        self.time_idx = 5
        self.pos_idx = 7
        self.im = 1500 * np.ones((30, 20), dtype=np.uint16)
        im_add = np.zeros((30, 20), dtype=np.uint16)
        im_add[15:, :] = 10
        self.ext = '.tif'
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        self.frames_meta = aux_utils.make_dataframe()

        for c in range(3):
            for z in range(5, 10):
                im_name = aux_utils.get_im_name(
                    channel_idx=c,
                    slice_idx=z,
                    time_idx=self.time_idx,
                    pos_idx=self.pos_idx,
                    ext=self.ext,
                )
                cv2.imwrite(os.path.join(self.image_dir, im_name), self.im)
                if c == 2:
                    norm_im = normalize.zscore(self.im + im_add).astype(np.float32)
                    cv2.imwrite(
                        os.path.join(self.pred_dir, im_name),
                        norm_im,
                    )
                self.frames_meta = self.frames_meta.append(
                    aux_utils.parse_idx_from_name(im_name),
                    ignore_index=True,
                )
        # Write metadata
        self.frames_meta.to_csv(
            os.path.join(self.image_dir, self.meta_name),
            sep=',',
        )
        # Write as test metadata in model dir too
        self.frames_meta.to_csv(
            os.path.join(self.model_dir, 'test_metadata.csv'),
            sep=',',
        )
        # Write split samples
        split_idx_fname = os.path.join(self.model_dir, 'split_samples.json')
        split_samples = {'test': [5, 6, 7, 8, 9]}
        aux_utils.write_json(split_samples, split_idx_fname)
        # Write config in model dir
        config = {
            'dataset': {
                'input_channels': [0, 1],
                'target_channels': [2],
                'split_by_column': 'slice_idx'
            },
            'network': {}
        }
        config_name = os.path.join(self.model_dir, 'config.yml')
        with open(config_name, 'w') as outfile:
            yaml.dump(config, outfile, default_flow_style=False)
Exemple #22
0
    def setUp(self):
        """Set up a dir for tiling with flatfield"""

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        frames_meta = aux_utils.make_dataframe()
        # Write images
        self.im = 127 * np.ones((15, 11), dtype=np.uint8)
        self.im2 = 234 * np.ones((15, 11), dtype=np.uint8)
        self.channel_idx = 1
        self.time_idx = 5
        self.pos_idx1 = 7
        self.pos_idx2 = 8
        self.int2str_len = 3

        # Write test images with 4 z and 2 pos idx
        for z in range(15, 20):
            im_name = aux_utils.get_im_name(
                channel_idx=self.channel_idx,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx1,
            )
            cv2.imwrite(
                os.path.join(self.temp_path, im_name),
                self.im,
            )
            frames_meta = frames_meta.append(
                aux_utils.parse_idx_from_name(im_name),
                ignore_index=True,
            )

        for z in range(15, 20):
            im_name = aux_utils.get_im_name(
                channel_idx=self.channel_idx,
                slice_idx=z,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx2,
            )
            cv2.imwrite(
                os.path.join(self.temp_path, im_name),
                self.im2,
            )
            frames_meta = frames_meta.append(
                aux_utils.parse_idx_from_name(im_name),
                ignore_index=True,
            )

        # Write metadata
        frames_meta.to_csv(
            os.path.join(self.temp_path, self.meta_name),
            sep=',',
        )
        # Add flatfield
        self.flat_field_dir = os.path.join(self.temp_path, 'ff_dir')
        self.tempdir.makedir('ff_dir')
        self.ff_im = 4. * np.ones((15, 11))
        self.ff_name = os.path.join(
            self.flat_field_dir,
            'flat-field_channel-1.npy',
        )
        np.save(self.ff_name, self.ff_im, allow_pickle=True, fix_imports=True)
        # Instantiate tiler class
        self.output_dir = os.path.join(self.temp_path, 'tile_dir')
        self.tile_inst = tile_images.ImageTilerUniform(
            input_dir=self.temp_path,
            output_dir=self.output_dir,
            tile_size=[5, 5],
            step_size=[4, 4],
            depths=3,
            channel_ids=[1],
            normalize_channels=[True],
            flat_field_dir=self.flat_field_dir,
        )
        exp_fnames = [
            'im_c001_z015_t005_p007.png', 'im_c001_z016_t005_p007.png',
            'im_c001_z017_t005_p007.png'
        ]
        self.exp_fnames = [
            os.path.join(self.temp_path, fname) for fname in exp_fnames
        ]
        self.exp_tile_indices = [
            [0, 5, 0, 5],
            [0, 5, 4, 9],
            [0, 5, 6, 11],
            [10, 15, 0, 5],
            [10, 15, 4, 9],
            [10, 15, 6, 11],
            [4, 9, 0, 5],
            [4, 9, 4, 9],
            [4, 9, 6, 11],
            [8, 13, 0, 5],
            [8, 13, 4, 9],
            [8, 13, 6, 11],
        ]

        # create a mask
        mask_dir = os.path.join(self.temp_path, 'mask_dir')
        os.makedirs(mask_dir, exist_ok=True)
        mask_images = np.zeros((15, 11, 5), dtype='bool')
        mask_images[4:12, 4:9, 2:4] = 1

        # write mask images and add meta to frames_meta
        self.mask_channel = 3
        mask_meta = []
        for z in range(5):
            cur_im = mask_images[:, :, z]
            im_name = aux_utils.get_im_name(
                channel_idx=3,
                slice_idx=z + 15,
                time_idx=self.time_idx,
                pos_idx=self.pos_idx1,
                ext='.npy',
            )
            np.save(os.path.join(mask_dir, im_name), cur_im)
            cur_meta = {
                'channel_idx': 3,
                'slice_idx': z + 15,
                'time_idx': self.time_idx,
                'pos_idx': self.pos_idx1,
                'file_name': im_name
            }
            mask_meta.append(cur_meta)
        mask_meta_df = pd.DataFrame.from_dict(mask_meta)
        mask_meta_df.to_csv(os.path.join(mask_dir, 'frames_meta.csv'), sep=',')
        self.mask_dir = mask_dir

        exp_tile_indices = [[0, 5, 0, 5], [0, 5, 4, 9], [0, 5, 6, 11],
                            [10, 15, 0, 5], [10, 15, 4, 9], [10, 15, 6, 11],
                            [4, 9, 0, 5], [4, 9, 4, 9], [4, 9, 6, 11],
                            [8, 13, 0, 5], [8, 13, 4, 9], [8, 13, 6, 11]]
        self.exp_tile_indices = exp_tile_indices
Exemple #23
0
def run_prediction(model_dir,
                   image_dir,
                   gpu_ids,
                   gpu_mem_frac,
                   model_fname=None,
                   metrics=None,
                   test_data=True,
                   ext='.tif',
                   save_figs=False,
                   normalize_im=False):
    """
    Predict images given model + weights.
    If the test_data flag is set to True, the test indices in
    split_samples.json file in model directory will be predicted
    Otherwise, all images in image directory will be predicted.
    It will load the config.yml file save in model_dir to reconstruct the model.
    Predictions are converted to uint16 and saved as png as default, but can
    also be saved as is in .npy format.
    If saving figures, it assumes that input as well as target channels are
    present in image_dir.

    :param str model_dir: Model directory
    :param str image_dir: Directory containing images for inference
    :param int gpu_ids: GPU ID to use for session
    :param float gpu_mem_frac: What fraction of GPU memory to use
    :param str model_fname: Model weights file name (in model dir)
    :param str metrics: String or list thereof of train/metrics.py functions
        to be computed during inference
    :param bool test_data: Use test indices from metadata, else use all
    :param str ext: File extension for inference output
    :param bool save_figs: Save plots of input/target/prediction
    """
    if gpu_ids >= 0:
        sess = train_utils.set_keras_session(gpu_ids=gpu_ids,
                                             gpu_mem_frac=gpu_mem_frac)
    # Load config file
    config_name = os.path.join(model_dir, 'config.yml')
    with open(config_name, 'r') as f:
        config = yaml.safe_load(f)
    # Load frames metadata and determine indices
    network_config = config['network']
    dataset_config = config['dataset']
    trainer_config = config['trainer']
    frames_meta = pd.read_csv(
        os.path.join(image_dir, 'frames_meta.csv'),
        index_col=0,
    )
    test_tile_meta = pd.read_csv(
        os.path.join(model_dir, 'test_metadata.csv'),
        index_col=0,
    )
    # TODO: generate test_frames_meta.csv together with tile csv during training
    test_frames_meta_filename = os.path.join(
        model_dir,
        'test_frames_meta.csv',
    )
    if metrics is not None:
        if isinstance(metrics, str):
            metrics = [metrics]
        metrics_cls = train_utils.get_metrics(metrics)
    else:
        metrics_cls = metrics
    loss = trainer_config['loss']
    loss_cls = train_utils.get_loss(loss)
    split_idx_name = dataset_config['split_by_column']
    K.set_image_data_format(network_config['data_format'])
    if test_data:
        idx_fname = os.path.join(model_dir, 'split_samples.json')
        try:
            split_samples = aux_utils.read_json(idx_fname)
            test_ids = split_samples['test']
        except FileNotFoundError as e:
            print("No split_samples file. Will predict all images in dir.")
    else:
        test_ids = np.unique(frames_meta[split_idx_name])

    # Find other indices to iterate over than split index name
    # E.g. if split is position, we also need to iterate over time and slice
    metadata_ids = {split_idx_name: test_ids}
    iter_ids = ['slice_idx', 'pos_idx', 'time_idx']
    for id in iter_ids:
        if id != split_idx_name:
            metadata_ids[id] = np.unique(test_tile_meta[id])

    # create empty dataframe for test image metadata
    if metrics is not None:
        test_frames_meta = pd.DataFrame(
            columns=frames_meta.columns.values.tolist() + metrics, )
    else:
        test_frames_meta = pd.DataFrame(
            columns=frames_meta.columns.values.tolist())
    # Get model weight file name, if none, load latest saved weights
    if model_fname is None:
        fnames = [f for f in os.listdir(model_dir) if f.endswith('.hdf5')]
        assert len(fnames) > 0, 'No weight files found in model dir'
        fnames = natsort.natsorted(fnames)
        model_fname = fnames[-1]
    weights_path = os.path.join(model_dir, model_fname)

    # Create image subdirectory to write predicted images
    pred_dir = os.path.join(model_dir, 'predictions')
    os.makedirs(pred_dir, exist_ok=True)
    target_channel = dataset_config['target_channels'][0]
    # If saving figures, create another subdirectory to predictions
    if save_figs:
        fig_dir = os.path.join(pred_dir, 'figures')
        os.makedirs(fig_dir, exist_ok=True)

    # If network depth is > 3 determine depth margins for +-z
    depth = 1
    if 'depth' in network_config:
        depth = network_config['depth']

    # Get input channel
    # TODO: Add multi channel support once such models are tested
    input_channel = dataset_config['input_channels'][0]
    assert isinstance(input_channel, int),\
        "Only supporting single input channel for now"
    # Get data format
    data_format = 'channels_first'
    if 'data_format' in network_config:
        data_format = network_config['data_format']
    # Load model with predict = True
    model = inference.load_model(
        network_config=network_config,
        model_fname=weights_path,
        predict=True,
    )
    print(model.summary())
    optimizer = trainer_config['optimizer']['name']
    model.compile(loss=loss_cls, optimizer=optimizer, metrics=metrics_cls)
    # Iterate over all indices for test data
    for time_idx in metadata_ids['time_idx']:
        for pos_idx in metadata_ids['pos_idx']:
            for slice_idx in metadata_ids['slice_idx']:
                # TODO: Add flatfield support
                im_stack = preprocess_imstack(frames_metadata=frames_meta,
                                              input_dir=image_dir,
                                              depth=depth,
                                              time_idx=time_idx,
                                              channel_idx=input_channel,
                                              slice_idx=slice_idx,
                                              pos_idx=pos_idx,
                                              normalize_im=normalize_im)
                # Crop image shape to nearest factor of two
                im_stack = image_utils.crop2base(im_stack)
                # Change image stack format to zyx
                im_stack = np.transpose(im_stack, [2, 0, 1])
                if depth == 1:
                    # Remove singular z dimension for 2D image
                    im_stack = np.squeeze(im_stack)
                # Add channel dimension
                if data_format == 'channels_first':
                    im_stack = im_stack[np.newaxis, ...]
                else:
                    im_stack = im_stack[..., np.newaxis]
                # add batch dimensions
                im_stack = im_stack[np.newaxis, ...]
                # Predict on large image
                start = time.time()
                im_pred = inference.predict_large_image(
                    model=model,
                    input_image=im_stack,
                )
                print("Inference time:", time.time() - start)
                # Write prediction image
                im_name = aux_utils.get_im_name(
                    time_idx=time_idx,
                    channel_idx=input_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                    ext=ext,
                )
                file_name = os.path.join(pred_dir, im_name)
                if ext == '.png':
                    # Convert to uint16 for now
                    im_pred = 2 ** 16 * (im_pred - im_pred.min()) / \
                              (im_pred.max() - im_pred.min())
                    im_pred = im_pred.astype(np.uint16)
                    cv2.imwrite(file_name, np.squeeze(im_pred))
                if ext == '.tif':
                    # Convert to float32 and remove batch dimension
                    im_pred = im_pred.astype(np.float32)
                    cv2.imwrite(file_name, np.squeeze(im_pred))
                elif ext == '.npy':
                    np.save(file_name, im_pred, allow_pickle=True)
                else:
                    raise ValueError('Unsupported file extension')

                # assuming target and predicted images are always 2D for now
                # Load target
                meta_idx = aux_utils.get_meta_idx(
                    frames_meta,
                    time_idx,
                    target_channel,
                    slice_idx,
                    pos_idx,
                )
                # get a single row of frame meta data
                test_frames_meta_row = frames_meta.loc[meta_idx].copy()
                im_target = preprocess_imstack(
                    frames_metadata=frames_meta,
                    input_dir=image_dir,
                    depth=1,
                    time_idx=time_idx,
                    channel_idx=target_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                )
                im_target = image_utils.crop2base(im_target)
                # TODO: Add image_format option to network config
                # Change image stack format to zyx
                im_target = np.transpose(im_target, [2, 0, 1])
                if depth == 1:
                    # Remove singular z dimension for 2D image
                    im_target = np.squeeze(im_target)
                # Add channel dimension
                if data_format == 'channels_first':
                    im_target = im_target[np.newaxis, ...]
                else:
                    im_target = im_target[..., np.newaxis]
                # add batch dimensions
                im_target = im_target[np.newaxis, ...]

                metric_vals = model.evaluate(x=im_pred, y=im_target)
                for metric, metric_val in zip([loss] + metrics, metric_vals):
                    test_frames_meta_row[metric] = metric_val

                test_frames_meta = test_frames_meta.append(
                    test_frames_meta_row,
                    ignore_index=True,
                )
                # Save figures if specified
                if save_figs:
                    # save predicted images assumes 2D
                    if depth > 1:
                        im_stack = im_stack[..., depth // 2, :, :]
                        im_target = im_target[0, ...]
                    plot_utils.save_predicted_images(input_batch=im_stack,
                                                     target_batch=im_target,
                                                     pred_batch=im_pred,
                                                     output_dir=fig_dir,
                                                     output_fname=im_name[:-4],
                                                     ext='jpg',
                                                     clip_limits=1,
                                                     font_size=15)

    # Save metrics as csv
    test_frames_meta.to_csv(test_frames_meta_filename, sep=",")
    def setUp(self):
        """Set up a dir for tiling with flatfield"""

        self.tempdir = TempDirectory()
        self.temp_path = self.tempdir.path
        # Start frames meta file
        self.meta_name = 'frames_meta.csv'
        frames_meta = aux_utils.make_dataframe()
        self.im = 127 * np.ones((15, 11), dtype=np.uint8)
        self.im2 = 234 * np.ones((15, 11), dtype=np.uint8)
        self.int2str_len = 3
        self.channel_idx = [1, 2]
        self.pos_idx1 = 7
        self.pos_idx2 = 8

        # write pos1 with 3 time points and 5 slices
        for z in range(5):
            for t in range(3):
                for c in self.channel_idx:
                    im_name = aux_utils.get_im_name(
                        channel_idx=c,
                        slice_idx=z,
                        time_idx=t,
                        pos_idx=self.pos_idx1,
                    )
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        sk_im_io.imsave(
                            os.path.join(self.temp_path, im_name),
                            self.im,
                        )
                    frames_meta = frames_meta.append(
                        aux_utils.parse_idx_from_name(im_name),
                        ignore_index=True,
                    )
        # write pos2 with 2 time points and 3 slices
        for z in range(3):
            for t in range(2):
                for c in self.channel_idx:
                    im_name = aux_utils.get_im_name(
                        channel_idx=c,
                        slice_idx=z,
                        time_idx=t,
                        pos_idx=self.pos_idx2,
                    )
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        sk_im_io.imsave(
                            os.path.join(self.temp_path, im_name),
                            self.im,
                        )
                    frames_meta = frames_meta.append(
                        aux_utils.parse_idx_from_name(im_name),
                        ignore_index=True,
                    )

        # Write metadata
        frames_meta.to_csv(os.path.join(self.temp_path, self.meta_name),
                           sep=',',)
        # Instantiate tiler class
        self.output_dir = os.path.join(self.temp_path, 'tile_dir')

        self.tile_inst = tile_images.ImageTilerNonUniform(
            input_dir=self.temp_path,
            output_dir=self.output_dir,
            tile_size=[5, 5],
            step_size=[4, 4],
            depths=3,
            channel_ids=[1, 2],
            normalize_channels=[False, True]
        )
Exemple #25
0
    def test_generate_masks_nonuni(self):
        """Test generate_masks with non-uniform structure"""

        rec = self.rec_object[:, :, 3:6]
        channel_ids = 0
        time_ids = 0
        pos_ids = [1, 2]
        frames_meta = aux_utils.make_dataframe()

        for z in range(self.sph_object.shape[2]):
            im_name = aux_utils.get_im_name(
                time_idx=time_ids,
                channel_idx=channel_ids,
                slice_idx=z,
                pos_idx=pos_ids[0],
            )
            sk_im_io.imsave(os.path.join(self.temp_path, im_name),
                            self.sph_object[:, :, z].astype('uint8'))
            frames_meta = frames_meta.append(aux_utils.parse_idx_from_name(
                im_name, aux_utils.DF_NAMES),
                                             ignore_index=True)
        for z in range(rec.shape[2]):
            im_name = aux_utils.get_im_name(
                time_idx=time_ids,
                channel_idx=channel_ids,
                slice_idx=z,
                pos_idx=pos_ids[1],
            )
            sk_im_io.imsave(os.path.join(self.temp_path, im_name),
                            rec[:, :, z].astype('uint8'))
            frames_meta = frames_meta.append(aux_utils.parse_idx_from_name(
                im_name, aux_utils.DF_NAMES),
                                             ignore_index=True)
        # Write metadata
        frames_meta.to_csv(os.path.join(self.temp_path, self.meta_fname),
                           sep=',')

        self.output_dir = os.path.join(self.temp_path, 'mask_dir')
        mask_gen_inst = MaskProcessor(input_dir=self.temp_path,
                                      output_dir=self.output_dir,
                                      channel_ids=channel_ids,
                                      uniform_struct=False)
        exp_nested_id_dict = {
            0: {
                0: {
                    1: [0, 1, 2, 3, 4, 5, 6, 7],
                    2: [0, 1, 2]
                }
            }
        }
        numpy.testing.assert_array_equal(mask_gen_inst.nested_id_dict[0][0][1],
                                         exp_nested_id_dict[0][0][1])
        numpy.testing.assert_array_equal(mask_gen_inst.nested_id_dict[0][0][2],
                                         exp_nested_id_dict[0][0][2])

        mask_gen_inst.generate_masks(str_elem_radius=1)

        frames_meta = pd.read_csv(
            os.path.join(mask_gen_inst.get_mask_dir(), 'frames_meta.csv'),
            index_col=0,
        )
        # pos1: 8 slices, pos2: 3 slices
        exp_len = 8 + 3
        nose.tools.assert_equal(len(frames_meta), exp_len)
        mask_fnames = frames_meta['file_name'].tolist()
        exp_mask_fnames = [
            'im_c001_z000_t000_p001.npy', 'im_c001_z000_t000_p002.npy',
            'im_c001_z001_t000_p001.npy', 'im_c001_z001_t000_p002.npy',
            'im_c001_z002_t000_p001.npy', 'im_c001_z002_t000_p002.npy',
            'im_c001_z003_t000_p001.npy', 'im_c001_z004_t000_p001.npy',
            'im_c001_z005_t000_p001.npy', 'im_c001_z006_t000_p001.npy',
            'im_c001_z007_t000_p001.npy'
        ]
        nose.tools.assert_list_equal(mask_fnames, exp_mask_fnames)
    def test_tile_remaining_channels(self):
        """Test tile_remaining_channels"""

        # tile channel 1
        nested_id_dict_copy = copy.deepcopy(self.tile_inst.nested_id_dict)
        ch0_ids = {}
        for tp_idx, tp_dict in self.tile_inst.nested_id_dict.items():
            for ch_idx, ch_dict in tp_dict.items():
                if ch_idx == 1:
                    ch0_dict = {ch_idx: ch_dict}
                    del nested_id_dict_copy[tp_idx][ch_idx]
            ch0_ids[tp_idx] = ch0_dict

        ch0_meta_df = self.tile_inst.tile_first_channel(ch0_ids, 3)
        # tile channel 2
        self.tile_inst.tile_remaining_channels(nested_id_dict_copy,
                                               tiled_ch_id=1,
                                               cur_meta_df=ch0_meta_df)
        frames_meta = pd.read_csv(os.path.join(self.tile_inst.tile_dir,
                                               'frames_meta.csv'),
                                  sep=',')
        # get the expected meta df which is a concat of the first channel df
        # and the current. it does seem to retain orig index, not sure how to
        # replace index in-place!
        exp_meta = []
        for row in [0, 4, 8, 10]:
            for col in [0, 4, 6]:
                for z in [1, 2, 3]:
                    for t in [0, 1, 2]:
                        for c in self.channel_idx:
                            fname = aux_utils.get_im_name(
                                channel_idx=c,
                                slice_idx=z,
                                time_idx=t,
                                pos_idx=7,
                                ext='.npy',
                            )
                            tile_id = '_r{}-{}_c{}-{}_sl0-3'.format(row, row+5,
                                                                    col, col+5)
                            fname = fname.split('.')[0] + tile_id + '.npy'
                            cur_meta = {'channel_idx': c,
                                        'slice_idx': z,
                                        'time_idx': t,
                                        'file_name': fname,
                                        'pos_idx': 7,
                                        'row_start': row,
                                        'col_start': col}
                            exp_meta.append(cur_meta)
                for t in [0, 1]:
                    for c in self.channel_idx:
                        fname = aux_utils.get_im_name(
                            channel_idx=c,
                            slice_idx=1,
                            time_idx=t,
                            pos_idx=8,
                            ext='.npy',
                        )
                        tile_id = '_r{}-{}_c{}-{}_sl0-3'.format(row, row + 5,
                                                                col, col + 5)
                        fname = fname.split('.')[0] + tile_id + '.npy'
                        cur_meta = {'channel_idx': c,
                                    'slice_idx': 1,
                                    'time_idx': t,
                                    'file_name': fname,
                                    'pos_idx': 8,
                                    'row_start': row,
                                    'col_start': col}
                        exp_meta.append(cur_meta)
        exp_meta_df = pd.DataFrame.from_dict(exp_meta, )
        frames_meta = frames_meta.sort_values(by=['file_name'])
        nose.tools.assert_equal(len(exp_meta_df), len(frames_meta))

        for i in range(len(frames_meta)):
            act_row = frames_meta.loc[i]
            row_idx = ((exp_meta_df['channel_idx'] == act_row['channel_idx']) &
                       (exp_meta_df['slice_idx'] == act_row['slice_idx']) &
                       (exp_meta_df['time_idx'] == act_row['time_idx']) &
                       (exp_meta_df['pos_idx'] == act_row['pos_idx']) &
                       (exp_meta_df['row_start'] == act_row['row_start']) &
                       (exp_meta_df['col_start'] == act_row['col_start']))
            exp_row = exp_meta_df.loc[row_idx]
            nose.tools.assert_equal(len(exp_row), 1)
            np.testing.assert_array_equal(act_row['file_name'],
                                          exp_row['file_name'])
Exemple #27
0
def compute_metrics(model_dir,
                    image_dir,
                    metrics_list,
                    orientations_list,
                    test_data=True):
    """
    Compute specified metrics for given orientations for predictions, which
    are assumed to be stored in model_dir/predictions. Targets are stored in
    image_dir.
    Writes metrics csv files for each orientation in model_dir/predictions.

    :param str model_dir: Assumed to contain config, split_samples.json and
        subdirectory predictions/
    :param str image_dir: Directory containing target images with frames_meta.csv
    :param list metrics_list: See inference/evaluation_metrics.py for options
    :param list orientations_list: Any subset of {xy, xz, yz, xyz}
        (see evaluation_metrics)
    :param bool test_data: Uses test indices in split_samples.json,
    otherwise all indices
    """
    # Load config file
    config_name = os.path.join(model_dir, 'config.yml')
    with open(config_name, 'r') as f:
        config = yaml.safe_load(f)
    # Load frames metadata and determine indices
    frames_meta = pd.read_csv(os.path.join(image_dir, 'frames_meta.csv'))

    if isinstance(metrics_list, str):
        metrics_list = [metrics_list]
    metrics_inst = metrics.MetricsEstimator(metrics_list=metrics_list)

    split_idx_name = config['dataset']['split_by_column']
    if test_data:
        idx_fname = os.path.join(model_dir, 'split_samples.json')
        try:
            split_samples = aux_utils.read_json(idx_fname)
            test_ids = split_samples['test']
        except FileNotFoundError as e:
            print("No split_samples file. Will predict all images in dir.")
    else:
        test_ids = np.unique(frames_meta[split_idx_name])

    # Find other indices to iterate over than split index name
    # E.g. if split is position, we also need to iterate over time and slice
    test_meta = pd.read_csv(os.path.join(model_dir, 'test_metadata.csv'))
    metadata_ids = {split_idx_name: test_ids}
    iter_ids = ['slice_idx', 'pos_idx', 'time_idx']

    for id in iter_ids:
        if id != split_idx_name:
            metadata_ids[id] = np.unique(test_meta[id])

    # Create image subdirectory to write predicted images
    pred_dir = os.path.join(model_dir, 'predictions')

    target_channel = config['dataset']['target_channels'][0]

    # If network depth is > 3 determine depth margins for +-z
    depth = 1
    if 'depth' in config['network']:
        depth = config['network']['depth']

    # Get channel name and extension for predictions
    pred_fnames = [f for f in os.listdir(pred_dir) if f.startswith('im_')]
    meta_row = aux_utils.parse_idx_from_name(pred_fnames[0])
    pred_channel = meta_row['channel_idx']
    _, ext = os.path.splitext(pred_fnames[0])

    if isinstance(orientations_list, str):
        orientations_list = [orientations_list]
    available_orientations = {'xy', 'xz', 'yz', 'xyz'}
    assert set(orientations_list).issubset(available_orientations), \
        "Orientations must be subset of {}".format(available_orientations)

    fn_mapping = {
        'xy': metrics_inst.estimate_xy_metrics,
        'xz': metrics_inst.estimate_xz_metrics,
        'yz': metrics_inst.estimate_yz_metrics,
        'xyz': metrics_inst.estimate_xyz_metrics,
    }
    metrics_mapping = {
        'xy': metrics_inst.get_metrics_xy,
        'xz': metrics_inst.get_metrics_xz,
        'yz': metrics_inst.get_metrics_yz,
        'xyz': metrics_inst.get_metrics_xyz,
    }
    df_mapping = {
        'xy': pd.DataFrame(),
        'xz': pd.DataFrame(),
        'yz': pd.DataFrame(),
        'xyz': pd.DataFrame(),
    }

    # Iterate over all indices for test data
    for time_idx in metadata_ids['time_idx']:
        for pos_idx in metadata_ids['pos_idx']:
            target_fnames = []
            pred_fnames = []
            for slice_idx in metadata_ids['slice_idx']:
                im_idx = aux_utils.get_meta_idx(
                    frames_metadata=frames_meta,
                    time_idx=time_idx,
                    channel_idx=target_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                )
                target_fname = os.path.join(
                    image_dir,
                    frames_meta.loc[im_idx, 'file_name'],
                )
                target_fnames.append(target_fname)
                pred_fname = aux_utils.get_im_name(
                    time_idx=time_idx,
                    channel_idx=pred_channel,
                    slice_idx=slice_idx,
                    pos_idx=pos_idx,
                    ext=ext,
                )
                pred_fname = os.path.join(pred_dir, pred_fname)
                pred_fnames.append(pred_fname)

            target_stack = image_utils.read_imstack(
                input_fnames=tuple(target_fnames),
            )
            pred_stack = image_utils.read_imstack(
                input_fnames=tuple(pred_fnames),
                normalize_im=False,
            )

            if depth == 1:
                # Remove singular z dimension for 2D image
                target_stack = np.squeeze(target_stack)
                pred_stack = np.squeeze(pred_stack)
            if target_stack.dtype == np.float64:
                target_stack = target_stack.astype(np.float32)
            pred_name = "t{}_p{}".format(time_idx, pos_idx)
            for orientation in orientations_list:
                metric_fn = fn_mapping[orientation]
                metric_fn(
                    target=target_stack,
                    prediction=pred_stack,
                    pred_name=pred_name,
                )
                df_mapping[orientation] = df_mapping[orientation].append(
                    metrics_mapping[orientation](),
                    ignore_index=True,
                )

    # Save non-empty dataframes
    for orientation in orientations_list:
        metrics_df = df_mapping[orientation]
        df_name = 'metrics_{}.csv'.format(orientation)
        metrics_name = os.path.join(pred_dir, df_name)
        metrics_df.to_csv(metrics_name, sep=",", index=False)
 def test_pre_process(self):
     out_config, runtime = pp.pre_process(self.pp_config, self.base_config)
     self.assertIsInstance(runtime, np.float)
     self.assertEqual(
         self.base_config['input_dir'],
         self.image_dir,
     )
     self.assertEqual(
         self.base_config['channel_ids'],
         self.pp_config['channel_ids'],
     )
     self.assertEqual(
         out_config['flat_field']['flat_field_dir'],
         os.path.join(self.output_dir, 'flat_field_images')
     )
     self.assertEqual(
         out_config['masks']['mask_dir'],
         os.path.join(self.output_dir, 'mask_channels_3')
     )
     self.assertEqual(
         out_config['tile']['tile_dir'],
         os.path.join(self.output_dir, 'tiles_10-10_step_10-10'),
     )
     # Make sure new mask channel assignment is correct
     self.assertEqual(out_config['masks']['mask_channel'], 4)
     # Check that masks are generated
     mask_dir = out_config['masks']['mask_dir']
     mask_meta = aux_utils.read_meta(mask_dir)
     mask_names = os.listdir(mask_dir)
     mask_names.pop(mask_names.index('frames_meta.csv'))
     # Validate that all masks are there
     self.assertEqual(
         len(mask_names),
         len(self.slice_ids) * len(self.pos_ids),
     )
     for p in self.pos_ids:
         for z in self.slice_ids:
             im_name = aux_utils.get_im_name(
                 channel_idx=out_config['masks']['mask_channel'],
                 slice_idx=z,
                 time_idx=self.time_idx,
                 pos_idx=p,
             )
             im = cv2.imread(
                 os.path.join(mask_dir, im_name),
                 cv2.IMREAD_ANYDEPTH,
             )
             self.assertTupleEqual(im.shape, (30, 20))
             self.assertTrue(im.dtype == 'uint8')
             self.assertTrue(im_name in mask_names)
             self.assertTrue(im_name in mask_meta['file_name'].tolist())
     # Check flatfield images
     ff_dir = out_config['flat_field']['flat_field_dir']
     ff_names = os.listdir(ff_dir)
     self.assertEqual(len(ff_names), 3)
     for processed_channel in [0, 1, 3]:
         expected_name = "flat-field_channel-{}.npy".format(processed_channel)
         self.assertTrue(expected_name in ff_names)
         im = np.load(os.path.join(ff_dir, expected_name))
         self.assertTrue(im.dtype == np.float64)
         self.assertTupleEqual(im.shape, (30, 20))
     # Check tiles
     tile_dir = out_config['tile']['tile_dir']
     tile_meta = aux_utils.read_meta(tile_dir)
     # 4 processed channels (0, 1, 3, 4), 6 tiles per image
     expected_rows = 4 * 6 * len(self.slice_ids) * len(self.pos_ids)
     self.assertEqual(tile_meta.shape[0], expected_rows)
     # Check indices
     self.assertListEqual(
         tile_meta.channel_idx.unique().tolist(),
         [0, 1, 3, 4],
     )
     self.assertListEqual(
         tile_meta.pos_idx.unique().tolist(),
         self.pos_ids,
     )
     self.assertListEqual(
         tile_meta.slice_idx.unique().tolist(),
         self.slice_ids,
     )
     self.assertListEqual(
         tile_meta.time_idx.unique().tolist(),
         [self.time_idx],
     )
     self.assertListEqual(
         list(tile_meta),
         ['channel_idx',
          'col_start',
          'file_name',
          'pos_idx',
          'row_start',
          'slice_idx',
          'time_idx']
     )
     self.assertListEqual(
         tile_meta.row_start.unique().tolist(),
         [0, 10, 20],
     )
     self.assertListEqual(
         tile_meta.col_start.unique().tolist(),
         [0, 10],
     )
     # Read one tile and check format
     # r = row start/end idx, c = column start/end, sl = slice start/end
     # sl0-1 signifies depth of 1
     im = np.load(os.path.join(
         tile_dir,
         'im_c001_z000_t000_p007_r10-20_c10-20_sl0-1.npy',
     ))
     self.assertTupleEqual(im.shape, (1, 10, 10))
     self.assertTrue(im.dtype == np.float64)
Exemple #29
0
 def setUp(self):
     """
     Set up a directory with images
     """
     self.tempdir = TempDirectory()
     self.temp_path = self.tempdir.path
     self.tempdir.makedir('image_dir')
     self.tempdir.makedir('model_dir')
     self.tempdir.makedir('mask_dir')
     self.image_dir = os.path.join(self.temp_path, 'image_dir')
     self.model_dir = os.path.join(self.temp_path, 'model_dir')
     self.mask_dir = os.path.join(self.temp_path, 'mask_dir')
     # Create a temp image dir
     im = np.zeros((10, 16), dtype=np.uint8)
     self.frames_meta = aux_utils.make_dataframe()
     self.time_idx = 2
     for p in range(5):
         for z in range(4):
             for c in range(3):
                 im_name = aux_utils.get_im_name(
                     time_idx=self.time_idx,
                     channel_idx=c,
                     slice_idx=z,
                     pos_idx=p,
                 )
                 cv2.imwrite(os.path.join(self.image_dir, im_name), im + c * 10)
                 self.frames_meta = self.frames_meta.append(
                     aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                     ignore_index=True,
                 )
     # Write frames meta to image dir too
     self.frames_meta.to_csv(os.path.join(self.image_dir, 'frames_meta.csv'))
     # Save masks and mask meta
     self.mask_meta = aux_utils.make_dataframe()
     self.mask_channel = 50
     for p in range(5):
         for z in range(4):
             im_name = aux_utils.get_im_name(
                 time_idx=2,
                 channel_idx=self.mask_channel,
                 slice_idx=z,
                 pos_idx=p,
             )
             cv2.imwrite(os.path.join(self.mask_dir, im_name), im + 1)
             self.mask_meta = self.mask_meta.append(
                 aux_utils.parse_idx_from_name(im_name, aux_utils.DF_NAMES),
                 ignore_index=True,
         )
     # Write frames meta to image dir too
     self.mask_meta.to_csv(os.path.join(self.mask_dir, 'frames_meta.csv'))
     # Select inference split of dataset
     self.split_col_ids = ('pos_idx', [1, 3])
     # Make configs with fields necessary for inference dataset
     dataset_config = {
         'input_channels': [2],
         'target_channels': [self.mask_channel],
         'model_task': 'segmentation',
     }
     self.network_config = {
         'class': 'UNetStackTo2D',
         'depth': 3,
         'data_format': 'channels_first',
     }
     # Instantiate class
     self.data_inst = inference_dataset.InferenceDataSet(
         image_dir=self.image_dir,
         dataset_config=dataset_config,
         network_config=self.network_config,
         split_col_ids=self.split_col_ids,
         mask_dir=self.mask_dir,
     )
    def test_tile_first_channel(self):
        """Test tile_first_channel"""

        ch0_ids = {}
        # get the indices for first channel
        for tp_idx, tp_dict in self.tile_inst.nested_id_dict.items():
            for ch_idx, ch_dict in tp_dict.items():
                if ch_idx == 1:
                    ch0_dict = {ch_idx: ch_dict}
            ch0_ids[tp_idx] = ch0_dict

        # get the expected meta df
        exp_meta = []
        for row in [0, 4, 8, 10]:
            for col in [0, 4, 6]:
                for z in [1, 2, 3]:
                    for t in [0, 1, 2]:
                        fname = aux_utils.get_im_name(
                            channel_idx=1,
                            slice_idx=z,
                            time_idx=t,
                            pos_idx=7,
                            ext='.npy',
                        )
                        tile_id = '_r{}-{}_c{}-{}_sl0-3'.format(row, row+5,
                                                               col, col+5)
                        fname = fname.split('.')[0] + tile_id + '.npy'
                        cur_meta = {'channel_idx': 1,
                                    'slice_idx': z,
                                    'time_idx': t,
                                    'file_name': fname,
                                    'pos_idx': 7,
                                    'row_start': row,
                                    'col_start': col}
                        exp_meta.append(cur_meta)
                for t in [0, 1]:
                    fname = aux_utils.get_im_name(
                        channel_idx=1,
                        slice_idx=1,
                        time_idx=t,
                        pos_idx=8,
                        ext='.npy',
                    )
                    tile_id = '_r{}-{}_c{}-{}_sl0-3'.format(row, row + 5,
                                                            col, col + 5)
                    fname = fname.split('.')[0] + tile_id + '.npy'
                    cur_meta = {'channel_idx': 1,
                                'slice_idx': 1,
                                'time_idx': t,
                                'file_name': fname,
                                'pos_idx': 8,
                                'row_start': row,
                                'col_start': col}
                    exp_meta.append(cur_meta)
        exp_meta_df = pd.DataFrame.from_dict(exp_meta)
        exp_meta_df = exp_meta_df.sort_values(by=['file_name'])

        ch0_meta_df = self.tile_inst.tile_first_channel(ch0_ids, 3)
        ch0_meta_df = ch0_meta_df.sort_values(by=['file_name'])
        # compare values of the returned and expected dfs
        np.testing.assert_array_equal(exp_meta_df.values, ch0_meta_df.values)