# The constructor
    def __init__(self):
        self.debug = 0
        self.remove_temp_files = 1  # remove temporary files
        self.fname_mask = ''  # this field is needed in the function register@sct_register_multimodal
        self.padding = 10  # this field is needed in the function register@sct_register_multimodal
        self.verbose = 1  # verbose
        self.path_template = path_sct + '/data/PAM50'
        self.path_qc = os.path.abspath(os.curdir) + '/qc/'
        self.zsubsample = '0.25'
        self.param_straighten = ''


# get default parameters
# Note: step0 is used as pre-registration
step0 = Paramreg(step='0', type='label', dof='Tx_Ty_Tz_Sz')  # if ref=template, we only need translations and z-scaling because the cord is already straight
step1 = Paramreg(step='1', type='seg', algo='centermassrot', smooth='2')
# step2 = Paramreg(step='2', type='seg', algo='columnwise', smooth='0', smoothWarpXY='2')
step2 = Paramreg(step='2', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='3', smooth='1')
# step3 = Paramreg(step='3', type='im', algo='syn', metric='CC', iter='1')
paramreg = ParamregMultiStep([step0, step1, step2])


# PARSER
# ==========================================================================================
def get_parser():
    param = Param()
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.\n\n'
      'To register a subject to the template, try the default command:\n'
      'sct_register_to_template -i data.nii.gz -s data_seg.nii.gz -l data_labels.nii.gz\n'
# theta = [0.57]#[1.57079] #10 degres
# x_ = [0]
# y_ = [0]
# generate_warping_field(im_d_3, x_, y_, theta, center_rotation=None, fname='warping_field_15transx.nii.gz')
# sct.run('sct_apply_transfo -i '+im_d_3+' -d '+im_d_3+' -w warping_field_15transx.nii.gz -o ' + im_d_5 + ' -x nn')

# im and algo rigid
im_i = im_T1
im_d = im_T2
window_size = 31
x_disp, y_disp, theta = register_images(im_i,
                                        im_d,
                                        paramreg=Paramreg(step='0',
                                                          type='im',
                                                          algo='Rigid',
                                                          metric='MI',
                                                          iter='100',
                                                          shrink='1',
                                                          smooth='0',
                                                          gradStep='3'),
                                        remove_tmp_folder=1)

x_disp_a = asarray(x_disp)
y_disp_a = asarray(y_disp)
theta_a = asarray(theta)

x_disp_smooth = smoothing_window(x_disp_a,
                                 window_len=window_size,
                                 window='hanning',
                                 verbose=2)
