Ejemplo n.º 1
0
def moco(param):
    """
    Main function that performs motion correction.

    :param param:
    :return:
    """
    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + file_data, param.verbose)
    printv('  Reference file ........ ' + file_target, param.verbose)
    printv('  Polynomial degree ..... ' + param.poly, param.verbose)
    printv('  Smoothing kernel ...... ' + param.smooth, param.verbose)
    printv('  Gradient step ......... ' + param.gradStep, param.verbose)
    printv('  Metric ................ ' + param.metric, param.verbose)
    printv('  Sampling .............. ' + param.sampling, param.verbose)
    printv('  Todo .................. ' + todo, param.verbose)
    printv('  Mask  ................. ' + param.fname_mask, param.verbose)
    printv('  Output mat folder ..... ' + folder_mat, param.verbose)

    try:
        os.makedirs(folder_mat)
    except FileExistsError:
        pass

    # Get size of data
    printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv(
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),
        verbose)

    # copy file_target to a temporary file
    printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target, verbose=0)

    # Check if user specified a mask
    if not param.fname_mask == '':
        # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask)
        im_mask = Image(param.fname_mask)
        if not np.array_equal(im_mask.data, im_mask.data.astype(bool)):
            # If it is a soft mask, multiply the target by the soft mask.
            im = Image(file_target)
            im_masked = im.copy()
            im_masked.data = im.data * im_mask.data
            im_masked.save(
                verbose=0)  # silence warning about file overwritting

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
            im_z.save(verbose=0)
            file_data_splitZ.append(im_z.absolutepath)
        # z-split target
        im_targetz_list = split_data(Image(file_target),
                                     dim=dim_sag,
                                     squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
            im_targetz.save(verbose=0)
            file_target_splitZ.append(im_targetz.absolutepath)
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
                                       dim=dim_sag,
                                       squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
                im_maskz.save(verbose=0)
                file_mask_splitZ.append(im_maskz.absolutepath)
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
    else:
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    printv(
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)'
    )
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
            im_zt.save(verbose=0)
            file_data_splitZ_splitT.append(im_zt.absolutepath)
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in sct_progress_bar(range(nt),
                                             unit='iter',
                                             unit_scale=False,
                                             desc="Z=" + str(iz) + "/" +
                                             str(len(file_data_splitZ) - 1),
                                             ascii=False,
                                             ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                folder_mat,
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(
                add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking (except in the 'apply' case, where masking is irrelevant)
            input_mask = None
            if not param.fname_mask == '' and not param.todo == 'apply':
                # Check if mask is binary
                if np.array_equal(im_maskz_list[iz].data,
                                  im_maskz_list[iz].data.astype(bool)):
                    # If it is, pass this mask into register() to be used
                    input_mask = im_maskz_list[iz]
                else:
                    # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks.
                    #  Instead, multiply the input data by the Gaussian mask.
                    im = Image(file_data_splitZ_splitT[it])
                    im_masked = im.copy()
                    im_masked.data = im.data * im_maskz_list[iz].data
                    im_masked.save(
                        verbose=0)  # silence warning about file overwritting

            # run 3D registration
            failed_transfo[it] = register(param,
                                          file_data_splitZ_splitT[it],
                                          file_target_splitZ[iz],
                                          file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it],
                                          im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz
                im_targetz.save(verbose=0)

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                printv(
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                     file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(argv=[
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                ])
            else:
                # exit program if no transformation exists.
                printv(
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,
                    'error')
                sys.exit(2)

        # Merge data along T
        file_data_splitZ_moco.append(add_suffix(file, suffix))
        if todo != 'estimate':
            im_data_splitZ_splitT_moco = [
                Image(fname) for fname in file_data_splitZ_splitT_moco
            ]
            im_out = concat_data(im_data_splitZ_splitT_moco, 3)
            im_out.absolutepath = file_data_splitZ_moco[iz]
            im_out.save(verbose=0)

    # If sagittal, merge along Z
    if param.is_sagittal:
        # TODO: im_out.dim is incorrect: Z value is one
        im_data_splitZ_moco = [Image(fname) for fname in file_data_splitZ_moco]
        im_out = concat_data(im_data_splitZ_moco, 2)
        dirname, basename, ext = extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.absolutepath = path_out
        im_out.save(verbose=0)

    return file_mat, im_out
Ejemplo n.º 2
0
def moco_wrapper(param):
    """
    Wrapper that performs motion correction.

    :param param: ParamMoco class
    :return: None
    """
    file_data = 'data.nii'  # corresponds to the full input data (e.g. dmri or fmri)
    file_data_dirname, file_data_basename, file_data_ext = extract_fname(
        file_data)
    file_b0 = 'b0.nii'
    file_datasub = 'datasub.nii'  # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True)
    file_datasubgroup = 'datasub-groups.nii'  # concatenation of the average of each file_datasub
    file_mask = 'mask.nii'
    file_moco_params_csv = 'moco_params.tsv'
    file_moco_params_x = 'moco_params_x.nii.gz'
    file_moco_params_y = 'moco_params_y.nii.gz'
    ext_data = '.nii.gz'  # workaround "too many open files" by slurping the data
    # TODO: check if .nii can be used
    mat_final = 'mat_final/'
    # ext_mat = 'Warp.nii.gz'  # warping field

    # Start timer
    start_time = time.time()

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + param.fname_data, param.verbose)
    printv('  Group size ............ {}'.format(param.group_size),
           param.verbose)

    # Get full path
    # param.fname_data = os.path.abspath(param.fname_data)
    # param.fname_bvecs = os.path.abspath(param.fname_bvecs)
    # if param.fname_bvals != '':
    #     param.fname_bvals = os.path.abspath(param.fname_bvals)

    # Extract path, file and extension
    # path_data, file_data, ext_data = extract_fname(param.fname_data)
    # path_mask, file_mask, ext_mask = extract_fname(param.fname_mask)

    path_tmp = tmp_create(basename="moco")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...',
           param.verbose)
    convert(param.fname_data, os.path.join(path_tmp, file_data))
    if param.fname_mask != '':
        convert(param.fname_mask,
                os.path.join(path_tmp, file_mask),
                verbose=param.verbose)
        # Update field in param (because used later in another function, and param class will be passed)
        param.fname_mask = file_mask

    # Build absolute output path and go to tmp folder
    curdir = os.getcwd()
    path_out_abs = os.path.abspath(param.path_out)
    os.chdir(path_tmp)

    # Get dimensions of data
    printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Get orientation
    printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        printv('  Treated as axial')
    else:
        param.is_sagittal = False
        printv(
            'WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.'
        )

    printv(
        "\nSet suffix of transformation file name, which depends on the orientation:"
    )
    if param.is_sagittal:
        param.suffix_mat = '0GenericAffine.mat'
        printv(
            "Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the "
            "estimated transformation is a 2D affine transfo.".format(
                param.suffix_mat))
    else:
        param.suffix_mat = 'Warp.nii.gz'
        printv(
            "Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is "
            "composed of a stack of 2D Tx-Ty transformations".format(
                param.suffix_mat))

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
        printv(
            'For sagittal data group_size should be one for more robustness. Forcing group_size=1.',
            1, 'warning')
        param.group_size = 1

    if param.is_diffusion:
        # Identify b=0 and DWI images
        index_b0, index_dwi, nb_b0, nb_dwi = \
            sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min,
                                                     param.verbose)

        # check if dmri and bvecs are the same size
        if not nb_b0 + nb_dwi == nt:
            printv(
                '\nERROR in ' + os.path.basename(__file__) +
                ': Size of data (' + str(nt) + ') and size of bvecs (' +
                str(nb_b0 + nb_dwi) +
                ') are not the same. Check your bvecs file.\n', 1, 'error')
            sys.exit(2)

    # ==================================================================================================================
    # Prepare data (mean/groups...)
    # ==================================================================================================================

    # Split into T dimension
    printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")
        im.save()

    if param.is_diffusion:
        # Merge and average b=0 images
        printv('\nMerge and average b=0 data...', param.verbose)
        im_b0_list = []
        for it in range(nb_b0):
            im_b0_list.append(im_data_split_list[index_b0[it]])
        im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0)
        # Average across time
        im_b0.mean(dim=3).save(add_suffix(file_b0, '_mean'))

        n_moco = nb_dwi  # set number of data to perform moco on (using grouping)
        index_moco = index_dwi

    # If not a diffusion scan, we will motion-correct all volumes
    else:
        n_moco = nt
        index_moco = list(range(0, nt))

    nb_groups = int(math.floor(n_moco / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_moco[(iGroup *
                                         param.group_size):((iGroup + 1) *
                                                            param.group_size)])

    # add the remaining images to a new last group (in case the total number of image is not divisible by group_size)
    nb_remaining = n_moco % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_moco[len(index_moco) -
                                        nb_remaining:len(index_moco)])

    _, file_dwi_basename, file_dwi_ext = extract_fname(file_datasub)
    # Group data
    list_file_group = []
    for iGroup in sct_progress_bar(range(nb_groups),
                                   unit='iter',
                                   unit_scale=False,
                                   desc="Merge within groups",
                                   ascii=False,
                                   ncols=80):
        # get index
        index_moco_i = group_indexes[iGroup]
        n_moco_i = len(index_moco_i)
        # concatenate images across time, within this group
        file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) +
                                        ext_data)
        im_dwi_list = []
        for it in range(n_moco_i):
            im_dwi_list.append(im_data_split_list[index_moco_i[it]])
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i,
                                                      verbose=0)
        # Average across time
        list_file_group.append(
            os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' +
                         ext_data))
        im_dwi_out.mean(dim=3).save(list_file_group[-1])

    # Merge across groups
    printv('\nMerge across groups...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    fname_dw_list = []
    for iGroup in range(nb_groups):
        fname_dw_list.append(list_file_group[iGroup])
    im_dw_list = [Image(fname) for fname in fname_dw_list]
    concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0)

    # Cleanup
    del im, im_data_split_list

    # ==================================================================================================================
    # Estimate moco
    # ==================================================================================================================

    # Initialize another class instance that will be passed on to the moco() function
    param_moco = deepcopy(param)

    if param.is_diffusion:
        # Estimate moco on b0 groups
        printv(
            '\n-------------------------------------------------------------------------------',
            param.verbose)
        printv('  Estimating motion on b=0 images...', param.verbose)
        printv(
            '-------------------------------------------------------------------------------',
            param.verbose)
        param_moco.file_data = 'b0.nii'
        # Identify target image
        if index_moco[0] != 0:
            # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that
            # case select it as the target image for registration of all b=0
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[index_moco[0] - 1]).zfill(4) + ext_data)
        else:
            # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[0]).zfill(4) + ext_data)
        # Run moco
        param_moco.path_out = ''
        param_moco.todo = 'estimate_and_apply'
        param_moco.mat_moco = 'mat_b0groups'
        file_mat_b0, _ = moco(param_moco)

    # Estimate moco across groups
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Estimating motion across groups...', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_datasubgroup
    param_moco.file_target = list_file_group[
        0]  # target is the first volume (closest to the first b=0 if DWI scan)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat_datasub_group, _ = moco(param_moco)

    # Spline Regularization along T
    if param.spline_fitting:
        # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy)
        raise NotImplementedError()
        # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # ==================================================================================================================
    # Apply moco
    # ==================================================================================================================

    # If group_size>1, assign transformation to each individual ungrouped 3d volume
    if param.group_size > 1:
        file_mat_datasub = []
        for iz in range(len(file_mat_datasub_group)):
            # duplicate by factor group_size the transformation file for each it
            #  example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2
            file_mat_datasub.append(
                functools.reduce(operator.iconcat,
                                 [[i] * param.group_size
                                  for i in file_mat_datasub_group[iz]], []))
    else:
        file_mat_datasub = file_mat_datasub_group

    # Copy transformations to mat_final folder and rename them appropriately
    copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param)
    if param.is_diffusion:
        copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param)

    # Apply moco on all dmri data
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Apply moco', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = list_file_group[
        0]  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''  # TODO not used in moco()
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'
    file_mat_data, im_moco = moco(param_moco)

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_moco.header = im_data.header
    im_moco.save(verbose=0)

    # Average across time
    if param.is_diffusion:
        # generate b0_moco_mean and dwi_moco_mean
        args = [
            '-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1',
            '-v', '0'
        ]
        if not param.fname_bvals == '':
            # if bvals file is provided
            args += ['-bval', param.fname_bvals]
        fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main(
            argv=args)
    else:
        fname_moco_mean = add_suffix(im_moco.absolutepath, '_mean')
        im_moco.mean(dim=3).save(fname_moco_mean)

    # Extract and output the motion parameters (doesn't work for sagittal orientation)
    printv('Extract motion parameters...')
    if param.output_motion_param:
        if param.is_sagittal:
            printv(
                'Motion parameters cannot be generated for sagittal images.',
                1, 'warning')
        else:
            files_warp_X, files_warp_Y = [], []
            moco_param = []
            for fname_warp in file_mat_data[0]:
                # Cropping the image to keep only one voxel in the XY plane
                im_warp = Image(fname_warp + param.suffix_mat)
                im_warp.data = np.expand_dims(np.expand_dims(
                    im_warp.data[0, 0, :, :, :], axis=0),
                                              axis=0)

                # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters
                #fname_warp_crop = fname_warp + '_crop_' + ext_mat
                # files_warp.append(fname_warp_crop)
                # im_warp.save(fname_warp_crop)

                # Separating the three components and saving X and Y only (Z is equal to 0 by default).
                im_warp_XYZ = multicomponent_split(im_warp)

                fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat
                im_warp_XYZ[0].save(fname_warp_crop_X)
                files_warp_X.append(fname_warp_crop_X)

                fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat
                im_warp_XYZ[1].save(fname_warp_crop_Y)
                files_warp_Y.append(fname_warp_crop_Y)

                # Calculating the slice-wise average moco estimate to provide a QC file
                moco_param.append([
                    np.mean(np.ravel(im_warp_XYZ[0].data)),
                    np.mean(np.ravel(im_warp_XYZ[1].data))
                ])

            # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters
            # im_warp = [Image(fname) for fname in files_warp]
            # im_warp_concat = concat_data(im_warp, dim=3)
            # im_warp_concat.save('fmri_moco_params.nii')

            # Concatenating the moco parameters into a time series for X and Y components.
            im_warp_X = [Image(fname) for fname in files_warp_X]
            im_warp_concat = concat_data(im_warp_X, dim=3)
            im_warp_concat.save(file_moco_params_x)

            im_warp_Y = [Image(fname) for fname in files_warp_Y]
            im_warp_concat = concat_data(im_warp_Y, dim=3)
            im_warp_concat.save(file_moco_params_y)

            # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC
            with open(file_moco_params_csv, 'wt') as out_file:
                tsv_writer = csv.writer(out_file, delimiter='\t')
                tsv_writer.writerow(['X', 'Y'])
                for mocop in moco_param:
                    tsv_writer.writerow([mocop[0], mocop[1]])

    # Generate output files
    printv('\nGenerate output files...', param.verbose)
    fname_moco = os.path.join(
        path_out_abs,
        add_suffix(os.path.basename(param.fname_data), param.suffix))
    generate_output_file(im_moco.absolutepath, fname_moco)
    if param.is_diffusion:
        generate_output_file(fname_b0_mean, add_suffix(fname_moco, '_b0_mean'))
        generate_output_file(fname_dwi_mean,
                             add_suffix(fname_moco, '_dwi_mean'))
    else:
        generate_output_file(fname_moco_mean, add_suffix(fname_moco, '_mean'))
    if os.path.exists(file_moco_params_csv):
        generate_output_file(file_moco_params_x,
                             os.path.join(path_out_abs, file_moco_params_x),
                             squeeze_data=False)
        generate_output_file(file_moco_params_y,
                             os.path.join(path_out_abs, file_moco_params_y),
                             squeeze_data=False)
        generate_output_file(file_moco_params_csv,
                             os.path.join(path_out_abs, file_moco_params_csv))

    # Delete temporary files
    if param.remove_temp_files == 1:
        printv('\nDelete temporary files...', param.verbose)
        rmtree(path_tmp, verbose=param.verbose)

    # come back to working directory
    os.chdir(curdir)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        param.verbose)

    display_viewer_syntax([
        os.path.join(
            param.path_out,
            add_suffix(os.path.basename(param.fname_data), param.suffix)),
        param.fname_data
    ],
                          mode='ortho,ortho')
