Ejemplo n.º 1
0
    def tearDown_reg_f3d(self):
        print("Computational time = %s" %
              (self.registration_method.get_computational_time()))

        warped_moving_sitk = self.registration_method.get_warped_moving_sitk()

        if self.show_fig:
            sitkh.show_sitk_image(
                [self.fixed_sitk, self.moving_sitk, warped_moving_sitk],
                label=["fixed", "moving", "warped_moving"])
Ejemplo n.º 2
0
    def _get_best_transform(self, transformations, debug=False):

        if self._refine_pca_initializations:
            transformations = self._run_registrations(transformations)

        warps = []
        for transform_sitk in transformations:
            warped_moving_sitk = sitk.Resample(
                self._moving.sitk,
                self._fixed.sitk,
                transform_sitk,
                sitk.sitkLinear,
            )
            warps.append(
                st.Stack.from_sitk_image(
                    warped_moving_sitk,
                    extract_slices=False,
                    slice_thickness=self._fixed.get_slice_thickness(),
                ))

        image_similarity_evaluator = ise.ImageSimilarityEvaluator(
            stacks=warps,
            reference=self._fixed,
            measures=[self._similarity_measure],
            use_reference_mask=True,
            verbose=False,
        )
        ph.print_info("Find best aligning transform as measured by %s" %
                      self._similarity_measure)
        image_similarity_evaluator.compute_similarities()
        similarities = image_similarity_evaluator.get_similarities()

        # get transform which leads to highest similarity
        index = np.argmax(similarities[self._similarity_measure])
        transform_init_sitk = transformations[index]

        if debug:
            labels = [
                "attempt%d" % (d + 1) for d in range(len(transformations))
            ]
            labels[index] = "best"
            foo = [w.sitk for w in warps]
            foo.insert(0, self._fixed.sitk)
            labels.insert(0, "fixed")
            sitkh.show_sitk_image(
                foo,
                segmentation=self._fixed.sitk_mask,
                label=labels,
            )
            for i in range(len(transformations)):
                print(
                    "%s: %.6f" %
                    (labels[1 + i], similarities[self._similarity_measure][i]))

        return transform_init_sitk
Ejemplo n.º 3
0
    def save_landmarks_to_image(self, path_to_file):

        if self._landmarks_image_space is None:
            raise RuntimeError("Execute 'run' first to estimate landmarks")

        ph.print_info("Save landmarks to image '%s' ... " % path_to_file,
                      newline=False)

        # read original image
        image_label_sitk = sitk.ReadImage(self._path_to_image_label)
        image_label_nda = sitk.GetArrayFromImage(image_label_sitk) * 0

        # convert to integer voxels
        image_label_nda = self._get_array_with_landmarks(
            image_label_sitk.GetSize()[::-1], self._landmarks_voxel_space)
        # landmarks_voxel_space = self._landmarks_voxel_space.astype('int')

        # for i in range(landmarks_voxel_space.shape[0]):
        #     image_label_nda[landmarks_voxel_space[i, 2],
        #                    landmarks_voxel_space[i, 1],
        #                    landmarks_voxel_space[i, 0]] = 1

        image_landmarks_sitk = sitk.GetImageFromArray(image_label_nda)
        image_landmarks_sitk.CopyInformation(image_label_sitk)

        sitkh.write_nifti_image_sitk(image_landmarks_sitk, path_to_file)
        print("done")

        # show landmark estimate
        if self._verbose:
            # find bounding box for "zoomed in" visualization
            ran_x, ran_y, ran_z = self._get_bounding_box(image_label_nda)

            # get zoomed-in image mask
            image_label_nda_show = image_label_nda[ran_x[0]:ran_x[1],
                                                   ran_y[0]:ran_y[1],
                                                   ran_z[0]:ran_z[1]]
            landmarks_nda = self._get_array_with_landmarks(
                image_label_nda.shape, self._landmarks_voxel_space)
            show_mask_sitk = sitk.GetImageFromArray(image_label_nda_show)

            # get zoomed-in landmark estimate (dilated for visualization)
            landmarks_nda_show = landmarks_nda[ran_x[0]:ran_x[1],
                                               ran_y[0]:ran_y[1],
                                               ran_z[0]:ran_z[1]]
            landmarks_nda_show += scipy.ndimage.morphology.binary_dilation(
                landmarks_nda_show, iterations=10)
            show_landmarks_sitk = sitk.GetImageFromArray(landmarks_nda_show)

            sitkh.show_sitk_image(show_mask_sitk,
                                  segmentation=show_landmarks_sitk,
                                  label=os.path.basename(
                                      ph.strip_filename_extension(
                                          self._path_to_image_label)[0]))
