def lin_xfm_to_elastix(xfm, elastix_par):
    """Convert MINC style xfm into elastix style registration parameters
    Assuming that xfm fiel is strictly linear
    """
    with minc_tools.mincTools() as minc:
        minc.command(['itk_convert_xfm', xfm,
                      minc.tmp('input.txt')],
                     inputs=xfm,
                     outputs=[minc.tmp('input.txt')])
        # parsing text transformation
        param = None
        fix_param = None

        with open(minc.tmp('input.txt'), 'r') as f:
            for ln in f:
                if re.match('^Parameters: ', ln):
                    param = ln.split(' ')
                if re.match('^FixedParameters: ', ln):
                    fix_param = ln.split(' ')
        param.pop(0)
        fix_param.pop(0)
        with open(minc.tmp('elastix_par'), 'w') as f:
            f.write('''(Transform "AffineTransform")
(NumberOfParameters 12)
(TransformParameters {})
(InitialTransformParametersFileName "NoInitialTransform")
(HowToCombineTransforms "Compose")

// EulerTransform specific
(CenterOfRotationPoint {})
'''.format(' '.join(param), ' '.join(fix_param)))
def nl_elastix_to_xfm(elastix_par, xfm, downsample_grid=None, nl=True):
    """Convert elastix transformation file into minc XFM file"""
    with minc_tools.mincTools() as minc:
        threads = os.environ.get('ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS', 1)
        cmd = [
            'transformix', '-tp', elastix_par, '-out', minc.tempdir, '-xfm',
            xfm, '-q', '-threads',
            str(threads)
        ]

        if nl:
            cmd.extend(['-def', 'all'])
            if downsample_grid is not None:
                cmd.extend(['-sub', str(downsample_grid)])

        minc.command(cmd, inputs=[elastix_par], outputs=[xfm])
        return xfm
def nl_xfm_to_elastix(xfm, elastix_par):
    """Convert MINC style xfm into elastix style registration parameters
    Assuming that xfm file is strictly non-linear, with a single non-linear deformation field
    """
    # TODO: make a proper parsing of XFM file
    with minc_tools.mincTools() as minc:
        grid = xfm.rsplit('.xfm', 1)[0] + '_grid_0.mnc'
        if not os.path.exists(grid):
            print("nl_xfm_to_elastix error!")
            raise minc_tools.mincError(
                "Unfortunately currently only a very primitive way of dealing with Minc XFM files is implemented\n{}"
                .format(traceback.format_exc()))

        with open(elastix_par, 'w') as f:
            f.write("(Transform \"DeformationFieldTransform\")\n")
            f.write("(DeformationFieldInterpolationOrder 0)\n")
            f.write("(DeformationFieldFileName \"{}\")\n".format(grid))
        return elastix_par
Exemplo n.º 4
0
def linear_register_to_self(source,
                            target,
                            output_xfm,
                            parameters=None,
                            mask=None,
                            target_talxfm=None,
                            init_xfm=None,
                            model=None,
                            modeldir=None,
                            close=False,
                            nocrop=False,
                            noautothreshold=False):
    """perform linear registration, wrapper around mritoself
    
    """

    # TODO convert mritoself to python (?)
    with minc_tools.mincTools() as minc:
        cmd = ['mritoself', source, target, output_xfm]
        if parameters is not None:
            cmd.append(parameters)
        if mask is not None:
            cmd.extend(['-mask', mask])
        if target_talxfm is not None:
            cmd.extend(['-target_talxfm', target_talxfm])
        if init_xfm is not None:
            cmd.extend(['-transform', init_xfm])
        if model is not None:
            cmd.extend(['-model', model])
        if modeldir is not None:
            cmd.extend(['-modeldir', modeldir])
        if close:
            cmd.append('-close')
        if nocrop:
            cmd.append('-nocrop')
        if noautothreshold:
            cmd.append('-noautothreshold')
            cmd.append('-nothreshold')
        minc.command(cmd, inputs=[source, target], outputs=[output_xfm])