y_disp_smooth = smoothing_window(y_disp_a,
                                 window_len=window_size,
def register_images(im_input,
                    im_dest,
                    mask='',
                    paramreg=Paramreg(step='0',
                                      type='im',
                                      algo='Translation',
                                      metric='MI',
                                      iter='5',
                                      shrink='1',
                                      smooth='0',
                                      gradStep='0.5'),
                    remove_tmp_folder=1):

    path_i, root_i, ext_i = sct.extract_fname(im_input)
    path_d, root_d, ext_d = sct.extract_fname(im_dest)
    path_m, root_m, ext_m = sct.extract_fname(mask)

    # set metricSize
    if paramreg.metric == 'MI':
        metricSize = '32'  # corresponds to number of bins
    else:
        metricSize = '4'  # corresponds to radius (for CC, MeanSquares...)

    # initiate default parameters of antsRegistration transformation
    ants_registration_params = {
        'rigid': '',
        'affine': '',
        'compositeaffine': '',
        'similarity': '',
        'translation': '',
        'bspline': ',10',
        'gaussiandisplacementfield': ',3,0',
        'bsplinedisplacementfield': ',5,10',
        'syn': ',3,0',
        'bsplinesyn': ',3,32'
    }

    # Get image dimensions and retrieve nz
    print '\nGet image dimensions of destination image...'
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(im_dest)
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]
    matrix_def = [0 for i in range(nz)]

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp ' + im_input + ' ' + path_tmp + '/' + root_i + ext_i)
    sct.run('cp ' + im_dest + ' ' + path_tmp + '/' + root_d + ext_d)
    if mask:
        sct.run('cp ' + mask + ' ' + path_tmp + '/mask.nii.gz')

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print '\nSplit input volume...'
    sct.run(sct.fsloutput + 'fslsplit ' + im_input + ' ' + root_i + '_z -z')
    #file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Split destination volume along z
    print '\nSplit destination volume...'
    sct.run(sct.fsloutput + 'fslsplit ' + im_dest + ' ' + root_d + '_z -z')
    #file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Split mask volume along z
    if mask:
        print '\nSplit mask volume...'
        sct.run(sct.fsloutput + 'fslsplit mask.nii.gz mask_z -z')
        #file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    im_dest_img = Image(im_dest)
    im_input_img = Image(im_input)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0, 0, 0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0, 0, 0]])
    coord_diff_origin_z = coord_origin_dest[0][2] - coord_origin_input[0][2]
    [[x_o, y_o,
      z_o]] = im_input_img.transfo_phys2pix([[0, 0, coord_diff_origin_z]])

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + z_o)
        if mask:
            masking = '-x mask_z' + num + '.nii'
        else:
            masking = ''

        cmd = (
            'isct_antsRegistration '
            '--dimensionality 2 '
            '--transform ' + paramreg.algo + '[' + paramreg.gradStep +
            ants_registration_params[paramreg.algo.lower()] + '] '
            '--metric ' + paramreg.metric + '[' + root_d + '_z' + num +
            '.nii' + ',' + root_i + '_z' + num_2 + '.nii' + ',1,' +
            metricSize +
            '] '  #[fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
            '--convergence ' + paramreg.iter + ' '
            '--shrink-factors ' + paramreg.shrink + ' '
            '--smoothing-sigmas ' + paramreg.smooth + 'mm '
            #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
            '--output [transform_' + num +
            '] '  #--> file.txt (contains Tx,Ty)    [outputTransformPrefix,<outputWarpedImage>,<outputInverseWarpedImage>]
            '--interpolation BSpline[3] ' + masking)

        try:
            sct.run(cmd)

            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                f = 'transform_' + num + '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                if i == 20 or i == 40:
                    print i
                x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
                y_displacement[i] = array_transfo[5][0]
                theta_rotation[i] = asin(array_transfo[2])

            if paramreg.algo == 'Affine':
                f = 'transform_' + num + '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
                y_displacement[i] = array_transfo[5][0]
                matrix_def[i] = [[array_transfo[0][0], array_transfo[1][0]],
                                 [array_transfo[2][0], array_transfo[3][0]]
                                 ]  # comment savoir lequel est lequel?

        except:
            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                x_displacement[i] = x_displacement[i - 1]  #is it? or is it y?
                y_displacement[i] = y_displacement[i - 1]
                theta_rotation[i] = theta_rotation[i - 1]
            if paramreg.algo == 'Affine':
                x_displacement[i] = x_displacement[i - 1]
                y_displacement[i] = y_displacement[i - 1]
                matrix_def[i] = matrix_def[i - 1]

        # # get displacement form this slice and complete x and y displacement lists
        # with open('transform_'+num+'.csv') as f:
        #     reader = csv.reader(f)
        #     count = 0
        #     for line in reader:
        #         count += 1
        #         if count == 2:
        #             x_displacement[i] = line[0]
        #             y_displacement[i] = line[1]
        #             f.close()

        # # get matrix of transfo for a rigid transform   (pb slicereg fait une rotation ie le deplacement n'est pas homogene par slice)
        # # recuperer le deplacement ne donnerait pas une liste mais un warping field: mieux vaut recup la matrice output
        # # pb du smoothing du deplacement par slice !!   on peut smoother les param theta tx ty
        # if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
        #     f = 'transform_' +num+ '0GenericAffine.mat'
        #     matfile = loadmat(f, struct_as_record=True)
        #     array_transfo = matfile['AffineTransform_double_2_2']
        #     x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
        #     y_displacement[i] = array_transfo[5][0]
        #     theta_rotation[i] = acos(array_transfo[0])

        #TO DO: different treatment for other algo

    #Delete tmp folder
    os.chdir('../')
    if remove_tmp_folder:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)
    if paramreg.algo == 'Rigid':
        return x_displacement, y_displacement, theta_rotation  # check if the displacement are not inverted (x_dis = -x_disp...)   theta is in radian
    if paramreg.algo == 'Translation':
        return x_displacement, y_displacement
    if paramreg.algo == 'Affine':
        return x_displacement, y_displacement, matrix_def
