Пример #1
0
    def compute_texture(self):

        offset = int(self.param_glcm.distance)
        printv('\nCompute texture metrics...', self.param.verbose, 'normal')

        # open image and re-orient it to RPI if needed
        im_tmp = Image(self.param.fname_im)
        if self.orientation_im != self.orientation_extraction:
            im_tmp.change_orientation(self.orientation_extraction)

        dct_metric = {}
        for m in self.metric_lst:
            im_2save = zeros_like(im_tmp, dtype='float64')
            dct_metric[m] = im_2save
            # dct_metric[m] = Image(self.fname_metric_lst[m])

        with sct_progress_bar() as pbar:
            for im_z, seg_z, zz in zip(self.dct_im_seg['im'], self.dct_im_seg['seg'], range(len(self.dct_im_seg['im']))):
                for xx in range(im_z.shape[0]):
                    for yy in range(im_z.shape[1]):
                        if not seg_z[xx, yy]:
                            continue
                        if xx < offset or yy < offset:
                            continue
                        if xx > (im_z.shape[0] - offset - 1) or yy > (im_z.shape[1] - offset - 1):
                            continue  # to check if the whole glcm_window is in the axial_slice
                        if False in np.unique(seg_z[xx - offset: xx + offset + 1, yy - offset: yy + offset + 1]):
                            continue  # to check if the whole glcm_window is in the mask of the axial_slice

                        glcm_window = im_z[xx - offset: xx + offset + 1, yy - offset: yy + offset + 1]
                        glcm_window = glcm_window.astype(np.uint8)

                        dct_glcm = {}
                        for a in self.param_glcm.angle.split(','):  # compute the GLCM for self.param_glcm.distance and for each self.param_glcm.angle
                            dct_glcm[a] = greycomatrix(glcm_window,
                                                       [self.param_glcm.distance], [np.radians(int(a))],
                                                       symmetric=self.param_glcm.symmetric,
                                                       normed=self.param_glcm.normed)

                        for m in self.metric_lst:  # compute the GLCM property (m.split('_')[0]) of the voxel xx,yy,zz
                            dct_metric[m].data[xx, yy, zz] = greycoprops(dct_glcm[m.split('_')[2]], m.split('_')[0])[0][0]

                        pbar.set_postfix(pos="{}/{}".format(zz, len(self.dct_im_seg["im"])))
                        pbar.update(1)

        for m in self.metric_lst:
            fname_out = add_suffix("".join(extract_fname(self.param.fname_im)[1:]), '_' + m)
            dct_metric[m].save(fname_out)
            self.fname_metric_lst[m] = fname_out