Exemplo n.º 5
0
def non_linear_register_full(source,
                             target,
                             output_xfm,
                             source_mask=None,
                             target_mask=None,
                             init_xfm=None,
                             level=4,
                             start=32,
                             parameters=None,
                             work_dir=None,
                             downsample=None):
    """perform non-linear registration, multiple levels
    Args:
        source - name of source minc file
        target - name of target minc file
        output_xfm - name of output transformation file
        source_mask - name of source mask file (optional)
        target_mask - name of target mask file (optional)
        init_xfm - name of initial transformation file (optional)
        parameters - configuration for iterative algorithm dict (optional)
        work_dir - working directory (optional) , default create one in temp
        start - initial step size, default 32mm 
        level - final step size, default 4mm
        downsample - downsample initial files to this step size, default None

    Returns:
        resulting XFM file

    Raises:
        mincError when tool fails
    """
    with minc_tools.mincTools() as minc:

        if not minc.checkfiles(inputs=[source, target], outputs=[output_xfm]):
            return

        if parameters is None:
            #print("Using default parameters")
            parameters = {
                'cost':
                'corrcoeff',
                'weight':
                1,
                'stiffness':
                1,
                'similarity':
                0.3,
                'sub_lattice':
                6,
                'conf': [
                    {
                        'step': 32.0,
                        'blur_fwhm': 16.0,
                        'iterations': 20,
                        'blur': 'blur',
                    },
                    {
                        'step': 16.0,
                        'blur_fwhm': 8.0,
                        'iterations': 20,
                        'blur': 'blur',
                    },
                    {
                        'step': 12.0,
                        'blur_fwhm': 6.0,
                        'iterations': 20,
                        'blur': 'blur',
                    },
                    {
                        'step': 8.0,
                        'blur_fwhm': 4.0,
                        'iterations': 20,
                        'blur': 'blur',
                    },
                    {
                        'step': 6.0,
                        'blur_fwhm': 3.0,
                        'iterations': 20,
                        'blur': 'blur',
                    },
                    {
                        'step': 4.0,
                        'blur_fwhm': 2.0,
                        'iterations': 10,
                        'blur': 'blur',
                    },
                    {
                        'step': 2.0,
                        'blur_fwhm': 1.0,
                        'iterations': 10,
                        'blur': 'blur',
                    },
                    {
                        'step': 1.0,
                        'blur_fwhm': 1.0,
                        'iterations': 10,
                        'blur': 'blur',
                    },
                    {
                        'step': 1.0,
                        'blur_fwhm': 0.5,
                        'iterations': 10,
                        'blur': 'blur',
                    },
                    {
                        'step': 0.5,
                        'blur_fwhm': 0.25,
                        'iterations': 10,
                        'blur': 'blur',
                    },
                ]
            }

        prev_xfm = None
        prev_grid = None

        sources = []
        targets = []

        if isinstance(source, list):
            sources.extend(source)
        else:
            sources.append(source)

        if isinstance(target, list):
            targets.extend(target)
        else:
            targets.append(target)

        if len(sources) != len(targets):
            raise mincError(' ** Error: Different number of inputs ')

        s_base = os.path.basename(sources[0]).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]
        t_base = os.path.basename(targets[0]).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]

        # figure out what to do here:
        with minc_tools.cache_files(work_dir=work_dir, context='reg') as tmp:
            # a fitting we shall go...
            (sources_lr, targets_lr, source_mask_lr,
             target_mask_lr) = minc.downsample_registration_files(
                 sources, targets, source_mask, target_mask, downsample)

            for (i, c) in enumerate(parameters['conf']):

                if c['step'] > start:
                    continue
                elif c['step'] < level:
                    break

                # set up intermediate files
                tmp_ = tmp.tmp(s_base + '_' + t_base + '_' + str(i))

                tmp_xfm = tmp_ + '.xfm'
                tmp_grid = tmp_ + '_grid_0.mnc'

                tmp_sources = sources_lr
                tmp_targets = targets_lr

                if c['blur_fwhm'] > 0:
                    tmp_sources = []
                    tmp_targets = []

                    for s_, _ in enumerate(sources_lr):
                        tmp_source = tmp.cache(s_base + '_' + c['blur'] + '_' +
                                               str(c['blur_fwhm']) + '_' +
                                               str(s_) + '.mnc')
                        if not os.path.exists(tmp_source):
                            minc.blur(sources_lr[s_],
                                      tmp_source,
                                      gmag=(c['blur'] == 'dxyz'),
                                      fwhm=c['blur_fwhm'])
                        tmp_target = tmp.cache(t_base + '_' + c['blur'] + '_' +
                                               str(c['blur_fwhm']) + '_' +
                                               str(s_) + '.mnc')
                        if not os.path.exists(tmp_target):
                            minc.blur(targets_lr[s_],
                                      tmp_target,
                                      gmag=(c['blur'] == 'dxyz'),
                                      fwhm=c['blur_fwhm'])
                        tmp_sources.append(tmp_source)
                        tmp_targets.append(tmp_target)

                # set up registration
                args = [
                    'minctracc',
                    tmp_sources[0],
                    tmp_targets[0],
                    '-clobber',
                    '-nonlinear',
                    parameters['cost'],
                    '-weight',
                    parameters['weight'],
                    '-stiffness',
                    parameters['stiffness'],
                    '-similarity',
                    parameters['similarity'],
                    '-sub_lattice',
                    parameters['sub_lattice'],
                ]

                args.extend(['-iterations', c['iterations']])
                args.extend([
                    '-lattice_diam', c['step'] * 3.0, c['step'] * 3.0,
                    c['step'] * 3.0
                ])
                args.extend(['-step', c['step'], c['step'], c['step']])

                if c['step'] < 4:  #TODO: check if it's 4*minc_step ?
                    args.append('-no_super')

                for s_ in range(len(tmp_targets) - 1):
                    args.extend([
                        '-feature_vol', tmp_sources[s_ + 1],
                        tmp_targets[s_ + 1], parameters['cost'], 1.0
                    ])

                    # Current transformation at this step
                if prev_xfm is not None:
                    args.extend(['-transformation', prev_xfm])
                elif init_xfm is not None:
                    args.extend(['-transformation', init_xfm])
                else:
                    args.append('-identity')

                # masks (even if the blurred image is masked, it's still preferable
                # to use the mask in minctracc)
                if source_mask is not None:
                    args.extend(['-source_mask', source_mask_lr])
                if target_mask is not None:
                    args.extend(['-model_mask', target_mask_lr])

                # add files and run registration
                args.append(tmp_xfm)

                minc.command([str(ii) for ii in args],
                             inputs=[tmp_source, tmp_target],
                             outputs=[tmp_xfm])

                prev_xfm = tmp_xfm
                prev_grid = tmp_grid

            # done
            if prev_xfm is None:
                raise minc_tools.mincError("No iterations were performed!")

            # STOP-gap measure to save space for now
            # TODO: fix minctracc?
            # TODO: fix mincreshape too!
            minc.calc([prev_grid],
                      'A[0]',
                      tmp.tmp('final_grid_0.mnc'),
                      datatype='-float')
            shutil.move(tmp.tmp('final_grid_0.mnc'), prev_grid)

            minc.param2xfm(tmp.tmp('identity.xfm'))
            minc.xfmconcat([tmp.tmp('identity.xfm'), prev_xfm], output_xfm)
            return output_xfm
