Esempio n. 1
0
def ants_registration_affine_node(**kwargs):
    """return antsRegistration interace instance with default values
    based on antsRegistrationSyN.sh with the a transformation option

    :param \\*\\*kwargs: parameters to override the default values
    :return: :py:obj:`Registration` node
    """
    defaults = dict(dimension=3,
                    use_histogram_matching=False,
                    interpolation='Linear',
                    initial_moving_transform_com=1,
                    metric=['MI', 'MI'],
                    metric_weight=[1.0, 1.0],
                    radius_or_number_of_bins=[32, 32],
                    sampling_strategy=['Regular', 'Regular'],
                    sampling_percentage=[0.25, 0.25],
                    transforms=['Rigid', 'Affine'],
                    transform_parameters=[(0.1, ), (0.1, )],
                    smoothing_sigmas=[[3, 2, 1, 0], [3, 2, 1, 0]],
                    sigma_units=['vox', 'vox'],
                    shrink_factors=[[8, 4, 2, 1], [8, 4, 2, 1]],
                    number_of_iterations=[[1000, 500, 250, 100],
                                          [1000, 500, 250, 100]],
                    convergence_threshold=[1e-6, 1e-6],
                    convergence_window_size=[10, 10],
                    winsorize_lower_quantile=0.005,
                    winsorize_upper_quantile=0.995,
                    write_composite_transform=True,
                    output_warped_image=True)
    defaults.update(kwargs)
    ants = Registration(**defaults)
    return ants
Esempio n. 2
0
def ants_registration_syn_no_affine_node(**kwargs):
    """return antsRegistration interace instance with default values
    based on antsRegistrationSyN.sh with the s transformation option
    and the rigid and affine steps removed

    :param \\*\\*kwargs: parameters to override the default values
    :return: :py:obj:`Registration` node
    """
    defaults = dict(dimension=3,
                    use_histogram_matching=False,
                    interpolation='Linear',
                    metric=['CC'],
                    metric_weight=[1.0],
                    radius_or_number_of_bins=[4],
                    sampling_strategy=[None],
                    sampling_percentage=[None],
                    transforms=['SyN'],
                    transform_parameters=[(0.1, 3, 0)],
                    smoothing_sigmas=[[3, 2, 1, 0]],
                    sigma_units=['vox'],
                    shrink_factors=[[8, 4, 2, 1]],
                    number_of_iterations=[[100, 70, 50, 20]],
                    convergence_threshold=[1e-6],
                    convergence_window_size=[10],
                    winsorize_lower_quantile=0.005,
                    winsorize_upper_quantile=0.995,
                    write_composite_transform=True,
                    output_warped_image=True)
    defaults.update(kwargs)
    ants = Registration(**defaults)
    return ants
Esempio n. 3
0
    def _config_ants(self, ants_settings):
        NIWORKFLOWS_LOG.info('Loading settings from file %s.', ants_settings)
        self.norm = Registration(
            moving_image=self.inputs.moving_image,
            num_threads=self.inputs.num_threads,
            from_file=ants_settings,
            terminal_output='file',
            write_composite_transform=True
        )
        if isdefined(self.inputs.moving_mask):
            if self.inputs.explicit_masking:
                self.norm.inputs.moving_image = mask(
                    self.inputs.moving_image[0],
                    self.inputs.moving_mask,
                    "moving_masked.nii.gz")
            else:
                self.norm.inputs.moving_image_mask = self.inputs.moving_mask


        if isdefined(self.inputs.reference_image):
            self.norm.inputs.fixed_image = self.inputs.reference_image
            if isdefined(self.inputs.reference_mask):
                if self.inputs.explicit_masking:
                    self.norm.inputs.fixed_image = mask(
                        self.inputs.reference_image[0],
                        self.inputs.mreference_mask,
                        "fixed_masked.nii.gz")
                else:
                    self.norm.inputs.fixed_image_mask = self.inputs.reference_mask
        else:
            get_template = getattr(getters, 'get_{}'.format(self.inputs.template))
            mni_template = get_template()

            if self.inputs.orientation == 'LAS':
                raise NotImplementedError

            resolution = self.inputs.template_resolution
            if self.inputs.testing:
                resolution = 2

            if self.inputs.explicit_masking:
                self.norm.inputs.fixed_image = mask(op.join(
                    mni_template, '%dmm_%s.nii.gz' % (resolution, self.inputs.reference)),
                    op.join(
                        mni_template, '%dmm_brainmask.nii.gz' % resolution),
                    "fixed_masked.nii.gz")
            else:
                self.norm.inputs.fixed_image = op.join(
                    mni_template,
                    '%dmm_%s.nii.gz' % (resolution, self.inputs.reference))
                self.norm.inputs.fixed_image_mask = op.join(
                    mni_template, '%dmm_brainmask.nii.gz' % resolution)