Пример #2
0
def moco(param):
    """
    Main function that performs motion correction.

    :param param:
    :return:
    """
    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + file_data, param.verbose)
    printv('  Reference file ........ ' + file_target, param.verbose)
    printv('  Polynomial degree ..... ' + param.poly, param.verbose)
    printv('  Smoothing kernel ...... ' + param.smooth, param.verbose)
    printv('  Gradient step ......... ' + param.gradStep, param.verbose)
    printv('  Metric ................ ' + param.metric, param.verbose)
    printv('  Sampling .............. ' + param.sampling, param.verbose)
    printv('  Todo .................. ' + todo, param.verbose)
    printv('  Mask  ................. ' + param.fname_mask, param.verbose)
    printv('  Output mat folder ..... ' + folder_mat, param.verbose)

    try:
        os.makedirs(folder_mat)
    except FileExistsError:
        pass

    # Get size of data
    printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv(
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),
        verbose)

    # copy file_target to a temporary file
    printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target, verbose=0)

    # Check if user specified a mask
    if not param.fname_mask == '':
        # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask)
        im_mask = Image(param.fname_mask)
        if not np.array_equal(im_mask.data, im_mask.data.astype(bool)):
            # If it is a soft mask, multiply the target by the soft mask.
            im = Image(file_target)
            im_masked = im.copy()
            im_masked.data = im.data * im_mask.data
            im_masked.save(
                verbose=0)  # silence warning about file overwritting

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
            im_z.save(verbose=0)
            file_data_splitZ.append(im_z.absolutepath)
        # z-split target
        im_targetz_list = split_data(Image(file_target),
                                     dim=dim_sag,
                                     squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
            im_targetz.save(verbose=0)
            file_target_splitZ.append(im_targetz.absolutepath)
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
                                       dim=dim_sag,
                                       squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
                im_maskz.save(verbose=0)
                file_mask_splitZ.append(im_maskz.absolutepath)
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
    else:
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    printv(
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)'
    )
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
            im_zt.save(verbose=0)
            file_data_splitZ_splitT.append(im_zt.absolutepath)
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in sct_progress_bar(range(nt),
                                             unit='iter',
                                             unit_scale=False,
                                             desc="Z=" + str(iz) + "/" +
                                             str(len(file_data_splitZ) - 1),
                                             ascii=False,
                                             ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                folder_mat,
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(
                add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking (except in the 'apply' case, where masking is irrelevant)
            input_mask = None
            if not param.fname_mask == '' and not param.todo == 'apply':
                # Check if mask is binary
                if np.array_equal(im_maskz_list[iz].data,
                                  im_maskz_list[iz].data.astype(bool)):
                    # If it is, pass this mask into register() to be used
                    input_mask = im_maskz_list[iz]
                else:
                    # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks.
                    #  Instead, multiply the input data by the Gaussian mask.
                    im = Image(file_data_splitZ_splitT[it])
                    im_masked = im.copy()
                    im_masked.data = im.data * im_maskz_list[iz].data
                    im_masked.save(
                        verbose=0)  # silence warning about file overwritting

            # run 3D registration
            failed_transfo[it] = register(param,
                                          file_data_splitZ_splitT[it],
                                          file_target_splitZ[iz],
                                          file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it],
                                          im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz
                im_targetz.save(verbose=0)

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                printv(
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                     file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(argv=[
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                ])
            else:
                # exit program if no transformation exists.
                printv(
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,
                    'error')
                sys.exit(2)

        # Merge data along T
        file_data_splitZ_moco.append(add_suffix(file, suffix))
        if todo != 'estimate':
            im_data_splitZ_splitT_moco = [
                Image(fname) for fname in file_data_splitZ_splitT_moco
            ]
            im_out = concat_data(im_data_splitZ_splitT_moco, 3)
            im_out.absolutepath = file_data_splitZ_moco[iz]
            im_out.save(verbose=0)

    # If sagittal, merge along Z
    if param.is_sagittal:
        # TODO: im_out.dim is incorrect: Z value is one
        im_data_splitZ_moco = [Image(fname) for fname in file_data_splitZ_moco]
        im_out = concat_data(im_data_splitZ_moco, 2)
        dirname, basename, ext = extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.absolutepath = path_out
        im_out.save(verbose=0)

    return file_mat, im_out
Пример #3
0
def moco_wrapper(param):
    """
    Wrapper that performs motion correction.

    :param param: ParamMoco class
    :return: None
    """
    file_data = 'data.nii'  # corresponds to the full input data (e.g. dmri or fmri)
    file_data_dirname, file_data_basename, file_data_ext = extract_fname(
        file_data)
    file_b0 = 'b0.nii'
    file_datasub = 'datasub.nii'  # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True)
    file_datasubgroup = 'datasub-groups.nii'  # concatenation of the average of each file_datasub
    file_mask = 'mask.nii'
    file_moco_params_csv = 'moco_params.tsv'
    file_moco_params_x = 'moco_params_x.nii.gz'
    file_moco_params_y = 'moco_params_y.nii.gz'
    ext_data = '.nii.gz'  # workaround "too many open files" by slurping the data
    # TODO: check if .nii can be used
    mat_final = 'mat_final/'
    # ext_mat = 'Warp.nii.gz'  # warping field

    # Start timer
    start_time = time.time()

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + param.fname_data, param.verbose)
    printv('  Group size ............ {}'.format(param.group_size),
           param.verbose)

    # Get full path
    # param.fname_data = os.path.abspath(param.fname_data)
    # param.fname_bvecs = os.path.abspath(param.fname_bvecs)
    # if param.fname_bvals != '':
    #     param.fname_bvals = os.path.abspath(param.fname_bvals)

    # Extract path, file and extension
    # path_data, file_data, ext_data = extract_fname(param.fname_data)
    # path_mask, file_mask, ext_mask = extract_fname(param.fname_mask)

    path_tmp = tmp_create(basename="moco")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...',
           param.verbose)
    convert(param.fname_data, os.path.join(path_tmp, file_data))
    if param.fname_mask != '':
        convert(param.fname_mask,
                os.path.join(path_tmp, file_mask),
                verbose=param.verbose)
        # Update field in param (because used later in another function, and param class will be passed)
        param.fname_mask = file_mask

    # Build absolute output path and go to tmp folder
    curdir = os.getcwd()
    path_out_abs = os.path.abspath(param.path_out)
    os.chdir(path_tmp)

    # Get dimensions of data
    printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Get orientation
    printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        printv('  Treated as axial')
    else:
        param.is_sagittal = False
        printv(
            'WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.'
        )

    printv(
        "\nSet suffix of transformation file name, which depends on the orientation:"
    )
    if param.is_sagittal:
        param.suffix_mat = '0GenericAffine.mat'
        printv(
            "Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the "
            "estimated transformation is a 2D affine transfo.".format(
                param.suffix_mat))
    else:
        param.suffix_mat = 'Warp.nii.gz'
        printv(
            "Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is "
            "composed of a stack of 2D Tx-Ty transformations".format(
                param.suffix_mat))

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
        printv(
            'For sagittal data group_size should be one for more robustness. Forcing group_size=1.',
            1, 'warning')
        param.group_size = 1

    if param.is_diffusion:
        # Identify b=0 and DWI images
        index_b0, index_dwi, nb_b0, nb_dwi = \
            sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min,
                                                     param.verbose)

        # check if dmri and bvecs are the same size
        if not nb_b0 + nb_dwi == nt:
            printv(
                '\nERROR in ' + os.path.basename(__file__) +
                ': Size of data (' + str(nt) + ') and size of bvecs (' +
                str(nb_b0 + nb_dwi) +
                ') are not the same. Check your bvecs file.\n', 1, 'error')
            sys.exit(2)

    # ==================================================================================================================
    # Prepare data (mean/groups...)
    # ==================================================================================================================

    # Split into T dimension
    printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")
        im.save()

    if param.is_diffusion:
        # Merge and average b=0 images
        printv('\nMerge and average b=0 data...', param.verbose)
        im_b0_list = []
        for it in range(nb_b0):
            im_b0_list.append(im_data_split_list[index_b0[it]])
        im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0)
        # Average across time
        im_b0.mean(dim=3).save(add_suffix(file_b0, '_mean'))

        n_moco = nb_dwi  # set number of data to perform moco on (using grouping)
        index_moco = index_dwi

    # If not a diffusion scan, we will motion-correct all volumes
    else:
        n_moco = nt
        index_moco = list(range(0, nt))

    nb_groups = int(math.floor(n_moco / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_moco[(iGroup *
                                         param.group_size):((iGroup + 1) *
                                                            param.group_size)])

    # add the remaining images to a new last group (in case the total number of image is not divisible by group_size)
    nb_remaining = n_moco % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_moco[len(index_moco) -
                                        nb_remaining:len(index_moco)])

    _, file_dwi_basename, file_dwi_ext = extract_fname(file_datasub)
    # Group data
    list_file_group = []
    for iGroup in sct_progress_bar(range(nb_groups),
                                   unit='iter',
                                   unit_scale=False,
                                   desc="Merge within groups",
                                   ascii=False,
                                   ncols=80):
        # get index
        index_moco_i = group_indexes[iGroup]
        n_moco_i = len(index_moco_i)
        # concatenate images across time, within this group
        file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) +
                                        ext_data)
        im_dwi_list = []
        for it in range(n_moco_i):
            im_dwi_list.append(im_data_split_list[index_moco_i[it]])
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i,
                                                      verbose=0)
        # Average across time
        list_file_group.append(
            os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' +
                         ext_data))
        im_dwi_out.mean(dim=3).save(list_file_group[-1])

    # Merge across groups
    printv('\nMerge across groups...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    fname_dw_list = []
    for iGroup in range(nb_groups):
        fname_dw_list.append(list_file_group[iGroup])
    im_dw_list = [Image(fname) for fname in fname_dw_list]
    concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0)

    # Cleanup
    del im, im_data_split_list

    # ==================================================================================================================
    # Estimate moco
    # ==================================================================================================================

    # Initialize another class instance that will be passed on to the moco() function
    param_moco = deepcopy(param)

    if param.is_diffusion:
        # Estimate moco on b0 groups
        printv(
            '\n-------------------------------------------------------------------------------',
            param.verbose)
        printv('  Estimating motion on b=0 images...', param.verbose)
        printv(
            '-------------------------------------------------------------------------------',
            param.verbose)
        param_moco.file_data = 'b0.nii'
        # Identify target image
        if index_moco[0] != 0:
            # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that
            # case select it as the target image for registration of all b=0
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[index_moco[0] - 1]).zfill(4) + ext_data)
        else:
            # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[0]).zfill(4) + ext_data)
        # Run moco
        param_moco.path_out = ''
        param_moco.todo = 'estimate_and_apply'
        param_moco.mat_moco = 'mat_b0groups'
        file_mat_b0, _ = moco(param_moco)

    # Estimate moco across groups
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Estimating motion across groups...', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_datasubgroup
    param_moco.file_target = list_file_group[
        0]  # target is the first volume (closest to the first b=0 if DWI scan)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat_datasub_group, _ = moco(param_moco)

    # Spline Regularization along T
    if param.spline_fitting:
        # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy)
        raise NotImplementedError()
        # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # ==================================================================================================================
    # Apply moco
    # ==================================================================================================================

    # If group_size>1, assign transformation to each individual ungrouped 3d volume
    if param.group_size > 1:
        file_mat_datasub = []
        for iz in range(len(file_mat_datasub_group)):
            # duplicate by factor group_size the transformation file for each it
            #  example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2
            file_mat_datasub.append(
                functools.reduce(operator.iconcat,
                                 [[i] * param.group_size
                                  for i in file_mat_datasub_group[iz]], []))
    else:
        file_mat_datasub = file_mat_datasub_group

    # Copy transformations to mat_final folder and rename them appropriately
    copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param)
    if param.is_diffusion:
        copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param)

    # Apply moco on all dmri data
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Apply moco', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = list_file_group[
        0]  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''  # TODO not used in moco()
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'
    file_mat_data, im_moco = moco(param_moco)

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_moco.header = im_data.header
    im_moco.save(verbose=0)

    # Average across time
    if param.is_diffusion:
        # generate b0_moco_mean and dwi_moco_mean
        args = [
            '-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1',
            '-v', '0'
        ]
        if not param.fname_bvals == '':
            # if bvals file is provided
            args += ['-bval', param.fname_bvals]
        fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main(
            argv=args)
    else:
        fname_moco_mean = add_suffix(im_moco.absolutepath, '_mean')
        im_moco.mean(dim=3).save(fname_moco_mean)

    # Extract and output the motion parameters (doesn't work for sagittal orientation)
    printv('Extract motion parameters...')
    if param.output_motion_param:
        if param.is_sagittal:
            printv(
                'Motion parameters cannot be generated for sagittal images.',
                1, 'warning')
        else:
            files_warp_X, files_warp_Y = [], []
            moco_param = []
            for fname_warp in file_mat_data[0]:
                # Cropping the image to keep only one voxel in the XY plane
                im_warp = Image(fname_warp + param.suffix_mat)
                im_warp.data = np.expand_dims(np.expand_dims(
                    im_warp.data[0, 0, :, :, :], axis=0),
                                              axis=0)

                # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters
                #fname_warp_crop = fname_warp + '_crop_' + ext_mat
                # files_warp.append(fname_warp_crop)
                # im_warp.save(fname_warp_crop)

                # Separating the three components and saving X and Y only (Z is equal to 0 by default).
                im_warp_XYZ = multicomponent_split(im_warp)

                fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat
                im_warp_XYZ[0].save(fname_warp_crop_X)
                files_warp_X.append(fname_warp_crop_X)

                fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat
                im_warp_XYZ[1].save(fname_warp_crop_Y)
                files_warp_Y.append(fname_warp_crop_Y)

                # Calculating the slice-wise average moco estimate to provide a QC file
                moco_param.append([
                    np.mean(np.ravel(im_warp_XYZ[0].data)),
                    np.mean(np.ravel(im_warp_XYZ[1].data))
                ])

            # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters
            # im_warp = [Image(fname) for fname in files_warp]
            # im_warp_concat = concat_data(im_warp, dim=3)
            # im_warp_concat.save('fmri_moco_params.nii')

            # Concatenating the moco parameters into a time series for X and Y components.
            im_warp_X = [Image(fname) for fname in files_warp_X]
            im_warp_concat = concat_data(im_warp_X, dim=3)
            im_warp_concat.save(file_moco_params_x)

            im_warp_Y = [Image(fname) for fname in files_warp_Y]
            im_warp_concat = concat_data(im_warp_Y, dim=3)
            im_warp_concat.save(file_moco_params_y)

            # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC
            with open(file_moco_params_csv, 'wt') as out_file:
                tsv_writer = csv.writer(out_file, delimiter='\t')
                tsv_writer.writerow(['X', 'Y'])
                for mocop in moco_param:
                    tsv_writer.writerow([mocop[0], mocop[1]])

    # Generate output files
    printv('\nGenerate output files...', param.verbose)
    fname_moco = os.path.join(
        path_out_abs,
        add_suffix(os.path.basename(param.fname_data), param.suffix))
    generate_output_file(im_moco.absolutepath, fname_moco)
    if param.is_diffusion:
        generate_output_file(fname_b0_mean, add_suffix(fname_moco, '_b0_mean'))
        generate_output_file(fname_dwi_mean,
                             add_suffix(fname_moco, '_dwi_mean'))
    else:
        generate_output_file(fname_moco_mean, add_suffix(fname_moco, '_mean'))
    if os.path.exists(file_moco_params_csv):
        generate_output_file(file_moco_params_x,
                             os.path.join(path_out_abs, file_moco_params_x),
                             squeeze_data=False)
        generate_output_file(file_moco_params_y,
                             os.path.join(path_out_abs, file_moco_params_y),
                             squeeze_data=False)
        generate_output_file(file_moco_params_csv,
                             os.path.join(path_out_abs, file_moco_params_csv))

    # Delete temporary files
    if param.remove_temp_files == 1:
        printv('\nDelete temporary files...', param.verbose)
        rmtree(path_tmp, verbose=param.verbose)

    # come back to working directory
    os.chdir(curdir)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        param.verbose)

    display_viewer_syntax([
        os.path.join(
            param.path_out,
            add_suffix(os.path.basename(param.fname_data), param.suffix)),
        param.fname_data
    ],
                          mode='ortho,ortho')