Exemplo n.º 6
0
def linear_register(source,
                    target,
                    output_xfm,
                    parameters=None,
                    source_mask=None,
                    target_mask=None,
                    init_xfm=None,
                    objective=None,
                    conf=None,
                    debug=False,
                    close=False,
                    norot=False,
                    noshear=False,
                    noshift=False,
                    noscale=False,
                    work_dir=None,
                    start=None,
                    downsample=None,
                    verbose=0):
    """Perform linear registration, replacement for bestlinreg.pl script
    
    Args:
        source - name of source minc file
        target - name of target minc file
        output_xfm - name of output transformation file
        parameters - registration parameters (optional), can be 
            '-lsq6', '-lsq9', '-lsq12'
        source_mask - name of source mask file (optional)
        target_mask - name of target mask file (optional)
        init_xfm - name of initial transformation file (optional)
        objective - name of objective function (optional), could be 
            '-xcorr' (default), '-nmi','-mi'
        conf - configuration for iterative algorithm (optional) 
               array of dict, or a string describing a flawor
               bestlinreg (default)
               bestlinreg_s
               bestlinreg_s2
               bestlinreg_new - Claude's latest and greatest
        debug - debug flag (optional) , default False
        close - closeness flag (optional) , default False
        norot - disable rotation flag (optional) , default False
        noshear - disable shear flag (optional) , default False
        noshift - disable shift flag (optional) , default False
        noscale - disable scale flag (optional) , default False
        work_dir - working directory (optional) , default create one in temp
        start - initial blurring level, default 16mm from configuration
        downsample - downsample initial files to this step size, default None
        verbose  - verbosity level
    Returns:
        resulting XFM file

    Raises:
        mincError when tool fails
    """
    print("linear_register s:{} s_m:{} t:{} t_m:{} i:{} ".format(
        source, source_mask, target, target_mask, init_xfm))

    with minc_tools.mincTools(verbose=verbose) as minc:
        if not minc.checkfiles(inputs=[source, target], outputs=[output_xfm]):
            return

        sources = []
        targets = []

        if isinstance(source, list):
            sources.extend(source)
        else:
            sources.append(source)

        if isinstance(target, list):
            targets.extend(target)
        else:
            targets.append(target)

        if len(sources) != len(targets):
            raise mincError(' ** Error: Different number of inputs ')

        # python version
        if conf is None:
            conf = linear_registration_config['bestlinreg']  # bestlinreg_new ?
        elif not isinstance(conf, list):  # assume that it is a string
            if conf in linear_registration_config:
                conf = linear_registration_config[conf]

        if parameters is None:
            parameters = '-lsq9'

        if objective is None:
            objective = '-xcorr'

        if not isinstance(conf, list):  # assume that it is a string
            # assume it's external program's name
            # else run internally
            # TODO: check if we are given multiple sources/targets?
            #
            with minc_tools.mincTools() as m:
                cmd = [conf, source, target, output_xfm]
                if source_mask is not None:
                    cmd.extend(['-source_mask', source_mask])
                if target_mask is not None:
                    cmd.extend(['-target_mask', target_mask])
                if parameters is not None:
                    cmd.append(parameters)
                if objective is not None:
                    cmd.append(objective)
                if init_xfm is not None:
                    cmd.extend(['-init_xfm', init_xfm])

                m.command(cmd,
                          inputs=[source, target],
                          outputs=[output_xfm],
                          verbose=2)
            return output_xfm
        else:

            prev_xfm = None

            s_base = os.path.basename(sources[0]).rsplit('.gz', 1)[0].rsplit(
                '.mnc', 1)[0]
            t_base = os.path.basename(targets[0]).rsplit('.gz', 1)[0].rsplit(
                '.mnc', 1)[0]

            # figure out what to do here:
            with minc_tools.cache_files(work_dir=work_dir,
                                        context='reg') as tmp:

                (sources_lr, targets_lr, source_mask_lr,
                 target_mask_lr) = minc.downsample_registration_files(
                     sources, targets, source_mask, target_mask, downsample)

                # a fitting we shall go...
                for (i, c) in enumerate(conf):
                    _parameters = parameters

                    if 'parameters' in c and parameters != '-lsq6':  # emulate Claude's approach
                        _parameters = c.get('parameters')  #'-lsq7'

                    _reverse = c.get('reverse',
                                     False)  # swap target and source
                    # set up intermediate files
                    if start is not None and start > c['blur_fwhm']:
                        continue
                    elif close and c['blur_fwhm'] > 8:
                        continue

                    tmp_xfm = tmp.tmp(s_base + '_' + t_base + '_' + str(i) +
                                      '.xfm')

                    tmp_sources = sources_lr
                    tmp_targets = targets_lr

                    if c['blur_fwhm'] > 0:
                        tmp_sources = []
                        tmp_targets = []

                        for s_, _ in enumerate(sources_lr):
                            tmp_source = tmp.cache(s_base + '_' + c['blur'] +
                                                   '_' + str(c['blur_fwhm']) +
                                                   '_' + str(s_) + '.mnc')
                            if not os.path.exists(tmp_source):
                                minc.blur(sources_lr[s_],
                                          tmp_source,
                                          gmag=(c['blur'] == 'dxyz'),
                                          fwhm=c['blur_fwhm'])

                            tmp_target = tmp.cache(t_base + '_' + c['blur'] +
                                                   '_' + str(c['blur_fwhm']) +
                                                   '_' + str(s_) + '.mnc')
                            if not os.path.exists(tmp_target):
                                minc.blur(targets_lr[s_],
                                          tmp_target,
                                          gmag=(c['blur'] == 'dxyz'),
                                          fwhm=c['blur_fwhm'])

                            tmp_sources.append(tmp_source)
                            tmp_targets.append(tmp_target)

                    objective_ = objective

                    if isinstance(objective, list):
                        objective_ = objective[0]

                    if _reverse:
                        args = [
                            'minctracc', tmp_targets[0], tmp_sources[0],
                            '-clobber', _parameters, objective_, '-simplex',
                            c['simplex'], '-tol', c['tolerance']
                        ]

                        # additional modalities
                        for s_ in range(len(tmp_targets) - 1):
                            if isinstance(objective, list):
                                objective_ = objective[s_ + 1]
                            args.extend([
                                '-feature_vol', tmp_targets[s_ + 1],
                                tmp_sources[s_ + 1],
                                objective_.lstrip('-'), 1.0
                            ])
                    else:
                        # set up registration
                        args = [
                            'minctracc', tmp_sources[0], tmp_targets[0],
                            '-clobber', _parameters, objective_, '-simplex',
                            c['simplex'], '-tol', c['tolerance']
                        ]

                        for s_ in range(len(tmp_targets) - 1):
                            if isinstance(objective, list):
                                objective_ = objective[s_ + 1]
                            args.extend([
                                '-feature_vol', tmp_sources[s_ + 1],
                                tmp_targets[s_ + 1],
                                objective_.lstrip('-'), 1.0
                            ])

                    args.append('-step')
                    args.extend(c['steps'])

                    # Current transformation at this step
                    if prev_xfm is not None:
                        if _reverse:
                            inv_prev_xfm = tmp.tmp(s_base + '_' + t_base +
                                                   '_' + str(i) + '_init.xfm')
                            minc.xfminvert(prev_xfm, inv_prev_xfm)
                            args.extend(['-transformation', inv_prev_xfm])
                        else:
                            args.extend(['-transformation', prev_xfm])
                    elif init_xfm is not None:
                        # _reverse should not be first?
                        args.extend(
                            ['-transformation', init_xfm, '-est_center'])
                    elif close:
                        args.append('-identity')
                    else:
                        # _reverse should not be first?
                        # Initial transformation will be computed from the from Principal axis
                        # transformation (PAT).
                        if c['trans'] is not None and c['trans'][
                                0] != '-est_translations':
                            args.extend(c['trans'])
                        else:
                            # will use manual transformation based on shif of CoM, should be identical to '-est_translations' , but it's not
                            com_src = minc.stats(source,
                                                 ['-com', '-world_only'],
                                                 single_value=False)
                            com_trg = minc.stats(target,
                                                 ['-com', '-world_only'],
                                                 single_value=False)
                            diff = [com_trg[k] - com_src[k] for k in range(3)]
                            xfm = tmp.cache(s_base + '_init.xfm')
                            minc.param2xfm(xfm, translation=diff)
                            args.extend(['-transformation', xfm])

                    # masks (even if the blurred image is masked, it's still preferable
                    # to use the mask in minctracc)
                    if _reverse:
                        if source_mask is not None:
                            args.extend(['-model_mask', source_mask_lr])
                        #disable one mask in this mode
                        #if target_mask is not None:
                        #args.extend(['-source_mask',  target_mask_lr])
                    else:
                        if source_mask is not None:
                            args.extend(['-source_mask', source_mask_lr])
                        if target_mask is not None:
                            args.extend(['-model_mask', target_mask_lr])

                    if noshear:
                        args.extend(['-w_shear', 0, 0, 0])
                    if noscale:
                        args.extend(['-w_scales', 0, 0, 0])
                    if noshift:
                        args.extend(['-w_translations', 0, 0, 0])
                    if norot:
                        args.extend(['-w_rotations', 0, 0, 0])

                    # add files and run registration
                    args.append(tmp_xfm)
                    minc.command([str(ii) for ii in args],
                                 inputs=[tmp_source, tmp_target],
                                 outputs=[tmp_xfm])

                    if _reverse:
                        inv_tmp_xfm = tmp.tmp(s_base + '_' + t_base + '_' +
                                              str(i) + '_sol.xfm')
                        minc.xfminvert(tmp_xfm, inv_tmp_xfm)
                        prev_xfm = inv_tmp_xfm
                    else:
                        prev_xfm = tmp_xfm

                shutil.copyfile(prev_xfm, output_xfm)
                return output_xfm