def test_Registration_outputs():
    output_map = dict(composite_transform=dict(),
    forward_invert_flags=dict(),
    forward_transforms=dict(),
    inverse_composite_transform=dict(),
    inverse_warped_image=dict(),
    reverse_invert_flags=dict(),
    reverse_transforms=dict(),
    warped_image=dict(),
    )
    outputs = Registration.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
def test_Registration_outputs():
    output_map = dict(
        composite_transform=dict(),
        forward_invert_flags=dict(),
        forward_transforms=dict(),
        inverse_composite_transform=dict(),
        inverse_warped_image=dict(),
        reverse_invert_flags=dict(),
        reverse_transforms=dict(),
        warped_image=dict(),
    )
    outputs = Registration.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
def test_Registration_inputs():
    input_map = dict(args=dict(argstr='%s',
    ),
    collapse_linear_transforms_to_fixed_image_header=dict(argstr='%s',
    usedefault=True,
    ),
    collapse_output_transforms=dict(argstr='--collapse-output-transforms %d',
    usedefault=True,
    ),
    convergence_threshold=dict(requires=['number_of_iterations'],
    usedefault=True,
    ),
    convergence_window_size=dict(requires=['convergence_threshold'],
    usedefault=True,
    ),
    dimension=dict(argstr='--dimensionality %d',
    usedefault=True,
    ),
    environ=dict(nohash=True,
    usedefault=True,
    ),
    fixed_image=dict(mandatory=True,
    ),
    fixed_image_mask=dict(argstr='%s',
    ),
    ignore_exception=dict(nohash=True,
    usedefault=True,
    ),
    initial_moving_transform=dict(argstr='%s',
    xor=['initial_moving_transform_com'],
    ),
    initial_moving_transform_com=dict(argstr='%s',
    xor=['initial_moving_transform'],
    ),
    interpolation=dict(argstr='%s',
    usedefault=True,
    ),
    invert_initial_moving_transform=dict(requires=['initial_moving_transform'],
    xor=['initial_moving_transform_com'],
    ),
    metric=dict(mandatory=True,
    ),
    metric_item_trait=dict(),
    metric_stage_trait=dict(),
    metric_weight=dict(mandatory=True,
    requires=['metric'],
    usedefault=True,
    ),
    metric_weight_item_trait=dict(),
    metric_weight_stage_trait=dict(),
    moving_image=dict(mandatory=True,
    ),
    moving_image_mask=dict(requires=['fixed_image_mask'],
    ),
    num_threads=dict(nohash=True,
    usedefault=True,
    ),
    number_of_iterations=dict(),
    output_inverse_warped_image=dict(hash_files=False,
    requires=['output_warped_image'],
    ),
    output_transform_prefix=dict(argstr='%s',
    usedefault=True,
    ),
    output_warped_image=dict(hash_files=False,
    ),
    radius_bins_item_trait=dict(),
    radius_bins_stage_trait=dict(),
    radius_or_number_of_bins=dict(requires=['metric_weight'],
    usedefault=True,
    ),
    sampling_percentage=dict(requires=['sampling_strategy'],
    ),
    sampling_percentage_item_trait=dict(),
    sampling_percentage_stage_trait=dict(),
    sampling_strategy=dict(requires=['metric_weight'],
    ),
    sampling_strategy_item_trait=dict(),
    sampling_strategy_stage_trait=dict(),
    shrink_factors=dict(mandatory=True,
    ),
    sigma_units=dict(requires=['smoothing_sigmas'],
    ),
    smoothing_sigmas=dict(mandatory=True,
    ),
    terminal_output=dict(mandatory=True,
    nohash=True,
    ),
    transform_parameters=dict(),
    transforms=dict(argstr='%s',
    mandatory=True,
    ),
    use_estimate_learning_rate_once=dict(),
    use_histogram_matching=dict(usedefault=True,
    ),
    winsorize_lower_quantile=dict(argstr='%s',
    usedefault=True,
    ),
    winsorize_upper_quantile=dict(argstr='%s',
    usedefault=True,
    ),
    write_composite_transform=dict(argstr='--write-composite-transform %d',
    usedefault=True,
    ),
    )
    inputs = Registration.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
    def fit(
        dwdata,
        *,
        n_iter=1,
        align_kwargs=None,
        model="b0",
        omp_nthreads=None,
        seed=None,
        **kwargs,
    ):
        r"""
        Estimate head-motion and Eddy currents.

        Parameters
        ----------
        dwdata : :obj:`~eddymotion.dmri.DWI`
            The target DWI dataset, represented by this tool's internal
            type. The object is used in-place, and will contain the estimated
            parameters in its ``em_affines`` property, as well as the rotated
            *b*-vectors within its ``gradients`` property.
        n_iter : :obj:`int`
            Number of iterations this particular model is going to be repeated.
        align_kwargs : :obj:`dict`
            Parameters to configure the image registration process.
        model : :obj:`str`
            Selects the diffusion model that will generate the registration target
            corresponding to each gradient map.
            See :obj:`~eddymotion.model.ModelFactory` for allowed models (and corresponding
            keywords).
        seed : :obj:`int` or :obj:`bool`
            Seed the random number generator (necessary when we want deterministic
            estimation).

        Return
        ------
        affines : :obj:`list` of :obj:`numpy.ndarray`
            A list of :math:`4 \times 4` affine matrices encoding the estimated
            parameters of the deformations caused by head-motion and eddy-currents.

        """
        align_kwargs = align_kwargs or {}
        reg_target_type = ("dwi" if model.lower() not in ("b0", "s0", "avg",
                                                          "average",
                                                          "mean") else "b0")

        if seed or seed == 0:
            np.random.seed(20210324 if seed is True else seed)

        bmask_img = None
        if dwdata.brainmask is not None:
            _, bmask_img = mkstemp(suffix="_bmask.nii.gz")
            nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine,
                           None).to_filename(bmask_img)
            kwargs["mask"] = dwdata.brainmask

        kwargs["S0"] = _advanced_clip(dwdata.bzero)

        if "n_threads" in kwargs:
            align_kwargs["num_threads"] = kwargs["n_threads"]

        for i_iter in range(1, n_iter + 1):
            index_order = np.arange(len(dwdata))
            np.random.shuffle(index_order)
            with tqdm(total=len(index_order), unit="dwi") as pbar:
                for i in index_order:
                    # run a original-to-synthetic affine registration
                    with TemporaryDirectory() as tmpdir:
                        pbar.write(
                            f"Pass {i_iter}/{n_iter} | Processing b-index <{i}> in <{tmpdir}>"
                        )
                        data_train, data_test = dwdata.logo_split(i,
                                                                  with_b0=True)

                        # Factory creates the appropriate model and pipes arguments
                        dwmodel = ModelFactory.init(gtab=data_train[1],
                                                    model=model,
                                                    omp_nthreads=omp_nthreads,
                                                    **kwargs)

                        # fit the model
                        dwmodel.fit(data_train[0])

                        # generate a synthetic dw volume for the test gradient
                        predicted = dwmodel.predict(data_test[1])

                        tmpdir = Path(tmpdir)
                        moving = tmpdir / "moving.nii.gz"
                        fixed = tmpdir / "fixed.nii.gz"
                        _to_nifti(data_test[0], dwdata.affine, moving)
                        _to_nifti(
                            predicted,
                            dwdata.affine,
                            fixed,
                            clip=reg_target_type == "dwi",
                        )

                        registration = Registration(
                            terminal_output="file",
                            from_file=pkg_fn(
                                "eddymotion",
                                f"config/dwi-to-{reg_target_type}_level{i_iter}.json",
                            ),
                            fixed_image=str(fixed.absolute()),
                            moving_image=str(moving.absolute()),
                            **align_kwargs,
                        )
                        if bmask_img:
                            registration.inputs.fixed_image_masks = [
                                "NULL", bmask_img
                            ]

                        if dwdata.em_affines and dwdata.em_affines[
                                i] is not None:
                            mat_file = tmpdir / f"init{i_iter}.mat"
                            dwdata.em_affines[i].to_filename(mat_file,
                                                             fmt="itk")
                            registration.inputs.initial_moving_transform = str(
                                mat_file)

                        # execute ants command line
                        result = registration.run(cwd=str(tmpdir)).outputs

                        # read output transform
                        xform = nt.io.itk.ITKLinearTransform.from_filename(
                            result.forward_transforms[0]).to_ras(
                                reference=fixed, moving=moving)

                    # update
                    dwdata.set_transform(i, xform)
                    pbar.update()

        return dwdata.em_affines
