Пример #1
0
def apply_transfo(im_src, im_dest, warp, interp='spline', rm_tmp=True):
    # create tmp dir and go in it
    tmp_dir = tmp_create()
    # copy warping field to tmp dir
    shutil.copy(warp, tmp_dir)
    warp = ''.join(extract_fname(warp)[1:])
    # go to tmp dir
    os.chdir(tmp_dir)
    # save image and seg
    fname_src = 'src.nii.gz'
    im_src.setFileName(fname_src)
    im_src.save()
    fname_dest = 'dest.nii.gz'
    im_dest.setFileName(fname_dest)
    im_dest.save()
    # apply warping field
    fname_src_reg = add_suffix(fname_src, '_reg')
    sct_apply_transfo.main(args=['-i', fname_src,
                                  '-d', fname_dest,
                                  '-w', warp,
                                  '-x', interp])

    im_src_reg = Image(fname_src_reg)
    # get out of tmp dir
    os.chdir('..')
    if rm_tmp:
        # remove tmp dir
        shutil.rmtree(tmp_dir)
    # return res image
    return im_src_reg
def apply_transfo(im_src, im_dest, warp, interp='spline', rm_tmp=True):
    # create tmp dir and go in it
    tmp_dir = sct.tmp_create()
    # copy warping field to tmp dir
    sct.copy(warp, tmp_dir)
    warp = ''.join(extract_fname(warp)[1:])
    # go to tmp dir
    curdir = os.getcwd()
    os.chdir(tmp_dir)
    # save image and seg
    fname_src = 'src.nii.gz'
    im_src.save(fname_src)
    fname_dest = 'dest.nii.gz'
    im_dest.save(fname_dest)
    # apply warping field
    fname_src_reg = add_suffix(fname_src, '_reg')
    sct_apply_transfo.main(
        args=['-i', fname_src, '-d', fname_dest, '-w', warp, '-x', interp])

    im_src_reg = Image(fname_src_reg)
    # get out of tmp dir
    os.chdir(curdir)
    if rm_tmp:
        # remove tmp dir
        sct.rmtree(tmp_dir)
    # return res image
    return im_src_reg
def test_integrity(param_test):
    """
    Test integrity of function
    """
    # fetch index of the test being performed
    index_args = param_test.default_args.index(param_test.args)

    # apply transformation to binary mask: template --> anat
    sct_apply_transfo.main(args=[
        '-i', param_test.fname_gt[index_args], '-d', param_test.file_seg, '-w',
        'warp_template2anat.nii.gz', '-o', 'test_template2anat.nii.gz', '-x',
        'nn', '-v', '0'
    ])

    # apply transformation to binary mask: anat --> template
    sct_apply_transfo.main(args=[
        '-i', param_test.file_seg, '-d', param_test.fname_gt[index_args], '-w',
        'warp_anat2template.nii.gz', '-o', 'test_anat2template.nii.gz', '-x',
        'nn', '-v', '0'
    ])

    # compute dice coefficient between template segmentation warped to anat and segmentation from anat
    im_seg = Image(param_test.file_seg)
    im_template_seg_reg = Image('test_template2anat.nii.gz')
    dice_template2anat = msct_image.compute_dice(im_seg,
                                                 im_template_seg_reg,
                                                 mode='3d',
                                                 zboundaries=True)
    # check
    param_test.output += 'Dice[seg,template_seg_reg]: ' + str(
        dice_template2anat)
    if dice_template2anat > param_test.dice_threshold:
        param_test.output += '\n--> PASSED'
    else:
        param_test.status = 99
        param_test.output += '\n--> FAILED'

    # compute dice coefficient between anat segmentation warped to template and segmentation from template
    im_seg_reg = Image('test_anat2template.nii.gz')
    im_template_seg = Image(param_test.fname_gt[index_args])
    dice_anat2template = msct_image.compute_dice(im_seg_reg,
                                                 im_template_seg,
                                                 mode='3d',
                                                 zboundaries=True)
    # check
    param_test.output += '\n\nDice[seg_reg,template_seg]: ' + str(
        dice_anat2template)
    if dice_anat2template > param_test.dice_threshold:
        param_test.output += '\n--> PASSED'
    else:
        param_test.status = 99
        param_test.output += '\n--> FAILED'

    # update Panda structure
    param_test.results['dice_template2anat'] = dice_template2anat
    param_test.results['dice_anat2template'] = dice_anat2template

    return param_test
Пример #4
0
def merge_images(list_fname_src, fname_dest, list_fname_warp, param):
    """
    Merge multiple source images onto destination space. All images are warped to the destination space and then added.
    To deal with overlap during merging (e.g. one voxel in destination image is shared with two input images), the
    resulting voxel is divided by the sum of the partial volume of each image. For example, if src(x,y,z)=1 is mapped to
    dest(i,j,k) with a partial volume of 0.5 (because destination voxel is bigger), then its value after linear interpolation
    will be 0.5. To account for partial volume, the resulting voxel will be: dest(i,j,k) = 0.5*0.5/0.5 = 0.5.
    Now, if two voxels overlap in the destination space, let's say: src(x,y,z)=1 and src2'(x',y',z')=1, then the
    resulting value will be: dest(i,j,k) = (0.5*0.5 + 0.5*0.5) / (0.5+0.5) = 0.5. So this function acts like a weighted
    average operator, only in destination voxels that share multiple source voxels.

    Parameters
    ----------
    list_fname_src
    fname_dest
    list_fname_warp
    param

    Returns
    -------

    """

    # create temporary folder
    path_tmp = sct.tmp_create()

    # get dimensions of destination file
    nii_dest = msct_image.Image(fname_dest)

    # initialize variables
    data = np.zeros([
        nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2],
        len(list_fname_src)
    ])
    partial_volume = np.zeros([
        nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2],
        len(list_fname_src)
    ])
    data_merge = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2]])

    # loop across files
    i_file = 0
    for fname_src in list_fname_src:

        # apply transformation src --> dest
        sct_apply_transfo.main(args=[
            '-i', fname_src, '-d', fname_dest, '-w', list_fname_warp[i_file],
            '-x', param.interp, '-o', 'src_' + str(i_file) +
            '_template.nii.gz', '-v', param.verbose
        ])

        # create binary mask from input file by assigning one to all non-null voxels
        sct_maths.main(args=[
            '-i', fname_src, '-bin',
            str(param.almost_zero), '-o', 'src_' + str(i_file) +
            'native_bin.nii.gz'
        ])

        # apply transformation to binary mask to compute partial volume
        sct_apply_transfo.main(args=[
            '-i', 'src_' + str(i_file) + 'native_bin.nii.gz', '-d', fname_dest,
            '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' +
            str(i_file) + '_template_partialVolume.nii.gz'
        ])

        # open data
        data[:, :, :, i_file] = msct_image.Image('src_' + str(i_file) +
                                                 '_template.nii.gz').data
        partial_volume[:, :, :, i_file] = msct_image.Image(
            'src_' + str(i_file) + '_template_partialVolume.nii.gz').data
        i_file += 1

    # merge files using partial volume information (and convert nan resulting from division by zero to zeros)
    data_merge = np.divide(np.sum(data * partial_volume, axis=3),
                           np.sum(partial_volume, axis=3))
    data_merge = np.nan_to_num(data_merge)

    # write result in file
    nii_dest.data = data_merge
    nii_dest.save(param.fname_out)

    # remove temporary folder
    if param.rm_tmp:
        sct.rmtree(path_tmp)