Ejemplo n.º 4
0
    def test_02_brain_mask(self):
        filename = "stack0"

        brain_stripping = bs.BrainStripping.from_filename(
            self.dir_test_data, filename)
        brain_stripping.compute_brain_image(0)
        brain_stripping.compute_brain_mask(1)
        brain_stripping.compute_skull_image(0)
        # brain_stripping.set_bet_options("-f 0.3")

        brain_stripping.run()
        original_sitk = brain_stripping.get_input_image_sitk()
        brain_mask_sitk = brain_stripping.get_brain_mask_sitk()
        sitkh.show_sitk_image([original_sitk], segmentation=brain_mask_sitk)
Ejemplo n.º 5
0
    def tearDown(self):
        print("Computational time = %s" %
              (self.registration_method.get_computational_time()))

        registration_transform_sitk = \
            self.registration_method.get_registration_transform_sitk()

        transformed_fixed_sitk = sitkh.get_transformed_sitk_image(
            self.fixed_sitk, registration_transform_sitk)

        if self.show_fig:
            sitkh.show_sitk_image(
                [self.moving_sitk, self.fixed_sitk, transformed_fixed_sitk],
                label=["moving_itk", "fixed_itk", "warped_fixed_itk"])
    def tearDown(self):
        print("Computational time = %s" %
              (self.registration_method.get_computational_time()))

        transformed_fixed_sitk = \
            self.registration_method.get_transformed_fixed_sitk()
        warped_moving_sitk = self.registration_method.get_warped_moving_sitk()

        if self.show_fig:
            # sitkh.show_sitk_image(
            #     [self.fixed_sitk, self.moving_sitk, warped_moving_sitk],
            #     label=["fixed", "moving", "warped_moving"])
            sitkh.show_sitk_image(
                [self.moving_sitk, self.fixed_sitk, transformed_fixed_sitk],
                label=["moving", "fixed", "warped_fixed"])
    def run(self, debug=False):
        # perform PCAs for fixed and moving images
        pca_moving = self.get_pca_from_mask(self._moving.sitk_mask)
        eigvec_moving = pca_moving.get_eigvec()
        mean_moving = pca_moving.get_mean()

        pca_fixed = self.get_pca_from_mask(self._fixed.sitk_mask)
        eigvec_fixed = pca_fixed.get_eigvec()
        mean_fixed = pca_fixed.get_mean()

        # test different initializations based on eigenvector orientations
        orientations = [
            [1, 1],
            [1, -1],
            [-1, 1],
            [-1, -1],
        ]
        transformations = []
        for i_o, orientation in enumerate(orientations):
            eigvec_moving_o = np.array(eigvec_moving)
            eigvec_moving_o[:, 0] *= orientation[0]
            eigvec_moving_o[:, 1] *= orientation[1]

            # get right-handed coordinate system
            cross = np.cross(eigvec_moving_o[:, 0], eigvec_moving_o[:, 1])
            eigvec_moving_o[:, 2] = cross

            # transformation to align fixed with moving eigenbasis
            R = eigvec_moving_o.dot(eigvec_fixed.transpose())
            t = mean_moving - R.dot(mean_fixed)

            # build rigid transformation as sitk object
            rigid_transform_sitk = sitk.Euler3DTransform()
            rigid_transform_sitk.SetMatrix(R.flatten())
            rigid_transform_sitk.SetTranslation(t)
            transformations.append(rigid_transform_sitk)

        # get best transformation according to selected similarity measure
        self._initial_transform_sitk = self._get_best_transform(
            transformations)

        if debug:
            foo = sitk.Resample(
                self._moving.sitk,
                self._fixed.sitk,
                self._initial_transform_sitk,
            )
            sitkh.show_sitk_image([fixed.sitk, foo])
Ejemplo n.º 8
0
    def show(self, show_segmentation=0, label=None, viewer=VIEWER, verbose=True):

        if label is None:
            label = self._filename + "_" + str(self._slice_number)

        if show_segmentation:
            segmentation = self.sitk_mask
        else:
            segmentation = None

        sitkh.show_sitk_image(
            self.sitk,
            segmentation=segmentation,
            label=label,
            viewer=viewer,
            verbose=verbose)
