Example #1
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 channel_ids,
                 slice_ids,
                 block_size=32):
        """
        Flatfield images are estimated once per channel for 2D data

        :param str input_dir: Directory with 2D image frames from dataset
        :param str output_dir: Base output directory
        :param int/list channel_ids: channel ids for flat field_correction
        :param int/list slice_ids: Z slice indices for flatfield correction
        :param int block_size: Size of blocks image will be divided into
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        # Create flat_field_dir as a subdirectory of output_dir
        self.flat_field_dir = os.path.join(self.output_dir,
                                           'flat_field_images')
        os.makedirs(self.flat_field_dir, exist_ok=True)
        self.slice_ids = slice_ids
        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        metadata_ids, _ = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
        )
        self.channels_ids = metadata_ids['channel_ids']
        self.slice_ids = metadata_ids['slice_ids']
        if block_size is None:
            block_size = 32
        self.block_size = block_size
Example #2
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 tile_size=[256, 256],
                 step_size=[64, 64],
                 depths=1,
                 time_ids=-1,
                 channel_ids=-1,
                 normalize_channels=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 hist_clip_limits=None,
                 flat_field_dir=None,
                 image_format='zyx',
                 num_workers=4,
                 int2str_len=3,
                 tile_3d=False):
        """Init

        Assuming same structure across channels and same number of samples
        across channels. The dataset could have varying number of time points
        and / or varying number of slices / size for each sample / position
        Please ref to init of ImageTilerUniform.
        """

        super().__init__(input_dir=input_dir,
                         output_dir=output_dir,
                         tile_size=tile_size,
                         step_size=step_size,
                         depths=depths,
                         time_ids=time_ids,
                         channel_ids=channel_ids,
                         normalize_channels=normalize_channels,
                         slice_ids=slice_ids,
                         pos_ids=pos_ids,
                         hist_clip_limits=hist_clip_limits,
                         flat_field_dir=flat_field_dir,
                         image_format=image_format,
                         num_workers=num_workers,
                         int2str_len=int2str_len,
                         tile_3d=tile_3d)
        # Get metadata indices
        metadata_ids, nested_id_dict = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
            uniform_structure=False
        )
        self.nested_id_dict = nested_id_dict
        # self.tile_dir is already created in super(). Check if frames_meta
        # exists in self.tile_dir
        meta_path = os.path.join(self.tile_dir, 'frames_meta.csv')
        assert not os.path.exists(meta_path), 'Tile dir exists. ' \
                                              'cannot add to existing dir'
Example #3
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 scale_factor,
                 channel_ids=-1,
                 time_ids=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 int2str_len=3,
                 num_workers=4,
                 flat_field_dir=None):
        """
        :param str input_dir: Directory with image frames
        :param str output_dir: Base output directory
        :param float/list scale_factor: Scale factor for resizing frames.
        :param int/list channel_ids: Channel indices to resize
            (default -1 includes all slices)
        :param int/list time_ids: timepoints to use
        :param int/list slice_ids: Index of slice (z) indices to use
        :param int/list pos_ids: Position (FOV) indices to use
        :param int int2str_len: Length of str when converting ints
        :param int num_workers: number of workers for multiprocessing
        :param str flat_field_dir: dir with flat field images
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        if isinstance(scale_factor, list):
            scale_factor = np.array(scale_factor)
        assert np.all(scale_factor > 0), \
            "Scale factor should be positive float, not {}".format(scale_factor)
        self.scale_factor = scale_factor

        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        metadata_ids, _ = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
        )
        self.time_ids = metadata_ids['time_ids']
        self.channel_ids = metadata_ids['channel_ids']
        self.slice_ids = metadata_ids['slice_ids']
        self.pos_ids = metadata_ids['pos_ids']

        # Create resize_dir as a subdirectory of output_dir
        self.resize_dir = os.path.join(
            self.output_dir,
            'resized_images',
        )
        os.makedirs(self.resize_dir, exist_ok=True)

        self.int2str_len = int2str_len
        self.num_workers = num_workers
        self.flat_field_dir = flat_field_dir
Example #4
0
def validate_metadata_indices():
    metadata_ids, tp_dict = aux_utils.validate_metadata_indices(
        meta_df,
        time_ids=-1,
        channel_ids=-1,
        slice_ids=-1,
        pos_ids=-1,
    )
    nose.tools.assert_list_equal(metadata_ids['channel_ids'].tolist(), [5])
    nose.tools.assert_list_equal(metadata_ids['slice_ids'].tolist(), [0, 1, 2])
    nose.tools.assert_list_equal(metadata_ids['time_ids'].tolist(), [6])
    nose.tools.assert_list_equal(metadata_ids['pos_ids'].tolist(),
                                 [0, 1, 2, 3])
    nose.tools.assert_is_none(tp_dict)