Пример #5
0
def moco(param):

    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    sct.printv('\nInput parameters:', param.verbose)
    sct.printv('  Input file ............' + file_data, param.verbose)
    sct.printv('  Reference file ........' + file_target, param.verbose)
    sct.printv('  Polynomial degree .....' + param.poly, param.verbose)
    sct.printv('  Smoothing kernel ......' + param.smooth, param.verbose)
    sct.printv('  Gradient step .........' + param.gradStep, param.verbose)
    sct.printv('  Metric ................' + param.metric, param.verbose)
    sct.printv('  Sampling ..............' + param.sampling, param.verbose)
    sct.printv('  Todo ..................' + todo, param.verbose)
    sct.printv('  Mask  .................' + param.fname_mask, param.verbose)
    sct.printv('  Output mat folder .....' + folder_mat, param.verbose)

    # create folder for mat files
    sct.create_folder(folder_mat)

    # Get size of data
    sct.printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv(
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),
        verbose)

    # copy file_target to a temporary file
    sct.printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target)

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
            im_z.save()
            file_data_splitZ.append(im_z.absolutepath)
        # z-split target
        im_targetz_list = split_data(Image(file_target),
                                     dim=dim_sag,
                                     squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
            im_targetz.save()
            file_target_splitZ.append(im_targetz.absolutepath)
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
                                       dim=dim_sag,
                                       squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
                im_maskz.save()
                file_mask_splitZ.append(im_maskz.absolutepath)
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
    else:
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    sct.printv(
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial'
    )
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # sct.printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
            im_zt.save(verbose=0)
            file_data_splitZ_splitT.append(im_zt.absolutepath)
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in tqdm(range(nt),
                                 unit='iter',
                                 unit_scale=False,
                                 desc="Z=" + str(iz) + "/" +
                                 str(len(file_data_splitZ) - 1),
                                 ascii=True,
                                 ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                folder_mat,
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(
                sct.add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking
            if not param.fname_mask == '':
                input_mask = im_maskz_list[iz]
            else:
                input_mask = None
            # run 3D registration
            failed_transfo[it] = register(param,
                                          file_data_splitZ_splitT[it],
                                          file_target_splitZ[iz],
                                          file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it],
                                          im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz
                im_targetz.save(verbose=0)

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                sct.printv(
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                sct.copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                         file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(args=[
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                ])
            else:
                # exit program if no transformation exists.
                sct.printv(
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,
                    'error')
                sys.exit(2)

        # Merge data along T
        file_data_splitZ_moco.append(sct.add_suffix(file, suffix))
        if todo != 'estimate':
            im_out = concat_data(file_data_splitZ_splitT_moco, 3)
            im_out.save(file_data_splitZ_moco[iz])

    # If sagittal, merge along Z
    if param.is_sagittal:
        im_out = concat_data(file_data_splitZ_moco, 2)
        dirname, basename, ext = sct.extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.save(path_out)

    return file_mat
Пример #6
0
def register(param, file_src, file_dest, file_mat, file_out, im_mask=None):
    """
    Register two images by estimating slice-wise Tx and Ty transformations, which are regularized along Z. This function
    uses ANTs' isct_antsSliceRegularizedRegistration.
    :param param:
    :param file_src:
    :param file_dest:
    :param file_mat:
    :param file_out:
    :param im_mask: Image of mask, could be 2D or 3D
    :return:
    """

    # TODO: deal with mask

    # initialization
    failed_transfo = 0  # by default, failed matrix is 0 (i.e., no failure)
    do_registration = True

    # get metric radius (if MeanSquares, CC) or nb bins (if MI)
    if param.metric == 'MI':
        metric_radius = '16'
    else:
        metric_radius = '4'
    file_out_concat = file_out

    kw = dict()
    im_data = Image(
        file_src
    )  # TODO: pass argument to use antsReg instead of opening Image each time

    # register file_src to file_dest
    if param.todo == 'estimate' or param.todo == 'estimate_and_apply':
        # If orientation is sagittal, use antsRegistration in 2D mode
        # Note: the parameter --restrict-deformation is irrelevant with affine transfo
        if im_data.orientation[2] in 'LR':
            cmd = [
                'isct_antsRegistration', '-d', '2', '--transform',
                'Affine[%s]' % param.gradStep, '--metric',
                param.metric + '[' + file_dest + ',' + file_src + ',1,' +
                metric_radius + ',Regular,' + param.sampling + ']',
                '--convergence', param.iter, '--shrink-factors', '1',
                '--smoothing-sigmas', param.smooth, '--verbose', '1',
                '--output', '[' + file_mat + ',' + file_out_concat + ']'
            ]
            cmd += sct.get_interpolation('isct_antsRegistration', param.interp)
            if im_mask is not None:
                # if user specified a mask, make sure there are non-null voxels in the image before running the registration
                if np.count_nonzero(im_mask.data):
                    cmd += ['--masks', im_mask.absolutepath]
                else:
                    # Mask only contains zeros. Copying the image instead of estimating registration.
                    sct.copy(file_src, file_out_concat, verbose=0)
                    do_registration = False
                    # TODO: create affine mat file with identity, in case used by -g 2
        # 3D mode
        else:
            cmd = [
                'isct_antsSliceRegularizedRegistration', '--polydegree',
                param.poly, '--transform',
                'Translation[%s]' % param.gradStep, '--metric',
                param.metric + '[' + file_dest + ',' + file_src + ',1,' +
                metric_radius + ',Regular,' + param.sampling + ']',
                '--iterations', param.iter, '--shrinkFactors', '1',
                '--smoothingSigmas', param.smooth, '--verbose', '1',
                '--output', '[' + file_mat + ',' + file_out_concat + ']'
            ]
            cmd += sct.get_interpolation(
                'isct_antsSliceRegularizedRegistration', param.interp)
            if im_mask is not None:
                cmd += ['--mask', im_mask.absolutepath]
        # run command
        if do_registration:
            kw.update(dict(is_sct_binary=True))
            env = dict()
            env.update(os.environ)
            env = kw.get("env", env)
            # reducing the number of CPU used for moco (see issue #201)
            env["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "1"
            status, output = sct.run(cmd, verbose=0, **kw)

    elif param.todo == 'apply':
        sct_apply_transfo.main(args=[
            '-i', file_src, '-d', file_dest, '-w', file_mat +
            'Warp.nii.gz', '-o', file_out_concat, '-x', param.interp, '-v', '0'
        ])

    # check if output file exists
    if not os.path.isfile(file_out_concat):
        # sct.printv(output, verbose, 'error')
        sct.printv(
            'WARNING in ' + os.path.basename(__file__) +
            ': No output. Maybe related to improper calculation of '
            'mutual information. Either the mask you provided is '
            'too small, or the subject moved a lot. If you see too '
            'many messages like this try with a bigger mask. '
            'Using previous transformation for this volume (if it'
            'exists).', param.verbose, 'warning')
        failed_transfo = 1

    # TODO: if sagittal, copy header (because ANTs screws it) and add singleton in 3rd dimension (for z-concatenation)
    if im_data.orientation[2] in 'LR' and do_registration:
        im_out = Image(file_out_concat)
        im_out.header = im_data.header
        im_out.data = np.expand_dims(im_out.data, 2)
        im_out.save(file_out, verbose=0)

    # return status of failure
    return failed_transfo
Пример #7
0
def main(args=None):

    # initializations
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    if '-l' in arguments:
        fname_landmarks = arguments['-l']
        label_type = 'body'
    elif '-ldisc' in arguments:
        fname_landmarks = arguments['-ldisc']
        label_type = 'disc'
    else:
        sct.printv('ERROR: Labels should be provided.', 1, 'error')
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''

    param.path_qc = arguments.get("-qc", None)

    path_template = arguments['-t']
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    param.remove_temp_files = int(arguments.get('-r'))
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    param_centerline = ParamCenterline(
        algo_fitting=arguments['-centerline-algo'],
        smooth=arguments['-centerline-smooth'])
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    zsubsample = param.zsubsample

    # retrieve template file names
    file_template_vertebral_labeling = get_file_label(os.path.join(path_template, 'template'), 'vertebral labeling')
    file_template = get_file_label(os.path.join(path_template, 'template'), contrast_template.upper() + '-weighted template')
    file_template_seg = get_file_label(os.path.join(path_template, 'template'), 'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = os.path.join(path_template, 'template', file_template)
    fname_template_vertebral_labeling = os.path.join(path_template, 'template', file_template_vertebral_labeling)
    fname_template_seg = os.path.join(path_template, 'template', file_template_seg)
    fname_template_disc_labeling = os.path.join(path_template, 'template', 'PAM50_label_disc.nii.gz')

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # sct.printv(arguments)
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 ' + fname_data, verbose)
    sct.printv('  Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('  Segmentation:         ' + fname_seg, verbose)
    sct.printv('  Path template:        ' + path_template, verbose)
    sct.printv('  Remove temp files:    ' + str(param.remove_temp_files), verbose)

    # check input labels
    labels = check_labels(fname_landmarks, label_type=label_type)

    vertebral_alignment = False
    if len(labels) > 2 and label_type == 'disc':
        vertebral_alignment = True

    path_tmp = sct.tmp_create(basename="register_to_template", verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    Image(fname_data).save(os.path.join(path_tmp, ftmp_data))
    Image(fname_seg).save(os.path.join(path_tmp, ftmp_seg))
    Image(fname_landmarks).save(os.path.join(path_tmp, ftmp_label))
    Image(fname_template).save(os.path.join(path_tmp, ftmp_template))
    Image(fname_template_seg).save(os.path.join(path_tmp, ftmp_template_seg))
    Image(fname_template_vertebral_labeling).save(os.path.join(path_tmp, ftmp_template_label))
    if label_type == 'disc':
        Image(fname_template_disc_labeling).save(os.path.join(path_tmp, ftmp_template_label))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Generate labels from template vertebral labeling
    if label_type == 'body':
        sct.printv('\nGenerate labels from template vertebral labeling', verbose)
        ftmp_template_label_, ftmp_template_label = ftmp_template_label, sct.add_suffix(ftmp_template_label, "_body")
        sct_label_utils.main(args=['-i', ftmp_template_label_, '-vert-body', '0', '-o', ftmp_template_label])

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # if only one label is present, force affine transformation to be Tx,Ty,Tz only (no scaling)
    if len(labels) == 1:
        paramreg.steps['0'].dof = 'Tx_Ty_Tz'
        sct.printv('WARNING: Only one label is present. Forcing initial transformation to: ' + paramreg.steps['0'].dof,
                   1, 'warning')

    # Project labels onto the spinal cord centerline because later, an affine transformation is estimated between the
    # template's labels (centered in the cord) and the subject's labels (assumed to be centered in the cord).
    # If labels are not centered, mis-registration errors are observed (see issue #1826)
    ftmp_label = project_labels_on_spinalcord(ftmp_label, ftmp_seg, param_centerline)

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    ftmp_seg_, ftmp_seg = ftmp_seg, sct.add_suffix(ftmp_seg, "_bin")
    sct_maths.main(['-i', ftmp_seg_,
                    '-bin', '0.5',
                    '-o', ftmp_seg])

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        resample_file(ftmp_data, add_suffix(ftmp_data, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        resample_file(ftmp_seg, add_suffix(ftmp_seg, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling
        # with nearest neighbour can make them disappear.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)

        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath


        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        if vertebral_alignment:
            # cropping the segmentation based on the label coverage to ensure good registration with vertebral alignment
            # See https://github.com/neuropoly/spinalcordtoolbox/pull/1669 for details
            image_labels = Image(ftmp_label)
            coordinates_labels = image_labels.getNonZeroCoordinates(sorting='z')
            nx, ny, nz, nt, px, py, pz, pt = image_labels.dim
            offset_crop = 10.0 * pz  # cropping the image 10 mm above and below the highest and lowest label
            cropping_slices = [coordinates_labels[0].z - offset_crop, coordinates_labels[-1].z + offset_crop]
            # make sure that the cropping slices do not extend outside of the slice range (issue #1811)
            if cropping_slices[0] < 0:
                cropping_slices[0] = 0
            if cropping_slices[1] > nz:
                cropping_slices[1] = nz
            msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, np.int32(np.round(cropping_slices))),))).save(ftmp_seg)
        else:
            # if we do not align the vertebral levels, we crop the segmentation from top to bottom
            im_seg_rpi = Image(ftmp_seg_)
            bottom = 0
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "IS"):
                if (data != 0).any():
                    break
                bottom += 1
            top = im_seg_rpi.data.shape[2]
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "SI"):
                if (data != 0).any():
                    break
                top -= 1
            msct_image.spatial_crop(im_seg_rpi, dict(((2, (bottom, top)),))).save(ftmp_seg)


        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)

        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        fn_warp_curve2straight = os.path.join(curdir, "warp_curve2straight.nii.gz")
        fn_warp_straight2curve = os.path.join(curdir, "warp_straight2curve.nii.gz")
        fn_straight_ref = os.path.join(curdir, "straight_ref.nii.gz")

        cache_input_files=[ftmp_seg]
        if vertebral_alignment:
            cache_input_files += [
             ftmp_template_seg,
             ftmp_label,
             ftmp_template_label,
            ]
        cache_sig = sct.cache_signature(
         input_files=cache_input_files,
        )
        cachefile = os.path.join(curdir, "straightening.cache")
        if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(fn_warp_curve2straight) and os.path.isfile(fn_warp_straight2curve) and os.path.isfile(fn_straight_ref):
            sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
            sct.copy(fn_warp_curve2straight, 'warp_curve2straight.nii.gz')
            sct.copy(fn_warp_straight2curve, 'warp_straight2curve.nii.gz')
            sct.copy(fn_straight_ref, 'straight_ref.nii.gz')
            # apply straightening
            sct_apply_transfo.main(args=[
                '-i', ftmp_seg,
                '-w', 'warp_curve2straight.nii.gz',
                '-d', 'straight_ref.nii.gz',
                '-o', add_suffix(ftmp_seg, '_straight')])
        else:
            from spinalcordtoolbox.straightening import SpinalCordStraightener
            sc_straight = SpinalCordStraightener(ftmp_seg, ftmp_seg)
            sc_straight.param_centerline = param_centerline
            sc_straight.output_filename = add_suffix(ftmp_seg, '_straight')
            sc_straight.path_output = './'
            sc_straight.qc = '0'
            sc_straight.remove_temp_files = param.remove_temp_files
            sc_straight.verbose = verbose

            if vertebral_alignment:
                sc_straight.centerline_reference_filename = ftmp_template_seg
                sc_straight.use_straight_reference = True
                sc_straight.discs_input_filename = ftmp_label
                sc_straight.discs_ref_filename = ftmp_template_label

            sc_straight.straighten()
            sct.cache_save(cachefile, cache_sig)

        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct_concat_transfo.main(args=[
            '-w', 'warp_straight2curve.nii.gz',
            '-d', ftmp_data,
            '-o', 'warp_straight2curve.nii.gz'])

        if vertebral_alignment:
            sct.copy('warp_curve2straight.nii.gz', 'warp_curve2straightAffine.nii.gz')
        else:
            # Label preparation:
            # --------------------------------------------------------------------------------
            # Remove unused label on template. Keep only label present in the input label image
            sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
            sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

            # Dilating the input label so they can be straighten without losing them
            sct.printv('\nDilating input labels using 3vox ball radius')
            sct_maths.main(['-i', ftmp_label,
                            '-dilate', '3',
                            '-o', add_suffix(ftmp_label, '_dilate')])
            ftmp_label = add_suffix(ftmp_label, '_dilate')

            # Apply straightening to labels
            sct.printv('\nApply straightening to labels...', verbose)
            sct_apply_transfo.main(args=[
                '-i', ftmp_label,
                '-o', add_suffix(ftmp_label, '_straight'),
                '-d', add_suffix(ftmp_seg, '_straight'),
                '-w', 'warp_curve2straight.nii.gz',
                '-x', 'nn'])
            ftmp_label = add_suffix(ftmp_label, '_straight')

            # Compute rigid transformation straight landmarks --> template landmarks
            sct.printv('\nEstimate transformation for step #0...', verbose)
            try:
                register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof,
                                   fname_affine='straight2templateAffine.txt', verbose=verbose)
            except RuntimeError:
                raise('Input labels do not seem to be at the right place. Please check the position of the labels. '
                      'See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42')

            # Concatenate transformations: curve --> straight --> affine
            sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
            sct_concat_transfo.main(args=[
                '-w', ['warp_curve2straight.nii.gz', 'straight2templateAffine.txt'],
                '-d', 'template.nii',
                '-o', 'warp_curve2straightAffine.nii.gz'])

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct_apply_transfo.main(args=[
            '-i', ftmp_data,
            '-o', add_suffix(ftmp_data, '_straightAffine'),
            '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz'])
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct_apply_transfo.main(args=[
            '-i', ftmp_seg,
            '-o', add_suffix(ftmp_seg, '_straightAffine'),
            '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz',
            '-x', 'linear'])
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(np.round(np.min(points_straight))), int(np.round(np.max(points_straight)))
        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_black')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (min_point,max_point)),))).save(ftmp_seg)

        """
        # open segmentation
        im = Image(ftmp_seg)
        im_new = msct_image.empty_like(im)
        # binarize
        im_new.data = im.data > 0.5
        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = msct_image.find_zmin_zmax(im_new, threshold=0.5)
        # save binarized segmentation
        im_new.save(add_suffix(ftmp_seg, '_bin')) # unused?
        # crop template in z-direction (for faster processing)
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        ftmp_template_, ftmp_template = ftmp_template, add_suffix(ftmp_template, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template)

        ftmp_template_seg_, ftmp_template_seg = ftmp_template_seg, add_suffix(ftmp_template_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template_seg)

        ftmp_data_, ftmp_data = ftmp_data, add_suffix(ftmp_data, '_crop')
        msct_image.spatial_crop(Image(ftmp_data_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_data)

        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_seg)

        # sub-sample in z-direction
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run(['sct_resample', '-i', ftmp_template, '-o', add_suffix(ftmp_template, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run(['sct_resample', '-i', ftmp_template_seg, '-o', add_suffix(ftmp_template_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run(['sct_resample', '-i', ftmp_data, '-o', add_suffix(ftmp_data, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run(['sct_resample', '-i', ftmp_seg, '-o', add_suffix(ftmp_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')

            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':
                src_seg = ftmp_seg
                dest_seg = ftmp_template_seg
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # apply transformation from previous step, to use as new src for registration
                sct_apply_transfo.main(args=[
                    '-i', src,
                    '-d', dest,
                    '-w', warp_forward,
                    '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                    '-x', interp_step])
                src = add_suffix(src, '_regStep' + str(i_step - 1))
                if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':  # also apply transformation to the seg
                    sct_apply_transfo.main(args=[
                        '-i', src_seg,
                        '-d', dest_seg,
                        '-w', warp_forward,
                        '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                        '-x', interp_step])
                    src_seg = add_suffix(src_seg, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog': # im_seg case
                warp_forward_out, warp_inverse_out = register([src, src_seg], [dest, dest_seg], paramreg, param, str(i_step))
            else:
                warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations: anat --> template
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        warp_forward.insert(0, 'warp_curve2straightAffine.nii.gz')
        sct_concat_transfo.main(args=[
            '-w', warp_forward,
            '-d', 'template.nii',
            '-o', 'warp_anat2template.nii.gz'])

        # Concatenate transformations: template --> anat
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()
        if vertebral_alignment:
            warp_inverse.append('warp_straight2curve.nii.gz')
            sct_concat_transfo.main(args=[
                '-w', warp_inverse,
                '-d', 'data.nii',
                '-o', 'warp_template2anat.nii.gz'])
        else:
            warp_inverse.append('straight2templateAffine.txt')
            warp_inverse.append('warp_straight2curve.nii.gz')
            sct_concat_transfo.main(args=[
                '-w', warp_inverse,
                '-winv', ['straight2templateAffine.txt'],
                '-d', 'data.nii',
                '-o', 'warp_template2anat.nii.gz'])

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This
        # new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = np.round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x), int(new_label.y), int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.absolutepath = 'label_rpi_modif.nii.gz'
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc="./")
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct_apply_transfo.main(args=[
                '-i', src,
                '-d', dest,
                '-w', warp_forward,
                '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                '-x', interp_step])
            src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct_concat_transfo.main(args=[
            '-w', warp_forward,
            '-d', 'data.nii',
            '-o', 'warp_template2anat.nii.gz'])
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct_concat_transfo.main(args=[
            '-w', warp_inverse,
            '-winv', ['template2subjectAffine.txt'],
            '-d', 'template.nii',
            '-o', 'warp_anat2template.nii.gz'])

    # Apply warping fields to anat and template
    sct.run(['sct_apply_transfo', '-i', 'template.nii', '-o', 'template2anat.nii.gz', '-d', 'data.nii', '-w', 'warp_template2anat.nii.gz', '-crop', '1'], verbose)
    sct.run(['sct_apply_transfo', '-i', 'data.nii', '-o', 'anat2template.nii.gz', '-d', 'template.nii', '-w', 'warp_anat2template.nii.gz', '-crop', '1'], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    fname_template2anat = os.path.join(path_output, 'template2anat' + ext_data)
    fname_anat2template = os.path.join(path_output, 'anat2template' + ext_data)
    sct.generate_output_file(os.path.join(path_tmp, "warp_template2anat.nii.gz"), os.path.join(path_output, "warp_template2anat.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "warp_anat2template.nii.gz"), os.path.join(path_output, "warp_anat2template.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "template2anat.nii.gz"), fname_template2anat, verbose)
    sct.generate_output_file(os.path.join(path_tmp, "anat2template.nii.gz"), fname_anat2template, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose)

    # Delete temporary files
    if param.remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)
    if param.path_qc is not None:
        generate_qc(fname_data, fname_in2=fname_template2anat, fname_seg=fname_seg, args=args,
                    path_qc=os.path.abspath(param.path_qc), dataset=qc_dataset, subject=qc_subject,
                    process='sct_register_to_template')
    sct.display_viewer_syntax([fname_data, fname_template2anat], verbose=verbose)
    sct.display_viewer_syntax([fname_template, fname_anat2template], verbose=verbose)
def test_integrity(param_test):
    """
    Test integrity of function
    """

    if param_test.args.startswith(default_args[1]):
        return param_test  # no integrity test

    # apply transformation to binary mask: template --> anat
    sct_apply_transfo.main(args=[
        '-i', param_test.fname_gt, '-d', param_test.dict_args_with_path['-s'],
        '-w',
        os.path.join(param_test.path_output,
                     'warp_template2anat.nii.gz'), '-o',
        os.path.join(param_test.path_output,
                     'test_template2anat.nii.gz'), '-x', 'nn', '-v', '0'
    ])

    # apply transformation to binary mask: anat --> template
    sct_apply_transfo.main(args=[
        '-i', param_test.dict_args_with_path['-s'], '-d', param_test.fname_gt,
        '-w',
        os.path.join(param_test.path_output,
                     'warp_anat2template.nii.gz'), '-o',
        os.path.join(param_test.path_output,
                     'test_anat2template.nii.gz'), '-x', 'nn', '-v', '0'
    ])

    # compute dice coefficient between template segmentation warped to anat and segmentation from anat
    im_seg = Image(param_test.dict_args_with_path['-s'])
    im_template_seg_reg = Image(
        os.path.join(param_test.path_output, 'test_template2anat.nii.gz'))
    dice_template2anat = msct_image.compute_dice(im_seg,
                                                 im_template_seg_reg,
                                                 mode='3d',
                                                 zboundaries=True)
    # check
    param_test.output += 'Dice[seg,template_seg_reg]: ' + str(
        dice_template2anat)
    if dice_template2anat > param_test.dice_threshold:
        param_test.output += '\n--> PASSED'
    else:
        param_test.status = 99
        param_test.output += '\n--> FAILED'

    # compute dice coefficient between anat segmentation warped to template and segmentation from template
    im_seg_reg = Image(
        os.path.join(param_test.path_output, 'test_anat2template.nii.gz'))
    im_template_seg = Image(param_test.fname_gt)
    dice_anat2template = msct_image.compute_dice(im_seg_reg,
                                                 im_template_seg,
                                                 mode='3d',
                                                 zboundaries=True)
    # check
    param_test.output += '\n\nDice[seg_reg,template_seg]: ' + str(
        dice_anat2template)
    if dice_anat2template > param_test.dice_threshold:
        param_test.output += '\n--> PASSED'
    else:
        param_test.status = 99
        param_test.output += '\n--> FAILED'

    # update Panda structure
    param_test.results['dice_template2anat'] = dice_template2anat
    param_test.results['dice_anat2template'] = dice_anat2template

    return param_test
Пример #9
0
def register(param, file_src, file_dest, file_mat, file_out, im_mask=None):
    """
    Register two images by estimating slice-wise Tx and Ty transformations, which are regularized along Z. This function
    uses ANTs' isct_antsSliceRegularizedRegistration.

    :param param:
    :param file_src:
    :param file_dest:
    :param file_mat:
    :param file_out:
    :param im_mask: Image of mask, could be 2D or 3D
    :return:
    """

    # TODO: deal with mask

    # initialization
    failed_transfo = 0  # by default, failed matrix is 0 (i.e., no failure)
    do_registration = True

    # get metric radius (if MeanSquares, CC) or nb bins (if MI)
    if param.metric == 'MI':
        metric_radius = '16'
    else:
        metric_radius = '4'
    file_out_concat = file_out

    kw = dict()
    im_data = Image(
        file_src
    )  # TODO: pass argument to use antsReg instead of opening Image each time

    # register file_src to file_dest
    if param.todo == 'estimate' or param.todo == 'estimate_and_apply':
        # If orientation is sagittal, use antsRegistration in 2D mode
        # Note: the parameter --restrict-deformation is irrelevant with affine transfo

        if param.sampling == 'None':
            # 'None' sampling means 'fully dense' sampling
            # see https://github.com/ANTsX/ANTs/wiki/antsRegistration-reproducibility-issues
            sampling = param.sampling
        else:
            # param.sampling should be a float in [0,1], and means the
            # samplingPercentage that chooses a subset of points to
            # estimate from. We always use 'Regular' (evenly-spaced)
            # mode, though antsRegistration offers 'Random' as well.
            # Be aware: even 'Regular' is not fully deterministic:
            # > Regular includes a random perturbation on the grid sampling
            # - https://github.com/ANTsX/ANTs/issues/976#issuecomment-602313884
            sampling = 'Regular,' + param.sampling

        if im_data.orientation[2] in 'LR':
            cmd = [
                'isct_antsRegistration', '-d', '2', '--transform',
                'Affine[%s]' % param.gradStep, '--metric',
                param.metric + '[' + file_dest + ',' + file_src + ',1,' +
                metric_radius + ',' + sampling + ']', '--convergence',
                param.iter, '--shrink-factors', '1', '--smoothing-sigmas',
                param.smooth, '--verbose', '1', '--output',
                '[' + file_mat + ',' + file_out_concat + ']'
            ]
            cmd += get_interpolation('isct_antsRegistration', param.interp)
            if im_mask is not None:
                # if user specified a mask, make sure there are non-null voxels in the image before running the registration
                if np.count_nonzero(im_mask.data):
                    cmd += ['--masks', im_mask.absolutepath]
                else:
                    # Mask only contains zeros. Copying the image instead of estimating registration.
                    copy(file_src, file_out_concat, verbose=0)
                    do_registration = False
                    # TODO: create affine mat file with identity, in case used by -g 2
        # 3D mode
        else:
            cmd = [
                'isct_antsSliceRegularizedRegistration', '--polydegree',
                param.poly, '--transform',
                'Translation[%s]' % param.gradStep, '--metric',
                param.metric + '[' + file_dest + ',' + file_src + ',1,' +
                metric_radius + ',' + sampling + ']', '--iterations',
                param.iter, '--shrinkFactors', '1', '--smoothingSigmas',
                param.smooth, '--verbose', '1', '--output',
                '[' + file_mat + ',' + file_out_concat + ']'
            ]
            cmd += get_interpolation('isct_antsSliceRegularizedRegistration',
                                     param.interp)
            if im_mask is not None:
                cmd += ['--mask', im_mask.absolutepath]
        # run command
        if do_registration:
            kw.update(dict(is_sct_binary=True))
            # reducing the number of CPU used for moco (see issue #201 and #2642)
            env = {
                **os.environ,
                **{
                    "ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS": "1"
                }
            }
            status, output = run_proc(cmd,
                                      verbose=1 if param.verbose == 2 else 0,
                                      env=env,
                                      **kw)

    elif param.todo == 'apply':
        sct_apply_transfo.main(args=[
            '-i', file_src, '-d', file_dest, '-w', file_mat + param.suffix_mat,
            '-o', file_out_concat, '-x', param.interp, '-v', '0'
        ])

    # check if output file exists
    # Note (from JCA): In the past, i've tried to catch non-zero output from ANTs function (via the 'status' variable),
    # but in some OSs, the function can fail while outputing zero. So as a pragmatic approach, I decided to go with
    # the "output file checking" approach, which is 100% sensitive.
    if not os.path.isfile(file_out_concat):
        # printv(output, verbose, 'error')
        printv(
            'WARNING in ' + os.path.basename(__file__) +
            ': No output. Maybe related to improper calculation of '
            'mutual information. Either the mask you provided is '
            'too small, or the subject moved a lot. If you see too '
            'many messages like this try with a bigger mask. '
            'Using previous transformation for this volume (if it'
            'exists).', param.verbose, 'warning')
        failed_transfo = 1

    # If sagittal, copy header (because ANTs screws it) and add singleton in 3rd dimension (for z-concatenation)
    if im_data.orientation[2] in 'LR' and do_registration:
        im_out = Image(file_out_concat)
        im_out.header = im_data.header
        im_out.data = np.expand_dims(im_out.data, 2)
        im_out.save(file_out, verbose=0)

    # return status of failure
    return failed_transfo
Пример #10
0
def moco(param):
    """
    Main function that performs motion correction.

    :param param:
    :return:
    """
    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + file_data, param.verbose)
    printv('  Reference file ........ ' + file_target, param.verbose)
    printv('  Polynomial degree ..... ' + param.poly, param.verbose)
    printv('  Smoothing kernel ...... ' + param.smooth, param.verbose)
    printv('  Gradient step ......... ' + param.gradStep, param.verbose)
    printv('  Metric ................ ' + param.metric, param.verbose)
    printv('  Sampling .............. ' + param.sampling, param.verbose)
    printv('  Todo .................. ' + todo, param.verbose)
    printv('  Mask  ................. ' + param.fname_mask, param.verbose)
    printv('  Output mat folder ..... ' + folder_mat, param.verbose)

    try:
        os.makedirs(folder_mat)
    except FileExistsError:
        pass

    # Get size of data
    printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv(
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),
        verbose)

    # copy file_target to a temporary file
    printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target, verbose=0)

    # Check if user specified a mask
    if not param.fname_mask == '':
        # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask)
        im_mask = Image(param.fname_mask)
        if not np.array_equal(im_mask.data, im_mask.data.astype(bool)):
            # If it is a soft mask, multiply the target by the soft mask.
            im = Image(file_target)
            im_masked = im.copy()
            im_masked.data = im.data * im_mask.data
            im_masked.save(
                verbose=0)  # silence warning about file overwritting

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
            im_z.save(verbose=0)
            file_data_splitZ.append(im_z.absolutepath)
        # z-split target
        im_targetz_list = split_data(Image(file_target),
                                     dim=dim_sag,
                                     squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
            im_targetz.save(verbose=0)
            file_target_splitZ.append(im_targetz.absolutepath)
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
                                       dim=dim_sag,
                                       squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
                im_maskz.save(verbose=0)
                file_mask_splitZ.append(im_maskz.absolutepath)
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
    else:
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    printv(
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)'
    )
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
            im_zt.save(verbose=0)
            file_data_splitZ_splitT.append(im_zt.absolutepath)
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in sct_progress_bar(range(nt),
                                             unit='iter',
                                             unit_scale=False,
                                             desc="Z=" + str(iz) + "/" +
                                             str(len(file_data_splitZ) - 1),
                                             ascii=False,
                                             ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                folder_mat,
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(
                add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking (except in the 'apply' case, where masking is irrelevant)
            input_mask = None
            if not param.fname_mask == '' and not param.todo == 'apply':
                # Check if mask is binary
                if np.array_equal(im_maskz_list[iz].data,
                                  im_maskz_list[iz].data.astype(bool)):
                    # If it is, pass this mask into register() to be used
                    input_mask = im_maskz_list[iz]
                else:
                    # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks.
                    #  Instead, multiply the input data by the Gaussian mask.
                    im = Image(file_data_splitZ_splitT[it])
                    im_masked = im.copy()
                    im_masked.data = im.data * im_maskz_list[iz].data
                    im_masked.save(
                        verbose=0)  # silence warning about file overwritting

            # run 3D registration
            failed_transfo[it] = register(param,
                                          file_data_splitZ_splitT[it],
                                          file_target_splitZ[iz],
                                          file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it],
                                          im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz
                im_targetz.save(verbose=0)

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                printv(
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                     file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(args=[
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                ])
            else:
                # exit program if no transformation exists.
                printv(
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,
                    'error')
                sys.exit(2)

        # Merge data along T
        file_data_splitZ_moco.append(add_suffix(file, suffix))
        if todo != 'estimate':
            im_out = concat_data(file_data_splitZ_splitT_moco, 3)
            im_out.absolutepath = file_data_splitZ_moco[iz]
            im_out.save(verbose=0)

    # If sagittal, merge along Z
    if param.is_sagittal:
        # TODO: im_out.dim is incorrect: Z value is one
        im_out = concat_data(file_data_splitZ_moco, 2)
        dirname, basename, ext = extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.absolutepath = path_out
        im_out.save(verbose=0)

    return file_mat, im_out
Пример #11
0
def apply_transfo_sct(i, d, w, o):
    sct_apply_transfo.main(
        args=['-i', i, '-d', d, '-w', w, '-o', o, '-x', 'nn', '-v', '0'])
def main(args=None):
    if args is None:
        args = sys.argv[1:]

    # initialize parameters
    param = Param()

    # Initialization
    fname_output = ''
    path_out = ''
    fname_src_seg = ''
    fname_dest_seg = ''
    fname_src_label = ''
    fname_dest_label = ''
    generate_warpinv = 1

    start_time = time.time()

    # get default registration parameters
    # step1 = Paramreg(step='1', type='im', algo='syn', metric='MI', iter='5', shrink='1', smooth='0', gradStep='0.5')
    step0 = Paramreg(
        step='0',
        type='im',
        algo='syn',
        metric='MI',
        iter='0',
        shrink='1',
        smooth='0',
        gradStep='0.5',
        slicewise='0',
        dof='Tx_Ty_Tz_Rx_Ry_Rz')  # only used to put src into dest space
    step1 = Paramreg(step='1', type='im')
    paramreg = ParamregMultiStep([step0, step1])

    parser = get_parser(paramreg=paramreg)

    arguments = parser.parse(args)

    # get arguments
    fname_src = arguments['-i']
    fname_dest = arguments['-d']
    if '-iseg' in arguments:
        fname_src_seg = arguments['-iseg']
    if '-dseg' in arguments:
        fname_dest_seg = arguments['-dseg']
    if '-ilabel' in arguments:
        fname_src_label = arguments['-ilabel']
    if '-dlabel' in arguments:
        fname_dest_label = arguments['-dlabel']
    if '-o' in arguments:
        fname_output = arguments['-o']
    if '-ofolder' in arguments:
        path_out = arguments['-ofolder']
    if '-owarp' in arguments:
        fname_output_warp = arguments['-owarp']
    else:
        fname_output_warp = ''
    if '-initwarp' in arguments:
        fname_initwarp = os.path.abspath(arguments['-initwarp'])
    else:
        fname_initwarp = ''
    if '-initwarpinv' in arguments:
        fname_initwarpinv = os.path.abspath(arguments['-initwarpinv'])
    else:
        fname_initwarpinv = ''
    if '-m' in arguments:
        fname_mask = arguments['-m']
    else:
        fname_mask = ''
    padding = arguments['-z']
    if "-param" in arguments:
        paramreg_user = arguments['-param']
        # update registration parameters
        for paramStep in paramreg_user:
            paramreg.addStep(paramStep)
    path_qc = arguments.get("-qc", None)
    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)

    identity = int(arguments['-identity'])
    interp = arguments['-x']
    remove_temp_files = int(arguments['-r'])
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    # sct.printv(arguments)
    sct.printv('\nInput parameters:')
    sct.printv('  Source .............. ' + fname_src)
    sct.printv('  Destination ......... ' + fname_dest)
    sct.printv('  Init transfo ........ ' + fname_initwarp)
    sct.printv('  Mask ................ ' + fname_mask)
    sct.printv('  Output name ......... ' + fname_output)
    # sct.printv('  Algorithm ........... '+paramreg.algo)
    # sct.printv('  Number of iterations  '+paramreg.iter)
    # sct.printv('  Metric .............. '+paramreg.metric)
    sct.printv('  Remove temp files ... ' + str(remove_temp_files))
    sct.printv('  Verbose ............. ' + str(verbose))

    # update param
    param.verbose = verbose
    param.padding = padding
    param.fname_mask = fname_mask
    param.remove_temp_files = remove_temp_files

    # Get if input is 3D
    sct.printv('\nCheck if input data are 3D...', verbose)
    sct.check_if_3d(fname_src)
    sct.check_if_3d(fname_dest)

    # Check if user selected type=seg, but did not input segmentation data
    if 'paramreg_user' in locals():
        if True in [
                'type=seg' in paramreg_user[i]
                for i in range(len(paramreg_user))
        ]:
            if fname_src_seg == '' or fname_dest_seg == '':
                sct.printv(
                    '\nERROR: if you select type=seg you must specify -iseg and -dseg flags.\n',
                    1, 'error')

    # Extract path, file and extension
    path_src, file_src, ext_src = sct.extract_fname(fname_src)
    path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest)

    # check if source and destination images have the same name (related to issue #373)
    # If so, change names to avoid conflict of result files and warns the user
    suffix_src, suffix_dest = '_reg', '_reg'
    if file_src == file_dest:
        suffix_src, suffix_dest = '_src_reg', '_dest_reg'

    # define output folder and file name
    if fname_output == '':
        path_out = '' if not path_out else path_out  # output in user's current directory
        file_out = file_src + suffix_src
        file_out_inv = file_dest + suffix_dest
        ext_out = ext_src
    else:
        path, file_out, ext_out = sct.extract_fname(fname_output)
        path_out = path if not path_out else path_out
        file_out_inv = file_out + '_inv'

    # create temporary folder
    path_tmp = sct.tmp_create()

    sct.printv('\nCopying input data to tmp folder and convert to nii...',
               verbose)
    Image(fname_src).save(os.path.join(path_tmp, "src.nii"))
    Image(fname_dest).save(os.path.join(path_tmp, "dest.nii"))

    if fname_src_seg:
        Image(fname_src_seg).save(os.path.join(path_tmp, "src_seg.nii"))

    if fname_dest_seg:
        Image(fname_dest_seg).save(os.path.join(path_tmp, "dest_seg.nii"))

    if fname_src_label:
        Image(fname_src_label).save(os.path.join(path_tmp, "src_label.nii"))
        Image(fname_dest_label).save(os.path.join(path_tmp, "dest_label.nii"))

    if fname_mask != '':
        Image(fname_mask).save(os.path.join(path_tmp, "mask.nii.gz"))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # reorient destination to RPI
    Image('dest.nii').change_orientation("RPI").save('dest_RPI.nii')
    if fname_dest_seg:
        Image('dest_seg.nii').change_orientation("RPI").save(
            'dest_seg_RPI.nii')
    if fname_dest_label:
        Image('dest_label.nii').change_orientation("RPI").save(
            'dest_label_RPI.nii')

    if identity:
        # overwrite paramreg and only do one identity transformation
        step0 = Paramreg(step='0',
                         type='im',
                         algo='syn',
                         metric='MI',
                         iter='0',
                         shrink='1',
                         smooth='0',
                         gradStep='0.5')
        paramreg = ParamregMultiStep([step0])

    # Put source into destination space using header (no estimation -- purely based on header)
    # TODO: Check if necessary to do that
    # TODO: use that as step=0
    # sct.printv('\nPut source into destination space using header...', verbose)
    # sct.run('isct_antsRegistration -d 3 -t Translation[0] -m MI[dest_pad.nii,src.nii,1,16] -c 0 -f 1 -s 0 -o
    # [regAffine,src_regAffine.nii] -n BSpline[3]', verbose)
    # if segmentation, also do it for seg

    # initialize list of warping fields
    warp_forward = []
    warp_forward_winv = []
    warp_inverse = []
    warp_inverse_winv = []

    # initial warping is specified, update list of warping fields and skip step=0
    if fname_initwarp:
        sct.printv('\nSkip step=0 and replace with initial transformations: ',
                   param.verbose)
        sct.printv('  ' + fname_initwarp, param.verbose)
        # sct.copy(fname_initwarp, 'warp_forward_0.nii.gz')
        warp_forward = [fname_initwarp]
        start_step = 1
        if fname_initwarpinv:
            warp_inverse = [fname_initwarpinv]
        else:
            sct.printv(
                '\nWARNING: No initial inverse warping field was specified, therefore the inverse warping field '
                'will NOT be generated.', param.verbose, 'warning')
            generate_warpinv = 0
    else:
        start_step = 0

    # loop across registration steps
    for i_step in range(start_step, len(paramreg.steps)):
        sct.printv('\n--\nESTIMATE TRANSFORMATION FOR STEP #' + str(i_step),
                   param.verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'src.nii'
            dest = 'dest_RPI.nii'
            interp_step = 'spline'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'src_seg.nii'
            dest = 'dest_seg_RPI.nii'
            interp_step = 'nn'
        elif paramreg.steps[str(i_step)].type == 'label':
            src = 'src_label.nii'
            dest = 'dest_label_RPI.nii'
            interp_step = 'nn'
        else:
            # src = dest = interp_step = None
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>0, apply warp_forward_concat to the src image to be used
        if i_step > 0:
            sct.printv('\nApply transformation from previous step',
                       param.verbose)
            sct_apply_transfo.main(args=[
                '-i', src, '-d', dest, '-w', warp_forward, '-o',
                sct.add_suffix(src, '_reg'), '-x', interp_step
            ])
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg,
                                                      param, str(i_step))
        # deal with transformations with "-" as prefix. They should be inverted with calling sct_concat_transfo.
        if warp_forward_out[0] == "-":
            warp_forward_out = warp_forward_out[1:]
            warp_forward_winv.append(warp_forward_out)
        if warp_inverse_out[0] == "-":
            warp_inverse_out = warp_inverse_out[1:]
            warp_inverse_winv.append(warp_inverse_out)
        # update list of forward/inverse transformations
        warp_forward.append(warp_forward_out)
        warp_inverse.insert(0, warp_inverse_out)

    # Concatenate transformations
    sct.printv('\nConcatenate transformations...', verbose)
    sct_concat_transfo.main(args=[
        '-w', warp_forward, '-winv', warp_forward_winv, '-d', 'dest.nii', '-o',
        'warp_src2dest.nii.gz'
    ])
    sct_concat_transfo.main(args=[
        '-w', warp_inverse, '-winv', warp_inverse_winv, '-d', 'src.nii', '-o',
        'warp_dest2src.nii.gz'
    ])

    # Apply warping field to src data
    sct.printv('\nApply transfo source --> dest...', verbose)
    sct_apply_transfo.main(args=[
        '-i', 'src.nii', '-d', 'dest.nii', '-w', 'warp_src2dest.nii.gz', '-o',
        'src_reg.nii', '-x', interp
    ])
    sct.printv('\nApply transfo dest --> source...', verbose)
    sct_apply_transfo.main(args=[
        '-i', 'dest.nii', '-d', 'src.nii', '-w', 'warp_dest2src.nii.gz', '-o',
        'dest_reg.nii', '-x', interp
    ])

    # come back
    os.chdir(curdir)

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    # generate: src_reg
    fname_src2dest = sct.generate_output_file(
        os.path.join(path_tmp, "src_reg.nii"),
        os.path.join(path_out, file_out + ext_out), verbose)
    # generate: forward warping field
    if fname_output_warp == '':
        fname_output_warp = os.path.join(
            path_out, 'warp_' + file_src + '2' + file_dest + '.nii.gz')
    sct.generate_output_file(os.path.join(path_tmp, "warp_src2dest.nii.gz"),
                             fname_output_warp, verbose)
    if generate_warpinv:
        # generate: dest_reg
        fname_dest2src = sct.generate_output_file(
            os.path.join(path_tmp, "dest_reg.nii"),
            os.path.join(path_out, file_out_inv + ext_dest), verbose)
        # generate: inverse warping field
        sct.generate_output_file(
            os.path.join(path_tmp, "warp_dest2src.nii.gz"),
            os.path.join(path_out,
                         'warp_' + file_dest + '2' + file_src + '.nii.gz'),
            verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nRemove temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        verbose)

    if path_qc is not None:
        if fname_dest_seg:
            generate_qc(fname_src2dest,
                        fname_in2=fname_dest,
                        fname_seg=fname_dest_seg,
                        args=args,
                        path_qc=os.path.abspath(path_qc),
                        dataset=qc_dataset,
                        subject=qc_subject,
                        process='sct_register_multimodal')
        else:
            sct.printv(
                'WARNING: Cannot generate QC because it requires destination segmentation.',
                1, 'warning')

    if generate_warpinv:
        sct.display_viewer_syntax([fname_src, fname_dest2src], verbose=verbose)
    sct.display_viewer_syntax([fname_dest, fname_src2dest], verbose=verbose)
Пример #13
0
def moco(param):

    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    sct.printv('\nInput parameters:', param.verbose)
    sct.printv('  Input file ............' + file_data, param.verbose)
    sct.printv('  Reference file ........' + file_target, param.verbose)
    sct.printv('  Polynomial degree .....' + param.poly, param.verbose)
    sct.printv('  Smoothing kernel ......' + param.smooth, param.verbose)
    sct.printv('  Gradient step .........' + param.gradStep, param.verbose)
    sct.printv('  Metric ................' + param.metric, param.verbose)
    sct.printv('  Sampling ..............' + param.sampling, param.verbose)
    sct.printv('  Todo ..................' + todo, param.verbose)
    sct.printv('  Mask  .................' + param.fname_mask, param.verbose)
    sct.printv('  Output mat folder .....' + folder_mat, param.verbose)

    # create folder for mat files
    sct.create_folder(folder_mat)

    # Get size of data
    sct.printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv(('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)), verbose)

    # copy file_target to a temporary file
    sct.printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target)

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
            im_z.save()
            file_data_splitZ.append(im_z.absolutepath)
        # z-split target
        im_targetz_list = split_data(Image(file_target), dim=dim_sag, squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
            im_targetz.save()
            file_target_splitZ.append(im_targetz.absolutepath)
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask), dim=dim_sag, squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
                im_maskz.save()
                file_mask_splitZ.append(im_maskz.absolutepath)
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
    else:
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False)
            im_maskz_list = [Image(file_mask)]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    sct.printv('\nRegister. Loop across Z (note: there is only one Z if orientation is axial')
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # sct.printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
            im_zt.save(verbose=0)
            file_data_splitZ_splitT.append(im_zt.absolutepath)
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in tqdm(range(nt), unit='iter', unit_scale=False,
                                 desc="Z=" + str(iz) + "/" + str(len(file_data_splitZ)-1), ascii=True, ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(folder_mat, "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(sct.add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking
            if not param.fname_mask == '':
                input_mask = im_maskz_list[iz]
            else:
                input_mask = None
            # run 3D registration
            failed_transfo[it] = register(param, file_data_splitZ_splitT[it], file_target_splitZ[iz], file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it], im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) + data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz
                im_targetz.save(verbose=0)

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                sct.printv('  transfo #' + str(fT[it]) + ' --> use transfo #' + str(gT[index_good]), verbose)
                # copy transformation
                sct.copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz', file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(args=['-i', file_data_splitZ_splitT[fT[it]],
                                             '-d', file_target,
                                             '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz',
                                             '-o', file_data_splitZ_splitT_moco[fT[it]],
                                             '-x', param.interp])
            else:
                # exit program if no transformation exists.
                sct.printv('\nERROR in ' + os.path.basename(__file__) + ': No good transformation exist. Exit program.\n', verbose, 'error')
                sys.exit(2)

        # Merge data along T
        file_data_splitZ_moco.append(sct.add_suffix(file, suffix))
        if todo != 'estimate':
            im_out = concat_data(file_data_splitZ_splitT_moco, 3)
            im_out.save(file_data_splitZ_moco[iz])

    # If sagittal, merge along Z
    if param.is_sagittal:
        im_out = concat_data(file_data_splitZ_moco, 2)
        dirname, basename, ext = sct.extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.save(path_out)

    return file_mat
Пример #14
0
def register(param, file_src, file_dest, file_mat, file_out, im_mask=None):
    """
    Register two images by estimating slice-wise Tx and Ty transformations, which are regularized along Z. This function
    uses ANTs' isct_antsSliceRegularizedRegistration.
    :param param:
    :param file_src:
    :param file_dest:
    :param file_mat:
    :param file_out:
    :param im_mask: Image of mask, could be 2D or 3D
    :return:
    """

    # TODO: deal with mask

    # initialization
    failed_transfo = 0  # by default, failed matrix is 0 (i.e., no failure)
    do_registration = True

    # get metric radius (if MeanSquares, CC) or nb bins (if MI)
    if param.metric == 'MI':
        metric_radius = '16'
    else:
        metric_radius = '4'
    file_out_concat = file_out

    kw = dict()
    im_data = Image(file_src)  # TODO: pass argument to use antsReg instead of opening Image each time

    # register file_src to file_dest
    if param.todo == 'estimate' or param.todo == 'estimate_and_apply':
        # If orientation is sagittal, use antsRegistration in 2D mode
        # Note: the parameter --restrict-deformation is irrelevant with affine transfo
        if im_data.orientation[2] in 'LR':
            cmd = ['isct_antsRegistration',
                   '-d', '2',
                   '--transform', 'Affine[%s]' %param.gradStep,
                   '--metric', param.metric + '[' + file_dest + ',' + file_src + ',1,' + metric_radius + ',Regular,' + param.sampling + ']',
                   '--convergence', param.iter,
                   '--shrink-factors', '1',
                   '--smoothing-sigmas', param.smooth,
                   '--verbose', '1',
                   '--output', '[' + file_mat + ',' + file_out_concat + ']']
            cmd += sct.get_interpolation('isct_antsRegistration', param.interp)
            if im_mask is not None:
                # if user specified a mask, make sure there are non-null voxels in the image before running the registration
                if np.count_nonzero(im_mask.data):
                    cmd += ['--masks', im_mask.absolutepath]
                else:
                    # Mask only contains zeros. Copying the image instead of estimating registration.
                    sct.copy(file_src, file_out_concat, verbose=0)
                    do_registration = False
                    # TODO: create affine mat file with identity, in case used by -g 2
        # 3D mode
        else:
            cmd = ['isct_antsSliceRegularizedRegistration',
                   '--polydegree', param.poly,
                   '--transform', 'Translation[%s]' %param.gradStep,
                   '--metric', param.metric + '[' + file_dest + ',' + file_src + ',1,' + metric_radius + ',Regular,' + param.sampling + ']',
                   '--iterations', param.iter,
                   '--shrinkFactors', '1',
                   '--smoothingSigmas', param.smooth,
                   '--verbose', '1',
                   '--output', '[' + file_mat + ',' + file_out_concat + ']']
            cmd += sct.get_interpolation('isct_antsSliceRegularizedRegistration', param.interp)
            if im_mask is not None:
                cmd += ['--mask', im_mask.absolutepath]
        # run command
        if do_registration:
            kw.update(dict(is_sct_binary=True))
            env = dict()
            env.update(os.environ)
            env = kw.get("env", env)
            # reducing the number of CPU used for moco (see issue #201)
            env["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "1"
            status, output = sct.run(cmd, verbose=0, **kw)

    elif param.todo == 'apply':
        sct_apply_transfo.main(args=['-i', file_src,
                                     '-d', file_dest,
                                     '-w', file_mat + 'Warp.nii.gz',
                                     '-o', file_out_concat,
                                     '-x', param.interp,
                                     '-v', '0'])

    # check if output file exists
    if not os.path.isfile(file_out_concat):
        # sct.printv(output, verbose, 'error')
        sct.printv('WARNING in ' + os.path.basename(__file__) + ': No output. Maybe related to improper calculation of '
                                                                'mutual information. Either the mask you provided is '
                                                                'too small, or the subject moved a lot. If you see too '
                                                                'many messages like this try with a bigger mask. '
                                                                'Using previous transformation for this volume (if it'
                                                                'exists).', param.verbose, 'warning')
        failed_transfo = 1

    # TODO: if sagittal, copy header (because ANTs screws it) and add singleton in 3rd dimension (for z-concatenation)
    if im_data.orientation[2] in 'LR' and do_registration:
        im_out = Image(file_out_concat)
        im_out.header = im_data.header
        im_out.data = np.expand_dims(im_out.data, 2)
        im_out.save(file_out, verbose=0)

    # return status of failure
    return failed_transfo
def merge_images(list_fname_src, fname_dest, list_fname_warp, param):
    """
    Merge multiple source images onto destination space. All images are warped to the destination space and then added.
    To deal with overlap during merging (e.g. one voxel in destination image is shared with two input images), the
    resulting voxel is divided by the sum of the partial volume of each image. For example, if src(x,y,z)=1 is mapped to
    dest(i,j,k) with a partial volume of 0.5 (because destination voxel is bigger), then its value after linear interpolation
    will be 0.5. To account for partial volume, the resulting voxel will be: dest(i,j,k) = 0.5*0.5/0.5 = 0.5.
    Now, if two voxels overlap in the destination space, let's say: src(x,y,z)=1 and src2'(x',y',z')=1, then the
    resulting value will be: dest(i,j,k) = (0.5*0.5 + 0.5*0.5) / (0.5+0.5) = 0.5. So this function acts like a weighted
    average operator, only in destination voxels that share multiple source voxels.

    Parameters
    ----------
    list_fname_src
    fname_dest
    list_fname_warp
    param

    Returns
    -------

    """

    # create temporary folder
    path_tmp = sct.tmp_create()

    # get dimensions of destination file
    nii_dest = msct_image.Image(fname_dest)

    # initialize variables
    data = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2], len(list_fname_src)])
    partial_volume = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2], len(list_fname_src)])
    data_merge = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2]])

    # loop across files
    i_file = 0
    for fname_src in list_fname_src:

        # apply transformation src --> dest
        sct_apply_transfo.main(args=[
            '-i', fname_src,
            '-d', fname_dest,
            '-w', list_fname_warp[i_file],
            '-x', param.interp,
            '-o', 'src_' + str(i_file) + '_template.nii.gz',
            '-v', param.verbose])

        # create binary mask from input file by assigning one to all non-null voxels
        sct_maths.main(args=[
            '-i', fname_src,
            '-bin', str(param.almost_zero),
            '-o', 'src_' + str(i_file) + 'native_bin.nii.gz'])

        # apply transformation to binary mask to compute partial volume
        sct_apply_transfo.main(args=[
            '-i', 'src_' + str(i_file) + 'native_bin.nii.gz',
            '-d', fname_dest,
            '-w', list_fname_warp[i_file],
            '-x', param.interp,
            '-o', 'src_' + str(i_file) + '_template_partialVolume.nii.gz'])

        # open data
        data[:, :, :, i_file] = msct_image.Image('src_' + str(i_file) + '_template.nii.gz').data
        partial_volume[:, :, :, i_file] = msct_image.Image('src_' + str(i_file) + '_template_partialVolume.nii.gz').data
        i_file += 1

    # merge files using partial volume information (and convert nan resulting from division by zero to zeros)
    data_merge = np.divide(np.sum(data * partial_volume, axis=3), np.sum(partial_volume, axis=3))
    data_merge = np.nan_to_num(data_merge)

    # write result in file
    nii_dest.data = data_merge
    nii_dest.save(param.fname_out)

    # remove temporary folder
    if param.rm_tmp:
        sct.rmtree(path_tmp)