コード例 #1
0
    def __init__(
            self,
            data,
            roi_size=64,
            zoom_range=(0.8, 1.25),
            samples_per_epoch=100000,
    ):
        self.data = data
        self.roi_size = roi_size

        rotation_pad_size = math.ceil(self.roi_size * (math.sqrt(2) - 1) / 2)
        padded_roi_size = roi_size + 2 * rotation_pad_size

        self.transforms = [
            RandomCrop(padded_roi_size),
            AffineTransform(zoom_range),
            RemovePadding(rotation_pad_size),
            RandomVerticalFlip(),
            RandomHorizontalFlip(),
        ]

        self.transforms = Compose(self.transforms)

        self.samples_per_epoch = samples_per_epoch
コード例 #2
0
    def initialize(self, pyramid_images_output_path=None):
        if len(self.pyramid_factors) == 0:
            self.add_pyramid_level(1, 0.0)
        if len(self.initial_transforms) == 0:
            self.add_initial_transform(AffineTransform(self.dim))

        ### Preprocessing

        pyramid_levels = len(self.pyramid_factors)

        for i in range(pyramid_levels):
            factor = self.pyramid_factors[i]

            ref_resampled = filters.downsample(
                filters.gaussian_filter(self.ref_im, self.pyramid_sigmas[i]),
                factor)
            flo_resampled = filters.downsample(
                filters.gaussian_filter(self.flo_im, self.pyramid_sigmas[i]),
                factor)

            ref_mask_resampled = filters.downsample(self.ref_mask, factor)
            flo_mask_resampled = filters.downsample(self.flo_mask, factor)

            ref_resampled = filters.normalize(ref_resampled, 0.0,
                                              ref_mask_resampled)
            flo_resampled = filters.normalize(flo_resampled, 0.0,
                                              flo_mask_resampled)

            if pyramid_images_output_path is not None and ref_resampled.ndim == 2:
                scipy.misc.imsave(
                    '%sref_resampled_%d.png' %
                    (pyramid_images_output_path, i + 1), ref_resampled)
                scipy.misc.imsave(
                    '%sflo_resampled_%d.png' %
                    (pyramid_images_output_path, i + 1), flo_resampled)

            if self.ref_weights is None:
                ref_weights = np.zeros(ref_resampled.shape)
                ref_weights[ref_mask_resampled] = 1.0
            else:
                ref_weights = filters.downsample(self.ref_weights, factor)
            if self.flo_weights is None:
                flo_weights = np.zeros(flo_resampled.shape)
                flo_weights[flo_mask_resampled] = 1.0
            else:
                flo_weights = filters.downsample(self.flo_weights, factor)

            ref_diag = np.sqrt(
                np.square(np.array(ref_resampled.shape) *
                          self.ref_spacing).sum())
            flo_diag = np.sqrt(
                np.square(np.array(flo_resampled.shape) *
                          self.flo_spacing).sum())

            q_ref = QuantizedImage(ref_resampled,
                                   self.alpha_levels,
                                   ref_weights,
                                   self.ref_spacing * factor,
                                   remove_zero_weight_pnts=True)
            q_flo = QuantizedImage(flo_resampled,
                                   self.alpha_levels,
                                   flo_weights,
                                   self.flo_spacing * factor,
                                   remove_zero_weight_pnts=True)

            tf_ref = alpha_amd.AlphaAMD(q_ref,
                                        self.alpha_levels,
                                        ref_diag,
                                        self.ref_spacing * factor,
                                        ref_mask_resampled,
                                        ref_mask_resampled,
                                        interpolator_mode='linear',
                                        dt_fun=None,
                                        mask_out_edges=True)
            tf_flo = alpha_amd.AlphaAMD(q_flo,
                                        self.alpha_levels,
                                        flo_diag,
                                        self.flo_spacing * factor,
                                        flo_mask_resampled,
                                        flo_mask_resampled,
                                        interpolator_mode='linear',
                                        dt_fun=None,
                                        mask_out_edges=True)

            symmetric_measure = True
            squared_measure = False

            sym_dist = symmetric_amd_distance.SymmetricAMDDistance(
                symmetric_measure=symmetric_measure,
                squared_measure=squared_measure)

            sym_dist.set_ref_image_source(q_ref)
            sym_dist.set_ref_image_target(tf_ref)

            sym_dist.set_flo_image_source(q_flo)
            sym_dist.set_flo_image_target(tf_flo)

            sym_dist.set_sampling_fraction(self.sampling_fraction)

            sym_dist.initialize()

            self.distances.append(sym_dist)
