Example #1
0
    def _run_discrete_shepard_reconstruction(self):

        shape = sitk.GetArrayFromImage(self._HR_volume.sitk).shape
        helper_N_nda = np.zeros(shape)
        helper_D_nda = np.zeros(shape)

        default_pixel_value = 0.0

        for i in range(0, self._N_stacks):
            if self._verbose:
                ph.print_info("Stack %s/%s" % (i + 1, self._N_stacks))
            stack = self._stacks[i]
            slices = stack.get_slices()
            N_slices = stack.get_number_of_slices()

            # for j in range(10, 11):
            for j in range(0, N_slices):
                # print("\t\tSlice %s/%s" %(j,N_slices-1))
                slice = slices[j]
                slice_sitk = self._get_slice[(bool(self._use_masks),
                                              bool(self._sda_mask))](slice)

                # Add intensity offset so that a "zero" intensity can be
                # identified as contribution of image slice (line 353/356)
                slice_sitk += 1

                # Nearest neighbour resampling of slice to target space (HR
                # volume)
                slice_resampled_sitk = sitk.Resample(
                    slice_sitk, self._HR_volume.sitk, sitk.Euler3DTransform(),
                    sitk.sitkNearestNeighbor, default_pixel_value,
                    self._HR_volume.sitk.GetPixelIDValue())

                # sitkh.show_sitk_image(slice_resampled_sitk)

                # Extract array of pixel intensities
                nda_slice = sitk.GetArrayFromImage(slice_resampled_sitk)

                # Get voxels in HR volume space which are struck by the slice
                ind_nonzero = nda_slice > 0

                # update numerator (correct previous intensity offset)
                helper_N_nda[ind_nonzero] += nda_slice[ind_nonzero] - 1

                # update denominator
                helper_D_nda[ind_nonzero] += 1

                # test = sitk.GetImageFromArray(helper_N_nda)
                # sitkh.show_sitk_image(test,title="N")

                # test = sitk.GetImageFromArray(helper_D_nda)
                # sitkh.show_sitk_image(test,title="D")

                # print("helper_N_nda: (min, max) = (%s, %s)" %(np.min(helper_N_nda), np.max(helper_N_nda)))
                # print("helper_D_nda: (min, max) = (%s, %s)" %(np.min(helper_D_nda), np.max(helper_D_nda)))

        # TODO: Set zero entries to one; Otherwise results are very weird!?
        helper_D_nda[helper_D_nda == 0] = 1

        # Create itk-images with correct header data
        pixel_type = itk.D
        dimension = 3
        image_type = itk.Image[pixel_type, dimension]

        itk2np = itk.PyBuffer[image_type]
        helper_N = itk2np.GetImageFromArray(helper_N_nda)
        helper_D = itk2np.GetImageFromArray(helper_D_nda)

        helper_N.SetSpacing(self._HR_volume.sitk.GetSpacing())
        helper_N.SetDirection(
            sitkh.get_itk_direction_from_sitk_image(self._HR_volume.sitk))
        helper_N.SetOrigin(self._HR_volume.sitk.GetOrigin())

        helper_D.SetSpacing(self._HR_volume.sitk.GetSpacing())
        helper_D.SetDirection(
            sitkh.get_itk_direction_from_sitk_image(self._HR_volume.sitk))
        helper_D.SetOrigin(self._HR_volume.sitk.GetOrigin())

        # Apply Recursive Gaussian YVV filter
        gaussian = itk.SmoothingRecursiveYvvGaussianImageFilter[
            image_type, image_type].New()  # YVV-based Filter
        # gaussian = itk.SmoothingRecursiveGaussianImageFilter[image_type,
        # image_type].New()    # Deriche-based Filter
        gaussian.SetSigmaArray(self._sigma_array)
        gaussian.SetInput(helper_N)
        gaussian.Update()
        HR_volume_update_N = gaussian.GetOutput()
        HR_volume_update_N.DisconnectPipeline()

        gaussian.SetInput(helper_D)
        gaussian.Update()
        HR_volume_update_D = gaussian.GetOutput()
        HR_volume_update_D.DisconnectPipeline()

        # Convert numerator and denominator back to data array
        nda_N = itk2np.GetArrayFromImage(HR_volume_update_N)
        nda_D = itk2np.GetArrayFromImage(HR_volume_update_D)

        # Compute data array of HR volume:
        # nda_D[nda_D==0]=1
        nda = nda_N / nda_D.astype(float)

        # Update HR volume image file within Stack-object HR_volume
        HR_volume_update = sitk.GetImageFromArray(nda)
        HR_volume_update.CopyInformation(self._HR_volume.sitk)

        if not self._sda_mask:
            self._HR_volume.sitk = HR_volume_update
            self._HR_volume.itk = sitkh.get_itk_from_sitk_image(
                HR_volume_update)
        else:
            # Approximate uint8 mask from float SDA outcome
            mask_estimator = bm.BinaryMaskFromMaskSRREstimator(
                HR_volume_update)
            mask_estimator.run()
            HR_volume_update = mask_estimator.get_mask_sitk()

            self._HR_volume.sitk_mask = HR_volume_update
            self._HR_volume.itk_mask = sitkh.get_itk_from_sitk_image(
                HR_volume_update)
Example #2
0
    def _run_discrete_shepard_based_on_Deriche_reconstruction(self):

        shape = sitk.GetArrayFromImage(self._HR_volume.sitk).shape
        helper_N_nda = np.zeros(shape)
        helper_D_nda = np.zeros(shape)

        default_pixel_value = 0.0

        for i in range(0, self._N_stacks):
            if self._verbose:
                ph.print_info("Stack %s/%s" % (i + 1, self._N_stacks))
            stack = self._stacks[i]
            slices = stack.get_slices()
            N_slices = stack.get_number_of_slices()

            for j in range(0, N_slices):

                slice = slices[j]
                slice_sitk = self._get_slice[(bool(self._use_masks),
                                              bool(self._sda_mask))](slice)

                # Nearest neighbour resampling of slice to target space (HR
                # volume)
                slice_resampled_sitk = sitk.Resample(
                    slice_sitk, self._HR_volume.sitk, sitk.Euler3DTransform(),
                    sitk.sitkNearestNeighbor, default_pixel_value,
                    self._HR_volume.sitk.GetPixelIDValue())

                # Extract array of pixel intensities
                nda_slice = sitk.GetArrayFromImage(slice_resampled_sitk)

                # Look for indices which are stroke by the slice in the
                # isotropic grid
                ind_nonzero = nda_slice > 0

                # update arrays of numerator and denominator
                helper_N_nda[ind_nonzero] += nda_slice[ind_nonzero]
                helper_D_nda[ind_nonzero] += 1

                # print("helper_N_nda: (min, max) = (%s, %s)" %(np.min(helper_N_nda), np.max(helper_N_nda)))
                # print("helper_D_nda: (min, max) = (%s, %s)" %(np.min(helper_D_nda), np.max(helper_D_nda)))

        # TODO: Set zero entries to one; Otherwise results are very weird!?
        helper_D_nda[helper_D_nda == 0] = 1

        # Create sitk-images with correct header data
        helper_N = sitk.GetImageFromArray(helper_N_nda)
        helper_D = sitk.GetImageFromArray(helper_D_nda)

        helper_N.CopyInformation(self._HR_volume.sitk)
        helper_D.CopyInformation(self._HR_volume.sitk)

        # Apply recursive Gaussian smoothing
        gaussian = sitk.SmoothingRecursiveGaussianImageFilter()
        gaussian.SetSigma(self._sigma_array[1])

        HR_volume_update_N = gaussian.Execute(helper_N)
        HR_volume_update_D = gaussian.Execute(helper_D)

        # ## Avoid undefined division by zero
        # """
        # HACK start
        # """
        # ## HACK for denominator
        # nda = sitk.GetArrayFromImage(HR_volume_update_D)
        # ind_min = np.unravel_index(np.argmin(nda), nda.shape)
        # # print(nda[nda<0])
        # # print(nda[ind_min])

        # eps = 1e-8
        # # nda[nda<=eps]=1
        # print("denominator min = %s" % np.min(nda))

        # HR_volume_update_D = sitk.GetImageFromArray(nda)
        # HR_volume_update_D.CopyInformation(self._HR_volume.sitk)

        # ## HACK for numerator given that some intensities are negative!?
        # nda = sitk.GetArrayFromImage(HR_volume_update_N)
        # ind_min = np.unravel_index(np.argmin(nda), nda.shape)
        # # nda[nda<=eps]=0
        # # print(nda[nda<0])
        # print("numerator min = %s" % np.min(nda))
        # """
        # HACK end
        # """

        # Compute HR volume based on scattered data approximation with correct
        # header (might be redundant):
        HR_volume_update = HR_volume_update_N / HR_volume_update_D
        HR_volume_update.CopyInformation(self._HR_volume.sitk)

        if not self._sda_mask:
            self._HR_volume.sitk = HR_volume_update
            self._HR_volume.itk = sitkh.get_itk_from_sitk_image(
                HR_volume_update)
        else:
            # Approximate uint8 mask from float SDA outcome
            mask_estimator = bm.BinaryMaskFromMaskSRREstimator(
                HR_volume_update)
            mask_estimator.run()
            HR_volume_update = mask_estimator.get_mask_sitk()

            self._HR_volume.sitk_mask = HR_volume_update
            self._HR_volume.itk_mask = sitkh.get_itk_from_sitk_image(
                HR_volume_update)
        """
        Additional info
        """
        if self._verbose:
            nda = sitk.GetArrayFromImage(HR_volume_update)
            print("Minimum of data array = %s" % np.min(nda))