Esempio n. 8
0
def test_Registration_inputs():
    input_map = dict(
        args=dict(argstr='%s', ),
        collapse_output_transforms=dict(
            argstr='--collapse-output-transforms %d',
            usedefault=True,
        ),
        convergence_threshold=dict(
            requires=['number_of_iterations'],
            usedefault=True,
        ),
        convergence_window_size=dict(
            requires=['convergence_threshold'],
            usedefault=True,
        ),
        dimension=dict(
            argstr='--dimensionality %d',
            usedefault=True,
        ),
        environ=dict(
            nohash=True,
            usedefault=True,
        ),
        fixed_image=dict(mandatory=True, ),
        fixed_image_mask=dict(argstr='%s', ),
        float=dict(argstr='--float %d', ),
        ignore_exception=dict(
            nohash=True,
            usedefault=True,
        ),
        initial_moving_transform=dict(
            argstr='%s',
            xor=['initial_moving_transform_com'],
        ),
        initial_moving_transform_com=dict(
            argstr='%s',
            xor=['initial_moving_transform'],
        ),
        initialize_transforms_per_stage=dict(
            argstr='--initialize-transforms-per-stage %d',
            usedefault=True,
        ),
        interpolation=dict(
            argstr='%s',
            usedefault=True,
        ),
        invert_initial_moving_transform=dict(
            requires=['initial_moving_transform'],
            xor=['initial_moving_transform_com'],
        ),
        metric=dict(mandatory=True, ),
        metric_item_trait=dict(),
        metric_stage_trait=dict(),
        metric_weight=dict(
            mandatory=True,
            requires=['metric'],
            usedefault=True,
        ),
        metric_weight_item_trait=dict(),
        metric_weight_stage_trait=dict(),
        moving_image=dict(mandatory=True, ),
        moving_image_mask=dict(requires=['fixed_image_mask'], ),
        num_threads=dict(
            nohash=True,
            usedefault=True,
        ),
        number_of_iterations=dict(),
        output_inverse_warped_image=dict(
            hash_files=False,
            requires=['output_warped_image'],
        ),
        output_transform_prefix=dict(
            argstr='%s',
            usedefault=True,
        ),
        output_warped_image=dict(hash_files=False, ),
        radius_bins_item_trait=dict(),
        radius_bins_stage_trait=dict(),
        radius_or_number_of_bins=dict(
            requires=['metric_weight'],
            usedefault=True,
        ),
        restore_state=dict(argstr='--restore-state %s', ),
        sampling_percentage=dict(requires=['sampling_strategy'], ),
        sampling_percentage_item_trait=dict(),
        sampling_percentage_stage_trait=dict(),
        sampling_strategy=dict(requires=['metric_weight'], ),
        sampling_strategy_item_trait=dict(),
        sampling_strategy_stage_trait=dict(),
        save_state=dict(argstr='--save-state %s', ),
        shrink_factors=dict(mandatory=True, ),
        sigma_units=dict(requires=['smoothing_sigmas'], ),
        smoothing_sigmas=dict(mandatory=True, ),
        terminal_output=dict(nohash=True, ),
        transform_parameters=dict(),
        transforms=dict(
            argstr='%s',
            mandatory=True,
        ),
        use_estimate_learning_rate_once=dict(),
        use_histogram_matching=dict(usedefault=True, ),
        winsorize_lower_quantile=dict(
            argstr='%s',
            usedefault=True,
        ),
        winsorize_upper_quantile=dict(
            argstr='%s',
            usedefault=True,
        ),
        write_composite_transform=dict(
            argstr='--write-composite-transform %d',
            usedefault=True,
        ),
    )
    inputs = Registration.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value