Esempio n. 1
0
    def __init__(self, root_dir, phase, transformer_config, mirror_padding=(0, 32, 32), expand_dims=True):
        assert os.path.isdir(root_dir), 'root_dir is not a directory'
        assert phase in ['train', 'val', 'test']

        # use mirror padding only during the 'test' phase
        if phase in ['train', 'val']:
            mirror_padding = None
        if mirror_padding is not None:
            assert len(mirror_padding) == 3, f"Invalid mirror_padding: {mirror_padding}"
        self.mirror_padding = mirror_padding

        self.phase = phase

        # load raw images
        images_dir = os.path.join(root_dir, 'images')
        assert os.path.isdir(images_dir)
        self.images, self.paths = self._load_files(images_dir, expand_dims)
        self.file_path = images_dir

        min_value, max_value, mean, std = calculate_stats(self.images)
        logger.info(f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}')

        transformer = transforms.get_transformer(transformer_config, min_value=min_value, max_value=max_value,
                                                 mean=mean, std=std)

        # load raw images transformer
        self.raw_transform = transformer.raw_transform()

        if phase != 'test':
            # load labeled images
            masks_dir = os.path.join(root_dir, 'masks')
            assert os.path.isdir(masks_dir)
            self.masks, _ = self._load_files(masks_dir, expand_dims)
            assert len(self.images) == len(self.masks)
            # load label images transformer
            self.masks_transform = transformer.label_transform()
        else:
            self.masks = None
            self.masks_transform = None

            # add mirror padding if needed
            if self.mirror_padding is not None:
                z, y, x = self.mirror_padding
                pad_width = ((z, z), (y, y), (x, x))
                padded_imgs = []
                for img in self.images:
                    padded_img = np.pad(img, pad_width=pad_width, mode='reflect')
                    padded_imgs.append(padded_img)

                self.images = padded_imgs
    def __init__(self,
                 file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 mirror_padding=(16, 32, 32),
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :para'/home/adrian/workspace/ilastik-datasets/VolkerDeconv/train'm slice_builder_config: configuration of the SliceBuilder
        :param transformer_config: data augmentation configuration
        :param mirror_padding (int or tuple): number of voxels padded to each axis
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        """
        assert phase in ['train', 'val', 'test']
        if phase in ['train', 'val']:
            mirror_padding = None

        if mirror_padding is not None:
            if isinstance(mirror_padding, int):
                mirror_padding = (mirror_padding, ) * 3
            else:
                assert len(mirror_padding
                           ) == 3, f"Invalid mirror_padding: {mirror_padding}"

        self.mirror_padding = mirror_padding
        self.phase = phase
        self.file_path = file_path

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        internal_paths = list(raw_internal_path)
        if label_internal_path is not None:
            internal_paths.extend(label_internal_path)
        if weight_internal_path is not None:
            internal_paths.extend(weight_internal_path)

        input_file = self.create_h5_file(file_path, internal_paths)

        self.raws = self.fetch_and_check(input_file, raw_internal_path)

        # calculate global min, max, mean and std for normalization
        min_value, max_value, mean, std = calculate_stats(self.raws)
        logger.info(
            f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}'
        )

        self.transformer = transforms.get_transformer(transformer_config,
                                                      min_value=min_value,
                                                      max_value=max_value,
                                                      mean=mean,
                                                      std=std)
        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()

            # Fetch labels for segmentation and regression separately

            self.labels_seg = self.fetch_and_check(input_file,
                                                   label_internal_path)
            self.labels_reg = self.fetch_and_check(input_file,
                                                   label_internal_path)

            if weight_internal_path is not None:
                # look for the weight map in the raw file
                self.weight_maps = self.fetch_and_check(
                    input_file, weight_internal_path)
                self.weight_transform = self.transformer.weight_transform()
            else:
                self.weight_maps = None

            self._check_dimensionality(self.raws, self.labels_seg)
            self._check_dimensionality(self.raws, self.labels_reg)
        else:
            # 'test' phase used only for predictions so ignore the label dataset
            self.labels = None
            self.weight_maps = None

            # add mirror padding if needed
            if self.mirror_padding is not None:
                z, y, x = self.mirror_padding
                pad_width = ((z, z), (y, y), (x, x))
                padded_volumes = []
                for raw in self.raws:
                    if raw.ndim == 4:
                        channels = [
                            np.pad(r, pad_width=pad_width, mode='reflect')
                            for r in raw
                        ]
                        padded_volume = np.stack(channels)
                    else:
                        padded_volume = np.pad(raw,
                                               pad_width=pad_width,
                                               mode='reflect')

                    padded_volumes.append(padded_volume)

                self.raws = padded_volumes

        if phase == 'train' or phase == 'val':

            # Dilate the labels for segmentation

            self.dilation_list = copy.deepcopy(self.labels_seg)
            self.dilated_labels = []
            for each_vol in self.dilation_list:
                self.label_dilate = each_vol
                self.inv_label = np.logical_not(self.label_dilate)
                self.label_dist_transform = ndimage.distance_transform_edt(
                    self.inv_label)
                self.label_thresh_tr = self.label_dist_transform > 2

                self.label_thresh_tr = np.logical_not(
                    self.label_thresh_tr).astype(np.float64)
                self.dilated_labels.append(self.label_thresh_tr)

            self.labels_seg = self.dilated_labels

            # EDT of labels for regression (not including vector edt currently)

            self.edt_list = copy.deepcopy(self.labels_reg)
            self.edt_labels = []
            for each_edt_vol in self.edt_list:
                self.edt_label = vigra.filters.distanceTransform(
                    each_edt_vol[0].astype(np.float32))
                self.edt_label = np.expand_dims(self.edt_label, axis=0)

                #self.vec_edt = vigra.filters.vectorDistanceTransform(each_edt_vol[0].astype(np.float32))
                #self.vec_edt = np.transpose(self.vec_edt, (3,0,1,2))

                #self.new_label = np.concatenate((self.edt_label, self.vec_edt), axis=0)
                #self.edt_labels.append(self.new_label)

                self.edt_labels.append(self.edt_label)

            self.labels_reg = self.edt_labels

            # build slice indices for raw and label data sets

        if phase == 'train' or phase == 'val':
            slice_builder_seg = get_slice_builder(self.raws, self.labels_seg,
                                                  self.weight_maps,
                                                  slice_builder_config)
            slice_builder_reg = get_slice_builder(self.raws, self.labels_reg,
                                                  self.weight_maps,
                                                  slice_builder_config)

        # For testing, get only the raw slices

        slice_builder = get_slice_builder(self.raws, None, self.weight_maps,
                                          slice_builder_config)
        self.raw_slices = slice_builder.raw_slices

        if phase == 'train' or phase == 'val':
            self.label_slices_seg = slice_builder_seg.label_slices_seg
            self.label_slices_reg = slice_builder_reg.label_slices_reg

        self.weight_slices = slice_builder.weight_slices

        self.patch_count = len(self.raw_slices)
        logger.info(f'Number of patches: {self.patch_count}')
Esempio n. 3
0
    def __init__(self, file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 mirror_padding=(16, 32, 32),
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :para'/home/adrian/workspace/ilastik-datasets/VolkerDeconv/train'm slice_builder_config: configuration of the SliceBuilder
        :param transformer_config: data augmentation configuration
        :param mirror_padding (int or tuple): number of voxels padded to each axis
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        """
        assert phase in ['train', 'val', 'test']
        if phase in ['train', 'val']:
            mirror_padding = None

        if mirror_padding is not None:
            if isinstance(mirror_padding, int):
                mirror_padding = (mirror_padding,) * 3
            else:
                assert len(mirror_padding) == 3, f"Invalid mirror_padding: {mirror_padding}"

        self.mirror_padding = mirror_padding
        self.phase = phase
        self.file_path = file_path

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        internal_paths = list(raw_internal_path)
        if label_internal_path is not None:
            internal_paths.extend(label_internal_path)
        if weight_internal_path is not None:
            internal_paths.extend(weight_internal_path)

        input_file = self.create_h5_file(file_path, internal_paths)

        self.raws = self.fetch_datasets(input_file, raw_internal_path)

        # calculate global min, max, mean and std for normalization
        min_value, max_value, mean, std = calculate_stats(self.raws)
        logger.info(f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}')

        self.transformer = transforms.get_transformer(transformer_config, min_value=min_value, max_value=max_value,
                                                      mean=mean, std=std)
        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()
            self.labels = self.fetch_datasets(input_file, label_internal_path)

            if weight_internal_path is not None:
                # look for the weight map in the raw file
                self.weight_maps = self.fetch_datasets(input_file, weight_internal_path)
                self.weight_transform = self.transformer.weight_transform()
            else:
                self.weight_maps = None

            self._check_dimensionality(self.raws, self.labels)
        else:
            # 'test' phase used only for predictions so ignore the label dataset
            self.labels = None
            self.weight_maps = None

            # add mirror padding if needed
            if self.mirror_padding is not None:
                z, y, x = self.mirror_padding
                pad_width = ((z, z), (y, y), (x, x))
                padded_volumes = []
                for raw in self.raws:
                    if raw.ndim == 4:
                        channels = [np.pad(r, pad_width=pad_width, mode='reflect') for r in raw]
                        padded_volume = np.stack(channels)
                    else:
                        padded_volume = np.pad(raw, pad_width=pad_width, mode='reflect')

                    padded_volumes.append(padded_volume)

                self.raws = padded_volumes

        # build slice indices for raw and label data sets
        slice_builder = get_slice_builder(self.raws, self.labels, self.weight_maps, slice_builder_config)
        self.raw_slices = slice_builder.raw_slices
        self.label_slices = slice_builder.label_slices
        self.weight_slices = slice_builder.weight_slices

        self.patch_count = len(self.raw_slices)
        logger.info(f'Number of patches: {self.patch_count}')
Esempio n. 4
0
    def __init__(self,
                 root_dir,
                 phase,
                 transformer_config,
                 mirror_padding=(0, 32, 32),
                 expand_dims=True,
                 instance_ratio=None,
                 random_seed=0):
        assert os.path.isdir(root_dir), f'{root_dir} is not a directory'
        assert phase in ['train', 'val', 'test']

        # use mirror padding only during the 'test' phase
        if phase in ['train', 'val']:
            mirror_padding = None
        if mirror_padding is not None:
            assert len(mirror_padding
                       ) == 3, f"Invalid mirror_padding: {mirror_padding}"
        self.mirror_padding = mirror_padding

        self.phase = phase

        # load raw images
        images_dir = os.path.join(root_dir, 'images')
        assert os.path.isdir(images_dir)
        self.images, self.paths = self._load_files(images_dir, expand_dims)
        self.file_path = images_dir
        self.instance_ratio = instance_ratio

        min_value, max_value, mean, std = calculate_stats(self.images)
        logger.info(
            f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}'
        )

        transformer = transforms.get_transformer(transformer_config,
                                                 min_value=min_value,
                                                 max_value=max_value,
                                                 mean=mean,
                                                 std=std)

        # load raw images transformer
        self.raw_transform = transformer.raw_transform()

        if phase != 'test':
            # load labeled images
            masks_dir = os.path.join(root_dir, 'masks')
            assert os.path.isdir(masks_dir)
            self.masks, _ = self._load_files(masks_dir, expand_dims)
            # prepare for training with sparse object supervision (allow sparse objects only in training phase)
            if self.instance_ratio is not None and phase == 'train':
                assert 0 < self.instance_ratio <= 1
                rs = np.random.RandomState(random_seed)
                self.masks = [
                    sample_instances(m, self.instance_ratio, rs)
                    for m in self.masks
                ]
            assert len(self.images) == len(self.masks)
            # load label images transformer
            self.masks_transform = transformer.label_transform()
        else:
            self.masks = None
            self.masks_transform = None

            # add mirror padding if needed
            if self.mirror_padding is not None:
                z, y, x = self.mirror_padding
                pad_width = ((z, z), (y, y), (x, x))
                padded_imgs = []
                for img in self.images:
                    padded_img = np.pad(img,
                                        pad_width=pad_width,
                                        mode='reflect')
                    padded_imgs.append(padded_img)

                self.images = padded_imgs
Esempio n. 5
0
    def __init__(self,
                 file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 mirror_padding=(16, 32, 32),
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :para'/home/adrian/workspace/ilastik-datasets/VolkerDeconv/train'm slice_builder_config: configuration of the SliceBuilder
        :param transformer_config: data augmentation configuration
        :param mirror_padding (int or tuple): number of voxels padded to each axis
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        """
        assert phase in ['train', 'val', 'test']
        if phase in ['train', 'val']:
            mirror_padding = None

        if mirror_padding is not None:
            if isinstance(mirror_padding, int):
                mirror_padding = (mirror_padding, ) * 3
            else:
                assert len(mirror_padding
                           ) == 3, f"Invalid mirror_padding: {mirror_padding}"

        self.mirror_padding = mirror_padding
        self.phase = phase
        self.file_path = file_path

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        internal_paths = list(raw_internal_path)
        if label_internal_path is not None:
            internal_paths.extend(label_internal_path)
        if weight_internal_path is not None:
            internal_paths.extend(weight_internal_path)

        input_file = self.create_h5_file(file_path, internal_paths)

        # label dilation routine added here
        ##### START #####

        # self.label = input_file['label']
        # self.inv_label = np.logical_not(self.label)

        # self.label_dist_transform = ndimage.distance_transform_edt(self.inv_label)
        # self.label_thresh_tr = self.label_dist_transform > 3

        # self.label_thresh_tr = np.logical_not(self.label_thresh_tr).astype(np.float64)

        # input_file['label'] = self.label_thresh_tr

        ##### END #####

        self.raws = self.fetch_and_check(input_file, raw_internal_path)

        # calculate global min, max, mean and std for normalization
        min_value, max_value, mean, std = calculate_stats(self.raws)
        logger.info(
            f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}'
        )

        self.transformer = transforms.get_transformer(transformer_config,
                                                      min_value=min_value,
                                                      max_value=max_value,
                                                      mean=mean,
                                                      std=std)
        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()
            self.labels = self.fetch_and_check(input_file, label_internal_path)

            if weight_internal_path is not None:
                # look for the weight map in the raw file
                self.weight_maps = self.fetch_and_check(
                    input_file, weight_internal_path)
                self.weight_transform = self.transformer.weight_transform()
            else:
                self.weight_maps = None

            self._check_dimensionality(self.raws, self.labels)
        else:
            # 'test' phase used only for predictions so ignore the label dataset
            self.labels = None
            self.weight_maps = None

            # add mirror padding if needed
            if self.mirror_padding is not None:
                z, y, x = self.mirror_padding
                pad_width = ((z, z), (y, y), (x, x))
                padded_volumes = []
                for raw in self.raws:
                    if raw.ndim == 4:
                        channels = [
                            np.pad(r, pad_width=pad_width, mode='reflect')
                            for r in raw
                        ]
                        padded_volume = np.stack(channels)
                    else:
                        padded_volume = np.pad(raw,
                                               pad_width=pad_width,
                                               mode='reflect')

                    padded_volumes.append(padded_volume)

                self.raws = padded_volumes

        # for Thresh_IoU eval metric
        # if phase != 'test':

        ###### temporal batch passing channel setup

        # # Modify only the train vol ground truth for dilation and thus better context
        # if phase != 'test': # for peak matching eval metric
        #     self.dilation_list = copy.deepcopy(self.labels)
        #     self.dilated_labels = []

        #     # apply dilation to each vol in the batch
        #     for each_vol in self.dilation_list:
        #         self.label_dilate = each_vol

        #         self.vol_channels = []

        #         for channel_num, ch in enumerate(self.label_dilate):
        #             self.inv_label = np.logical_not(self.label_dilate[channel_num])
        #             self.label_dist_transform = ndimage.distance_transform_edt(self.inv_label)
        #             self.label_thresh_tr = self.label_dist_transform > 2

        #             self.label_thresh_tr = np.logical_not(self.label_thresh_tr).astype(np.float64)

        #             self.vol_channels.append(self.label_thresh_tr)

        #         self.output_thresh_tr = torch.concatenate((self.vol_channels[0], self.vol_channels[1], self.vol_channels[2]), axis=0)

        #         self.dilated_labels.append(self.output_thresh_tr)

        #     self.labels = self.dilated_labels

        if phase == '!test':
            self.dilation_list = copy.deepcopy(self.labels)
            self.dilated_labels = []
            for each_vol in self.dilation_list:
                self.label_dilate = each_vol
                self.inv_label = np.logical_not(self.label_dilate)
                self.label_dist_transform = ndimage.distance_transform_edt(
                    self.inv_label)
                self.label_thresh_tr = self.label_dist_transform > 2

                self.label_thresh_tr = np.logical_not(
                    self.label_thresh_tr).astype(np.float64)
                self.dilated_labels.append(self.label_thresh_tr)

            self.labels = self.dilated_labels

        # build slice indices for raw and label data sets
        slice_builder = get_slice_builder(self.raws, self.labels,
                                          self.weight_maps,
                                          slice_builder_config)
        self.raw_slices = slice_builder.raw_slices
        self.label_slices = slice_builder.label_slices
        self.weight_slices = slice_builder.weight_slices

        self.patch_count = len(self.raw_slices)
        logger.info(f'Number of patches: {self.patch_count}')
Esempio n. 6
0
    def __init__(self,
                 file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 mirror_padding=(16, 32, 32),
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 instance_ratio=None,
                 random_seed=0):
        """
        :param file_path: path to nifti file from dicom file
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :para'/home/adrian/workspace/ilastik-datasets/VolkerDeconv/train'm slice_builder_config: configuration of the SliceBuilder
        :param transformer_config: data augmentation configuration
        :param mirror_padding (int or tuple): number of voxels padded to each axis
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param a number between (0, 1]: specifies a fraction of ground truth instances to be sampled from the dense ground truth labels
        """
        assert phase in ['train', 'val', 'test']
        if phase in ['train', 'val']:
            raise
            mirror_padding = None

        if mirror_padding is not None:
            if isinstance(mirror_padding, int):
                mirror_padding = (mirror_padding, ) * 3
            else:
                assert len(mirror_padding
                           ) == 3, f"Invalid mirror_padding: {mirror_padding}"

        self.mirror_padding = mirror_padding
        self.phase = phase
        self.file_path = file_path

        self.instance_ratio = instance_ratio

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        nifti = nib.load(file_path)
        arr = nifti.get_fdata()
        self.raws = [np.fliplr(arr).transpose((2, 1, 0))]
        self.affine = [nifti.affine]
        self.header = [nifti.header]

        min_value, max_value, mean, std = None, None, None, None  #self.ds_stats()

        self.transformer = transforms.get_transformer(transformer_config,
                                                      min_value=min_value,
                                                      max_value=max_value,
                                                      mean=mean,
                                                      std=std)
        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            raise
            # not yet

            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()
            self.labels = self.fetch_and_check(input_file, label_internal_path)

            if self.instance_ratio is not None:
                assert 0 < self.instance_ratio <= 1
                rs = np.random.RandomState(random_seed)
                self.labels = [
                    sample_instances(m, self.instance_ratio, rs)
                    for m in self.labels
                ]

            if weight_internal_path is not None:
                # look for the weight map in the raw file
                self.weight_maps = self.fetch_and_check(
                    input_file, weight_internal_path)
                self.weight_transform = self.transformer.weight_transform()
            else:
                self.weight_maps = None

            self._check_dimensionality(self.raws, self.labels)
        else:
            # 'test' phase used only for predictions so ignore the label dataset
            self.labels = None
            self.weight_maps = None

            # add mirror padding if needed
            if self.mirror_padding is not None:
                z, y, x = self.mirror_padding
                pad_width = ((z, z), (y, y), (x, x))
                padded_volumes = []
                for raw in self.raws:
                    if raw.ndim == 4:
                        channels = [
                            np.pad(r, pad_width=pad_width, mode='reflect')
                            for r in raw
                        ]
                        padded_volume = np.stack(channels)
                    else:
                        padded_volume = np.pad(raw,
                                               pad_width=pad_width,
                                               mode='reflect')

                    padded_volumes.append(padded_volume)

                self.raws = padded_volumes

        # build slice indices for raw and label data sets
        assert 'name' in slice_builder_config
        if slice_builder_config['name'] == "AroundCenterSliceBuilder":
            assert 'centers_internal_path' in slice_builder_config
            centers_internal_path = slice_builder_config.get(
                'centers_internal_path', 'centers')
            self.centers = self.fetch_and_check(input_file,
                                                [centers_internal_path])
            slice_builder = get_slice_builder(self.raws,
                                              self.labels,
                                              self.weight_maps,
                                              slice_builder_config,
                                              centers=self.centers)
        else:
            slice_builder = get_slice_builder(self.raws, self.labels,
                                              self.weight_maps,
                                              slice_builder_config)
        self.raw_slices = slice_builder.raw_slices
        self.label_slices = slice_builder.label_slices
        self.weight_slices = slice_builder.weight_slices

        self.patch_count = len(self.raw_slices)
        logger.info(f'Number of patches: {self.patch_count}')
Esempio n. 7
0
    def __init__(self,
                 file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 mirror_padding=False,
                 pad_width=20):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param slice_builder_config: configuration of the SliceBuilder
        :param transformer_config: data augmentation configuration
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param mirror_padding (bool): pad with the reflection of the vector mirrored on the first and last values
            along each axis. Only applicable during the 'test' phase
        :param pad_width: number of voxels padded to the edges of each axis (only if `mirror_padding=True`)
        """
        assert phase in ['train', 'val', 'test']
        self.phase = phase
        self.file_path = file_path
        self.mirror_padding = mirror_padding
        self.pad_width = pad_width

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        with h5py.File(file_path, 'r') as input_file:
            # WARN: we load everything into memory due to hdf5 bug when reading H5 from multiple subprocesses, i.e.
            # File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
            # OSError: Can't read data (inflate() failed)
            self.raws = [
                input_file[internal_path][...]
                for internal_path in raw_internal_path
            ]
            # calculate global min, max, mean and std for normalization
            min_value, max_value, mean, std = self._calculate_stats(self.raws)
            logger.info(
                f'Input stats: min={min_value}, max={max_value}, mean={mean}, std={std}'
            )

            self.transformer = transforms.get_transformer(transformer_config,
                                                          min_value=min_value,
                                                          max_value=max_value,
                                                          mean=mean,
                                                          std=std)
            self.raw_transform = self.transformer.raw_transform()

            if phase != 'test':
                # create label/weight transform only in train/val phase
                self.label_transform = self.transformer.label_transform()
                self.labels = [
                    input_file[internal_path][...]
                    for internal_path in label_internal_path
                ]

                if weight_internal_path is not None:
                    # look for the weight map in the raw file
                    self.weight_maps = [
                        input_file[internal_path][...]
                        for internal_path in weight_internal_path
                    ]
                    self.weight_transform = self.transformer.weight_transform()
                else:
                    self.weight_maps = None

                self._check_dimensionality(self.raws, self.labels)
            else:
                # 'test' phase used only for predictions so ignore the label dataset
                self.labels = None
                self.weight_maps = None

                # add mirror padding if needed
                if self.mirror_padding:
                    padded_volumes = []
                    for raw in self.raws:
                        if raw.ndim == 4:
                            channels = [
                                np.pad(r,
                                       pad_width=self.pad_width,
                                       mode='reflect') for r in raw
                            ]
                            padded_volume = np.stack(channels)
                        else:
                            padded_volume = np.pad(raw,
                                                   pad_width=self.pad_width,
                                                   mode='reflect')

                        padded_volumes.append(padded_volume)

                    self.raws = padded_volumes

            # build slice indices for raw and label data sets
            slice_builder = _get_slice_builder(self.raws, self.labels,
                                               self.weight_maps,
                                               slice_builder_config)
            self.raw_slices = slice_builder.raw_slices
            self.label_slices = slice_builder.label_slices
            self.weight_slices = slice_builder.weight_slices

            self.patch_count = len(self.raw_slices)
            logger.info(f'Number of patches: {self.patch_count}')