コード例 #3
0
def main():
    np.random.seed(1000)
    
    if len(sys.argv) < 3:
        print('register_example.py: Too few parameters. Give the path to two gray-scale image files.')
        print('Example: python2 register_example.py reference_image floating_image')
        return False

    ref_im_path = sys.argv[1]
    flo_im_path = sys.argv[2]

    ref_im = scipy.misc.imread(ref_im_path, 'L')
    flo_im = scipy.misc.imread(flo_im_path, 'L')

    # Save copies of original images
    ref_im_orig = ref_im.copy()
    flo_im_orig = flo_im.copy()

    ref_im = filters.normalize(ref_im, 0.0, None)
    flo_im = filters.normalize(flo_im, 0.0, None)
    
    diag = 0.5 * (transforms.image_diagonal(ref_im, spacing) + transforms.image_diagonal(flo_im, spacing))

    weights1 = np.ones(ref_im.shape)
    mask1 = np.ones(ref_im.shape, 'bool')
    weights2 = np.ones(flo_im.shape)
    mask2 = np.ones(flo_im.shape, 'bool')

    # Initialize registration framework for 2d images
    reg = Register(2)

    reg.set_report_freq(param_report_freq)
    reg.set_alpha_levels(alpha_levels)

    reg.set_reference_image(ref_im)
    reg.set_reference_mask(mask1)
    reg.set_reference_weights(weights1)

    reg.set_floating_image(flo_im)
    reg.set_floating_mask(mask2)
    reg.set_floating_weights(weights2)

    # Setup the Gaussian pyramid resolution levels
    
    reg.add_pyramid_level(4, 5.0)
    reg.add_pyramid_level(2, 3.0)
    reg.add_pyramid_level(1, 0.0)

    # Learning-rate / Step lengths [[start1, end1], [start2, end2] ...] (for each pyramid level)
    step_lengths = np.array([[1.0 ,1.0], [1.0, 0.5], [0.5, 0.1]])

    # Create the transform and add it to the registration framework (switch between affine/rigid transforms by commenting/uncommenting)
    # Affine
    reg.add_initial_transform(AffineTransform(2), np.array([1.0/diag, 1.0/diag, 1.0/diag, 1.0/diag, 1.0, 1.0]))
    # Rigid 2D
    #reg.add_initial_transform(Rigid2DTransform(2), np.array([1.0/diag, 1.0, 1.0]))

    # Set the parameters
    reg.set_iterations(param_iterations)
    reg.set_gradient_magnitude_threshold(0.001)
    reg.set_sampling_fraction(param_sampling_fraction)
    reg.set_step_lengths(step_lengths)

    # Create output directory
    directory = os.path.dirname('./test_images/output/')
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Start the pre-processing
    reg.initialize('./test_images/output/')
    
    # Control the formatting of numpy
    np.set_printoptions(suppress=True, linewidth=200)

    # Start the registration
    reg.run()

    (transform, value) = reg.get_output(0)

    ### Warp final image
    c = transforms.make_image_centered_transform(transform, ref_im, flo_im, spacing, spacing)

    # Print out transformation parameters
    print('Transformation parameters: %s.' % str(transform.get_params()))

    # Create the output image
    ref_im_warped = np.zeros(ref_im.shape)

    # Transform the floating image into the reference image space by applying transformation 'c'
    c.warp(In = flo_im_orig, Out = ref_im_warped, in_spacing=spacing, out_spacing=spacing, mode='spline', bg_value = 0.0)

    # Save the registered image
    scipy.misc.imsave('./test_images/output/registered.png', ref_im_warped)

    # Compute the absolute difference image between the reference and registered images
    D1 = np.abs(ref_im_orig-ref_im_warped)
    err = np.sum(D1)
    print("Err: %f" % err)

    scipy.misc.imsave('./test_images/output/diff.png', D1)

    return True