Ejemplo n.º 3
0
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        list_warp = self.list_warp  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        islabel = False
        if self.interp == 'label':
            islabel = True
            self.interp = 'nn'

        interp = get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # list_warp = list_warp.replace(' ', '')  # remove spaces
        # list_warp = list_warp.split(",")  # parse with comma
        for idx_warp, path_warp in enumerate(self.list_warp):
            # Check if this transformation should be inverted
            if path_warp in self.list_warpinv:
                use_inverse.append('-i')
                # list_warp[idx_warp] = path_warp[1:]  # remove '-'
                fname_warp_list_invert += [[
                    use_inverse[idx_warp], list_warp[idx_warp]
                ]]
            else:
                use_inverse.append('')
                fname_warp_list_invert += [[path_warp]]
            path_warp = list_warp[idx_warp]
            if path_warp.endswith((".nii", ".nii.gz")) \
                    and Image(list_warp[idx_warp]).header.get_intent()[0] != 'vector':
                raise ValueError(
                    "Displacement field in {} is invalid: should be encoded"
                    " in a 5D file with vector intent code"
                    " (see https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h"
                    .format(path_warp))
        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = extract_fname(
            fname_warp_list_invert[-1][-1])
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # check if destination file is 3d
        # check_dim(fname_dest, dim_lst=[3]) # PR 2598: we decided to skip this line.

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order
        fname_warp_list_invert.reverse()
        fname_warp_list_invert = functools.reduce(lambda x, y: x + y,
                                                  fname_warp_list_invert)

        # Extract path, file and extension
        path_src, file_src, ext_src = extract_fname(fname_src)
        path_dest, file_dest, ext_dest = extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = os.path.join(path_out, file_out + ext_out)

        # Get dimensions of data
        printv('\nGet dimensions of data...', verbose)
        img_src = Image(fname_src)
        nx, ny, nz, nt, px, py, pz, pt = img_src.dim
        # nx, ny, nz, nt, px, py, pz, pt = get_dimension(fname_src)
        printv(
            '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' +
            str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
            else:
                dim = '3'
            # if labels, dilate before resampling
            if islabel:
                printv("\nDilate labels before warping...")
                path_tmp = tmp_create(basename="apply_transfo")
                fname_dilated_labels = os.path.join(path_tmp,
                                                    "dilated_data.nii")
                # dilate points
                dilate(Image(fname_src), 4, 'ball').save(fname_dilated_labels)
                fname_src = fname_dilated_labels

            printv(
                "\nApply transformation and resample to destination space...",
                verbose)
            run_proc([
                'isct_antsApplyTransforms', '-d', dim, '-i', fname_src, '-o',
                fname_out, '-t'
            ] + fname_warp_list_invert + ['-r', fname_dest] + interp,
                     is_sct_binary=True)

        # if 4d, loop across the T dimension
        else:
            if islabel:
                raise NotImplementedError

            dim = '4'
            path_tmp = tmp_create(basename="apply_transfo")

            # convert to nifti into temp folder
            printv('\nCopying input data to tmp folder and convert to nii...',
                   verbose)
            img_src.save(os.path.join(path_tmp, "data.nii"))
            copy(fname_dest, os.path.join(path_tmp, file_dest + ext_dest))
            fname_warp_list_tmp = []
            for fname_warp in list_warp:
                path_warp, file_warp, ext_warp = extract_fname(fname_warp)
                copy(fname_warp, os.path.join(path_tmp, file_warp + ext_warp))
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            curdir = os.getcwd()
            os.chdir(path_tmp)

            # split along T dimension
            printv('\nSplit along T dimension...', verbose)

            im_dat = Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = sct_image.split_data(im_dat, 3)
            for im in data_split_list:
                im.save()

            # apply transfo
            printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'

                status, output = run_proc([
                    'isct_antsApplyTransforms',
                    '-d',
                    '3',
                    '-i',
                    file_data_split,
                    '-o',
                    file_data_split_reg,
                    '-t',
                ] + fname_warp_list_invert_tmp + [
                    '-r',
                    file_dest + ext_dest,
                ] + interp,
                                          verbose,
                                          is_sct_binary=True)

            # Merge files back
            printv('\nMerge file back...', verbose)
            import glob
            path_out, name_out, ext_out = extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            fname_list.sort()
            im_list = [Image(fname) for fname in fname_list]
            im_out = sct_image.concat_data(im_list, 3, im_header['pixdim'])
            im_out.save(name_out + ext_out)

            os.chdir(curdir)
            generate_output_file(os.path.join(path_tmp, name_out + ext_out),
                                 fname_out)
            # Delete temporary folder if specified
            if remove_temp_files:
                printv('\nRemove temporary files...', verbose)
                rmtree(path_tmp, verbose=verbose)

        # Copy affine matrix from destination space to make sure qform/sform are the same
        printv(
            "Copy affine matrix from destination space to make sure qform/sform are the same.",
            verbose)
        im_src_reg = Image(fname_out)
        im_src_reg.copy_qform_from_ref(Image(fname_dest))
        im_src_reg.save(
            verbose=0
        )  # set verbose=0 to avoid warning message about rewriting file

        if islabel:
            printv(
                "\nTake the center of mass of each registered dilated labels..."
            )
            labeled_img = cubic_to_point(im_src_reg)
            labeled_img.save(path=fname_out)
            if remove_temp_files:
                printv('\nRemove temporary files...', verbose)
                rmtree(path_tmp, verbose=verbose)

        # Crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # If the last transformation is not an affine transfo, we need to compute the matrix space of the concatenated
        # warping field
        if not isLastAffine and crop_reference in [1, 2]:
            printv('Last transformation is not affine.')
            if crop_reference in [1, 2]:
                # Extract only the first ndim of the warping field
                img_warp = Image(warping_field)
                if dim == '2':
                    img_warp_ndim = Image(img_src.data[:, :], hdr=img_warp.hdr)
                elif dim in ['3', '4']:
                    img_warp_ndim = Image(img_src.data[:, :, :],
                                          hdr=img_warp.hdr)
                # Set zero to everything outside the warping field
                cropper = ImageCropper(Image(fname_out))
                cropper.get_bbox_from_ref(img_warp_ndim)
                if crop_reference == 1:
                    printv(
                        'Cropping strategy is: keep same matrix size, put 0 everywhere around warping field'
                    )
                    img_out = cropper.crop(background=0)
                elif crop_reference == 2:
                    printv(
                        'Cropping strategy is: crop around warping field (the size of warping field will '
                        'change)')
                    img_out = cropper.crop()
                img_out.save(fname_out)

        display_viewer_syntax([fname_dest, fname_out], verbose=verbose)
Ejemplo n.º 4
0
def main(args=None):
    # initialize parameters
    param = Param()
    # call main function
    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
    else:
        arguments = parser.parse_args(
            args=None if sys.argv[1:] else ['--help'])

    fname_data = arguments.i
    fname_bvecs = arguments.bvec
    average = arguments.a
    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level
    remove_temp_files = arguments.r
    path_out = arguments.ofolder

    fname_bvals = arguments.bval
    if arguments.bvalmin:
        param.bval_min = arguments.bvalmin

    # Initialization
    start_time = time.time()

    # printv(arguments)
    printv('\nInput parameters:', verbose)
    printv('  input file ............' + fname_data, verbose)
    printv('  bvecs file ............' + fname_bvecs, verbose)
    printv('  bvals file ............' + fname_bvals, verbose)
    printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    # create temporary folder
    path_tmp = tmp_create(basename="dmri_separate")

    # copy files into tmp folder and convert to nifti
    printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        printv('ERROR in convert.', 1, 'error')
    copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    printv(
        '.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),
        verbose)

    # Identify b=0 and DWI images
    printv(fname_bvals)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals,
                                                     param.bval_min, verbose)

    # Split into T dimension
    printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:
        im_d.save()

    # Merge b=0 images
    printv('\nMerge b=0...', verbose)
    l = []
    for it in range(nb_b0):
        l.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        printv('\nAverage b=0...', verbose)
        run_proc([
            'sct_maths', '-i', b0_name + ext, '-o', b0_mean_name + ext,
            '-mean', 't'
        ], verbose)

    # Merge DWI
    l = []
    for it in range(nb_dwi):
        l.append(dmri_name + '_T' + str(index_dwi[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        printv('\nAverage DWI...', verbose)
        run_proc([
            'sct_maths', '-i', dwi_name + ext, '-o', dwi_mean_name + ext,
            '-mean', 't'
        ], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(
        os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(
        os.path.join(path_out, dwi_mean_name + ext_data))
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, b0_name + ext),
                         fname_b0,
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, dwi_name + ext),
                         fname_dwi,
                         verbose=verbose)
    if average:
        generate_output_file(os.path.join(path_tmp, b0_mean_name + ext),
                             fname_b0_mean,
                             verbose=verbose)
        generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext),
                             fname_dwi_mean,
                             verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
Ejemplo n.º 5
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initialize parameters
    param = Param()

    fname_data = arguments.i
    fname_bvecs = arguments.bvec
    average = arguments.a
    remove_temp_files = arguments.r
    path_out = arguments.ofolder

    fname_bvals = arguments.bval
    if arguments.bvalmin:
        param.bval_min = arguments.bvalmin

    # Initialization
    start_time = time.time()

    # printv(arguments)
    printv('\nInput parameters:', verbose)
    printv('  input file ............' + fname_data, verbose)
    printv('  bvecs file ............' + fname_bvecs, verbose)
    printv('  bvals file ............' + fname_bvals, verbose)
    printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    # create temporary folder
    path_tmp = tmp_create(basename="dmri_separate")

    # copy files into tmp folder and convert to nifti
    printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        printv('ERROR in convert.', 1, 'error')
    copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    printv(
        '.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),
        verbose)

    # Identify b=0 and DWI images
    printv(fname_bvals)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals,
                                                     param.bval_min, verbose)

    # Split into T dimension
    printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:
        im_d.save()

    # Merge b=0 images
    printv('\nMerge b=0...', verbose)
    fname_in_list_b0 = []
    for it in range(nb_b0):
        fname_in_list_b0.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) +
                                ext)
    im_in_list_b0 = [Image(fname) for fname in fname_in_list_b0]
    concat_data(im_in_list_b0, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        printv('\nAverage b=0...', verbose)
        img = Image(b0_name + ext)
        out = img.copy()
        dim_idx = 3
        if len(np.shape(img.data)) < dim_idx + 1:
            raise ValueError("Expecting image with 4 dimensions!")
        out.data = np.mean(out.data, dim_idx)
        out.save(path=b0_mean_name + ext)

    # Merge DWI
    fname_in_list_dwi = []
    for it in range(nb_dwi):
        fname_in_list_dwi.append(dmri_name + '_T' +
                                 str(index_dwi[it]).zfill(4) + ext)
    im_in_list_dwi = [Image(fname) for fname in fname_in_list_dwi]
    concat_data(im_in_list_dwi, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        printv('\nAverage DWI...', verbose)
        img = Image(dwi_name + ext)
        out = img.copy()
        dim_idx = 3
        if len(np.shape(img.data)) < dim_idx + 1:
            raise ValueError("Expecting image with 4 dimensions!")
        out.data = np.mean(out.data, dim_idx)
        out.save(path=dwi_mean_name + ext)

    # come back
    os.chdir(curdir)

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(
        os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(
        os.path.join(path_out, dwi_mean_name + ext_data))
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, b0_name + ext),
                         fname_b0,
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, dwi_name + ext),
                         fname_dwi,
                         verbose=verbose)
    if average:
        generate_output_file(os.path.join(path_tmp, b0_mean_name + ext),
                             fname_b0_mean,
                             verbose=verbose)
        generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext),
                             fname_dwi_mean,
                             verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean