def _run(self):

        ph.print_title("Two-step S2V-Registration and SRR Reconstruction")

        # Use linear spacing for alphas excluding the last alpha reserved
        # for the final SRR step
        alphas = np.linspace(self._alpha_range[0], self._alpha_range[1],
                             self._cycles)
        alphas = alphas[0:self._cycles]

        s2vreg = SliceToVolumeRegistration(
            stacks=self._stacks,
            reference=self._reference,
            registration_method=self._registration_method,
            verbose=self._verbose)

        reference = self._reference

        for cycle in range(0, self._cycles):

            # Slice-to-volume registration step
            s2vreg.set_reference(reference)
            s2vreg.set_print_prefix("Cycle %d/%d: " %
                                    (cycle + 1, self._cycles))
            s2vreg.run()

            self._computational_time_registration += \
                s2vreg.get_computational_time()

            # SRR step
            if cycle < self._cycles - 1:
                self._reconstruction_method.set_alpha(alphas[cycle])
                self._reconstruction_method.run()

                self._computational_time_reconstruction += \
                    self._reconstruction_method.get_computational_time()

                reference = self._reconstruction_method.get_reconstruction()

                # Store SRR
                filename = "Iter%d_%s" % (cycle + 1,
                                          self._reconstruction_method.
                                          get_setting_specific_filename())
                self._reconstructions.insert(
                    0, st.Stack.from_stack(reference, filename=filename))

                if self._verbose:
                    sitkh.show_stacks(self._reconstructions,
                                      segmentation=self._reference)
    def _run(self, debug=1):
        ph.print_title(
            "Hierarchical SliceSet2V-Registration and SRR Reconstruction")

        N_stacks = len(self._stacks)

        # Minimum number of stacks at which no further splitting performed
        N_min = 1
        slice_sets_indices = [None] * N_stacks
        for i, stack in enumerate(self._stacks):
            slice_sets_indices[i] = \
                self._get_slice_set_indices_per_cycle(stack, N_min=N_min)

        # Debug
        if debug:
            for i, stack in enumerate(self._stacks):
                print("Stack %d/%d:" % (i + 1, N_stacks))
                for k, v in six.iteritems(slice_sets_indices[i]):
                    print("\tCycle %d: arrays = %s" % (k + 1, str(v)))

        N_cycles = np.max(
            [len(slice_sets_indices[i]) for i in range(N_stacks)])

        reference = st.Stack.from_stack(self._reference)
        alphas = np.linspace(self._alpha_range[0], self._alpha_range[1],
                             N_cycles + 1)
        alphas = alphas[0:N_cycles]

        ctr_iter = [0]
        for i_cycle in range(0, N_cycles):
            self._registration_method.set_moving(reference)

            slice_index_sets_of_stacks = {
                i: (slice_sets_indices[i][i_cycle]
                    if i_cycle in slice_sets_indices[i] else [])
                for i in range(len(self._stacks))
            }

            ss2vreg = SliceSetToVolumeRegistration(
                print_prefix="Cycle %d/%d -- " % (i_cycle + 1, N_cycles),
                stacks=self._stacks,
                reference=reference,
                registration_method=self._registration_method,
                slice_index_sets_of_stacks=slice_index_sets_of_stacks,
                verbose=self._verbose,
            )
            ss2vreg.run()
            self._computational_time_registration += \
                ss2vreg.get_computational_time()

            # SRR step
            self._reconstruction_method.set_alpha(alphas[i_cycle])
            self._reconstruction_method.run()

            self._computational_time_reconstruction += \
                self._reconstruction_method.get_computational_time()

            reference = self._reconstruction_method.get_reconstruction()

            # Store SRR
            filename = "Iter%d_%s" % (
                ph.add_one(ctr_iter),
                self._reconstruction_method.get_setting_specific_filename())
            self._reconstructions.insert(
                0, st.Stack.from_stack(reference, filename=filename))
            if self._verbose:
                sitkh.show_stacks(self._reconstructions)

        # Run slice-to-volume registration in case last hierarchical run was
        # not based on individual slices
        if N_min > 1:
            s2vreg = SliceToVolumeRegistration(
                stacks=self._stacks,
                reference=reference,
                registration_method=self._registration_method,
                verbose=self._verbose)
            s2vreg.run()
            self._computational_time_registration += \
                s2vreg.get_computational_time()

            # SRR step
            self._reconstruction_method.set_alpha(alphas[-1])
            self._reconstruction_method.run()

            self._computational_time_reconstruction += \
                self._reconstruction_method.get_computational_time()

            # Store SRR
            filename = "Iter%d_%s" % (
                ph.add_one(ctr_iter),
                self._reconstruction_method.get_setting_specific_filename())
            self._reconstructions.insert(
                0, st.Stack.from_stack(reference, filename=filename))

            if self._verbose:
                sitkh.show_stacks(self._reconstructions)
    def _run(self):

        ph.print_title("Two-step S2V-Registration and SRR Reconstruction")

        # Use linear spacing for alphas excluding the last alpha reserved
        # for the final SRR step
        alphas = np.linspace(self._alpha_range[0], self._alpha_range[1],
                             self._cycles)
        alphas = alphas[0:self._cycles]

        thresholds = np.linspace(self._threshold_range[0],
                                 self._threshold_range[1], self._cycles)
        thresholds = thresholds[0:self._cycles]

        s2vreg = SliceToVolumeRegistration(
            stacks=self._stacks,
            reference=self._reference,
            registration_method=self._registration_method,
            verbose=self._verbose,
            threshold_measure=self._threshold_measure,
            interleave=self._interleave,
        )

        reference = self._reference

        for cycle in range(0, self._cycles):

            # Slice-to-volume registration step
            s2vreg.set_reference(reference)
            s2vreg.set_print_prefix("Cycle %d/%d: " %
                                    (cycle + 1, self._cycles))
            if self._outlier_rejection:
                s2vreg.set_threshold(thresholds[cycle])
            if self._use_robust_registration and cycle == 0:
                s2vreg.set_s2v_smoothing(self._s2v_smoothing)
            else:
                s2vreg.set_s2v_smoothing(None)
            s2vreg.run()

            self._computational_time_registration += \
                s2vreg.get_computational_time()

            # SRR step
            if cycle < self._cycles - 1:
                # ---------------- Perform Image Reconstruction ---------------
                ph.print_subtitle("Volumetric Image Reconstruction")
                if isinstance(self._reconstruction_method,
                              sda.ScatteredDataApproximation):
                    self._reconstruction_method.set_sigma(alphas[cycle])
                else:
                    self._reconstruction_method.set_alpha(alphas[cycle])
                self._reconstruction_method.run()

                self._computational_time_reconstruction += \
                    self._reconstruction_method.get_computational_time()

                reference = self._reconstruction_method.get_reconstruction()

                # ------------------ Perform Image Mask SDA -------------------
                ph.print_subtitle("Volumetric Image Mask Reconstruction")
                SDA = sda.ScatteredDataApproximation(
                    self._stacks,
                    reference,
                    sigma=self._sigma_sda_mask,
                    sda_mask=True,
                )
                SDA.run()

                # reference contains updated mask based on SDA
                reference = SDA.get_reconstruction()

                # -------------------- Store Reconstruction -------------------
                filename = "Iter%d_%s" % (cycle + 1,
                                          self._reconstruction_method.
                                          get_setting_specific_filename())
                self._reconstructions.insert(
                    0, st.Stack.from_stack(reference, filename=filename))

                if self._verbose:
                    sitkh.show_stacks(self._reconstructions,
                                      segmentation=self._reference,
                                      viewer=self._viewer)
Esempio n. 4
0
    def _run(self):

        ph.print_title("Two-step S2V-Registration and SRR Reconstruction")

        s2vreg = SliceToVolumeRegistration(
            stacks=self._stacks,
            reference=self._reference,
            registration_method=self._registration_method,
            verbose=False,
            interleave=self._interleave,
        )

        reference = self._reference

        for cycle in range(0, self._cycles):

            if cycle == 0 and self._use_hierarchical_registration:
                hs2vreg = HieararchicalSliceSetRegistration(
                    stacks=self._stacks,
                    reference=reference,
                    registration_method=self._registration_method,
                    interleave=self._interleave,
                    viewer=self._viewer,
                    min_slices=1,
                    verbose=False,
                )
                hs2vreg.run()
                self._computational_time_registration += \
                    hs2vreg.get_computational_time()
            else:
                # Slice-to-volume registration step
                s2vreg.set_reference(reference)
                s2vreg.set_print_prefix("Cycle %d/%d: " %
                                        (cycle + 1, self._cycles))
                s2vreg.run()

            self._computational_time_registration += \
                s2vreg.get_computational_time()

            # Reject misregistered slices
            if self._outlier_rejection:
                ph.print_subtitle("Slice Outlier Rejection (%s < %g)" % (
                    self._threshold_measure, self._thresholds[cycle]))
                outlier_rejector = outre.OutlierRejector(
                    stacks=self._stacks,
                    reference=self._reference,
                    threshold=self._thresholds[cycle],
                    measure=self._threshold_measure,
                    verbose=True,
                )
                outlier_rejector.run()
                self._reconstruction_method.set_stacks(
                    outlier_rejector.get_stacks())

                if len(self._stacks) == 0:
                    raise RuntimeError(
                        "All slices of all stacks were rejected "
                        "as outliers. Volumetric reconstruction is aborted.")

            # SRR step
            if cycle < self._cycles - 1:
                # ---------------- Perform Image Reconstruction ---------------
                ph.print_subtitle("Volumetric Image Reconstruction")
                if isinstance(
                    self._reconstruction_method,
                    sda.ScatteredDataApproximation
                ):
                    self._reconstruction_method.set_sigma(self._alphas[cycle])
                else:
                    self._reconstruction_method.set_alpha(self._alphas[cycle])
                self._reconstruction_method.run()

                self._computational_time_reconstruction += \
                    self._reconstruction_method.get_computational_time()

                reference = self._reconstruction_method.get_reconstruction()

                # ------------------ Perform Image Mask SDA -------------------
                ph.print_subtitle("Volumetric Image Mask Reconstruction")
                SDA = sda.ScatteredDataApproximation(
                    self._stacks,
                    reference,
                    sigma=self._sigma_sda_mask,
                    sda_mask=True,
                )
                SDA.run()

                # reference contains updated mask based on SDA
                reference = SDA.get_reconstruction()

                # -------------------- Store Reconstruction -------------------
                filename = "Iter%d_%s" % (
                    cycle + 1,
                    self._reconstruction_method.get_setting_specific_filename()
                )
                self._reconstructions.insert(0, st.Stack.from_stack(
                    reference, filename=filename))

                if self._verbose:
                    sitkh.show_stacks(self._reconstructions,
                                      segmentation=self._reference,
                                      viewer=self._viewer)
def main():

    input_parser = InputArgparser(
        description="Simulate stacks from obtained reconstruction. "
        "Script simulates/projects the slices at estimated positions "
        "within reconstructed volume. Ideally, if motion correction was "
        "correct, the resulting stack of such obtained projected slices, "
        "corresponds to the originally acquired (motion corrupted) data.",
    )
    input_parser.add_filenames(required=True)
    input_parser.add_filenames_masks()
    input_parser.add_dir_input_mc(required=True)
    input_parser.add_reconstruction(required=True)
    input_parser.add_dir_output(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_prefix_output(default="Simulated_")
    input_parser.add_option(
        option_string="--copy-data",
        type=int,
        help="Turn on/off copying of original data (including masks) to "
        "output folder.",
        default=0)
    input_parser.add_option(
        option_string="--reconstruction-mask",
        type=str,
        help="If given, reconstruction image mask is propagated to "
        "simulated stack(s) of slices as well",
        default=None)
    input_parser.add_interpolator(
        option_string="--interpolator-mask",
        help="Choose the interpolator type to propagate the reconstruction "
        "mask (%s)." % (INTERPOLATOR_TYPES),
        default="NearestNeighbor")
    input_parser.add_log_config(default=0)
    input_parser.add_verbose(default=0)
    input_parser.add_slice_thicknesses(default=None)

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    if args.interpolator_mask not in ALLOWED_INTERPOLATORS:
        raise IOError(
            "Unknown interpolator provided. Possible choices are %s" % (
                INTERPOLATOR_TYPES))

    if args.log_config:
        input_parser.log_config(os.path.abspath(__file__))

    # Read motion corrected data
    data_reader = dr.MultipleImagesReader(
        file_paths=args.filenames,
        file_paths_masks=args.filenames_masks,
        suffix_mask=args.suffix_mask,
        dir_motion_correction=args.dir_input_mc,
        stacks_slice_thicknesses=args.slice_thicknesses,
    )
    data_reader.read_data()
    stacks = data_reader.get_data()

    reconstruction = st.Stack.from_filename(
        args.reconstruction, args.reconstruction_mask, extract_slices=False)

    linear_operators = lin_op.LinearOperators()

    for i, stack in enumerate(stacks):

        # initialize image data array(s)
        nda = np.zeros_like(sitk.GetArrayFromImage(stack.sitk))
        nda[:] = np.nan

        if args.reconstruction_mask:
            nda_mask = np.zeros_like(sitk.GetArrayFromImage(stack.sitk_mask))

        slices = stack.get_slices()
        kept_indices = [s.get_slice_number() for s in slices]

        # Fill stack information "as if slice was acquired consecutively"
        # Therefore, simulated stack slices correspond to acquired slices
        # (in case motion correction was correct)
        for j in range(nda.shape[0]):
            if j in kept_indices:
                index = kept_indices.index(j)
                simulated_slice = linear_operators.A(
                    reconstruction,
                    slices[index],
                    interpolator_mask=args.interpolator_mask
                )
                nda[j, :, :] = sitk.GetArrayFromImage(simulated_slice.sitk)

                if args.reconstruction_mask:
                    nda_mask[j, :, :] = sitk.GetArrayFromImage(
                        simulated_slice.sitk_mask)

        # Create nifti image with same image header as original stack
        simulated_stack_sitk = sitk.GetImageFromArray(nda)
        simulated_stack_sitk.CopyInformation(stack.sitk)

        if args.reconstruction_mask:
            simulated_stack_sitk_mask = sitk.GetImageFromArray(nda_mask)
            simulated_stack_sitk_mask.CopyInformation(stack.sitk_mask)
        else:
            simulated_stack_sitk_mask = None

        simulated_stack = st.Stack.from_sitk_image(
            image_sitk=simulated_stack_sitk,
            image_sitk_mask=simulated_stack_sitk_mask,
            filename=args.prefix_output + stack.get_filename(),
            extract_slices=False,
            slice_thickness=stack.get_slice_thickness(),
        )

        if args.verbose:
            sitkh.show_stacks([
                stack, simulated_stack],
                segmentation=stack)

        simulated_stack.write(
            args.dir_output,
            write_mask=False,
            write_slices=False,
            suffix_mask=args.suffix_mask)

        if args.copy_data:
            stack.write(
                args.dir_output,
                write_mask=True,
                write_slices=False,
                suffix_mask="_mask")

    return 0
Esempio n. 6
0
def main():

    time_start = ph.start_timing()

    # Set print options for numpy
    np.set_printoptions(precision=3)

    input_parser = InputArgparser(
        description="Volumetric MRI reconstruction framework to reconstruct "
        "an isotropic, high-resolution 3D volume from multiple stacks of 2D "
        "slices with motion correction. The resolution of the computed "
        "Super-Resolution Reconstruction (SRR) is given by the in-plane "
        "spacing of the selected target stack. A region of interest can be "
        "specified by providing a mask for the selected target stack. Only "
        "this region will then be reconstructed by the SRR algorithm which "
        "can substantially reduce the computational time.",
    )
    input_parser.add_filenames(required=True)
    input_parser.add_filenames_masks()
    input_parser.add_output(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_target_stack(default=None)
    input_parser.add_search_angle(default=45)
    input_parser.add_multiresolution(default=0)
    input_parser.add_shrink_factors(default=[3, 2, 1])
    input_parser.add_smoothing_sigmas(default=[1.5, 1, 0])
    input_parser.add_sigma(default=1)
    input_parser.add_reconstruction_type(default="TK1L2")
    input_parser.add_iterations(default=15)
    input_parser.add_alpha(default=0.015)
    input_parser.add_alpha_first(default=0.2)
    input_parser.add_iter_max(default=10)
    input_parser.add_iter_max_first(default=5)
    input_parser.add_dilation_radius(default=3)
    input_parser.add_extra_frame_target(default=10)
    input_parser.add_bias_field_correction(default=0)
    input_parser.add_intensity_correction(default=1)
    input_parser.add_isotropic_resolution(default=1)
    input_parser.add_log_config(default=1)
    input_parser.add_subfolder_motion_correction()
    input_parser.add_write_motion_correction(default=1)
    input_parser.add_verbose(default=0)
    input_parser.add_two_step_cycles(default=3)
    input_parser.add_use_masks_srr(default=0)
    input_parser.add_boundary_stacks(default=[10, 10, 0])
    input_parser.add_metric(default="Correlation")
    input_parser.add_metric_radius(default=10)
    input_parser.add_reference()
    input_parser.add_reference_mask()
    input_parser.add_outlier_rejection(default=1)
    input_parser.add_threshold_first(default=0.5)
    input_parser.add_threshold(default=0.8)
    input_parser.add_interleave(default=3)
    input_parser.add_slice_thicknesses(default=None)
    input_parser.add_viewer(default="itksnap")
    input_parser.add_v2v_method(default="RegAladin")
    input_parser.add_argument(
        "--v2v-robust", "-v2v-robust",
        action='store_true',
        help="If given, a more robust volume-to-volume registration step is "
        "performed, i.e. four rigid registrations are performed using four "
        "rigid transform initializations based on "
        "principal component alignment of associated masks."
    )
    input_parser.add_argument(
        "--s2v-hierarchical", "-s2v-hierarchical",
        action='store_true',
        help="If given, a hierarchical approach for the first slice-to-volume "
        "registration cycle is used, i.e. sub-packages defined by the "
        "specified interleave (--interleave) are registered until each "
        "slice is registered independently."
    )
    input_parser.add_argument(
        "--sda", "-sda",
        action='store_true',
        help="If given, the volumetric reconstructions are performed using "
        "Scattered Data Approximation (Vercauteren et al., 2006). "
        "'alpha' is considered the final 'sigma' for the "
        "iterative adjustment. "
        "Recommended value is, e.g., --alpha 0.8"
    )
    input_parser.add_option(
        option_string="--transforms-history",
        type=int,
        help="Write entire history of applied slice motion correction "
        "transformations to motion correction output directory",
        default=0,
    )

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    rejection_measure = "NCC"
    threshold_v2v = -2  # 0.3
    debug = False

    if args.v2v_method not in V2V_METHOD_OPTIONS:
        raise ValueError("v2v-method must be in {%s}" % (
            ", ".join(V2V_METHOD_OPTIONS)))

    if np.alltrue([not args.output.endswith(t) for t in ALLOWED_EXTENSIONS]):
        raise ValueError(
            "output filename invalid; allowed extensions are: %s" %
            ", ".join(ALLOWED_EXTENSIONS))

    if args.alpha_first < args.alpha and not args.sda:
        raise ValueError("It must hold alpha-first >= alpha")

    if args.threshold_first > args.threshold:
        raise ValueError("It must hold threshold-first <= threshold")

    dir_output = os.path.dirname(args.output)
    ph.create_directory(dir_output)

    if args.log_config:
        input_parser.log_config(os.path.abspath(__file__))

    # --------------------------------Read Data--------------------------------
    ph.print_title("Read Data")
    data_reader = dr.MultipleImagesReader(
        file_paths=args.filenames,
        file_paths_masks=args.filenames_masks,
        suffix_mask=args.suffix_mask,
        stacks_slice_thicknesses=args.slice_thicknesses,
    )

    if len(args.boundary_stacks) is not 3:
        raise IOError(
            "Provide exactly three values for '--boundary-stacks' to define "
            "cropping in i-, j-, and k-dimension of the input stacks")

    data_reader.read_data()
    stacks = data_reader.get_data()
    ph.print_info("%d input stacks read for further processing" % len(stacks))

    if all(s.is_unity_mask() is True for s in stacks):
        ph.print_warning("No mask is provided! "
                         "Generated reconstruction space may be very big!")
        ph.print_warning("Consider using a mask to speed up computations")

        # args.extra_frame_target = 0
        # ph.wrint_warning("Overwritten: extra-frame-target set to 0")

    # Specify target stack for intensity correction and reconstruction space
    if args.target_stack is None:
        target_stack_index = 0
    else:
        try:
            target_stack_index = args.filenames.index(args.target_stack)
        except ValueError as e:
            raise ValueError(
                "--target-stack must correspond to an image as provided by "
                "--filenames")

    # ---------------------------Data Preprocessing---------------------------
    ph.print_title("Data Preprocessing")

    segmentation_propagator = segprop.SegmentationPropagation(
        # registration_method=regflirt.FLIRT(use_verbose=args.verbose),
        # registration_method=niftyreg.RegAladin(use_verbose=False),
        dilation_radius=args.dilation_radius,
        dilation_kernel="Ball",
    )

    data_preprocessing = dp.DataPreprocessing(
        stacks=stacks,
        segmentation_propagator=segmentation_propagator,
        use_cropping_to_mask=True,
        use_N4BiasFieldCorrector=args.bias_field_correction,
        target_stack_index=target_stack_index,
        boundary_i=args.boundary_stacks[0],
        boundary_j=args.boundary_stacks[1],
        boundary_k=args.boundary_stacks[2],
        unit="mm",
    )
    data_preprocessing.run()
    time_data_preprocessing = data_preprocessing.get_computational_time()

    # Get preprocessed stacks
    stacks = data_preprocessing.get_preprocessed_stacks()

    # Define reference/target stack for registration and reconstruction
    if args.reference is not None:
        reference = st.Stack.from_filename(
            file_path=args.reference,
            file_path_mask=args.reference_mask,
            extract_slices=False)

    else:
        reference = st.Stack.from_stack(stacks[target_stack_index])

    # ------------------------Volume-to-Volume Registration--------------------
    if len(stacks) > 1:

        if args.v2v_method == "FLIRT":
            # Define search angle ranges for FLIRT in all three dimensions
            search_angles = ["-searchr%s -%d %d" %
                             (x, args.search_angle, args.search_angle)
                             for x in ["x", "y", "z"]]
            options = (" ").join(search_angles)
            # options += " -noresample"

            vol_registration = regflirt.FLIRT(
                registration_type="Rigid",
                use_fixed_mask=True,
                use_moving_mask=True,
                options=options,
                use_verbose=False,
            )
        else:
            vol_registration = niftyreg.RegAladin(
                registration_type="Rigid",
                use_fixed_mask=True,
                use_moving_mask=True,
                # options="-ln 2 -voff",
                use_verbose=False,
            )
        v2vreg = pipeline.VolumeToVolumeRegistration(
            stacks=stacks,
            reference=reference,
            registration_method=vol_registration,
            verbose=debug,
            robust=args.v2v_robust,
        )
        v2vreg.run()
        stacks = v2vreg.get_stacks()
        time_registration = v2vreg.get_computational_time()

    else:
        time_registration = ph.get_zero_time()

    # ---------------------------Intensity Correction--------------------------
    if args.intensity_correction:
        ph.print_title("Intensity Correction")
        intensity_corrector = ic.IntensityCorrection()
        intensity_corrector.use_individual_slice_correction(False)
        intensity_corrector.use_reference_mask(True)
        intensity_corrector.use_stack_mask(True)
        intensity_corrector.use_verbose(False)

        for i, stack in enumerate(stacks):
            if i == target_stack_index:
                ph.print_info("Stack %d (%s): Reference image. Skipped." % (
                    i + 1, stack.get_filename()))
                continue
            else:
                ph.print_info("Stack %d (%s): Intensity Correction ... " % (
                    i + 1, stack.get_filename()), newline=False)
            intensity_corrector.set_stack(stack)
            intensity_corrector.set_reference(
                stacks[target_stack_index].get_resampled_stack(
                    resampling_grid=stack.sitk,
                    interpolator="NearestNeighbor",
                ))
            intensity_corrector.run_linear_intensity_correction()
            stacks[i] = intensity_corrector.get_intensity_corrected_stack()
            print("done (c1 = %g) " %
                  intensity_corrector.get_intensity_correction_coefficients())

    # ---------------------------Create first volume---------------------------
    time_tmp = ph.start_timing()

    # Isotropic resampling to define HR target space
    ph.print_title("Reconstruction Space Generation")
    HR_volume = reference.get_isotropically_resampled_stack(
        resolution=args.isotropic_resolution)
    ph.print_info(
        "Isotropic reconstruction space with %g mm resolution is created" %
        HR_volume.sitk.GetSpacing()[0])

    if args.reference is None:
        # Create joint image mask in target space
        joint_image_mask_builder = imb.JointImageMaskBuilder(
            stacks=stacks,
            target=HR_volume,
            dilation_radius=1,
        )
        joint_image_mask_builder.run()
        HR_volume = joint_image_mask_builder.get_stack()
        ph.print_info(
            "Isotropic reconstruction space is centered around "
            "joint stack masks. ")

        # Crop to space defined by mask (plus extra margin)
        HR_volume = HR_volume.get_cropped_stack_based_on_mask(
            boundary_i=args.extra_frame_target,
            boundary_j=args.extra_frame_target,
            boundary_k=args.extra_frame_target,
            unit="mm",
        )

        # Create first volume
        # If outlier rejection is activated, eliminate obvious outliers early
        # from stack and re-run SDA to get initial volume without them
        ph.print_title("First Estimate of HR Volume")
        if args.outlier_rejection and threshold_v2v > -1:
            ph.print_subtitle("SDA Approximation")
            SDA = sda.ScatteredDataApproximation(
                stacks, HR_volume, sigma=args.sigma)
            SDA.run()
            HR_volume = SDA.get_reconstruction()

            # Identify and reject outliers
            ph.print_subtitle("Eliminate slice outliers (%s < %g)" % (
                rejection_measure, threshold_v2v))
            outlier_rejector = outre.OutlierRejector(
                stacks=stacks,
                reference=HR_volume,
                threshold=threshold_v2v,
                measure=rejection_measure,
                verbose=True,
            )
            outlier_rejector.run()
            stacks = outlier_rejector.get_stacks()

        ph.print_subtitle("SDA Approximation Image")
        SDA = sda.ScatteredDataApproximation(
            stacks, HR_volume, sigma=args.sigma)
        SDA.run()
        HR_volume = SDA.get_reconstruction()

        ph.print_subtitle("SDA Approximation Image Mask")
        SDA = sda.ScatteredDataApproximation(
            stacks, HR_volume, sigma=args.sigma, sda_mask=True)
        SDA.run()
        # HR volume contains updated mask based on SDA
        HR_volume = SDA.get_reconstruction()

        HR_volume.set_filename(SDA.get_setting_specific_filename())

    time_reconstruction = ph.stop_timing(time_tmp)

    if args.verbose:
        tmp = list(stacks)
        tmp.insert(0, HR_volume)
        sitkh.show_stacks(tmp, segmentation=HR_volume, viewer=args.viewer)

    # -----------Two-step Slice-to-Volume Registration-Reconstruction----------
    if args.two_step_cycles > 0:

        # Slice-to-volume registration set-up
        if args.metric == "ANTSNeighborhoodCorrelation":
            metric_params = {"radius": args.metric_radius}
        else:
            metric_params = None
        registration = regsitk.SimpleItkRegistration(
            moving=HR_volume,
            use_fixed_mask=True,
            use_moving_mask=True,
            interpolator="Linear",
            metric=args.metric,
            metric_params=metric_params,
            use_multiresolution_framework=args.multiresolution,
            shrink_factors=args.shrink_factors,
            smoothing_sigmas=args.smoothing_sigmas,
            initializer_type="SelfGEOMETRY",
            optimizer="ConjugateGradientLineSearch",
            optimizer_params={
                "learningRate": 1,
                "numberOfIterations": 100,
                "lineSearchUpperLimit": 2,
            },
            scales_estimator="Jacobian",
            use_verbose=debug,
        )

        # Volumetric reconstruction set-up
        if args.sda:
            recon_method = sda.ScatteredDataApproximation(
                stacks,
                HR_volume,
                sigma=args.sigma,
                use_masks=args.use_masks_srr,
            )
            alpha_range = [args.sigma, args.alpha]
        else:
            recon_method = tk.TikhonovSolver(
                stacks=stacks,
                reconstruction=HR_volume,
                reg_type="TK1",
                minimizer="lsmr",
                alpha=args.alpha_first,
                iter_max=np.min([args.iter_max_first, args.iter_max]),
                verbose=True,
                use_masks=args.use_masks_srr,
            )
            alpha_range = [args.alpha_first, args.alpha]

        # Define the regularization parameters for the individual
        # reconstruction steps in the two-step cycles
        alphas = np.linspace(
            alpha_range[0], alpha_range[1], args.two_step_cycles)

        # Define outlier rejection threshold after each S2V-reg step
        thresholds = np.linspace(
            args.threshold_first, args.threshold, args.two_step_cycles)

        two_step_s2v_reg_recon = \
            pipeline.TwoStepSliceToVolumeRegistrationReconstruction(
                stacks=stacks,
                reference=HR_volume,
                registration_method=registration,
                reconstruction_method=recon_method,
                cycles=args.two_step_cycles,
                alphas=alphas[0:args.two_step_cycles - 1],
                outlier_rejection=args.outlier_rejection,
                threshold_measure=rejection_measure,
                thresholds=thresholds,
                interleave=args.interleave,
                viewer=args.viewer,
                verbose=args.verbose,
                use_hierarchical_registration=args.s2v_hierarchical,
            )
        two_step_s2v_reg_recon.run()
        HR_volume_iterations = \
            two_step_s2v_reg_recon.get_iterative_reconstructions()
        time_registration += \
            two_step_s2v_reg_recon.get_computational_time_registration()
        time_reconstruction += \
            two_step_s2v_reg_recon.get_computational_time_reconstruction()
        stacks = two_step_s2v_reg_recon.get_stacks()

    # no two-step s2v-registration/reconstruction iterations
    else:
        HR_volume_iterations = []

    # Write motion-correction results
    ph.print_title("Write Motion Correction Results")
    if args.write_motion_correction:
        dir_output_mc = os.path.join(
            dir_output, args.subfolder_motion_correction)
        ph.clear_directory(dir_output_mc)

        for stack in stacks:
            stack.write(
                dir_output_mc,
                write_stack=False,
                write_mask=False,
                write_slices=False,
                write_transforms=True,
                write_transforms_history=args.transforms_history,
            )

        if args.outlier_rejection:
            deleted_slices_dic = {}
            for i, stack in enumerate(stacks):
                deleted_slices = stack.get_deleted_slice_numbers()
                deleted_slices_dic[stack.get_filename()] = deleted_slices

            # check whether any stack was removed entirely
            stacks0 = data_preprocessing.get_preprocessed_stacks()
            if len(stacks) != len(stacks0):
                stacks_remain = [s.get_filename() for s in stacks]
                for stack in stacks0:
                    if stack.get_filename() in stacks_remain:
                        continue

                    # add info that all slices of this stack were rejected
                    deleted_slices = [
                        slice.get_slice_number()
                        for slice in stack.get_slices()
                    ]
                    deleted_slices_dic[stack.get_filename()] = deleted_slices
                    ph.print_info(
                        "All slices of stack '%s' were rejected entirely. "
                        "Information added." % stack.get_filename())

            ph.write_dictionary_to_json(
                deleted_slices_dic,
                os.path.join(
                    dir_output,
                    args.subfolder_motion_correction,
                    "rejected_slices.json"
                )
            )

    # ---------------------Final Volumetric Reconstruction---------------------
    ph.print_title("Final Volumetric Reconstruction")
    if args.sda:
        recon_method = sda.ScatteredDataApproximation(
            stacks,
            HR_volume,
            sigma=args.alpha,
            use_masks=args.use_masks_srr,
        )
    else:
        if args.reconstruction_type in ["TVL2", "HuberL2"]:
            recon_method = pd.PrimalDualSolver(
                stacks=stacks,
                reconstruction=HR_volume,
                reg_type="TV" if args.reconstruction_type == "TVL2" else "huber",
                iterations=args.iterations,
                use_masks=args.use_masks_srr,
            )
        else:
            recon_method = tk.TikhonovSolver(
                stacks=stacks,
                reconstruction=HR_volume,
                reg_type="TK1" if args.reconstruction_type == "TK1L2" else "TK0",
                use_masks=args.use_masks_srr,
            )
        recon_method.set_alpha(args.alpha)
        recon_method.set_iter_max(args.iter_max)
        recon_method.set_verbose(True)
    recon_method.run()
    time_reconstruction += recon_method.get_computational_time()
    HR_volume_final = recon_method.get_reconstruction()

    ph.print_subtitle("Final SDA Approximation Image Mask")
    SDA = sda.ScatteredDataApproximation(
        stacks, HR_volume_final, sigma=args.sigma, sda_mask=True)
    SDA.run()
    # HR volume contains updated mask based on SDA
    HR_volume_final = SDA.get_reconstruction()
    time_reconstruction += SDA.get_computational_time()

    elapsed_time_total = ph.stop_timing(time_start)

    # Write SRR result
    filename = recon_method.get_setting_specific_filename()
    HR_volume_final.set_filename(filename)
    dw.DataWriter.write_image(
        HR_volume_final.sitk,
        args.output,
        description=filename)
    dw.DataWriter.write_mask(
        HR_volume_final.sitk_mask,
        ph.append_to_filename(args.output, "_mask"),
        description=SDA.get_setting_specific_filename())

    HR_volume_iterations.insert(0, HR_volume_final)
    for stack in stacks:
        HR_volume_iterations.append(stack)

    if args.verbose:
        sitkh.show_stacks(
            HR_volume_iterations,
            segmentation=HR_volume_final,
            viewer=args.viewer,
        )

    # Summary
    ph.print_title("Summary")
    exe_file_info = os.path.basename(os.path.abspath(__file__)).split(".")[0]
    print("%s | Computational Time for Data Preprocessing: %s" %
          (exe_file_info, time_data_preprocessing))
    print("%s | Computational Time for Registrations: %s" %
          (exe_file_info, time_registration))
    print("%s | Computational Time for Reconstructions: %s" %
          (exe_file_info, time_reconstruction))
    print("%s | Computational Time for Entire Reconstruction Pipeline: %s" %
          (exe_file_info, elapsed_time_total))

    ph.print_line_separator()

    return 0
def main():

    time_start = ph.start_timing()

    np.set_printoptions(precision=3)

    input_parser = InputArgparser(
        description="Perform (linear) intensity correction across "
        "stacks/images given a reference stack/image", )
    input_parser.add_filenames(required=True)
    input_parser.add_dir_output(required=True)
    input_parser.add_reference(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_search_angle(default=180)
    input_parser.add_prefix_output(default="IC_")
    input_parser.add_log_config(default=1)
    input_parser.add_option(
        option_string="--registration",
        type=int,
        help="Turn on/off registration from image to reference prior to "
        "intensity correction.",
        default=0)
    input_parser.add_verbose(default=0)

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    if args.log_config:
        input_parser.log_config(os.path.abspath(__file__))

    if args.reference in args.filenames:
        args.filenames.remove(args.reference)

    # Read data
    data_reader = dr.MultipleImagesReader(args.filenames,
                                          suffix_mask=args.suffix_mask,
                                          extract_slices=False)
    data_reader.read_data()
    stacks = data_reader.get_data()

    data_reader = dr.MultipleImagesReader([args.reference],
                                          suffix_mask=args.suffix_mask,
                                          extract_slices=False)
    data_reader.read_data()
    reference = data_reader.get_data()[0]

    if args.registration:
        # Define search angle ranges for FLIRT in all three dimensions
        search_angles = [
            "-searchr%s -%d %d" % (x, args.search_angle, args.search_angle)
            for x in ["x", "y", "z"]
        ]
        search_angles = (" ").join(search_angles)
        registration = regflirt.FLIRT(
            moving=reference,
            registration_type="Rigid",
            use_fixed_mask=True,
            use_moving_mask=True,
            options=search_angles,
            use_verbose=False,
        )

    # Perform Intensity Correction
    ph.print_title("Perform Intensity Correction")
    intensity_corrector = ic.IntensityCorrection(
        use_reference_mask=True,
        use_individual_slice_correction=False,
        prefix_corrected=args.prefix_output,
        use_verbose=False,
    )
    stacks_corrected = [None] * len(stacks)
    for i, stack in enumerate(stacks):
        if args.registration:
            ph.print_info("Image %d/%d: Registration ... " %
                          (i + 1, len(stacks)),
                          newline=False)
            registration.set_fixed(stack)
            registration.run()
            transform_sitk = registration.get_registration_transform_sitk()
            stack.update_motion_correction(transform_sitk)
            print("done")

        ph.print_info("Image %d/%d: Intensity Correction ... " %
                      (i + 1, len(stacks)),
                      newline=False)

        ref = reference.get_resampled_stack(stack.sitk)
        ref = st.Stack.from_sitk_image(image_sitk=ref.sitk,
                                       image_sitk_mask=stack.sitk_mask *
                                       ref.sitk_mask,
                                       filename=reference.get_filename())
        intensity_corrector.set_stack(stack)
        intensity_corrector.set_reference(ref)
        intensity_corrector.run_linear_intensity_correction()
        # intensity_corrector.run_affine_intensity_correction()
        stacks_corrected[i] = \
            intensity_corrector.get_intensity_corrected_stack()
        print("done (c1 = %g) " %
              intensity_corrector.get_intensity_correction_coefficients())

        # Write Data
        stacks_corrected[i].write(args.dir_output,
                                  write_mask=True,
                                  suffix_mask=args.suffix_mask)

        if args.verbose:
            sitkh.show_stacks(
                [
                    reference,
                    stacks_corrected[i],
                    # stacks[i],
                ],
                segmentation=stacks_corrected[i])
            # ph.pause()

    # Write reference too (although not intensity corrected)
    reference.write(args.dir_output,
                    filename=args.prefix_output + reference.get_filename(),
                    write_mask=True,
                    suffix_mask=args.suffix_mask)

    elapsed_time = ph.stop_timing(time_start)

    ph.print_title("Summary")
    print("Computational Time for Intensity Correction(s): %s" %
          (elapsed_time))

    return 0
Esempio n. 8
0
def main():

    # Set print options
    np.set_printoptions(precision=3)
    pd.set_option('display.width', 1000)

    input_parser = InputArgparser(description=".", )
    input_parser.add_filenames(required=True)
    input_parser.add_reference(required=True)
    input_parser.add_reference_mask()
    input_parser.add_dir_output(required=False)
    input_parser.add_measures(
        default=["PSNR", "RMSE", "MAE", "SSIM", "NCC", "NMI"])
    input_parser.add_verbose(default=0)
    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    ph.print_title("Image similarity")
    data_reader = dr.MultipleImagesReader(args.filenames)
    data_reader.read_data()
    stacks = data_reader.get_data()

    reference = st.Stack.from_filename(args.reference, args.reference_mask)

    for stack in stacks:
        try:
            stack.sitk - reference.sitk
        except RuntimeError as e:
            raise IOError(
                "All provided images must be at the same image space")

    x_ref = sitk.GetArrayFromImage(reference.sitk)

    if args.reference_mask is None:
        indices = np.where(x_ref != np.inf)
    else:
        x_ref_mask = sitk.GetArrayFromImage(reference.sitk_mask)
        indices = np.where(x_ref_mask > 0)

    measures_dic = {
        m: lambda x, m=m: SimilarityMeasures.similarity_measures[m]
        (x[indices], x_ref[indices])
        # SimilarityMeasures.similarity_measures[m](x, x_ref)
        for m in args.measures
    }

    observer = obs.Observer()
    observer.set_measures(measures_dic)
    for stack in stacks:
        nda = sitk.GetArrayFromImage(stack.sitk)
        observer.add_x(nda)

    if args.verbose:
        stacks_comparison = [s for s in stacks]
        stacks_comparison.insert(0, reference)
        sitkh.show_stacks(
            stacks_comparison,
            segmentation=reference,
        )

    observer.compute_measures()
    measures = observer.get_measures()

    # Store information in array
    error = np.zeros((len(stacks), len(measures)))
    cols = measures
    rows = []
    for i_stack, stack in enumerate(stacks):
        error[i_stack, :] = np.array([measures[m][i_stack] for m in measures])
        rows.append(stack.get_filename())

    header = "# Ref: %s, Ref-Mask: %d, %s \n" % (
        reference.get_filename(),
        args.reference_mask is None,
        ph.get_time_stamp(),
    )
    header += "# %s\n" % ("\t").join(measures)

    path_to_file_filenames = os.path.join(args.dir_output, "filenames.txt")
    path_to_file_similarities = os.path.join(args.dir_output,
                                             "similarities.txt")

    # Write to files
    ph.write_to_file(path_to_file_similarities, header)
    ph.write_array_to_file(path_to_file_similarities, error, verbose=False)
    text = header
    text += "%s\n" % "\n".join(rows)
    ph.write_to_file(path_to_file_filenames, text)

    # Print to screen
    ph.print_subtitle("Computed Similarities")
    df = pd.DataFrame(error, rows, cols)
    print(df)

    return 0
Esempio n. 9
0
def main():

    time_start = ph.start_timing()

    # Set print options for numpy
    np.set_printoptions(precision=3)

    # Read input
    input_parser = InputArgparser(
        description="Volumetric MRI reconstruction framework to reconstruct "
        "an isotropic, high-resolution 3D volume from multiple "
        "motion-corrected (or static) stacks of low-resolution slices.", )
    input_parser.add_dir_input()
    input_parser.add_filenames()
    input_parser.add_image_selection()
    input_parser.add_dir_output(required=True)
    input_parser.add_prefix_output(default="SRR_")
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_target_stack_index(default=0)
    input_parser.add_extra_frame_target(default=10)
    input_parser.add_isotropic_resolution(default=None)
    input_parser.add_reconstruction_space(default=None)
    input_parser.add_minimizer(default="lsmr")
    input_parser.add_iter_max(default=10)
    input_parser.add_reconstruction_type(default="TK1L2")
    input_parser.add_data_loss(default="linear")
    input_parser.add_data_loss_scale(default=1)
    input_parser.add_alpha(default=0.02  # TK1L2
                           # default=0.006  #TVL2, HuberL2
                           )
    input_parser.add_rho(default=0.5)
    input_parser.add_tv_solver(default="PD")
    input_parser.add_pd_alg_type(default="ALG2")
    input_parser.add_iterations(default=15)
    input_parser.add_subfolder_comparison()
    input_parser.add_provide_comparison(default=0)
    input_parser.add_log_script_execution(default=1)
    input_parser.add_verbose(default=0)
    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    # Write script execution call
    if args.log_script_execution:
        input_parser.write_performed_script_execution(
            os.path.abspath(__file__))

    # --------------------------------Read Data--------------------------------
    ph.print_title("Read Data")

    # Neither '--dir-input' nor '--filenames' was specified
    if args.filenames is not None and args.dir_input is not None:
        raise IOError("Provide input by either '--dir-input' or '--filenames' "
                      "but not both together")

    # '--dir-input' specified
    elif args.dir_input is not None:
        data_reader = dr.ImageSlicesDirectoryReader(
            path_to_directory=args.dir_input,
            suffix_mask=args.suffix_mask,
            image_selection=args.image_selection)

    # '--filenames' specified
    elif args.filenames is not None:
        data_reader = dr.MultipleImagesReader(args.filenames,
                                              suffix_mask=args.suffix_mask)

    else:
        raise IOError("Provide input by either '--dir-input' or '--filenames'")

    if args.reconstruction_type not in ["TK1L2", "TVL2", "HuberL2"]:
        raise IOError("Reconstruction type unknown")

    data_reader.read_data()
    stacks = data_reader.get_data()
    ph.print_info("%d input stacks read for further processing" % len(stacks))

    # Reconstruction space is given isotropically resampled target stack
    if args.reconstruction_space is None:
        recon0 = \
            stacks[args.target_stack_index].get_isotropically_resampled_stack(
                resolution=args.isotropic_resolution,
                extra_frame=args.extra_frame_target)

    # Reconstruction space was provided by user
    else:
        recon0 = st.Stack.from_filename(args.reconstruction_space,
                                        extract_slices=False)

        # Change resolution for isotropic resolution if provided by user
        if args.isotropic_resolution is not None:
            recon0 = recon0.get_isotropically_resampled_stack(
                args.isotropic_resolution)

        # Use image information of selected target stack as recon0 serves
        # as initial value for reconstruction
        recon0 = \
            stacks[args.target_stack_index].get_resampled_stack(recon0.sitk)
        recon0 = recon0.get_stack_multiplied_with_mask()

    if args.reconstruction_type in ["TVL2", "HuberL2"]:
        ph.print_title("Compute Initial value for %s" %
                       args.reconstruction_type)
    SRR0 = tk.TikhonovSolver(
        stacks=stacks,
        reconstruction=recon0,
        alpha=args.alpha,
        iter_max=args.iter_max,
        reg_type="TK1",
        minimizer=args.minimizer,
        data_loss=args.data_loss,
        data_loss_scale=args.data_loss_scale,
        # verbose=args.verbose,
    )
    SRR0.run()

    recon = SRR0.get_reconstruction()
    recon.set_filename(SRR0.get_setting_specific_filename(args.prefix_output))
    recon.write(args.dir_output)

    # List to store SRRs
    recons = []
    for i in range(0, len(stacks)):
        recons.append(stacks[i])
    recons.insert(0, recon)

    if args.reconstruction_type in ["TVL2", "HuberL2"]:
        ph.print_title("Compute %s reconstruction" % args.reconstruction_type)
        if args.tv_solver == "ADMM":
            SRR = admm.ADMMSolver(
                stacks=stacks,
                reconstruction=st.Stack.from_stack(SRR0.get_reconstruction()),
                minimizer=args.minimizer,
                alpha=args.alpha,
                iter_max=args.iter_max,
                rho=args.rho,
                data_loss=args.data_loss,
                iterations=args.iterations,
                verbose=args.verbose,
            )
            SRR.run()
            recon = SRR.get_reconstruction()
            recon.set_filename(
                SRR.get_setting_specific_filename(args.prefix_output))
            recons.insert(0, recon)

            recon.write(args.dir_output)

        else:

            SRR = pd.PrimalDualSolver(
                stacks=stacks,
                reconstruction=st.Stack.from_stack(SRR0.get_reconstruction()),
                minimizer=args.minimizer,
                alpha=args.alpha,
                iter_max=args.iter_max,
                iterations=args.iterations,
                alg_type=args.pd_alg_type,
                reg_type="TV"
                if args.reconstruction_type == "TVL2" else "huber",
                data_loss=args.data_loss,
                verbose=args.verbose,
            )
            SRR.run()
            recon = SRR.get_reconstruction()
            recon.set_filename(
                SRR.get_setting_specific_filename(args.prefix_output))
            recons.insert(0, recon)

            recon.write(args.dir_output)

    if args.verbose and not args.provide_comparison:
        sitkh.show_stacks(recons)

    # Show SRR together with linearly resampled input data.
    # Additionally, a script is generated to open files
    if args.provide_comparison:
        sitkh.show_stacks(
            recons,
            show_comparison_file=args.provide_comparison,
            dir_output=os.path.join(args.dir_output,
                                    args.subfolder_comparison),
        )

    ph.print_line_separator()

    elapsed_time = ph.stop_timing(time_start)
    ph.print_title("Summary")
    print("Computational Time for Volumetric Reconstruction: %s" %
          (elapsed_time))

    return 0
Esempio n. 10
0
def main():

    time_start = ph.start_timing()

    np.set_printoptions(precision=3)

    input_parser = InputArgparser(
        description="Perform Bias Field correction on images based on N4ITK.",
    )
    input_parser.add_filenames(required=True)
    input_parser.add_dir_output(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_prefix_output(default="N4ITK_")
    input_parser.add_option(
        option_string="--convergence-threshold",
        type=float,
        help="Specify the convergence threshold.",
        default=1e-6,
    )
    input_parser.add_option(
        option_string="--spline-order",
        type=int,
        help="Specify the spline order defining the bias field estimate.",
        default=3,
    )
    input_parser.add_option(
        option_string="--wiener-filter-noise",
        type=float,
        help="Specify the noise estimate defining the Wiener filter.",
        default=0.11,
    )
    input_parser.add_option(
        option_string="--bias-field-fwhm",
        type=float,
        help="Specify the full width at half maximum parameter characterizing "
        "the width of the Gaussian deconvolution.",
        default=0.15,
    )
    input_parser.add_log_script_execution(default=1)
    input_parser.add_verbose(default=0)

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    # Write script execution call
    if args.log_script_execution:
        input_parser.write_performed_script_execution(
            os.path.abspath(__file__))

    # Read data
    data_reader = dr.MultipleImagesReader(args.filenames,
                                          suffix_mask=args.suffix_mask)
    data_reader.read_data()
    stacks = data_reader.get_data()

    # Perform Bias Field Correction
    ph.print_title("Perform Bias Field Correction")
    bias_field_corrector = n4itk.N4BiasFieldCorrection(
        convergence_threshold=args.convergence_threshold,
        spline_order=args.spline_order,
        wiener_filter_noise=args.wiener_filter_noise,
        bias_field_fwhm=args.bias_field_fwhm,
        prefix_corrected=args.prefix_output,
    )
    stacks_corrected = [None] * len(stacks)
    for i, stack in enumerate(stacks):
        ph.print_info("Image %d/%d: N4ITK Bias Field Correction ... " %
                      (i + 1, len(stacks)),
                      newline=False)
        bias_field_corrector.set_stack(stack)
        bias_field_corrector.run_bias_field_correction()
        stacks_corrected[i] = \
            bias_field_corrector.get_bias_field_corrected_stack()
        print("done")
        ph.print_info("Image %d/%d: Computational time = %s" %
                      (i + 1, len(stacks),
                       bias_field_corrector.get_computational_time()))

        # Write Data
        stacks_corrected[i].write(args.dir_output,
                                  write_mask=True,
                                  suffix_mask=args.suffix_mask)

        if args.verbose:
            sitkh.show_stacks([stacks[i], stacks_corrected[i]],
                              segmentation=stacks[i])

    elapsed_time = ph.stop_timing(time_start)

    ph.print_title("Summary")
    print("Computational Time for Bias Field Correction(s): %s" %
          (elapsed_time))

    return 0
Esempio n. 11
0
def main():

    time_start = ph.start_timing()

    # Set print options for numpy
    np.set_printoptions(precision=3)

    # Read input
    input_parser = InputArgparser(
        description="Volumetric MRI reconstruction framework to reconstruct "
        "an isotropic, high-resolution 3D volume from multiple stacks of 2D "
        "slices with motion correction. The resolution of the computed "
        "Super-Resolution Reconstruction (SRR) is given by the in-plane "
        "spacing of the selected target stack. A region of interest can be "
        "specified by providing a mask for the selected target stack. Only "
        "this region will then be reconstructed by the SRR algorithm which "
        "can substantially reduce the computational time.", )
    input_parser.add_dir_input()
    input_parser.add_filenames()
    input_parser.add_dir_output(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_target_stack_index(default=0)
    input_parser.add_search_angle(default=90)
    input_parser.add_multiresolution(default=0)
    input_parser.add_shrink_factors(default=[2, 1])
    input_parser.add_smoothing_sigmas(default=[1, 0])
    input_parser.add_sigma(default=0.9)
    input_parser.add_reconstruction_type(default="TK1L2")
    input_parser.add_iterations(default=15)
    input_parser.add_alpha(default=0.02)
    input_parser.add_alpha_first(default=0.05)
    input_parser.add_iter_max(default=10)
    input_parser.add_iter_max_first(default=5)
    input_parser.add_dilation_radius(default=3)
    input_parser.add_extra_frame_target(default=10)
    input_parser.add_bias_field_correction(default=0)
    input_parser.add_intensity_correction(default=0)
    input_parser.add_isotropic_resolution(default=None)
    input_parser.add_log_script_execution(default=1)
    input_parser.add_subfolder_motion_correction()
    input_parser.add_provide_comparison(default=0)
    input_parser.add_subfolder_comparison()
    input_parser.add_write_motion_correction(default=1)
    input_parser.add_verbose(default=0)
    input_parser.add_two_step_cycles(default=3)
    input_parser.add_use_masks_srr(default=1)
    input_parser.add_boundary_stacks(default=[10, 10, 0])
    input_parser.add_reference()
    input_parser.add_reference_mask()

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    # Write script execution call
    if args.log_script_execution:
        input_parser.write_performed_script_execution(
            os.path.abspath(__file__))

    # Use FLIRT for volume-to-volume reg. step. Otherwise, RegAladin is used.
    use_flirt_for_v2v_registration = True

    # --------------------------------Read Data--------------------------------
    ph.print_title("Read Data")

    # Neither '--dir-input' nor '--filenames' was specified
    if args.filenames is not None and args.dir_input is not None:
        raise IOError("Provide input by either '--dir-input' or '--filenames' "
                      "but not both together")

    # '--dir-input' specified
    elif args.dir_input is not None:
        data_reader = dr.ImageDirectoryReader(args.dir_input,
                                              suffix_mask=args.suffix_mask)

    # '--filenames' specified
    elif args.filenames is not None:
        data_reader = dr.MultipleImagesReader(args.filenames,
                                              suffix_mask=args.suffix_mask)

    else:
        raise IOError("Provide input by either '--dir-input' or '--filenames'")

    if len(args.boundary_stacks) is not 3:
        raise IOError(
            "Provide exactly three values for '--boundary-stacks' to define "
            "cropping in i-, j-, and k-dimension of the input stacks")

    data_reader.read_data()
    stacks = data_reader.get_data()
    ph.print_info("%d input stacks read for further processing" % len(stacks))

    if all(s.is_unity_mask() is True for s in stacks):
        ph.print_warning("No mask is provided! "
                         "Generated reconstruction space may be very big!")

    # ---------------------------Data Preprocessing---------------------------
    ph.print_title("Data Preprocessing")

    segmentation_propagator = segprop.SegmentationPropagation(
        # registration_method=regflirt.FLIRT(use_verbose=args.verbose),
        dilation_radius=args.dilation_radius,
        dilation_kernel="Ball",
    )

    data_preprocessing = dp.DataPreprocessing(
        stacks=stacks,
        segmentation_propagator=segmentation_propagator,
        use_cropping_to_mask=True,
        use_N4BiasFieldCorrector=args.bias_field_correction,
        use_intensity_correction=args.intensity_correction,
        target_stack_index=args.target_stack_index,
        boundary_i=args.boundary_stacks[0],
        boundary_j=args.boundary_stacks[1],
        boundary_k=args.boundary_stacks[2],
        unit="mm",
    )
    data_preprocessing.run()
    time_data_preprocessing = data_preprocessing.get_computational_time()

    # Get preprocessed stacks
    stacks = data_preprocessing.get_preprocessed_stacks()

    # Define reference/target stack for registration and reconstruction
    if args.reference is not None:
        reference = st.Stack.from_filename(file_path=args.reference,
                                           file_path_mask=args.reference_mask,
                                           extract_slices=False)

    else:
        reference = st.Stack.from_stack(stacks[args.target_stack_index])

    # ------------------------Volume-to-Volume Registration--------------------
    if args.two_step_cycles > 0:
        # Define search angle ranges for FLIRT in all three dimensions
        search_angles = [
            "-searchr%s -%d %d" % (x, args.search_angle, args.search_angle)
            for x in ["x", "y", "z"]
        ]
        search_angles = (" ").join(search_angles)

        if use_flirt_for_v2v_registration:
            vol_registration = regflirt.FLIRT(
                registration_type="Rigid",
                use_fixed_mask=True,
                use_moving_mask=True,
                options=search_angles,
                use_verbose=False,
            )
        else:
            vol_registration = niftyreg.RegAladin(
                registration_type="Rigid",
                use_fixed_mask=True,
                use_moving_mask=True,
                use_verbose=False,
            )
        v2vreg = pipeline.VolumeToVolumeRegistration(
            stacks=stacks,
            reference=reference,
            registration_method=vol_registration,
            verbose=args.verbose,
        )
        v2vreg.run()
        stacks = v2vreg.get_stacks()
        time_registration = v2vreg.get_computational_time()

    else:
        time_registration = ph.get_zero_time()

    # ---------------------------Create first volume---------------------------
    time_tmp = ph.start_timing()

    # Isotropic resampling to define HR target space
    ph.print_title("Reconstruction Space Generation")
    HR_volume = reference.get_isotropically_resampled_stack(
        resolution=args.isotropic_resolution)
    ph.print_info(
        "Isotropic reconstruction space with %g mm resolution is created" %
        HR_volume.sitk.GetSpacing()[0])

    if args.reference is None:
        # Create joint image mask in target space
        joint_image_mask_builder = imb.JointImageMaskBuilder(
            stacks=stacks,
            target=HR_volume,
            dilation_radius=1,
        )
        joint_image_mask_builder.run()
        HR_volume = joint_image_mask_builder.get_stack()
        ph.print_info("Isotropic reconstruction space is centered around "
                      "joint stack masks. ")

        # Crop to space defined by mask (plus extra margin)
        HR_volume = HR_volume.get_cropped_stack_based_on_mask(
            boundary_i=args.extra_frame_target,
            boundary_j=args.extra_frame_target,
            boundary_k=args.extra_frame_target,
            unit="mm",
        )

        # Scattered Data Approximation to get first estimate of HR volume
        ph.print_title("First Estimate of HR Volume")
        SDA = sda.ScatteredDataApproximation(stacks,
                                             HR_volume,
                                             sigma=args.sigma)
        SDA.run()
        HR_volume = SDA.get_reconstruction()

    time_reconstruction = ph.stop_timing(time_tmp)

    if args.verbose:
        tmp = list(stacks)
        tmp.insert(0, HR_volume)
        sitkh.show_stacks(tmp, segmentation=HR_volume)

    # ----------------Two-step Slice-to-Volume Registration SRR----------------
    SRR = tk.TikhonovSolver(
        stacks=stacks,
        reconstruction=HR_volume,
        reg_type="TK1",
        minimizer="lsmr",
        alpha=args.alpha_first,
        iter_max=args.iter_max_first,
        verbose=True,
        use_masks=args.use_masks_srr,
    )

    if args.two_step_cycles > 0:

        registration = regsitk.SimpleItkRegistration(
            moving=HR_volume,
            use_fixed_mask=True,
            use_moving_mask=True,
            use_verbose=args.verbose,
            interpolator="Linear",
            metric="Correlation",
            use_multiresolution_framework=args.multiresolution,
            shrink_factors=args.shrink_factors,
            smoothing_sigmas=args.smoothing_sigmas,
            initializer_type="SelfGEOMETRY",
            optimizer="ConjugateGradientLineSearch",
            optimizer_params={
                "learningRate": 1,
                "numberOfIterations": 100,
                "lineSearchUpperLimit": 2,
            },
            scales_estimator="Jacobian",
        )
        two_step_s2v_reg_recon = \
            pipeline.TwoStepSliceToVolumeRegistrationReconstruction(
                stacks=stacks,
                reference=HR_volume,
                registration_method=registration,
                reconstruction_method=SRR,
                cycles=args.two_step_cycles,
                alpha_range=[args.alpha_first, args.alpha],
                verbose=args.verbose,
            )
        two_step_s2v_reg_recon.run()
        HR_volume_iterations = \
            two_step_s2v_reg_recon.get_iterative_reconstructions()
        time_registration += \
            two_step_s2v_reg_recon.get_computational_time_registration()
        time_reconstruction += \
            two_step_s2v_reg_recon.get_computational_time_reconstruction()
    else:
        HR_volume_iterations = []

    # Write motion-correction results
    if args.write_motion_correction:
        for stack in stacks:
            stack.write(
                os.path.join(args.dir_output,
                             args.subfolder_motion_correction),
                write_mask=True,
                write_slices=True,
                write_transforms=True,
                suffix_mask=args.suffix_mask,
            )

    # ------------------Final Super-Resolution Reconstruction------------------
    ph.print_title("Final Super-Resolution Reconstruction")
    if args.reconstruction_type in ["TVL2", "HuberL2"]:
        SRR = pd.PrimalDualSolver(
            stacks=stacks,
            reconstruction=HR_volume,
            reg_type="TV" if args.reconstruction_type == "TVL2" else "huber",
            iterations=args.iterations,
        )
    else:
        SRR = tk.TikhonovSolver(
            stacks=stacks,
            reconstruction=HR_volume,
            reg_type="TK1" if args.reconstruction_type == "TK1L2" else "TK0",
            use_masks=args.use_masks_srr,
        )
    SRR.set_alpha(args.alpha)
    SRR.set_iter_max(args.iter_max)
    SRR.set_verbose(True)
    SRR.run()
    time_reconstruction += SRR.get_computational_time()

    elapsed_time_total = ph.stop_timing(time_start)

    # Write SRR result
    HR_volume_final = SRR.get_reconstruction()
    HR_volume_final.set_filename(SRR.get_setting_specific_filename())
    HR_volume_final.write(args.dir_output,
                          write_mask=True,
                          suffix_mask=args.suffix_mask)

    HR_volume_iterations.insert(0, HR_volume_final)
    for stack in stacks:
        HR_volume_iterations.append(stack)

    if args.verbose and not args.provide_comparison:
        sitkh.show_stacks(HR_volume_iterations, segmentation=HR_volume)
    # HR_volume_final.show()

    # Show SRR together with linearly resampled input data.
    # Additionally, a script is generated to open files
    if args.provide_comparison:
        sitkh.show_stacks(
            HR_volume_iterations,
            segmentation=HR_volume,
            show_comparison_file=args.provide_comparison,
            dir_output=os.path.join(args.dir_output,
                                    args.subfolder_comparison),
        )

    # Summary
    ph.print_title("Summary")
    print("Computational Time for Data Preprocessing: %s" %
          (time_data_preprocessing))
    print("Computational Time for Registrations: %s" % (time_registration))
    print("Computational Time for Reconstructions: %s" % (time_reconstruction))
    print("Computational Time for Entire Reconstruction Pipeline: %s" %
          (elapsed_time_total))

    ph.print_line_separator()

    return 0
Esempio n. 12
0
    def test_inplane_rigid_alignment_to_reference_with_intensity_correction_affine(
            self):

        filename_stack = "fetal_brain_0"
        filename_recon = "FetalBrain_reconstruction_3stacks_myAlg"

        stack_sitk = sitk.ReadImage(self.dir_test_data + filename_stack +
                                    ".nii.gz")
        recon_sitk = sitk.ReadImage(self.dir_test_data + filename_recon +
                                    ".nii.gz")

        recon_resampled_sitk = sitk.Resample(recon_sitk, stack_sitk)
        stack = st.Stack.from_sitk_image(recon_resampled_sitk, "original")

        # Create in-plane motion corruption
        angle_z = 0.01
        center_2D = (0, 0)
        translation_2D = np.array([1, 0])

        intensity_scale = 5
        intensity_bias = 5

        # Get corrupted stack and corresponding motions
        stack_corrupted, motion_sitk, motion_2_sitk = get_inplane_corrupted_stack(
            stack,
            angle_z,
            center_2D,
            translation_2D,
            intensity_scale=intensity_scale,
            intensity_bias=intensity_bias)

        # Perform in-plane rigid registration
        inplane_registration = inplanereg.IntraStackRegistration(
            stack_corrupted, stack)
        # inplane_registration = inplanereg.IntraStackRegistration(stack_corrupted)
        inplane_registration.set_transform_type("rigid")
        inplane_registration.set_transform_initializer_type("identity")
        inplane_registration.set_optimizer_loss("linear")
        inplane_registration.set_intensity_correction_initializer_type(
            "affine")
        inplane_registration.set_intensity_correction_type_slice_neighbour_fit(
            "affine")
        inplane_registration.use_parameter_normalization(True)
        inplane_registration.use_verbose(True)
        inplane_registration.use_stack_mask(True)
        inplane_registration.set_prior_intensity_coefficients(
            (intensity_scale - 0.4, intensity_bias + 0.7))
        inplane_registration.set_alpha_reference(1)
        inplane_registration.set_alpha_neighbour(1)
        inplane_registration.set_alpha_parameter(1e3)
        inplane_registration.set_optimizer_iter_max(15)
        inplane_registration.use_verbose(True)
        inplane_registration.run()
        inplane_registration.print_statistics()

        stack_registered = inplane_registration.get_corrected_stack()
        parameters = inplane_registration.get_parameters()

        sitkh.show_stacks([
            stack, stack_corrupted,
            stack_registered.get_resampled_stack_from_slices(
                resampling_grid=None, interpolator="Linear")
        ])

        self.assertEqual(
            np.round(
                np.linalg.norm(parameters[:, -2:] -
                               np.array([intensity_scale, intensity_bias])),
                decimals=0), 0)

        # 2) Test slice transforms
        slice_transforms_sitk = inplane_registration.get_slice_transforms_sitk(
        )

        stack_tmp = st.Stack.from_stack(stack_corrupted)
        stack_tmp.update_motion_correction_of_slices(slice_transforms_sitk)

        stack_diff_sitk = stack_tmp.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk
        ).sitk - stack_registered.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk).sitk

        stack_diff_nda = sitk.GetArrayFromImage(stack_diff_sitk)

        self.assertEqual(np.round(np.linalg.norm(stack_diff_nda), decimals=8),
                         0)
Esempio n. 13
0
    def test_inplane_rigid_alignment_to_reference(self):

        filename_stack = "fetal_brain_0"
        # filename_recon = "FetalBrain_reconstruction_3stacks_myAlg"

        # stack_sitk = sitk.ReadImage(self.dir_test_data + filename_stack + ".nii.gz")
        # recon_sitk = sitk.ReadImage(self.dir_test_data + filename_recon + ".nii.gz")

        # recon_resampled_sitk = sitk.Resample(recon_sitk, stack_sitk)
        # stack = st.Stack.from_sitk_image(recon_resampled_sitk, "original")

        stack = st.Stack.from_filename(
            os.path.join(self.dir_test_data, filename_stack + ".nii.gz"),
            os.path.join(self.dir_test_data, filename_stack + "_mask.nii.gz"))

        # Create in-plane motion corruption
        angle_z = 0.1
        center_2D = (0, 0)
        translation_2D = np.array([1, -2])

        # Get corrupted stack and corresponding motions
        stack_corrupted, motion_sitk, motion_2_sitk = get_inplane_corrupted_stack(
            stack, angle_z, center_2D, translation_2D)

        # stack.show(1)
        # stack_corrupted.show(1)

        # Perform in-plane rigid registration
        inplane_registration = inplanereg.IntraStackRegistration(
            stack_corrupted, stack)
        # inplane_registration = inplanereg.IntraStackRegistration(stack_corrupted)
        inplane_registration.set_transform_initializer_type("moments")
        inplane_registration.set_optimizer_iter_max(10)
        inplane_registration.set_alpha_neighbour(0)
        inplane_registration.set_alpha_parameter(0)
        inplane_registration.use_stack_mask(1)
        inplane_registration.use_reference_mask(0)
        inplane_registration.set_optimizer_loss("linear")
        # inplane_registration.set_optimizer_method("trf")
        # inplane_registration._run_registration_pipeline_initialization()
        # inplane_registration._apply_motion_correction()
        # inplane_registration.use_verbose(True)
        inplane_registration.run()
        inplane_registration.print_statistics()

        stack_registered = inplane_registration.get_corrected_stack()
        parameters = inplane_registration.get_parameters()

        sitkh.show_stacks([
            stack, stack_corrupted,
            stack_registered.get_resampled_stack_from_slices(
                interpolator="Linear", resampling_grid=stack.sitk)
        ])

        print(parameters)

        # self.assertEqual(np.round(
        #     np.linalg.norm(nda_diff)
        # , decimals = self.accuracy), 0)

        # 2) Test slice transforms
        slice_transforms_sitk = inplane_registration.get_slice_transforms_sitk(
        )

        stack_tmp = st.Stack.from_stack(stack_corrupted)
        stack_tmp.update_motion_correction_of_slices(slice_transforms_sitk)

        stack_diff_sitk = stack_tmp.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk
        ).sitk - stack_registered.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk).sitk

        stack_diff_nda = sitk.GetArrayFromImage(stack_diff_sitk)

        self.assertEqual(np.round(np.linalg.norm(stack_diff_nda), decimals=8),
                         0)
Esempio n. 14
0
    def test_inplane_rigid_alignment_to_reference_multimodal(self):

        filename_stack = "fetal_brain_0"
        filename_recon = "FetalBrain_reconstruction_3stacks_myAlg"

        stack_tmp = st.Stack.from_filename(
            os.path.join(self.dir_test_data, filename_stack + ".nii.gz"),
            os.path.join(self.dir_test_data, filename_stack + "_mask.nii.gz"))

        recon = st.Stack.from_filename(
            os.path.join(self.dir_test_data, filename_recon))

        recon_sitk = recon.get_resampled_stack_from_slices(
            resampling_grid=stack_tmp.sitk, interpolator="Linear").sitk

        stack = st.Stack.from_sitk_image(recon_sitk, "original",
                                         stack_tmp.sitk_mask)

        # recon_resampled_sitk = sitk.Resample(recon_sitk, stack_sitk)
        # stack = st.Stack.from_sitk_image(recon_resampled_sitk, "original")

        # Create in-plane motion corruption
        scale = 1.05
        angle_z = 0.05
        center_2D = (0, 0)
        translation_2D = np.array([1, -2])

        intensity_scale = 1
        intensity_bias = 0

        # Get corrupted stack and corresponding motions
        stack_corrupted, motion_sitk, motion_2_sitk = get_inplane_corrupted_stack(
            stack,
            angle_z,
            center_2D,
            translation_2D,
            intensity_scale=intensity_scale,
            scale=scale,
            intensity_bias=intensity_bias)

        # stack_corrupted.show(1)
        # stack.show(1)

        # Perform in-plane rigid registration
        inplane_registration = inplanereg.IntraStackRegistration(
            stack_corrupted, stack)
        # inplane_registration = inplanereg.IntraStackRegistration(stack_corrupted)
        # inplane_registration.set_image_transform_reference_fit_term("gradient_magnitude")
        inplane_registration.set_image_transform_reference_fit_term(
            "partial_derivative")
        inplane_registration.set_transform_initializer_type("moments")
        # inplane_registration.set_transform_type("similarity")
        inplane_registration.set_intensity_correction_initializer_type(None)
        inplane_registration.set_intensity_correction_type_slice_neighbour_fit(
            None)
        inplane_registration.set_intensity_correction_type_reference_fit(None)
        inplane_registration.use_parameter_normalization(True)
        inplane_registration.use_verbose(True)
        inplane_registration.set_optimizer_loss(
            "linear")  # linear, soft_l1, huber
        inplane_registration.set_alpha_reference(100)
        inplane_registration.set_alpha_neighbour(0)
        inplane_registration.set_alpha_parameter(1)
        # inplane_registration.use_stack_mask(True)
        # inplane_registration.use_reference_mask(True)
        inplane_registration.set_optimizer_iter_max(10)
        inplane_registration.run()
        inplane_registration.print_statistics()

        stack_registered = inplane_registration.get_corrected_stack()
        parameters = inplane_registration.get_parameters()

        sitkh.show_stacks([
            stack, stack_corrupted,
            stack_registered.get_resampled_stack_from_slices(
                resampling_grid=None, interpolator="Linear")
        ])

        # print("Final parameters:")
        # print(parameters)

        # self.assertEqual(np.round(
        #     np.linalg.norm(parameters[:,-1] - intensity_scale)
        # , decimals = 0), 0)

        # 2) Test slice transforms
        slice_transforms_sitk = inplane_registration.get_slice_transforms_sitk(
        )

        stack_tmp = st.Stack.from_stack(stack_corrupted)
        stack_tmp.update_motion_correction_of_slices(slice_transforms_sitk)

        stack_diff_sitk = stack_tmp.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk
        ).sitk - stack_registered.get_resampled_stack_from_slices(
            resampling_grid=stack.sitk).sitk

        stack_diff_nda = sitk.GetArrayFromImage(stack_diff_sitk)

        self.assertEqual(np.round(np.linalg.norm(stack_diff_nda), decimals=8),
                         0)
Esempio n. 15
0
def main():

    time_start = ph.start_timing()

    np.set_printoptions(precision=3)

    input_parser = InputArgparser(
        description="Register an obtained reconstruction (moving) "
        "to a template image/space (fixed) using rigid registration. "
        "The resulting registration can optionally be applied to previously "
        "obtained motion correction slice transforms so that a volumetric "
        "reconstruction is possible in the (standard anatomical) space "
        "defined by the fixed.", )
    input_parser.add_fixed(required=True)
    input_parser.add_moving(
        required=True,
        nargs="+",
        help="Specify moving image to be warped to fixed space. "
        "If multiple images are provided, all images will be transformed "
        "uniformly according to the registration obtained for the first one.")
    input_parser.add_dir_output(required=True)
    input_parser.add_dir_input()
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_search_angle(default=180)
    input_parser.add_option(
        option_string="--transform-only",
        type=int,
        help="Turn on/off functionality to transform moving image(s) to fixed "
        "image only, i.e. no resampling to fixed image space",
        default=0)
    input_parser.add_option(
        option_string="--write-transform",
        type=int,
        help="Turn on/off functionality to write registration transform",
        default=0)
    input_parser.add_verbose(default=0)

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    use_reg_aladin_for_refinement = True

    # --------------------------------Read Data--------------------------------
    ph.print_title("Read Data")
    data_reader = dr.MultipleImagesReader(args.moving, suffix_mask="_mask")
    data_reader.read_data()
    moving = data_reader.get_data()

    data_reader = dr.MultipleImagesReader([args.fixed], suffix_mask="_mask")
    data_reader.read_data()
    fixed = data_reader.get_data()[0]

    # -------------------Register Reconstruction to Template-------------------
    ph.print_title("Register Reconstruction to Template")

    # Define search angle ranges for FLIRT in all three dimensions
    search_angles = [
        "-searchr%s -%d %d" % (x, args.search_angle, args.search_angle)
        for x in ["x", "y", "z"]
    ]
    search_angles = (" ").join(search_angles)
    options_args = []
    options_args.append(search_angles)
    # cost = "mutualinfo"
    # options_args.append("-searchcost %s -cost %s" % (cost, cost))
    registration = regflirt.FLIRT(
        fixed=moving[0],
        moving=fixed,
        # use_fixed_mask=True,
        # use_moving_mask=True,  # moving mask only seems to work for SB cases
        registration_type="Rigid",
        use_verbose=False,
        options=(" ").join(options_args),
    )
    ph.print_info("Run Registration (FLIRT) ... ", newline=False)
    registration.run()
    print("done")
    transform_sitk = registration.get_registration_transform_sitk()

    if args.write_transform:
        path_to_transform = os.path.join(args.dir_output,
                                         "registration_transform_sitk.txt")
        sitk.WriteTransform(transform_sitk, path_to_transform)

    # Apply rigidly transform to align reconstruction (moving) with template
    # (fixed)
    for m in moving:
        m.update_motion_correction(transform_sitk)

        # Additionally, use RegAladin for more accurate alignment
        # Rationale: FLIRT has better capture range, but RegAladin seems to
        # find better alignment once it is within its capture range.
        if use_reg_aladin_for_refinement:
            registration = niftyreg.RegAladin(
                fixed=m,
                use_fixed_mask=True,
                moving=fixed,
                registration_type="Rigid",
                use_verbose=False,
            )
            ph.print_info("Run Registration (RegAladin) ... ", newline=False)
            registration.run()
            print("done")
            transform2_sitk = registration.get_registration_transform_sitk()
            m.update_motion_correction(transform2_sitk)
            transform_sitk = sitkh.get_composite_sitk_affine_transform(
                transform2_sitk, transform_sitk)

    if args.transform_only:
        for m in moving:
            m.write(args.dir_output, write_mask=False)
        ph.exit()

    # Resample reconstruction (moving) to template space (fixed)
    warped_moving = [
        m.get_resampled_stack(fixed.sitk, interpolator="Linear")
        for m in moving
    ]

    for wm in warped_moving:
        wm.set_filename(wm.get_filename() + "ResamplingToTemplateSpace")

        if args.verbose:
            sitkh.show_stacks([fixed, wm], segmentation=fixed)

        # Write resampled reconstruction (moving)
        wm.write(args.dir_output, write_mask=False)

    if args.dir_input is not None:
        data_reader = dr.ImageSlicesDirectoryReader(
            path_to_directory=args.dir_input, suffix_mask=args.suffix_mask)
        data_reader.read_data()
        stacks = data_reader.get_data()

        for i, stack in enumerate(stacks):
            stack.update_motion_correction(transform_sitk)
            ph.print_info("Stack %d/%d: All slice transforms updated" %
                          (i + 1, len(stacks)))

            # Write transformed slices
            stack.write(
                os.path.join(args.dir_output, "motion_correction"),
                write_mask=True,
                write_slices=True,
                write_transforms=True,
                suffix_mask=args.suffix_mask,
            )

    elapsed_time_total = ph.stop_timing(time_start)

    # Summary
    ph.print_title("Summary")
    print("Computational Time: %s" % (elapsed_time_total))

    return 0