Ejemplo n.º 9
0
    def show(self, show_segmentation=0, label=None, viewer=VIEWER, verbose=True):

        if label is None:
            label = self._filename

        if show_segmentation:
            sitk_mask = self.sitk_mask
        else:
            sitk_mask = None

        sitkh.show_sitk_image(
            self.sitk,
            label=label,
            segmentation=sitk_mask,
            viewer=viewer,
            verbose=verbose)
Ejemplo n.º 10
0
    def tearDown(self):
        print("Computational time = %s" %
              (self.registration_method.get_computational_time()))

        transformed_fixed_sitk = \
            self.registration_method.get_transformed_fixed_sitk()
        warped_moving_sitk = self.registration_method.get_warped_moving_sitk()

        if self.show_fig:
            sitkh.show_sitk_image(
                [self.fixed_sitk, self.moving_sitk, warped_moving_sitk],
                label=["fixed", "moving", "warped_moving"])
            sitkh.show_sitk_image([self.moving_sitk, transformed_fixed_sitk],
                                  label=["moving", "warped_fixed"])

        registration_transform_sitk = \
            self.registration_method.get_registration_transform_sitk()

        # Check transformed fixed
        transformed_fixed_sitk_2 = sitkh.get_transformed_sitk_image(
            self.fixed_sitk, registration_transform_sitk)
        transformed_fixed_sitk = sitk.Cast(transformed_fixed_sitk,
                                           sitk.sitkFloat64)
        transformed_fixed_sitk_2 = sitk.Cast(transformed_fixed_sitk_2,
                                             sitk.sitkFloat64)
        diff_nda = sitk.GetArrayFromImage(transformed_fixed_sitk_2 -
                                          transformed_fixed_sitk)
        self.assertEqual(
            np.round(np.linalg.norm(diff_nda), decimals=self.accuracy), 0)

        # Check warped moving
        warped_moving_sitk_2 = sitk.Resample(self.moving_sitk, self.fixed_sitk,
                                             registration_transform_sitk)
        warped_moving_sitk = sitk.Cast(warped_moving_sitk, sitk.sitkFloat64)
        warped_moving_sitk_2 = sitk.Cast(warped_moving_sitk_2,
                                         sitk.sitkFloat64)
        diff_nda = sitk.GetArrayFromImage(warped_moving_sitk_2 -
                                          warped_moving_sitk)
        try:
            self.assertEqual(
                np.round(np.linalg.norm(diff_nda), decimals=self.accuracy), 0)
        except Exception as e:
            print("FAIL: " + self.id() +
                  " failed given norm of difference = %.2e > 1e-%s" %
                  (np.linalg.norm(diff_nda), self.accuracy))
            sitkh.show_sitk_image(
                [
                    warped_moving_sitk,
                    warped_moving_sitk_2,
                    warped_moving_sitk_2 - warped_moving_sitk,
                ],
                label=["warped_moving", "warped_moving_2", "diff"])