def register2d(fname_src, fname_dest, fname_mask='', fname_warp='warp_forward.nii.gz', fname_warp_inv='warp_inverse.nii.gz', paramreg=Paramreg(step='0', type='im', algo='Translation', metric='MI', iter='5', shrink='1', smooth='0', gradStep='0.5'),
                    ants_registration_params={'rigid': '', 'affine': '', 'compositeaffine': '', 'similarity': '', 'translation': '', 'bspline': ',10', 'gaussiandisplacementfield': ',3,0',
                                              'bsplinedisplacementfield': ',5,10', 'syn': ',3,0', 'bsplinesyn': ',1,3'}, verbose=0):
    """Slice-by-slice registration of two images.

    We first split the 3D images into 2D images (and the mask if inputted). Then we register slices of the two images
    that physically correspond to one another looking at the physical origin of each image. The images can be of
    different sizes but the destination image must be smaller thant the input image. We do that using antsRegistration
    in 2D. Once this has been done for each slices, we gather the results and return them.
    Algorithms implemented: translation, rigid, affine, syn and BsplineSyn.
    N.B.: If the mask is inputted, it must also be 3D and it must be in the same space as the destination image.

    input:
        fname_source: name of moving image (type: string)
        fname_dest: name of fixed image (type: string)
        mask[optional]: name of mask file (type: string) (parameter -x of antsRegistration)
        fname_warp: name of output 3d forward warping field
        fname_warp_inv: name of output 3d inverse warping field
        paramreg[optional]: parameters of antsRegistration (type: Paramreg class from sct_register_multimodal)
        ants_registration_params[optional]: specific algorithm's parameters for antsRegistration (type: dictionary)

    output:
        if algo==translation:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
        if algo==rigid:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
            theta_rotation: list of rotation angle in radian (and in ITK's coordinate system) for each slice (type: list)
        if algo==affine or algo==syn or algo==bsplinesyn:
            creation of two 3D warping fields (forward and inverse) that are the concatenations of the slice-by-slice
            warps.
    """

    # set metricSize
    if paramreg.metric == 'MI':
        metricSize = '32'  # corresponds to number of bins
    else:
        metricSize = '4'  # corresponds to radius (for CC, MeanSquares...)

    # Get image dimensions and retrieve nz
    sct.printv('\nGet image dimensions of destination image...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    sct.printv('.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)
    sct.printv('.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(pz) + 'mm', verbose)

    # Split input volume along z
    sct.printv('\nSplit input volume...', verbose)
    from sct_image import split_data
    im_src = Image('src.nii')
    split_source_list = split_data(im_src, 2)
    for im in split_source_list:
        im.save()

    # Split destination volume along z
    sct.printv('\nSplit destination volume...', verbose)
    im_dest = Image('dest.nii')
    split_dest_list = split_data(im_dest, 2)
    for im in split_dest_list:
        im.save()

    # Split mask volume along z
    if fname_mask != '':
        sct.printv('\nSplit mask volume...', verbose)
        im_mask = Image('mask.nii.gz')
        split_mask_list = split_data(im_mask, 2)
        for im in split_mask_list:
            im.save()

    # coord_origin_dest = im_dest.transfo_pix2phys([[0,0,0]])
    # coord_origin_input = im_src.transfo_pix2phys([[0,0,0]])
    # coord_diff_origin = (np.asarray(coord_origin_dest[0]) - np.asarray(coord_origin_input[0])).tolist()
    # [x_o, y_o, z_o] = [coord_diff_origin[0] * 1.0/px, coord_diff_origin[1] * 1.0/py, coord_diff_origin[2] * 1.0/pz]

    # initialization
    if paramreg.algo in ['Translation']:
        x_displacement = [0 for i in range(nz)]
        y_displacement = [0 for i in range(nz)]
        theta_rotation = [0 for i in range(nz)]
    if paramreg.algo in ['Rigid', 'Affine', 'BSplineSyN', 'SyN']:
        list_warp = []
        list_warp_inv = []

    # loop across slices
    for i in range(nz):
        # set masking
        sct.printv('Registering slice ' + str(i) + '/' + str(nz - 1) + '...', verbose)
        num = numerotation(i)
        prefix_warp2d = 'warp2d_' + num
        # if mask is used, prepare command for ANTs
        if fname_mask != '':
            masking = ['-x', 'mask_Z' + num + '.nii.gz']
        else:
            masking = []
        # main command for registration
        # TODO fixup isct_ants* parsers
        cmd = ['isct_antsRegistration',
         '--dimensionality', '2',
         '--transform', paramreg.algo + '[' + str(paramreg.gradStep) + ants_registration_params[paramreg.algo.lower()] + ']',
         '--metric', paramreg.metric + '[dest_Z' + num + '.nii' + ',src_Z' + num + '.nii' + ',1,' + metricSize + ']',  #[fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
         '--convergence', str(paramreg.iter),
         '--shrink-factors', str(paramreg.shrink),
         '--smoothing-sigmas', str(paramreg.smooth) + 'mm',
         '--output', '[' + prefix_warp2d + ',src_Z' + num + '_reg.nii]',    #--> file.mat (contains Tx,Ty, theta)
         '--interpolation', 'BSpline[3]',
         '--verbose', '1',
        ] + masking
        # add init translation
        if not paramreg.init == '':
            init_dict = {'geometric': '0', 'centermass': '1', 'origin': '2'}
            cmd += ['-r', '[dest_Z' + num + '.nii' + ',src_Z' + num + '.nii,' + init_dict[paramreg.init] + ']']

        try:
            # run registration
            sct.run(cmd)

            if paramreg.algo in ['Translation']:
                file_mat = prefix_warp2d + '0GenericAffine.mat'
                matfile = loadmat(file_mat, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                x_displacement[i] = array_transfo[4][0]  # Tx in ITK'S coordinate system
                y_displacement[i] = array_transfo[5][0]  # Ty  in ITK'S and fslview's coordinate systems
                theta_rotation[i] = asin(array_transfo[2])  # angle of rotation theta in ITK'S coordinate system (minus theta for fslview)

            if paramreg.algo in ['Rigid', 'Affine', 'BSplineSyN', 'SyN']:
                # List names of 2d warping fields for subsequent merge along Z
                file_warp2d = prefix_warp2d + '0Warp.nii.gz'
                file_warp2d_inv = prefix_warp2d + '0InverseWarp.nii.gz'
                list_warp.append(file_warp2d)
                list_warp_inv.append(file_warp2d_inv)

            if paramreg.algo in ['Rigid', 'Affine']:
                # Generating null 2d warping field (for subsequent concatenation with affine transformation)
                # TODO fixup isct_ants* parsers
                sct.run(['isct_antsRegistration',
                 '-d', '2',
                 '-t', 'SyN[1,1,1]',
                 '-c', '0',
                 '-m', 'MI[dest_Z' + num + '.nii,src_Z' + num + '.nii,1,32]',
                 '-o', 'warp2d_null',
                 '-f', '1',
                 '-s', '0',
                ])
                # --> outputs: warp2d_null0Warp.nii.gz, warp2d_null0InverseWarp.nii.gz
                file_mat = prefix_warp2d + '0GenericAffine.mat'
                # Concatenating mat transfo and null 2d warping field to obtain 2d warping field of affine transformation
                sct.run(['isct_ComposeMultiTransform', '2', file_warp2d, '-R', 'dest_Z' + num + '.nii', 'warp2d_null0Warp.nii.gz', file_mat])
                sct.run(['isct_ComposeMultiTransform', '2', file_warp2d_inv, '-R', 'src_Z' + num + '.nii', 'warp2d_null0InverseWarp.nii.gz', '-i', file_mat])

        # if an exception occurs with ants, take the last value for the transformation
        # TODO: DO WE NEED TO DO THAT??? (julien 2016-03-01)
        except Exception as e:
            sct.printv('ERROR: Exception occurred.\n' + str(e), 1, 'error')

    # Merge warping field along z
    sct.printv('\nMerge warping fields along z...', verbose)

    if paramreg.algo in ['Translation']:
        # convert to array
        x_disp_a = np.asarray(x_displacement)
        y_disp_a = np.asarray(y_displacement)
        theta_rot_a = np.asarray(theta_rotation)
        # Generate warping field
        generate_warping_field('dest.nii', x_disp_a, y_disp_a, fname_warp=fname_warp)  #name_warp= 'step'+str(paramreg.step)
        # Inverse warping field
        generate_warping_field('src.nii', -x_disp_a, -y_disp_a, fname_warp=fname_warp_inv)

    if paramreg.algo in ['Rigid', 'Affine', 'BSplineSyN', 'SyN']:
        from sct_image import concat_warp2d
        # concatenate 2d warping fields along z
        concat_warp2d(list_warp, fname_warp, 'dest.nii')
        concat_warp2d(list_warp_inv, fname_warp_inv, 'src.nii')
def register_images(
        fname_source,
        fname_dest,
        mask='',
        paramreg=Paramreg(step='0',
                          type='im',
                          algo='Translation',
                          metric='MI',
                          iter='5',
                          shrink='1',
                          smooth='0',
                          gradStep='0.5'),
        ants_registration_params={
            'rigid': '',
            'affine': '',
            'compositeaffine': '',
            'similarity': '',
            'translation': '',
            'bspline': ',10',
            'gaussiandisplacementfield': ',3,0',
            'bsplinedisplacementfield': ',5,10',
            'syn': ',3,0',
            'bsplinesyn': ',1,3'
        },
        remove_tmp_folder=1):
    """Slice-by-slice registration of two images.

    We first split the 3D images into 2D images (and the mask if inputted). Then we register slices of the two images
    that physically correspond to one another looking at the physical origin of each image. The images can be of
    different sizes but the destination image must be smaller thant the input image. We do that using antsRegistration
    in 2D. Once this has been done for each slices, we gather the results and return them.
    Algorithms implemented: translation, rigid, affine, syn and BsplineSyn.
    N.B.: If the mask is inputted, it must also be 3D and it must be in the same space as the destination image.

    input:
        fname_source: name of moving image (type: string)
        fname_dest: name of fixed image (type: string)
        mask[optional]: name of mask file (type: string) (parameter -x of antsRegistration)
        paramreg[optional]: parameters of antsRegistration (type: Paramreg class from sct_register_multimodal)
        ants_registration_params[optional]: specific algorithm's parameters for antsRegistration (type: dictionary)

    output:
        if algo==translation:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
        if algo==rigid:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
            theta_rotation: list of rotation angle in radian (and in ITK's coordinate system) for each slice (type: list)
        if algo==affine or algo==syn or algo==bsplinesyn:
            creation of two 3D warping fields (forward and inverse) that are the concatenations of the slice-by-slice
            warps.
    """
    # Extracting names
    path_i, root_i, ext_i = sct.extract_fname(fname_source)
    path_d, root_d, ext_d = sct.extract_fname(fname_dest)

    # set metricSize
    if paramreg.metric == 'MI':
        metricSize = '32'  # corresponds to number of bins
    else:
        metricSize = '4'  # corresponds to radius (for CC, MeanSquares...)

    # Get image dimensions and retrieve nz
    print '\nGet image dimensions of destination image...'
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp ' + fname_source + ' ' + path_tmp + '/' + root_i + ext_i)
    sct.run('cp ' + fname_dest + ' ' + path_tmp + '/' + root_d + ext_d)
    if mask:
        sct.run('cp ' + mask + ' ' + path_tmp + '/mask.nii.gz')

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print '\nSplit input volume...'
    from sct_split_data import split_data
    split_data(fname_source, 2, '_z')

    # Split destination volume along z
    print '\nSplit destination volume...'
    split_data(fname_dest, 2, '_z')

    # Split mask volume along z
    if mask:
        print '\nSplit mask volume...'
        split_data('mask.nii.gz', 2, '_z')

    im_dest_img = Image(fname_dest)
    im_input_img = Image(fname_source)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0, 0, 0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0, 0, 0]])
    coord_diff_origin = (asarray(coord_origin_dest[0]) -
                         asarray(coord_origin_input[0])).tolist()
    [x_o, y_o, z_o] = [
        coord_diff_origin[0] * 1.0 / px, coord_diff_origin[1] * 1.0 / py,
        coord_diff_origin[2] * 1.0 / pz
    ]

    if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
        list_warp_x = []
        list_warp_x_inv = []
        list_warp_y = []
        list_warp_y_inv = []
        name_warp_final = 'Warp_total'  #if modified, name should also be modified in msct_register (algo slicereg2d_bsplinesyn and slicereg2d_syn)

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + int(z_o))
        if mask:
            masking = '-x mask_z' + num + '.nii'
        else:
            masking = ''

        cmd = (
            'isct_antsRegistration '
            '--dimensionality 2 '
            '--transform ' + paramreg.algo + '[' + str(paramreg.gradStep) +
            ants_registration_params[paramreg.algo.lower()] + '] '
            '--metric ' + paramreg.metric + '[' + root_d + '_z' + num +
            '.nii' + ',' + root_i + '_z' + num_2 + '.nii' + ',1,' +
            metricSize +
            '] '  #[fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
            '--convergence ' + str(paramreg.iter) + ' '
            '--shrink-factors ' + str(paramreg.shrink) + ' '
            '--smoothing-sigmas ' + str(paramreg.smooth) + 'mm '
            #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
            '--output [transform_' + num + ',' + root_i + '_z' + num_2 +
            'reg.nii] '  #--> file.mat (contains Tx,Ty, theta)
            '--interpolation BSpline[3] ' + masking)

        try:
            sct.run(cmd)

            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                f = 'transform_' + num + '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                x_displacement[i] = array_transfo[4][
                    0]  # Tx in ITK'S coordinate system
                y_displacement[i] = array_transfo[5][
                    0]  # Ty  in ITK'S and fslview's coordinate systems
                theta_rotation[i] = asin(
                    array_transfo[2]
                )  # angle of rotation theta in ITK'S coordinate system (minus theta for fslview)

            if paramreg.algo == 'Affine':
                # New process added for generating total nifti warping field from mat warp
                name_dest = root_d + '_z' + num + '.nii'
                name_reg = root_i + '_z' + num + 'reg.nii'
                name_output_warp = 'warp_from_mat_' + num_2 + '.nii.gz'
                name_output_warp_inverse = 'warp_from_mat_' + num + '_inverse.nii.gz'
                name_warp_null = 'warp_null_' + num + '.nii.gz'
                name_warp_null_dest = 'warp_null_dest' + num + '.nii.gz'
                name_warp_mat = 'transform_' + num + '0GenericAffine.mat'
                # Generating null nifti warping fields
                nx, ny, nz, nt, px, py, pz, pt = Image(name_reg).dim
                nx_d, ny_d, nz_d, nt_d, px_d, py_d, pz_d, pt_d = Image(
                    name_dest).dim
                x_trans = [0 for i in range(nz)]
                x_trans_d = [0 for i in range(nz_d)]
                y_trans = [0 for i in range(nz)]
                y_trans_d = [0 for i in range(nz_d)]
                generate_warping_field(name_reg,
                                       x_trans=x_trans,
                                       y_trans=y_trans,
                                       fname=name_warp_null,
                                       verbose=0)
                generate_warping_field(name_dest,
                                       x_trans=x_trans_d,
                                       y_trans=y_trans_d,
                                       fname=name_warp_null_dest,
                                       verbose=0)
                # Concatenating mat wrp and null nifti warp to obtain equivalent nifti warp to mat warp
                sct.run('isct_ComposeMultiTransform 2 ' + name_output_warp +
                        ' -R ' + name_reg + ' ' + name_warp_null + ' ' +
                        name_warp_mat)
                sct.run('isct_ComposeMultiTransform 2 ' +
                        name_output_warp_inverse + ' -R ' + name_dest + ' ' +
                        name_warp_null_dest + ' -i ' + name_warp_mat)
                # Split the warping fields into two for displacement along x and y before merge
                sct.run('isct_c3d -mcs ' + name_output_warp +
                        ' -oo transform_' + num + '0Warp_x.nii.gz transform_' +
                        num + '0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs ' + name_output_warp_inverse +
                        ' -oo transform_' + num +
                        '0InverseWarp_x.nii.gz transform_' + num +
                        '0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_' + num + '0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_' + num +
                                       '0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_' + num + '0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_' + num +
                                       '0InverseWarp_y.nii.gz')

            if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN':
                # Split the warping fields into two for displacement along x and y before merge
                # Need to separate the merge for x and y displacement as merge of 3d warping fields does not work properly
                sct.run('isct_c3d -mcs transform_' + num +
                        '0Warp.nii.gz -oo transform_' + num +
                        '0Warp_x.nii.gz transform_' + num + '0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs transform_' + num +
                        '0InverseWarp.nii.gz -oo transform_' + num +
                        '0InverseWarp_x.nii.gz transform_' + num +
                        '0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_' + num + '0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_' + num +
                                       '0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_' + num + '0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_' + num +
                                       '0InverseWarp_y.nii.gz')
        # if an exception occurs with ants, take the last value for the transformation
        except:
            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                x_displacement[i] = x_displacement[i - 1]
                y_displacement[i] = y_displacement[i - 1]
                theta_rotation[i] = theta_rotation[i - 1]

            if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
                print 'Problem with ants for slice ' + str(
                    i) + '. Copy of the last warping field.'
                sct.run('cp transform_' + numerotation(i - 1) +
                        '0Warp.nii.gz transform_' + num + '0Warp.nii.gz')
                sct.run('cp transform_' + numerotation(i - 1) +
                        '0InverseWarp.nii.gz transform_' + num +
                        '0InverseWarp.nii.gz')
                # Split the warping fields into two for displacement along x and y before merge
                sct.run('isct_c3d -mcs transform_' + num +
                        '0Warp.nii.gz -oo transform_' + num +
                        '0Warp_x.nii.gz transform_' + num + '0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs transform_' + num +
                        '0InverseWarp.nii.gz -oo transform_' + num +
                        '0InverseWarp_x.nii.gz transform_' + num +
                        '0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_' + num + '0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_' + num +
                                       '0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_' + num + '0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_' + num +
                                       '0InverseWarp_y.nii.gz')

    if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
        print '\nMerge along z of the warping fields...'
        # from sct_concat_data import concat_data
        sct.run('sct_concat_data -i ' + ','.join(list_warp_x) + ' -o ' +
                name_warp_final + '_x.nii.gz -dim z')
        sct.run('sct_concat_data -i ' + ','.join(list_warp_x_inv) + ' -o ' +
                name_warp_final + '_x_inverse.nii.gz -dim z')
        sct.run('sct_concat_data -i ' + ','.join(list_warp_y) + ' -o ' +
                name_warp_final + '_y.nii.gz -dim z')
        sct.run('sct_concat_data -i ' + ','.join(list_warp_y_inv) + ' -o ' +
                name_warp_final + '_y_inverse.nii.gz -dim z')
        # concat_data(','.join(list_warp_x), name_warp_final+'_x.nii.gz', 2)
        # concat_data(','.join(list_warp_x_inv), name_warp_final+'_x_inverse.nii.gz', 2)
        # concat_data(','.join(list_warp_y), name_warp_final+'_y.nii.gz', 2)
        # concat_data(','.join(list_warp_y_inv), name_warp_final+'_y_inverse.nii.gz', 2)
        # sct.run('fslmerge -z ' + name_warp_final + '_x ' + " ".join(list_warp_x))
        # sct.run('fslmerge -z ' + name_warp_final + '_x_inverse ' + " ".join(list_warp_x_inv))
        # sct.run('fslmerge -z ' + name_warp_final + '_y ' + " ".join(list_warp_y))
        # sct.run('fslmerge -z ' + name_warp_final + '_y_inverse ' + " ".join(list_warp_y_inv))
        print '\nChange resolution of warping fields to match the resolution of the destination image...'
        from sct_copy_header import copy_header
        copy_header(fname_dest, name_warp_final + '_x.nii.gz')
        copy_header(fname_source, name_warp_final + '_x_inverse.nii.gz')
        copy_header(fname_dest, name_warp_final + '_y.nii.gz')
        copy_header(fname_source, name_warp_final + '_y_inverse.nii.gz')
        print '\nMerge translation fields along x and y into one global warping field '
        sct.run('isct_c3d ' + name_warp_final + '_x.nii.gz ' +
                name_warp_final + '_y.nii.gz -omc 2 ' + name_warp_final +
                '.nii.gz')
        sct.run('isct_c3d ' + name_warp_final + '_x_inverse.nii.gz ' +
                name_warp_final + '_y_inverse.nii.gz -omc 2 ' +
                name_warp_final + '_inverse.nii.gz')
        print '\nCopy to parent folder...'
        sct.run('cp ' + name_warp_final + '.nii.gz ../')
        sct.run('cp ' + name_warp_final + '_inverse.nii.gz ../')

    #Delete tmp folder
    os.chdir('../')
    if remove_tmp_folder:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)
    if paramreg.algo == 'Rigid':
        return x_displacement, y_displacement, theta_rotation
    if paramreg.algo == 'Translation':
        return x_displacement, y_displacement