def register_elastix(source,
                     target,
                     output_par=None,
                     output_xfm=None,
                     source_mask=None,
                     target_mask=None,
                     init_xfm=None,
                     init_par=None,
                     parameters=None,
                     work_dir=None,
                     downsample=None,
                     downsample_grid=None,
                     nl=True,
                     output_log=None,
                     tags=None,
                     verbose=0,
                     iterations=None):
    """Run elastix with given parameters
    Arguments:
    source -- source image (fixed image in Elastix notation)
    target -- target, or reference image (moving image in Elastix notation)
    
    Keyword arguments:
    output_par -- output transformation in elastix format
    output_xfm -- output transformation in MINC XFM format
    source_mask -- source mask
    target_mask -- target mask
    init_xfm    -- initial transform in XFM format
    init_par    -- initial transform in Elastix format
    parameters  -- parameters for transformation
                   if it is a string starting with @ it's a text file name that contains 
                   parameters in elastix format
                   if it any other string - it should be treated as transformation parameters in elastix format
                   if it is a dictionary:
                   for non-linear mode (nl==True):
                        "optimizer" , "AdaptiveStochasticGradientDescent" (default for nonlinear)
                                      "CMAEvolutionStrategy" (default for linear)
                                      "ConjugateGradient"
                                      "ConjugateGradientFRPR"
                                      "FiniteDifferenceGradientDescent"
                                      "QuasiNewtonLBFGS"
                                      "RegularStepGradientDescent"
                                      "RSGDEachParameterApart"
                                      
                        "transform", "BSplineTransform" (default for nonlinear mode)
                                     "SimilarityTransform" (default for linear)
                                     "AffineTransform"
                                     "AffineDTITransform"
                                     "EulerTransform"
                                     "MultiBSplineTransformWithNormal"
                                     "TranslationTransform"
                                     
                        "metric"   , "AdvancedNormalizedCorrelation"  (default)
                                     "AdvancedMattesMutualInformation"
                                     "NormalizedMutualInformation"
                                     "AdvancedKappaStatistic"
                                     "KNNGraphAlphaMutualInformation"
                                     
                        "resolutions", 3   - number of resolution steps
                        "pyramid","8 8 8 4 4 4 2 2 2" - downsampling schedule
                        "iterations",4000 - number of iterations
                        "samples",4096  - number of samples
                        "sampler", "Random" (default)
                                   "Full"
                                   "RandomCoordinate"
                                   "Grid"  TODO: add SampleGridSpacing
                                   "RandomSparseMask"
                                   
                        "grid_spacing",10  - grid spacing in mm
                        "max_step","1.0" - maximum step (mm)
                        
                   for linear mode (nl==False):
                        "optimizer","CMAEvolutionStrategy" - optimizer
                        "transform","SimilarityTransform"  - transform
                        "metric","AdvancedNormalizedCorrelation" - cost function
                        "resolutions", 3  - number of resolutions
                        "pyramid","8 8 8  4 4 4 2 2 2" - resampling schedule
                        "iterations",4000  - number of iterations
                        "samples",4096 - number of samples
                        "sampler","Random" - sampler
                        "max_step","1.0" - max step
                        "automatic_transform_init",True - perform automatic transform initialization
                        "automatic_transform_init_method", - type of automatic transform initalization method, 
                                                          "CenterOfGravity" (default)
                                                          "GeometricalCenter" - center of the image based
    work_dir    -- Work directory
    downsample  -- Downsample input images
    downsample_grid -- Downsample output nl-deformation
    nl          -- flag to show that non-linear version is running
    output_log  -- output log
    iterations  -- run several iterations (restarting elastix), will be done automatically if parameters is a list
    """
    with minc_tools.mincTools() as minc:

        def_iterations = 4000

        s_base = os.path.basename(source).rsplit('.gz',
                                                 1)[0].rsplit('.mnc', 1)[0]
        t_base = os.path.basename(target).rsplit('.gz',
                                                 1)[0].rsplit('.mnc', 1)[0]

        source_lr = source
        target_lr = target

        source_mask_lr = source_mask
        target_mask_lr = target_mask
        use_mask = True

        if (init_par is not None) and (init_xfm is not None):
            print("register_elastix: init_xfm={} init_par={}".format(
                repr(init_xfm), repr(init_par)))
            raise minc_tools.mincError("Specify either init_xfm or init_par")

        outputs = []
        if output_par is not None: outputs.append(output_par)
        if output_xfm is not None: outputs.append(output_xfm)

        if len(outputs) > 0 and (not minc.checkfiles(inputs=[source, target],
                                                     outputs=outputs)):
            return

        threads = os.environ.get('ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS', 1)

        if parameters is None:
            parameters = {}
        #print("Running elastix with parameters:{}".format(repr(parameters)))
        # figure out what to do here:
        with minc_tools.cache_files(work_dir=work_dir,
                                    context='elastix') as tmp:

            if init_xfm is not None:
                if nl:
                    init_par = nl_xfm_to_elastix(init_xfm,
                                                 tmp.cache('init.txt'))
                else:
                    init_par = lin_xfm_to_elastix(init_xfm,
                                                  tmp.cache('init.txt'))

            # a fitting we shall go...
            if downsample is not None:
                source_lr = tmp.cache(s_base + '_' + str(downsample) + '.mnc')
                target_lr = tmp.cache(t_base + '_' + str(downsample) + '.mnc')

                minc.resample_smooth(source, source_lr, unistep=downsample)
                minc.resample_smooth(target, target_lr, unistep=downsample)

                if source_mask is not None:
                    source_mask_lr = tmp.cache(s_base + '_mask_' +
                                               str(downsample) + '.mnc')
                    minc.resample_labels(source_mask,
                                         source_mask_lr,
                                         unistep=downsample,
                                         datatype='byte')

                if target_mask is not None:
                    target_mask_lr = tmp.cache(s_base + '_mask_' +
                                               str(downsample) + '.mnc')
                    minc.resample_labels(target_mask,
                                         target_mask_lr,
                                         unistep=downsample,
                                         datatype='byte')

            _iterations = 1

            if isinstance(parameters, list):
                _iterations = len(parameters)

            try:
                for it in range(_iterations):

                    if isinstance(parameters, list):
                        _par = parameters[it]
                    else:
                        _par = parameters

                    par_file = tmp.cache('parameters_{}.txt'.format(it))
                    measure_mode = False
                    # paramters could be stored in a file

                    if isinstance(_par, dict):
                        use_mask = _par.get('use_mask', True)
                        measure_mode = _par.get('measure', False)

                        if measure_mode:
                            def_iterations = 1
                            _par['iterations'] = 1

                        if nl:
                            gen_config_nl(_par,
                                          par_file,
                                          def_iterations=def_iterations)
                        else:
                            gen_config_lin(_par,
                                           par_file,
                                           def_iterations=def_iterations)
                    else:
                        if _par[0] == "@":
                            par_file = _par.split("@", 1)[1]
                        else:
                            with open(par_file, 'w') as p:
                                p.write(_par)
                    it_output_dir = tmp.tempdir + os.sep + str(it)
                    if not os.path.exists(it_output_dir):
                        os.makedirs(it_output_dir)

                    cmd = [
                        'elastix', '-f', source_lr, '-m', target_lr, '-out',
                        it_output_dir + os.sep, '-p', par_file, '-threads',
                        str(threads)
                    ]  # , '-q'

                    if measure_mode:
                        cmd.append('-M')

                    if verbose < 1:
                        cmd.append('-q')

                    inputs = [source_lr, target_lr]

                    if init_par is not None:
                        cmd.extend(['-t0', init_par])
                        inputs.append(init_par)

                    if source_mask is not None and use_mask:
                        cmd.extend(['-fMask', source_mask_lr])
                        inputs.append(source_mask_lr)

                    if target_mask is not None and use_mask:
                        cmd.extend(['-mMask', target_mask_lr])
                        inputs.append(target_mask_lr)

                    if tags is not None:
                        vols = tag2elx(tags, tmp.cache(s_base + '_tags.txt'),
                                       tmp.cache(t_base + '_tags.txt'))
                        inputs.append(tmp.cache(s_base + '_tags.txt'))
                        cmd.extend(['-fp', tmp.cache(s_base + '_tags.txt')])
                        shutil.copyfile(tmp.cache(s_base + '_tags.txt'),
                                        "source.tag")

                        if vols > 1:
                            inputs.append(tmp.cache(t_base + '_tags.txt'))
                            cmd.extend(
                                ['-mp', tmp.cache(t_base + '_tags.txt')])
                            shutil.copyfile(tmp.cache(t_base + '_tags.txt'),
                                            "target.tag")

                    outputs = [
                        it_output_dir + os.sep + 'TransformParameters.0.txt'
                    ]

                    outcome = None

                    if measure_mode:
                        # going to read the output of iterations
                        out_ = minc.execute_w_output(cmd).split("\n")
                        for l, j in enumerate(out_):
                            if re.match("^1\:ItNr\s2\:Metric\s.*", j):
                                outcome = float(out_[l + 1].split("\t")[1])
                                break
                        else:
                            #
                            print("Elastix output:\n{}".format(
                                "\n".join(out_)))
                            raise minc_tools.mincError(
                                "Elastix didn't report measure")
                    else:
                        minc.command(cmd,
                                     inputs=inputs,
                                     outputs=outputs,
                                     verbose=verbose)

                    init_par = it_output_dir + os.sep + 'TransformParameters.0.txt'
                    # end of iterations

                if output_par is not None:
                    shutil.copyfile(
                        it_output_dir + os.sep + 'TransformParameters.0.txt',
                        output_par)

                if output_xfm is not None:
                    nl_elastix_to_xfm(it_output_dir + os.sep +
                                      'TransformParameters.0.txt',
                                      output_xfm,
                                      downsample_grid=downsample_grid,
                                      nl=nl)

            finally:
                if output_log is not None:
                    shutil.copyfile(it_output_dir + os.sep + 'elastix.log',
                                    output_log)

        return outcome