Ejemplo n.º 11
0
    def test_inplane_similarity_alignment_to_reference(self):

        filename_stack = "fetal_brain_0"
        # filename_stack = "3D_SheppLoganPhantom_64"

        stack = st.Stack.from_filename(
            os.path.join(self.dir_test_data, filename_stack + ".nii.gz"),
            os.path.join(self.dir_test_data, filename_stack + "_mask.nii.gz"))
        # stack.show(1)

        nda = sitk.GetArrayFromImage(stack.sitk)
        nda_mask = sitk.GetArrayFromImage(stack.sitk_mask)
        i = 5
        nda_slice = np.array(nda[i, :, :])
        nda_mask_slice = np.array(nda_mask[i, :, :])

        for i in range(0, nda.shape[0]):
            nda[i, :, :] = nda_slice
            nda_mask[i, :, :] = nda_mask_slice

        stack_sitk = sitk.GetImageFromArray(nda)
        stack_sitk_mask = sitk.GetImageFromArray(nda_mask)
        stack_sitk.CopyInformation(stack.sitk)
        stack_sitk_mask.CopyInformation(stack.sitk_mask)

        stack = st.Stack.from_sitk_image(stack_sitk, stack.get_filename(),
                                         stack_sitk_mask)

        # Create in-plane motion corruption
        scale = 1.2
        angle_z = 0.05
        center_2D = (0, 0)
        # translation_2D = np.array([0,0])
        translation_2D = np.array([1, -1])

        intensity_scale = 10
        intensity_bias = 50

        # Get corrupted stack and corresponding motions
        stack_corrupted, motion_sitk, motion_2_sitk = get_inplane_corrupted_stack(
            stack,
            angle_z,
            center_2D,
            translation_2D,
            scale=scale,
            intensity_scale=intensity_scale,
            intensity_bias=intensity_bias,
            debug=0)

        # stack_corrupted.show(1)
        # stack.show(1)

        # Perform in-plane rigid registrations
        inplane_registration = inplanereg.IntraStackRegistration(
            stack=stack_corrupted, reference=stack)
        # inplane_registration = inplanereg.IntraStackRegistration(stack_corrupted)
        inplane_registration.set_transform_initializer_type("geometry")
        # inplane_registration.set_transform_initializer_type("identity")
        inplane_registration.set_intensity_correction_initializer_type(
            "affine")
        inplane_registration.set_transform_type("similarity")
        inplane_registration.set_interpolator("Linear")
        inplane_registration.set_optimizer_loss("linear")
        # inplane_registration.use_reference_mask(True)
        inplane_registration.use_stack_mask(True)
        inplane_registration.use_parameter_normalization(True)
        inplane_registration.set_prior_scale(1 / scale)
        inplane_registration.set_prior_intensity_coefficients(
            (intensity_scale, intensity_bias))
        inplane_registration.set_intensity_correction_type_slice_neighbour_fit(
            "affine")
        inplane_registration.set_intensity_correction_type_reference_fit(
            "affine")
        inplane_registration.use_verbose(True)
        inplane_registration.set_alpha_reference(1)
        inplane_registration.set_alpha_neighbour(0)
        inplane_registration.set_alpha_parameter(1e10)
        inplane_registration.set_optimizer_iter_max(20)
        inplane_registration.use_verbose(True)
        inplane_registration.run()
        inplane_registration.print_statistics()

        # inplane_registration._run_registration_pipeline_initialization()
        # inplane_registration._apply_motion_correction()

        stack_registered = inplane_registration.get_corrected_stack()
        parameters = inplane_registration.get_parameters()

        sitkh.show_sitk_image([
            stack.sitk,
            stack_corrupted.get_resampled_stack_from_slices(
                interpolator="Linear", resampling_grid=stack.sitk).sitk,
            stack_registered.get_resampled_stack_from_slices(
                interpolator="Linear", resampling_grid=stack.sitk).sitk
        ],
                              label=["original", "corrupted", "recovered"])

        # self.assertEqual(np.round(
        #     np.linalg.norm(nda_diff)
        # , decimals = self.accuracy), 0)

        # 2) Test slice transforms
        slice_transforms_sitk = inplane_registration.get_slice_transforms_sitk(
        )

        stack_tmp = st.Stack.from_stack(stack_corrupted)
        stack_tmp.update_motion_correction_of_slices(slice_transforms_sitk)

        stack_diff_sitk = stack_tmp.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk
        ).sitk - stack_registered.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk).sitk

        stack_diff_nda = sitk.GetArrayFromImage(stack_diff_sitk)

        self.assertEqual(np.round(np.linalg.norm(stack_diff_nda), decimals=8),
                         0)