def main():

    time_start = ph.start_timing()

    # Set print options for numpy
    np.set_printoptions(precision=3)

    # Read input
    input_parser = InputArgparser(
        description="Volumetric MRI reconstruction framework to reconstruct "
        "an isotropic, high-resolution 3D volume from multiple "
        "motion-corrected (or static) stacks of low-resolution slices.", )
    input_parser.add_filenames(required=True)
    input_parser.add_filenames_masks()
    input_parser.add_dir_input_mc()
    input_parser.add_output(required=True)
    input_parser.add_suffix_mask(default="_mask")
    input_parser.add_target_stack(default=None)
    input_parser.add_extra_frame_target(default=10)
    input_parser.add_isotropic_resolution(default=None)
    input_parser.add_intensity_correction(default=1)
    input_parser.add_reconstruction_space(default=None)
    input_parser.add_minimizer(default="lsmr")
    input_parser.add_iter_max(default=10)
    input_parser.add_reconstruction_type(default="TK1L2")
    input_parser.add_data_loss(default="linear")
    input_parser.add_data_loss_scale(default=1)
    input_parser.add_alpha(
        default=0.01  # TK1L2 @ isotropic_resolution = 0.8
        # default=0.006  #TVL2, HuberL2 @ isotropic_resolution = 0.8
    )
    input_parser.add_rho(default=0.1)
    input_parser.add_tv_solver(default="PD")
    input_parser.add_pd_alg_type(default="ALG2")
    input_parser.add_iterations(default=15)
    input_parser.add_log_config(default=1)
    input_parser.add_use_masks_srr(default=0)
    input_parser.add_slice_thicknesses(default=None)
    input_parser.add_verbose(default=0)
    input_parser.add_viewer(default="itksnap")
    input_parser.add_argument(
        "--mask",
        "-mask",
        action='store_true',
        help="If given, input images are interpreted as image masks. "
        "Obtained volumetric reconstruction will be exported in uint8 format.")
    input_parser.add_argument(
        "--sda",
        "-sda",
        action='store_true',
        help="If given, the volume is reconstructed using "
        "Scattered Data Approximation (Vercauteren et al., 2006). "
        "--alpha is considered the value for the standard deviation then. "
        "Recommended value is, e.g., --alpha 0.8")

    args = input_parser.parse_args()
    input_parser.print_arguments(args)

    if args.reconstruction_type not in ["TK1L2", "TVL2", "HuberL2"]:
        raise IOError("Reconstruction type unknown")

    if np.alltrue([not args.output.endswith(t) for t in ALLOWED_EXTENSIONS]):
        raise ValueError("output filename '%s' invalid; "
                         "allowed image extensions are: %s" %
                         (args.output, ", ".join(ALLOWED_EXTENSIONS)))

    dir_output = os.path.dirname(args.output)
    ph.create_directory(dir_output)

    debug = 0

    if args.log_config:
        input_parser.log_config(os.path.abspath(__file__))

    if args.verbose:
        show_niftis = []
        # show_niftis = [f for f in args.filenames]

    # --------------------------------Read Data--------------------------------
    ph.print_title("Read Data")

    if args.mask:
        filenames_masks = args.filenames
    else:
        filenames_masks = args.filenames_masks

    data_reader = dr.MultipleImagesReader(
        file_paths=args.filenames,
        file_paths_masks=filenames_masks,
        suffix_mask=args.suffix_mask,
        dir_motion_correction=args.dir_input_mc,
        stacks_slice_thicknesses=args.slice_thicknesses,
    )
    data_reader.read_data()
    stacks = data_reader.get_data()

    ph.print_info("%d input stacks read for further processing" % len(stacks))

    # Specify target stack for intensity correction and reconstruction space
    if args.target_stack is None:
        target_stack_index = 0
    else:
        # TODO: deal with case when target stack got rejected in previous step
        filenames = ["%s.nii.gz" % s.get_filename() for s in stacks]
        filename_target_stack = os.path.basename(args.target_stack)
        try:
            target_stack_index = filenames.index(filename_target_stack)
        except ValueError as e:
            raise ValueError(
                "--target-stack must correspond to an image as provided by "
                "--filenames")

    # ---------------------------Intensity Correction--------------------------
    if args.intensity_correction and not args.mask:
        ph.print_title("Intensity Correction")
        intensity_corrector = ic.IntensityCorrection()
        intensity_corrector.use_individual_slice_correction(False)
        intensity_corrector.use_stack_mask(True)
        intensity_corrector.use_reference_mask(True)
        intensity_corrector.use_verbose(False)

        for i, stack in enumerate(stacks):
            if i == target_stack_index:
                ph.print_info("Stack %d (%s): Reference image. Skipped." %
                              (i + 1, stack.get_filename()))
                continue
            else:
                ph.print_info("Stack %d (%s): Intensity Correction ... " %
                              (i + 1, stack.get_filename()),
                              newline=False)
            intensity_corrector.set_stack(stack)
            intensity_corrector.set_reference(
                stacks[target_stack_index].get_resampled_stack(
                    resampling_grid=stack.sitk,
                    interpolator="NearestNeighbor",
                ))
            intensity_corrector.run_linear_intensity_correction()
            stacks[i] = intensity_corrector.get_intensity_corrected_stack()
            print("done (c1 = %g) " %
                  intensity_corrector.get_intensity_correction_coefficients())

    # -------------------------Volumetric Reconstruction-----------------------
    ph.print_title("Volumetric Reconstruction")

    # Reconstruction space defined by isotropically resampled,
    # bounding box-cropped target stack
    if args.reconstruction_space is None:
        recon0 = stacks[target_stack_index].get_isotropically_resampled_stack(
            resolution=args.isotropic_resolution,
            extra_frame=args.extra_frame_target,
        )
        recon0 = recon0.get_cropped_stack_based_on_mask(
            boundary_i=args.extra_frame_target,
            boundary_j=args.extra_frame_target,
            boundary_k=args.extra_frame_target,
            unit="mm",
        )

    # Reconstruction space was provided by user
    else:
        recon0 = st.Stack.from_filename(args.reconstruction_space,
                                        extract_slices=False)

        # Change resolution for isotropic resolution if provided by user
        if args.isotropic_resolution is not None:
            recon0 = recon0.get_isotropically_resampled_stack(
                args.isotropic_resolution)

        # Use image information of selected target stack as recon0 serves
        # as initial value for reconstruction
        recon0 = stacks[target_stack_index].get_resampled_stack(recon0.sitk)
        recon0 = recon0.get_stack_multiplied_with_mask()

    ph.print_info("Reconstruction space defined with %s mm3 resolution" %
                  " x ".join(["%.2f" % s for s in recon0.sitk.GetSpacing()]))

    if debug:
        # visualize (intensity corrected) data alongside recon0 init
        show = [st.Stack.from_stack(s) for s in stacks]
        show.insert(0, recon0)
        sitkh.show_stacks(show)

    if args.sda:
        ph.print_title("Compute SDA reconstruction")
        SDA = sda.ScatteredDataApproximation(stacks,
                                             recon0,
                                             sigma=args.alpha,
                                             sda_mask=args.mask)
        SDA.run()
        recon = SDA.get_reconstruction()
        filename = SDA.get_setting_specific_filename()
        if args.mask:
            dw.DataWriter.write_mask(recon.sitk_mask,
                                     args.output,
                                     description=filename)
        else:
            dw.DataWriter.write_image(recon.sitk,
                                      args.output,
                                      description=filename)

        if args.verbose:
            show_niftis.insert(0, args.output)

    else:
        if args.reconstruction_type in ["TVL2", "HuberL2"]:
            ph.print_title("Compute Initial value for %s" %
                           args.reconstruction_type)
            SRR0 = sda.ScatteredDataApproximation(stacks, recon0, sigma=0.8)
        else:
            ph.print_title("Compute %s reconstruction" %
                           args.reconstruction_type)
            SRR0 = tk.TikhonovSolver(
                stacks=stacks,
                reconstruction=recon0,
                alpha=args.alpha,
                iter_max=args.iter_max,
                reg_type="TK1",
                minimizer=args.minimizer,
                data_loss=args.data_loss,
                data_loss_scale=args.data_loss_scale,
                use_masks=args.use_masks_srr,
                # verbose=args.verbose,
            )
        SRR0.run()

        recon = SRR0.get_reconstruction()
        filename = SRR0.get_setting_specific_filename()

        if args.verbose and args.reconstruction_type in ["TVL2", "HuberL2"]:
            output = ph.append_to_filename(args.output, "_init")

            if args.mask:
                mask_estimator = bm.BinaryMaskFromMaskSRREstimator(recon.sitk)
                mask_estimator.run()
                mask_sitk = mask_estimator.get_mask_sitk()
                dw.DataWriter.write_mask(mask_sitk,
                                         output,
                                         description=filename)
            else:
                dw.DataWriter.write_image(recon.sitk,
                                          output,
                                          description=filename)

            show_niftis.insert(0, output)

        if args.reconstruction_type in ["TVL2", "HuberL2"]:
            ph.print_title("Compute %s reconstruction" %
                           args.reconstruction_type)
            if args.tv_solver == "ADMM":
                SRR = admm.ADMMSolver(
                    stacks=stacks,
                    reconstruction=st.Stack.from_stack(
                        SRR0.get_reconstruction()),
                    minimizer=args.minimizer,
                    alpha=args.alpha,
                    iter_max=args.iter_max,
                    rho=args.rho,
                    data_loss=args.data_loss,
                    iterations=args.iterations,
                    use_masks=args.use_masks_srr,
                    verbose=args.verbose,
                )

            else:
                SRR = pd.PrimalDualSolver(
                    stacks=stacks,
                    reconstruction=st.Stack.from_stack(
                        SRR0.get_reconstruction()),
                    minimizer=args.minimizer,
                    alpha=args.alpha,
                    iter_max=args.iter_max,
                    iterations=args.iterations,
                    alg_type=args.pd_alg_type,
                    reg_type="TV"
                    if args.reconstruction_type == "TVL2" else "huber",
                    data_loss=args.data_loss,
                    use_masks=args.use_masks_srr,
                    verbose=args.verbose,
                )
            SRR.run()
            recon = SRR.get_reconstruction()
            filename = SRR.get_setting_specific_filename()

        if args.mask:
            mask_estimator = bm.BinaryMaskFromMaskSRREstimator(recon.sitk)
            mask_estimator.run()
            mask_sitk = mask_estimator.get_mask_sitk()
            dw.DataWriter.write_mask(mask_sitk,
                                     args.output,
                                     description=filename)

        else:
            dw.DataWriter.write_image(recon.sitk,
                                      args.output,
                                      description=filename)

        if args.verbose:
            show_niftis.insert(0, args.output)

    if args.verbose:
        ph.show_niftis(show_niftis, viewer=args.viewer)

    ph.print_line_separator()

    elapsed_time = ph.stop_timing(time_start)
    ph.print_title("Summary")
    exe_file_info = os.path.basename(os.path.abspath(__file__)).split(".")[0]
    print("%s | Computational Time for Volumetric Reconstruction: %s" %
          (exe_file_info, elapsed_time))

    return 0