def non_linear_register_ldd(source,
                            target,
                            output_velocity,
                            output_xfm=None,
                            source_mask=None,
                            target_mask=None,
                            init_xfm=None,
                            init_velocity=None,
                            level=2,
                            start=32,
                            parameters=None,
                            work_dir=None,
                            downsample=None):
    """Use log-diffeomorphic demons to run registration"""

    with minc_tools.mincTools() as minc:
        if not minc.checkfiles(inputs=[source, target],
                               outputs=[output_velocity]):
            return
        if parameters is None:
            parameters = {
                'conf': {},
                'smooth_update': 2,
                'smooth_field': 2,
                'update_rule': 1,
                'grad_type': 0,
                'max_step': 2.0,
                'hist_match': True,
                'LCC': False
            }

        LCC = parameters.get('LCC', False)

        source_lr = source
        target_lr = target
        source_mask_lr = source_mask
        target_mask_lr = target_mask

        if downsample is not None:
            s_base = os.path.basename(source).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]
            t_base = os.path.basename(target).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]
            source_lr = minc.tmp(s_base + '_' + str(downsample) + '.mnc')
            target_lr = minc.tmp(t_base + '_' + str(downsample) + '.mnc')

            minc.resample_smooth(source, source_lr, unistep=downsample)
            minc.resample_smooth(target, target_lr, unistep=downsample)

            if target_mask is not None:
                target_mask_lr = minc.tmp(s_base + '_mask_' + str(downsample) +
                                          '.mnc')
                minc.resample_labels(target_mask,
                                     target_mask_lr,
                                     unistep=downsample,
                                     datatype='byte')
            if target_mask is not None:
                target_mask_lr = minc.tmp(s_base + '_mask_' + str(downsample) +
                                          '.mnc')
                minc.resample_labels(target_mask,
                                     target_mask_lr,
                                     unistep=downsample,
                                     datatype='byte')

        prog = ''

        for i in range(int(math.log(start) / math.log(2)), -1, -1):
            res = 2**i
            if res >= level:
                prog += str(parameters['conf'].get(res, 20))
            else:
                prog += '0'
            if i > 0:
                prog += 'x'

        inputs = [source, target]
        cmd = None

        if LCC:
            cmd = [
                'rpiLCClogDemons', '-f', source_lr, '-m', target_lr,
                '--output-transform', output_velocity, '-S',
                str(parameters.get('tradeoff', 0.15)), '-u',
                str(parameters.get('smooth_update', 2)), '-d',
                str(parameters.get('smooth_field', 2)), '-C',
                str(parameters.get('smooth_similarity', 3)), '-b',
                str(parameters.get('bending_weight', 1)), '-x',
                str(parameters.get('harmonic_weight', 0)), '-r',
                str(parameters.get('update_rule', 2)), '-g',
                str(parameters.get('grad_type', 0)), '-l',
                str(parameters.get('max_step', 2.0)), '-a', prog
            ]

            if parameters.get('hist_match', True):
                cmd.append('--use-histogram-matching')

            # generate programm
            if source_mask_lr is not None:
                cmd.extend(['--mask-image', source_mask_lr])
                inputs.append(source_mask_lr)

            if init_velocity is not None:
                cmd.extend(['--initial-transform', init_velocity])
                inputs.append(init_velocity)
        else:
            cmd = [
                'LogDomainDemonsRegistration', '-f', source_lr, '-m',
                target_lr, '--outputVel-field', output_velocity, '-g',
                str(parameters.get('smooth_update', 2)), '-s',
                str(parameters.get('smooth_field', 2)), '-a',
                str(parameters.get('update_rule', 1)), '-t',
                str(parameters.get('grad_type', 0)), '-l',
                str(parameters.get('max_step', 2.0)), '-i', prog
            ]

            if parameters.get('hist_match', True):
                cmd.append('--use-histogram-matching')

            # generate programm
            if source_mask_lr is not None:
                cmd.extend(['--fixed-mask', source_mask_lr])
                inputs.append(source_mask_lr)

            if target_mask_lr is not None:
                cmd.extend(['--moving-mask', target_mask_lr])
                inputs.append(target_mask_lr)

            if init_velocity is not None:
                cmd.extend(['--input-field', init_velocity])
                inputs.append(init_velocity)

            if init_xfm is not None:
                cmd.extend(['--input-transform', init_xfm])
                inputs.append(init_xfm)

            if output_xfm is not None:
                cmd.extend(['--outputDef-field', output_xfm])
                outputs.append(output_xfm)

        outputs = [output_velocity]

        minc.command(cmd, inputs=inputs, outputs=outputs)
