Ejemplo n.º 1
0
 def pre_process(self, input_chunk, transformations_list=None):
     if transformations_list is None:
         return input_chunk
     else:
         mean = input_chunk.mean()
         std = input_chunk.std()
         transformer = transforms.get_transformer(transformations_list,
                                                  mean, std, "test")
         raw_transform = transformer.raw_transform()
         tarray = raw_transform(input_chunk.array)
         tarray = np.squeeze(tarray.numpy(), axis=0)
         transformed_chunk = Chunk(tarray,
                                   global_offset=input_chunk.global_offset)
         return transformed_chunk
Ejemplo n.º 2
0
    def __init__(self, data, return_weight, transformer_config, phase='train'):
        if return_weight:
            assert data['mr'].shape[0] == data['mask'].shape[0] == data[
                'weight'].shape[0]
            self.weight = data['weight']
        else:
            assert data['mr'].shape[0] == data['mask'].shape[0]
        self.raw = data['mr']
        self.labels = data['mask']

        self.phase = phase
        self.return_weight = return_weight
        if self.phase == 'train':
            self.transformer = transforms.get_transformer(transformer_config,
                                                          mean=0,
                                                          std=1,
                                                          phase=phase)
            self.raw_transform = self.transformer.raw_transform()
            self.label_transform = self.transformer.label_transform()
            if return_weight:
                self.weight_transform = self.transformer.weight_transform()