def main():

    time_start = ph.start_timing()

    # Set print options for numpy
    np.set_printoptions(precision=3)

    # Read input
    input_parser = InputArgparser(
        description="Volumetric MRI reconstruction framework to reconstruct "
        "an isotropic, high-resolution 3D volume from multiple "
        "motion-corrected (or static) stacks of low-resolution slices.", )
    input_parser.add_filenames(required=True)
    input_parser.add_filenames_masks()
    input_parser.add_dir_input_mc()
    input_parser.add_output(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_target_stack(default=None)
    input_parser.add_extra_frame_target(default=10)
    input_parser.add_isotropic_resolution(default=None)
    input_parser.add_intensity_correction(default=1)
    input_parser.add_reconstruction_space(default=None)
    input_parser.add_minimizer(default="lsmr")
    input_parser.add_iter_max(default=10)
    input_parser.add_reconstruction_type(default="TK1L2")
    input_parser.add_data_loss(default="linear")
    input_parser.add_data_loss_scale(default=1)
    input_parser.add_alpha(default=0.01  # TK1L2
                           # default=0.006  #TVL2, HuberL2
                           )
    input_parser.add_rho(default=0.5)
    input_parser.add_tv_solver(default="PD")
    input_parser.add_pd_alg_type(default="ALG2")
    input_parser.add_iterations(default=15)
    input_parser.add_log_config(default=1)
    input_parser.add_use_masks_srr(default=0)
    input_parser.add_slice_thicknesses(default=None)
    input_parser.add_verbose(default=0)
    input_parser.add_viewer(default="itksnap")
    input_parser.add_argument(
        "--mask",
        "-mask",
        action='store_true',
        help="If given, input images are interpreted as image masks. "
        "Obtained volumetric reconstruction will be exported in uint8 format.")
    input_parser.add_argument(
        "--sda",
        "-sda",
        action='store_true',
        help="If given, the volume is reconstructed using "
        "Scattered Data Approximation (Vercauteren et al., 2006). "
        "--alpha is considered the value for the standard deviation then. "
        "Recommended value is, e.g., --alpha 0.8")

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    if args.reconstruction_type not in ["TK1L2", "TVL2", "HuberL2"]:
        raise IOError("Reconstruction type unknown")

    if np.alltrue([not args.output.endswith(t) for t in ALLOWED_EXTENSIONS]):
        raise ValueError("output filename '%s' invalid; "
                         "allowed image extensions are: %s" %
                         (args.output, ", ".join(ALLOWED_EXTENSIONS)))

    dir_output = os.path.dirname(args.output)
    ph.create_directory(dir_output)

    if args.log_config:
        input_parser.log_config(os.path.abspath(__file__))

    if args.verbose:
        show_niftis = []
        # show_niftis = [f for f in args.filenames]

    # --------------------------------Read Data--------------------------------
    ph.print_title("Read Data")

    if args.mask:
        filenames_masks = args.filenames
    else:
        filenames_masks = args.filenames_masks

    data_reader = dr.MultipleImagesReader(
        file_paths=args.filenames,
        file_paths_masks=filenames_masks,
        suffix_mask=args.suffix_mask,
        dir_motion_correction=args.dir_input_mc,
        stacks_slice_thicknesses=args.slice_thicknesses,
    )
    data_reader.read_data()
    stacks = data_reader.get_data()

    ph.print_info("%d input stacks read for further processing" % len(stacks))

    # Specify target stack for intensity correction and reconstruction space
    if args.target_stack is None:
        target_stack_index = 0
    else:
        filenames = ["%s.nii.gz" % s.get_filename() for s in stacks]
        filename_target_stack = os.path.basename(args.target_stack)
        try:
            target_stack_index = filenames.index(filename_target_stack)
        except ValueError as e:
            raise ValueError(
                "--target-stack must correspond to an image as provided by "
                "--filenames")

    # ---------------------------Intensity Correction--------------------------
    if args.intensity_correction and not args.mask:
        ph.print_title("Intensity Correction")
        intensity_corrector = ic.IntensityCorrection()
        intensity_corrector.use_individual_slice_correction(False)
        intensity_corrector.use_stack_mask(True)
        intensity_corrector.use_reference_mask(True)
        intensity_corrector.use_verbose(False)

        for i, stack in enumerate(stacks):
            if i == target_stack_index:
                ph.print_info("Stack %d (%s): Reference image. Skipped." %
                              (i + 1, stack.get_filename()))
                continue
            else:
                ph.print_info("Stack %d (%s): Intensity Correction ... " %
                              (i + 1, stack.get_filename()),
                              newline=False)
            intensity_corrector.set_stack(stack)
            intensity_corrector.set_reference(
                stacks[target_stack_index].get_resampled_stack(
                    resampling_grid=stack.sitk,
                    interpolator="NearestNeighbor",
                ))
            intensity_corrector.run_linear_intensity_correction()
            stacks[i] = intensity_corrector.get_intensity_corrected_stack()
            print("done (c1 = %g) " %
                  intensity_corrector.get_intensity_correction_coefficients())

    # -------------------------Volumetric Reconstruction-----------------------
    ph.print_title("Volumetric Reconstruction")

    # Reconstruction space is given isotropically resampled target stack
    if args.reconstruction_space is None:
        recon0 = stacks[target_stack_index].get_isotropically_resampled_stack(
            resolution=args.isotropic_resolution,
            extra_frame=args.extra_frame_target)
        recon0 = recon0.get_cropped_stack_based_on_mask(
            boundary_i=args.extra_frame_target,
            boundary_j=args.extra_frame_target,
            boundary_k=args.extra_frame_target,
            unit="mm",
        )

    # Reconstruction space was provided by user
    else:
        recon0 = st.Stack.from_filename(args.reconstruction_space,
                                        extract_slices=False)

        # Change resolution for isotropic resolution if provided by user
        if args.isotropic_resolution is not None:
            recon0 = recon0.get_isotropically_resampled_stack(
                args.isotropic_resolution)

        # Use image information of selected target stack as recon0 serves
        # as initial value for reconstruction
        recon0 = stacks[target_stack_index].get_resampled_stack(recon0.sitk)
        recon0 = recon0.get_stack_multiplied_with_mask()

    ph.print_info("Reconstruction space defined with %s mm3 resolution" %
                  " x ".join(["%.2f" % s for s in recon0.sitk.GetSpacing()]))

    if args.sda:
        ph.print_title("Compute SDA reconstruction")
        SDA = sda.ScatteredDataApproximation(stacks,
                                             recon0,
                                             sigma=args.alpha,
                                             sda_mask=args.mask)
        SDA.run()
        recon = SDA.get_reconstruction()
        if args.mask:
            dw.DataWriter.write_mask(recon.sitk_mask, args.output)
        else:
            dw.DataWriter.write_image(recon.sitk, args.output)

        if args.verbose:
            show_niftis.insert(0, args.output)

    else:
        if args.reconstruction_type in ["TVL2", "HuberL2"]:
            ph.print_title("Compute Initial value for %s" %
                           args.reconstruction_type)
            SRR0 = tk.TikhonovSolver(
                stacks=stacks,
                reconstruction=recon0,
                alpha=args.alpha,
                iter_max=np.min([5, args.iter_max]),
                reg_type="TK1",
                minimizer="lsmr",
                data_loss="linear",
                use_masks=args.use_masks_srr,
                # verbose=args.verbose,
            )
        else:
            ph.print_title("Compute %s reconstruction" %
                           args.reconstruction_type)
            SRR0 = tk.TikhonovSolver(
                stacks=stacks,
                reconstruction=recon0,
                alpha=args.alpha,
                iter_max=args.iter_max,
                reg_type="TK1",
                minimizer=args.minimizer,
                data_loss=args.data_loss,
                data_loss_scale=args.data_loss_scale,
                use_masks=args.use_masks_srr,
                # verbose=args.verbose,
            )
        SRR0.run()

        recon = SRR0.get_reconstruction()

        if args.reconstruction_type in ["TVL2", "HuberL2"]:
            output = ph.append_to_filename(args.output, "_initTK1L2")
        else:
            output = args.output

        if args.mask:
            mask_estimator = bm.BinaryMaskFromMaskSRREstimator(recon.sitk)
            mask_estimator.run()
            mask_sitk = mask_estimator.get_mask_sitk()
            dw.DataWriter.write_mask(mask_sitk, output)
        else:
            dw.DataWriter.write_image(recon.sitk, output)

        if args.verbose:
            show_niftis.insert(0, output)

        if args.reconstruction_type in ["TVL2", "HuberL2"]:
            ph.print_title("Compute %s reconstruction" %
                           args.reconstruction_type)
            if args.tv_solver == "ADMM":
                SRR = admm.ADMMSolver(
                    stacks=stacks,
                    reconstruction=st.Stack.from_stack(
                        SRR0.get_reconstruction()),
                    minimizer=args.minimizer,
                    alpha=args.alpha,
                    iter_max=args.iter_max,
                    rho=args.rho,
                    data_loss=args.data_loss,
                    iterations=args.iterations,
                    use_masks=args.use_masks_srr,
                    verbose=args.verbose,
                )

            else:
                SRR = pd.PrimalDualSolver(
                    stacks=stacks,
                    reconstruction=st.Stack.from_stack(
                        SRR0.get_reconstruction()),
                    minimizer=args.minimizer,
                    alpha=args.alpha,
                    iter_max=args.iter_max,
                    iterations=args.iterations,
                    alg_type=args.pd_alg_type,
                    reg_type="TV"
                    if args.reconstruction_type == "TVL2" else "huber",
                    data_loss=args.data_loss,
                    use_masks=args.use_masks_srr,
                    verbose=args.verbose,
                )
            SRR.run()
            recon = SRR.get_reconstruction()

            if args.mask:
                mask_estimator = bm.BinaryMaskFromMaskSRREstimator(recon.sitk)
                mask_estimator.run()
                mask_sitk = mask_estimator.get_mask_sitk()
                dw.DataWriter.write_mask(mask_sitk, args.output)

            else:
                dw.DataWriter.write_image(recon.sitk, args.output)

            if args.verbose:
                show_niftis.insert(0, args.output)

    if args.verbose:
        ph.show_niftis(show_niftis, viewer=args.viewer)

    ph.print_line_separator()

    elapsed_time = ph.stop_timing(time_start)
    ph.print_title("Summary")
    print("Computational Time for Volumetric Reconstruction: %s" %
          (elapsed_time))

    return 0