def non_linear_register_dd(source,
                           target,
                           output_xfm,
                           source_mask=None,
                           target_mask=None,
                           init_xfm=None,
                           level=4,
                           start=32,
                           parameters=None,
                           work_dir=None,
                           downsample=None):
    """perform incremental non-linear registration with diffeomorphic demons"""

    with minc_tools.mincTools() as minc:
        if not minc.checkfiles(inputs=[source, target], outputs=[output_xfm]):
            return

        if parameters is None:
            parameters = {
                'conf': {},
                'smooth_update': 2,
                'smooth_field': 2,
                'update_rule': 0,
                'grad_type': 0,
                'max_step': 2.0,
                'hist_match': True
            }

        source_lr = source
        target_lr = target
        source_mask_lr = source_mask
        target_mask_lr = target_mask

        if downsample is not None:
            s_base = os.path.basename(source).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]
            t_base = os.path.basename(target).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]
            source_lr = minc.tmp(s_base + '_' + str(downsample) + '.mnc')
            target_lr = minc.tmp(t_base + '_' + str(downsample) + '.mnc')

            minc.resample_smooth(source, source_lr, unistep=downsample)
            minc.resample_smooth(target, target_lr, unistep=downsample)

            if target_mask is not None:
                target_mask_lr = minc.tmp(s_base + '_mask_' + str(downsample) +
                                          '.mnc')
                minc.resample_labels(target_mask,
                                     target_mask_lr,
                                     unistep=downsample,
                                     datatype='byte')
            if target_mask is not None:
                target_mask_lr = minc.tmp(s_base + '_mask_' + str(downsample) +
                                          '.mnc')
                minc.resample_labels(target_mask,
                                     target_mask_lr,
                                     unistep=downsample,
                                     datatype='byte')

        prog = ''

        for i in range(int(math.log(start) / math.log(2)), -1, -1):
            res = 2**i
            if res >= level:
                prog += str(parameters['conf'].get(res, 20))
            else:
                prog += '0'
            if i > 0:
                prog += 'x'

        inputs = [source_lr, target_lr]
        cmd = [
            'DemonsRegistration', '-f', source_lr, '-m', target_lr,
            '--outputDef-field', output_xfm, '-g',
            str(parameters.get('smooth_update', 2)), '-s',
            str(parameters.get('smooth_field', 2)), '-a',
            str(parameters.get('update_rule', 0)), '-t',
            str(parameters.get('grad_type', 0)), '-l',
            str(parameters.get('max_step', 2.0)), '-i', prog
        ]

        if parameters.get('hist_match', True):
            cmd.append('--use-histogram-matching')
        # generate programm

        if source_mask_lr is not None:
            cmd.extend(['--fixed-mask', source_mask_lr])
            inputs.append(source_mask_lr)

        if target_mask_lr is not None:
            cmd.extend(['--moving-mask', target_mask_lr])
            inputs.append(target_mask_lr)

        if init_xfm is not None:
            cmd.extend(['--input-transform', init_xfm])
            inputs.append(init_xfm)

        outputs = [output_xfm]

        minc.command(cmd, inputs=inputs, outputs=outputs)
Exemplo n.º 10
0
def ants_linear_register(source,
                         target,
                         output_xfm,
                         parameters=None,
                         source_mask=None,
                         target_mask=None,
                         init_xfm=None,
                         objective=None,
                         conf=None,
                         debug=False,
                         close=False,
                         work_dir=None,
                         downsample=None,
                         verbose=0):
    """perform linear registration with ANTs"""

    # TODO: make use of parameters

    if parameters is None:
        parameters = {}

    with minc_tools.mincTools(verbose=verbose) as minc:
        if not minc.checkfiles(inputs=[source, target], outputs=[output_xfm]):
            return

        prev_xfm = None

        s_base = os.path.basename(source).rsplit('.gz',
                                                 1)[0].rsplit('.mnc', 1)[0]
        t_base = os.path.basename(target).rsplit('.gz',
                                                 1)[0].rsplit('.mnc', 1)[0]

        source_lr = source
        target_lr = target

        source_mask_lr = source_mask
        target_mask_lr = target_mask
        # figure out what to do here:
        with minc_tools.cache_files(work_dir=work_dir, context='reg') as tmp:

            if downsample is not None:
                source_lr = tmp.cache(s_base + '_' + str(downsample) + '.mnc')
                target_lr = tmp.cache(t_base + '_' + str(downsample) + '.mnc')

                minc.resample_smooth(source, source_lr, unistep=downsample)
                minc.resample_smooth(target, target_lr, unistep=downsample)

                if source_mask is not None:
                    source_mask_lr = tmp.cache(s_base + '_mask_' +
                                               str(downsample) + '.mnc')
                    minc.resample_labels(source_mask,
                                         source_mask_lr,
                                         unistep=downsample,
                                         datatype='byte')
                if target_mask is not None:
                    target_mask_lr = tmp.cache(s_base + '_mask_' +
                                               str(downsample) + '.mnc')
                    minc.resample_labels(target_mask,
                                         target_mask_lr,
                                         unistep=downsample,
                                         datatype='byte')

            iterations = parameters.get('affine-iterations',
                                        '10000x10000x10000x10000x10000')

            default_gradient_descent_option = '0.5x0.95x1.e-5x1.e-4'
            if close: default_gradient_descent_option = '0.05x0.5x1.e-4x1.e-4'
            gradient_descent_option = parameters.get(
                'gradient_descent_option', default_gradient_descent_option)

            mi_option = parameters.get('mi-option', '32x16000')
            use_mask = parameters.get('use_mask', True)
            use_histogram_matching = parameters.get('use_histogram_matching',
                                                    False)
            affine_metric = parameters.get('metric_type', 'MI')
            affine_rigid = parameters.get('rigid', False)

            cost_function_par = '1,4'

            cmd = ['ANTS', '3']

            s_base = os.path.basename(source).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]
            t_base = os.path.basename(target).rsplit('.gz',
                                                     1)[0].rsplit('.mnc', 1)[0]

            source_lr = source
            target_lr = target

            target_mask_lr = target_mask

            if downsample is not None:
                source_lr = tmp.cache(s_base + '_' + str(downsample) + '.mnc')
                target_lr = tmp.cache(t_base + '_' + str(downsample) + '.mnc')

                minc.resample_smooth(source, source_lr, unistep=downsample)
                minc.resample_smooth(target, target_lr, unistep=downsample)

                if target_mask is not None:
                    target_mask_lr = tmp.cache(s_base + '_mask_' +
                                               str(downsample) + '.mnc')
                    minc.resample_labels(target_mask,
                                         target_mask_lr,
                                         unistep=downsample,
                                         datatype='byte')

            cmd.extend([
                '-m', '{}[{},{},{}]'.format('CC', source_lr, target_lr,
                                            cost_function_par)
            ])
            cmd.extend(['-i', '0'])
            cmd.extend(['--number-of-affine-iterations', iterations])
            cmd.extend(
                ['--affine-gradient-descent-option', gradient_descent_option])
            cmd.extend(['--MI-option', mi_option])
            cmd.extend(['--affine-metric-type', affine_metric])

            if affine_rigid:
                cmd.append('--rigid-affine')

            cmd.extend(['-o', output_xfm])

            inputs = [source_lr, target_lr]
            if target_mask_lr is not None and use_mask:
                inputs.append(target_mask_lr)
                cmd.extend(['-x', target_mask_lr])

            if use_histogram_matching:
                cmd.append('--use-Histogram-Matching')

            if winsorize_intensity is not None:
                if isinstance(winsorize_intensity, dict):
                    cmd.extend([
                        '--winsorize-image-intensities',
                        winsorize_intensity.get('low', 5),
                        winsorize_intensity.get('high', 95)
                    ])
                else:
                    cmd.append('--winsorize-image-intensities')

            if init_xfm is not None:
                cmd.extend(['--initial-affine', init_xfm])

            outputs = [output_xfm]  # TODO: add inverse xfm ?
            minc.command(cmd, inputs=inputs, outputs=outputs)