コード例 #4
0
    def __init__(
        self,
        data_src,
        folds2include=None,
        num_folds=5,
        samples_per_epoch=2000,
        roi_size=96,
        scale_int=(0, 255),
        norm_mean=0.,
        norm_sd=1.,
        zoom_range=(0.90, 1.1),
        prob_unseeded_patch=0.2,
        int_aug_offset=None,
        int_aug_expansion=None,
        valid_labels=None,  # if this is None it will include all the available labels
        is_preloaded=False,
        max_input_size=2048  #if any image is larger than this it will be splitted (only working with folds)
    ):

        _dum = set(dir(self))

        self.data_src = Path(data_src)
        if not self.data_src.exists():
            raise ValueError(f'`data_src` : `{data_src}` does not exists.')

        self.folds2include = folds2include
        self.num_folds = num_folds

        self.samples_per_epoch = samples_per_epoch
        self.roi_size = roi_size
        self.scale_int = scale_int

        self.norm_mean = norm_mean
        self.norm_sd = norm_sd

        self.zoom_range = zoom_range
        self.prob_unseeded_patch = prob_unseeded_patch
        self.int_aug_offset = int_aug_offset
        self.int_aug_expansion = int_aug_expansion
        self.valid_labels = valid_labels
        self.is_preloaded = is_preloaded
        self.max_input_size = max_input_size

        self._input_names = list(
            set(dir(self)) - _dum
        )  #i want the name of this fields so i can access them if necessary

        rotation_pad_size = math.ceil(self.roi_size * (math.sqrt(2) - 1) / 2)
        padded_roi_size = roi_size + 2 * rotation_pad_size

        transforms_random = [
            RandomCropWithSeeds(padded_roi_size, rotation_pad_size,
                                prob_unseeded_patch),
            AffineTransform(zoom_range),
            RemovePadding(rotation_pad_size),
            RandomVerticalFlip(),
            RandomHorizontalFlip(),
            NormalizeIntensity(scale_int, norm_mean, norm_sd),
            RandomIntensityOffset(int_aug_offset),
            RandomIntensityExpansion(int_aug_expansion),
            OutContours2Segmask(),
            FixDTypes()
            #I cannot really pass the ToTensor to the dataloader since it crashes when the batchsize is large (>256)
        ]
        self.transforms_random = Compose(transforms_random)

        transforms_full = [
            NormalizeIntensity(scale_int),
            OutContours2Segmask(),
            FixDTypes(),
            ToTensor()
        ]
        self.transforms_full = Compose(transforms_full)
        self.hard_neg_data = None

        if self.data_src.is_dir():
            assert self.folds2include is None
            self.data = self.load_data_from_dir(self.data_src, padded_roi_size,
                                                self.is_preloaded)
        else:
            assert self.is_preloaded
            self.data = self.load_data_from_file(self.data_src)

        self.type_ids = sorted(list(self.data.keys()))
        self.types2label = {k: (ii + 1) for ii, k in enumerate(self.type_ids)}

        self.num_clases = len(self.type_ids)

        #flatten data so i can access the whole list by index
        self.data_indexes = [(_type, _fname, ii)
                             for _type, type_data in self.data.items()
                             for _fname, file_data in type_data.items()
                             for ii in range(len(file_data))]

        assert len(self.data_indexes) > 0  #makes sure there are valid files
コード例 #5
0
    def initialize(self, pyramid_images_output_path=None):
        if len(self.pyramid_factors) == 0:
            self.add_pyramid_level(1, 0.0)
        if len(self.initial_transforms) == 0:
            self.add_initial_transform(AffineTransform(self.dim))

        ch = len(self.ref_im)
        #        print(ch)
        # require same number of channels
        assert (ch == len(self.flo_im))

        ref_input = self.ref_im
        flo_input = self.flo_im
        if self.channel_mode == 'decompose_pre':
            lev = None
            #lev = self.alpha_levels
            ref_input = filters.fidt(ref_input, lev)  #self.alpha_levels)
            flo_input = filters.fidt(flo_input, lev)  #self.alpha_levels)
            ch = len(ref_input)
        ### Preprocessing

        pyramid_levels = len(self.pyramid_factors)
        percentile = 0.01

        for i in range(pyramid_levels):
            factor = self.pyramid_factors[i]

            ref_mask_resampled = filters.downsample(self.ref_mask, factor)
            flo_mask_resampled = filters.downsample(self.flo_mask, factor)

            ref_resampled = []
            flo_resampled = []

            for k in range(ch):
                ref_k = filters.downsample(
                    filters.gaussian_filter(ref_input[k],
                                            self.pyramid_sigmas[i]), factor)
                flo_k = filters.downsample(
                    filters.gaussian_filter(flo_input[k],
                                            self.pyramid_sigmas[i]), factor)
                #if self.channel_mode == 'sum':
                #ref_k = filters.normalize(ref_k, percentile, ref_mask_resampled)
                #flo_k = filters.normalize(flo_k, percentile, flo_mask_resampled)
                ref_resampled.append(ref_k)
                flo_resampled.append(flo_k)

            if self.channel_mode == 'sum' or self.channel_mode == 'decompose_pre':
                pass
            elif self.channel_mode == 'decompose':
                ref_resampled = filters.fidt(ref_resampled, self.alpha_levels)
                flo_resampled = filters.fidt(flo_resampled, self.alpha_levels)
                for k in range(len(ref_resampled)):
                    ref_resampled[k] = filters.normalize(
                        ref_resampled[k], percentile, ref_mask_resampled)
                    flo_resampled[k] = filters.normalize(
                        flo_resampled[k], percentile, flo_mask_resampled)

            #if pyramid_images_output_path is not None and ref_resampled[0].ndim == 2:
            #    scipy.misc.imsave('%sref_resampled_%d.png' % (pyramid_images_output_path, i+1), ref_resampled)
            #    scipy.misc.imsave('%sflo_resampled_%d.png' % (pyramid_images_output_path, i+1), flo_resampled)

            if self.ref_weights is None:
                ref_weights = np.zeros(ref_resampled[0].shape)
                ref_weights[ref_mask_resampled] = 1.0
            else:
                ref_weights = filters.downsample(self.ref_weights, factor)
            if self.flo_weights is None:
                flo_weights = np.zeros(flo_resampled[0].shape)
                flo_weights[flo_mask_resampled] = 1.0
            else:
                flo_weights = filters.downsample(self.flo_weights, factor)

            ref_diag = np.sqrt(
                np.square(np.array(ref_resampled[0].shape) *
                          self.ref_spacing).sum())
            flo_diag = np.sqrt(
                np.square(np.array(flo_resampled[0].shape) *
                          self.flo_spacing).sum())

            dists = []

            for k in range(len(ref_resampled)):
                q_ref = QuantizedImage(ref_resampled[k],
                                       self.alpha_levels,
                                       ref_weights,
                                       self.ref_spacing * factor,
                                       remove_zero_weight_pnts=True)
                q_flo = QuantizedImage(flo_resampled[k],
                                       self.alpha_levels,
                                       flo_weights,
                                       self.flo_spacing * factor,
                                       remove_zero_weight_pnts=True)

                if self.squared_measure:
                    dt_fun = alpha_amd.edt_sq
                else:
                    dt_fun = None

                tf_ref = alpha_amd.AlphaAMD(q_ref,
                                            self.alpha_levels,
                                            ref_diag,
                                            self.ref_spacing * factor,
                                            ref_mask_resampled,
                                            ref_mask_resampled,
                                            interpolator_mode='linear',
                                            dt_fun=dt_fun,
                                            mask_out_edges=True)
                tf_flo = alpha_amd.AlphaAMD(q_flo,
                                            self.alpha_levels,
                                            flo_diag,
                                            self.flo_spacing * factor,
                                            flo_mask_resampled,
                                            flo_mask_resampled,
                                            interpolator_mode='linear',
                                            dt_fun=dt_fun,
                                            mask_out_edges=True)

                symmetric_measure = True
                squared_measure = False  #self.squared_measure

                sym_dist = symmetric_amd_distance.SymmetricAMDDistance(
                    symmetric_measure=symmetric_measure,
                    squared_measure=squared_measure)

                sym_dist.set_ref_image_source(q_ref)
                sym_dist.set_ref_image_target(tf_ref)

                sym_dist.set_flo_image_source(q_flo)
                sym_dist.set_flo_image_target(tf_flo)

                sym_dist.set_sampling_fraction(self.sampling_fraction)

                sym_dist.initialize()

                dists.append(sym_dist)

            self.distances.append(dists)