Example #5
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 channel_ids,
                 flat_field_dir=None,
                 time_ids=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 int2str_len=3,
                 uniform_struct=True,
                 num_workers=4,
                 mask_type='otsu',
                 mask_channel=None,
                 mask_ext='.npy',
                 normalize_im=False):
        """
        :param str input_dir: Directory with image frames
        :param str output_dir: Base output directory
        :param list[int] channel_ids: Channel indices to be masked (typically
            just one)
        :param str flat_field_dir: Directory with flatfield images if
            flatfield correction is applied
        :param int/list channel_ids: generate mask from the sum of these
            (flurophore) channel indices
        :param list/int time_ids: timepoints to consider
        :param int slice_ids: Index of which focal plane (z)
            acquisition to use (default -1 includes all slices)
        :param int pos_ids: Position (FOV) indices to use
        :param int int2str_len: Length of str when converting ints
        :param bool uniform_struct: bool indicator for same structure across
            pos and time points
        :param int num_workers: number of workers for multiprocessing
        :param str mask_type: method to use for generating mask. Needed for
            mapping to the masking function
        :param int mask_channel: channel number assigned to to be generated masks.
            If resizing images on a subset of channels, frames_meta is from resize
            dir, which could lead to wrong mask channel being assigned.
        :param str mask_ext: '.npy' or 'png'. Save the mask as uint8 PNG or
            NPY files
        :param bool normalize_im: indicator to normalize image based on z-score or not
        """

        self.input_dir = input_dir
        self.output_dir = output_dir
        self.flat_field_dir = flat_field_dir
        self.num_workers = num_workers
        self.normalize_im = normalize_im

        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        # Create a unique mask channel number so masks can be treated
        # as a new channel
        if mask_channel is None:
            self.mask_channel = int(self.frames_metadata['channel_idx'].max() +
                                    1)
        else:
            self.mask_channel = int(mask_channel)

        metadata_ids, nested_id_dict = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
            uniform_structure=uniform_struct,
        )
        self.time_ids = metadata_ids['time_ids']
        self.channel_ids = metadata_ids['channel_ids']
        self.slice_ids = metadata_ids['slice_ids']
        self.pos_ids = metadata_ids['pos_ids']
        # Create mask_dir as a subdirectory of output_dir
        self.mask_dir = os.path.join(
            self.output_dir,
            'mask_channels_' + '-'.join(map(str, self.channel_ids)),
        )
        os.makedirs(self.mask_dir, exist_ok=True)

        self.int2str_len = int2str_len
        self.uniform_struct = uniform_struct
        self.nested_id_dict = nested_id_dict

        assert mask_type in ['otsu', 'unimodal', 'borders_weight_loss_map'], \
            'Masking method invalid, Otsu, borders_weight_loss_map, " +\
            "and unimodal are currently supported'

        self.mask_type = mask_type
        self.mask_ext = mask_ext