Exemplo n.º 11
0
def linear_register_ants2(source,
                          target,
                          output_xfm,
                          target_mask=None,
                          source_mask=None,
                          init_xfm=None,
                          parameters=None,
                          downsample=None,
                          close=False,
                          verbose=0):
    """perform linear registration using ANTs"""
    #TODO:implement close

    with minc_tools.mincTools(verbose=verbose) as minc:

        if parameters is None:
            #TODO add more options here
            parameters = {'conf': {}, 'blur': {}, 'shrink': {}}
        else:
            if not 'conf' in parameters: parameters['conf'] = {}
            if not 'blur' in parameters: parameters['blur'] = {}
            if not 'shrink' in parameters: parameters['shrink'] = {}

        levels = parameters.get('levels', 3)
        prog = ''
        shrink = ''
        blur = ''

        for i in range(levels, 0, -1):
            _i = str(i)
            prog += str(parameters['conf'].get(
                i, parameters['conf'].get(_i, 10000)))
            shrink += str(parameters['shrink'].get(
                i, parameters['shrink'].get(_i, 2**i)))
            blur += str(parameters['blur'].get(
                i, parameters['blur'].get(_i, 2**i)))

            if i > 1:
                prog += 'x'
                shrink += 'x'
                blur += 'x'
        # TODO: make it a parameter?
        prog += ',' + parameters.get('convergence', '1.e-8,20')

        sources = []
        targets = []

        if isinstance(source, list):
            sources.extend(source)
        else:
            sources.append(source)

        if isinstance(target, list):
            targets.extend(target)
        else:
            targets.append(target)

        if len(sources) != len(targets):
            raise mincError(' ** Error: Different number of inputs ')

        modalities = len(sources)

        if not minc.checkfiles(inputs=sources + targets, outputs=[output_xfm]):
            return

        output_base = output_xfm.rsplit('.xfm', 1)[0]

        cost_function = parameters.get('cost_function', 'Mattes')
        cost_function_par = parameters.get('cost_function_par',
                                           '1,32,regular,0.3')

        transformation = parameters.get('transformation', 'affine[ 0.1 ]')
        use_mask = parameters.get('use_mask', True)
        use_histogram_matching = parameters.get('use_histogram_matching',
                                                False)
        winsorize_intensity = parameters.get('winsorize-image-intensities',
                                             None)
        use_float = parameters.get('use_float', False)
        intialize_fixed = parameters.get('initialize_fixed', None)
        intialize_moving = parameters.get('intialize_moving', None)

        cmd = [
            'antsRegistration', '--collapse-output-transforms', '0', '--minc',
            '-a', '--dimensionality', '3'
        ]

        (sources_lr, targets_lr, source_mask_lr,
         target_mask_lr) = minc.downsample_registration_files(
             sources, targets, source_mask, target_mask, downsample)

        # generate modalities
        for _s in range(modalities):
            if isinstance(cost_function, list):
                cost_function_ = cost_function[_s]
            else:
                cost_function_ = cost_function
            #
            if isinstance(cost_function_par, list):
                cost_function_par_ = cost_function_par[_s]
            else:
                cost_function_par_ = cost_function_par
            #
            cmd.extend([
                '--metric',
                '{}[{},{},{}]'.format(cost_function_, sources_lr[_s],
                                      targets_lr[_s], cost_function_par_)
            ])
        #
        #
        cmd.extend(['--convergence', '[{}]'.format(prog)])
        cmd.extend(['--shrink-factors', shrink])
        cmd.extend(['--smoothing-sigmas', blur])
        cmd.extend(['--transform', transformation])
        cmd.extend(['--output', output_base])
        #cmd.extend(['--save-state',output_xfm])

        if init_xfm is not None:
            cmd.extend(['--initial-fixed-transform', init_xfm])
            # this is  a hack in attempt to make initial linear transform to work as expected
            # currently, it looks like the center of the transform (i.e center of rotation) is messed up :(
            # and it causes lots of problems
            cmd.extend(['--initialize-transforms-per-stage', '1'])
        elif intialize_fixed is not None:
            cmd.extend([
                '--initial-fixed-transform',
                "[{},{},{}]".format(sources_lr[0], targets_lr[0],
                                    str(intialize_fixed))
            ])
        elif not close:
            cmd.extend([
                '--initial-fixed-transform',
                "[{},{},{}]".format(sources_lr[0], targets_lr[0], '0')
            ])

        if intialize_moving is not None:
            cmd.extend([
                '--initial-moving-transform',
                "[{},{},{}]".format(sources_lr[0], targets_lr[0],
                                    str(intialize_moving))
            ])
        elif not close:
            cmd.extend([
                '--initial-moving-transform',
                "[{},{},{}]".format(sources_lr[0], targets_lr[0], '0')
            ])
        #
        inputs = sources_lr + targets_lr
        #
        if target_mask_lr is not None and source_mask_lr is not None and use_mask:
            inputs.extend([source_mask_lr, target_mask_lr])
            cmd.extend(
                ['-x', '[{},{}]'.format(source_mask_lr, target_mask_lr)])

        if use_histogram_matching:
            cmd.append('--use-histogram-matching')

        if winsorize_intensity is not None:
            if isinstance(winsorize_intensity, dict):
                cmd.extend([
                    '--winsorize-image-intensities',
                    winsorize_intensity.get('low', 1),
                    winsorize_intensity.get('high', 99)
                ])
            else:
                cmd.append('--winsorize-image-intensities')

        if use_float:
            cmd.append('--float')

        if verbose > 0:
            cmd.extend(['--verbose', '1'])

        outputs = [output_xfm]  # TODO: add inverse xfm ?
        minc.command(cmd, inputs=inputs, outputs=outputs, verbose=verbose)