Ejemplo n.º 3
0
    def __init__(self, file_path, patch_shape, stride_shape, phase, transformer_config,
                 raw_internal_path='raw', label_internal_path='label',
                 weight_internal_path=None, slice_builder_cls=SliceBuilder,
                 mirror_padding=False, pad_width=20):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param patch_shape: the shape of the patch DxHxW
        :param stride_shape: the shape of the stride DxHxW
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param transformer_config: data augmentation configuration
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param slice_builder_cls: defines how to sample the patches from the volume
        :param mirror_padding (bool): pad with the reflection of the vector mirrored on the first and last values
            along each axis. Only applicable during the 'test' phase
        :param pad_width: number of voxels padded to the edges of each axis (only if `mirror_padding=True`)
        """
        assert phase in ['train', 'val', 'test']
        self._check_patch_shape(patch_shape)
        self.phase = phase
        self.file_path = file_path

        self.mirror_padding = mirror_padding
        self.pad_width = pad_width

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        with h5py.File(file_path, 'r') as input_file:
            # WARN: we load everything into memory due to hdf5 bug when reading H5 from multiple subprocesses, i.e.
            # File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
            # OSError: Can't read data (inflate() failed)
            self.raws = [input_file[internal_path][...] for internal_path in raw_internal_path]
            # calculate global mean and std for Normalization augmentation
            mean, std = self._calculate_mean_std(self.raws[0])

            self.transformer = transforms.get_transformer(transformer_config, mean, std, phase)
            self.raw_transform = self.transformer.raw_transform()

            if phase != 'test':
                # create label/weight transform only in train/val phase
                self.label_transform = self.transformer.label_transform()
                self.labels = [input_file[internal_path][...] for internal_path in label_internal_path]

                if weight_internal_path is not None:
                    # look for the weight map in the raw file
                    self.weight_maps = [input_file[internal_path][...] for internal_path in weight_internal_path]
                    self.weight_transform = self.transformer.weight_transform()
                else:
                    self.weight_maps = None

                self._check_dimensionality(self.raws, self.labels)
            else:
                # 'test' phase used only for predictions so ignore the label dataset
                self.labels = None
                self.weight_maps = None

                # add mirror padding if needed
                if self.mirror_padding:
                    padded_volumes = []
                    for raw in self.raws:
                        if raw.ndim == 4:
                            channels = [np.pad(r, pad_width=self.pad_width, mode='reflect') for r in raw]
                            padded_volume = np.stack(channels)
                        else:
                            padded_volume = np.pad(raw, pad_width=self.pad_width, mode='reflect')

                        padded_volumes.append(padded_volume)

                    self.raws = padded_volumes

            # build slice indices for raw and label data sets
            slice_builder = slice_builder_cls(self.raws, self.labels, self.weight_maps, patch_shape, stride_shape)
            self.raw_slices = slice_builder.raw_slices
            self.label_slices = slice_builder.label_slices
            self.weight_slices = slice_builder.weight_slices

            self.patch_count = len(self.raw_slices)
            logger.info(f'Number of patches: {self.patch_count}')
Ejemplo n.º 4
0
    def __init__(self,
                 image_cv,
                 seg_cv,
                 id,
                 bounds,
                 mip_level,
                 phase,
                 patch_shape,
                 stride_shape,
                 transformer_config,
                 slice_builder_cls=SliceBuilder,
                 mirror_padding=False,
                 pad_width=20,
                 logfile=None):

        assert phase in ['train', 'val', 'test']
        self.phase = phase
        self._check_patch_shape(patch_shape)
        self.image_cv = image_cv
        self.seg_cv = seg_cv
        self.id = id
        assert isinstance(bounds, list)
        self.bounds = bounds
        self.mip_level = mip_level

        self.mirror_padding = mirror_padding
        self.pad_width = pad_width

        self.logger = get_logger('CloudVolumeDataset', logfile=logfile)

        minx, maxx, miny, maxy, minz, maxz = self.bounds
        self.raws = []
        img = np.squeeze(self.image_cv[minx:maxx, miny:maxy, minz:maxz, 0])
        # transpose (the data is always CDHW)
        img = np.transpose(img, (2, 0, 1))
        self.raws.append(img)

        mean, std = self._calculate_mean_std(self.raws[0])

        self.transformer = transforms.get_transformer(transformer_config, mean,
                                                      std, phase)
        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()
            label = np.squeeze(self.seg_cv[minx:maxx, miny:maxy, minz:maxz, 0])
            label = np.where(label / np.ndarray.max(label) >= 0.2, 1, 0)
            # transpose the label (the data is always CDHW)
            label = np.transpose(label, (2, 0, 1))
            self.labels = [label]

            self.weight_maps = None

            self._check_dimensionality(self.raws, self.labels)
        else:
            self.labels = None
            self.weight_maps = None

            # add mirror padding if needed
            if self.mirror_padding:
                padded_volumes = [
                    np.pad(raw, pad_width=self.pad_width, mode='reflect')
                    for raw in self.raws
                ]
                self.raws = padded_volumes

        #print(self.raws[0].shape, self.labels[0].shape)
        #print(np.min(self.labels[0][:,:,200]), np.max(self.labels[0][:,:,200]))
        #plt.imshow(self.labels[0][:,:,200])
        #plt.show()
        #plt.imshow(self.raws[0][:,:,200])
        #plt.show()

        # build slice indices for raw and label data sets
        slice_builder = slice_builder_cls(self.raws, self.labels,
                                          self.weight_maps, patch_shape,
                                          stride_shape)
        self.raw_slices = slice_builder.raw_slices
        self.label_slices = slice_builder.label_slices
        self.weight_slices = slice_builder.weight_slices

        self.patch_count = len(self.raw_slices)
        self.logger.info(f'Number of patches: {self.patch_count}')
Ejemplo n.º 5
0
    def __init__(self,
                 file_path,
                 patch_shape,
                 stride_shape,
                 phase,
                 transformer_config,
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 slice_builder_cls=SliceBuilder):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param patch_shape: the shape of the patch DxHxW
        :param stride_shape: the shape of the stride DxHxW
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param transformer_config: data augmentation configuration
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param slice_builder_cls: defines how to sample the patches from the volume
        """
        assert phase in ['train', 'val', 'test']
        self._check_patch_shape(patch_shape)
        self.phase = phase
        self.file_path = file_path

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        with h5py.File(file_path, 'r') as input_file:
            # WARN: we load everything into memory due to hdf5 bug when reading H5 from multiple subprocesses, i.e.
            # File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
            # OSError: Can't read data (inflate() failed)
            self.raws = [
                input_file[internal_path][...]
                for internal_path in raw_internal_path
            ]
            # calculate global mean and std for Normalization augmentation
            mean, std = self._calculate_mean_std(self.raws[0])

            self.transformer = transforms.get_transformer(
                transformer_config, mean, std, phase)
            self.raw_transform = self.transformer.raw_transform()

            if phase != 'test':
                # create label/weight transform only in train/val phase
                self.label_transform = self.transformer.label_transform()
                self.labels = [
                    input_file[internal_path][...]
                    for internal_path in label_internal_path
                ]

                if weight_internal_path is not None:
                    # look for the weight map in the raw file
                    self.weight_maps = [
                        input_file[internal_path][...]
                        for internal_path in weight_internal_path
                    ]
                    self.weight_transform = self.transformer.weight_transform()
                else:
                    self.weight_maps = None

                self._check_dimensionality(self.raws, self.labels)
            else:
                # 'test' phase used only for predictions so ignore the label dataset
                self.labels = None
                self.weight_maps = None

            # build slice indices for raw and label data sets
            slice_builder = slice_builder_cls(self.raws, self.labels,
                                              self.weight_maps, patch_shape,
                                              stride_shape)
            self.raw_slices = slice_builder.raw_slices
            self.label_slices = slice_builder.label_slices
            self.weight_slices = slice_builder.weight_slices

            self.patch_count = len(self.raw_slices)
