예제 #1
0
def tile_and_save(input_fnames,
                  flat_field_fname,
                  hist_clip_limits,
                  time_idx,
                  channel_idx,
                  pos_idx,
                  slice_idx,
                  tile_size,
                  step_size,
                  min_fraction,
                  image_format,
                  save_dir,
                  int2str_len=3,
                  is_mask=False,
                  normalize_im=False):
    """Crop image into tiles at given indices and save

    :param tuple input_fnames: tuple of input fnames with full path
    :param str flat_field_fname: fname of flat field image
    :param tuple hist_clip_limits: limits for histogram clipping
    :param int time_idx: time point of input image
    :param int channel_idx: channel idx of input image
    :param int slice_idx: slice idx of input image
    :param int pos_idx: sample idx of input image
    :param list tile_size: size of tile along row, col (& slices)
    :param list step_size: step size along row, col (& slices)
    :param float min_fraction: min foreground volume fraction for keep tile
    :param str image_format: zyx / xyz
    :param str save_dir: output dir to save tiles
    :param int int2str_len: len of indices for creating file names
    :param bool is_mask: Indicates if files are masks
    :param bool normalize_im: Indicates if normalizing using z score is needed
    :return: pd.DataFrame from a list of dicts with metadata
    """
    try:
        input_image = image_utils.read_imstack(
            input_fnames=input_fnames,
            flat_field_fname=flat_field_fname,
            hist_clip_limits=hist_clip_limits,
            is_mask=is_mask,
            normalize_im=normalize_im
        )
        save_dict = {'time_idx': time_idx,
                     'channel_idx': channel_idx,
                     'pos_idx': pos_idx,
                     'slice_idx': slice_idx,
                     'save_dir': save_dir,
                     'image_format': image_format,
                     'int2str_len': int2str_len}

        tile_meta_df = tile_utils.tile_image(
            input_image=input_image,
            tile_size=tile_size,
            step_size=step_size,
            min_fraction=min_fraction,
            save_dict=save_dict,
        )
    except Exception as e:
        err_msg = 'error in t_{}, c_{}, pos_{}, sl_{}'.format(
            time_idx, channel_idx, pos_idx, slice_idx
        )
        err_msg = err_msg + str(e)
        # TODO(Anitha) write to log instead
        print(err_msg)
        raise e
    return tile_meta_df
예제 #2
0
    def tile_stack(self):
        """
        Tiles images in the specified channels.

        https://research.wmz.ninja/articles/2018/03/
        on-sharing-large-arrays-when-using-pythons-multiprocessing.html

        Saves a csv with columns
        ['time_idx', 'channel_idx', 'pos_idx','slice_idx', 'file_name']
        for all the tiles
        """
        # Get or create tiled metadata and tile indices
        prev_tiled_metadata, tile_indices = self._get_tiled_data()

        tiled_meta0 = None
        fn_args = []
        for channel_idx in self.channel_ids:
            # Find channel index position in channel_ids list
            list_idx = self.channel_ids.index(channel_idx)
            # Perform flatfield correction if flatfield dir is specified
            flat_field_im = self._get_flat_field(channel_idx=channel_idx)
            for slice_idx in self.slice_ids:
                for time_idx in self.time_ids:
                    for pos_idx in self.pos_ids:
                        if tile_indices is None:
                            # tile and save first image
                            # get meta data and tile_indices
                            im = image_utils.preprocess_imstack(
                                frames_metadata=self.frames_metadata,
                                input_dir=self.input_dir,
                                depth=self.channel_depth[channel_idx],
                                time_idx=time_idx,
                                channel_idx=channel_idx,
                                slice_idx=slice_idx,
                                pos_idx=pos_idx,
                                flat_field_im=flat_field_im,
                                hist_clip_limits=self.hist_clip_limits,
                                normalize_im=self.normalize_channels[list_idx],
                            )
                            save_dict = {
                                'time_idx': time_idx,
                                'channel_idx': channel_idx,
                                'pos_idx': pos_idx,
                                'slice_idx': slice_idx,
                                'save_dir': self.tile_dir,
                                'image_format': self.image_format,
                                'int2str_len': self.int2str_len
                            }
                            tiled_meta0, tile_indices = \
                                tile_utils.tile_image(
                                    input_image=im,
                                    tile_size=self.tile_size,
                                    step_size=self.step_size,
                                    return_index=True,
                                    save_dict=save_dict,
                                )
                        else:
                            cur_args = self.get_crop_tile_args(
                                channel_idx,
                                time_idx,
                                slice_idx,
                                pos_idx,
                                task_type='crop',
                                tile_indices=tile_indices,
                                normalize_im=self.normalize_channels[list_idx],
                            )
                            fn_args.append(cur_args)
        tiled_meta_df_list = mp_utils.mp_crop_save(
            fn_args,
            workers=self.num_workers,
        )
        if tiled_meta0 is not None:
            tiled_meta_df_list.append(tiled_meta0)
        tiled_metadata = pd.concat(tiled_meta_df_list, ignore_index=True)
        if self.tiles_exist:
            tiled_metadata.reset_index(drop=True, inplace=True)
            prev_tiled_metadata.reset_index(drop=True, inplace=True)
            tiled_metadata = pd.concat(
                [prev_tiled_metadata, tiled_metadata],
                ignore_index=True,
            )
        # Finally, save all the metadata
        tiled_metadata = tiled_metadata.sort_values(by=['file_name'])
        tiled_metadata.to_csv(
            os.path.join(self.tile_dir, "frames_meta.csv"),
            sep=",",
        )