예제 #6
0
def main():

    # get default parameters
    step1 = Paramreg(step='1',
                     type='seg',
                     algo='slicereg',
                     metric='MeanSquares',
                     iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(
        name="-l",
        type_value="file",
        description=
        "Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
        mandatory=True,
        default_value='',
        example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(
        name="-p",
        type_value=[[':'], 'str'],
        description=
        """Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""
        + paramreg.steps['1'].type + """\nalgo=""" + paramreg.steps['1'].algo +
        """\nmetric=""" + paramreg.steps['1'].metric + """\npoly=""" +
        paramreg.steps['1'].poly + """\n--\nstep=2\ntype=""" +
        paramreg.steps['2'].type + """\nalgo=""" + paramreg.steps['2'].algo +
        """\nmetric=""" + paramreg.steps['2'].metric + """\niter=""" +
        paramreg.steps['2'].iter + """\nshrink=""" +
        paramreg.steps['2'].shrink + """\nsmooth=""" +
        paramreg.steps['2'].smooth + """\ngradStep=""" +
        paramreg.steps['2'].gradStep + """\n--""",
        mandatory=False,
        example=
        "step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3"
    )
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(
        name="-v",
        type_value="multiple_choice",
        description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
        mandatory=False,
        default_value=param.verbose,
        example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1) + file_template
    fname_template_label = sct.slash_at_the_end(path_template,
                                                1) + file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template,
                                              1) + file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 ' + fname_data, verbose)
    sct.printv('.. Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('.. Segmentation:         ' + fname_seg, verbose)
    sct.printv('.. Path template:        ' + path_template, verbose)
    sct.printv('.. Output type:          ' + str(output_type), verbose)
    sct.printv('.. Remove temp files:    ' + str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps) + 1):
        sct.printv('Step #' + paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #' + paramreg.steps[str(pStep)].type, verbose)
        sct.printv(
            '.. Algorithm................ ' + paramreg.steps[str(pStep)].algo,
            verbose)
        sct.printv(
            '.. Metric................... ' +
            paramreg.steps[str(pStep)].metric, verbose)
        sct.printv(
            '.. Number of iterations..... ' + paramreg.steps[str(pStep)].iter,
            verbose)
        sct.printv(
            '.. Shrink factor............ ' +
            paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv(
            '.. Smoothing factor......... ' +
            paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv(
            '.. Gradient step............ ' +
            paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv(
            '.. Degree of polynomial..... ' + paramreg.steps[str(pStep)].poly,
            verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv(
            'ERROR: Wrong landmarks input. All labels must be different.',
            verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(
        sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv(
            'ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
            'provided: ' + str(labels[-1].value) +
            '\nLabel max from template: ' + str(labels_template[-1].value),
            verbose, 'error')

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir ' + path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d ' + fname_data + ' -o ' + path_tmp + '/data.nii')
    sct.run('isct_c3d ' + fname_landmarks + ' -o ' + path_tmp +
            '/landmarks.nii.gz')
    sct.run('isct_c3d ' + fname_seg + ' -o ' + path_tmp +
            '/segmentation.nii.gz')
    sct.run('isct_c3d ' + fname_template + ' -o ' + path_tmp + '/template.nii')
    sct.run('isct_c3d ' + fname_template_label + ' -o ' + path_tmp +
            '/template_labels.nii.gz')
    sct.run('isct_c3d ' + fname_template_seg + ' -o ' + path_tmp +
            '/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run(
        'isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii'
    )
    sct.run(
        'isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz'
    )
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run(
        'sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax'
    )

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...',
               verbose)
    sct.run(
        'sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '
        + str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run(
        'sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz'
    )

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv(
        '\nRemove unused label on template. Keep only label present in the input label image...',
        verbose)
    sct.run(
        'sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz'
    )

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run(
        'isct_c3d template_label.nii.gz -type int -o template_label.nii.gz',
        verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run(
        'sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5'
    )

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv(
        '\nCreate a 5mm cross for the input labels and dilate for straightening preparation...',
        verbose)
    sct.run(
        'sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d'
    )

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run(
        'sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn'
    )

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run(
        'isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz'
    )

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.',
               verbose)
    sct.run(
        'sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz'
    )

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv(
        '\nEstimate affine transfo: straight anat --> template (landmark-based)...',
        verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_moving.append(
            [point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_fixed.append(
            [point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv(
        '\nComputing rigid transformation (algo=translation-scaling-z) ...',
        verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write(
        "Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write(
        "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
        % (1.0 / rotation_matrix[0, 0], rotation_matrix[0, 1],
           rotation_matrix[0, 2], rotation_matrix[1, 0],
           1.0 / rotation_matrix[1, 1], rotation_matrix[1, 2],
           rotation_matrix[2, 0], rotation_matrix[2, 1],
           1.0 / rotation_matrix[2, 2], translation_array[0, 0],
           translation_array[0, 1], -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" %
                    (points_moving_barycenter[0], points_moving_barycenter[1],
                     points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...',
               verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear'
    )

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax(
        'segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...',
               verbose)
    sct.run(
        'sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start ' +
        str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...',
               verbose)
    sct.run(
        'sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x' +
        zsubsample, verbose)
    sct.run(
        'sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps) + 1):
        sct.printv(
            '\nEstimate transformation for step #' + str(i_step) + '...',
            verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run(
                'sct_apply_transfo -i ' + src + ' -d ' + dest + ' -w ' +
                ','.join(warp_forward) + ' -o ' + sct.add_suffix(src, '_reg') +
                ' -x ' + interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg,
                                                      param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straightAffine.nii.gz,' +
        ','.join(warp_forward) +
        ' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run(
        'sct_concat_transfo -w ' + ','.join(warp_inverse) +
        ',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz',
        verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run(
            'sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1',
            verbose)
        sct.run(
            'sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1',
            verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp + '/warp_template2anat.nii.gz',
                             'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp + '/warp_anat2template.nii.gz',
                             'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp + '/template2anat.nii.gz',
                                 'template2anat' + ext_data, verbose)
        sct.generate_output_file(path_tmp + '/anat2template.nii.gz',
                                 'anat2template' + ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf ' + path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv(
        '\nFinished! Elapsed time: ' + str(int(round(elapsed_time))) + 's',
        verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview ' + fname_data + ' template2anat -b 0,4000 &', verbose,
               'info')
    sct.printv('fslview ' + fname_template + ' -b 0,5000 anat2template &\n',
               verbose, 'info')