def dummy_segmentation_4d(vol_num=10, create_bvecs=False, size_arr=(256, 256, 256), pixdim=(1, 1, 1), dtype=np.float64,
                          orientation='LPI', shape='rectangle', angle_RL=0, angle_AP=0, angle_IS=0, radius_RL=5.0,
                          radius_AP=3.0, degree=2, interleaved=False, zeroslice=[], debug=False):
    Create a dummy 4D segmentation (dMRI/fMRI) and dummy bvecs file (optional)
    :param vol_num: int: number of volumes in 4D data
    :param create_bvecs: bool: create dummy bvecs file (necessary e.g. for sct_dmri_moco)
    other parameters are same as in dummy_segmentation function
    :return: Image object

    img_list = []

    # Loop across individual volumes of 4D data
    for volume in range(0,vol_num):
        # set debug=True in line below for saving individual volumes into individual nii files
        img_list.append(dummy_segmentation(size_arr=size_arr, pixdim=pixdim, dtype=dtype, orientation=orientation,
                                           shape=shape, angle_RL=angle_RL, angle_AP=angle_AP, angle_IS=angle_IS,
                                           radius_RL=radius_RL, radius_AP=radius_AP, degree=degree, zeroslice=zeroslice,
                                           interleaved=interleaved, debug=False))

    # Concatenate individual 3D images into 4D data
    img_4d = concat_data(img_list, 3)
    if debug:
        out_name = datetime.now().strftime("%Y%m%d%H%M%S%f")
        file_4d_data = 'tmp_dummy_4d_' + out_name + '.nii.gz'
        img_4d.save(file_4d_data, verbose=0)

    # Create a dummy bvecs file (necessary e.g. for sct_dmri_moco)
    if create_bvecs:
        n_b0 = 1                # number of b0
        n_dwi = vol_num-n_b0    # number of dwi
        bvecs_dummy = ['', '', '']
        bvec_b0 = np.array([[0.0, 0.0, 0.0]] * n_b0)
        bvec_dwi = np.array([[uniform(0,1), uniform(0,1), uniform(0,1)]] * n_dwi)
        bvec = np.concatenate((bvec_b0,bvec_dwi),axis=0)
        # Concatenate bvecs
        for i in (0, 1, 2):
            bvecs_dummy[i] += ' '.join(str(v) for v in map(lambda n: '%.16f' % n, bvec[:, i]))
            bvecs_dummy[i] += ' '
        bvecs_concat = '\n'.join(str(v) for v in bvecs_dummy)  # transform list into lines of strings
        if debug:
            new_f = open('tmp_dummy_4d_' + out_name + '.bvec', 'w')

    return img_4d
def main(args=None):
    # initialize parameters
    param = Param()
    # call main function
    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
        arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    fname_data = arguments.i
    fname_bvecs = arguments.bvec
    average = arguments.a
    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level
    remove_temp_files = arguments.r
    path_out = arguments.ofolder

    fname_bvals = arguments.bval
    if arguments.bvalmin:
        param.bval_min = arguments.bvalmin

    # Initialization
    start_time = time.time()

    # printv(arguments)
    printv('\nInput parameters:', verbose)
    printv('  input file ............' + fname_data, verbose)
    printv('  bvecs file ............' + fname_bvecs, verbose)
    printv('  bvals file ............' + fname_bvals, verbose)
    printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    # create temporary folder
    path_tmp = tmp_create(basename="dmri_separate")

    # copy files into tmp folder and convert to nifti
    printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        printv('ERROR in convert.', 1, 'error')
    copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), verbose)

    # Identify b=0 and DWI images
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals, param.bval_min, verbose)

    # Split into T dimension
    printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:

    # Merge b=0 images
    printv('\nMerge b=0...', verbose)
    from sct_image import concat_data
    l = []
    for it in range(nb_b0):
        l.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        printv('\nAverage b=0...', verbose)
        run_proc(['sct_maths', '-i', b0_name + ext, '-o', b0_mean_name + ext, '-mean', 't'], verbose)

    # Merge DWI
    l = []
    for it in range(nb_dwi):
        l.append(dmri_name + '_T' + str(index_dwi[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        printv('\nAverage DWI...', verbose)
        run_proc(['sct_maths', '-i', dwi_name + ext, '-o', dwi_mean_name + ext, '-mean', 't'], verbose)

    # come back

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(os.path.join(path_out, dwi_mean_name + ext_data))
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, b0_name + ext), fname_b0, verbose=verbose)
    generate_output_file(os.path.join(path_tmp, dwi_name + ext), fname_dwi, verbose=verbose)
    if average:
        generate_output_file(os.path.join(path_tmp, b0_mean_name + ext), fname_b0_mean, verbose=verbose)
        generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext), fname_dwi_mean, verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
def dmri_moco(param):

    file_data = 'dmri.nii'
    file_data_dirname, file_data_basename, file_data_ext = sct.extract_fname(
    file_b0 = 'b0.nii'
    file_dwi = 'dwi.nii'
    ext_data = '.nii.gz'  # workaround "too many open files" by slurping the data
    mat_final = 'mat_final/'
    file_dwi_group = 'dwi_averaged_groups.nii'
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz),

    # Identify b=0 and DWI images
    index_b0, index_dwi, nb_b0, nb_dwi = sct_dmri_separate_b0_and_dwi.identify_b0(
        'bvecs.txt', param.fname_bvals, param.bval_min, param.verbose)

    # check if dmri and bvecs are the same size
    if not nb_b0 + nb_dwi == nt:
            '\nERROR in ' + os.path.basename(__file__) + ': Size of data (' +
            str(nt) + ') and size of bvecs (' + str(nb_b0 + nb_dwi) +
            ') are not the same. Check your bvecs file.\n', 1, 'error')

    # Prepare NIFTI (mean/groups...)
    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")

    # Merge b=0 images
    sct.printv('\nMerge b=0...', param.verbose)
    im_b0_list = []
    for it in range(nb_b0):
    im_b0_out = concat_data(im_b0_list, 3).save(file_b0)
    sct.printv(('  File created: ' + file_b0), param.verbose)

    # Average b=0 images
    sct.printv('\nAverage b=0...', param.verbose)
    file_b0_mean = sct.add_suffix(file_b0, '_mean')
    sct.run(['sct_maths', '-i', file_b0, '-o', file_b0_mean, '-mean', 't'],

    # Number of DWI groups
    nb_groups = int(math.floor(nb_dwi / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_dwi[(iGroup *
                                        param.group_size):((iGroup + 1) *

    # add the remaining images to the last DWI group
    nb_remaining = nb_dwi % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_dwi[len(index_dwi) -

    file_dwi_dirname, file_dwi_basename, file_dwi_ext = sct.extract_fname(
    # DWI groups
    file_dwi_mean = []
    for iGroup in tqdm(range(nb_groups),
                       desc="Merge within groups",
        # get index
        index_dwi_i = group_indexes[iGroup]
        nb_dwi_i = len(index_dwi_i)
        # Merge DW Images
        file_dwi_merge_i = os.path.join(
            file_dwi_dirname, file_dwi_basename + '_' + str(iGroup) + ext_data)
        im_dwi_list = []
        for it in range(nb_dwi_i):
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i)
        # Average DW Images
        file_dwi_mean_i = os.path.join(
            file_dwi_basename + '_mean_' + str(iGroup) + ext_data)
            "sct_maths", "-i", file_dwi_merge_i, "-o", file_dwi_mean[iGroup],
            "-mean", "t"
        ], 0)

    # Merge DWI groups means
    sct.printv('\nMerging DW files...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    im_dw_list = []
    for iGroup in range(nb_groups):
    im_dw_out = concat_data(im_dw_list, 3).save(file_dwi_group)

    # Average DW Images
    # TODO: USEFULL ???
    sct.printv('\nAveraging all DW images...', param.verbose)
        "sct_maths", "-i", file_dwi_group, "-o",
        file_dwi_group + '_mean' + ext_data, "-mean", "t"
    ], param.verbose)

    # segment dwi images using otsu algorithm
    if param.otsu:
        sct.printv('\nSegment group DWI using OTSU algorithm...',
        # import module
        otsu = importlib.import_module('sct_otsu')
        # get class from module
        param_otsu = otsu.param()  #getattr(otsu, param)
        param_otsu.fname_data = file_dwi_group
        param_otsu.threshold = param.otsu
        param_otsu.file_suffix = '_seg'
        # run otsu
        file_dwi_group = file_dwi_group + '_seg.nii'


    # Estimate moco on b0 groups
    sct.printv('  Estimating motion on b=0 images...', param.verbose)
    param_moco = param
    param_moco.file_data = 'b0.nii'
    # identify target image
    if index_dwi[0] != 0:
        # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that case
        # select it as the target image for registration of all b=0
        param_moco.file_target = os.path.join(
            file_data_dirname, file_data_basename + '_T' +
            str(index_b0[index_dwi[0] - 1]).zfill(4) + ext_data)
        # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
        param_moco.file_target = os.path.join(
            file_data_basename + '_T' + str(index_b0[0]).zfill(4) + ext_data)

    param_moco.path_out = ''
    param_moco.todo = 'estimate'
    param_moco.mat_moco = 'mat_b0groups'
    file_mat_b0 = moco.moco(param_moco)

    # Estimate moco on dwi groups
    sct.printv('  Estimating motion on DW images...', param.verbose)
    param_moco.file_data = file_dwi_group
    param_moco.file_target = file_dwi_mean[
        0]  # target is the first DW image (closest to the first b=0)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_dwigroups'
    file_mat_dwi = moco.moco(param_moco)

    # create final mat folder

    # Copy b=0 registration matrices
    # TODO: use file_mat_b0 and file_mat_dwi instead of the hardcoding below
    sct.printv('\nCopy b=0 registration matrices...', param.verbose)
    for it in range(nb_b0):
            'mat_b0groups/' + 'mat.Z0000T' + str(it).zfill(4) + ext_mat,
            mat_final + 'mat.Z0000T' + str(index_b0[it]).zfill(4) + ext_mat)

    # Copy DWI registration matrices
    sct.printv('\nCopy DWI registration matrices...', param.verbose)
    for iGroup in range(nb_groups):
        for dwi in range(
        ):  # we cannot use enumerate because group_indexes has 2 dim.
                'mat_dwigroups/' + 'mat.Z0000T' + str(iGroup).zfill(4) +
                ext_mat, mat_final + 'mat.Z0000T' +
                str(group_indexes[iGroup][dwi]).zfill(4) + ext_mat)

    # Spline Regularization along T
    if param.spline_fitting:
        moco.spline(mat_final, nt, nz, param.verbose, np.array(index_b0),

    # combine Eddy Matrices
    if param.run_eddy:
        param.mat_2_combine = 'mat_eddy'
        param.mat_final = mat_final

    # Apply moco on all dmri data
    sct.printv('  Apply moco', param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = os.path.join(
        file_dwi_dirname, file_dwi_basename + '_mean_' + str(0) +
        ext_data)  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_dmri = Image(file_data)

    fname_data_moco = os.path.join(file_data_dirname,
                                   file_data_basename + param.suffix + '.nii')
    im_dmri_moco = Image(fname_data_moco)
    im_dmri_moco.header = im_dmri.header

    return os.path.abspath(fname_data_moco)
def find_centerline(algo, image_fname, contrast_type, brain_bool,
                    folder_output, remove_temp_files, centerline_fname):
    Assumes RPI orientation
    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:

    # TODO: remove unnecessary i/o
    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        img_ctl, arr_ctl, _ = get_centerline(Image(image_fname),
        centerline_filename = sct.add_suffix(image_fname, "_ctr")

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {
            't2': {
                'size': (80, 80),
                'mean': 51.1417,
                'std': 57.4408
            't2s': {
                'size': (80, 80),
                'mean': 68.8591,
                'std': 71.4659
            't1': {
                'size': (80, 80),
                'mean': 55.7359,
                'std': 64.3149
            'dwi': {
                'size': (80, 80),
                'mean': 55.744,
                'std': 45.003
        dct_params_ctr = {
            't2': {
                'features': 16,
                'dilation_layers': 2
            't2s': {
                'features': 8,
                'dilation_layers': 3
            't1': {
                'features': 24,
                'dilation_layers': 3
            'dwi': {
                'features': 8,
                'dilation_layers': 2

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data',
        ctr_model = nn_architecture_ctr(

        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])


        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(Image(image_fname))
        im_centerline, arr_centerline, _ = get_centerline(im_labels)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        labels_filename = sct.add_suffix(image_fname, "_labels-centerline")

    elif algo == 'file':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        # Re-orient the manual centerline

            'The parameter "-centerline" is incorrect. Please try again.')

    if algo != 'cnn':
        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])



    if bool_2d:
        im_split_lst = split_data(Image(centerline_filename), dim=2)

    return fname_res, centerline_filename
def moco(param):

    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    sct.printv('\nInput parameters:', param.verbose)
    sct.printv('  Input file ............' + file_data, param.verbose)
    sct.printv('  Reference file ........' + file_target, param.verbose)
    sct.printv('  Polynomial degree .....' + param.poly, param.verbose)
    sct.printv('  Smoothing kernel ......' + param.smooth, param.verbose)
    sct.printv('  Gradient step .........' + param.gradStep, param.verbose)
    sct.printv('  Metric ................' + param.metric, param.verbose)
    sct.printv('  Sampling ..............' + param.sampling, param.verbose)
    sct.printv('  Todo ..................' + todo, param.verbose)
    sct.printv('  Mask  .................' + param.fname_mask, param.verbose)
    sct.printv('  Output mat folder .....' + folder_mat, param.verbose)

    # create folder for mat files

    # Get size of data
    sct.printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),

    # copy file_target to a temporary file
    sct.printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target)

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
        # z-split target
        im_targetz_list = split_data(Image(file_target),
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial'
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # sct.printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in tqdm(range(nt),
                                 desc="Z=" + str(iz) + "/" +
                                 str(len(file_data_splitZ) - 1),

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
                sct.add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking
            if not param.fname_mask == '':
                input_mask = im_maskz_list[iz]
                input_mask = None
            # run 3D registration
            failed_transfo[it] = register(param,

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                sct.copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                         file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                # exit program if no transformation exists.
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,

        # Merge data along T
        file_data_splitZ_moco.append(sct.add_suffix(file, suffix))
        if todo != 'estimate':
            im_out = concat_data(file_data_splitZ_splitT_moco, 3)

    # If sagittal, merge along Z
    if param.is_sagittal:
        im_out = concat_data(file_data_splitZ_moco, 2)
        dirname, basename, ext = sct.extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)

    return file_mat
def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname):
    Assumes RPI orientation
    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:

    # TODO: remove unnecessary i/o
    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        img_ctl, arr_ctl, _ = get_centerline(Image(image_fname), algo_fitting='optic', contrast=contrast_type)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                         't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                         't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                         'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
        dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
                          't2s': {'features': 8, 'dilation_layers': 3},
                          't1': {'features': 24, 'dilation_layers': 3},
                          'dwi': {'features': 8, 'dilation_layers': 2}}

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast_type]['size'][0],

        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname, fname_res, new_resolution, 'mm', 'linear', verbose=0)

        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(Image(image_fname))
        im_centerline, arr_centerline, _ = get_centerline(im_labels)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        labels_filename = sct.add_suffix(image_fname, "_labels-centerline")

    elif algo == 'file':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        # Re-orient the manual centerline

        logger.error('The parameter "-centerline" is incorrect. Please try again.')

    if algo != 'cnn':
        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname, fname_res, new_resolution, 'mm', 'linear', verbose=0)

        resampling.resample_file(centerline_filename, centerline_filename, new_resolution, 'mm', 'linear', verbose=0)

    if bool_2d:
        im_split_lst = split_data(Image(centerline_filename), dim=2)

    return fname_res, centerline_filename
def moco_wrapper(param):
    Wrapper that performs motion correction.

    :param param: ParamMoco class
    :return: None
    file_data = 'data.nii'  # corresponds to the full input data (e.g. dmri or fmri)
    file_data_dirname, file_data_basename, file_data_ext = extract_fname(
    file_b0 = 'b0.nii'
    file_datasub = 'datasub.nii'  # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True)
    file_datasubgroup = 'datasub-groups.nii'  # concatenation of the average of each file_datasub
    file_mask = 'mask.nii'
    file_moco_params_csv = 'moco_params.tsv'
    file_moco_params_x = 'moco_params_x.nii.gz'
    file_moco_params_y = 'moco_params_y.nii.gz'
    ext_data = '.nii.gz'  # workaround "too many open files" by slurping the data
    # TODO: check if .nii can be used
    mat_final = 'mat_final/'
    # ext_mat = 'Warp.nii.gz'  # warping field

    # Start timer
    start_time = time.time()

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + param.fname_data, param.verbose)
    printv('  Group size ............ {}'.format(param.group_size),

    # Get full path
    # param.fname_data = os.path.abspath(param.fname_data)
    # param.fname_bvecs = os.path.abspath(param.fname_bvecs)
    # if param.fname_bvals != '':
    #     param.fname_bvals = os.path.abspath(param.fname_bvals)

    # Extract path, file and extension
    # path_data, file_data, ext_data = extract_fname(param.fname_data)
    # path_mask, file_mask, ext_mask = extract_fname(param.fname_mask)

    path_tmp = tmp_create(basename="moco")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...',
    convert(param.fname_data, os.path.join(path_tmp, file_data))
    if param.fname_mask != '':
                os.path.join(path_tmp, file_mask),
        # Update field in param (because used later in another function, and param class will be passed)
        param.fname_mask = file_mask

    # Build absolute output path and go to tmp folder
    curdir = os.getcwd()
    path_out_abs = os.path.abspath(param.path_out)

    # Get dimensions of data
    printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Get orientation
    printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        printv('  Treated as axial')
        param.is_sagittal = False
            'WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.'

        "\nSet suffix of transformation file name, which depends on the orientation:"
    if param.is_sagittal:
        param.suffix_mat = '0GenericAffine.mat'
            "Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the "
            "estimated transformation is a 2D affine transfo.".format(
        param.suffix_mat = 'Warp.nii.gz'
            "Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is "
            "composed of a stack of 2D Tx-Ty transformations".format(

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
            'For sagittal data group_size should be one for more robustness. Forcing group_size=1.',
            1, 'warning')
        param.group_size = 1

    if param.is_diffusion:
        # Identify b=0 and DWI images
        index_b0, index_dwi, nb_b0, nb_dwi = \
            sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min,

        # check if dmri and bvecs are the same size
        if not nb_b0 + nb_dwi == nt:
                '\nERROR in ' + os.path.basename(__file__) +
                ': Size of data (' + str(nt) + ') and size of bvecs (' +
                str(nb_b0 + nb_dwi) +
                ') are not the same. Check your bvecs file.\n', 1, 'error')

    # ==================================================================================================================
    # Prepare data (mean/groups...)
    # ==================================================================================================================

    # Split into T dimension
    printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")

    if param.is_diffusion:
        # Merge and average b=0 images
        printv('\nMerge and average b=0 data...', param.verbose)
        im_b0_list = []
        for it in range(nb_b0):
        im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0)
        # Average across time
        im_b0.mean(dim=3).save(add_suffix(file_b0, '_mean'))

        n_moco = nb_dwi  # set number of data to perform moco on (using grouping)
        index_moco = index_dwi

    # If not a diffusion scan, we will motion-correct all volumes
        n_moco = nt
        index_moco = list(range(0, nt))

    nb_groups = int(math.floor(n_moco / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_moco[(iGroup *
                                         param.group_size):((iGroup + 1) *

    # add the remaining images to a new last group (in case the total number of image is not divisible by group_size)
    nb_remaining = n_moco % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_moco[len(index_moco) -

    _, file_dwi_basename, file_dwi_ext = extract_fname(file_datasub)
    # Group data
    list_file_group = []
    for iGroup in sct_progress_bar(range(nb_groups),
                                   desc="Merge within groups",
        # get index
        index_moco_i = group_indexes[iGroup]
        n_moco_i = len(index_moco_i)
        # concatenate images across time, within this group
        file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) +
        im_dwi_list = []
        for it in range(n_moco_i):
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i,
        # Average across time
            os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' +

    # Merge across groups
    printv('\nMerge across groups...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    im_dw_list = []
    for iGroup in range(nb_groups):
    concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0)

    # Cleanup
    del im, im_data_split_list

    # ==================================================================================================================
    # Estimate moco
    # ==================================================================================================================

    # Initialize another class instance that will be passed on to the moco() function
    param_moco = deepcopy(param)

    if param.is_diffusion:
        # Estimate moco on b0 groups
        printv('  Estimating motion on b=0 images...', param.verbose)
        param_moco.file_data = 'b0.nii'
        # Identify target image
        if index_moco[0] != 0:
            # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that
            # case select it as the target image for registration of all b=0
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[index_moco[0] - 1]).zfill(4) + ext_data)
            # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[0]).zfill(4) + ext_data)
        # Run moco
        param_moco.path_out = ''
        param_moco.todo = 'estimate_and_apply'
        param_moco.mat_moco = 'mat_b0groups'
        file_mat_b0, _ = moco(param_moco)

    # Estimate moco across groups
    printv('  Estimating motion across groups...', param.verbose)
    param_moco.file_data = file_datasubgroup
    param_moco.file_target = list_file_group[
        0]  # target is the first volume (closest to the first b=0 if DWI scan)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat_datasub_group, _ = moco(param_moco)

    # Spline Regularization along T
    if param.spline_fitting:
        # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy)
        raise NotImplementedError()
        # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # ==================================================================================================================
    # Apply moco
    # ==================================================================================================================

    # If group_size>1, assign transformation to each individual ungrouped 3d volume
    if param.group_size > 1:
        file_mat_datasub = []
        for iz in range(len(file_mat_datasub_group)):
            # duplicate by factor group_size the transformation file for each it
            #  example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2
                                 [[i] * param.group_size
                                  for i in file_mat_datasub_group[iz]], []))
        file_mat_datasub = file_mat_datasub_group

    # Copy transformations to mat_final folder and rename them appropriately
    copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param)
    if param.is_diffusion:
        copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param)

    # Apply moco on all dmri data
    printv('  Apply moco', param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = list_file_group[
        0]  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''  # TODO not used in moco()
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'
    file_mat_data, im_moco = moco(param_moco)

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_moco.header = im_data.header

    # Average across time
    if param.is_diffusion:
        # generate b0_moco_mean and dwi_moco_mean
        args = [
            '-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1',
            '-v', '0'
        if not param.fname_bvals == '':
            # if bvals file is provided
            args += ['-bval', param.fname_bvals]
        fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main(
        fname_moco_mean = add_suffix(im_moco.absolutepath, '_mean')

    # Extract and output the motion parameters (doesn't work for sagittal orientation)
    printv('Extract motion parameters...')
    if param.output_motion_param:
        if param.is_sagittal:
                'Motion parameters cannot be generated for sagittal images.',
                1, 'warning')
            files_warp_X, files_warp_Y = [], []
            moco_param = []
            for fname_warp in file_mat_data[0]:
                # Cropping the image to keep only one voxel in the XY plane
                im_warp = Image(fname_warp + param.suffix_mat)
                im_warp.data = np.expand_dims(np.expand_dims(
                    im_warp.data[0, 0, :, :, :], axis=0),

                # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters
                #fname_warp_crop = fname_warp + '_crop_' + ext_mat
                # files_warp.append(fname_warp_crop)
                # im_warp.save(fname_warp_crop)

                # Separating the three components and saving X and Y only (Z is equal to 0 by default).
                im_warp_XYZ = multicomponent_split(im_warp)

                fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat

                fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat

                # Calculating the slice-wise average moco estimate to provide a QC file

            # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters
            #im_warp_concat = concat_data(files_warp, dim=3)
            # im_warp_concat.save('fmri_moco_params.nii')

            # Concatenating the moco parameters into a time series for X and Y components.
            im_warp_concat = concat_data(files_warp_X, dim=3)

            im_warp_concat = concat_data(files_warp_Y, dim=3)

            # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC
            with open(file_moco_params_csv, 'wt') as out_file:
                tsv_writer = csv.writer(out_file, delimiter='\t')
                tsv_writer.writerow(['X', 'Y'])
                for mocop in moco_param:
                    tsv_writer.writerow([mocop[0], mocop[1]])

    # Generate output files
    printv('\nGenerate output files...', param.verbose)
    fname_moco = os.path.join(
        add_suffix(os.path.basename(param.fname_data), param.suffix))
    generate_output_file(im_moco.absolutepath, fname_moco)
    if param.is_diffusion:
        generate_output_file(fname_b0_mean, add_suffix(fname_moco, '_b0_mean'))
                             add_suffix(fname_moco, '_dwi_mean'))
        generate_output_file(fname_moco_mean, add_suffix(fname_moco, '_mean'))
    if os.path.exists(file_moco_params_csv):
                             os.path.join(path_out_abs, file_moco_params_x),
                             os.path.join(path_out_abs, file_moco_params_y),
                             os.path.join(path_out_abs, file_moco_params_csv))

    # Delete temporary files
    if param.remove_temp_files == 1:
        printv('\nDelete temporary files...', param.verbose)
        rmtree(path_tmp, verbose=param.verbose)

    # come back to working directory

    # display elapsed time
    elapsed_time = time.time() - start_time
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',

            add_suffix(os.path.basename(param.fname_data), param.suffix)),
def main(args=None):
    if not args:
        args = sys.argv[1:]

    # initialize parameters
    param = Param()
    # call main function
    parser = get_parser()
    arguments = parser.parse(args)

    fname_data = arguments['-i']
    fname_bvecs = arguments['-bvec']
    average = arguments['-a']
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    remove_temp_files = int(arguments['-r'])
    path_out = arguments['-ofolder']

    if '-bval' in arguments:
        fname_bvals = arguments['-bval']
        fname_bvals = ''
    if '-bvalmin' in arguments:
        param.bval_min = arguments['-bvalmin']

    # Initialization
    start_time = time.time()

    # sct.printv(arguments)
    sct.printv('\nInput parameters:', verbose)
    sct.printv('  input file ............' + fname_data, verbose)
    sct.printv('  bvecs file ............' + fname_bvecs, verbose)
    sct.printv('  bvals file ............' + fname_bvals, verbose)
    sct.printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # create temporary folder
    path_tmp = sct.tmp_create(basename="dmri_separate", verbose=verbose)

    # copy files into tmp folder and convert to nifti
    sct.printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        sct.printv('ERROR in convert.', 1, 'error')
    sct.copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    sct.printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    sct.printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), verbose)

    # Identify b=0 and DWI images
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals, param.bval_min, verbose)

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:

    # Merge b=0 images
    sct.printv('\nMerge b=0...', verbose)
    from sct_image import concat_data
    l = []
    for it in range(nb_b0):
        l.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        sct.printv('\nAverage b=0...', verbose)
        sct.run(['sct_maths', '-i', b0_name + ext, '-o', b0_mean_name + ext, '-mean', 't'], verbose)

    # Merge DWI
    l = []
    for it in range(nb_dwi):
        l.append(dmri_name + '_T' + str(index_dwi[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        sct.printv('\nAverage DWI...', verbose)
        sct.run(['sct_maths', '-i', dwi_name + ext, '-o', dwi_mean_name + ext, '-mean', 't'], verbose)

    # come back

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(os.path.join(path_out, dwi_mean_name + ext_data))
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(os.path.join(path_tmp, b0_name + ext), fname_b0, verbose)
    sct.generate_output_file(os.path.join(path_tmp, dwi_name + ext), fname_dwi, verbose)
    if average:
        sct.generate_output_file(os.path.join(path_tmp, b0_mean_name + ext), fname_b0_mean, verbose)
        sct.generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext), fname_dwi_mean, verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        sct.printv('\nRemove temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
def moco(param):

    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    sct.printv('\nInput parameters:', param.verbose)
    sct.printv('  Input file ............' + file_data, param.verbose)
    sct.printv('  Reference file ........' + file_target, param.verbose)
    sct.printv('  Polynomial degree .....' + param.poly, param.verbose)
    sct.printv('  Smoothing kernel ......' + param.smooth, param.verbose)
    sct.printv('  Gradient step .........' + param.gradStep, param.verbose)
    sct.printv('  Metric ................' + param.metric, param.verbose)
    sct.printv('  Sampling ..............' + param.sampling, param.verbose)
    sct.printv('  Todo ..................' + todo, param.verbose)
    sct.printv('  Mask  .................' + param.fname_mask, param.verbose)
    sct.printv('  Output mat folder .....' + folder_mat, param.verbose)

    # create folder for mat files

    # Get size of data
    sct.printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv(('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)), verbose)

    # copy file_target to a temporary file
    sct.printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target)

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
        # z-split target
        im_targetz_list = split_data(Image(file_target), dim=dim_sag, squeeze_data=False)
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask), dim=dim_sag, squeeze_data=False)
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False)
            im_maskz_list = [Image(file_mask)]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
    sct.printv('\nRegister. Loop across Z (note: there is only one Z if orientation is axial')
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # sct.printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in tqdm(range(nt), unit='iter', unit_scale=False,
                                 desc="Z=" + str(iz) + "/" + str(len(file_data_splitZ)-1), ascii=True, ncols=80):

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(folder_mat, "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
            file_data_splitZ_splitT_moco.append(sct.add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking
            if not param.fname_mask == '':
                input_mask = im_maskz_list[iz]
                input_mask = None
            # run 3D registration
            failed_transfo[it] = register(param, file_data_splitZ_splitT[it], file_target_splitZ[iz], file_mat[iz][it],
                                          file_data_splitZ_splitT_moco[it], im_mask=input_mask)

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) + data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                sct.printv('  transfo #' + str(fT[it]) + ' --> use transfo #' + str(gT[index_good]), verbose)
                # copy transformation
                sct.copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz', file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                sct_apply_transfo.main(args=['-i', file_data_splitZ_splitT[fT[it]],
                                             '-d', file_target,
                                             '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz',
                                             '-o', file_data_splitZ_splitT_moco[fT[it]],
                                             '-x', param.interp])
                # exit program if no transformation exists.
                sct.printv('\nERROR in ' + os.path.basename(__file__) + ': No good transformation exist. Exit program.\n', verbose, 'error')

        # Merge data along T
        file_data_splitZ_moco.append(sct.add_suffix(file, suffix))
        if todo != 'estimate':
            im_out = concat_data(file_data_splitZ_splitT_moco, 3)

    # If sagittal, merge along Z
    if param.is_sagittal:
        im_out = concat_data(file_data_splitZ_moco, 2)
        dirname, basename, ext = sct.extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)

    return file_mat
def create_mask(param):

    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        sct.check_file_exist(method_val, param.verbose)

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = os.path.abspath(param.file_prefix + file_data + ext_data)

    path_tmp = sct.tmp_create(basename="create_mask", verbose=param.verbose)

    sct.printv('\nOrientation:', param.verbose)
    orientation_input = Image(param.fname_data).orientation
    sct.printv('  ' + orientation_input, param.verbose)

    # copy input data to tmp folder and re-orient to RPI
    Image(param.fname_data).change_orientation("RPI").save(os.path.join(path_tmp, "data_RPI.nii"))
    if method_type == 'centerline':
        Image(method_val).change_orientation("RPI").save(os.path.join(path_tmp, "centerline_RPI.nii"))
    if method_type == 'point':
        Image(method_val).change_orientation("RPI").save(os.path.join(path_tmp, "point_RPI.nii"))

    # go to tmp folder
    curdir = os.getcwd()

    # Get dimensions of data
    im_data = Image('data_RPI.nii')
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('\nDimensions:', param.verbose)
    sct.printv(im_data.dim, param.verbose)
    # in case user input 4d data
    if nt != 1:
        sct.printv('WARNING in ' + os.path.basename(__file__) + ': Input image is 4d but output mask will be 3D from first time slice.', param.verbose, 'warning')
        # extract first volume to have 3d reference
        nii = msct_image.empty_like(Image('data_RPI.nii'))
        data3d = nii.data[:, :, :, 0]
        nii.data = data3d

    if method_type == 'coord':
        # parse to get coordinate
        coord = [x for x in map(int, method_val.split('x'))]

    if method_type == 'point':
        # get file name
        # extract coordinate of point
        sct.printv('\nExtract coordinate of point...', param.verbose)
        # TODO: change this way to remove dependence to sct.run. ProcessLabels.display_voxel returns list of coordinates
        status, output = sct.run(['sct_label_utils', '-i', 'point_RPI.nii', '-display'], verbose=param.verbose)
        # parse to get coordinate
        # TODO fixup... this is quite magic
        coord = output[output.find('Position=') + 10:-17].split(',')

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = np.round(float(nx) / 2), np.round(float(ny) / 2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline_RPI.nii'
        # generate volume with line along Z at coordinates 'coord'
        sct.printv('\nCreate line...', param.verbose)
        fname_centerline = create_line(param, 'data_RPI.nii', coord, nz)

    # create mask
    sct.printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    spacing = hdr.structarr['pixdim']
    data_centerline = centerline.get_data()  # get centerline
    # if data is 2D, reshape with empty third dimension
    if len(data_centerline.shape) == 2:
        data_centerline_shape = list(data_centerline.shape)
        data_centerline = data_centerline.reshape(data_centerline_shape)
    z_centerline_not_null = [iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()]
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        if iz in z_centerline_not_null:
            cx[iz], cy[iz] = ndimage.measurements.center_of_mass(np.array(data_centerline[:, :, iz]))
    # create 2d masks
    file_mask = 'data_mask'
    for iz in range(nz):
        if iz not in z_centerline_not_null:
            # write an empty nifty volume
            img = nibabel.Nifti1Image(data_centerline[:, :, iz], None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))
            center = np.array([cx[iz], cy[iz]])
            mask2d = create_mask2d(param, center, param.shape, param.size, im_data=im_data)
            # Write NIFTI volumes
            img = nibabel.Nifti1Image(mask2d, None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))

    fname_list = [file_mask + str(iz) + '.nii' for iz in range(nz)]
    im_out = concat_data(fname_list, dim=2).save('mask_RPI.nii.gz')

    im_out.header = Image(param.fname_data).header

    # come back

    # Remove temporary files
    if param.remove_temp_files == 1:
        sct.printv('\nRemove temporary files...', param.verbose)

    sct.display_viewer_syntax([param.fname_data, param.fname_out], colormaps=['gray', 'red'], opacities=['', '0.5'])
def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname):
    Assumes RPI orientation

    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:

    im = Image(image_fname)
    ctl_absolute_path = sct.add_suffix(im.absolutepath, "_ctr")

    # isct_spine_detect requires nz > 1
    if im.dim[2] == 1:
        im = concat_data([im, im], dim=2)
        im.hdr['dim'][3] = 2  # Needs to be change manually since dim not updated during concat_data
        bool_2d = True
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        im_ctl, _, _, _ = get_centerline(im,
                                        ParamCenterline(algo_fitting='optic', contrast=contrast_type))

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                         't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                         't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                         'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
        dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
                          't2s': {'features': 8, 'dilation_layers': 3},
                          't1': {'features': 24, 'dilation_layers': 3},
                          'dwi': {'features': 8, 'dilation_layers': 2}}

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast_type]['size'][0],

        # compute the heatmap
        im_heatmap, z_max = heatmap(im=im,
        im_ctl, _, _, _ = get_centerline(im_heatmap,
                                        ParamCenterline(algo_fitting='optic', contrast=contrast_type))

        if z_max is not None:
            sct.printv('Cropping brain section.')
            im_ctl.data[:, :, z_max:] = 0

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(im)
        im_ctl, _, _, _ = get_centerline(im_labels, param=ParamCenterline())

    elif algo == 'file':
        im_ctl = Image(centerline_fname)

        logger.error('The parameter "-centerline" is incorrect. Please try again.')

    # TODO: for some reason, when algo == 'file', the absolutepath is changed to None out of the method find_centerline
    im_ctl.absolutepath = ctl_absolute_path

    if bool_2d:
        im_ctl = split_data(im_ctl, dim=2)[0]

    if algo != 'viewer':
        im_labels = None

    # TODO: remove unecessary return params
    return "dummy_file_name", im_ctl, im_labels
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        fname_warp_list = self.warp_input  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        interp = sct.get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        sct.printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # fname_warp_list = fname_warp_list.replace(' ', '')  # remove spaces
        # fname_warp_list = fname_warp_list.split(",")  # parse with comma
        for i in range(len(fname_warp_list)):
            # Check if inverse matrix is specified with '-' at the beginning of file name
            if fname_warp_list[i].find('-') == 0:
                use_inverse.append('-i ')
                fname_warp_list[i] = fname_warp_list[i][1:]  # remove '-'
                '  Transfo #' + str(i) + ': ' + use_inverse[i] +
                fname_warp_list[i], verbose)
            fname_warp_list_invert.append(use_inverse[i] + fname_warp_list[i])

        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = sct.extract_fname(
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # check if destination file is 3d
        if not sct.check_if_3d(fname_dest):
            sct.printv('ERROR: Destination data must be 3d')

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order

        # Extract path, file and extension
        path_src, file_src, ext_src = sct.extract_fname(fname_src)
        path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = path_out + file_out + ext_out

        # Get dimensions of data
        sct.printv('\nGet dimensions of data...', verbose)
        from msct_image import Image
        nx, ny, nz, nt, px, py, pz, pt = Image(fname_src).dim
        # nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_src)
            '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' +
            str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            sct.printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
                dim = '3'
                'isct_antsApplyTransforms -d ' + dim + ' -i ' + fname_src +
                ' -o ' + fname_out + ' -t ' +
                ' '.join(fname_warp_list_invert) + ' -r ' + fname_dest +
                interp, verbose)

        # if 4d, loop across the T dimension
            # create temporary folder
            sct.printv('\nCreate temporary folder...', verbose)
            path_tmp = sct.slash_at_the_end(
                'tmp.' + time.strftime("%y%m%d%H%M%S"), 1)
            # sct.run('mkdir '+path_tmp, verbose)
            sct.run('mkdir ' + path_tmp, verbose)

            # convert to nifti into temp folder
                '\nCopying input data to tmp folder and convert to nii...',
            from sct_convert import convert
            convert(fname_src, path_tmp + 'data.nii', squeeze_data=False)
            sct.run('cp ' + fname_dest + ' ' + path_tmp + file_dest + ext_dest)
            fname_warp_list_tmp = []
            for fname_warp in fname_warp_list:
                path_warp, file_warp, ext_warp = sct.extract_fname(fname_warp)
                sct.run('cp ' + fname_warp + ' ' + path_tmp + file_warp +
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            # split along T dimension
            sct.printv('\nSplit along T dimension...', verbose)
            from sct_image import split_data
            im_dat = Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = split_data(im_dat, 3)
            for im in data_split_list:

            # apply transfo
            sct.printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'
                status, output = sct.run(
                    'isct_antsApplyTransforms -d 3 -i ' + file_data_split +
                    ' -o ' + file_data_split_reg + ' -t ' +
                    ' '.join(fname_warp_list_invert_tmp) + ' -r ' + file_dest +
                    ext_dest + interp, verbose)

            # Merge files back
            sct.printv('\nMerge file back...', verbose)
            from sct_image import concat_data
            import glob
            path_out, name_out, ext_out = sct.extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            im_out = concat_data(fname_list, 3, im_header['pixdim'])
            im_out.setFileName(name_out + ext_out)

            sct.generate_output_file(path_tmp + name_out + ext_out, fname_out)
            # Delete temporary folder if specified
            if int(remove_temp_files):
                sct.printv('\nRemove temporary files...', verbose)
                sct.run('rm -rf ' + path_tmp, verbose, error_exit='warning')

        # 2. crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # if last warping field is an affine transfo, we need to compute the space of the concatenate warping field:
        if isLastAffine:
                'WARNING: the resulting image could have wrong apparent results. You should use an affine transformation as last transformation...',
                verbose, 'warning')
        elif crop_reference == 1:
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field+' -b 0')
        elif crop_reference == 2:
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field)

        # display elapsed time
        sct.printv('\nDone! To view results, type:', verbose)
        sct.printv('fslview ' + fname_dest + ' ' + fname_out + ' &\n', verbose,
def dmri_moco(param):

    file_data = 'dmri'
    ext_data = '.nii'
    file_b0 = 'b0'
    file_dwi = 'dwi'
    mat_final = 'mat_final/'
    file_dwi_group = 'dwi_averaged_groups'  # no extension
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data + ext_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Identify b=0 and DWI images
    sct.printv('\nIdentify b=0 and DWI images...', param.verbose)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0('bvecs.txt', param.fname_bvals, param.bval_min, param.verbose)

    # check if dmri and bvecs are the same size
    if not nb_b0 + nb_dwi == nt:
        sct.printv('\nERROR in '+os.path.basename(__file__)+': Size of data ('+str(nt)+') and size of bvecs ('+str(nb_b0+nb_dwi)+') are not the same. Check your bvecs file.\n', 1, 'error')

    # Prepare NIFTI (mean/groups...)
    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:

    # Merge b=0 images
    sct.printv('\nMerge b=0...', param.verbose)
    # cmd = fsloutput + 'fslmerge -t ' + file_b0
    # for it in range(nb_b0):
    #     cmd = cmd + ' ' + file_data + '_T' + str(index_b0[it]).zfill(4)
    im_b0_list = []
    for it in range(nb_b0):
    im_b0_out = concat_data(im_b0_list, 3)
    im_b0_out.setFileName(file_b0 + ext_data)
    sct.printv(('  File created: ' + file_b0), param.verbose)

    # Average b=0 images
    sct.printv('\nAverage b=0...', param.verbose)
    file_b0_mean = file_b0+'_mean'
    sct.run('sct_maths -i '+file_b0+ext_data+' -o '+file_b0_mean+ext_data+' -mean t', param.verbose)
    # if not average_data_across_dimension(file_b0+'.nii', file_b0_mean+'.nii', 3):
    #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
    # cmd = fsloutput + 'fslmaths ' + file_b0 + ' -Tmean ' + file_b0_mean
    # status, output = sct.run(cmd, param.verbose)

    # Number of DWI groups
    nb_groups = int(math.floor(nb_dwi/param.group_size))
    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
    # add the remaining images to the last DWI group
    nb_remaining = nb_dwi%param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1

    # DWI groups
    file_dwi_mean = []
    for iGroup in range(nb_groups):
        sct.printv('\nDWI group: ' +str((iGroup+1))+'/'+str(nb_groups), param.verbose)

        # get index
        index_dwi_i = group_indexes[iGroup]
        nb_dwi_i = len(index_dwi_i)

        # Merge DW Images
        sct.printv('Merge DW images...', param.verbose)
        file_dwi_merge_i = file_dwi + '_' + str(iGroup)

        im_dwi_list = []
        for it in range(nb_dwi_i):
        im_dwi_out = concat_data(im_dwi_list, 3)
        im_dwi_out.setFileName(file_dwi_merge_i + ext_data)

        # Average DW Images
        sct.printv('Average DW images...', param.verbose)
        file_dwi_mean.append(file_dwi + '_mean_' + str(iGroup))
        sct.run('sct_maths -i '+file_dwi_merge_i+ext_data+' -o '+file_dwi_mean[iGroup]+ext_data+' -mean t', param.verbose)

    # Merge DWI groups means
    sct.printv('\nMerging DW files...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    im_dw_list = []
    for iGroup in range(nb_groups):
        im_dw_list.append(Image(file_dwi_mean[iGroup] + ext_data))
    im_dw_out = concat_data(im_dw_list, 3)
    im_dw_out.setFileName(file_dwi_group + ext_data)
    # cmd = fsloutput + 'fslmerge -t ' + file_dwi_group
    # for iGroup in range(nb_groups):
    #     cmd = cmd + ' ' + file_dwi + '_mean_' + str(iGroup)

    # Average DW Images
    # TODO: USEFULL ???
    sct.printv('\nAveraging all DW images...', param.verbose)
    fname_dwi_mean = file_dwi+'_mean'
    sct.run('sct_maths -i '+file_dwi_group+ext_data+' -o '+file_dwi_group+'_mean'+ext_data+' -mean t', param.verbose)

    # segment dwi images using otsu algorithm
    if param.otsu:
        sct.printv('\nSegment group DWI using OTSU algorithm...', param.verbose)
        # import module
        otsu = importlib.import_module('sct_otsu')
        # get class from module
        param_otsu = otsu.param()  #getattr(otsu, param)
        param_otsu.fname_data = file_dwi_group+ext_data
        param_otsu.threshold = param.otsu
        param_otsu.file_suffix = '_seg'
        # run otsu
        file_dwi_group = file_dwi_group+'_seg'

    # extract first DWI volume as target for registration
    nii = Image(file_dwi_group+ext_data)
    data_crop = nii.data[:, :, :, index_dwi[0]:index_dwi[0]+1]
    nii.data = data_crop
    target_dwi_name = 'target_dwi'


    # Estimate moco on b0 groups
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion on b=0 images...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco = param
    param_moco.file_data = 'b0'
    if index_dwi[0] != 0:
        # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that case
        # select it as the target image for registration of all b=0
        param_moco.file_target = file_data + '_T' + str(index_b0[index_dwi[0]-1]).zfill(4)
        # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
        param_moco.file_target = file_data + '_T' + str(index_b0[0]).zfill(4)
    param_moco.path_out = ''
    param_moco.todo = 'estimate'
    param_moco.mat_moco = 'mat_b0groups'

    # Estimate moco on dwi groups
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion on DW images...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco.file_data = file_dwi_group
    param_moco.file_target = target_dwi_name  # target is the first DW image (closest to the first b=0)
    param_moco.path_out = ''
    # param_moco.todo = 'estimate'
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_dwigroups'

    # create final mat folder

    # Copy b=0 registration matrices
    sct.printv('\nCopy b=0 registration matrices...', param.verbose)

    for it in range(nb_b0):
        sct.run('cp '+'mat_b0groups/'+'mat.T'+str(it)+ext_mat+' '+mat_final+'mat.T'+str(index_b0[it])+ext_mat, param.verbose)

    # Copy DWI registration matrices
    sct.printv('\nCopy DWI registration matrices...', param.verbose)
    for iGroup in range(nb_groups):
        for dwi in range(len(group_indexes[iGroup])):
            sct.run('cp '+'mat_dwigroups/'+'mat.T'+str(iGroup)+ext_mat+' '+mat_final+'mat.T'+str(group_indexes[iGroup][dwi])+ext_mat, param.verbose)

    # Spline Regularization along T
    if param.spline_fitting:
        moco.spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # combine Eddy Matrices
    if param.run_eddy:
        param.mat_2_combine = 'mat_eddy'
        param.mat_final = mat_final

    # Apply moco on all dmri data
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Apply moco', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = file_dwi+'_mean_'+str(0)  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_dmri = Image(file_data+ext_data)
    im_dmri_moco = Image(file_data+param.suffix+ext_data)
    im_dmri_moco = copy_header(im_dmri, im_dmri_moco)

    # generate b0_moco_mean and dwi_moco_mean
    cmd = 'sct_dmri_separate_b0_and_dwi -i '+file_data+param.suffix+ext_data+' -bvec bvecs.txt -a 1'
    if not param.fname_bvals == '':
        cmd = cmd+' -m '+param.fname_bvals
    sct.run(cmd, param.verbose)
def find_centerline(algo, image_fname, path_sct, contrast_type, brain_bool,
                    folder_output, remove_temp_files, centerline_fname):

    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        from sct_image import concat_data
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
        bool_2d = False

    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        optic_models_fname = os.path.join(path_sct, 'data', 'optic_models',
        _, centerline_filename = optic.detect_centerline(
    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {
            't2': {
                'size': (80, 80),
                'mean': 51.1417,
                'std': 57.4408
            't2s': {
                'size': (80, 80),
                'mean': 68.8591,
                'std': 71.4659
            't1': {
                'size': (80, 80),
                'mean': 55.7359,
                'std': 64.3149
            'dwi': {
                'size': (80, 80),
                'mean': 55.744,
                'std': 45.003
        dct_params_ctr = {
            't2': {
                'features': 16,
                'dilation_layers': 2
            't2s': {
                'features': 8,
                'dilation_layers': 3
            't1': {
                'features': 24,
                'dilation_layers': 3
            'dwi': {
                'features': 8,
                'dilation_layers': 2

        # load model
        ctr_model_fname = os.path.join(path_sct, 'data', 'deepseg_sc_models',
        ctr_model = nn_architecture_ctr(

        sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])


        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        fname_labels_viewer = _call_viewer_centerline(fname_in=image_fname)
        centerline_filename = extract_centerline(fname_labels_viewer,
    elif algo == 'manual':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        image_manual_centerline = Image(centerline_fname)
        # Re-orient and Re-sample the manual centerline
            'The parameter "-centerline" is incorrect. Please try again.')

    if algo != 'cnn':
        sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])



    if bool_2d:
        from sct_image import split_data
        im_split_lst = split_data(Image(centerline_filename), dim=2)

    return fname_res, centerline_filename
def main(fname_anat, fname_centerline, degree_poly, centerline_fitting, interp,
         remove_temp_files, verbose):

    # extract path of the script
    path_script = os.path.dirname(__file__) + '/'

    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        status, path_sct_data = commands.getstatusoutput(
            'echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_data + '/t2/t2.nii.gz'
        fname_centerline = path_sct_data + '/t2/t2_seg.nii.gz'

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... ' + fname_anat
    print '  Centerline ........................ ' + fname_centerline
    print ''

    # Get input image orientation
    im_anat = Image(fname_anat)
    input_image_orientation = get_orientation_3d(im_anat)

    # Reorient input data into RL PA IS orientation
    im_centerline = Image(fname_centerline)
    im_anat_orient = set_orientation(im_anat, 'RPI')
    im_centerline_orient = set_orientation(im_centerline, 'RPI')

    # Open centerline
    print '\nGet dimensions of input centerline...'
    nx, ny, nz, nt, px, py, pz, pt = im_centerline_orient.dim
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    print '\nOpen centerline volume...'
    data = im_centerline_orient.data

    X, Y, Z = (data > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)

    # loop across z and associate x,y coordinate with the point having maximum intensity
    x_centerline = [0 for iz in range(min_z_index, max_z_index + 1, 1)]
    y_centerline = [0 for iz in range(min_z_index, max_z_index + 1, 1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index + 1, 1)]

    # Two possible scenario:
    # 1. the centerline is probabilistic: each slices contains voxels with the probability of containing the centerline [0:...:1]
    # We only take the maximum value of the image to aproximate the centerline.
    # 2. The centerline/segmentation image contains many pixels per slice with values {0,1}.
    # We take all the points and approximate the centerline on all these points.

    X, Y, Z = ((data < 1) * (data > 0)).nonzero()  # X is empty if binary image
    if (len(X) > 0):  # Scenario 1
        for iz in range(min_z_index, max_z_index + 1, 1):
            x_centerline[iz - min_z_index], y_centerline[
                iz - min_z_index] = numpy.unravel_index(
                    data[:, :, iz].argmax(), data[:, :, iz].shape)
    else:  # Scenario 2
        for iz in range(min_z_index, max_z_index + 1, 1):
            x_seg, y_seg = (data[:, :, iz] > 0).nonzero()
            if len(x_seg) > 0:
                x_centerline[iz - min_z_index] = numpy.mean(x_seg)
                y_centerline[iz - min_z_index] = numpy.mean(y_seg)

    # TODO: find a way to do the previous loop with this, which is more neat:
    # [numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape) for iz in range(0,nz,1)]

    # clear variable
    del data

    # Fit the centerline points with the kind of curve given as argument of the script and return the new smoothed coordinates
    if centerline_fitting == 'nurbs':
            x_centerline_fit, y_centerline_fit = b_spline_centerline(
                x_centerline, y_centerline, z_centerline)
        except ValueError:
            print "splines fitting doesn't work, trying with polynomial fitting...\n"
            x_centerline_fit, y_centerline_fit = polynome_centerline(
                x_centerline, y_centerline, z_centerline)
    elif centerline_fitting == 'polynome':
        x_centerline_fit, y_centerline_fit = polynome_centerline(
            x_centerline, y_centerline, z_centerline)

    # Split input volume
    print '\nSplit input volume...'
    im_anat_orient_split_list = split_data(im_anat_orient, 2)
    file_anat_split = []
    for im in im_anat_orient_split_list:

    # initialize variables
    file_mat_inv_cumul = [
        'tmp.mat_inv_cumul_Z' + str(z).zfill(4) for z in range(0, nz, 1)
    z_init = min_z_index
    displacement_max_z_index = x_centerline_fit[
        z_init - min_z_index] - x_centerline_fit[max_z_index - min_z_index]

    # write centerline as text file
    print '\nGenerate fitted transformation matrices...'
    file_mat_inv_cumul_fit = [
        'tmp.mat_inv_cumul_fit_Z' + str(z).zfill(4) for z in range(0, nz, 1)
    for iz in range(min_z_index, max_z_index + 1, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        if (x_centerline[iz - min_z_index] == 0
                and y_centerline[iz - min_z_index] == 0):
            displacement = 0
            displacement = x_centerline_fit[
                z_init - min_z_index] - x_centerline_fit[iz - min_z_index]
        fid.write('%i %i %i %f\n' % (1, 0, 0, displacement))
        fid.write('%i %i %i %f\n' % (0, 1, 0, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))

    # we complete the displacement matrix in z direction
    for iz in range(0, min_z_index, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, 0))
        fid.write('%i %i %i %f\n' % (0, 1, 0, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
    for iz in range(max_z_index + 1, nz, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, displacement_max_z_index))
        fid.write('%i %i %i %f\n' % (0, 1, 0, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = [
        'tmp.anat_orient_fit_Z' + str(z).zfill(4) for z in range(0, nz, 1)
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput + 'flirt -in ' + file_anat_split[iz] + ' -ref ' +
                file_anat_split[iz] + ' -applyxfm -init ' +
                file_mat_inv_cumul_fit[iz] + ' -out ' +
                file_anat_split_fit[iz] + ' -interp ' + interp)

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    from glob import glob
    im_to_concat_list = [
        Image(fname) for fname in glob('tmp.anat_orient_fit_Z*.nii')
    im_concat_out = concat_data(im_to_concat_list, 2)
    # sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')

    # Reorient data as it was before
    print '\nReorient data back into native orientation...'
    fname_anat_fit_orient = set_orientation(im_concat_out.absolutepath,
    move(fname_anat_fit_orient, 'tmp.anat_orient_fit_reorient.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
                             file_anat + '_flatten' + ext_anat)

    # Delete temporary files
    if remove_temp_files == 1:
        print '\nDelete temporary files...'
        sct.run('rm -rf tmp.*')

    # to view results
    print '\nDone! To view results, type:'
    print 'fslview ' + file_anat + ext_anat + ' ' + file_anat + '_flatten' + ext_anat + ' &\n'
def moco(param):

    # retrieve parameters
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = sct.slash_at_the_end(param.mat_moco, 1)  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    #file_schedule = param.file_schedule
    verbose = param.verbose
    ext = '.nii'

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')

    # print arguments
    sct.printv('\nInput parameters:', param.verbose)
    sct.printv('  Input file ............'+file_data, param.verbose)
    sct.printv('  Reference file ........'+file_target, param.verbose)
    sct.printv('  Polynomial degree .....'+param.param[0], param.verbose)
    sct.printv('  Smoothing kernel ......'+param.param[1], param.verbose)
    sct.printv('  Gradient step .........'+param.param[2], param.verbose)
    sct.printv('  Metric ................'+param.param[3], param.verbose)
    sct.printv('  Todo ..................'+todo, param.verbose)
    sct.printv('  Mask  .................'+param.fname_mask, param.verbose)
    sct.printv('  Output mat folder .....'+folder_mat, param.verbose)

    # create folder for mat files

    # Get size of data
    sct.printv('\nGet dimensions data...', verbose)
    data_im = Image(file_data+ext)
    nx, ny, nz, nt, px, py, pz, pt = data_im.dim
    sct.printv(('.. '+str(nx)+' x '+str(ny)+' x '+str(nz)+' x '+str(nt)), verbose)

    # copy file_target to a temporary file
    sct.printv('\nCopy file_target to a temporary file...', verbose)
    sct.run('cp '+file_target+ext+' target.nii')
    file_target = 'target'

    # Split data along T dimension
    sct.printv('\nSplit data along T dimension...', verbose)
    data_split_list = split_data(data_im, dim=3)
    for im in data_split_list:
    file_data_splitT = file_data + '_T'

    # Motion correction: initialization
    index = np.arange(nt)
    file_data_splitT_num = []
    file_data_splitT_moco_num = []
    failed_transfo = [0 for i in range(nt)]
    file_mat = [[] for i in range(nt)]

    # Motion correction: Loop across T
    for indice_index in range(nt):

        # create indices and display stuff
        it = index[indice_index]
        file_data_splitT_num.append(file_data_splitT + str(it).zfill(4))
        file_data_splitT_moco_num.append(file_data + suffix + '_T' + str(it).zfill(4))
        sct.printv(('\nVolume '+str((it))+'/'+str(nt-1)+':'), verbose)
        file_mat[it] = folder_mat + 'mat.T' + str(it)

        # run 3D registration
        failed_transfo[it] = register(param, file_data_splitT_num[it], file_target, file_mat[it], file_data_splitT_moco_num[it])

        # average registered volume with target image
        # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
        if param.iterative_averaging and indice_index < 10 and failed_transfo[it] == 0:
            sct.run('sct_maths -i '+file_target+ext+' -mul '+str(indice_index+1)+' -o '+file_target+ext)
            sct.run('sct_maths -i '+file_target+ext+' -add '+file_data_splitT_moco_num[it]+ext+' -o '+file_target+ext)
            sct.run('sct_maths -i '+file_target+ext+' -div '+str(indice_index+2)+' -o '+file_target+ext)

    # Replace failed transformation with the closest good one
    sct.printv(('\nReplace failed transformations...'), verbose)
    fT = [i for i, j in enumerate(failed_transfo) if j == 1]
    gT = [i for i, j in enumerate(failed_transfo) if j == 0]
    for it in range(len(fT)):
        abs_dist = [abs(gT[i]-fT[it]) for i in range(len(gT))]
        if not abs_dist == []:
            index_good = abs_dist.index(min(abs_dist))
            sct.printv('  transfo #'+str(fT[it])+' --> use transfo #'+str(gT[index_good]), verbose)
            # copy transformation
            sct.run('cp '+file_mat[gT[index_good]]+'Warp.nii.gz'+' '+file_mat[fT[it]]+'Warp.nii.gz')
            # apply transformation
            sct.run('sct_apply_transfo -i '+file_data_splitT_num[fT[it]]+'.nii -d '+file_target+'.nii -w '+file_mat[fT[it]]+'Warp.nii.gz'+' -o '+file_data_splitT_moco_num[fT[it]]+'.nii'+' -x '+param.interp, verbose)
            # exit program if no transformation exists.
            sct.printv('\nERROR in '+os.path.basename(__file__)+': No good transformation exist. Exit program.\n', verbose, 'error')

    # Merge data along T
    file_data_moco = file_data+suffix
    if todo != 'estimate':
        sct.printv('\nMerge data back along T...', verbose)
        from sct_image import concat_data
        # im_list = []
        fname_list = []
        for indice_index in range(len(index)):
            # im_list.append(Image(file_data_splitT_moco_num[indice_index] + ext))
            fname_list.append(file_data_splitT_moco_num[indice_index] + ext)
        im_out = concat_data(fname_list, 3)
        im_out.setFileName(file_data_moco + ext)

    # delete file target.nii (to avoid conflict if this function is run another time)
    sct.printv('\nRemove temporary file...', verbose)
    sct.run('rm target.nii')
def fmri_moco(param):

    file_data = 'fmri'
    ext_data = '.nii'
    mat_final = 'mat_final/'
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data + '.nii')
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), param.verbose)

    # Get orientation
    sct.printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        sct.printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        sct.printv('  Treated as axial')
        param.is_sagittal = False
        sct.printv('WARNING: Orientation seems to be neither axial nor sagittal.')

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
        sct.printv('For sagittal data group_size should be one for more robustness. Forcing group_size=1.', 1, 'warning')
        param.group_size = 1

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data = Image(file_data + ext_data)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:

    # assign an index to each volume
    index_fmri = list(range(0, nt))

    # Number of groups
    nb_groups = int(math.floor(nt / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_fmri[(iGroup * param.group_size):((iGroup + 1) * param.group_size)])

    # add the remaining images to the last fMRI group
    nb_remaining = nt%param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_fmri[len(index_fmri) - nb_remaining:len(index_fmri)])

    # groups
    for iGroup in tqdm(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=True, ncols=80):
        # get index
        index_fmri_i = group_indexes[iGroup]
        nt_i = len(index_fmri_i)

        # Merge Images
        file_data_merge_i = file_data + '_' + str(iGroup)
        # cmd = fsloutput + 'fslmerge -t ' + file_data_merge_i
        # for it in range(nt_i):
        #     cmd = cmd + ' ' + file_data + '_T' + str(index_fmri_i[it]).zfill(4)

        im_fmri_list = []
        for it in range(nt_i):
        im_fmri_concat = concat_data(im_fmri_list, 3, squeeze_data=True).save(file_data_merge_i + ext_data)

        file_data_mean = file_data + '_mean_' + str(iGroup)
        if param.group_size == 1:
            # copy to new file name instead of averaging (faster)
            # note: this is a bandage. Ideally we should skip this entire for loop if g=1
            sct.copy(file_data_merge_i + '.nii', file_data_mean + '.nii')
            # Average Images
            sct.run(['sct_maths', '-i', file_data_merge_i + '.nii', '-o', file_data_mean + '.nii', '-mean', 't'], verbose=0)
        # if not average_data_across_dimension(file_data_merge_i+'.nii', file_data_mean+'.nii', 3):
        #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
        # cmd = fsloutput + 'fslmaths ' + file_data_merge_i + ' -Tmean ' + file_data_mean
        # sct.run(cmd, param.verbose)

    # Merge groups means. The output 4D volume will be used for motion correction.
    sct.printv('\nMerging volumes...', param.verbose)
    file_data_groups_means_merge = 'fmri_averaged_groups'
    im_mean_list = []
    for iGroup in range(nb_groups):
        im_mean_list.append(Image(file_data + '_mean_' + str(iGroup) + ext_data))
    im_mean_concat = concat_data(im_mean_list, 3).save(file_data_groups_means_merge + ext_data)

    # Estimate moco
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco = param
    param_moco.file_data = 'fmri_averaged_groups'
    param_moco.file_target = file_data + '_mean_' + param.num_target
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat = moco.moco(param_moco)

    # TODO: if g=1, no need to run the block below (already applied)
    if param.group_size == 1:
        # if flag g=1, it means that all images have already been corrected, so we just need to rename the file
        sct.mv('fmri_averaged_groups_moco.nii', 'fmri_moco.nii')
        # create final mat folder

        # Copy registration matrices
        sct.printv('\nCopy transformations...', param.verbose)
        for iGroup in range(nb_groups):
            for data in range(len(group_indexes[iGroup])):  # we cannot use enumerate because group_indexes has 2 dim.
                # fetch all file_mat_z for given t-group
                list_file_mat_z = file_mat[:, iGroup]
                # loop across file_mat_z and copy to mat_final folder
                for file_mat_z in list_file_mat_z:
                    # we want to copy 'mat_groups/mat.ZXXXXTYYYYWarp.nii.gz' --> 'mat_final/mat.ZXXXXTYYYZWarp.nii.gz'
                    # Notice the Y->Z in the under the T index: the idea here is to use the single matrix from each group,
                    # and apply it to all images belonging to the same group.
                    sct.copy(file_mat_z + ext_mat,
                             mat_final + file_mat_z[11:20] + 'T' + str(group_indexes[iGroup][data]).zfill(4) + ext_mat)

        # Apply moco on all fmri data
        sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
        sct.printv('  Apply moco', param.verbose)
        sct.printv('-------------------------------------------------------------------------------', param.verbose)
        param_moco.file_data = 'fmri'
        param_moco.file_target = file_data + '_mean_' + str(0)
        param_moco.path_out = ''
        param_moco.mat_moco = mat_final
        param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_fmri = Image('fmri.nii')
    im_fmri_moco = Image('fmri_moco.nii')
    im_fmri_moco.header = im_fmri.header

    # Average volumes
    sct.printv('\nAveraging data...', param.verbose)
    sct_maths.main(args=['-i', 'fmri_moco.nii',
                         '-o', 'fmri_moco_mean.nii',
                         '-mean', 't',
                         '-v', '0'])
def fmri_moco(param):

    file_data = 'fmri'
    ext_data = '.nii'
    mat_final = 'mat_final/'
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(file_data+'.nii').dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), param.verbose)

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data = Image(file_data + ext_data)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:

    # assign an index to each volume
    index_fmri = range(0, nt)

    # Number of groups
    nb_groups = int(math.floor(nt/param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):

    # add the remaining images to the last DWI group
    nb_remaining = nt%param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1

    # groups
    for iGroup in range(nb_groups):
        sct.printv('\nGroup: ' +str((iGroup+1))+'/'+str(nb_groups), param.verbose)

        # get index
        index_fmri_i = group_indexes[iGroup]
        nt_i = len(index_fmri_i)

        # Merge Images
        sct.printv('Merge consecutive volumes...', param.verbose)
        file_data_merge_i = file_data + '_' + str(iGroup)
        # cmd = fsloutput + 'fslmerge -t ' + file_data_merge_i
        # for it in range(nt_i):
        #     cmd = cmd + ' ' + file_data + '_T' + str(index_fmri_i[it]).zfill(4)

        im_fmri_list = []
        for it in range(nt_i):
        im_fmri_concat = concat_data(im_fmri_list, 3)
        im_fmri_concat.setFileName(file_data_merge_i + ext_data)

        # Average Images
        sct.printv('Average volumes...', param.verbose)
        file_data_mean = file_data + '_mean_' + str(iGroup)
        sct.run('sct_maths -i '+file_data_merge_i+'.nii -o '+file_data_mean+'.nii -mean t')
        # if not average_data_across_dimension(file_data_merge_i+'.nii', file_data_mean+'.nii', 3):
        #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
        # cmd = fsloutput + 'fslmaths ' + file_data_merge_i + ' -Tmean ' + file_data_mean
        # sct.run(cmd, param.verbose)

    # Merge groups means
    sct.printv('\nMerging volumes...', param.verbose)
    file_data_groups_means_merge = 'fmri_averaged_groups'
    # cmd = fsloutput + 'fslmerge -t ' + file_data_groups_means_merge
    # for iGroup in range(nb_groups):
    #     cmd = cmd + ' ' + file_data + '_mean_' + str(iGroup)
    im_mean_list = []
    for iGroup in range(nb_groups):
        im_mean_list.append(Image(file_data + '_mean_' + str(iGroup) + ext_data))
    im_mean_concat = concat_data(im_mean_list, 3)
    im_mean_concat.setFileName(file_data_groups_means_merge + ext_data)

    # Estimate moco on dwi groups
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco = param
    param_moco.file_data = 'fmri_averaged_groups'
    param_moco.file_target = file_data + '_mean_' + str(param.num_target)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'

    # create final mat folder

    # Copy registration matrices
    sct.printv('\nCopy transformations...', param.verbose)
    for iGroup in range(nb_groups):
        for data in range(len(group_indexes[iGroup])):
            # if param.slicewise:
            #     for iz in range(nz):
            #         sct.run('cp '+'mat_dwigroups/'+'mat.T'+str(iGroup)+'_Z'+str(iz)+ext_mat+' '+mat_final+'mat.T'+str(group_indexes[iGroup][dwi])+'_Z'+str(iz)+ext_mat, param.verbose)
            # else:
            sct.run('cp '+'mat_groups/'+'mat.T'+str(iGroup)+ext_mat+' '+mat_final+'mat.T'+str(group_indexes[iGroup][data])+ext_mat, param.verbose)

    # Apply moco on all fmri data
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Apply moco', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco.file_data = 'fmri'
    param_moco.file_target = file_data+'_mean_'+str(0)
    param_moco.path_out = ''
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_fmri = Image('fmri.nii')
    im_fmri_moco = Image('fmri_moco.nii')
    im_fmri_moco = copy_header(im_fmri, im_fmri_moco)

    # Average volumes
    sct.printv('\nAveraging data...', param.verbose)
    sct.run('sct_maths -i fmri_moco.nii -o fmri_moco_mean.nii -mean t')
def get_centerline_from_point(input_image, point_file, gap=4, gaussian_kernel=4, remove_tmp_files=1):

    # Initialization
    fname_anat = input_image
    fname_point = point_file
    slice_gap = gap
    remove_tmp_files = remove_tmp_files
    gaussian_kernel = gaussian_kernel
    start_time = time()
    verbose = 1

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput("echo $SCT_DIR")
    path_sct = sct.slash_at_the_end(path_sct, 1)

    # Parameters for debug mode
    if param.debug == 1:
        sct.printv("\n*** WARNING: DEBUG MODE ON ***\n\t\t\tCurrent working directory: " + os.getcwd(), "warning")
        status, path_sct_testing_data = commands.getstatusoutput("echo $SCT_TESTING_DATA_DIR")
        fname_anat = path_sct_testing_data + "/t2/t2.nii.gz"
        fname_point = path_sct_testing_data + "/t2/t2_centerline_init.nii.gz"
        slice_gap = 5

    # check existence of input files

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_point, file_point, ext_point = sct.extract_fname(fname_point)

    # extract path of schedule file
    # TODO: include schedule file in sct
    # TODO: check existence of schedule file
    file_schedule = path_sct + param.schedule_file

    # Get input image orientation
    input_image_orientation = get_orientation_3d(fname_anat, filename=True)

    # Display arguments
    print "\nCheck input arguments..."
    print "  Anatomical image:     " + fname_anat
    print "  Orientation:          " + input_image_orientation
    print "  Point in spinal cord: " + fname_point
    print "  Slice gap:            " + str(slice_gap)
    print "  Gaussian kernel:      " + str(gaussian_kernel)
    print "  Degree of polynomial: " + str(param.deg_poly)

    # create temporary folder
    print ("\nCreate temporary folder...")
    path_tmp = "tmp." + strftime("%y%m%d%H%M%S")
    print "\nCopy input data..."
    sct.run("cp " + fname_anat + " " + path_tmp + "/tmp.anat" + ext_anat)
    sct.run("cp " + fname_point + " " + path_tmp + "/tmp.point" + ext_point)

    # go to temporary folder

    # convert to nii
    im_anat = convert("tmp.anat" + ext_anat, "tmp.anat.nii")
    im_point = convert("tmp.point" + ext_point, "tmp.point.nii")

    # Reorient input anatomical volume into RL PA IS orientation
    print "\nReorient input volume to RL PA IS orientation..."
    set_orientation(im_anat, "RPI")
    # Reorient binary point into RL PA IS orientation
    print "\nReorient binary point into RL PA IS orientation..."
    # sct.run(sct.fsloutput + 'fslswapdim tmp.point RL PA IS tmp.point_orient')
    set_orientation(im_point, "RPI")

    # Get image dimensions
    print "\nGet image dimensions..."
    nx, ny, nz, nt, px, py, pz, pt = Image("tmp.anat_orient.nii").dim
    print ".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz)
    print ".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm"

    # Split input volume
    print "\nSplit input volume..."
    im_anat_split_list = split_data(im_anat, 2)
    file_anat_split = []
    for im in im_anat_split_list:

    im_point_split_list = split_data(im_point, 2)
    file_point_split = []
    for im in im_point_split_list:

    # Extract coordinates of input point
    data_point = Image("tmp.point_orient.nii").data
    x_init, y_init, z_init = unravel_index(data_point.argmax(), data_point.shape)
    sct.printv("Coordinates of input point: (" + str(x_init) + ", " + str(y_init) + ", " + str(z_init) + ")", verbose)

    # Create 2D gaussian mask
    sct.printv("\nCreate gaussian mask from point...", verbose)
    xx, yy = mgrid[:nx, :ny]
    mask2d = zeros((nx, ny))
    radius = round(float(gaussian_kernel + 1) / 2)  # add 1 because the radius includes the center.
    sigma = float(radius)
    mask2d = exp(-(((xx - x_init) ** 2) / (2 * (sigma ** 2)) + ((yy - y_init) ** 2) / (2 * (sigma ** 2))))

    # Save mask to 2d file
    file_mask_split = ["tmp.mask_orient_Z" + str(z).zfill(4) for z in range(0, nz, 1)]
    nii_mask2d = Image("tmp.anat_orient_Z0000.nii")
    nii_mask2d.data = mask2d
    nii_mask2d.setFileName(file_mask_split[z_init] + ".nii")

    # initialize variables
    file_mat = ["tmp.mat_Z" + str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_inv = ["tmp.mat_inv_Z" + str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_inv_cumul = ["tmp.mat_inv_cumul_Z" + str(z).zfill(4) for z in range(0, nz, 1)]

    # create identity matrix for initial transformation matrix
    fid = open(file_mat_inv_cumul[z_init], "w")
    fid.write("%i %i %i %i\n" % (1, 0, 0, 0))
    fid.write("%i %i %i %i\n" % (0, 1, 0, 0))
    fid.write("%i %i %i %i\n" % (0, 0, 1, 0))
    fid.write("%i %i %i %i\n" % (0, 0, 0, 1))

    # initialize centerline: give value corresponding to initial point
    x_centerline = [x_init]
    y_centerline = [y_init]
    z_centerline = [z_init]
    warning_count = 0

    # go up (1), then down (2) in reference to the binary point
    for iUpDown in range(1, 3):

        if iUpDown == 1:
            # z increases
            slice_gap_signed = slice_gap
        elif iUpDown == 2:
            # z decreases
            slice_gap_signed = -slice_gap
            # reverse centerline (because values will be appended at the end)

        # initialization before looping
        z_dest = z_init  # point given by user
        z_src = z_dest + slice_gap_signed

        # continue looping if 0 <= z < nz
        while 0 <= z_src < nz:

            # print current z:
            print "z=" + str(z_src) + ":"

            # estimate transformation
                + "flirt -in "
                + file_anat_split[z_src]
                + " -ref "
                + file_anat_split[z_dest]
                + " -schedule "
                + file_schedule
                + " -verbose 0 -omat "
                + file_mat[z_src]
                + " -cost normcorr -forcescaling -inweight "
                + file_mask_split[z_dest]
                + " -refweight "
                + file_mask_split[z_dest]

            # display transfo
            status, output = sct.run("cat " + file_mat[z_src])
            print output

            # check if transformation is bigger than 1.5x slice_gap
            tx = float(output.split()[3])
            ty = float(output.split()[7])
            norm_txy = linalg.norm([tx, ty], ord=2)
            if norm_txy > 1.5 * slice_gap:
                print "WARNING: Transformation is too large --> using previous one."
                warning_count = warning_count + 1
                # if previous transformation exists, replace current one with previous one
                if os.path.isfile(file_mat[z_dest]):
                    sct.run("cp " + file_mat[z_dest] + " " + file_mat[z_src])

            # estimate inverse transformation matrix
            sct.run("convert_xfm -omat " + file_mat_inv[z_src] + " -inverse " + file_mat[z_src])

            # compute cumulative transformation
                "convert_xfm -omat "
                + file_mat_inv_cumul[z_src]
                + " -concat "
                + file_mat_inv[z_src]
                + " "
                + file_mat_inv_cumul[z_dest]

            # apply inverse cumulative transformation to initial gaussian mask (to put it in src space)
                + "flirt -in "
                + file_mask_split[z_init]
                + " -ref "
                + file_mask_split[z_init]
                + " -applyxfm -init "
                + file_mat_inv_cumul[z_src]
                + " -out "
                + file_mask_split[z_src]

            # open inverse cumulative transformation file and generate centerline
            fid = open(file_mat_inv_cumul[z_src])
            mat = fid.read().split()
            x_centerline.append(x_init + float(mat[3]))
            y_centerline.append(y_init + float(mat[7]))
            # z_index = z_index+1

            # define new z_dest (target slice) and new z_src (moving slice)
            z_dest = z_dest + slice_gap_signed
            z_src = z_src + slice_gap_signed

    # Reconstruct centerline
    # ====================================================================================================

    # reverse back centerline (because it's been reversed once, so now all values are in the right order)

    # fit centerline in the Z-X plane using polynomial function
    print "\nFit centerline in the Z-X plane using polynomial function..."
    coeffsx = polyfit(z_centerline, x_centerline, deg=param.deg_poly)
    polyx = poly1d(coeffsx)
    x_centerline_fit = polyval(polyx, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(x_centerline_fit - x_centerline) / sqrt(len(x_centerline))
    # calculate max absolute error
    max_abs = max(abs(x_centerline_fit - x_centerline))
    print ".. RMSE (in mm): " + str(rmse * px)
    print ".. Maximum absolute error (in mm): " + str(max_abs * px)

    # fit centerline in the Z-Y plane using polynomial function
    print "\nFit centerline in the Z-Y plane using polynomial function..."
    coeffsy = polyfit(z_centerline, y_centerline, deg=param.deg_poly)
    polyy = poly1d(coeffsy)
    y_centerline_fit = polyval(polyy, z_centerline)
    # calculate RMSE
    rmse = linalg.norm(y_centerline_fit - y_centerline) / sqrt(len(y_centerline))
    # calculate max absolute error
    max_abs = max(abs(y_centerline_fit - y_centerline))
    print ".. RMSE (in mm): " + str(rmse * py)
    print ".. Maximum absolute error (in mm): " + str(max_abs * py)

    # display
    if param.debug == 1:
        import matplotlib.pyplot as plt

        plt.plot(z_centerline, x_centerline, ".", z_centerline, x_centerline_fit, "r")
        plt.legend(["Data", "Polynomial Fit"])
        plt.title("Z-X plane polynomial interpolation")

        plt.plot(z_centerline, y_centerline, ".", z_centerline, y_centerline_fit, "r")
        plt.legend(["Data", "Polynomial Fit"])
        plt.title("Z-Y plane polynomial interpolation")

    # generate full range z-values for centerline
    z_centerline_full = [iz for iz in range(0, nz, 1)]

    # calculate X and Y values for the full centerline
    x_centerline_fit_full = polyval(polyx, z_centerline_full)
    y_centerline_fit_full = polyval(polyy, z_centerline_full)

    # Generate fitted transformation matrices and write centerline coordinates in text file
    print "\nGenerate fitted transformation matrices and write centerline coordinates in text file..."
    file_mat_inv_cumul_fit = ["tmp.mat_inv_cumul_fit_z" + str(z).zfill(4) for z in range(0, nz, 1)]
    file_mat_cumul_fit = ["tmp.mat_cumul_fit_z" + str(z).zfill(4) for z in range(0, nz, 1)]
    fid_centerline = open("tmp.centerline_coordinates.txt", "w")
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], "w")
        fid.write("%i %i %i %f\n" % (1, 0, 0, x_centerline_fit_full[iz] - x_init))
        fid.write("%i %i %i %f\n" % (0, 1, 0, y_centerline_fit_full[iz] - y_init))
        fid.write("%i %i %i %i\n" % (0, 0, 1, 0))
        fid.write("%i %i %i %i\n" % (0, 0, 0, 1))
        # compute forward cumulative fitted transformation matrix
        sct.run("convert_xfm -omat " + file_mat_cumul_fit[iz] + " -inverse " + file_mat_inv_cumul_fit[iz])
        # write centerline coordinates in x, y, z format
            "%f %f %f\n" % (x_centerline_fit_full[iz], y_centerline_fit_full[iz], z_centerline_full[iz])

    # Prepare output data
    # ====================================================================================================

    # write centerline as text file
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], "w")
        fid.write("%i %i %i %f\n" % (1, 0, 0, x_centerline_fit_full[iz] - x_init))
        fid.write("%i %i %i %f\n" % (0, 1, 0, y_centerline_fit_full[iz] - y_init))
        fid.write("%i %i %i %i\n" % (0, 0, 1, 0))
        fid.write("%i %i %i %i\n" % (0, 0, 0, 1))

    # write polynomial coefficients
    savetxt("tmp.centerline_polycoeffs_x.txt", coeffsx)
    savetxt("tmp.centerline_polycoeffs_y.txt", coeffsy)

    # apply transformations to data
    print "\nApply fitted transformation matrices..."
    file_anat_split_fit = ["tmp.anat_orient_fit_z" + str(z).zfill(4) for z in range(0, nz, 1)]
    file_mask_split_fit = ["tmp.mask_orient_fit_z" + str(z).zfill(4) for z in range(0, nz, 1)]
    file_point_split_fit = ["tmp.point_orient_fit_z" + str(z).zfill(4) for z in range(0, nz, 1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
            + "flirt -in "
            + file_anat_split[iz]
            + " -ref "
            + file_anat_split[iz]
            + " -applyxfm -init "
            + file_mat_cumul_fit[iz]
            + " -out "
            + file_anat_split_fit[iz]
        # inverse cumulative transformation to mask
            + "flirt -in "
            + file_mask_split[z_init]
            + " -ref "
            + file_mask_split[z_init]
            + " -applyxfm -init "
            + file_mat_inv_cumul_fit[iz]
            + " -out "
            + file_mask_split_fit[iz]
        # inverse cumulative transformation to point
            + "flirt -in "
            + file_point_split[z_init]
            + " -ref "
            + file_point_split[z_init]
            + " -applyxfm -init "
            + file_mat_inv_cumul_fit[iz]
            + " -out "
            + file_point_split_fit[iz]
            + " -interp nearestneighbour"

    # Merge into 4D volume
    print "\nMerge into 4D volume..."
    # im_anat_list = [Image(fname) for fname in glob.glob('tmp.anat_orient_fit_z*.nii')]
    fname_anat_list = glob.glob("tmp.anat_orient_fit_z*.nii")
    im_anat_concat = concat_data(fname_anat_list, 2)

    # im_mask_list = [Image(fname) for fname in glob.glob('tmp.mask_orient_fit_z*.nii')]
    fname_mask_list = glob.glob("tmp.mask_orient_fit_z*.nii")
    im_mask_concat = concat_data(fname_mask_list, 2)

    # im_point_list = [Image(fname) for fname in 	glob.glob('tmp.point_orient_fit_z*.nii')]
    fname_point_list = glob.glob("tmp.point_orient_fit_z*.nii")
    im_point_concat = concat_data(fname_point_list, 2)

    # Copy header geometry from input data
    print "\nCopy header geometry from input data..."
    im_anat = Image("tmp.anat_orient.nii")
    im_anat_orient_fit = Image("tmp.anat_orient_fit.nii")
    im_mask_orient_fit = Image("tmp.mask_orient_fit.nii")
    im_point_orient_fit = Image("tmp.point_orient_fit.nii")
    im_anat_orient_fit = copy_header(im_anat, im_anat_orient_fit)
    im_mask_orient_fit = copy_header(im_anat, im_mask_orient_fit)
    im_point_orient_fit = copy_header(im_anat, im_point_orient_fit)
    for im in [im_anat_orient_fit, im_mask_orient_fit, im_point_orient_fit]:

    # Reorient outputs into the initial orientation of the input image
    print "\nReorient the centerline into the initial orientation of the input image..."
    set_orientation("tmp.point_orient_fit.nii", input_image_orientation, "tmp.point_orient_fit.nii")
    set_orientation("tmp.mask_orient_fit.nii", input_image_orientation, "tmp.mask_orient_fit.nii")

    # Generate output file (in current folder)
    print "\nGenerate output file (in current folder)..."
    os.chdir("..")  # come back to parent folder
    fname_output_centerline = sct.generate_output_file(
        path_tmp + "/tmp.point_orient_fit.nii", file_anat + "_centerline" + ext_anat

    # Delete temporary files
    if remove_tmp_files == 1:
        print "\nRemove temporary files..."
        sct.run("rm -rf " + path_tmp, error_exit="warning")

    # print number of warnings
    print "\nNumber of warnings: " + str(
    ) + " (if >10, you should probably reduce the gap and/or increase the kernel size"

    # display elapsed time
    elapsed_time = time() - start_time
    print "\nFinished! \n\tGenerated file: " + fname_output_centerline + "\n\tElapsed time: " + str(
    ) + "s\n"
コード例 #20
def dmri_moco(param):

    file_data = 'dmri'
    ext_data = '.nii'
    file_b0 = 'b0'
    file_dwi = 'dwi'
    mat_final = 'mat_final/'
    file_dwi_group = 'dwi_averaged_groups'  # no extension
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data + ext_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz),

    # Identify b=0 and DWI images
    sct.printv('\nIdentify b=0 and DWI images...', param.verbose)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0('bvecs.txt',

    # check if dmri and bvecs are the same size
    if not nb_b0 + nb_dwi == nt:
            '\nERROR in ' + os.path.basename(__file__) + ': Size of data (' +
            str(nt) + ') and size of bvecs (' + str(nb_b0 + nb_dwi) +
            ') are not the same. Check your bvecs file.\n', 1, 'error')

    # Prepare NIFTI (mean/groups...)
    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:

    # Merge b=0 images
    sct.printv('\nMerge b=0...', param.verbose)
    im_b0_list = []
    for it in range(nb_b0):
    im_b0_out = concat_data(im_b0_list, 3)
    im_b0_out.setFileName(file_b0 + ext_data)
    sct.printv(('  File created: ' + file_b0), param.verbose)

    # Average b=0 images
    sct.printv('\nAverage b=0...', param.verbose)
    file_b0_mean = file_b0 + '_mean'
        'sct_maths -i ' + file_b0 + ext_data + ' -o ' + file_b0_mean +
        ext_data + ' -mean t', param.verbose)

    # Number of DWI groups
    nb_groups = int(math.floor(nb_dwi / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_dwi[(iGroup *
                                        param.group_size):((iGroup + 1) *

    # add the remaining images to the last DWI group
    nb_remaining = nb_dwi % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_dwi[len(index_dwi) -

    # DWI groups
    file_dwi_mean = []
    for iGroup in range(nb_groups):
        sct.printv('\nDWI group: ' + str((iGroup + 1)) + '/' + str(nb_groups),

        # get index
        index_dwi_i = group_indexes[iGroup]
        nb_dwi_i = len(index_dwi_i)

        # Merge DW Images
        sct.printv('Merge DW images...', param.verbose)
        file_dwi_merge_i = file_dwi + '_' + str(iGroup)

        im_dwi_list = []
        for it in range(nb_dwi_i):
        im_dwi_out = concat_data(im_dwi_list, 3)
        im_dwi_out.setFileName(file_dwi_merge_i + ext_data)

        # Average DW Images
        sct.printv('Average DW images...', param.verbose)
        file_dwi_mean.append(file_dwi + '_mean_' + str(iGroup))
            'sct_maths -i ' + file_dwi_merge_i + ext_data + ' -o ' +
            file_dwi_mean[iGroup] + ext_data + ' -mean t', param.verbose)

    # Merge DWI groups means
    sct.printv('\nMerging DW files...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    im_dw_list = []
    for iGroup in range(nb_groups):
        im_dw_list.append(file_dwi_mean[iGroup] + ext_data)
    im_dw_out = concat_data(im_dw_list, 3)
    im_dw_out.setFileName(file_dwi_group + ext_data)

    # Average DW Images
    # TODO: USEFULL ???
    sct.printv('\nAveraging all DW images...', param.verbose)
        'sct_maths -i ' + file_dwi_group + ext_data + ' -o ' + file_dwi_group +
        '_mean' + ext_data + ' -mean t', param.verbose)

    # segment dwi images using otsu algorithm
    if param.otsu:
        sct.printv('\nSegment group DWI using OTSU algorithm...',
        # import module
        otsu = importlib.import_module('sct_otsu')
        # get class from module
        param_otsu = otsu.param()  #getattr(otsu, param)
        param_otsu.fname_data = file_dwi_group + ext_data
        param_otsu.threshold = param.otsu
        param_otsu.file_suffix = '_seg'
        # run otsu
        file_dwi_group = file_dwi_group + '_seg'


    # Estimate moco on b0 groups
    sct.printv('  Estimating motion on b=0 images...', param.verbose)
    param_moco = param
    param_moco.file_data = 'b0'
    # identify target image
    if index_dwi[0] != 0:
        # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that case
        # select it as the target image for registration of all b=0
        param_moco.file_target = file_data + '_T' + str(
            index_b0[index_dwi[0] - 1]).zfill(4)
        # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
        param_moco.file_target = file_data + '_T' + str(index_b0[0]).zfill(4)
    param_moco.path_out = ''
    param_moco.todo = 'estimate'
    param_moco.mat_moco = 'mat_b0groups'

    # Estimate moco on dwi groups
    sct.printv('  Estimating motion on DW images...', param.verbose)
    param_moco.file_data = file_dwi_group
    param_moco.file_target = file_dwi_mean[
        0]  # target is the first DW image (closest to the first b=0)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_dwigroups'

    # create final mat folder

    # Copy b=0 registration matrices
    sct.printv('\nCopy b=0 registration matrices...', param.verbose)

    for it in range(nb_b0):
            'cp ' + 'mat_b0groups/' + 'mat.T' + str(it) + ext_mat + ' ' +
            mat_final + 'mat.T' + str(index_b0[it]) + ext_mat, param.verbose)

    # Copy DWI registration matrices
    sct.printv('\nCopy DWI registration matrices...', param.verbose)
    for iGroup in range(nb_groups):
        for dwi in range(len(group_indexes[iGroup])):
                'cp ' + 'mat_dwigroups/' + 'mat.T' + str(iGroup) + ext_mat +
                ' ' + mat_final + 'mat.T' + str(group_indexes[iGroup][dwi]) +
                ext_mat, param.verbose)

    # Spline Regularization along T
    if param.spline_fitting:
        moco.spline(mat_final, nt, nz, param.verbose, np.array(index_b0),

    # combine Eddy Matrices
    if param.run_eddy:
        param.mat_2_combine = 'mat_eddy'
        param.mat_final = mat_final

    # Apply moco on all dmri data
    sct.printv('  Apply moco', param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = file_dwi + '_mean_' + str(
        0)  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_dmri = Image(file_data + ext_data)
    im_dmri_moco = Image(file_data + param.suffix + ext_data)
    im_dmri_moco = copy_header(im_dmri, im_dmri_moco)

    # generate b0_moco_mean and dwi_moco_mean
    cmd = 'sct_dmri_separate_b0_and_dwi -i ' + file_data + param.suffix + ext_data + ' -bvec bvecs.txt -a 1'
    if not param.fname_bvals == '':
        cmd = cmd + ' -m ' + param.fname_bvals
    sct.run(cmd, param.verbose)
ファイル: moco.py プロジェクト: cinnnnn/spinalcordtoolbox
def moco(param):
    Main function that performs motion correction.

    :param param:
    # retrieve parameters
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = param.mat_moco  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    verbose = param.verbose

    # other parameters
    file_mask = 'mask.nii'

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + file_data, param.verbose)
    printv('  Reference file ........ ' + file_target, param.verbose)
    printv('  Polynomial degree ..... ' + param.poly, param.verbose)
    printv('  Smoothing kernel ...... ' + param.smooth, param.verbose)
    printv('  Gradient step ......... ' + param.gradStep, param.verbose)
    printv('  Metric ................ ' + param.metric, param.verbose)
    printv('  Sampling .............. ' + param.sampling, param.verbose)
    printv('  Todo .................. ' + todo, param.verbose)
    printv('  Mask  ................. ' + param.fname_mask, param.verbose)
    printv('  Output mat folder ..... ' + folder_mat, param.verbose)

    except FileExistsError:

    # Get size of data
    printv('\nData dimensions:', verbose)
    im_data = Image(param.file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
        ('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)),

    # copy file_target to a temporary file
    printv('\nCopy file_target to a temporary file...', verbose)
    file_target = "target.nii.gz"
    convert(param.file_target, file_target, verbose=0)

    # Check if user specified a mask
    if not param.fname_mask == '':
        # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask)
        im_mask = Image(param.fname_mask)
        if not np.array_equal(im_mask.data, im_mask.data.astype(bool)):
            # If it is a soft mask, multiply the target by the soft mask.
            im = Image(file_target)
            im_masked = im.copy()
            im_masked.data = im.data * im_mask.data
                verbose=0)  # silence warning about file overwritting

    # If scan is sagittal, split src and target along Z (slice)
    if param.is_sagittal:
        dim_sag = 2  # TODO: find it
        # z-split data (time series)
        im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False)
        file_data_splitZ = []
        for im_z in im_z_list:
        # z-split target
        im_targetz_list = split_data(Image(file_target),
        file_target_splitZ = []
        for im_targetz in im_targetz_list:
        # z-split mask (if exists)
        if not param.fname_mask == '':
            im_maskz_list = split_data(Image(file_mask),
            file_mask_splitZ = []
            for im_maskz in im_maskz_list:
        # initialize file list for output matrices
        file_mat = np.empty((nz, nt), dtype=object)

    # axial orientation
        file_data_splitZ = [file_data]  # TODO: make it absolute like above
        file_target_splitZ = [file_target]  # TODO: make it absolute like above
        # initialize file list for output matrices
        file_mat = np.empty((1, nt), dtype=object)

        # deal with mask
        if not param.fname_mask == '':
            convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0)
            im_maskz_list = [Image(file_mask)
                             ]  # use a list with single element

    # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise)
    # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)])

    file_data_splitZ_moco = []
        '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)'
    for file in file_data_splitZ:
        iz = file_data_splitZ.index(file)
        # Split data along T dimension
        # printv('\nSplit data along T dimension.', verbose)
        im_z = Image(file)
        list_im_zt = split_data(im_z, dim=3)
        file_data_splitZ_splitT = []
        for im_zt in list_im_zt:
        # file_data_splitT = file_data + '_T'

        # Motion correction: initialization
        index = np.arange(nt)
        file_data_splitT_num = []
        file_data_splitZ_splitT_moco = []
        failed_transfo = [0 for i in range(nt)]

        # Motion correction: Loop across T
        for indice_index in sct_progress_bar(range(nt),
                                             desc="Z=" + str(iz) + "/" +
                                             str(len(file_data_splitZ) - 1),

            # create indices and display stuff
            it = index[indice_index]
            file_mat[iz][it] = os.path.join(
                "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4)
                add_suffix(file_data_splitZ_splitT[it], '_moco'))
            # deal with masking (except in the 'apply' case, where masking is irrelevant)
            input_mask = None
            if not param.fname_mask == '' and not param.todo == 'apply':
                # Check if mask is binary
                if np.array_equal(im_maskz_list[iz].data,
                    # If it is, pass this mask into register() to be used
                    input_mask = im_maskz_list[iz]
                    # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks.
                    #  Instead, multiply the input data by the Gaussian mask.
                    im = Image(file_data_splitZ_splitT[it])
                    im_masked = im.copy()
                    im_masked.data = im.data * im_maskz_list[iz].data
                        verbose=0)  # silence warning about file overwritting

            # run 3D registration
            failed_transfo[it] = register(param,

            # average registered volume with target image
            # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
            if param.iterAvg and indice_index < 10 and failed_transfo[
                    it] == 0 and not param.todo == 'apply':
                im_targetz = Image(file_target_splitZ[iz])
                data_targetz = im_targetz.data
                data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data
                data_targetz = (data_targetz * (indice_index + 1) +
                                data_mocoz) / (indice_index + 2)
                im_targetz.data = data_targetz

        # Replace failed transformation with the closest good one
        fT = [i for i, j in enumerate(failed_transfo) if j == 1]
        gT = [i for i, j in enumerate(failed_transfo) if j == 0]
        for it in range(len(fT)):
            abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))]
            if not abs_dist == []:
                index_good = abs_dist.index(min(abs_dist))
                    '  transfo #' + str(fT[it]) + ' --> use transfo #' +
                    str(gT[index_good]), verbose)
                # copy transformation
                copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz',
                     file_mat[iz][fT[it]] + 'Warp.nii.gz')
                # apply transformation
                    '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target,
                    '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o',
                    file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp
                # exit program if no transformation exists.
                    '\nERROR in ' + os.path.basename(__file__) +
                    ': No good transformation exist. Exit program.\n', verbose,

        # Merge data along T
        file_data_splitZ_moco.append(add_suffix(file, suffix))
        if todo != 'estimate':
            im_out = concat_data(file_data_splitZ_splitT_moco, 3)
            im_out.absolutepath = file_data_splitZ_moco[iz]

    # If sagittal, merge along Z
    if param.is_sagittal:
        # TODO: im_out.dim is incorrect: Z value is one
        im_out = concat_data(file_data_splitZ_moco, 2)
        dirname, basename, ext = extract_fname(file_data)
        path_out = os.path.join(dirname, basename + suffix + ext)
        im_out.absolutepath = path_out

    return file_mat, im_out
def create_mask():
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI

    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        sct.check_file_exist(method_val, param.verbose)

    # check if orientation is RPI
    sct.printv('\nCheck if orientation is RPI...', param.verbose)
    ori = get_orientation(param.fname_data, filename=True)
    if not ori == 'RPI':
        sct.printv('\nERROR in '+os.path.basename(__file__)+': Orientation of input image should be RPI. Use sct_image -setorient to put your image in RPI.\n', 1, 'error')

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = param.file_prefix+file_data+ext_data
    #fname_out = os.path.abspath(path_out+file_out+ext_out)

    # create temporary folder
    sct.printv('\nCreate temporary folder...', param.verbose)
    path_tmp = sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
    sct.run('mkdir '+path_tmp, param.verbose)

    # Copying input data to tmp folder and convert to nii
    sct.printv('\nCopying input data to tmp folder and convert to nii...', param.verbose)
    convert(param.fname_data, path_tmp+'data.nii')
    # sct.run('cp '+param.fname_data+' '+path_tmp+'data'+ext_data, param.verbose)
    if method_type == 'centerline':
        convert(method_val, path_tmp+'centerline.nii.gz')

    # go to tmp folder

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image('data.nii').dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)+ ' x ' + str(nt), param.verbose)
    # in case user input 4d data
    if nt != 1:
        sct.printv('WARNING in '+os.path.basename(__file__)+': Input image is 4d but output mask will 3D.', param.verbose, 'warning')
        # extract first volume to have 3d reference
        nii = Image('data.nii')
        data3d = nii.data[:,:,:,0]
        nii.data = data3d

    if method_type == 'coord':
        # parse to get coordinate
        coord = map(int, method_val.split('x'))

    if method_type == 'point':
        # get file name
        fname_point = method_val
        # extract coordinate of point
        sct.printv('\nExtract coordinate of point...', param.verbose)
        status, output = sct.run('sct_label_utils -i '+fname_point+' -p display-voxel', param.verbose)
        # parse to get coordinate
        coord = output[output.find('Position=')+10:-17].split(',')

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = round(float(nx)/2), round(float(ny)/2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline.nii.gz'
        # generate volume with line along Z at coordinates 'coord'
        sct.printv('\nCreate line...', param.verbose)
        fname_centerline = create_line('data.nii', coord, nz)

    # create mask
    sct.printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    data_centerline = centerline.get_data()  # get centerline
    z_centerline = [iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()]
    nz = len(z_centerline)
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        cx[iz], cy[iz] = ndimage.measurements.center_of_mass(numpy.array(data_centerline[:, :, z_centerline[iz]]))
    # create 2d masks
    file_mask = 'data_mask'
    for iz in range(nz):
        center = numpy.array([cx[iz], cy[iz]])
        mask2d = create_mask2d(center, param.shape, param.size, nx, ny, param.even)
        # Write NIFTI volumes
        img = nibabel.Nifti1Image(mask2d, None, hdr)
        nibabel.save(img, (file_mask+str(iz)+'.nii'))
    # merge along Z
    # cmd = 'fslmerge -z mask '
    im_list = []
    for iz in range(nz):
    im_out = concat_data(im_list, 2)

    # copy geometry
    im_dat = Image('data.nii')
    im_mask = Image('mask.nii.gz')
    im_mask = copy_header(im_dat, im_mask)

    # come back to parent folder

    # Generate output files
    sct.printv('\nGenerate output files...', param.verbose)
    sct.generate_output_file(path_tmp+'mask.nii.gz', param.fname_out)

    # Remove temporary files
    if param.remove_tmp_files == 1:
        sct.printv('\nRemove temporary files...', param.verbose)
        sct.run('rm -rf '+path_tmp, param.verbose, error_exit='warning')

    # to view results
    sct.printv('\nDone! To view results, type:', param.verbose)
    sct.printv('fslview '+param.fname_data+' '+param.fname_out+' -l Red -t 0.5 &', param.verbose, 'info')
def fmri_moco(param):

    file_data = 'fmri'
    ext_data = '.nii'
    mat_final = 'mat_final/'
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(file_data + '.nii').dim
        '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data = Image(file_data + ext_data)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:

    # assign an index to each volume
    index_fmri = range(0, nt)

    # Number of groups
    nb_groups = int(math.floor(nt / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_fmri[(iGroup *
                                         param.group_size):((iGroup + 1) *

    # add the remaining images to the last DWI group
    nb_remaining = nt % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_fmri[len(index_fmri) -

    # groups
    for iGroup in range(nb_groups):
        sct.printv('\nGroup: ' + str((iGroup + 1)) + '/' + str(nb_groups),

        # get index
        index_fmri_i = group_indexes[iGroup]
        nt_i = len(index_fmri_i)

        # Merge Images
        sct.printv('Merge consecutive volumes...', param.verbose)
        file_data_merge_i = file_data + '_' + str(iGroup)
        # cmd = fsloutput + 'fslmerge -t ' + file_data_merge_i
        # for it in range(nt_i):
        #     cmd = cmd + ' ' + file_data + '_T' + str(index_fmri_i[it]).zfill(4)

        im_fmri_list = []
        for it in range(nt_i):
        im_fmri_concat = concat_data(im_fmri_list, 3)
        im_fmri_concat.setFileName(file_data_merge_i + ext_data)

        # Average Images
        sct.printv('Average volumes...', param.verbose)
        file_data_mean = file_data + '_mean_' + str(iGroup)
        sct.run('sct_maths -i ' + file_data_merge_i + '.nii -o ' +
                file_data_mean + '.nii -mean t')
        # if not average_data_across_dimension(file_data_merge_i+'.nii', file_data_mean+'.nii', 3):
        #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
        # cmd = fsloutput + 'fslmaths ' + file_data_merge_i + ' -Tmean ' + file_data_mean
        # sct.run(cmd, param.verbose)

    # Merge groups means
    sct.printv('\nMerging volumes...', param.verbose)
    file_data_groups_means_merge = 'fmri_averaged_groups'
    im_mean_list = []
    for iGroup in range(nb_groups):
            Image(file_data + '_mean_' + str(iGroup) + ext_data))
    im_mean_concat = concat_data(im_mean_list, 3)
    im_mean_concat.setFileName(file_data_groups_means_merge + ext_data)

    # Estimate moco
    sct.printv('  Estimating motion...', param.verbose)
    param_moco = param
    param_moco.file_data = 'fmri_averaged_groups'
    param_moco.file_target = file_data + '_mean_' + param.num_target
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'

    # create final mat folder

    # Copy registration matrices
    sct.printv('\nCopy transformations...', param.verbose)
    for iGroup in range(nb_groups):
        for data in range(len(group_indexes[iGroup])):
                'cp ' + 'mat_groups/' + 'mat.T' + str(iGroup) + ext_mat + ' ' +
                mat_final + 'mat.T' + str(group_indexes[iGroup][data]) +
                ext_mat, param.verbose)

    # Apply moco on all fmri data
    sct.printv('  Apply moco', param.verbose)
    param_moco.file_data = 'fmri'
    param_moco.file_target = file_data + '_mean_' + str(0)
    param_moco.path_out = ''
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_fmri = Image('fmri.nii')
    im_fmri_moco = Image('fmri_moco.nii')
    im_fmri_moco = copy_header(im_fmri, im_fmri_moco)

    # Average volumes
    sct.printv('\nAveraging data...', param.verbose)
    sct.run('sct_maths -i fmri_moco.nii -o fmri_moco_mean.nii -mean t')
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        list_warp = self.list_warp  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        islabel = False
        if self.interp == 'label':
            islabel = True
            self.interp = 'nn'

        interp = sct.get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        sct.printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # list_warp = list_warp.replace(' ', '')  # remove spaces
        # list_warp = list_warp.split(",")  # parse with comma
        for idx_warp, path_warp in enumerate(self.list_warp):
            # Check if this transformation should be inverted
            if path_warp in self.list_warpinv:
                # list_warp[idx_warp] = path_warp[1:]  # remove '-'
                fname_warp_list_invert += [[use_inverse[idx_warp], list_warp[idx_warp]]]
                fname_warp_list_invert += [[path_warp]]
            path_warp = list_warp[idx_warp]
            if path_warp.endswith((".nii", ".nii.gz")) \
             and Image(list_warp[idx_warp]).header.get_intent()[0] != 'vector':
                raise ValueError("Displacement field in {} is invalid: should be encoded" \
                 " in a 5D file with vector intent code" \
                 " (see https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h" \
        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = sct.extract_fname(fname_warp_list_invert[-1][-1])
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        ## check if destination file is 3d
        # sct.check_dim(fname_dest, dim_lst=[3]) # PR 2598: we decided to skip this line.

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order
        fname_warp_list_invert = functools.reduce(lambda x, y: x + y, fname_warp_list_invert)

        # Extract path, file and extension
        path_src, file_src, ext_src = sct.extract_fname(fname_src)
        path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = os.path.join(path_out, file_out + ext_out)

        # Get dimensions of data
        sct.printv('\nGet dimensions of data...', verbose)
        img_src = Image(fname_src)
        nx, ny, nz, nt, px, py, pz, pt = img_src.dim
        # nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_src)
        sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            sct.printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
                dim = '3'
            # if labels, dilate before resampling
            if islabel:
                sct.printv("\nDilate labels before warping...")
                path_tmp = sct.tmp_create(basename="apply_transfo", verbose=verbose)
                fname_dilated_labels = os.path.join(path_tmp, "dilated_data.nii")
                # dilate points
                dilate(Image(fname_src), 2, 'ball').save(fname_dilated_labels)
                fname_src = fname_dilated_labels

            sct.printv("\nApply transformation and resample to destination space...", verbose)
                     '-d', dim,
                     '-i', fname_src,
                     '-o', fname_out,
                     ] + fname_warp_list_invert + ['-r', fname_dest] + interp, verbose=verbose, is_sct_binary=True)

        # if 4d, loop across the T dimension
            if islabel:
                raise NotImplementedError

            dim = '4'
            path_tmp = sct.tmp_create(basename="apply_transfo", verbose=verbose)

            # convert to nifti into temp folder
            sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
            img_src.save(os.path.join(path_tmp, "data.nii"))
            sct.copy(fname_dest, os.path.join(path_tmp, file_dest + ext_dest))
            fname_warp_list_tmp = []
            for fname_warp in list_warp:
                path_warp, file_warp, ext_warp = sct.extract_fname(fname_warp)
                sct.copy(fname_warp, os.path.join(path_tmp, file_warp + ext_warp))
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            curdir = os.getcwd()

            # split along T dimension
            sct.printv('\nSplit along T dimension...', verbose)

            im_dat = Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = sct_image.split_data(im_dat, 3)
            for im in data_split_list:

            # apply transfo
            sct.printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'

                status, output = run_proc(['isct_antsApplyTransforms',
                                          '-d', '3',
                                          '-i', file_data_split,
                                          '-o', file_data_split_reg,
                                          ] + fname_warp_list_invert_tmp + [
                    '-r', file_dest + ext_dest,
                ] + interp, verbose, is_sct_binary=True)

            # Merge files back
            sct.printv('\nMerge file back...', verbose)
            import glob
            path_out, name_out, ext_out = sct.extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            im_out = sct_image.concat_data(fname_list, 3, im_header['pixdim'])
            im_out.save(name_out + ext_out)

            sct.generate_output_file(os.path.join(path_tmp, name_out + ext_out), fname_out)
            # Delete temporary folder if specified
            if remove_temp_files:
                sct.printv('\nRemove temporary files...', verbose)
                sct.rmtree(path_tmp, verbose=verbose)

        # Copy affine matrix from destination space to make sure qform/sform are the same
        sct.printv("Copy affine matrix from destination space to make sure qform/sform are the same.", verbose)
        im_src_reg = Image(fname_out)
        im_src_reg.save(verbose=0)  # set verbose=0 to avoid warning message about rewriting file

        if islabel:
            sct.printv("\nTake the center of mass of each registered dilated labels...")
            labeled_img = cubic_to_point(im_src_reg)
            if remove_temp_files:
                sct.printv('\nRemove temporary files...', verbose)
                sct.rmtree(path_tmp, verbose=verbose)

        # Crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # If the last transformation is not an affine transfo, we need to compute the matrix space of the concatenated
        # warping field
        if not isLastAffine and crop_reference in [1, 2]:
            sct.printv('Last transformation is not affine.')
            if crop_reference in [1, 2]:
                # Extract only the first ndim of the warping field
                img_warp = Image(warping_field)
                if dim == '2':
                    img_warp_ndim = Image(img_src.data[:, :], hdr=img_warp.hdr)
                elif dim in ['3', '4']:
                    img_warp_ndim = Image(img_src.data[:, :, :], hdr=img_warp.hdr)
                # Set zero to everything outside the warping field
                cropper = ImageCropper(Image(fname_out))
                if crop_reference == 1:
                    sct.printv('Cropping strategy is: keep same matrix size, put 0 everywhere around warping field')
                    img_out = cropper.crop(background=0)
                elif crop_reference == 2:
                    sct.printv('Cropping strategy is: crop around warping field (the size of warping field will '
                    img_out = cropper.crop()

        sct.display_viewer_syntax([fname_dest, fname_out], verbose=verbose)
def main(args=None):
    if not args:
        args = sys.argv[1:]

    # initialize parameters
    param = Param()
    # call main function
    parser = get_parser()
    arguments = parser.parse(args)

    fname_data = arguments['-i']
    fname_bvecs = arguments['-bvec']
    average = arguments['-a']
    verbose = int(arguments['-v'])
    remove_temp_files = int(arguments['-r'])
    path_out = arguments['-ofolder']

    if '-bval' in arguments:
        fname_bvals = arguments['-bval']
        fname_bvals = ''
    if '-bvalmin' in arguments:
        param.bval_min = arguments['-bvalmin']

    # Initialization
    start_time = time.time()

    # sct.printv(arguments)
    sct.printv('\nInput parameters:', verbose)
    sct.printv('  input file ............' + fname_data, verbose)
    sct.printv('  bvecs file ............' + fname_bvecs, verbose)
    sct.printv('  bvals file ............' + fname_bvals, verbose)
    sct.printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # # get output folder
    # if path_out == '':
    #     path_out = ''

    # create temporary folder
    path_tmp = sct.tmp_create(basename="dmri_separate", verbose=verbose)

    # copy files into tmp folder and convert to nifti
    sct.printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = 'b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = 'dwi'
    dwi_mean_name = dwi_name + '_mean'

    from sct_convert import convert
    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        sct.printv('ERROR in convert.', 1, 'error')
    sct.copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    sct.printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
        '.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),

    # Identify b=0 and DWI images
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals,
                                                     param.bval_min, verbose)

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:

    # Merge b=0 images
    sct.printv('\nMerge b=0...', verbose)
    from sct_image import concat_data
    l = []
    for it in range(nb_b0):
        l.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        sct.printv('\nAverage b=0...', verbose)
            'sct_maths', '-i', b0_name + ext, '-o', b0_mean_name + ext,
            '-mean', 't'
        ], verbose)

    # Merge DWI
    l = []
    for it in range(nb_dwi):
        l.append(dmri_name + '_T' + str(index_dwi[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        sct.printv('\nAverage DWI...', verbose)
            'sct_maths', '-i', dwi_name + ext, '-o', dwi_mean_name + ext,
            '-mean', 't'
        ], verbose)
        # if not average_data_across_dimension('dwi.nii', 'dwi_mean.nii', 3):
        #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
        # sct.run(fsloutput + 'fslmaths dwi -Tmean dwi_mean', verbose)

    # come back

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(os.path.join(path_tmp, b0_name + ext),
                             os.path.join(path_out, b0_name + ext_data),
    sct.generate_output_file(os.path.join(path_tmp, dwi_name + ext),
                             os.path.join(path_out, dwi_name + ext_data),
    if average:
            os.path.join(path_tmp, b0_mean_name + ext),
            os.path.join(path_out, b0_mean_name + ext_data), verbose)
            os.path.join(path_tmp, dwi_mean_name + ext),
            os.path.join(path_out, dwi_mean_name + ext_data), verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        sct.printv('\nRemove temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',

    # to view results
    sct.printv('\nTo view results, type: ', verbose)
    if average:
        sct.display_viewer_syntax(['b0', 'b0_mean', 'dwi', 'dwi_mean'])
        sct.display_viewer_syntax(['b0', 'dwi'])
def main(args = None):

    orientation = ''
    change_header = ''
    fname_out = ''

    if not args:
        args = sys.argv[1:]

    # Building the command, do sanity checks
    parser = get_parser()
    arguments = parser.parse(sys.argv[1:])
    fname_in = arguments['-i']
    if '-o' in arguments:
        fname_out = arguments['-o']
    if '-s' in arguments:
        orientation = arguments['-s']
    if '-a' in arguments:
        change_header = arguments['-a']
    remove_tmp_files = int(arguments['-r'])
    verbose = int(arguments['-v'])
    inversion = False  # change orientation

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S/")
    status, output = sct.run('mkdir '+path_tmp, verbose)

    # copy file in temp folder
    sct.printv('\nCopy files to tmp folder...', verbose)
    convert(fname_in, path_tmp+'data.nii', verbose=0)

    # go to temp folder

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image('data.nii').dim
    sct.printv(str(nx) + ' x ' + str(ny) + ' x ' + str(nz)+ ' x ' + str(nt), verbose)

    # if data are 3d, directly set or get orientation
    if nt == 1:
        if orientation != '':
            # set orientation
            sct.printv('\nChange orientation...', verbose)
            if change_header == '':
                set_orientation('data.nii', orientation, 'data_orient.nii')
                set_orientation('data.nii', change_header, 'data_orient.nii', True)
            # get orientation
            sct.printv('\nGet orientation...', verbose)
            print get_orientation('data.nii')
        # split along T dimension
        sct.printv('\nSplit along T dimension...', verbose)
        im = Image('data.nii')
        im_split_list = split_data(im, 3)
        for im_s in im_split_list:
        if orientation != '':
            # set orientation
            sct.printv('\nChange orientation...', verbose)
            for it in range(nt):
                file_data_split = 'data_T'+str(it).zfill(4)+'.nii'
                file_data_split_orient = 'data_orient_T'+str(it).zfill(4)+'.nii'
                set_orientation(file_data_split, orientation, file_data_split_orient)
            # Merge files back
            sct.printv('\nMerge file back...', verbose)
            from glob import glob
            im_data_list = [Image(fname) for fname in glob('data_orient_T*.nii')]
            im_concat = concat_data(im_data_list, 3)

            sct.printv('\nGet orientation...', verbose)
            print get_orientation('data_T0000.nii')

    # come back to parent folder

    # Generate output files
    if orientation != '':
        # Build fname_out
        if fname_out == '':
            path_data, file_data, ext_data = sct.extract_fname(fname_in)
            fname_out = path_data+file_data+'_'+orientation+ext_data
        sct.printv('\nGenerate output files...', verbose)
        sct.generate_output_file(path_tmp+'data_orient.nii', fname_out)

    # Remove temporary files
    if remove_tmp_files == 1:
        sct.printv('\nRemove temporary files...', verbose)
        sct.run('rm -rf '+path_tmp, verbose)
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        list_warp = self.list_warp  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        interp = sct.get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        sct.printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # list_warp = list_warp.replace(' ', '')  # remove spaces
        # list_warp = list_warp.split(",")  # parse with comma
        for idx_warp, path_warp in enumerate(self.list_warp):
            # Check if this transformation should be inverted
            if path_warp in self.list_warpinv:
                # list_warp[idx_warp] = path_warp[1:]  # remove '-'
                fname_warp_list_invert += [[
                    use_inverse[idx_warp], list_warp[idx_warp]
                fname_warp_list_invert += [[path_warp]]
            path_warp = list_warp[idx_warp]
            if path_warp.endswith((".nii", ".nii.gz")) \
             and msct_image.Image(list_warp[idx_warp]).header.get_intent()[0] != 'vector':
                raise ValueError("Displacement field in {} is invalid: should be encoded" \
                 " in a 5D file with vector intent code" \
                 " (see https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h" \
        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = sct.extract_fname(
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # check if destination file is 3d
        if not sct.check_if_3d(fname_dest):
            sct.printv('ERROR: Destination data must be 3d')

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order
        fname_warp_list_invert = functools.reduce(lambda x, y: x + y,

        # Extract path, file and extension
        path_src, file_src, ext_src = sct.extract_fname(fname_src)
        path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = os.path.join(path_out, file_out + ext_out)

        # Get dimensions of data
        sct.printv('\nGet dimensions of data...', verbose)
        img_src = msct_image.Image(fname_src)
        nx, ny, nz, nt, px, py, pz, pt = img_src.dim
        # nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_src)
            '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' +
            str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            sct.printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
                dim = '3'
                'isct_antsApplyTransforms', '-d', dim, '-i', fname_src, '-o',
                fname_out, '-t'
            ] + fname_warp_list_invert + ['-r', fname_dest] + interp,

        # if 4d, loop across the T dimension
            path_tmp = sct.tmp_create(basename="apply_transfo",

            # convert to nifti into temp folder
                '\nCopying input data to tmp folder and convert to nii...',
            img_src.save(os.path.join(path_tmp, "data.nii"))
            sct.copy(fname_dest, os.path.join(path_tmp, file_dest + ext_dest))
            fname_warp_list_tmp = []
            for fname_warp in list_warp:
                path_warp, file_warp, ext_warp = sct.extract_fname(fname_warp)
                         os.path.join(path_tmp, file_warp + ext_warp))
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            curdir = os.getcwd()

            # split along T dimension
            sct.printv('\nSplit along T dimension...', verbose)

            im_dat = msct_image.Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = sct_image.split_data(im_dat, 3)
            for im in data_split_list:

            # apply transfo
            sct.printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'

                status, output = sct.run([
                ] + fname_warp_list_invert_tmp + [
                    file_dest + ext_dest,
                ] + interp,

            # Merge files back
            sct.printv('\nMerge file back...', verbose)
            import glob
            path_out, name_out, ext_out = sct.extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            im_out = sct_image.concat_data(fname_list, 3, im_header['pixdim'])
            im_out.save(name_out + ext_out)

                os.path.join(path_tmp, name_out + ext_out), fname_out)
            # Delete temporary folder if specified
            if int(remove_temp_files):
                sct.printv('\nRemove temporary files...', verbose)
                sct.rmtree(path_tmp, verbose=verbose)

        # 2. crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # if last warping field is an affine transfo, we need to compute the space of the concatenate warping field:
        if isLastAffine:
                'WARNING: the resulting image could have wrong apparent results. You should use an affine transformation as last transformation...',
                verbose, 'warning')
        elif crop_reference == 1:
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field+' -b 0')
        elif crop_reference == 2:
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field)

        sct.display_viewer_syntax([fname_dest, fname_out], verbose=verbose)
def dmri_moco(param):

    file_data = 'dmri.nii'
    file_data_dirname, file_data_basename, file_data_ext = sct.extract_fname(file_data)
    file_b0 = 'b0.nii'
    file_dwi = 'dwi.nii'
    ext_data = '.nii.gz' # workaround "too many open files" by slurping the data
    mat_final = 'mat_final/'
    file_dwi_group = 'dwi_averaged_groups.nii'
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Identify b=0 and DWI images
    index_b0, index_dwi, nb_b0, nb_dwi = sct_dmri_separate_b0_and_dwi.identify_b0('bvecs.txt', param.fname_bvals, param.bval_min, param.verbose)

    # check if dmri and bvecs are the same size
    if not nb_b0 + nb_dwi == nt:
        sct.printv('\nERROR in ' + os.path.basename(__file__) + ': Size of data (' + str(nt) + ') and size of bvecs (' + str(nb_b0 + nb_dwi) + ') are not the same. Check your bvecs file.\n', 1, 'error')

    # Prepare NIFTI (mean/groups...)
    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")

    # Merge b=0 images
    sct.printv('\nMerge b=0...', param.verbose)
    im_b0_list = []
    for it in range(nb_b0):
    im_b0_out = concat_data(im_b0_list, 3).save(file_b0)
    sct.printv(('  File created: ' + file_b0), param.verbose)

    # Average b=0 images
    sct.printv('\nAverage b=0...', param.verbose)
    file_b0_mean = sct.add_suffix(file_b0, '_mean')
    sct.run(['sct_maths', '-i', file_b0, '-o', file_b0_mean, '-mean', 't'], param.verbose)

    # Number of DWI groups
    nb_groups = int(math.floor(nb_dwi / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_dwi[(iGroup * param.group_size):((iGroup + 1) * param.group_size)])

    # add the remaining images to the last DWI group
    nb_remaining = nb_dwi%param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_dwi[len(index_dwi) - nb_remaining:len(index_dwi)])

    file_dwi_dirname, file_dwi_basename, file_dwi_ext = sct.extract_fname(file_dwi)
    # DWI groups
    file_dwi_mean = []
    for iGroup in tqdm(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=True, ncols=80):
        # get index
        index_dwi_i = group_indexes[iGroup]
        nb_dwi_i = len(index_dwi_i)
        # Merge DW Images
        file_dwi_merge_i = os.path.join(file_dwi_dirname, file_dwi_basename + '_' + str(iGroup) + ext_data)
        im_dwi_list = []
        for it in range(nb_dwi_i):
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i)
        # Average DW Images
        file_dwi_mean_i = os.path.join(file_dwi_dirname, file_dwi_basename + '_mean_' + str(iGroup) + ext_data)
        sct.run(["sct_maths", "-i", file_dwi_merge_i, "-o", file_dwi_mean[iGroup], "-mean", "t"], 0)

    # Merge DWI groups means
    sct.printv('\nMerging DW files...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    im_dw_list = []
    for iGroup in range(nb_groups):
    im_dw_out = concat_data(im_dw_list, 3).save(file_dwi_group)

    # Average DW Images
    # TODO: USEFULL ???
    sct.printv('\nAveraging all DW images...', param.verbose)
    sct.run(["sct_maths", "-i", file_dwi_group, "-o", file_dwi_group + '_mean' + ext_data, "-mean", "t"], param.verbose)

    # segment dwi images using otsu algorithm
    if param.otsu:
        sct.printv('\nSegment group DWI using OTSU algorithm...', param.verbose)
        # import module
        otsu = importlib.import_module('sct_otsu')
        # get class from module
        param_otsu = otsu.param()  #getattr(otsu, param)
        param_otsu.fname_data = file_dwi_group
        param_otsu.threshold = param.otsu
        param_otsu.file_suffix = '_seg'
        # run otsu
        file_dwi_group = file_dwi_group + '_seg.nii'


    # Estimate moco on b0 groups
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion on b=0 images...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco = param
    param_moco.file_data = 'b0.nii'
    # identify target image
    if index_dwi[0] != 0:
        # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that case
        # select it as the target image for registration of all b=0
        param_moco.file_target = os.path.join(file_data_dirname, file_data_basename + '_T' + str(index_b0[index_dwi[0] - 1]).zfill(4) + ext_data)
        # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
        param_moco.file_target = os.path.join(file_data_dirname, file_data_basename + '_T' + str(index_b0[0]).zfill(4) + ext_data)

    param_moco.path_out = ''
    param_moco.todo = 'estimate'
    param_moco.mat_moco = 'mat_b0groups'
    file_mat_b0 = moco.moco(param_moco)

    # Estimate moco on dwi groups
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion on DW images...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco.file_data = file_dwi_group
    param_moco.file_target = file_dwi_mean[0]  # target is the first DW image (closest to the first b=0)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_dwigroups'
    file_mat_dwi = moco.moco(param_moco)

    # create final mat folder

    # Copy b=0 registration matrices
    # TODO: use file_mat_b0 and file_mat_dwi instead of the hardcoding below
    sct.printv('\nCopy b=0 registration matrices...', param.verbose)
    for it in range(nb_b0):
        sct.copy('mat_b0groups/' + 'mat.Z0000T' + str(it).zfill(4) + ext_mat,
                 mat_final + 'mat.Z0000T' + str(index_b0[it]).zfill(4) + ext_mat)

    # Copy DWI registration matrices
    sct.printv('\nCopy DWI registration matrices...', param.verbose)
    for iGroup in range(nb_groups):
        for dwi in range(len(group_indexes[iGroup])):  # we cannot use enumerate because group_indexes has 2 dim.
            sct.copy('mat_dwigroups/' + 'mat.Z0000T' + str(iGroup).zfill(4) + ext_mat,
                     mat_final + 'mat.Z0000T' + str(group_indexes[iGroup][dwi]).zfill(4) + ext_mat)

    # Spline Regularization along T
    if param.spline_fitting:
        moco.spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # combine Eddy Matrices
    if param.run_eddy:
        param.mat_2_combine = 'mat_eddy'
        param.mat_final = mat_final

    # Apply moco on all dmri data
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Apply moco', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = os.path.join(file_dwi_dirname, file_dwi_basename + '_mean_' + str(0) + ext_data)  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_dmri = Image(file_data)

    fname_data_moco = os.path.join(file_data_dirname, file_data_basename + param.suffix + '.nii')
    im_dmri_moco = Image(fname_data_moco)
    im_dmri_moco.header = im_dmri.header

    return os.path.abspath(fname_data_moco)
def moco(param):

    # retrieve parameters
    fsloutput = "export FSLOUTPUTTYPE=NIFTI; "  # for faster processing, all outputs are in NIFTI
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = sct.slash_at_the_end(param.mat_moco, 1)  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    # file_schedule = param.file_schedule
    verbose = param.verbose
    ext = ".nii"

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput("echo $SCT_DIR")

    # print arguments
    sct.printv("\nInput parameters:", param.verbose)
    sct.printv("  Input file ............" + file_data, param.verbose)
    sct.printv("  Reference file ........" + file_target, param.verbose)
    sct.printv("  Polynomial degree ....." + param.param[0], param.verbose)
    sct.printv("  Smoothing kernel ......" + param.param[1], param.verbose)
    sct.printv("  Gradient step ........." + param.param[2], param.verbose)
    sct.printv("  Metric ................" + param.param[3], param.verbose)
    sct.printv("  Todo .................." + todo, param.verbose)
    sct.printv("  Mask  ................." + param.fname_mask, param.verbose)
    sct.printv("  Output mat folder ....." + folder_mat, param.verbose)

    # create folder for mat files

    # Get size of data
    sct.printv("\nGet dimensions data...", verbose)
    data_im = Image(file_data + ext)
    nx, ny, nz, nt, px, py, pz, pt = data_im.dim
    sct.printv((".. " + str(nx) + " x " + str(ny) + " x " + str(nz) + " x " + str(nt)), verbose)

    # copy file_target to a temporary file
    sct.printv("\nCopy file_target to a temporary file...", verbose)
    sct.run("cp " + file_target + ext + " target.nii")
    file_target = "target"

    # Split data along T dimension
    sct.printv("\nSplit data along T dimension...", verbose)
    data_split_list = split_data(data_im, dim=3)
    for im in data_split_list:
    file_data_splitT = file_data + "_T"

    # Motion correction: initialization
    index = np.arange(nt)
    file_data_splitT_num = []
    file_data_splitT_moco_num = []
    failed_transfo = [0 for i in range(nt)]
    file_mat = [[] for i in range(nt)]

    # Motion correction: Loop across T
    for indice_index in range(nt):

        # create indices and display stuff
        it = index[indice_index]
        file_data_splitT_num.append(file_data_splitT + str(it).zfill(4))
        file_data_splitT_moco_num.append(file_data + suffix + "_T" + str(it).zfill(4))
        sct.printv(("\nVolume " + str((it)) + "/" + str(nt - 1) + ":"), verbose)
        file_mat[it] = folder_mat + "mat.T" + str(it)

        # run 3D registration
        failed_transfo[it] = register(
            param, file_data_splitT_num[it], file_target, file_mat[it], file_data_splitT_moco_num[it]

        # average registered volume with target image
        # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
        if param.iterative_averaging and indice_index < 10 and failed_transfo[it] == 0:
            sct.run("sct_maths -i " + file_target + ext + " -mul " + str(indice_index + 1) + " -o " + file_target + ext)
                "sct_maths -i "
                + file_target
                + ext
                + " -add "
                + file_data_splitT_moco_num[it]
                + ext
                + " -o "
                + file_target
                + ext
            sct.run("sct_maths -i " + file_target + ext + " -div " + str(indice_index + 2) + " -o " + file_target + ext)

    # Replace failed transformation with the closest good one
    sct.printv(("\nReplace failed transformations..."), verbose)
    fT = [i for i, j in enumerate(failed_transfo) if j == 1]
    gT = [i for i, j in enumerate(failed_transfo) if j == 0]
    for it in range(len(fT)):
        abs_dist = [abs(gT[i] - fT[it]) for i in range(len(gT))]
        if not abs_dist == []:
            index_good = abs_dist.index(min(abs_dist))
            sct.printv("  transfo #" + str(fT[it]) + " --> use transfo #" + str(gT[index_good]), verbose)
            # copy transformation
            sct.run("cp " + file_mat[gT[index_good]] + "Warp.nii.gz" + " " + file_mat[fT[it]] + "Warp.nii.gz")
            # apply transformation
                "sct_apply_transfo -i "
                + file_data_splitT_num[fT[it]]
                + ".nii -d "
                + file_target
                + ".nii -w "
                + file_mat[fT[it]]
                + "Warp.nii.gz"
                + " -o "
                + file_data_splitT_moco_num[fT[it]]
                + ".nii"
                + " -x "
                + param.interp,
            # exit program if no transformation exists.
                "\nERROR in " + os.path.basename(__file__) + ": No good transformation exist. Exit program.\n",

    # Merge data along T
    file_data_moco = file_data + suffix
    if todo != "estimate":
        sct.printv("\nMerge data back along T...", verbose)
        # cmd = fsloutput + 'fslmerge -t ' + file_data_moco
        # for indice_index in range(len(index)):
        #     cmd = cmd + ' ' + file_data_splitT_moco_num[indice_index]
        from sct_image import concat_data

        im_list = []
        for indice_index in range(len(index)):
            im_list.append(Image(file_data_splitT_moco_num[indice_index] + ext))
        im_out = concat_data(im_list, 3)
        im_out.setFileName(file_data_moco + ext)
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        fname_warp_list = self.warp_input  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        interp = sct.get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        sct.printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # fname_warp_list = fname_warp_list.replace(' ', '')  # remove spaces
        # fname_warp_list = fname_warp_list.split(",")  # parse with comma
        for idx_warp, path_warp in enumerate(fname_warp_list):
            # Check if inverse matrix is specified with '-' at the beginning of file name
            if path_warp.startswith("-"):
                fname_warp_list[idx_warp] = path_warp[1:]  # remove '-'
                fname_warp_list_invert += [[use_inverse[idx_warp], fname_warp_list[idx_warp]]]
                fname_warp_list_invert += [[path_warp]]
            path_warp = fname_warp_list[idx_warp]
            if path_warp.endswith((".nii", ".nii.gz")) \
             and msct_image.Image(fname_warp_list[idx_warp]).header.get_intent()[0] != 'vector':
                raise ValueError("Displacement field in {} is invalid: should be encoded" \
                 " in a 5D file with vector intent code" \
                 " (see https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h" \
        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = sct.extract_fname(fname_warp_list_invert[-1][-1])
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # check if destination file is 3d
        if not sct.check_if_3d(fname_dest):
            sct.printv('ERROR: Destination data must be 3d')

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order
        fname_warp_list_invert = functools.reduce(lambda x,y: x+y, fname_warp_list_invert)

        # Extract path, file and extension
        path_src, file_src, ext_src = sct.extract_fname(fname_src)
        path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = os.path.join(path_out, file_out + ext_out)

        # Get dimensions of data
        sct.printv('\nGet dimensions of data...', verbose)
        img_src = msct_image.Image(fname_src)
        nx, ny, nz, nt, px, py, pz, pt = img_src.dim
        # nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_src)
        sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            sct.printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
                dim = '3'
              '-d', dim,
              '-i', fname_src,
              '-o', fname_out,
             ] + fname_warp_list_invert + [
             '-r', fname_dest,
             ] + interp, verbose=verbose, is_sct_binary=True)

        # if 4d, loop across the T dimension
            path_tmp = sct.tmp_create(basename="apply_transfo", verbose=verbose)

            # convert to nifti into temp folder
            sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
            img_src.save(os.path.join(path_tmp, "data.nii"))
            sct.copy(fname_dest, os.path.join(path_tmp, file_dest + ext_dest))
            fname_warp_list_tmp = []
            for fname_warp in fname_warp_list:
                path_warp, file_warp, ext_warp = sct.extract_fname(fname_warp)
                sct.copy(fname_warp, os.path.join(path_tmp, file_warp + ext_warp))
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            curdir = os.getcwd()

            # split along T dimension
            sct.printv('\nSplit along T dimension...', verbose)

            im_dat = msct_image.Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = sct_image.split_data(im_dat, 3)
            for im in data_split_list:

            # apply transfo
            sct.printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'

                status, output = sct.run(['isct_antsApplyTransforms',
                  '-d', '3',
                  '-i', file_data_split,
                  '-o', file_data_split_reg,
                 ] + fname_warp_list_invert_tmp + [
                  '-r', file_dest + ext_dest,
                 ] + interp, verbose, is_sct_binary=True)

            # Merge files back
            sct.printv('\nMerge file back...', verbose)
            import glob
            path_out, name_out, ext_out = sct.extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            im_out = sct_image.concat_data(fname_list, 3, im_header['pixdim'])
            im_out.save(name_out + ext_out)

            sct.generate_output_file(os.path.join(path_tmp, name_out + ext_out), fname_out)
            # Delete temporary folder if specified
            if int(remove_temp_files):
                sct.printv('\nRemove temporary files...', verbose)
                sct.rmtree(path_tmp, verbose=verbose)

        # 2. crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # if last warping field is an affine transfo, we need to compute the space of the concatenate warping field:
        if isLastAffine:
            sct.printv('WARNING: the resulting image could have wrong apparent results. You should use an affine transformation as last transformation...', verbose, 'warning')
        elif crop_reference == 1:
            ImageCropper(input_file=fname_out, output_file=fname_out, ref=warping_field, background=0).crop()
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field+' -b 0')
        elif crop_reference == 2:
            ImageCropper(input_file=fname_out, output_file=fname_out, ref=warping_field).crop()
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field)

        sct.display_viewer_syntax([fname_dest, fname_out], verbose=verbose)
def fmri_moco(param):

    file_data = "fmri.nii"
    mat_final = 'mat_final/'
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(param.fname_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
        '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),

    # Get orientation
    sct.printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        sct.printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        sct.printv('  Treated as axial')
        param.is_sagittal = False
            'WARNING: Orientation seems to be neither axial nor sagittal.')

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
            'For sagittal data group_size should be one for more robustness. Forcing group_size=1.',
            1, 'warning')
        param.group_size = 1

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath)
        # Make further steps slurp the data to avoid too many open files (#2149)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")

    # assign an index to each volume
    index_fmri = list(range(0, nt))

    # Number of groups
    nb_groups = int(math.floor(nt / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_fmri[(iGroup *
                                         param.group_size):((iGroup + 1) *

    # add the remaining images to the last fMRI group
    nb_remaining = nt % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_fmri[len(index_fmri) -

    # groups
    for iGroup in tqdm(range(nb_groups),
                       desc="Merge within groups",
        # get index
        index_fmri_i = group_indexes[iGroup]
        nt_i = len(index_fmri_i)

        # Merge Images
        file_data_merge_i = sct.add_suffix(file_data, '_' + str(iGroup))
        # cmd = fsloutput + 'fslmerge -t ' + file_data_merge_i
        # for it in range(nt_i):
        #     cmd = cmd + ' ' + file_data + '_T' + str(index_fmri_i[it]).zfill(4)

        im_fmri_list = []
        for it in range(nt_i):
        im_fmri_concat = concat_data(im_fmri_list, 3,

        file_data_mean = sct.add_suffix(file_data, '_mean_' + str(iGroup))
        if file_data_mean.endswith(".nii"):
            file_data_mean += ".gz"  # #2149
        if param.group_size == 1:
            # copy to new file name instead of averaging (faster)
            # note: this is a bandage. Ideally we should skip this entire for loop if g=1
            convert(file_data_merge_i, file_data_mean)
            # Average Images
                'sct_maths', '-i', file_data_merge_i, '-o', file_data_mean,
                '-mean', 't'
        # if not average_data_across_dimension(file_data_merge_i+'.nii', file_data_mean+'.nii', 3):
        #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
        # cmd = fsloutput + 'fslmaths ' + file_data_merge_i + ' -Tmean ' + file_data_mean
        # sct.run(cmd, param.verbose)

    # Merge groups means. The output 4D volume will be used for motion correction.
    sct.printv('\nMerging volumes...', param.verbose)
    file_data_groups_means_merge = 'fmri_averaged_groups.nii'
    im_mean_list = []
    for iGroup in range(nb_groups):
        file_data_mean = sct.add_suffix(file_data, '_mean_' + str(iGroup))
        if file_data_mean.endswith(".nii"):
            file_data_mean += ".gz"  # #2149
    im_mean_concat = concat_data(im_mean_list,

    # Estimate moco
    sct.printv('  Estimating motion...', param.verbose)
    param_moco = param
    param_moco.file_data = 'fmri_averaged_groups.nii'
    param_moco.file_target = sct.add_suffix(file_data,
                                            '_mean_' + param.num_target)
    if param_moco.file_target.endswith(".nii"):
        param_moco.file_target += ".gz"  # #2149
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat = moco.moco(param_moco)

    # TODO: if g=1, no need to run the block below (already applied)
    if param.group_size == 1:
        # if flag g=1, it means that all images have already been corrected, so we just need to rename the file
        sct.mv('fmri_averaged_groups_moco.nii', 'fmri_moco.nii')
        # create final mat folder

        # Copy registration matrices
        sct.printv('\nCopy transformations...', param.verbose)
        for iGroup in range(nb_groups):
            for data in range(
            ):  # we cannot use enumerate because group_indexes has 2 dim.
                # fetch all file_mat_z for given t-group
                list_file_mat_z = file_mat[:, iGroup]
                # loop across file_mat_z and copy to mat_final folder
                for file_mat_z in list_file_mat_z:
                    # we want to copy 'mat_groups/mat.ZXXXXTYYYYWarp.nii.gz' --> 'mat_final/mat.ZXXXXTYYYZWarp.nii.gz'
                    # Notice the Y->Z in the under the T index: the idea here is to use the single matrix from each group,
                    # and apply it to all images belonging to the same group.
                        file_mat_z + ext_mat,
                        mat_final + file_mat_z[11:20] + 'T' +
                        str(group_indexes[iGroup][data]).zfill(4) + ext_mat)

        # Apply moco on all fmri data
        sct.printv('  Apply moco', param.verbose)
        param_moco.file_data = 'fmri.nii'
        param_moco.file_target = sct.add_suffix(file_data, '_mean_' + str(0))
        if param_moco.file_target.endswith(".nii"):
            param_moco.file_target += ".gz"
        param_moco.path_out = ''
        param_moco.mat_moco = mat_final
        param_moco.todo = 'apply'
        file_mat = moco.moco(param_moco)

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_fmri = Image('fmri.nii')
    im_fmri_moco = Image('fmri_moco.nii')
    im_fmri_moco.header = im_fmri.header

    # Extract and output the motion parameters
    if param.output_motion_param:
        from sct_image import multicomponent_split
        import csv
        #files_warp = []
        files_warp_X, files_warp_Y = [], []
        moco_param = []
        for fname_warp in file_mat[0]:
            # Cropping the image to keep only one voxel in the XY plane
            im_warp = Image(fname_warp + ext_mat)
            im_warp.data = np.expand_dims(np.expand_dims(
                im_warp.data[0, 0, :, :, :], axis=0),

            # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters
            #fname_warp_crop = fname_warp + '_crop_' + ext_mat

            # Separating the three components and saving X and Y only (Z is equal to 0 by default).
            im_warp_XYZ = multicomponent_split(im_warp)

            fname_warp_crop_X = fname_warp + '_crop_X_' + ext_mat

            fname_warp_crop_Y = fname_warp + '_crop_Y_' + ext_mat

            # Calculating the slice-wise average moco estimate to provide a QC file

        # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters
        #im_warp_concat = concat_data(files_warp, dim=3)

        # Concatenating the moco parameters into a time series for X and Y components.
        im_warp_concat = concat_data(files_warp_X, dim=3)

        im_warp_concat = concat_data(files_warp_Y, dim=3)

        # Writing a TSV file with the slicewise average estimate of the moco parameters, as it is a useful QC file.
        with open('fmri_moco_params.tsv', 'wt') as out_file:
            tsv_writer = csv.writer(out_file, delimiter='\t')
            tsv_writer.writerow(['X', 'Y'])
            for mocop in moco_param:
                tsv_writer.writerow([mocop[0], mocop[1]])

    # Average volumes
    sct.printv('\nAveraging data...', param.verbose)
        '-i', 'fmri_moco.nii', '-o', 'fmri_moco_mean.nii', '-mean', 't', '-v',
def main(fname_anat, fname_centerline, degree_poly, centerline_fitting, interp, remove_temp_files, verbose):
    # extract path of the script
    path_script = os.path.dirname(__file__)+'/'
    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        status, path_sct_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_data+'/t2/t2.nii.gz'
        fname_centerline = path_sct_data+'/t2/t2_seg.nii.gz'
    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... '+fname_anat
    print '  Centerline ........................ '+fname_centerline
    print ''
    # Get input image orientation
    im_anat = Image(fname_anat)
    input_image_orientation = get_orientation(im_anat)

    # Reorient input data into RL PA IS orientation
    im_centerline = Image(fname_centerline)
    im_anat_orient = set_orientation(im_anat, 'RPI')
    im_centerline_orient = set_orientation(im_centerline, 'RPI')

    # Open centerline
    print '\nGet dimensions of input centerline...'
    nx, ny, nz, nt, px, py, pz, pt = im_centerline_orient.dim
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'
    print '\nOpen centerline volume...'
    data = im_centerline_orient.data

    X, Y, Z = (data>0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    # loop across z and associate x,y coordinate with the point having maximum intensity
    x_centerline = [0 for iz in range(min_z_index, max_z_index+1, 1)]
    y_centerline = [0 for iz in range(min_z_index, max_z_index+1, 1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index+1, 1)]

    # Two possible scenario:
    # 1. the centerline is probabilistic: each slices contains voxels with the probability of containing the centerline [0:...:1]
    # We only take the maximum value of the image to aproximate the centerline.
    # 2. The centerline/segmentation image contains many pixels per slice with values {0,1}.
    # We take all the points and approximate the centerline on all these points.

    X, Y, Z = ((data<1)*(data>0)).nonzero() # X is empty if binary image
    if (len(X) > 0): # Scenario 1
        for iz in range(min_z_index, max_z_index+1, 1):
            x_centerline[iz-min_z_index], y_centerline[iz-min_z_index] = numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape)
    else: # Scenario 2
        for iz in range(min_z_index, max_z_index+1, 1):
            x_seg, y_seg = (data[:,:,iz]>0).nonzero()
            if len(x_seg) > 0:
                x_centerline[iz-min_z_index] = numpy.mean(x_seg)
                y_centerline[iz-min_z_index] = numpy.mean(y_seg)

    # TODO: find a way to do the previous loop with this, which is more neat:
    # [numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape) for iz in range(0,nz,1)]
    # clear variable
    del data
    # Fit the centerline points with the kind of curve given as argument of the script and return the new smoothed coordinates
    if centerline_fitting == 'nurbs':
            x_centerline_fit, y_centerline_fit = b_spline_centerline(x_centerline,y_centerline,z_centerline)
        except ValueError:
            print "splines fitting doesn't work, trying with polynomial fitting...\n"
            x_centerline_fit, y_centerline_fit = polynome_centerline(x_centerline,y_centerline,z_centerline)
    elif centerline_fitting == 'polynome':
        x_centerline_fit, y_centerline_fit = polynome_centerline(x_centerline,y_centerline,z_centerline)

    # Split input volume
    print '\nSplit input volume...'
    im_anat_orient_split_list = split_data(im_anat_orient, 2)
    file_anat_split = []
    for im in im_anat_orient_split_list:

    # initialize variables
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_Z'+str(z).zfill(4) for z in range(0,nz,1)]
    z_init = min_z_index
    displacement_max_z_index = x_centerline_fit[z_init-min_z_index]-x_centerline_fit[max_z_index-min_z_index]

    # write centerline as text file
    print '\nGenerate fitted transformation matrices...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_Z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(min_z_index, max_z_index+1, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        if (x_centerline[iz-min_z_index] == 0 and y_centerline[iz-min_z_index] == 0):
            displacement = 0
            displacement = x_centerline_fit[z_init-min_z_index]-x_centerline_fit[iz-min_z_index]
        fid.write('%i %i %i %f\n' %(1, 0, 0, displacement) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )

    # we complete the displacement matrix in z direction
    for iz in range(0, min_z_index, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, 0) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
    for iz in range(max_z_index+1, nz, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, displacement_max_z_index) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_Z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_anat_split_fit[iz]+' -interp '+interp)

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    from glob import glob
    im_to_concat_list = [Image(fname) for fname in glob('tmp.anat_orient_fit_Z*.nii')]
    im_concat_out = concat_data(im_to_concat_list, 2)
    # sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')

    # Reorient data as it was before
    print '\nReorient data back into native orientation...'
    fname_anat_fit_orient = set_orientation(im_concat_out.absolutepath, input_image_orientation, filename=True)
    move(fname_anat_fit_orient, 'tmp.anat_orient_fit_reorient.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    sct.generate_output_file('tmp.anat_orient_fit_reorient.nii', file_anat+'_flatten'+ext_anat)

    # Delete temporary files
    if remove_temp_files == 1:
        print '\nDelete temporary files...'
        sct.run('rm -rf tmp.*')

    # to view results
    print '\nDone! To view results, type:'
    print 'fslview '+file_anat+ext_anat+' '+file_anat+'_flatten'+ext_anat+' &\n'
def create_mask(param):

    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        sct.check_file_exist(method_val, param.verbose)

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = os.path.abspath(param.file_prefix + file_data +

    path_tmp = sct.tmp_create(basename="create_mask", verbose=param.verbose)

    sct.printv('\nOrientation:', param.verbose)
    orientation_input = Image(param.fname_data).orientation
    sct.printv('  ' + orientation_input, param.verbose)

    # copy input data to tmp folder and re-orient to RPI
        os.path.join(path_tmp, "data_RPI.nii"))
    if method_type == 'centerline':
            os.path.join(path_tmp, "centerline_RPI.nii"))
    if method_type == 'point':
            os.path.join(path_tmp, "point_RPI.nii"))

    # go to tmp folder
    curdir = os.getcwd()

    # Get dimensions of data
    im_data = Image('data_RPI.nii')
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('\nDimensions:', param.verbose)
    sct.printv(im_data.dim, param.verbose)
    # in case user input 4d data
    if nt != 1:
            'WARNING in ' + os.path.basename(__file__) +
            ': Input image is 4d but output mask will be 3D from first time slice.',
            param.verbose, 'warning')
        # extract first volume to have 3d reference
        nii = msct_image.empty_like(Image('data_RPI.nii'))
        data3d = nii.data[:, :, :, 0]
        nii.data = data3d

    if method_type == 'coord':
        # parse to get coordinate
        coord = [x for x in map(int, method_val.split('x'))]

    if method_type == 'point':
        # get file name
        # extract coordinate of point
        sct.printv('\nExtract coordinate of point...', param.verbose)
        # TODO: change this way to remove dependence to sct.run. ProcessLabels.display_voxel returns list of coordinates
        status, output = sct.run(
            ['sct_label_utils', '-i', 'point_RPI.nii', '-display'],
        # parse to get coordinate
        # TODO fixup... this is quite magic
        coord = output[output.find('Position=') + 10:-17].split(',')

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = np.round(float(nx) / 2), np.round(float(ny) / 2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline_RPI.nii'
        # generate volume with line along Z at coordinates 'coord'
        sct.printv('\nCreate line...', param.verbose)
        fname_centerline = create_line(param, 'data_RPI.nii', coord, nz)

    # create mask
    sct.printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    spacing = hdr.structarr['pixdim']
    data_centerline = centerline.get_data()  # get centerline
    # if data is 2D, reshape with empty third dimension
    if len(data_centerline.shape) == 2:
        data_centerline_shape = list(data_centerline.shape)
        data_centerline = data_centerline.reshape(data_centerline_shape)
    z_centerline_not_null = [
        iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        if iz in z_centerline_not_null:
            cx[iz], cy[iz] = ndimage.measurements.center_of_mass(
                np.array(data_centerline[:, :, iz]))
    # create 2d masks
    file_mask = 'data_mask'
    for iz in range(nz):
        if iz not in z_centerline_not_null:
            # write an empty nifty volume
            img = nibabel.Nifti1Image(data_centerline[:, :, iz], None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))
            center = np.array([cx[iz], cy[iz]])
            mask2d = create_mask2d(param,
            # Write NIFTI volumes
            img = nibabel.Nifti1Image(mask2d, None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))

    fname_list = [file_mask + str(iz) + '.nii' for iz in range(nz)]
    im_out = concat_data(fname_list, dim=2).save('mask_RPI.nii.gz')

    im_out.header = Image(param.fname_data).header

    # come back

    # Remove temporary files
    if param.remove_temp_files == 1:
        sct.printv('\nRemove temporary files...', param.verbose)

    sct.display_viewer_syntax([param.fname_data, param.fname_out],
                              colormaps=['gray', 'red'],
                              opacities=['', '0.5'])
def fmri_moco(param):

    file_data = "fmri.nii"
    mat_final = 'mat_final/'
    ext_mat = 'Warp.nii.gz'  # warping field

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(param.fname_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), param.verbose)

    # Get orientation
    sct.printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        sct.printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        sct.printv('  Treated as axial')
        param.is_sagittal = False
        sct.printv('WARNING: Orientation seems to be neither axial nor sagittal.')

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
        sct.printv('For sagittal data group_size should be one for more robustness. Forcing group_size=1.', 1, 'warning')
        param.group_size = 1

    # Split into T dimension
    sct.printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath)
        # Make further steps slurp the data to avoid too many open files (#2149)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")

    # assign an index to each volume
    index_fmri = list(range(0, nt))

    # Number of groups
    nb_groups = int(math.floor(nt / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_fmri[(iGroup * param.group_size):((iGroup + 1) * param.group_size)])

    # add the remaining images to the last fMRI group
    nb_remaining = nt%param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_fmri[len(index_fmri) - nb_remaining:len(index_fmri)])

    # groups
    for iGroup in tqdm(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=True, ncols=80):
        # get index
        index_fmri_i = group_indexes[iGroup]
        nt_i = len(index_fmri_i)

        # Merge Images
        file_data_merge_i = sct.add_suffix(file_data, '_' + str(iGroup))
        # cmd = fsloutput + 'fslmerge -t ' + file_data_merge_i
        # for it in range(nt_i):
        #     cmd = cmd + ' ' + file_data + '_T' + str(index_fmri_i[it]).zfill(4)

        im_fmri_list = []
        for it in range(nt_i):
        im_fmri_concat = concat_data(im_fmri_list, 3, squeeze_data=True).save(file_data_merge_i)

        file_data_mean = sct.add_suffix(file_data, '_mean_' + str(iGroup))
        if file_data_mean.endswith(".nii"):
            file_data_mean += ".gz" # #2149
        if param.group_size == 1:
            # copy to new file name instead of averaging (faster)
            # note: this is a bandage. Ideally we should skip this entire for loop if g=1
            convert(file_data_merge_i, file_data_mean)
            # Average Images
            sct.run(['sct_maths', '-i', file_data_merge_i, '-o', file_data_mean, '-mean', 't'], verbose=0)
        # if not average_data_across_dimension(file_data_merge_i+'.nii', file_data_mean+'.nii', 3):
        #     sct.printv('ERROR in average_data_across_dimension', 1, 'error')
        # cmd = fsloutput + 'fslmaths ' + file_data_merge_i + ' -Tmean ' + file_data_mean
        # sct.run(cmd, param.verbose)

    # Merge groups means. The output 4D volume will be used for motion correction.
    sct.printv('\nMerging volumes...', param.verbose)
    file_data_groups_means_merge = 'fmri_averaged_groups.nii'
    im_mean_list = []
    for iGroup in range(nb_groups):
        file_data_mean = sct.add_suffix(file_data, '_mean_' + str(iGroup))
        if file_data_mean.endswith(".nii"):
            file_data_mean += ".gz" # #2149
    im_mean_concat = concat_data(im_mean_list, 3).save(file_data_groups_means_merge)

    # Estimate moco
    sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
    sct.printv('  Estimating motion...', param.verbose)
    sct.printv('-------------------------------------------------------------------------------', param.verbose)
    param_moco = param
    param_moco.file_data = 'fmri_averaged_groups.nii'
    param_moco.file_target = sct.add_suffix(file_data, '_mean_' + param.num_target)
    if param_moco.file_target.endswith(".nii"):
        param_moco.file_target += ".gz" # #2149
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat = moco.moco(param_moco)

    # TODO: if g=1, no need to run the block below (already applied)
    if param.group_size == 1:
        # if flag g=1, it means that all images have already been corrected, so we just need to rename the file
        sct.mv('fmri_averaged_groups_moco.nii', 'fmri_moco.nii')
        # create final mat folder

        # Copy registration matrices
        sct.printv('\nCopy transformations...', param.verbose)
        for iGroup in range(nb_groups):
            for data in range(len(group_indexes[iGroup])):  # we cannot use enumerate because group_indexes has 2 dim.
                # fetch all file_mat_z for given t-group
                list_file_mat_z = file_mat[:, iGroup]
                # loop across file_mat_z and copy to mat_final folder
                for file_mat_z in list_file_mat_z:
                    # we want to copy 'mat_groups/mat.ZXXXXTYYYYWarp.nii.gz' --> 'mat_final/mat.ZXXXXTYYYZWarp.nii.gz'
                    # Notice the Y->Z in the under the T index: the idea here is to use the single matrix from each group,
                    # and apply it to all images belonging to the same group.
                    sct.copy(file_mat_z + ext_mat,
                             mat_final + file_mat_z[11:20] + 'T' + str(group_indexes[iGroup][data]).zfill(4) + ext_mat)

        # Apply moco on all fmri data
        sct.printv('\n-------------------------------------------------------------------------------', param.verbose)
        sct.printv('  Apply moco', param.verbose)
        sct.printv('-------------------------------------------------------------------------------', param.verbose)
        param_moco.file_data = 'fmri.nii'
        param_moco.file_target = sct.add_suffix(file_data, '_mean_' + str(0))
        if param_moco.file_target.endswith(".nii"):
            param_moco.file_target += ".gz"
        param_moco.path_out = ''
        param_moco.mat_moco = mat_final
        param_moco.todo = 'apply'

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_fmri = Image('fmri.nii')
    im_fmri_moco = Image('fmri_moco.nii')
    im_fmri_moco.header = im_fmri.header

    # Average volumes
    sct.printv('\nAveraging data...', param.verbose)
    sct_maths.main(args=['-i', 'fmri_moco.nii',
                         '-o', 'fmri_moco_mean.nii',
                         '-mean', 't',
                         '-v', '0'])
def create_mask():
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI

    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        sct.check_file_exist(method_val, param.verbose)

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = param.file_prefix+file_data+ext_data

    # create temporary folder
    sct.printv('\nCreate temporary folder...', param.verbose)
    path_tmp = sct.tmp_create(param.verbose)
    # )sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
    # sct.run('mkdir '+path_tmp, param.verbose)

    sct.printv('\nCheck orientation...', param.verbose)
    orientation_input = get_orientation(Image(param.fname_data))
    sct.printv('.. '+orientation_input, param.verbose)
    reorient_coordinates = False

    # copy input data to tmp folder
    convert(param.fname_data, path_tmp+'data.nii')
    if method_type == 'centerline':
        convert(method_val, path_tmp+'centerline.nii.gz')
    if method_type == 'point':
        convert(method_val, path_tmp+'point.nii.gz')

    # go to tmp folder

    # reorient to RPI
    sct.printv('\nReorient to RPI...', param.verbose)
    # if not orientation_input == 'RPI':
    sct.run('sct_image -i data.nii -o data_RPI.nii -setorient RPI -v 0', verbose=False)
    if method_type == 'centerline':
        sct.run('sct_image -i centerline.nii.gz -o centerline_RPI.nii.gz -setorient RPI -v 0', verbose=False)
    if method_type == 'point':
        sct.run('sct_image -i point.nii.gz -o point_RPI.nii.gz -setorient RPI -v 0', verbose=False)
    # if method_type == 'centerline':
    #     orientation_centerline = get_orientation_3d(method_val, filename=True)
    #     if not orientation_centerline == 'RPI':
    #         sct.run('sct_image -i ' + method_val + ' -o ' + path_tmp + 'centerline.nii.gz' + ' -setorient RPI -v 0', verbose=False)
    #     else:
    #         convert(method_val, path_tmp+'centerline.nii.gz')

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image('data_RPI.nii').dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)+ ' x ' + str(nt), param.verbose)
    # in case user input 4d data
    if nt != 1:
        sct.printv('WARNING in '+os.path.basename(__file__)+': Input image is 4d but output mask will 3D.', param.verbose, 'warning')
        # extract first volume to have 3d reference
        nii = Image('data_RPI.nii')
        data3d = nii.data[:,:,:,0]
        nii.data = data3d

    if method_type == 'coord':
        # parse to get coordinate
        coord = map(int, method_val.split('x'))

    if method_type == 'point':
        # get file name
        fname_point = method_val
        # extract coordinate of point
        sct.printv('\nExtract coordinate of point...', param.verbose)
        # TODO: change this way to remove dependence to sct.run. ProcessLabels.display_voxel returns list of coordinates
        status, output = sct.run('sct_label_utils -i point_RPI.nii.gz -display', param.verbose)
        # parse to get coordinate
        coord = output[output.find('Position=')+10:-17].split(',')

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = round(float(nx)/2), round(float(ny)/2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline_RPI.nii.gz'
        # generate volume with line along Z at coordinates 'coord'
        sct.printv('\nCreate line...', param.verbose)
        fname_centerline = create_line('data_RPI.nii', coord, nz)

    # create mask
    sct.printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    spacing = hdr.structarr['pixdim']
    data_centerline = centerline.get_data()  # get centerline
    z_centerline_not_null = [iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()]
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        if iz in z_centerline_not_null:
            cx[iz], cy[iz] = ndimage.measurements.center_of_mass(numpy.array(data_centerline[:, :, iz]))
    # create 2d masks
    file_mask = 'data_mask'
    for iz in range(nz):
        if iz not in z_centerline_not_null:
            # write an empty nifty volume
            img = nibabel.Nifti1Image(data_centerline[:, :, iz], None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))
            center = numpy.array([cx[iz], cy[iz]])
            mask2d = create_mask2d(center, param.shape, param.size, nx, ny, even=param.even, spacing=spacing)
            # Write NIFTI volumes
            img = nibabel.Nifti1Image(mask2d, None, hdr)
            nibabel.save(img, (file_mask+str(iz)+'.nii'))
    # merge along Z
    # cmd = 'fslmerge -z mask '

    # related to issue #755, we cannot open more than 256 files at one time.
    # to solve this issue, we do not open more than 100 files
    im_list = []
    im_temp = []
    for iz in range(nz_not_null):
        if iz != 0 and iz % 100 == 0:
            im_temp.append(concat_data(im_list, 2))
            im_list = [Image(file_mask + str(iz) + '.nii')]

    if im_temp:
        im_temp.append(concat_data(im_list, 2))
        im_out = concat_data(im_temp, 2, no_expand=True)
        im_out = concat_data(im_list, 2)
    fname_list = [file_mask + str(iz) + '.nii' for iz in range(nz)]
    im_out = concat_data(fname_list, dim=2)

    # reorient if necessary
    # if not orientation_input == 'RPI':
    sct.run('sct_image -i mask_RPI.nii.gz -o mask.nii.gz -setorient ' + orientation_input, param.verbose)

    # copy header input --> mask
    im_dat = Image('data.nii')
    im_mask = Image('mask.nii.gz')
    im_mask = copy_header(im_dat, im_mask)

    # come back to parent folder

    # Generate output files
    sct.printv('\nGenerate output files...', param.verbose)
    sct.generate_output_file(path_tmp+'mask.nii.gz', param.fname_out)

    # Remove temporary files
    if param.remove_tmp_files == 1:
        sct.printv('\nRemove temporary files...', param.verbose)
        sct.run('rm -rf '+path_tmp, param.verbose, error_exit='warning')

    # to view results
    sct.printv('\nDone! To view results, type:', param.verbose)
    sct.printv('fslview '+param.fname_data+' '+param.fname_out+' -l Red -t 0.5 &', param.verbose, 'info')
def moco(param):

    # retrieve parameters
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI
    file_data = param.file_data
    file_target = param.file_target
    folder_mat = sct.slash_at_the_end(param.mat_moco, 1)  # output folder of mat file
    todo = param.todo
    suffix = param.suffix
    #file_schedule = param.file_schedule
    verbose = param.verbose
    ext = '.nii'

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')

    # print arguments
    sct.printv('\nInput parameters:', param.verbose)
    sct.printv('  Input file ............' + file_data, param.verbose)
    sct.printv('  Reference file ........' + file_target, param.verbose)
    sct.printv('  Polynomial degree .....' + param.param[0], param.verbose)
    sct.printv('  Smoothing kernel ......' + param.param[1], param.verbose)
    sct.printv('  Gradient step .........' + param.param[2], param.verbose)
    sct.printv('  Metric ................' + param.param[3], param.verbose)
    sct.printv('  Todo ..................' + todo, param.verbose)
    sct.printv('  Mask  .................' + param.fname_mask, param.verbose)
    sct.printv('  Output mat folder .....' + folder_mat, param.verbose)

    # create folder for mat files

    # Get size of data
    sct.printv('\nGet dimensions data...', verbose)
    data_im = Image(file_data + ext)
    nx, ny, nz, nt, px, py, pz, pt = data_im.dim
    sct.printv(('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)), verbose)

    # copy file_target to a temporary file
    sct.printv('\nCopy file_target to a temporary file...', verbose)
    sct.run('cp ' + file_target + ext + ' target.nii')
    file_target = 'target'

    # Split data along T dimension
    sct.printv('\nSplit data along T dimension...', verbose)
    data_split_list = split_data(data_im, dim=3)
    for im in data_split_list:
    file_data_splitT = file_data + '_T'

    # Motion correction: initialization
    index = np.arange(nt)
    file_data_splitT_num = []
    file_data_splitT_moco_num = []
    failed_transfo = [0 for i in range(nt)]
    file_mat = [[] for i in range(nt)]

    # Motion correction: Loop across T
    for indice_index in range(nt):

        # create indices and display stuff
        it = index[indice_index]
        file_data_splitT_num.append(file_data_splitT + str(it).zfill(4))
        file_data_splitT_moco_num.append(file_data + suffix + '_T' + str(it).zfill(4))
        sct.printv(('\nVolume ' + str((it)) + '/' + str(nt - 1) + ':'), verbose)
        file_mat[it] = folder_mat + 'mat.T' + str(it)

        # run 3D registration
        failed_transfo[it] = register(param, file_data_splitT_num[it], file_target, file_mat[it], file_data_splitT_moco_num[it])

        # average registered volume with target image
        # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1)
        if param.iterative_averaging and indice_index < 10 and failed_transfo[it] == 0:
            sct.run('sct_maths -i ' + file_target + ext + ' -mul ' + str(indice_index + 1) + ' -o ' + file_target + ext)
            sct.run('sct_maths -i ' + file_target + ext + ' -add ' + file_data_splitT_moco_num[it] + ext + ' -o ' + file_target + ext)
            sct.run('sct_maths -i ' + file_target + ext + ' -div ' + str(indice_index + 2) + ' -o ' + file_target + ext)

    # Replace failed transformation with the closest good one
    sct.printv(('\nReplace failed transformations...'), verbose)
    fT = [i for i, j in enumerate(failed_transfo) if j == 1]
    gT = [i for i, j in enumerate(failed_transfo) if j == 0]
    for it in range(len(fT)):
        abs_dist = [abs(gT[i] - fT[it]) for i in range(len(gT))]
        if not abs_dist == []:
            index_good = abs_dist.index(min(abs_dist))
            sct.printv('  transfo #' + str(fT[it]) + ' --> use transfo #' + str(gT[index_good]), verbose)
            # copy transformation
            sct.run('cp ' + file_mat[gT[index_good]] + 'Warp.nii.gz' + ' ' + file_mat[fT[it]] + 'Warp.nii.gz')
            # apply transformation
            sct.run('sct_apply_transfo -i ' + file_data_splitT_num[fT[it]] + '.nii -d ' + file_target + '.nii -w ' + file_mat[fT[it]] + 'Warp.nii.gz' + ' -o ' + file_data_splitT_moco_num[fT[it]] + '.nii' + ' -x ' + param.interp, verbose)
            # exit program if no transformation exists.
            sct.printv('\nERROR in ' + os.path.basename(__file__) + ': No good transformation exist. Exit program.\n', verbose, 'error')

    # Merge data along T
    file_data_moco = file_data + suffix
    if todo != 'estimate':
        sct.printv('\nMerge data back along T...', verbose)
        from sct_image import concat_data
        # im_list = []
        fname_list = []
        for indice_index in range(len(index)):
            # im_list.append(Image(file_data_splitT_moco_num[indice_index] + ext))
            fname_list.append(file_data_splitT_moco_num[indice_index] + ext)
        im_out = concat_data(fname_list, 3)
        im_out.setFileName(file_data_moco + ext)

    # delete file target.nii (to avoid conflict if this function is run another time)
    sct.printv('\nRemove temporary file...', verbose)
    # os.remove('target.nii')
    sct.run('rm target.nii')
def create_mask():
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI

    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        sct.check_file_exist(method_val, param.verbose)

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = param.file_prefix + file_data + ext_data

    # create temporary folder
    sct.printv('\nCreate temporary folder...', param.verbose)
    path_tmp = sct.tmp_create(param.verbose)
    # )sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
    # sct.run('mkdir '+path_tmp, param.verbose)

    sct.printv('\nCheck orientation...', param.verbose)
    orientation_input = get_orientation(Image(param.fname_data))
    sct.printv('.. ' + orientation_input, param.verbose)
    reorient_coordinates = False

    # copy input data to tmp folder
    convert(param.fname_data, path_tmp + 'data.nii')
    if method_type == 'centerline':
        convert(method_val, path_tmp + 'centerline.nii.gz')
    if method_type == 'point':
        convert(method_val, path_tmp + 'point.nii.gz')

    # go to tmp folder

    # reorient to RPI
    sct.printv('\nReorient to RPI...', param.verbose)
    # if not orientation_input == 'RPI':
    sct.run('sct_image -i data.nii -o data_RPI.nii -setorient RPI -v 0',
    if method_type == 'centerline':
            'sct_image -i centerline.nii.gz -o centerline_RPI.nii.gz -setorient RPI -v 0',
    if method_type == 'point':
            'sct_image -i point.nii.gz -o point_RPI.nii.gz -setorient RPI -v 0',
    # if method_type == 'centerline':
    #     orientation_centerline = get_orientation_3d(method_val, filename=True)
    #     if not orientation_centerline == 'RPI':
    #         sct.run('sct_image -i ' + method_val + ' -o ' + path_tmp + 'centerline.nii.gz' + ' -setorient RPI -v 0', verbose=False)
    #     else:
    #         convert(method_val, path_tmp+'centerline.nii.gz')

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image('data_RPI.nii').dim
        '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),
    # in case user input 4d data
    if nt != 1:
            'WARNING in ' + os.path.basename(__file__) +
            ': Input image is 4d but output mask will 3D.', param.verbose,
        # extract first volume to have 3d reference
        nii = Image('data_RPI.nii')
        data3d = nii.data[:, :, :, 0]
        nii.data = data3d

    if method_type == 'coord':
        # parse to get coordinate
        coord = map(int, method_val.split('x'))

    if method_type == 'point':
        # get file name
        fname_point = method_val
        # extract coordinate of point
        sct.printv('\nExtract coordinate of point...', param.verbose)
        # TODO: change this way to remove dependence to sct.run. ProcessLabels.display_voxel returns list of coordinates
        status, output = sct.run(
            'sct_label_utils -i point_RPI.nii.gz -display', param.verbose)
        # parse to get coordinate
        coord = output[output.find('Position=') + 10:-17].split(',')

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = round(float(nx) / 2), round(float(ny) / 2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline_RPI.nii.gz'
        # generate volume with line along Z at coordinates 'coord'
        sct.printv('\nCreate line...', param.verbose)
        fname_centerline = create_line('data_RPI.nii', coord, nz)

    # create mask
    sct.printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    spacing = hdr.structarr['pixdim']
    data_centerline = centerline.get_data()  # get centerline
    # if data is 2D, reshape with empty third dimension
    if len(data_centerline.shape) == 2:
        data_centerline_shape = list(data_centerline.shape)
        data_centerline = data_centerline.reshape(data_centerline_shape)
    z_centerline_not_null = [
        iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        if iz in z_centerline_not_null:
            cx[iz], cy[iz] = ndimage.measurements.center_of_mass(
                numpy.array(data_centerline[:, :, iz]))
    # create 2d masks
    file_mask = 'data_mask'
    for iz in range(nz):
        if iz not in z_centerline_not_null:
            # write an empty nifty volume
            img = nibabel.Nifti1Image(data_centerline[:, :, iz], None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))
            center = numpy.array([cx[iz], cy[iz]])
            mask2d = create_mask2d(center,
            # Write NIFTI volumes
            img = nibabel.Nifti1Image(mask2d, None, hdr)
            nibabel.save(img, (file_mask + str(iz) + '.nii'))
    # merge along Z
    # cmd = 'fslmerge -z mask '

    # related to issue #755, we cannot open more than 256 files at one time.
    # to solve this issue, we do not open more than 100 files
    im_list = []
    im_temp = []
    for iz in range(nz_not_null):
        if iz != 0 and iz % 100 == 0:
            im_temp.append(concat_data(im_list, 2))
            im_list = [Image(file_mask + str(iz) + '.nii')]

    if im_temp:
        im_temp.append(concat_data(im_list, 2))
        im_out = concat_data(im_temp, 2, no_expand=True)
        im_out = concat_data(im_list, 2)
    fname_list = [file_mask + str(iz) + '.nii' for iz in range(nz)]
    im_out = concat_data(fname_list, dim=2)

    # reorient if necessary
    # if not orientation_input == 'RPI':
        'sct_image -i mask_RPI.nii.gz -o mask.nii.gz -setorient ' +
        orientation_input, param.verbose)

    # copy header input --> mask
    im_dat = Image('data.nii')
    im_mask = Image('mask.nii.gz')
    im_mask = copy_header(im_dat, im_mask)

    # come back to parent folder

    # Generate output files
    sct.printv('\nGenerate output files...', param.verbose)
    sct.generate_output_file(path_tmp + 'mask.nii.gz', param.fname_out)

    # Remove temporary files
    if param.remove_tmp_files == 1:
        sct.printv('\nRemove temporary files...', param.verbose)
        sct.run('rm -rf ' + path_tmp, param.verbose, error_exit='warning')

    # to view results
    sct.printv('\nDone! To view results, type:', param.verbose)
        'fslview ' + param.fname_data + ' ' + param.fname_out +
        ' -l Red -t 0.5 &', param.verbose, 'info')
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        fname_warp_list = self.warp_input  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        interp = sct.get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        sct.printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # fname_warp_list = fname_warp_list.replace(' ', '')  # remove spaces
        # fname_warp_list = fname_warp_list.split(",")  # parse with comma
        for i in range(len(fname_warp_list)):
            # Check if inverse matrix is specified with '-' at the beginning of file name
            if fname_warp_list[i].find('-') == 0:
                use_inverse.append('-i ')
                fname_warp_list[i] = fname_warp_list[i][1:]  # remove '-'
            sct.printv('  Transfo #'+str(i)+': '+use_inverse[i]+fname_warp_list[i], verbose)

        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = sct.extract_fname(fname_warp_list_invert[-1])
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # Check file existence
        sct.printv('\nCheck file existence...', verbose)
        sct.check_file_exist(fname_src, self.verbose)
        sct.check_file_exist(fname_dest, self.verbose)
        for i in range(len(fname_warp_list)):
            # check if file exist
            sct.check_file_exist(fname_warp_list[i], self.verbose)

        # check if destination file is 3d
        if not sct.check_if_3d(fname_dest):
            sct.printv('ERROR: Destination data must be 3d')

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order

        # Extract path, file and extension
        path_src, file_src, ext_src = sct.extract_fname(fname_src)
        path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src+'_reg'
            ext_out = ext_src
            fname_out = path_out+file_out+ext_out

        # Get dimensions of data
        sct.printv('\nGet dimensions of data...', verbose)
        from msct_image import Image
        nx, ny, nz, nt, px, py, pz, pt = Image(fname_src).dim
        # nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_src)
        sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)+ ' x ' + str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            sct.printv('\nApply transformation...', verbose)
            sct.run('isct_antsApplyTransforms -d 3 -i '+fname_src+' -o '+fname_out+' -t '+' '.join(fname_warp_list_invert)+' -r '+fname_dest+interp, verbose)

        # if 4d, loop across the T dimension
            # create temporary folder
            sct.printv('\nCreate temporary folder...', verbose)
            path_tmp = sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
            # sct.run('mkdir '+path_tmp, verbose)
            sct.run('mkdir '+path_tmp, verbose)

            # convert to nifti into temp folder
            sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
            from sct_convert import convert
            convert(fname_src, path_tmp+'data.nii')
            sct.run('cp '+fname_dest+' '+path_tmp+file_dest+ext_dest)
            fname_warp_list_tmp = []
            for fname_warp in fname_warp_list:
                path_warp, file_warp, ext_warp = sct.extract_fname(fname_warp)
                sct.run('cp '+fname_warp+' '+path_tmp+file_warp+ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            # split along T dimension
            sct.printv('\nSplit along T dimension...', verbose)
            from sct_image import split_data
            im_dat = Image('data.nii')
            data_split_list = split_data(im_dat, 3)
            for im in data_split_list:

            # apply transfo
            sct.printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T'+str(it).zfill(4)+'.nii'
                file_data_split_reg = 'data_reg_T'+str(it).zfill(4)+'.nii'
                sct.run('isct_antsApplyTransforms -d 3 -i '+file_data_split+' -o '+file_data_split_reg+' -t '+' '.join(fname_warp_list_invert_tmp)+' -r '+file_dest+ext_dest+interp, verbose)

            # Merge files back
            sct.printv('\nMerge file back...', verbose)
            from sct_image import concat_data
            import glob
            path_out, name_out, ext_out = sct.extract_fname(fname_out)
            im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            im_out = concat_data(im_list, 3)
            sct.generate_output_file(path_tmp+name_out+ext_out, fname_out)
            # Delete temporary folder if specified
            if int(remove_temp_files):
                sct.printv('\nRemove temporary files...', verbose)
                sct.run('rm -rf '+path_tmp, verbose, error_exit='warning')

        # 2. crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # if last warping field is an affine transfo, we need to compute the space of the concatenate warping field:
        if isLastAffine:
            sct.printv('WARNING: the resulting image could have wrong apparent results. You should use an affine transformation as last transformation...',verbose,'warning')
        elif crop_reference == 1:
            ImageCropper(input_file=fname_out, output_file=fname_out, ref=warping_field, background=0).crop()
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field+' -b 0')
        elif crop_reference == 2:
            ImageCropper(input_file=fname_out, output_file=fname_out, ref=warping_field).crop()
            # sct.run('sct_crop_image -i '+fname_out+' -o '+fname_out+' -ref '+warping_field)

        # display elapsed time
        sct.printv('\nDone! To view results, type:', verbose)
        sct.printv('fslview '+fname_dest+' '+fname_out+' &\n', verbose, 'info')