예제 #3
0
    def tile_mask_stack(self,
                        input_mask_dir,
                        tile_index_fname=None,
                        tile_size=None,
                        step_size=None,
                        isotropic=False):
        """
        Tiles a stack of masks

        :param str/list input_mask_dir: input_mask_dir with full path
        :param str tile_index_fname: fname with full path for the pickle file
         which contains a dict with fname as keys and crop indices as values.
         Needed when tiling using a volume fraction constraint (i.e. check for
         minimum foreground in tile)
        :param list/tuple tile_size: as named
        :param list/tuple step_size: as named
        :param bool isotropic: indicator for making the tiles have isotropic
         shape (only for 3D)
        """

        if tile_index_fname:
            msg = 'tile index file does not exist'
            assert (os.path.exists(tile_index_fname)
                    and os.path.isfile(tile_index_fname)), msg
            with open(tile_index_fname, 'rb') as f:
                crop_indices_dict = pickle.load(f)
        else:
            msg = 'tile_size and step_size are needed'
            assert tile_size is not None and step_size is not None, msg
            msg = 'tile and step sizes should have same length'
            assert len(tile_size) == len(step_size), msg

        if not isinstance(input_mask_dir, list):
            input_mask_dir = [input_mask_dir]

        for ch_idx, cur_dir in enumerate(input_mask_dir):
            # Split dir name and remove last / if present
            sep_strs = cur_dir.split(os.sep)
            if len(sep_strs[-1]) == 0:
                sep_strs.pop(-1)
            cur_tp = int(sep_strs[-2].split('_')[-1])
            cur_ch = int(sep_strs[-1].split('_')[-1])
            #  read all mask npy files
            mask_fnames = glob.glob(os.path.join(cur_dir, '*.npy'))
            # Sort file names, the assumption is that the csv is sorted
            mask_fnames = natsort.natsorted(mask_fnames)
            cropped_meta = []
            output_dir = os.path.join(
                self.output_dir, 'timepoint_{}'.format(cur_tp),
                'channel_{}'.format(self.output_channel_id[ch_idx]))
            os.makedirs(output_dir, exist_ok=True)
            for cur_mask_fname in mask_fnames:
                _, fname = os.path.split(cur_mask_fname)
                sample_num = int(fname.split('_')[1][1:])
                cur_mask = np.load(cur_mask_fname)
                if tile_index_fname:
                    cropped_image_data = tile_utils.crop_at_indices(
                        input_image=cur_mask,
                        crop_indices=crop_indices_dict[fname],
                        isotropic=isotropic)
                else:
                    cropped_image_data = tile_utils.tile_image(
                        input_image=cur_mask,
                        tile_size=tile_size,
                        step_size=step_size,
                        isotropic=isotropic)
                # save the stack
                for id_img_tuple in cropped_image_data:
                    rcsl_idx = id_img_tuple[0]
                    img_fname = 'n{}_{}.npy'.format(sample_num, rcsl_idx)
                    cropped_img = id_img_tuple[1]
                    cropped_img_fname = os.path.join(output_dir, img_fname)
                    np.save(
                        cropped_img_fname,
                        cropped_img,
                        allow_pickle=True,
                        fix_imports=True,
                    )
                    cropped_meta.append(
                        (cur_tp, cur_ch, sample_num, self.focal_plane_idx,
                         cropped_img_fname))
                    aux_utils.save_tile_meta(
                        cropped_meta,
                        cur_channel=cur_ch,
                        tiled_dir=self.output_dir,
                    )