Ejemplo n.º 6
0
    def __init__(self,
                 raw_file_path,
                 patch_shape,
                 stride_shape,
                 phase,
                 transformer_config,
                 label_file_path=None,
                 raw_internal_path='raw',
                 label_internal_path='label',
                 slice_builder_cls=SliceBuilder,
                 pixel_wise_weight=False):
        """
        :param raw_file_path: path to H5 file containing raw data
        :param patch_shape: the shape of the patch DxHxW
        :param stride_shape: the shape of the stride DxHxW
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param transformer_config: data augmentation configuration
        :param label_file_path: path to the H5 file containing label data or 'None' if the labels are stored in the raw
            H5 file
        :param raw_internal_path: H5 internal path to the raw dataset
        :param label_internal_path: H5 internal path to the label dataset
        :param slice_builder_cls: defines how to sample the patches from the volume
        :param pixel_wise_weight: does the raw file contain per pixel weights
        """
        assert phase in ['train', 'val', 'test']
        self._check_patch_shape(patch_shape)
        self.phase = phase
        self.raw_file_path = raw_file_path
        self.raw_file = h5py.File(raw_file_path, 'r')
        self.raw = self.raw_file[raw_internal_path]

        # create raw and label transforms
        mean, std = self.calculate_mean_std()

        self.transformer = transforms.get_transformer(transformer_config, mean,
                                                      std, phase)

        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()

            if pixel_wise_weight:
                # look for the weight map in the raw file
                self.weight_map = self.raw_file['weight_map']
                self.weight_transform = self.transformer.weight_transform()
            else:
                self.weight_map = None

            # if label_file_path is None assume that labels are stored in the raw_file_path as well
            if label_file_path is None:
                self.label_file = self.raw_file
            else:
                self.label_file = h5py.File(label_file_path, 'r')

            self.label = self.label_file[label_internal_path]
            self._check_dimensionality(self.raw, self.label)
        else:
            # 'test' phase used only for predictions so ignore the label dataset
            self.label = None

        slice_builder = slice_builder_cls(self.raw, self.label, patch_shape,
                                          stride_shape)
        self.raw_slices = slice_builder.raw_slices
        self.label_slices = slice_builder.label_slices

        self.patch_count = len(self.raw_slices)