Ejemplo n.º 12
0
def get_inplane_corrupted_stack(stack,
                                angle_z,
                                center_2D,
                                translation_2D,
                                scale=1,
                                intensity_scale=1,
                                intensity_bias=0,
                                debug=0,
                                random=False):

    # Convert to 3D:
    translation_3D = np.zeros(3)
    translation_3D[0:-1] = translation_2D

    center_3D = np.zeros(3)
    center_3D[0:-1] = center_2D

    # Transform to align physical coordinate system with stack-coordinate
    # system
    affine_centering_sitk = sitk.AffineTransform(3)
    affine_centering_sitk.SetMatrix(stack.sitk.GetDirection())
    affine_centering_sitk.SetTranslation(stack.sitk.GetOrigin())

    # Corrupt first stack towards positive direction
    if random:
        angle_z_1 = -angle_z * np.random.rand(1)[0]
    else:
        angle_z_1 = -angle_z

    in_plane_motion_sitk = sitk.Euler3DTransform()
    in_plane_motion_sitk.SetRotation(0, 0, angle_z_1)
    in_plane_motion_sitk.SetCenter(center_3D)
    in_plane_motion_sitk.SetTranslation(translation_3D)
    motion_sitk = sitkh.get_composite_sitk_affine_transform(
        in_plane_motion_sitk,
        sitk.AffineTransform(affine_centering_sitk.GetInverse()))
    motion_sitk = sitkh.get_composite_sitk_affine_transform(
        affine_centering_sitk, motion_sitk)
    stack_corrupted_resampled_sitk = sitk.Resample(stack.sitk, motion_sitk,
                                                   sitk.sitkLinear)
    stack_corrupted_resampled_sitk_mask = sitk.Resample(
        stack.sitk_mask, motion_sitk, sitk.sitkLinear)

    # Corrupt first stack towards negative direction
    if random:
        angle_z_2 = -angle_z * np.random.rand(1)[0]
    else:
        angle_z_2 = -angle_z

    in_plane_motion_2_sitk = sitk.Euler3DTransform()
    in_plane_motion_2_sitk.SetRotation(0, 0, angle_z_2)
    in_plane_motion_2_sitk.SetCenter(center_3D)
    in_plane_motion_2_sitk.SetTranslation(-translation_3D)
    motion_2_sitk = sitkh.get_composite_sitk_affine_transform(
        in_plane_motion_2_sitk,
        sitk.AffineTransform(affine_centering_sitk.GetInverse()))
    motion_2_sitk = sitkh.get_composite_sitk_affine_transform(
        affine_centering_sitk, motion_2_sitk)
    stack_corrupted_2_resampled_sitk = sitk.Resample(stack.sitk, motion_2_sitk,
                                                     sitk.sitkLinear)
    stack_corrupted_2_resampled_sitk_mask = sitk.Resample(
        stack.sitk_mask, motion_2_sitk, sitk.sitkLinear)

    # Create stack based on those two corrupted stacks
    nda = sitk.GetArrayFromImage(stack_corrupted_resampled_sitk)
    nda_mask = sitk.GetArrayFromImage(stack_corrupted_resampled_sitk_mask)
    nda_neg = sitk.GetArrayFromImage(stack_corrupted_2_resampled_sitk)
    nda_neg_mask = sitk.GetArrayFromImage(
        stack_corrupted_2_resampled_sitk_mask)
    for i in range(0, stack.sitk.GetDepth(), 2):
        nda[i, :, :] = nda_neg[i, :, :]
        nda_mask[i, :, :] = nda_neg_mask[i, :, :]
    stack_corrupted_sitk = sitk.GetImageFromArray(
        (nda - intensity_bias) / intensity_scale)
    stack_corrupted_sitk_mask = sitk.GetImageFromArray(nda_mask)
    stack_corrupted_sitk.CopyInformation(stack.sitk)
    stack_corrupted_sitk_mask.CopyInformation(stack.sitk_mask)

    # Debug: Show corrupted stacks (before scaling)
    if debug:
        sitkh.show_sitk_image([
            stack.sitk, stack_corrupted_resampled_sitk,
            stack_corrupted_2_resampled_sitk, stack_corrupted_sitk
        ],
                              title=[
                                  "original", "corrupted_1", "corrupted_2",
                                  "corrupted_final_from_1_and_2"
                              ])

    # Update in-plane scaling
    spacing = np.array(stack.sitk.GetSpacing())
    spacing[0:-1] /= scale
    stack_corrupted_sitk.SetSpacing(spacing)
    stack_corrupted_sitk_mask.SetSpacing(spacing)

    # Create Stack object
    stack_corrupted = st.Stack.from_sitk_image(stack_corrupted_sitk,
                                               "stack_corrupted",
                                               stack_corrupted_sitk_mask)

    # Debug: Show corrupted stacks (after scaling)
    if debug:
        stack_corrupted_resampled_sitk = sitk.Resample(stack_corrupted.sitk,
                                                       stack.sitk)
        sitkh.show_sitk_image([stack.sitk, stack_corrupted_resampled_sitk],
                              title=["original", "corrupted"])

    return stack_corrupted, motion_sitk, motion_2_sitk