Example #6
0
    def __init__(self,
                 input_dir,
                 output_dir,
                 tile_size=[256, 256],
                 step_size=[64, 64],
                 depths=1,
                 time_ids=-1,
                 channel_ids=-1,
                 normalize_channels=-1,
                 slice_ids=-1,
                 pos_ids=-1,
                 hist_clip_limits=None,
                 flat_field_dir=None,
                 image_format='zyx',
                 num_workers=4,
                 int2str_len=3,
                 tile_3d=False):
        """
        Tiles images.
        If tile_dir already exist, it will check which channels are already
        tiled, get indices from them and tile from indices only on the channels
        not already present.

        :param str input_dir: Directory with frames to be tiled
        :param str output_dir: Base output directory
        :param list tile_size: size of the blocks to be cropped
            from the image
        :param list step_size: size of the window shift. In case
            of no overlap, the step size is tile_size. If overlap, step_size <
            tile_size
        :param int/list depths: The z depth for generating stack training data
            Default 1 assumes 2D data for all channels to be tiled.
            For cases where input and target shapes are not the same (e.g. stack
            to 2D) you should specify depths for each channel in tile.channels.
        :param list/int time_ids: Tile given timepoint indices
        :param list/int channel_ids: Tile images in the given channel indices
            default=-1, tile all channels.
        :param list/int normalize_channels: list of booleans matching channel_ids
            indicating if channel should be normalized or not.
        :param int slice_ids: Index of which focal plane acquisition to
            use (for 2D). default=-1 for the whole z-stack
        :param int pos_ids: Position (FOV) indices to use
        :param list hist_clip_limits: lower and upper percentiles used for
            histogram clipping.
        :param str flat_field_dir: Flatfield directory. None if no flatfield
            correction
        :param str image_format: zyx (preferred) or xyz
        :param int num_workers: number of workers for multiprocessing
        :param int int2str_len: number of characters for each idx to be used
            in file names
        :param bool tile_3d: Whether tiling is 3D or 2D
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.normalize_channels = normalize_channels
        self.depths = depths
        self.tile_size = tile_size
        self.step_size = step_size
        self.hist_clip_limits = hist_clip_limits
        self.image_format = image_format
        assert self.image_format in {'zyx', 'xyz'}, \
            'Data format must be zyx or xyz'
        self.num_workers = num_workers
        self.int2str_len = int2str_len
        self.tile_3d = tile_3d

        self.str_tile_step = 'tiles_{}_step_{}'.format(
            '-'.join([str(val) for val in tile_size]),
            '-'.join([str(val) for val in step_size]),
        )
        self.tile_dir = os.path.join(
            output_dir,
            self.str_tile_step,
        )

        # If tile dir already exist, only tile channels not already present
        self.tiles_exist = False
        # If tile dir already exist, things could get messy because we don't
        # have any checks in place for how to add to existing tiles
        try:
            os.makedirs(self.tile_dir, exist_ok=False)
            # make dir for saving indiv meta per image, could be used for
            # tracking job success / fail
            os.makedirs(os.path.join(self.tile_dir, 'meta_dir'),
                        exist_ok=False)
        except FileExistsError as e:
            print("Tile dir exists. Only add untiled channels.")
            self.tiles_exist = True

        # make dir for saving individual meta per image, could be used for
        # tracking job success / fail
        os.makedirs(os.path.join(self.tile_dir, 'meta_dir'), exist_ok=True)

        self.flat_field_dir = flat_field_dir
        self.frames_metadata = aux_utils.read_meta(self.input_dir)
        # Get metadata indices
        metadata_ids, _ = aux_utils.validate_metadata_indices(
            frames_metadata=self.frames_metadata,
            time_ids=time_ids,
            channel_ids=channel_ids,
            slice_ids=slice_ids,
            pos_ids=pos_ids,
            uniform_structure=True)
        self.channel_ids = metadata_ids['channel_ids']
        self.normalize_channels = normalize_channels
        self.time_ids = metadata_ids['time_ids']
        self.slice_ids = metadata_ids['slice_ids']
        self.pos_ids = metadata_ids['pos_ids']
        self.normalize_channels = normalize_channels
        # Determine which channels should be normalized in tiling
        if self.normalize_channels == -1:
            self.normalize_channels = [True] * len(self.channel_ids)
        else:
            assert len(self.normalize_channels) == len(self.channel_ids),\
                "Channel ids {} and normalization list {} mismatch".format(
                    self.channel_ids,
                    self.normalize_channels,
                )
        # If more than one depth is specified, length must match channel ids
        if isinstance(self.depths, list):
            assert len(self.depths) == len(self.channel_ids),\
             "depths ({}) and channels ({}) length mismatch".format(
                self.depths, self.channel_ids,
            )
            # Get max of all specified depths
            max_depth = max(self.depths)
            # Convert channels + depths to dict for lookup
            self.channel_depth = dict(zip(self.channel_ids, self.depths))
        else:
            # If depth is scalar, make depth the same for all channels
            max_depth = self.depths
            self.channel_depth = dict(
                zip(self.channel_ids, [self.depths] * len(self.channel_ids)), )

        # Adjust slice margins
        self.slice_ids = aux_utils.adjust_slice_margins(
            slice_ids=self.slice_ids,
            depth=max_depth,
        )
Example #7
0
    def tile_mask_stack(self,
                        mask_dir,
                        mask_channel,
                        min_fraction,
                        mask_depth=1):
        """
        Tiles images in the specified channels assuming there are masks
        already created in mask_dir. Only tiles above a certain fraction
        of foreground in mask tile will be saved and added to metadata.


        Saves a csv with columns ['time_idx', 'channel_idx', 'pos_idx',
        'slice_idx', 'file_name'] for all the tiles

        :param str mask_dir: Directory containing masks
        :param int mask_channel: Channel number assigned to mask
        :param float min_fraction: Min fraction of foreground in tiled masks
        :param int mask_depth: Depth for mask channel
        """

        # mask depth has to match input or output channel depth
        assert mask_depth <= max(self.channel_depth.values())
        self.mask_depth = mask_depth

        # Mask meta is stored in mask dir. If channel_ids= -1,frames_meta will
        # not contain any rows for mask channel. Assuming structure is same
        # across channels. Get time, pos and slice indices for mask channel

        mask_meta_df = aux_utils.read_meta(mask_dir)
        # TODO: different masks across timepoints (but MaskProcessor generates
        # mask for tp=0 only)
        _, mask_nested_id_dict = aux_utils.validate_metadata_indices(
            frames_metadata=mask_meta_df,
            time_ids=self.time_ids,
            channel_ids=mask_channel,
            slice_ids=self.slice_ids,
            pos_ids=self.pos_ids,
            uniform_structure=False
        )

        # get t, z, p indices for mask_channel
        mask_ch_ids = {}
        for tp_idx, tp_dict in mask_nested_id_dict.items():
            for ch_idx, ch_dict in tp_dict.items():
                if ch_idx == mask_channel:
                    ch0_dict = {mask_channel: ch_dict}
                    mask_ch_ids[tp_idx] = ch0_dict

        # tile mask channel and use the tile indices to tile the rest
        meta_df = self.tile_first_channel(
            channel0_ids=mask_ch_ids,
            channel0_depth=mask_depth,
            cur_mask_dir=mask_dir,
            min_fraction=min_fraction,
            is_mask=True,
        )
        # tile the rest
        self.tile_remaining_channels(
            nested_id_dict=self.nested_id_dict,
            tiled_ch_id=mask_channel,
            cur_meta_df=meta_df,
        )
Example #8
0
    def __init__(self,
                 input_dir,
                 input_channel_id,
                 output_dir,
                 output_channel_id,
                 timepoint_id=0,
                 correct_flat_field=True,
                 focal_plane_idx=0,
                 plot_masks=False):
        """Init

        :param str input_dir: base input dir at the level of individual sample
         images (or the level above timepoint dirs)
        :param int/list/tuple input_channel_id: channel_ids for which masks
         have to be generated
        :param str output_dir: output dir with full path. It is the base dir
         for tiled images
        :param int/list/tuple output_channel_id: channel_ids to be assigned to
         the created masks. Must match the len(input_channel_id), i.e.
         mask(input_channel_id[0])->output_channel_id[0]
        :param int/list/tuple timepoint_id: timepoints to consider
        :param bool correct_flat_field: indicator to apply flat field
         correction
        :param str meta_fname: fname that contains the
         meta info at the sample image level. If None, read from the default
         dir structure
        :param int focal_plane_idx: focal plane acquisition to use
        :param bool plot_masks: Plot input, masks and overlays
        """

        assert os.path.exists(input_dir), 'input_dir does not exist'
        assert os.path.exists(output_dir), 'output_dir does not exist'
        self.input_dir = input_dir
        self.output_dir = output_dir

        self.correct_flat_field = correct_flat_field

        meta_fname = glob.glob(os.path.join(input_dir, "*info.csv"))
        assert len(meta_fname) == 1,\
            "Can't find info.csv file in {}".format(input_dir)
        study_metadata = pd.read_csv(meta_fname[0])

        self.study_metadata = study_metadata
        self.plot_masks = plot_masks

        avail_tp_channels = aux_utils.validate_metadata_indices(
            study_metadata,
            timepoint_ids=timepoint_id,
            channel_ids=input_channel_id)

        msg = 'timepoint_id is not available'
        assert timepoint_id in avail_tp_channels['timepoints'], msg
        if isinstance(timepoint_id, int):
            timepoint_id = [timepoint_id]
        self.timepoint_id = timepoint_id
        # Convert channel to int if there's only one value present
        if isinstance(input_channel_id, (list, tuple)):
            if len(input_channel_id) == 1:
                input_channel_id = input_channel_id[0]
        msg = 'input_channel_id is not available'
        assert input_channel_id in avail_tp_channels['channels'], msg

        msg = 'output_channel_id is already present'
        assert output_channel_id not in avail_tp_channels['channels'], msg
        if isinstance(input_channel_id, (list, tuple)):
            msg = 'input and output channel ids are not of same length'
            assert len(input_channel_id) == len(output_channel_id), msg
        else:
            input_channel_id = [input_channel_id]
            output_channel_id = [output_channel_id]

        self.input_channel_id = input_channel_id
        self.output_channel_id = output_channel_id
        self.focal_plane_idx = focal_plane_idx
Example #9
0
    def predict_on_full_image(self,
                              image_meta,
                              test_samples,
                              focal_plane_idx=None,
                              depth=None,
                              per_tile_overlap=1 / 8,
                              flat_field_correct=False,
                              base_image_dir=None,
                              place_operation='mean'):
        """Tile and run inference on tiles and assemble the full image

        :param pd.DataFrame image_meta: Df with individual image info,
         timepoint', 'channel_num', 'sample_num', 'slice_num', 'fname',
         'size_x_microns', 'size_y_microns', 'size_z_microns'
        :param list test_samples: list of sample numbers to be used in the
         test set
        :param int focal_plane_idx: focal plane to be used
        :param int depth: if 3D - num of slices used for tiling
        :param float per_tile_overlap: percent overlap between successive tiles
        :param bool flat_field_correct: indicator for applying flat field
         correction
        :param str base_image_dir: base directory where images are stored
        :param str place_operation: in ['mean', 'max']. mean for regression tasks,
         max for segmentation tasks
        """

        assert place_operation in ['mean', 'max'], \
            'only mean and max are allowed: %s' % place_operation
        if 'timepoints' not in self.config['dataset']:
            timepoint_ids = -1
        else:
            timepoint_ids = self.config['dataset']['timepoints']

        ip_channel_ids = self.config['dataset']['input_channels']
        op_channel_ids = self.config['dataset']['target_channels']
        tp_channel_ids = aux_utils.validate_metadata_indices(
            image_meta, time_ids=timepoint_ids)
        tp_idx = tp_channel_ids['timepoints']
        tile_size = [
            self.config['network']['height'], self.config['network']['width']
        ]

        if depth is not None:
            assert 'depth' in self.config['network']
            tile_size.insert(0, depth)

        step_size = (1 - per_tile_overlap) * np.array(tile_size)
        step_size = step_size.astype('int')
        step_size[step_size < 1] = 1

        overlap_size = tile_size - step_size
        batch_size = self.config['trainer']['batch_size']

        if flat_field_correct:
            assert base_image_dir is not None
            ff_dir = os.path.join(base_image_dir, 'flat_field_images')
        else:
            ff_dir = None

        for tp in tp_idx:
            # get the meta for all images in tp_dir and channel_dir
            row_idx_ip0 = aux_utils.get_row_idx(image_meta,
                                                tp,
                                                ip_channel_ids[0],
                                                slice_idx=focal_plane_idx)
            ip0_meta = image_meta[row_idx_ip0]

            # get rows corr. to test_samples from this DF
            test_row_ip0 = ip0_meta.loc[ip0_meta['sample_num'].isin(
                test_samples)]
            test_ip0_fnames = test_row_ip0['fname'].tolist()
            test_image_fnames = ([
                fname.split(os.sep)[-1] for fname in test_ip0_fnames
            ])
            tp_dir = str(os.sep).join(test_ip0_fnames[0].split(os.sep)[:-2])
            test_image = np.load(test_ip0_fnames[0])
            _, crop_indices = tile_utils.tile_image(test_image,
                                                    tile_size,
                                                    step_size,
                                                    return_index=True)
            pred_dir = os.path.join(self.config['trainer']['model_dir'],
                                    'predicted_images', 'tp_{}'.format(tp))
            for fname in test_image_fnames:
                target_image = self._read_one(tp_dir, op_channel_ids, fname,
                                              ff_dir)
                input_image = self._read_one(tp_dir, ip_channel_ids, fname,
                                             ff_dir)
                pred_tiles = self._pred_image(input_image, crop_indices,
                                              batch_size)
                pred_image = self._stitch_image(pred_tiles, crop_indices,
                                                input_image.shape, batch_size,
                                                tile_size, overlap_size,
                                                place_operation)
                pred_fname = '{}.npy'.format(fname.split('.')[0])
                for idx, op_ch in enumerate(op_channel_ids):
                    op_dir = os.path.join(pred_dir, 'channel_{}'.format(op_ch))
                    if not os.path.exists(op_dir):
                        os.makedirs(op_dir)
                    np.save(os.path.join(op_dir, pred_fname), pred_image[idx])
                    save_predicted_images([input_image], [target_image],
                                          [pred_image],
                                          os.path.join(op_dir, 'collage'),
                                          output_fname=fname.split('.')[0])