Ejemplo n.º 7
0
    def __init__(self,
                 file_path,
                 patch_shape,
                 phase,
                 transformer_config,
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 skel_internal_path='skelpt',
                 p_uniform=0.2,
                 n_trail=64,
                 pskel_rand_sft=0.25):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param patch_shape: the shape of the patch DxHxW
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param transformer_config: data augmentation configuration
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param p_uniform: 20% sample the whole volume uniformly
        :n_trail, a.k.a. items, number of generate patch/slice for one file 
        :pskel_rand_sft: shift xc,yc,zc by a percentage of patch_shape
        """
        assert phase in ['train', 'val', 'test']
        self._check_patch_shape(patch_shape)
        self.phase = phase
        self.file_path = file_path
        self.p_uniform = p_uniform
        self.n_trail = n_trail
        self.patch_shape = patch_shape
        self.pskel_sft = [int(ps * pskel_rand_sft) for ps in patch_shape]

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        assert isinstance(raw_internal_path, str), 'raw path is str'
        assert isinstance(label_internal_path, str), 'label path is str'
        assert isinstance(skel_internal_path, str), 'skel path is str'
        assert isinstance(
            weight_internal_path,
            str) or weight_internal_path is None, 'weight path is str or None'

        with h5py.File(file_path, 'r') as input_file:
            # WARN: we load everything into memory due to hdf5 bug when reading H5 from multiple subprocesses, i.e.
            # File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
            # OSError: Can't read data (inflate() failed)
            self.raws = input_file[raw_internal_path][...]

            if self.raws.ndim == 4:
                self.zz, self.yy, self.xx = self.raws.shape[1:]
            else:
                self.zz, self.yy, self.xx = self.raws.shape

            # calculate global mean and std for Normalization augmentation
            mean, std = 0.0, 0.5  #self._calculate_mean_std(self.raws[0])

            self.transformer = transforms.get_transformer(
                transformer_config, mean, std, phase)
            self.raw_transform = self.transformer.raw_transform()

            if phase != 'test':
                # create label/weight transform only in train/val phase
                self.label_transform = self.transformer.label_transform()
                self.skelxyz = input_file[skel_internal_path][...]

                self.labels = input_file[label_internal_path][...]

                if weight_internal_path is not None:
                    # look for the weight map in the raw file
                    self.weight_maps = input_file[weight_internal_path][...]
                    self.weight_transform = self.transformer.weight_transform()
                else:
                    self.weight_maps = None

                self._check_dimensionality(self.raws, self.labels)
            else:
                # No 'test' phase
                raise ValueError('rs cannot have a test phase')
Ejemplo n.º 8
0
    def __init__(self,
                 file_path,
                 patch_shape,
                 stride_shape,
                 phase,
                 transformer_config,
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 slice_builder_cls=SliceBuilderGP):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param patch_shape: the shape of the patch DxHxW
        :param stride_shape: the shape of the stride DxHxW
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param transformer_config: data augmentation configuration
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param slice_builder_cls: defines how to sample the patches from the volume
        """
        assert phase in ['train', 'val', 'test']
        self._check_patch_shape(patch_shape)
        self.phase = phase
        self.file_path = file_path

        #        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        #        if isinstance(raw_internal_path, str):
        #            raw_internal_path = [raw_internal_path]
        #        if isinstance(label_internal_path, str):
        #            label_internal_path = [label_internal_path]
        #        if isinstance(weight_internal_path, str):
        #            weight_internal_path = [weight_internal_path]

        with h5py.File(file_path, 'r') as input_file:
            # WARN: we load everything into memory due to hdf5 bug when reading H5 from multiple subprocesses, i.e.
            # File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
            # OSError: Can't read data (inflate() failed)
            self.raws = [input_file[raw_internal_path][...]]
            # calculate global mean and std for Normalization augmentation
            mean, std = 0.0, 0.5  #self._calculate_mean_std(self.raws[0])

            self.transformer = transforms.get_transformer(
                transformer_config, mean, std, phase)
            self.raw_transform = self.transformer.raw_transform()
            self.GP_transform = self.transformer.GP_transform()

            xyz = input_file['xyz'][...]
            if len(xyz) == 6:  # xmin,xmax,ymin,ymax,zmin,zmax
                xmin, xmax, ymin, ymax, zmin, zmax = xyz
                xx, yy, zz = 512, 512, zmax + 1
            elif len(xyz) == 9:  # xmin,xmax,ymin,ymax,zmin,zmax,xx,yy,zz
                xmin, xmax, ymin, ymax, zmin, zmax, xx, yy, zz = xyz
            else:
                raise ValueError('xyz value error')

            # xyz (-1,1) crop get meshgrid
            x_st, x_ed = 2.0 * xmin / float(xx) - 1.0, 2.0 * xmax / float(
                xx) - 1.0
            y_st, y_ed = 2.0 * ymin / float(yy) - 1.0, 2.0 * ymax / float(
                yy) - 1.0
            z_st, z_ed = 2.0 * zmin / float(zz) - 1.0, 2.0 * zmax / float(
                zz) - 1.0
            x_lin = np.linspace(x_st, x_ed, self.raws[0].shape[3])
            y_lin = np.linspace(y_st, y_ed, self.raws[0].shape[2])
            z_lin = np.linspace(z_st, z_ed, self.raws[0].shape[1])

            z_grid, y_grid, x_grid = np.meshgrid(z_lin,
                                                 y_lin,
                                                 x_lin,
                                                 indexing='ij')
            self.GP = [np.stack((x_grid, y_grid, z_grid))]

            if phase != 'test':
                # create label/weight transform only in train/val phase
                self.label_transform = self.transformer.label_transform()
                self.labels = [input_file[label_internal_path][...]]

                if weight_internal_path is not None:
                    # look for the weight map in the raw file
                    self.weight_maps = [input_file[weight_internal_path][...]]
                    self.weight_transform = self.transformer.weight_transform()
                else:
                    self.weight_maps = None

                self._check_dimensionality(self.raws, self.labels)
            else:
                # 'test' phase used only for predictions so ignore the label dataset
                self.labels = None
                self.weight_maps = None

            # build slice indices for raw and label data sets
            slice_builder = slice_builder_cls(self.raws, self.labels,
                                              self.weight_maps, self.GP,
                                              patch_shape, stride_shape)
            self.raw_slices = slice_builder.raw_slices
            self.label_slices = slice_builder.label_slices
            self.weight_slices = slice_builder.weight_slices
            self.GP_slices = slice_builder.GP_slices

            self.patch_count = len(self.raw_slices)