Пример #4
0
    def straighten(self):
        """
        Straighten spinal cord. Steps: (everything is done in physical space)
        1. open input image and centreline image
        2. extract bspline fitting of the centreline, and its derivatives
        3. compute length of centerline
        4. compute and generate straight space
        5. compute transformations
            for each voxel of one space: (done using matrices --> improves speed by a factor x300)
                a. determine which plane of spinal cord centreline it is included
                b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane)
                c. find the correspondant centreline point in the other space
                d. find the correspondance of the voxel in the corresponding plane
        6. generate warping fields for each transformations
        7. write warping fields and apply them

        step 5.b: how to find the corresponding plane?
            The centerline plane corresponding to a voxel correspond to the nearest point of the centerline.
            However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant.
            If it is more far than a threshold, warping value should be 0.

        step 5.d: how to make the correspondance between centerline point in both images?
            Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve.
            If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward.

        :return:
        """
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp  # TODO: remove this

        # start timer
        start_time = time.time()

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = extract_fname(fname_anat)

        path_tmp = tmp_create(basename="straighten_spinalcord")

        # Copying input data to tmp folder
        logger.info('Copy files to tmp folder...')
        Image(fname_anat,
              check_sform=True).save(os.path.join(path_tmp, "data.nii"))
        Image(fname_centerline, check_sform=True).save(
            os.path.join(path_tmp, "centerline.nii.gz"))

        if self.use_straight_reference:
            Image(self.centerline_reference_filename, check_sform=True).save(
                os.path.join(path_tmp, "centerline_ref.nii.gz"))
        if self.discs_input_filename != '':
            Image(self.discs_input_filename, check_sform=True).save(
                os.path.join(path_tmp, "labels_input.nii.gz"))
        if self.discs_ref_filename != '':
            Image(self.discs_ref_filename, check_sform=True).save(
                os.path.join(path_tmp, "labels_ref.nii.gz"))

        # go to tmp folder
        curdir = os.getcwd()
        os.chdir(path_tmp)

        # Change orientation of the input centerline into RPI
        image_centerline = Image("centerline.nii.gz").change_orientation(
            "RPI").save("centerline_rpi.nii.gz", mutable=True)

        # Get dimension
        nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
        if self.speed_factor != 1.0:
            intermediate_resampling = True
            px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor
        else:
            intermediate_resampling = False

        if intermediate_resampling:
            mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz')
            pz_native = pz
            # TODO: remove system call
            run_proc([
                'sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm',
                str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o',
                'centerline_rpi.nii.gz'
            ])
            image_centerline = Image('centerline_rpi.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim

        if np.min(image_centerline.data) < 0 or np.max(
                image_centerline.data) > 1:
            image_centerline.data[image_centerline.data < 0] = 0
            image_centerline.data[image_centerline.data > 1] = 1
            image_centerline.save()

        # 2. extract bspline fitting of the centerline, and its derivatives
        img_ctl = Image('centerline_rpi.nii.gz')
        centerline = _get_centerline(img_ctl, self.param_centerline, verbose)
        number_of_points = centerline.number_of_points

        # ==========================================================================================
        logger.info('Create the straight space and the safe zone')
        # 3. compute length of centerline
        # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction

        # Computation of the safe zone.
        # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete
        # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the
        # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle)
        # Calculate Ls for both edges and remove appropriate number of centerline points
        radius_safe = 0.0  # mm

        # inferior edge
        u = centerline.derivatives[0]
        v = np.array([0, 0, -1])

        angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)),
                                    np.dot(u, v))
        length_safe_inferior = radius_safe * np.sin(angle_inferior)

        # superior edge
        u = centerline.derivatives[-1]
        v = np.array([0, 0, 1])
        angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)),
                                    np.dot(u, v))
        length_safe_superior = radius_safe * np.sin(angle_superior)

        # remove points
        inferior_bound = bisect.bisect(centerline.progressive_length,
                                       length_safe_inferior) - 1
        superior_bound = centerline.number_of_points - bisect.bisect(
            centerline.progressive_length_inverse, length_safe_superior)

        z_centerline = centerline.points[:, 2]
        length_centerline = centerline.length
        size_z_centerline = z_centerline[-1] - z_centerline[0]

        # compute the size factor between initial centerline and straight bended centerline
        factor_curved_straight = length_centerline / size_z_centerline
        middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0

        bound_curved = [
            z_centerline[inferior_bound], z_centerline[superior_bound]
        ]
        bound_straight = [(z_centerline[inferior_bound] - middle_slice) *
                          factor_curved_straight + middle_slice,
                          (z_centerline[superior_bound] - middle_slice) *
                          factor_curved_straight + middle_slice]

        logger.info('Length of spinal cord: {}'.format(length_centerline))
        logger.info(
            'Size of spinal cord in z direction: {}'.format(size_z_centerline))
        logger.info('Ratio length/size: {}'.format(factor_curved_straight))
        logger.info(
            'Safe zone boundaries (curved space): {}'.format(bound_curved))
        logger.info(
            'Safe zone boundaries (straight space): {}'.format(bound_straight))

        # 4. compute and generate straight space
        # points along curved centerline are already regularly spaced.
        # calculate position of points along straight centerline

        # Create straight NIFTI volumes.
        # ==========================================================================================
        # TODO: maybe this if case is not needed?
        if self.use_straight_reference:
            image_centerline_pad = Image('centerline_rpi.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim

            fname_ref = 'centerline_ref_rpi.nii.gz'
            image_centerline_straight = Image('centerline_ref.nii.gz') \
                .change_orientation("RPI") \
                .save(fname_ref, mutable=True)
            centerline_straight = _get_centerline(image_centerline_straight,
                                                  self.param_centerline,
                                                  verbose)
            nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim

            # Prepare warping fields headers
            hdr_warp = image_centerline_pad.hdr.copy()
            hdr_warp.set_data_dtype('float32')
            hdr_warp_s = image_centerline_straight.hdr.copy()
            hdr_warp_s.set_data_dtype('float32')

            if self.discs_input_filename != "" and self.discs_ref_filename != "":
                discs_input_image = Image('labels_input.nii.gz')
                coord = discs_input_image.getNonZeroCoordinates(
                    sorting='z', reverse_coord=True)
                coord_physical = []
                for c in coord:
                    c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z]
                                                              ]).tolist()[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)
                centerline.compute_vertebral_distribution(coord_physical)
                centerline.save_centerline(
                    image=discs_input_image,
                    fname_output='discs_input_image.nii.gz')

                discs_ref_image = Image('labels_ref.nii.gz')
                coord = discs_ref_image.getNonZeroCoordinates(
                    sorting='z', reverse_coord=True)
                coord_physical = []
                for c in coord:
                    c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y,
                                                             c.z]]).tolist()[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)
                centerline_straight.compute_vertebral_distribution(
                    coord_physical)
                centerline_straight.save_centerline(
                    image=discs_ref_image,
                    fname_output='discs_ref_image.nii.gz')

        else:
            logger.info(
                'Pad input volume to account for spinal cord length...')

            start_point, end_point = bound_straight[0], bound_straight[1]
            offset_z = 0

            # if the destination image is resampled, we still create the straight reference space with the native
            # resolution.
            # TODO: Maybe this if case is not needed?
            if intermediate_resampling:
                padding_z = int(
                    np.ceil(1.5 *
                            ((length_centerline - size_z_centerline) / 2.0) /
                            pz_native))
                run_proc([
                    'sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o',
                    'tmp.centerline_pad_native.nii.gz', '-pad',
                    '0,0,' + str(padding_z)
                ])
                image_centerline_pad = Image('centerline_rpi_native.nii.gz')
                nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim
                start_point_coord_native = image_centerline_pad.transfo_phys2pix(
                    [[0, 0, start_point]])[0]
                end_point_coord_native = image_centerline_pad.transfo_phys2pix(
                    [[0, 0, end_point]])[0]
                straight_size_x = int(self.xy_size / px)
                straight_size_y = int(self.xy_size / py)
                warp_space_x = [
                    int(np.round(nx / 2)) - straight_size_x,
                    int(np.round(nx / 2)) + straight_size_x
                ]
                warp_space_y = [
                    int(np.round(ny / 2)) - straight_size_y,
                    int(np.round(ny / 2)) + straight_size_y
                ]
                if warp_space_x[0] < 0:
                    warp_space_x[1] += warp_space_x[0] - 2
                    warp_space_x[0] = 0
                if warp_space_y[0] < 0:
                    warp_space_y[1] += warp_space_y[0] - 2
                    warp_space_y[0] = 0

                spec = dict((
                    (0, warp_space_x),
                    (1, warp_space_y),
                    (2, (0, end_point_coord_native[2] -
                         start_point_coord_native[2])),
                ))
                spatial_crop(
                    Image("tmp.centerline_pad_native.nii.gz"),
                    spec).save("tmp.centerline_pad_crop_native.nii.gz")

                fname_ref = 'tmp.centerline_pad_crop_native.nii.gz'
                offset_z = 4
            else:
                fname_ref = 'tmp.centerline_pad_crop.nii.gz'

            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
            padding_z = int(
                np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) /
                        pz)) + offset_z
            image_centerline_pad = pad_image(image_centerline,
                                             pad_z_i=padding_z,
                                             pad_z_f=padding_z)
            nx, ny, nz = image_centerline_pad.data.shape
            hdr_warp = image_centerline_pad.hdr.copy()
            hdr_warp.set_data_dtype('float32')
            start_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, start_point]])[0]
            end_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, end_point]])[0]

            straight_size_x = int(self.xy_size / px)
            straight_size_y = int(self.xy_size / py)
            warp_space_x = [
                int(np.round(nx / 2)) - straight_size_x,
                int(np.round(nx / 2)) + straight_size_x
            ]
            warp_space_y = [
                int(np.round(ny / 2)) - straight_size_y,
                int(np.round(ny / 2)) + straight_size_y
            ]

            if warp_space_x[0] < 0:
                warp_space_x[1] += warp_space_x[0] - 2
                warp_space_x[0] = 0
            if warp_space_x[1] >= nx:
                warp_space_x[1] = nx - 1
            if warp_space_y[0] < 0:
                warp_space_y[1] += warp_space_y[0] - 2
                warp_space_y[0] = 0
            if warp_space_y[1] >= ny:
                warp_space_y[1] = ny - 1

            spec = dict((
                (0, warp_space_x),
                (1, warp_space_y),
                (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)),
            ))
            image_centerline_straight = spatial_crop(image_centerline_pad,
                                                     spec)

            nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim
            hdr_warp_s = image_centerline_straight.hdr.copy()
            hdr_warp_s.set_data_dtype('float32')

            if self.template_orientation == 1:
                raise NotImplementedError()

            start_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, start_point]])[0]
            end_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, end_point]])[0]

            number_of_voxel = nx * ny * nz
            logger.debug('Number of voxels: {}'.format(number_of_voxel))

            time_centerlines = time.time()

            coord_straight = np.empty((number_of_points, 3))
            coord_straight[..., 0] = int(np.round(nx_s / 2))
            coord_straight[..., 1] = int(np.round(ny_s / 2))
            coord_straight[..., 2] = np.linspace(
                0, end_point_coord[2] - start_point_coord[2], number_of_points)
            coord_phys_straight = image_centerline_straight.transfo_pix2phys(
                coord_straight)
            derivs_straight = np.empty((number_of_points, 3))
            derivs_straight[..., 0] = derivs_straight[..., 1] = 0
            derivs_straight[..., 2] = 1
            dx_straight, dy_straight, dz_straight = derivs_straight.T
            centerline_straight = Centerline(coord_phys_straight[:, 0],
                                             coord_phys_straight[:, 1],
                                             coord_phys_straight[:, 2],
                                             dx_straight, dy_straight,
                                             dz_straight)

            time_centerlines = time.time() - time_centerlines
            logger.info('Time to generate centerline: {} ms'.format(
                np.round(time_centerlines * 1000.0)))

        if verbose == 2:
            # TODO: use OO
            import matplotlib.pyplot as plt
            from datetime import datetime
            curved_points = centerline.progressive_length
            straight_points = centerline_straight.progressive_length
            range_points = np.linspace(0, 1, number_of_points)
            dist_curved = np.zeros(number_of_points)
            dist_straight = np.zeros(number_of_points)
            for i in range(1, number_of_points):
                dist_curved[i] = dist_curved[
                    i - 1] + curved_points[i - 1] / centerline.length
                dist_straight[i] = dist_straight[i - 1] + straight_points[
                    i - 1] / centerline_straight.length
            plt.plot(range_points, dist_curved)
            plt.plot(range_points, dist_straight)
            plt.grid(True)
            plt.savefig('fig_straighten_' +
                        datetime.now().strftime("%y%m%d%H%M%S%f") + '.png')
            plt.close()

        # alignment_mode = 'length'
        alignment_mode = 'levels'

        lookup_curved2straight = list(range(centerline.number_of_points))
        if self.discs_input_filename != "":
            # create look-up table curved to straight
            for index in range(centerline.number_of_points):
                disc_label = centerline.l_points[index]
                if alignment_mode == 'length':
                    relative_position = centerline.dist_points[index]
                else:
                    relative_position = centerline.dist_points_rel[index]
                idx_closest = centerline_straight.get_closest_to_absolute_position(
                    disc_label,
                    relative_position,
                    backup_index=index,
                    backup_centerline=centerline_straight,
                    mode=alignment_mode)
                if idx_closest is not None:
                    lookup_curved2straight[index] = idx_closest
                else:
                    lookup_curved2straight[index] = 0
        for p in range(0, len(lookup_curved2straight) // 2):
            if lookup_curved2straight[p] == lookup_curved2straight[p + 1]:
                lookup_curved2straight[p] = 0
            else:
                break
        for p in range(
                len(lookup_curved2straight) - 1,
                len(lookup_curved2straight) // 2, -1):
            if lookup_curved2straight[p] == lookup_curved2straight[p - 1]:
                lookup_curved2straight[p] = 0
            else:
                break
        lookup_curved2straight = np.array(lookup_curved2straight)

        lookup_straight2curved = list(
            range(centerline_straight.number_of_points))
        if self.discs_input_filename != "":
            for index in range(centerline_straight.number_of_points):
                disc_label = centerline_straight.l_points[index]
                if alignment_mode == 'length':
                    relative_position = centerline_straight.dist_points[index]
                else:
                    relative_position = centerline_straight.dist_points_rel[
                        index]
                idx_closest = centerline.get_closest_to_absolute_position(
                    disc_label,
                    relative_position,
                    backup_index=index,
                    backup_centerline=centerline_straight,
                    mode=alignment_mode)
                if idx_closest is not None:
                    lookup_straight2curved[index] = idx_closest
        for p in range(0, len(lookup_straight2curved) // 2):
            if lookup_straight2curved[p] == lookup_straight2curved[p + 1]:
                lookup_straight2curved[p] = 0
            else:
                break
        for p in range(
                len(lookup_straight2curved) - 1,
                len(lookup_straight2curved) // 2, -1):
            if lookup_straight2curved[p] == lookup_straight2curved[p - 1]:
                lookup_straight2curved[p] = 0
            else:
                break
        lookup_straight2curved = np.array(lookup_straight2curved)

        # Create volumes containing curved and straight warping fields
        data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3))
        data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3))

        # 5. compute transformations
        # Curved and straight images and the same dimensions, so we compute both warping fields at the same time.
        # b. determine which plane of spinal cord centreline it is included

        if self.curved2straight:
            for u in sct_progress_bar(range(nz_s)):
                x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1]
                indexes_straight = np.array(
                    list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel())))
                physical_coordinates_straight = image_centerline_straight.transfo_pix2phys(
                    indexes_straight)
                nearest_indexes_straight = centerline_straight.find_nearest_indexes(
                    physical_coordinates_straight)
                distances_straight = centerline_straight.get_distances_from_planes(
                    physical_coordinates_straight, nearest_indexes_straight)
                lookup = lookup_straight2curved[nearest_indexes_straight]
                indexes_out_distance_straight = np.logical_or(
                    np.logical_or(
                        distances_straight > self.threshold_distance,
                        distances_straight < -self.threshold_distance),
                    lookup == 0)
                projected_points_straight = centerline_straight.get_projected_coordinates_on_planes(
                    physical_coordinates_straight, nearest_indexes_straight)
                coord_in_planes_straight = centerline_straight.get_in_plans_coordinates(
                    projected_points_straight, nearest_indexes_straight)

                coord_straight2curved = centerline.get_inverse_plans_coordinates(
                    coord_in_planes_straight, lookup)
                displacements_straight = coord_straight2curved - physical_coordinates_straight
                # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+)
                # while ours is LPI-
                # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a
                #  https://www.slicer.org/wiki/Coordinate_systems
                displacements_straight[:, 2] = -displacements_straight[:, 2]
                displacements_straight[indexes_out_distance_straight] = [
                    100000.0, 100000.0, 100000.0
                ]

                data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\
                    = -displacements_straight

        if self.straight2curved:
            for u in sct_progress_bar(range(nz)):
                x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1]
                indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel())))
                physical_coordinates = image_centerline_pad.transfo_pix2phys(
                    indexes)
                nearest_indexes_curved = centerline.find_nearest_indexes(
                    physical_coordinates)
                distances_curved = centerline.get_distances_from_planes(
                    physical_coordinates, nearest_indexes_curved)
                lookup = lookup_curved2straight[nearest_indexes_curved]
                indexes_out_distance_curved = np.logical_or(
                    np.logical_or(distances_curved > self.threshold_distance,
                                  distances_curved < -self.threshold_distance),
                    lookup == 0)
                projected_points_curved = centerline.get_projected_coordinates_on_planes(
                    physical_coordinates, nearest_indexes_curved)
                coord_in_planes_curved = centerline.get_in_plans_coordinates(
                    projected_points_curved, nearest_indexes_curved)

                coord_curved2straight = centerline_straight.points[lookup]
                coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2]
                coord_curved2straight[:, 2] += distances_curved

                displacements_curved = coord_curved2straight - physical_coordinates

                displacements_curved[:, 2] = -displacements_curved[:, 2]
                displacements_curved[indexes_out_distance_curved] = [
                    100000.0, 100000.0, 100000.0
                ]

                data_warp_straight2curved[indexes[:, 0], indexes[:, 1],
                                          indexes[:, 2],
                                          0, :] = -displacements_curved

        # Creation of the safe zone based on pre-calculated safe boundaries
        coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix(
            [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix(
                [[0, 0, bound_curved[1]]])
        coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix(
            [[0, 0,
              bound_straight[0]]]), image_centerline_straight.transfo_phys2pix(
                  [[0, 0, bound_straight[1]]])

        if radius_safe > 0:
            data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2],
                                      0, :] = 100000.0
            data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:,
                                      0, :] = 100000.0
            data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2],
                                      0, :] = 100000.0
            data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:,
                                      0, :] = 100000.0

        # Generate warp files as a warping fields
        hdr_warp_s.set_intent('vector', (), '')
        hdr_warp_s.set_data_dtype('float32')
        hdr_warp.set_intent('vector', (), '')
        hdr_warp.set_data_dtype('float32')
        if self.curved2straight:
            img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s)
            save(img, 'tmp.curve2straight.nii.gz')
            logger.info('Warping field generated: tmp.curve2straight.nii.gz')

        if self.straight2curved:
            img = Nifti1Image(data_warp_straight2curved, None, hdr_warp)
            save(img, 'tmp.straight2curve.nii.gz')
            logger.info('Warping field generated: tmp.straight2curve.nii.gz')

        image_centerline_straight.save(fname_ref)
        if self.curved2straight:
            logger.info('Apply transformation to input image...')
            run_proc([
                'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i',
                'data.nii', '-o', 'tmp.anat_rigid_warp.nii.gz', '-t',
                'tmp.curve2straight.nii.gz', '-n', 'BSpline[3]'
            ],
                     is_sct_binary=True,
                     verbose=verbose)

        if self.accuracy_results:
            time_accuracy_results = time.time()
            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            logger.info('Apply transformation to centerline image...')
            run_proc([
                'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i',
                'centerline.nii.gz', '-o', 'tmp.centerline_straight.nii.gz',
                '-t', 'tmp.curve2straight.nii.gz', '-n', 'NearestNeighbor'
            ],
                     is_sct_binary=True,
                     verbose=verbose)
            file_centerline_straight = Image('tmp.centerline_straight.nii.gz',
                                             verbose=verbose)
            nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(
                sorting='z')
            mean_coord = []
            for z in range(coordinates_centerline[0].z,
                           coordinates_centerline[-1].z):
                temp_mean = [
                    coord.value for coord in coordinates_centerline
                    if coord.z == z
                ]
                if temp_mean:
                    mean_value = np.mean(temp_mean)
                    mean_coord.append(
                        np.mean([[
                            coord.x * coord.value / mean_value,
                            coord.y * coord.value / mean_value
                        ] for coord in coordinates_centerline if coord.z == z],
                                axis=0))

            # compute error between the straightened centerline and the straight line.
            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            if number_of_points >= 10:
                mean_c = mean_coord[
                    2:
                    -2]  # we don't include the four extrema because there are usually messy.
            else:
                mean_c = mean_coord
            for coord_z in mean_c:
                if not np.isnan(np.sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px)**2 + (
                        (y0 - coord_z[1]) * py)**2
                    self.mse_straightening += dist
                    dist = np.sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = np.sqrt(self.mse_straightening /
                                             float(count_mean))

            self.elapsed_time_accuracy = time.time() - time_accuracy_results

        os.chdir(curdir)

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        logger.info('Generate output files...')
        if self.curved2straight:
            generate_output_file(
                os.path.join(path_tmp, "tmp.curve2straight.nii.gz"),
                os.path.join(self.path_output, "warp_curve2straight.nii.gz"),
                verbose)
        if self.straight2curved:
            generate_output_file(
                os.path.join(path_tmp, "tmp.straight2curve.nii.gz"),
                os.path.join(self.path_output, "warp_straight2curve.nii.gz"),
                verbose)

        # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space
        if self.curved2straight:
            copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                 os.path.join(self.path_output, "straight_ref.nii.gz"))
            # move straightened input file
            if fname_output == '':
                fname_straight = generate_output_file(
                    os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                    os.path.join(self.path_output,
                                 file_anat + "_straight" + ext_anat), verbose)
            else:
                fname_straight = generate_output_file(
                    os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                    os.path.join(self.path_output, fname_output),
                    verbose)  # straightened anatomic

        # Remove temporary files
        if remove_temp_files:
            logger.info('Remove temporary files...')
            rmtree(path_tmp)

        if self.accuracy_results:
            logger.info('Maximum x-y error: {} mm'.format(
                self.max_distance_straightening))
            logger.info('Accuracy of straightening (MSE): {} mm'.format(
                self.mse_straightening))

        # display elapsed time
        self.elapsed_time = int(np.round(time.time() - start_time))

        return fname_straight