예제 #4
0
    def predict_3d(self, iteration_rows):
        """
        Run prediction in 3D on images with 3D shape.

        :param list iteration_rows: Inference meta rows
        :return np.array pred_stack: Prediction
        :return np.array target_stack: Target
        :return np.array/list mask_stack: Mask for metrics
        """
        crop_indices = None
        assert len(iteration_rows) == 1, \
            'more than one matching row found for position ' \
            '{}'.format(iteration_rows.pos_idx)
        cur_input, cur_target = \
            self.dataset_inst.__getitem__(iteration_rows[0])
        # If crop shape is defined in images dict
        if self.crop_shape is not None:
            cur_input = image_utils.center_crop_to_shape(
                cur_input,
                self.crop_shape,
            )
            cur_target = image_utils.center_crop_to_shape(
                cur_target,
                self.crop_shape,
            )
        inf_shape = None
        if self.tile_option == 'infer_on_center':
            inf_shape = self.params_3d['inf_shape']
            center_block = image_utils.center_crop_to_shape(
                cur_input, inf_shape)
            cur_target = image_utils.center_crop_to_shape(
                cur_target, inf_shape)
            pred_image = inference.predict_large_image(
                model=self.model,
                input_image=center_block,
            )
        elif self.tile_option == 'tile_z':
            pred_block_list, start_end_idx = \
                self._predict_sub_block_z(cur_input)
            pred_image = self.stitch_inst.stitch_predictions(
                np.squeeze(cur_input).shape, pred_block_list, start_end_idx)
        elif self.tile_option == 'tile_xyz':
            step_size = (np.array(self.params_3d['tile_shape']) -
                         np.array(self.num_overlap))
            if crop_indices is None:
                # TODO tile_image works for 2D/3D imgs, modify for multichannel
                _, crop_indices = tile_utils.tile_image(
                    input_image=np.squeeze(cur_input),
                    tile_size=self.params_3d['tile_shape'],
                    step_size=step_size,
                    return_index=True)
            pred_block_list = self._predict_sub_block_xyz(
                cur_input,
                crop_indices,
            )
            pred_image = self.stitch_inst.stitch_predictions(
                np.squeeze(cur_input).shape,
                pred_block_list,
                crop_indices,
            )
        pred_image = np.squeeze(pred_image).astype(np.float32)
        target_image = np.squeeze(cur_target).astype(np.float32)
        # save prediction
        cur_row = self.iteration_meta.iloc[iteration_rows[0]]
        self.save_pred_image(
            predicted_image=pred_image,
            time_idx=cur_row['time_idx'],
            target_channel_idx=cur_row['channel_idx'],
            pos_idx=cur_row['pos_idx'],
            slice_idx=cur_row['slice_idx'],
        )
        # 3D uses zyx, estimate metrics expects xyz
        if self.image_format == 'zyx':
            pred_image = np.transpose(pred_image, [1, 2, 0])
            target_image = np.transpose(target_image, [1, 2, 0])
        # get mask
        mask_image = None
        if self.masks_dict is not None:
            mask_image = self.get_mask(cur_row, transpose=True)
            if inf_shape is not None:
                mask_image = image_utils.center_crop_to_shape(
                    mask_image,
                    inf_shape,
                )
            if self.image_format == 'zyx':
                mask_image = np.transpose(mask_image, [1, 2, 0])
        return pred_image, target_image, mask_image
예제 #5
0
    def predict_on_full_image(self,
                              image_meta,
                              test_samples,
                              focal_plane_idx=None,
                              depth=None,
                              per_tile_overlap=1 / 8,
                              flat_field_correct=False,
                              base_image_dir=None,
                              place_operation='mean'):
        """Tile and run inference on tiles and assemble the full image

        :param pd.DataFrame image_meta: Df with individual image info,
         timepoint', 'channel_num', 'sample_num', 'slice_num', 'fname',
         'size_x_microns', 'size_y_microns', 'size_z_microns'
        :param list test_samples: list of sample numbers to be used in the
         test set
        :param int focal_plane_idx: focal plane to be used
        :param int depth: if 3D - num of slices used for tiling
        :param float per_tile_overlap: percent overlap between successive tiles
        :param bool flat_field_correct: indicator for applying flat field
         correction
        :param str base_image_dir: base directory where images are stored
        :param str place_operation: in ['mean', 'max']. mean for regression tasks,
         max for segmentation tasks
        """

        assert place_operation in ['mean', 'max'], \
            'only mean and max are allowed: %s' % place_operation
        if 'timepoints' not in self.config['dataset']:
            timepoint_ids = -1
        else:
            timepoint_ids = self.config['dataset']['timepoints']

        ip_channel_ids = self.config['dataset']['input_channels']
        op_channel_ids = self.config['dataset']['target_channels']
        tp_channel_ids = aux_utils.validate_metadata_indices(
            image_meta, time_ids=timepoint_ids)
        tp_idx = tp_channel_ids['timepoints']
        tile_size = [
            self.config['network']['height'], self.config['network']['width']
        ]

        if depth is not None:
            assert 'depth' in self.config['network']
            tile_size.insert(0, depth)

        step_size = (1 - per_tile_overlap) * np.array(tile_size)
        step_size = step_size.astype('int')
        step_size[step_size < 1] = 1

        overlap_size = tile_size - step_size
        batch_size = self.config['trainer']['batch_size']

        if flat_field_correct:
            assert base_image_dir is not None
            ff_dir = os.path.join(base_image_dir, 'flat_field_images')
        else:
            ff_dir = None

        for tp in tp_idx:
            # get the meta for all images in tp_dir and channel_dir
            row_idx_ip0 = aux_utils.get_row_idx(image_meta,
                                                tp,
                                                ip_channel_ids[0],
                                                slice_idx=focal_plane_idx)
            ip0_meta = image_meta[row_idx_ip0]

            # get rows corr. to test_samples from this DF
            test_row_ip0 = ip0_meta.loc[ip0_meta['sample_num'].isin(
                test_samples)]
            test_ip0_fnames = test_row_ip0['fname'].tolist()
            test_image_fnames = ([
                fname.split(os.sep)[-1] for fname in test_ip0_fnames
            ])
            tp_dir = str(os.sep).join(test_ip0_fnames[0].split(os.sep)[:-2])
            test_image = np.load(test_ip0_fnames[0])
            _, crop_indices = tile_utils.tile_image(test_image,
                                                    tile_size,
                                                    step_size,
                                                    return_index=True)
            pred_dir = os.path.join(self.config['trainer']['model_dir'],
                                    'predicted_images', 'tp_{}'.format(tp))
            for fname in test_image_fnames:
                target_image = self._read_one(tp_dir, op_channel_ids, fname,
                                              ff_dir)
                input_image = self._read_one(tp_dir, ip_channel_ids, fname,
                                             ff_dir)
                pred_tiles = self._pred_image(input_image, crop_indices,
                                              batch_size)
                pred_image = self._stitch_image(pred_tiles, crop_indices,
                                                input_image.shape, batch_size,
                                                tile_size, overlap_size,
                                                place_operation)
                pred_fname = '{}.npy'.format(fname.split('.')[0])
                for idx, op_ch in enumerate(op_channel_ids):
                    op_dir = os.path.join(pred_dir, 'channel_{}'.format(op_ch))
                    if not os.path.exists(op_dir):
                        os.makedirs(op_dir)
                    np.save(os.path.join(op_dir, pred_fname), pred_image[idx])
                    save_predicted_images([input_image], [target_image],
                                          [pred_image],
                                          os.path.join(op_dir, 'collage'),
                                          output_fname=fname.split('.')[0])
예제 #6
0
    def test_tile_image(self):
        """Test tile_image"""

        input_image = self.sph[:, :, 3:6]
        tile_size = [16, 16]
        step_size = [8, 8]
        # returns at tuple of (img_id, tile)
        tiled_image_list = tile_utils.tile_image(
            input_image,
            tile_size=tile_size,
            step_size=step_size,
        )
        nose.tools.assert_equal(len(tiled_image_list), 9)
        c = 0
        for row in range(0, 17, 8):
            for col in range(0, 17, 8):
                id_str = 'r{}-{}_c{}-{}_sl{}-{}'.format(
                    row, row + tile_size[0], col, col + tile_size[1], 0, 3)
                nose.tools.assert_equal(id_str, tiled_image_list[c][0])
                tile = input_image[row:row + tile_size[0],
                                   col:col + tile_size[1], ...]
                numpy.testing.assert_array_equal(tile, tiled_image_list[c][1])
                c += 1

        # returns tuple_list, cropping_index
        _, tile_index = tile_utils.tile_image(
            input_image,
            tile_size=tile_size,
            step_size=step_size,
            return_index=True,
        )
        exp_tile_index = [(0, 16, 0, 16), (0, 16, 8, 24), (0, 16, 16, 32),
                          (8, 24, 0, 16), (8, 24, 8, 24), (8, 24, 16, 32),
                          (16, 32, 0, 16), (16, 32, 8, 24), (16, 32, 16, 32)]

        numpy.testing.assert_equal(exp_tile_index, tile_index)

        # save tiles in place and return meta_df
        tile_dir = os.path.join(self.temp_path, 'tile_dir')
        os.makedirs(tile_dir, exist_ok=True)
        meta_dir = os.path.join(tile_dir, 'meta_dir')
        os.makedirs(meta_dir, exist_ok=True)
        save_dict = {
            'time_idx': self.time_idx,
            'channel_idx': self.channel_idx,
            'slice_idx': 4,
            'pos_idx': self.pos_idx,
            'image_format': 'zyx',
            'int2str_len': 3,
            'save_dir': tile_dir
        }
        tile_meta_df = tile_utils.tile_image(
            input_image,
            tile_size=tile_size,
            step_size=step_size,
            save_dict=save_dict,
        )
        tile_meta = []
        for row in range(0, 17, 8):
            for col in range(0, 17, 8):
                id_str = 'r{}-{}_c{}-{}_sl{}-{}'.format(
                    row, row + tile_size[0], col, col + tile_size[1], 0, 3)
                cur_fname = aux_utils.get_im_name(
                    time_idx=self.time_idx,
                    channel_idx=self.channel_idx,
                    slice_idx=4,
                    pos_idx=self.pos_idx,
                    int2str_len=3,
                    extra_field=id_str,
                    ext='.npy',
                )
                cur_path = os.path.join(tile_dir, cur_fname)
                nose.tools.assert_equal(os.path.exists(cur_path), True)
                cur_meta = {
                    'channel_idx': self.channel_idx,
                    'slice_idx': 4,
                    'time_idx': self.time_idx,
                    'file_name': cur_fname,
                    'pos_idx': self.pos_idx,
                    'row_start': row,
                    'col_start': col
                }
                tile_meta.append(cur_meta)
        exp_tile_meta_df = pd.DataFrame.from_dict(tile_meta)
        exp_tile_meta_df = exp_tile_meta_df.sort_values(by=['file_name'])
        pd.testing.assert_frame_equal(tile_meta_df, exp_tile_meta_df)

        # use mask and min_fraction to select tiles to retain
        input_image_bool = input_image > 128
        _, tile_index = tile_utils.tile_image(
            input_image_bool,
            tile_size=tile_size,
            step_size=step_size,
            min_fraction=0.3,
            return_index=True,
        )
        exp_tile_index = [(0, 16, 8, 24), (8, 24, 0, 16), (8, 24, 8, 24),
                          (8, 24, 16, 32), (16, 32, 8, 24)]
        numpy.testing.assert_array_equal(tile_index, exp_tile_index)

        # tile_3d
        input_image = self.sph
        tile_size = [16, 16, 6]
        step_size = [8, 8, 4]
        # returns at tuple of (img_id, tile)
        tiled_image_list = tile_utils.tile_image(
            input_image,
            tile_size=tile_size,
            step_size=step_size,
        )
        nose.tools.assert_equal(len(tiled_image_list), 18)
        c = 0
        for row in range(0, 17, 8):
            for col in range(0, 17, 8):
                for sl in range(0, 8, 6):
                    if sl == 0:
                        sl_start_end = [0, 6]
                    else:
                        sl_start_end = [2, 8]

                    id_str = 'r{}-{}_c{}-{}_sl{}-{}'.format(
                        row, row + tile_size[0], col, col + tile_size[1],
                        sl_start_end[0], sl_start_end[1])
                    nose.tools.assert_equal(id_str, tiled_image_list[c][0])
                    tile = input_image[row:row + tile_size[0],
                                       col:col + tile_size[1],
                                       sl_start_end[0]:sl_start_end[1]]
                    numpy.testing.assert_array_equal(tile,
                                                     tiled_image_list[c][1])
                    c += 1