Exemplo n.º 12
0
def non_linear_register_ants2(source,
                              target,
                              output_xfm,
                              target_mask=None,
                              source_mask=None,
                              init_xfm=None,
                              parameters=None,
                              downsample=None,
                              start=None,
                              level=32.0,
                              verbose=0):
    """perform non-linear registration using ANTs, WARNING: will create inverted xfm  will be named output_invert.xfm"""
    if start is None:
        start = level

    with minc_tools.mincTools(verbose=verbose) as minc:

        sources = []
        targets = []

        if isinstance(source, list):
            sources.extend(source)
        else:
            sources.append(source)

        if isinstance(target, list):
            targets.extend(target)
        else:
            targets.append(target)
        if len(sources) != len(targets):
            raise mincError(' ** Error: Different number of inputs ')

        modalities = len(sources)

        if parameters is None:
            #TODO add more options here
            parameters = {'conf': {}, 'blur': {}, 'shrink': {}}
        else:
            if not 'conf' in parameters: parameters['conf'] = {}
            if not 'blur' in parameters: parameters['blur'] = {}
            if not 'shrink' in parameters: parameters['shrink'] = {}

        prog = ''
        shrink = ''
        blur = ''
        for i in range(int(math.log(start) / math.log(2)), -1, -1):
            res = 2**i
            if res >= level:
                prog += str(parameters['conf'].get(
                    res, parameters['conf'].get(str(res), 20)))
                shrink += str(parameters['shrink'].get(
                    res, parameters['shrink'].get(str(res), 2**i)))
                blur += str(parameters['blur'].get(
                    res, parameters['blur'].get(str(res), 2**i)))
            if res > level:
                prog += 'x'
                shrink += 'x'
                blur += 'x'

        if not minc.checkfiles(inputs=sources + targets, outputs=[output_xfm]):
            return

        prog += ',' + parameters.get('convergence', '1.e-6,10')

        output_base = output_xfm.rsplit('.xfm', 1)[0]

        cost_function = parameters.get('cost_function', 'CC')
        cost_function_par = parameters.get('cost_function_par',
                                           '1,2,Regular,1.0')

        transformation = parameters.get('transformation', 'SyN[ .25, 2, 0.5 ]')
        use_mask = parameters.get('use_mask', True)
        use_histogram_matching = parameters.get('use_histogram_matching',
                                                False)
        use_float = parameters.get('use_float', False)

        winsorize_intensity = parameters.get('winsorize-image-intensities',
                                             None)

        cmd = ['antsRegistration', '--minc', '-a', '--dimensionality', '3']

        (sources_lr, targets_lr, source_mask_lr,
         target_mask_lr) = minc.downsample_registration_files(
             sources, targets, source_mask, target_mask, downsample)

        # generate modalities
        for _s in range(modalities):
            if isinstance(cost_function, list):
                cost_function_ = cost_function[_s]
            else:
                cost_function_ = cost_function
            #
            if isinstance(cost_function_par, list):
                cost_function_par_ = cost_function_par[_s]
            else:
                cost_function_par_ = cost_function_par
            #
            cmd.extend([
                '--metric',
                '{}[{},{},{}]'.format(cost_function_, sources_lr[_s],
                                      targets_lr[_s], cost_function_par_)
            ])

        cmd.extend(['--convergence', '[{}]'.format(prog)])
        cmd.extend(['--shrink-factors', shrink])
        cmd.extend(['--smoothing-sigmas', blur])
        cmd.extend(['--transform', transformation])

        cmd.extend(['--output', output_base])
        #cmd.extend(['--save-state',output_xfm])

        if init_xfm is not None:
            cmd.extend(['--initial-fixed-transform', init_xfm])

        inputs = sources_lr + targets_lr

        if target_mask_lr is not None and source_mask_lr is not None and use_mask:
            inputs.extend([source_mask_lr, target_mask_lr])
            cmd.extend(
                ['-x', '[{},{}]'.format(source_mask_lr, target_mask_lr)])

        if use_histogram_matching:
            cmd.append('--use-histogram-matching')

        if winsorize_intensity is not None:
            if isinstance(winsorize_intensity, dict):
                cmd.extend([
                    '--winsorize-image-intensities',
                    str(winsorize_intensity.get('low', 1)),
                    str(winsorize_intensity.get('high', 99))
                ])
            else:
                cmd.append('--winsorize-image-intensities')

        if use_float:
            cmd.append('--float')

        if verbose > 0:
            cmd.extend(['--verbose', '1'])

        outputs = [output_xfm]  # TODO: add inverse xfm ?

        print(">>>\n{}\n>>>>".format(' '.join(cmd)))

        minc.command(cmd, inputs=inputs, outputs=outputs)
Exemplo n.º 13
0
def non_linear_register_ants(source,
                             target,
                             output_xfm,
                             target_mask=None,
                             init_xfm=None,
                             parameters=None,
                             downsample=None,
                             verbose=0):
    """perform non-linear registration using ANTs, WARNING: will create inverted xfm  will be named output_invert.xfm"""

    with minc_tools.mincTools(verbose=verbose) as minc:

        if parameters is None:
            #print("Using default  ANTS parameters")
            parameters = {}

        if not minc.checkfiles(inputs=[source, target], outputs=[output_xfm]):
            return

        cost_function = parameters.get('cost_function', 'CC')
        cost_function_par = parameters.get('cost_function_par', '1,2')

        reg = parameters.get('regularization', 'Gauss[2,0.5]')
        iterations = parameters.get('iter', '20x20x0')
        transformation = parameters.get('transformation', 'SyN[0.25]')
        affine_iterations = parameters.get('affine-iterations', '0x0x0')
        use_mask = parameters.get('use_mask', True)
        use_histogram_matching = parameters.get('use_histogram_matching',
                                                False)

        cmd = ['ANTS', '3']

        s_base = os.path.basename(source).rsplit('.gz',
                                                 1)[0].rsplit('.mnc', 1)[0]
        t_base = os.path.basename(target).rsplit('.gz',
                                                 1)[0].rsplit('.mnc', 1)[0]

        source_lr = source
        target_lr = target

        target_mask_lr = target_mask

        if downsample is not None:
            source_lr = tmp.cache(s_base + '_' + str(downsample) + '.mnc')
            target_lr = tmp.cache(t_base + '_' + str(downsample) + '.mnc')

            minc.resample_smooth(source, source_lr, unistep=downsample)
            minc.resample_smooth(target, target_lr, unistep=downsample)

            if target_mask is not None:
                target_mask_lr = tmp.cache(s_base + '_mask_' +
                                           str(downsample) + '.mnc')
                minc.resample_labels(target_mask,
                                     target_mask_lr,
                                     unistep=downsample,
                                     datatype='byte')

        cmd.extend([
            '-m', '{}[{},{},{}]'.format(cost_function, source_lr, target_lr,
                                        cost_function_par)
        ])
        cmd.extend(['-i', iterations])
        cmd.extend(['-t', transformation])
        cmd.extend(['-r', reg])
        cmd.extend(['--number-of-affine-iterations', affine_iterations])
        cmd.extend(['-o', output_xfm])

        inputs = [source_lr, target_lr]
        if target_mask_lr is not None and use_mask:
            inputs.append(target_mask_lr)
            cmd.extend(['-x', target_mask_lr])

        if use_histogram_matching:
            cmd.append('--use-Histogram-Matching')

        outputs = [output_xfm]  # TODO: add inverse xfm ?

        #print(repr(cmd))

        minc.command(cmd, inputs=inputs, outputs=outputs)