def main():
    
    
    # get path of the toolbox
    status, path_sct = getstatusoutput('echo $SCT_DIR')
    #print path_sct


    #Initialization
    fname = ''
    landmark = ''
    verbose = param.verbose
    output_name = 'aligned.nii.gz'
    template_landmark = ''
    final_warp = param.final_warp
    compose = param.compose
    transfo = 'affine'
        
    try:
         opts, args = getopt.getopt(sys.argv[1:],'hi:l:o:R:t:w:c:v:')
    except getopt.GetoptError:
        usage()
    for opt, arg in opts :
        if opt == '-h':
            usage()
        elif opt in ("-i"):
            fname = arg
        elif opt in ("-l"):
            landmark = arg       
        elif opt in ("-o"):
            output_name = arg  
        elif opt in ("-R"):
            template_landmark = arg
        elif opt in ("-t"):
            transfo = arg    
        elif opt in ("-w"):
            final_warp = arg
        elif opt in ("-c"):
            compose = int(arg)                          
        elif opt in ('-v'):
            verbose = int(arg)
    
    # display usage if a mandatory argument is not provided
    if fname == '' or landmark == '' or template_landmark == '' :
        usage()
        
    if final_warp not in ['','spline','NN']:
        usage()
        
    if transfo not in ['affine', 'bspline', 'SyN', 'nurbs']:
        usage()       
    
    # check existence of input files
    print'\nCheck if file exists ...'
    
    sct.check_file_exist(fname)
    sct.check_file_exist(landmark)
    sct.check_file_exist(template_landmark)
    
    
        
    # Display arguments
    print'\nCheck input arguments...'
    print'  Input volume ...................... '+fname
    print'  Verbose ........................... '+str(verbose)

    if transfo == 'affine':
        print 'Creating cross using input landmarks\n...'
        sct.run('sct_label_utils -i ' + landmark + ' -o ' + 'cross_native.nii.gz -t cross ' )
    
        print 'Creating cross using template landmarks\n...'
        sct.run('sct_label_utils -i ' + template_landmark + ' -o ' + 'cross_template.nii.gz -t cross ' )
    
        print 'Computing affine transformation between subject and destination landmarks\n...'
        os.system('isct_ANTSUseLandmarkImagesToGetAffineTransform cross_template.nii.gz cross_native.nii.gz affine n2t.txt')
        warping = 'n2t.txt'
    elif transfo == 'nurbs':
        warping_subject2template = 'warp_subject2template.nii.gz'
        warping_template2subject = 'warp_template2subject.nii.gz'
        tmp_name = 'tmp.' + time.strftime("%y%m%d%H%M%S")
        sct.run('mkdir ' + tmp_name)
        tmp_abs_path = os.path.abspath(tmp_name)
        sct.run('cp ' + landmark + ' ' + tmp_abs_path)
        os.chdir(tmp_name)

        from msct_image import Image
        image_landmark = Image(landmark)
        image_template = Image(template_landmark)
        landmarks_input = image_landmark.getNonZeroCoordinates(sorting='value')
        landmarks_template = image_template.getNonZeroCoordinates(sorting='value')
        min_value = min([int(landmarks_input[0].value), int(landmarks_template[0].value)])
        max_value = max([int(landmarks_input[-1].value), int(landmarks_template[-1].value)])
        nx, ny, nz, nt, px, py, pz, pt = image_landmark.dim

        displacement_subject2template, displacement_template2subject = [], []
        for value in range(min_value, max_value+1):
            is_in_input = False
            coord_input = None
            for coord in landmarks_input:
                if int(value) == int(coord.value):
                    coord_input = coord
                    is_in_input = True
                    break
            is_in_template = False
            coord_template = None
            for coord in landmarks_template:
                if int(value) == int(coord.value):
                    coord_template = coord
                    is_in_template = True
                    break
            if is_in_template and is_in_input:
                displacement_subject2template.append([0.0, coord_input.z, coord_template.z - coord_input.z])
                displacement_template2subject.append([0.0, coord_template.z, coord_input.z - coord_template.z])

        # create displacement field
        from numpy import zeros
        from nibabel import Nifti1Image, save
        data_warp_subject2template = zeros((nx, ny, nz, 1, 3))
        data_warp_template2subject = zeros((nx, ny, nz, 1, 3))
        hdr_warp = image_template.hdr.copy()
        hdr_warp.set_intent('vector', (), '')
        hdr_warp.set_data_dtype('float32')

        # approximate displacement with nurbs
        from msct_smooth import b_spline_nurbs
        displacement_z = [item[1] for item in displacement_subject2template]
        displacement_x = [item[2] for item in displacement_subject2template]
        verbose = 1
        displacement_z, displacement_y, displacement_y_deriv, displacement_z_deriv = b_spline_nurbs(displacement_x, displacement_z, None, nbControl=None, verbose=verbose, all_slices=True)

        arg_min_z, arg_max_z = np.argmin(displacement_y), np.argmax(displacement_y)
        min_z, max_z = int(displacement_y[arg_min_z]), int(displacement_y[arg_max_z])
        displac = []
        for index, iz in enumerate(displacement_y):
            displac.append([iz, displacement_z[index]])
        for iz in range(0, min_z):
            displac.append([iz, displacement_z[arg_min_z]])

        for iz in range(max_z, nz):
            displac.append([iz, displacement_z[arg_max_z]])

        for item in displac:
            if 0 <= item[0] < nz:
                data_warp_template2subject[:, :, item[0], 0, 2] = item[1] * pz

        displacement_z = [item[1] for item in displacement_template2subject]
        displacement_x = [item[2] for item in displacement_template2subject]
        verbose = 1
        displacement_z, displacement_y, displacement_y_deriv, displacement_z_deriv = b_spline_nurbs(displacement_x, displacement_z, None, nbControl=None, verbose=verbose, all_slices=True)

        arg_min_z, arg_max_z = np.argmin(displacement_y), np.argmax(displacement_y)
        min_z, max_z = int(displacement_y[arg_min_z]), int(displacement_y[arg_max_z])
        displac = []
        for index, iz in enumerate(displacement_y):
            displac.append([iz, displacement_z[index]])
        for iz in range(0, min_z):
            displac.append([iz, displacement_z[arg_min_z]])
        for iz in range(max_z, nz):
            displac.append([iz, displacement_z[arg_max_z]])

        for item in displac:
            data_warp_subject2template[:, :, item[0], 0, 2] = item[1] * pz

        img = Nifti1Image(data_warp_template2subject, None, hdr_warp)
        save(img, warping_template2subject)
        sct.printv('\nDONE ! Warping field generated: ' + warping_template2subject, verbose)
        img = Nifti1Image(data_warp_subject2template, None, hdr_warp)
        save(img, warping_subject2template)
        sct.printv('\nDONE ! Warping field generated: ' + warping_subject2template, verbose)

        # Copy warping into parent folder
        sct.run('cp ' + warping_subject2template + ' ../' + warping_subject2template)
        sct.run('cp ' + warping_template2subject + ' ../' + warping_template2subject)
        warping = warping_subject2template

        os.chdir('..')
        remove_temp_files = True
        if remove_temp_files:
            sct.run('rm -rf ' + tmp_name)

    elif transfo == 'SyN':
        warping = 'warp_subject2template.nii.gz'
        tmp_name = 'tmp.'+time.strftime("%y%m%d%H%M%S")
        sct.run('mkdir '+tmp_name)
        tmp_abs_path = os.path.abspath(tmp_name)
        sct.run('cp ' + landmark + ' ' + tmp_abs_path)
        os.chdir(tmp_name)

        # sct.run('sct_label_utils -i '+landmark+' -t dist-inter')
        # sct.run('sct_label_utils -i '+template_landmark+' -t plan -o template_landmarks_plan.nii.gz -c 5')
        # sct.run('sct_crop_image -i template_landmarks_plan.nii.gz -o template_landmarks_plan_cropped.nii.gz -start 0.35,0.35 -end 0.65,0.65 -dim 0,1')
        # sct.run('sct_label_utils -i '+landmark+' -t plan -o landmarks_plan.nii.gz -c 5')
        # sct.run('sct_crop_image -i landmarks_plan.nii.gz -o landmarks_plan_cropped.nii.gz -start 0.35,0.35 -end 0.65,0.65 -dim 0,1')
        # sct.run('isct_antsRegistration --dimensionality 3 --transform SyN[0.5,3,0] --metric MeanSquares[template_landmarks_plan_cropped.nii.gz,landmarks_plan_cropped.nii.gz,1] --convergence 400x200 --shrink-factors 4x2 --smoothing-sigmas 4x2mm --restrict-deformation 0x0x1 --output [landmarks_reg,landmarks_reg.nii.gz] --interpolation NearestNeighbor --float')
        # sct.run('isct_c3d -mcs landmarks_reg0Warp.nii.gz -oo warp_vecx.nii.gz warp_vecy.nii.gz warp_vecz.nii.gz')
        # sct.run('isct_c3d warp_vecz.nii.gz -resample 200% -o warp_vecz_r.nii.gz')
        # sct.run('isct_c3d warp_vecz_r.nii.gz -smooth 0x0x3mm -o warp_vecz_r_sm.nii.gz')
        # sct.run('sct_crop_image -i warp_vecz_r_sm.nii.gz -o warp_vecz_r_sm_line.nii.gz -start 0.5,0.5 -end 0.5,0.5 -dim 0,1 -b 0')
        # sct.run('sct_label_utils -i warp_vecz_r_sm_line.nii.gz -t plan_ref -o warp_vecz_r_sm_line_extended.nii.gz -c 0 -r '+template_landmark)
        # sct.run('isct_c3d '+template_landmark+' warp_vecx.nii.gz -reslice-identity -o warp_vecx_res.nii.gz')
        # sct.run('isct_c3d '+template_landmark+' warp_vecy.nii.gz -reslice-identity -o warp_vecy_res.nii.gz')
        # sct.run('isct_c3d warp_vecx_res.nii.gz warp_vecy_res.nii.gz warp_vecz_r_sm_line_extended.nii.gz -omc 3 '+warping)  # no x?


        #new
        #put labels of the subject at the center of the image (for plan xOy)
        import nibabel
        from copy import copy
        file_labels_input = nibabel.load(landmark)
        hdr_labels_input = file_labels_input.get_header()
        data_labels_input = file_labels_input.get_data()
        data_labels_middle = copy(data_labels_input)
        data_labels_middle *= 0
        from msct_image import Image
        nx, ny, nz, nt, px, py, pz, pt = Image(landmark).dim
        X,Y,Z = data_labels_input.nonzero()
        x_middle = int(round(nx/2.0))
        y_middle = int(round(ny/2.0))

        #put labels of the template at the center of the image (for plan xOy)  #probably not necessary as already done by average labels
        file_labels_template = nibabel.load(template_landmark)
        hdr_labels_template = file_labels_template.get_header()
        data_labels_template = file_labels_template.get_data()
        data_template_middle = copy(data_labels_template)
        data_template_middle *= 0

        x, y, z = data_labels_template.nonzero()

        max_num = min([len(z), len(Z)])
        index_sort = np.argsort(Z)
        index_sort = index_sort[::-1]
        X = X[index_sort]
        Y = Y[index_sort]
        Z = Z[index_sort]
        index_sort = np.argsort(z)
        index_sort = index_sort[::-1]
        x = x[index_sort]
        y = y[index_sort]
        z = z[index_sort]

        for i in range(max_num):
            data_labels_middle[x_middle, y_middle, Z[i]] = data_labels_input[X[i], Y[i], Z[i]]
        img = nibabel.Nifti1Image(data_labels_middle, None, hdr_labels_input)
        nibabel.save(img, 'labels_input_middle_xy.nii.gz')

        for i in range(max_num):
            data_template_middle[x_middle, y_middle, z[i]] = data_labels_template[x[i], y[i], z[i]]
        img_template = nibabel.Nifti1Image(data_template_middle, None, hdr_labels_template)
        nibabel.save(img_template, 'labels_template_middle_xy.nii.gz')


        #estimate Bspline transform to register to template
        sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField labels_template_middle_xy.nii.gz labels_input_middle_xy.nii.gz '+ warping+' 40x40x1 5 5 0')

        # select centerline of warping field according to z and extend it
        sct.run('isct_c3d -mcs '+warping+' -oo warp_vecx.nii.gz warp_vecy.nii.gz warp_vecz.nii.gz')
        #sct.run('isct_c3d warp_vecz.nii.gz -resample 200% -o warp_vecz_r.nii.gz')
        #sct.run('isct_c3d warp_vecz.nii.gz -smooth 0x0x3mm -o warp_vecz_r_sm.nii.gz')
        sct.run('sct_crop_image -i warp_vecz.nii.gz -o warp_vecz_r_sm_line.nii.gz -start 0.5,0.5 -end 0.5,0.5 -dim 0,1 -b 0')
        sct.run('sct_label_utils -i warp_vecz_r_sm_line.nii.gz -t plan_ref -o warp_vecz_r_sm_line_extended.nii.gz -r '+template_landmark)
        sct.run('isct_c3d '+template_landmark+' warp_vecx.nii.gz -reslice-identity -o warp_vecx_res.nii.gz')
        sct.run('isct_c3d '+template_landmark+' warp_vecy.nii.gz -reslice-identity -o warp_vecy_res.nii.gz')
        sct.run('isct_c3d warp_vecx_res.nii.gz warp_vecy_res.nii.gz warp_vecz_r_sm_line_extended.nii.gz -omc 3 '+warping)

        # check results
        #dilate first labels
        sct.run('fslmaths labels_input_middle_xy.nii.gz -dilF landmark_dilated.nii.gz') #new
        sct.run('sct_apply_transfo -i landmark_dilated.nii.gz -o label_moved.nii.gz -d labels_template_middle_xy.nii.gz -w '+warping+' -x nn')
        #undilate
        sct.run('sct_label_utils -i label_moved.nii.gz -t cubic-to-point -o label_moved_2point.nii.gz')
        sct.run('sct_label_utils -i labels_template_middle_xy.nii.gz -r label_moved_2point.nii.gz -o template_removed.nii.gz -t remove')
        #end new


        # check results
        #dilate first labels

        #sct.run('fslmaths '+landmark+' -dilF landmark_dilated.nii.gz') #old
        #sct.run('sct_apply_transfo -i landmark_dilated.nii.gz -o label_moved.nii.gz -d '+template_landmark+' -w '+warping+' -x nn') #old



        #undilate
        #sct.run('sct_label_utils -i label_moved.nii.gz -t cubic-to-point -o label_moved_2point.nii.gz') #old

        #sct.run('sct_label_utils -i '+template_landmark+' -r label_moved_2point.nii.gz -o template_removed.nii.gz -t remove') #old


        # # sct.run('sct_apply_transfo -i '+landmark+' -o label_moved.nii.gz -d '+template_landmark+' -w '+warping+' -x nn')
        # # sct.run('sct_label_utils -i '+template_landmark+' -r label_moved.nii.gz -o template_removed.nii.gz -t remove')
        # # status, output = sct.run('sct_label_utils -i label_moved.nii.gz -r template_removed.nii.gz -t MSE')

        status, output = sct.run('sct_label_utils -i label_moved_2point.nii.gz -r template_removed.nii.gz -t MSE')
        sct.printv(output,1,'info')
        remove_temp_files = False
        if os.path.isfile('error_log_label_moved.txt'):
            remove_temp_files = False
            with open('log.txt', 'a') as log_file:
                log_file.write('Error for '+fname+'\n')
        # Copy warping into parent folder
        sct.run('cp '+ warping+' ../'+warping)

        os.chdir('..')
        if remove_temp_files:
            sct.run('rm -rf '+tmp_name)



    # if transfo == 'bspline' :
    #     print 'Computing bspline transformation between subject and destination landmarks\n...'
    #     sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField cross_template.nii.gz cross_native.nii.gz warp_ntotemp.nii.gz 5x5x5 3 2 0')
    #     warping = 'warp_ntotemp.nii.gz'
        
    # if final_warp == '' :    
    #     print 'Apply transfo to input image\n...'
    #     sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping + ' -n Linear')
        
    # if final_warp == 'NN':
    #     print 'Apply transfo to input image\n...'
    #     sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping + ' -n NearestNeighbor')
    if final_warp == 'spline':
        print 'Apply transfo to input image\n...'
        sct.run('sct_apply_transfo -i ' + fname + ' -o ' + output_name + ' -d ' + template_landmark + ' -w ' + warping + ' -x spline')


    # Remove warping
    #os.remove(warping)

    # if compose :
        
    #     print 'Computing affine transformation between subject and destination landmarks\n...'
    #     sct.run('isct_ANTSUseLandmarkImagesToGetAffineTransform cross_template.nii.gz cross_native.nii.gz affine n2t.txt')
    #     warping_affine = 'n2t.txt'
        
        
    #     print 'Apply transfo to input landmarks\n...'
    #     sct.run('isct_antsApplyTransforms 3 ' + cross_native + ' cross_affine.nii.gz -r ' + template_landmark + ' -t ' + warping_affine + ' -n NearestNeighbor')
        
    #     print 'Computing transfo between moved landmarks and template landmarks\n...'
    #     sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField cross_template.nii.gz cross_affine.nii.gz warp_affine2temp.nii.gz 5x5x5 3 2 0')
    #     warping_bspline = 'warp_affine2temp.nii.gz'
        
    #     print 'Composing transformations\n...'
    #     sct.run('isct_ComposeMultiTransform 3 warp_full.nii.gz -r ' + template_landmark + ' ' + warping_bspline + ' ' + warping_affine)
    #     warping_concat = 'warp_full.nii.gz'
        
    #     if final_warp == '' :    
    #         print 'Apply concat warp to input image\n...'
    #         sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping_concat + ' -n Linear')
        
    #     if final_warp == 'NN':
    #         print 'Apply concat warp to input image\n...'
    #         sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping_concat + ' -n NearestNeighbor')
        
    #     if final_warp == 'spline':
    #         print 'Apply concat warp to input image\n...'
    #         sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping_concat + ' -n BSpline[3]')
          
    
    
    print '\nFile created : ' + output_name
def main():
    
    # Initialization
    fname_anat = ''
    fname_centerline = ''
    gapxy = param.gapxy
    gapz = param.gapz
    padding = param.padding
    centerline_fitting = param.fitting_method
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    interpolation_warp = param.interpolation_warp

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    print path_sct
    # extract path of the script
    path_script = os.path.dirname(__file__)+'/'
    
    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        # fname_anat = path_sct+'/testing/data/errsm_23/t2/t2.nii.gz'
        # fname_centerline = path_sct+'/testing/data/errsm_23/t2/t2_segmentation_PropSeg.nii.gz'
        fname_anat = '/Users/julien/code/spinalcordtoolbox/scripts/tmp.140713193417/data_rpi.nii'
        fname_centerline = '/Users/julien/code/spinalcordtoolbox/scripts/tmp.140713193417/segmentation_rpi.nii.gz'
        remove_temp_files = 0
        centerline_fitting = 'splines'
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        verbose = 2
    
    # Check input param
    try:
        opts, args = getopt.getopt(sys.argv[1:],'hi:c:r:w:f:v:')
    except getopt.GetoptError as err:
        print str(err)
        usage()
    for opt, arg in opts:
        if opt == '-h':
            usage()
        elif opt in ('-i'):
            fname_anat = arg
        elif opt in ('-c'):
            fname_centerline = arg
        elif opt in ('-r'):
            remove_temp_files = int(arg)
        elif opt in ('-w'):
            interpolation_warp = str(arg)
        elif opt in ('-f'):
            centerline_fitting = str(arg)
        elif opt in ('-v'):
            verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()
    
    # Display usage if optional arguments are not correctly provided
    if centerline_fitting == '':
        centerline_fitting = 'splines'
    elif not centerline_fitting == '' and not centerline_fitting == 'splines' and not centerline_fitting == 'polynomial':
        print '\n \n -f argument is not valid \n \n'
        usage()
    
    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)

    # check interp method
    if interpolation_warp == 'spline':
        interpolation_warp_ants = '--use-BSpline'
    elif interpolation_warp == 'trilinear':
        interpolation_warp_ants = ''
    elif interpolation_warp == 'nearestneighbor':
        interpolation_warp_ants = '--use-NN'
    else:
        print '\WARNING: Interpolation method not recognized. Using: '+param.interpolation_warp
        interpolation_warp_ants = '--use-BSpline'

    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... '+fname_anat
    print '  Centerline ........................ '+fname_centerline
    print '  Centerline fitting option ......... '+centerline_fitting
    print '  Final interpolation ............... '+interpolation_warp
    print '  Verbose ........................... '+str(verbose)
    print ''

    # if verbose 2, import matplotlib
    if verbose == 2:
        import matplotlib.pyplot as plt

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)
    
    # create temporary folder
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir '+path_tmp)

    # copy files into tmp folder
    sct.run('cp '+fname_anat+' '+path_tmp)
    sct.run('cp '+fname_centerline+' '+path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Open centerline
    #==========================================================================================
    # Change orientation of the input centerline into RPI
    print '\nOrient centerline to RPI orientation...'
    fname_centerline_orient = 'tmp.centerline_rpi' + ext_centerline
    sct.run('sct_orientation -i ' + file_centerline + ext_centerline + ' -o ' + fname_centerline_orient + ' -orientation RPI')
    
    print '\nGet dimensions of input centerline...'
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'
    
    print '\nOpen centerline volume...'
    file = nibabel.load(fname_centerline_orient)
    data = file.get_data()
    
    # loop across z and associate x,y coordinate with the point having maximum intensity
    x_centerline = [0 for iz in range(0, nz, 1)]
    y_centerline = [0 for iz in range(0, nz, 1)]
    z_centerline = [iz for iz in range(0, nz, 1)]
    x_centerline_deriv = [0 for iz in range(0, nz, 1)]
    y_centerline_deriv = [0 for iz in range(0, nz, 1)]
    z_centerline_deriv = [0 for iz in range(0, nz, 1)]
    
    # Two possible scenario:
    # 1. the centerline is probabilistic: each slice contains voxels with the probability of containing the centerline [0:...:1]
    # We only take the maximum value of the image to aproximate the centerline.
    # 2. The centerline/segmentation image contains many pixels per slice with values {0,1}.
    # We take all the points and approximate the centerline on all these points.
    
    x_seg_start, y_seg_start = (data[:,:,0]>0).nonzero()
    x_seg_end, y_seg_end = (data[:,:,-1]>0).nonzero()
# REMOVED: 2014-07-18
    # check if centerline covers all the image
#    if len(x_seg_start)==0 or len(x_seg_end)==0:
#        print '\nERROR: centerline/segmentation must cover all "z" slices of the input image.\n' \
#              'To solve the problem, you need to crop the input image (you can use \'sct_crop_image\') and generate one' \
#              'more time the spinal cord centerline/segmentation from this cropped image.\n'
#        usage()
    
    # X, Y, Z = ((data<1)*(data>0)).nonzero() # X is empty if binary image
    # if (len(X) > 0): # Scenario 1
    #     for iz in range(0, nz, 1):
    #         x_centerline[iz], y_centerline[iz] = numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape)
    # else: # Scenario 2
    #     for iz in range(0, nz, 1):
    #         x_seg, y_seg = (data[:,:,iz]>0).nonzero()
    #         x_centerline[iz] = numpy.mean(x_seg)
    #         y_centerline[iz] = numpy.mean(y_seg)
    # # TODO: find a way to do the previous loop with this, which is more neat:
    # # [numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape) for iz in range(0,nz,1)]

    # get center of mass of the centerline/segmentation
    print '\nGet center of mass of the centerline/segmentation...'
    for iz in range(0, nz, 1):
        x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(numpy.array(data[:,:,iz]))

    # clear variable
    del data

    # Fit the centerline points with the kind of curve given as argument of the script and return the new fitted coordinates
    if centerline_fitting == 'splines':
        x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = msct_smooth.b_spline_nurbs(x_centerline,y_centerline,z_centerline)
        #x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)
    elif centerline_fitting == 'polynomial':
        x_centerline_fit, y_centerline_fit, polyx, polyy = polynome_centerline(x_centerline,y_centerline,z_centerline)

    if verbose == 2:
        # plot centerline
        ax = plt.subplot(1,2,1)
        plt.plot(x_centerline, z_centerline, 'b:', label='centerline')
        plt.plot(x_centerline_fit, z_centerline, 'r-', label='fit')
        plt.xlabel('x')
        plt.ylabel('z')
        ax = plt.subplot(1,2,2)
        plt.plot(y_centerline, z_centerline, 'b:', label='centerline')
        plt.plot(y_centerline_fit, z_centerline, 'r-', label='fit')
        plt.xlabel('y')
        plt.ylabel('z')
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels)
        plt.show()

    
    # Get coordinates of landmarks along curved centerline
    #==========================================================================================
    print '\nGet coordinates of landmarks along curved centerline...'
    # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy.
    
    # find derivative of polynomial
    step_z = round(nz/gapz)
    #iz_curved = [i for i in range (0, nz, gapz)]
    iz_curved = [i*step_z for i in range (0, gapz)]
    iz_curved.append(nz-1)     
    #print iz_curved, len(iz_curved)
    n_iz_curved = len(iz_curved)
    #print n_iz_curved
    landmark_curved = [ [ [ 0 for i in range(0,3)] for i in range(0,5) ] for i in iz_curved ]
    # print x_centerline_deriv,len(x_centerline_deriv)
    # landmark[a][b][c]
    #   a: index along z. E.g., the first cross with have index=0, the next index=1, and so on...
    #   b: index of element on the cross. I.e., 0: center of the cross, 1: +x, 2 -x, 3: +y, 4: -y
    #   c: dimension, i.e., 0: x, 1: y, 2: z
    # loop across index, which corresponds to iz (points along the centerline)
    
    if centerline_fitting=='polynomial':
        for index in range(0, n_iz_curved, 1):
            # set coordinates for landmark at the center of the cross
            landmark_curved[index][0][0], landmark_curved[index][0][1], landmark_curved[index][0][2] = x_centerline_fit[iz_curved[index]], y_centerline_fit[iz_curved[index]], iz_curved[index]
            # set x and z coordinates for landmarks +x and -x
            landmark_curved[index][1][2], landmark_curved[index][1][0], landmark_curved[index][2][2], landmark_curved[index][2][0] = get_points_perpendicular_to_curve(polyx, polyx.deriv(), iz_curved[index], gapxy)
            # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
            for i in range(1,3):
                landmark_curved[index][i][1] = y_centerline_fit[iz_curved[index]]
            # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
            landmark_curved[index][3][2], landmark_curved[index][3][1], landmark_curved[index][4][2], landmark_curved[index][4][1] = get_points_perpendicular_to_curve(polyy, polyy.deriv(), iz_curved[index], gapxy)
            # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
            for i in range(3,5):
                landmark_curved[index][i][0] = x_centerline_fit[iz_curved[index]]
    
    elif centerline_fitting=='splines':
        for index in range(0, n_iz_curved, 1):
            # calculate d (ax+by+cz+d=0)
            # print iz_curved[index]
            a=x_centerline_deriv[iz_curved[index]]
            b=y_centerline_deriv[iz_curved[index]]
            c=z_centerline_deriv[iz_curved[index]]
            x=x_centerline_fit[iz_curved[index]]
            y=y_centerline_fit[iz_curved[index]]
            z=iz_curved[index]
            d=-(a*x+b*y+c*z)
            #print a,b,c,d,x,y,z
            # set coordinates for landmark at the center of the cross
            landmark_curved[index][0][0], landmark_curved[index][0][1], landmark_curved[index][0][2] = x_centerline_fit[iz_curved[index]], y_centerline_fit[iz_curved[index]], iz_curved[index]
            
            # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
            for i in range(1,3):
                landmark_curved[index][i][1] = y_centerline_fit[iz_curved[index]]
            
            # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
            x_n=Symbol('x_n')
            landmark_curved[index][2][0],landmark_curved[index][1][0]=solve((x_n-x)**2+((-1/c)*(a*x_n+b*y+d)-z)**2-gapxy**2,x_n)  #x for -x and +x
            landmark_curved[index][1][2]=(-1/c)*(a*landmark_curved[index][1][0]+b*y+d)  #z for +x
            landmark_curved[index][2][2]=(-1/c)*(a*landmark_curved[index][2][0]+b*y+d)  #z for -x
            
            # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
            for i in range(3,5):
                landmark_curved[index][i][0] = x_centerline_fit[iz_curved[index]]
            
            # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
            y_n=Symbol('y_n')
            landmark_curved[index][4][1],landmark_curved[index][3][1]=solve((y_n-y)**2+((-1/c)*(a*x+b*y_n+d)-z)**2-gapxy**2,y_n)  #y for -y and +y
            landmark_curved[index][3][2]=(-1/c)*(a*x+b*landmark_curved[index][3][1]+d)#z for +y
            landmark_curved[index][4][2]=(-1/c)*(a*x+b*landmark_curved[index][4][1]+d)#z for -y
    
    
#    #display
#    fig = plt.figure()
#    ax = fig.add_subplot(111, projection='3d')
#    ax.plot(x_centerline_fit, y_centerline_fit,z_centerline, 'g')
#    ax.plot(x_centerline, y_centerline,z_centerline, 'r')
#    ax.plot([landmark_curved[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
#           [landmark_curved[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
#           [landmark_curved[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
#    ax.set_xlabel('x')
#    ax.set_ylabel('y')
#    ax.set_zlabel('z')
#    plt.show()

    # Get coordinates of landmarks along straight centerline
    #==========================================================================================
    print '\nGet coordinates of landmarks along straight centerline...'
    landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved
    
    # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
    iz_straight = [0 for i in range (0,gapz+1)]
    #print iz_straight,len(iz_straight)
    for index in range(1, n_iz_curved, 1):
        # compute vector between two consecutive points on the curved centerline
        vector_centerline = [x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index-1]], \
                             y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index-1]], \
                             iz_curved[index] - iz_curved[index-1]]
        # compute norm of this vector
        norm_vector_centerline = numpy.linalg.norm(vector_centerline, ord=2)
        # round to closest integer value
        norm_vector_centerline_rounded = int(round(norm_vector_centerline,0))
        # assign this value to the current z-coordinate on the straight centerline
        iz_straight[index] = iz_straight[index-1] + norm_vector_centerline_rounded
    
    # initialize x0 and y0 to be at the center of the FOV
    x0 = int(round(nx/2))
    y0 = int(round(ny/2))
    for index in range(0, n_iz_curved, 1):
        # set coordinates for landmark at the center of the cross
        landmark_straight[index][0][0], landmark_straight[index][0][1], landmark_straight[index][0][2] = x0, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +x
        landmark_straight[index][1][0], landmark_straight[index][1][1], landmark_straight[index][1][2] = x0 + gapxy, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks -x
        landmark_straight[index][2][0], landmark_straight[index][2][1], landmark_straight[index][2][2] = x0-gapxy, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +y
        landmark_straight[index][3][0], landmark_straight[index][3][1], landmark_straight[index][3][2] = x0, y0+gapxy, iz_straight[index]
        # set x, y and z coordinates for landmarks -y
        landmark_straight[index][4][0], landmark_straight[index][4][1], landmark_straight[index][4][2] = x0, y0-gapxy, iz_straight[index]
    
    # # display
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # #ax.plot(x_centerline_fit, y_centerline_fit,z_centerline, 'r')
    # ax.plot([landmark_straight[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
    # ax.set_xlabel('x')
    # ax.set_ylabel('y')
    # ax.set_zlabel('z')
    # plt.show()
    #
    
    # Create NIFTI volumes with landmarks
    #==========================================================================================
    # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
    # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
    print '\nPad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV...'
    sct.run('c3d '+fname_centerline_orient+' -pad '+str(padding)+'x'+str(padding)+'x'+str(padding)+'vox '+str(padding)+'x'+str(padding)+'x'+str(padding)+'vox 0 -o tmp.centerline_pad.nii.gz')
    
    # TODO: don't pad input volume: no need for that! instead, try to increase size of hdr when saving landmarks.
    
    # Open padded centerline for reading
    print '\nOpen padded centerline for reading...'
    file = nibabel.load('tmp.centerline_pad.nii.gz')
    data = file.get_data()
    hdr = file.get_header()
    
    # Create volumes containing curved and straight landmarks
    data_curved_landmarks = data * 0
    data_straight_landmarks = data * 0
    # initialize landmark value
    landmark_value = 1
    # Loop across cross index
    for index in range(0, n_iz_curved, 1):
        # loop across cross element index
        for i_element in range(0, 5, 1):
            # get x, y and z coordinates of curved landmark (rounded to closest integer)
            x, y, z = int(round(landmark_curved[index][i_element][0])), int(round(landmark_curved[index][i_element][1])), int(round(landmark_curved[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_curved_landmarks[x+padding-1:x+padding+2, y+padding-1:y+padding+2, z+padding-1:z+padding+2] = landmark_value
            # get x, y and z coordinates of straight landmark (rounded to closest integer)
            x, y, z = int(round(landmark_straight[index][i_element][0])), int(round(landmark_straight[index][i_element][1])), int(round(landmark_straight[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_straight_landmarks[x+padding-1:x+padding+2, y+padding-1:y+padding+2, z+padding-1:z+padding+2] = landmark_value
            # increment landmark value
            landmark_value = landmark_value + 1
    
    # Write NIFTI volumes
    hdr.set_data_dtype('uint32') # set imagetype to uint8 #TODO: maybe use int32
    print '\nWrite NIFTI volumes...'
    img = nibabel.Nifti1Image(data_curved_landmarks, None, hdr)
    nibabel.save(img, 'tmp.landmarks_curved.nii.gz')
    print '.. File created: tmp.landmarks_curved.nii.gz'
    img = nibabel.Nifti1Image(data_straight_landmarks, None, hdr)
    nibabel.save(img, 'tmp.landmarks_straight.nii.gz')
    print '.. File created: tmp.landmarks_straight.nii.gz'
    
    
    # Estimate deformation field by pairing landmarks
    #==========================================================================================
    
    # Dilate landmarks (because nearest neighbour interpolation will be later used, therefore some landmarks may "disapear" if they are single points)
    #print '\nDilate landmarks...'
    #sct.run(fsloutput+'fslmaths tmp.landmarks_curved.nii -kernel box 3x3x3 -dilD tmp.landmarks_curved_dilated -odt short')
    #sct.run(fsloutput+'fslmaths tmp.landmarks_straight.nii -kernel box 3x3x3 -dilD tmp.landmarks_straight_dilated -odt short')
    
    # Estimate rigid transformation
    print '\nEstimate rigid transformation between paired landmarks...'
    sct.run('ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt')
    
    # Apply rigid transformation
    print '\nApply rigid transformation to curved landmarks...'
    sct.run('WarpImageMultiTransform 3 tmp.landmarks_curved.nii.gz tmp.landmarks_curved_rigid.nii.gz -R tmp.landmarks_straight.nii.gz tmp.curve2straight_rigid.txt --use-NN')
    
    # Estimate b-spline transformation curve --> straight
    print '\nEstimate b-spline transformation: curve --> straight...'
    sct.run('ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz 5x5x5 3 2 0')
    
    # Concatenate rigid and non-linear transformations...
    print '\nConcatenate rigid and non-linear transformations...'
    #sct.run('ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
    # TODO: use sct.run() when output from the following command will be different from 0 (currently there seem to be a bug)
    cmd = 'ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt'
    print('>> '+cmd)
    commands.getstatusoutput(cmd)
    
    # Estimate b-spline transformation straight --> curve
    # TODO: invert warping field instead of estimating a new one
    print '\nEstimate b-spline transformation: straight --> curve...'
    sct.run('ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz 5x5x5 3 2 0')
    
    # Concatenate rigid and non-linear transformations...
    print '\nConcatenate rigid and non-linear transformations...'
    #sct.run('ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
    # TODO: use sct.run() when output from the following command will be different from 0 (currently there seem to be a bug)
    cmd = 'ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R tmp.landmarks_straight.nii.gz -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz'
    print('>> '+cmd)
    commands.getstatusoutput(cmd)
    
    #print '\nPad input image...'
    #sct.run('c3d '+fname_anat+' -pad '+str(padz)+'x'+str(padz)+'x'+str(padz)+'vox '+str(padz)+'x'+str(padz)+'x'+str(padz)+'vox 0 -o tmp.anat_pad.nii')
    
    # Unpad landmarks...
    # THIS WAS REMOVED ON 2014-06-03 because the output data was cropped at the edge, which caused landmarks to sometimes disappear
    # print '\nUnpad landmarks...'
    # sct.run('fslroi tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz '+str(padding)+' '+str(nx)+' '+str(padding)+' '+str(ny)+' '+str(padding)+' '+str(nz))
    
    # Apply deformation to input image
    print '\nApply transformation to input image...'
    sct.run('WarpImageMultiTransform 3 '+file_anat+ext_anat+' tmp.anat_rigid_warp.nii.gz -R tmp.landmarks_straight.nii.gz '+interpolation_warp+ ' tmp.curve2straight.nii.gz')
    # sct.run('WarpImageMultiTransform 3 '+fname_anat+' tmp.anat_rigid_warp.nii.gz -R tmp.landmarks_straight_crop.nii.gz '+interpolation_warp+ ' tmp.curve2straight.nii.gz')
    
    # come back to parent folder
    os.chdir('..')

    # Generate output file (in current folder)
    # TODO: do not uncompress the warping field, it is too time consuming!
    print '\nGenerate output file (in current folder)...'
    sct.generate_output_file(path_tmp+'/tmp.curve2straight.nii.gz','','warp_curve2straight','.nii.gz')  # warping field
    sct.generate_output_file(path_tmp+'/tmp.straight2curve.nii.gz','','warp_straight2curve','.nii.gz')  # warping field
    sct.generate_output_file(path_tmp+'/tmp.anat_rigid_warp.nii.gz','',file_anat+'_straight',ext_anat)  # straightened anatomic

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)
    
    print '\nDone!\n'
示例#3
0
def main():

    # Initialization
    fname_anat = ''
    fname_centerline = ''
    gapxy = param.gapxy
    gapz = param.gapz
    padding = param.padding
    centerline_fitting = param.fitting_method
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    interpolation_warp = param.interpolation_warp

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    print path_sct
    # extract path of the script
    path_script = os.path.dirname(__file__) + '/'

    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        # fname_anat = path_sct+'/testing/data/errsm_23/t2/t2.nii.gz'
        # fname_centerline = path_sct+'/testing/data/errsm_23/t2/t2_segmentation_PropSeg.nii.gz'
        fname_anat = '/home/django/jtouati/data/cover_z_slices/errsm13_t2.nii.gz'
        fname_centerline = '/home/django/jtouati/data/cover_z_slices/segmentation_centerline_binary.nii.gz'
        remove_temp_files = 0
        centerline_fitting = 'splines'
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        verbose = 2

    # Check input param
    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hi:c:r:w:f:v:')
    except getopt.GetoptError as err:
        print str(err)
        usage()
    for opt, arg in opts:
        if opt == '-h':
            usage()
        elif opt in ('-i'):
            fname_anat = arg
        elif opt in ('-c'):
            fname_centerline = arg
        elif opt in ('-r'):
            remove_temp_files = int(arg)
        elif opt in ('-w'):
            interpolation_warp = str(arg)
        elif opt in ('-f'):
            centerline_fitting = str(arg)
        elif opt in ('-v'):
            verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # Display usage if optional arguments are not correctly provided
    if centerline_fitting == '':
        centerline_fitting = 'splines'
    elif not centerline_fitting == '' and not centerline_fitting == 'splines' and not centerline_fitting == 'polynomial':
        print '\n \n -f argument is not valid \n \n'
        usage()

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)

    # check interp method
    if interpolation_warp == 'spline':
        interpolation_warp_ants = '--use-BSpline'
    elif interpolation_warp == 'trilinear':
        interpolation_warp_ants = ''
    elif interpolation_warp == 'nearestneighbor':
        interpolation_warp_ants = '--use-NN'
    else:
        print '\WARNING: Interpolation method not recognized. Using: ' + param.interpolation_warp
        interpolation_warp_ants = '--use-BSpline'

    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... ' + fname_anat
    print '  Centerline ........................ ' + fname_centerline
    print '  Centerline fitting option ......... ' + centerline_fitting
    print '  Final interpolation ............... ' + interpolation_warp
    print '  Verbose ........................... ' + str(verbose)
    print ''

    # if verbose 2, import matplotlib
    if verbose == 2:
        import matplotlib.pyplot as plt

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(
        fname_centerline)

    # create temporary folder
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir ' + path_tmp)

    # copy files into tmp folder
    sct.run('cp ' + fname_anat + ' ' + path_tmp)
    sct.run('cp ' + fname_centerline + ' ' + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Open centerline
    #==========================================================================================
    # Change orientation of the input centerline into RPI
    print '\nOrient centerline to RPI orientation...'
    fname_centerline_orient = 'tmp.centerline_rpi' + ext_centerline
    sct.run('sct_orientation -i ' + file_centerline + ext_centerline + ' -o ' +
            fname_centerline_orient + ' -orientation RPI')

    print '\nGet dimensions of input centerline...'
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    print '\nOpen centerline volume...'
    file = nibabel.load(fname_centerline_orient)
    data = file.get_data()

    # loop across z and associate x,y coordinate with the point having maximum intensity
    x_centerline = [0 for iz in range(0, nz, 1)]
    y_centerline = [0 for iz in range(0, nz, 1)]
    z_centerline = [iz for iz in range(0, nz, 1)]
    x_centerline_deriv = [0 for iz in range(0, nz, 1)]
    y_centerline_deriv = [0 for iz in range(0, nz, 1)]
    z_centerline_deriv = [0 for iz in range(0, nz, 1)]

    # Two possible scenario:
    # 1. the centerline is probabilistic: each slice contains voxels with the probability of containing the centerline [0:...:1]
    # We only take the maximum value of the image to aproximate the centerline.
    # 2. The centerline/segmentation image contains many pixels per slice with values {0,1}.
    # We take all the points and approximate the centerline on all these points.
    #
    # x_seg_start, y_seg_start = (data[:,:,0]>0).nonzero()
    # x_seg_end, y_seg_end = (data[:,:,-1]>0).nonzero()
    # REMOVED: 2014-07-18
    # check if centerline covers all the image
    #    if len(x_seg_start)==0 or len(x_seg_end)==0:
    #        print '\nERROR: centerline/segmentation must cover all "z" slices of the input image.\n' \
    #              'To solve the problem, you need to crop the input image (you can use \'sct_crop_image\') and generate one' \
    #              'more time the spinal cord centerline/segmentation from this cropped image.\n'
    #        usage()
    #
    # X, Y, Z = ((data<1)*(data>0)).nonzero() # X is empty if binary image
    # if (len(X) > 0): # Scenario 1
    #     for iz in range(0, nz, 1):
    #         x_centerline[iz], y_centerline[iz] = numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape)
    # else: # Scenario 2
    #     for iz in range(0, nz, 1):
    #         print (data[:,:,iz]>0).nonzero()
    #         x_seg, y_seg = (data[:,:,iz]>0).nonzero()
    #         x_centerline[iz] = numpy.mean(x_seg)
    #         y_centerline[iz] = numpy.mean(y_seg)
    # # TODO: find a way to do the previous loop with this, which is more neat:
    # # [numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape) for iz in range(0,nz,1)]

    # get center of mass of the centerline/segmentation
    print '\nGet center of mass of the centerline/segmentation...'
    for iz in range(0, nz, 1):
        x_centerline[iz], y_centerline[
            iz] = ndimage.measurements.center_of_mass(
                numpy.array(data[:, :, iz]))

    #print len(x_centerline),len(y_centerline)
    #print len((numpy.array(x_centerline)>=0).nonzero()[0]),len((numpy.array(y_centerline)>=0).nonzero()[0])

    x_seg_start, y_seg_start = (data[:, :, 0] > 0).nonzero()
    x_seg_end, y_seg_end = (data[:, :, -1] > 0).nonzero()

    #check if centerline covers all the image
    if len(x_seg_start) == 0 or len(x_seg_end) == 0:
        sct.printv(
            '\nWARNING : the centerline/segmentation you gave does not cover all "z" slices of the input image. Results should be improved if you crop the input image (you can use \'sct_crop_image\') and generate a new spinalcord centerline/segmentation from this cropped image.\n',
            1, 'warning')
        # print '\nWARNING : the centerline/segmentation you gave does not cover all "z" slices of the input image.\n' \
        #       'Results should be improved if you crop the input image (you can use \'sct_crop_image\') and generate\n'\
        #       'a new spinalcord centerline/segmentation from this cropped image.\n'
        #print len((numpy.array(x_centerline)>=0).nonzero()[0]),len((numpy.array(y_centerline)>=0).nonzero()[0])
        min_centerline = min((numpy.array(x_centerline) >= 0).nonzero()[0])
        max_centerline = max((numpy.array(x_centerline) >= 0).nonzero()[0])
        z_centerline = z_centerline[(min_centerline):(max_centerline + 1)]
        #print len(z_centerline)
        nz = len(z_centerline)
        x_centerline = [x for x in x_centerline if not isnan(x)]
        y_centerline = [y for y in y_centerline if not isnan(y)]
        #print len(x_centerline),len(y_centerline)

    # clear variable
    del data

    # Fit the centerline points with the kind of curve given as argument of the script and return the new fitted coordinates
    if centerline_fitting == 'splines':
        x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = msct_smooth.b_spline_nurbs(
            x_centerline, y_centerline, z_centerline)
        #x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)
    elif centerline_fitting == 'polynomial':
        x_centerline_fit, y_centerline_fit, polyx, polyy = polynome_centerline(
            x_centerline, y_centerline, z_centerline)
        #numpy.interp([i for i in xrange(0,min_centerline+1)],
        #y_centerline_fit

    #print z_centerline

    if verbose == 2:
        # plot centerline
        ax = plt.subplot(1, 2, 1)
        plt.plot(x_centerline, z_centerline, 'b:', label='centerline')
        plt.plot(x_centerline_fit, z_centerline, 'r-', label='fit')
        plt.xlabel('x')
        plt.ylabel('z')
        ax = plt.subplot(1, 2, 2)
        plt.plot(y_centerline, z_centerline, 'b:', label='centerline')
        plt.plot(y_centerline_fit, z_centerline, 'r-', label='fit')
        plt.xlabel('y')
        plt.ylabel('z')
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels)
        plt.show()

    # Get coordinates of landmarks along curved centerline
    #==========================================================================================
    print '\nGet coordinates of landmarks along curved centerline...'
    # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy.
    # find derivative of polynomial
    step_z = round(nz / gapz)
    #iz_curved = [i for i in range (0, nz, gapz)]
    iz_curved = [(min(z_centerline) + i * step_z) for i in range(0, gapz)]
    iz_curved.append(max(z_centerline))
    #print iz_curved, len(iz_curved)
    n_iz_curved = len(iz_curved)
    #print n_iz_curved
    landmark_curved = [[[0 for i in range(0, 3)] for i in range(0, 5)]
                       for i in iz_curved]
    # print x_centerline_deriv,len(x_centerline_deriv)
    # landmark[a][b][c]
    #   a: index along z. E.g., the first cross with have index=0, the next index=1, and so on...
    #   b: index of element on the cross. I.e., 0: center of the cross, 1: +x, 2 -x, 3: +y, 4: -y
    #   c: dimension, i.e., 0: x, 1: y, 2: z
    # loop across index, which corresponds to iz (points along the centerline)

    if centerline_fitting == 'polynomial':
        for index in range(0, n_iz_curved, 1):
            # set coordinates for landmark at the center of the cross
            landmark_curved[index][0][0], landmark_curved[index][0][
                1], landmark_curved[index][0][2] = x_centerline_fit[
                    iz_curved[index]], y_centerline_fit[
                        iz_curved[index]], iz_curved[index]
            # set x and z coordinates for landmarks +x and -x
            landmark_curved[index][1][2], landmark_curved[index][1][
                0], landmark_curved[index][2][2], landmark_curved[index][2][
                    0] = get_points_perpendicular_to_curve(
                        polyx, polyx.deriv(), iz_curved[index], gapxy)
            # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
            for i in range(1, 3):
                landmark_curved[index][i][1] = y_centerline_fit[
                    iz_curved[index]]
            # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
            landmark_curved[index][3][2], landmark_curved[index][3][
                1], landmark_curved[index][4][2], landmark_curved[index][4][
                    1] = get_points_perpendicular_to_curve(
                        polyy, polyy.deriv(), iz_curved[index], gapxy)
            # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
            for i in range(3, 5):
                landmark_curved[index][i][0] = x_centerline_fit[
                    iz_curved[index]]

    elif centerline_fitting == 'splines':
        for index in range(0, n_iz_curved, 1):
            # calculate d (ax+by+cz+d=0)
            # print iz_curved[index]
            a = x_centerline_deriv[iz_curved[index] - min(z_centerline)]
            b = y_centerline_deriv[iz_curved[index] - min(z_centerline)]
            c = z_centerline_deriv[iz_curved[index] - min(z_centerline)]
            x = x_centerline_fit[iz_curved[index] - min(z_centerline)]
            y = y_centerline_fit[iz_curved[index] - min(z_centerline)]
            z = iz_curved[index]
            d = -(a * x + b * y + c * z)
            #print a,b,c,d,x,y,z
            # set coordinates for landmark at the center of the cross
            landmark_curved[index][0][0], landmark_curved[index][0][
                1], landmark_curved[index][0][2] = x_centerline_fit[
                    iz_curved[index] - min(z_centerline)], y_centerline_fit[
                        iz_curved[index] - min(z_centerline)], iz_curved[index]

            # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
            for i in range(1, 3):
                landmark_curved[index][i][1] = y_centerline_fit[
                    iz_curved[index] - min(z_centerline)]

            # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
            x_n = Symbol('x_n')
            landmark_curved[index][2][0], landmark_curved[index][1][0] = solve(
                (x_n - x)**2 + ((-1 / c) * (a * x_n + b * y + d) - z)**2 -
                gapxy**2, x_n)  #x for -x and +x
            landmark_curved[index][1][2] = (-1 / c) * (
                a * landmark_curved[index][1][0] + b * y + d)  #z for +x
            landmark_curved[index][2][2] = (-1 / c) * (
                a * landmark_curved[index][2][0] + b * y + d)  #z for -x

            # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
            for i in range(3, 5):
                landmark_curved[index][i][0] = x_centerline_fit[
                    iz_curved[index] - min(z_centerline)]

            # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
            y_n = Symbol('y_n')
            landmark_curved[index][4][1], landmark_curved[index][3][1] = solve(
                (y_n - y)**2 + ((-1 / c) * (a * x + b * y_n + d) - z)**2 -
                gapxy**2, y_n)  #y for -y and +y
            landmark_curved[index][3][2] = (-1 / c) * (
                a * x + b * landmark_curved[index][3][1] + d)  #z for +y
            landmark_curved[index][4][2] = (-1 / c) * (
                a * x + b * landmark_curved[index][4][1] + d)  #z for -y

#    #display
#    fig = plt.figure()
#    ax = fig.add_subplot(111, projection='3d')
#    ax.plot(x_centerline_fit, y_centerline_fit,z_centerline, 'g')
#    ax.plot(x_centerline, y_centerline,z_centerline, 'r')
#    ax.plot([landmark_curved[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
#           [landmark_curved[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
#           [landmark_curved[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
#    ax.set_xlabel('x')
#    ax.set_ylabel('y')
#    ax.set_zlabel('z')
#    plt.show()

# Get coordinates of landmarks along straight centerline
#==========================================================================================
    print '\nGet coordinates of landmarks along straight centerline...'
    landmark_straight = [[[0 for i in range(0, 3)] for i in range(0, 5)]
                         for i in iz_curved
                         ]  # same structure as landmark_curved

    # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
    iz_straight = [(min(z_centerline) + 0) for i in range(0, gapz + 1)]
    #print iz_straight,len(iz_straight)
    for index in range(1, n_iz_curved, 1):
        # compute vector between two consecutive points on the curved centerline
        vector_centerline = [x_centerline_fit[iz_curved[index]-min(z_centerline)] - x_centerline_fit[iz_curved[index-1]-min(z_centerline)], \
                             y_centerline_fit[iz_curved[index]-min(z_centerline)] - y_centerline_fit[iz_curved[index-1]-min(z_centerline)], \
                             iz_curved[index] - iz_curved[index-1]]
        # compute norm of this vector
        norm_vector_centerline = numpy.linalg.norm(vector_centerline, ord=2)
        # round to closest integer value
        norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
        # assign this value to the current z-coordinate on the straight centerline
        iz_straight[index] = iz_straight[index -
                                         1] + norm_vector_centerline_rounded

    # initialize x0 and y0 to be at the center of the FOV
    x0 = int(round(nx / 2))
    y0 = int(round(ny / 2))
    for index in range(0, n_iz_curved, 1):
        # set coordinates for landmark at the center of the cross
        landmark_straight[index][0][0], landmark_straight[index][0][
            1], landmark_straight[index][0][2] = x0, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +x
        landmark_straight[index][1][0], landmark_straight[index][1][
            1], landmark_straight[index][1][2] = x0 + gapxy, y0, iz_straight[
                index]
        # set x, y and z coordinates for landmarks -x
        landmark_straight[index][2][0], landmark_straight[index][2][
            1], landmark_straight[index][2][2] = x0 - gapxy, y0, iz_straight[
                index]
        # set x, y and z coordinates for landmarks +y
        landmark_straight[index][3][0], landmark_straight[index][3][
            1], landmark_straight[index][3][2] = x0, y0 + gapxy, iz_straight[
                index]
        # set x, y and z coordinates for landmarks -y
        landmark_straight[index][4][0], landmark_straight[index][4][
            1], landmark_straight[index][4][2] = x0, y0 - gapxy, iz_straight[
                index]

    # # display
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # #ax.plot(x_centerline_fit, y_centerline_fit,z_centerline, 'r')
    # ax.plot([landmark_straight[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
    # ax.set_xlabel('x')
    # ax.set_ylabel('y')
    # ax.set_zlabel('z')
    # plt.show()
    #

    # Create NIFTI volumes with landmarks
    #==========================================================================================
    # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
    # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
    print '\nPad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV...'
    sct.run('isct_c3d ' + fname_centerline_orient + ' -pad ' + str(padding) +
            'x' + str(padding) + 'x' + str(padding) + 'vox ' + str(padding) +
            'x' + str(padding) + 'x' + str(padding) +
            'vox 0 -o tmp.centerline_pad.nii.gz')

    # TODO: don't pad input volume: no need for that! instead, try to increase size of hdr when saving landmarks.

    # Open padded centerline for reading
    print '\nOpen padded centerline for reading...'
    file = nibabel.load('tmp.centerline_pad.nii.gz')
    data = file.get_data()
    hdr = file.get_header()

    # Create volumes containing curved and straight landmarks
    data_curved_landmarks = data * 0
    data_straight_landmarks = data * 0
    # initialize landmark value
    landmark_value = 1
    # Loop across cross index
    for index in range(0, n_iz_curved, 1):
        # loop across cross element index
        for i_element in range(0, 5, 1):
            # get x, y and z coordinates of curved landmark (rounded to closest integer)
            x, y, z = int(round(landmark_curved[index][i_element][0])), int(
                round(landmark_curved[index][i_element][1])), int(
                    round(landmark_curved[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_curved_landmarks[x + padding - 1:x + padding + 2,
                                  y + padding - 1:y + padding + 2, z +
                                  padding - 1:z + padding + 2] = landmark_value
            # get x, y and z coordinates of straight landmark (rounded to closest integer)
            x, y, z = int(round(landmark_straight[index][i_element][0])), int(
                round(landmark_straight[index][i_element][1])), int(
                    round(landmark_straight[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_straight_landmarks[x + padding - 1:x + padding + 2,
                                    y + padding - 1:y + padding + 2,
                                    z + padding - 1:z + padding +
                                    2] = landmark_value
            # increment landmark value
            landmark_value = landmark_value + 1

    # Write NIFTI volumes
    hdr.set_data_dtype(
        'uint32')  # set imagetype to uint8 #TODO: maybe use int32
    print '\nWrite NIFTI volumes...'
    img = nibabel.Nifti1Image(data_curved_landmarks, None, hdr)
    nibabel.save(img, 'tmp.landmarks_curved.nii.gz')
    print '.. File created: tmp.landmarks_curved.nii.gz'
    img = nibabel.Nifti1Image(data_straight_landmarks, None, hdr)
    nibabel.save(img, 'tmp.landmarks_straight.nii.gz')
    print '.. File created: tmp.landmarks_straight.nii.gz'

    # Estimate deformation field by pairing landmarks
    #==========================================================================================

    # Dilate landmarks (because nearest neighbour interpolation will be later used, therefore some landmarks may "disapear" if they are single points)
    #print '\nDilate landmarks...'
    #sct.run(fsloutput+'fslmaths tmp.landmarks_curved.nii -kernel box 3x3x3 -dilD tmp.landmarks_curved_dilated -odt short')
    #sct.run(fsloutput+'fslmaths tmp.landmarks_straight.nii -kernel box 3x3x3 -dilD tmp.landmarks_straight_dilated -odt short')

    # Estimate rigid transformation
    print '\nEstimate rigid transformation between paired landmarks...'
    sct.run(
        'isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt'
    )

    # Apply rigid transformation
    print '\nApply rigid transformation to curved landmarks...'
    sct.run(
        'sct_WarpImageMultiTransform 3 tmp.landmarks_curved.nii.gz tmp.landmarks_curved_rigid.nii.gz -R tmp.landmarks_straight.nii.gz tmp.curve2straight_rigid.txt --use-NN'
    )

    # Estimate b-spline transformation curve --> straight
    print '\nEstimate b-spline transformation: curve --> straight...'
    sct.run(
        'isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz 5x5x5 3 2 0'
    )

    # Concatenate rigid and non-linear transformations...
    print '\nConcatenate rigid and non-linear transformations...'
    #sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
    # TODO: use sct.run() when output from the following command will be different from 0 (currently there seem to be a bug)
    cmd = 'isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt'
    print('>> ' + cmd)
    commands.getstatusoutput(cmd)

    # Estimate b-spline transformation straight --> curve
    # TODO: invert warping field instead of estimating a new one
    print '\nEstimate b-spline transformation: straight --> curve...'
    sct.run(
        'isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz 5x5x5 3 2 0'
    )

    # Concatenate rigid and non-linear transformations...
    print '\nConcatenate rigid and non-linear transformations...'
    #sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
    # TODO: use sct.run() when output from the following command will be different from 0 (currently there seem to be a bug)
    cmd = 'isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R tmp.landmarks_straight.nii.gz -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz'
    print('>> ' + cmd)
    commands.getstatusoutput(cmd)

    #print '\nPad input image...'
    #sct.run('isct_c3d '+fname_anat+' -pad '+str(padz)+'x'+str(padz)+'x'+str(padz)+'vox '+str(padz)+'x'+str(padz)+'x'+str(padz)+'vox 0 -o tmp.anat_pad.nii')

    # Unpad landmarks...
    # THIS WAS REMOVED ON 2014-06-03 because the output data was cropped at the edge, which caused landmarks to sometimes disappear
    # print '\nUnpad landmarks...'
    # sct.run('fslroi tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz '+str(padding)+' '+str(nx)+' '+str(padding)+' '+str(ny)+' '+str(padding)+' '+str(nz))

    # Apply deformation to input image
    print '\nApply transformation to input image...'
    sct.run('sct_WarpImageMultiTransform 3 ' + file_anat + ext_anat +
            ' tmp.anat_rigid_warp.nii.gz -R tmp.landmarks_straight.nii.gz ' +
            interpolation_warp + ' tmp.curve2straight.nii.gz')
    # sct.run('sct_WarpImageMultiTransform 3 '+fname_anat+' tmp.anat_rigid_warp.nii.gz -R tmp.landmarks_straight_crop.nii.gz '+interpolation_warp+ ' tmp.curve2straight.nii.gz')

    # come back to parent folder
    os.chdir('..')

    # Generate output file (in current folder)
    # TODO: do not uncompress the warping field, it is too time consuming!
    print '\nGenerate output file (in current folder)...'
    sct.generate_output_file(path_tmp + '/tmp.curve2straight.nii.gz', '',
                             'warp_curve2straight', '.nii.gz')  # warping field
    sct.generate_output_file(path_tmp + '/tmp.straight2curve.nii.gz', '',
                             'warp_straight2curve', '.nii.gz')  # warping field
    sct.generate_output_file(path_tmp + '/tmp.anat_rigid_warp.nii.gz', '',
                             file_anat + '_straight',
                             ext_anat)  # straightened anatomic

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)

    print '\nDone!\n'
def smooth_centerline(fname_centerline, algo_fitting="hanning", type_window="hanning", window_length=80, verbose=0):
    """
    :param fname_centerline: centerline in RPI orientation
    :return: a bunch of useful stuff
    """
    # window_length = param.window_length
    # type_window = param.type_window
    # algo_fitting = param.algo_fitting

    sct.printv("\nSmooth centerline/segmentation...", verbose)

    # get dimensions (again!)
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline)

    # open centerline
    file = load(fname_centerline)
    data = file.get_data()

    # loop across z and associate x,y coordinate with the point having maximum intensity
    # N.B. len(z_centerline) = nz_nonz can be smaller than nz in case the centerline is smaller than the input volume
    z_centerline = [iz for iz in range(0, nz, 1) if data[:, :, iz].any()]
    nz_nonz = len(z_centerline)
    x_centerline = [0 for iz in range(0, nz_nonz, 1)]
    y_centerline = [0 for iz in range(0, nz_nonz, 1)]
    x_centerline_deriv = [0 for iz in range(0, nz_nonz, 1)]
    y_centerline_deriv = [0 for iz in range(0, nz_nonz, 1)]
    z_centerline_deriv = [0 for iz in range(0, nz_nonz, 1)]

    # get center of mass of the centerline/segmentation
    sct.printv(".. Get center of mass of the centerline/segmentation...", verbose)
    for iz in range(0, nz_nonz, 1):
        x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(array(data[:, :, z_centerline[iz]]))

        # import matplotlib.pyplot as plt
        # fig = plt.figure()
        # ax = fig.add_subplot(111)
        # data_tmp = data
        # data_tmp[x_centerline[iz], y_centerline[iz], z_centerline[iz]] = 10
        # implot = ax.imshow(data_tmp[:, :, z_centerline[iz]].T)
        # implot.set_cmap('gray')
        # plt.show()

    sct.printv(".. Smoothing algo = " + algo_fitting, verbose)
    if algo_fitting == "hanning":
        # 2D smoothing
        sct.printv(".. Windows length = " + str(window_length), verbose)

        # change to array
        x_centerline = asarray(x_centerline)
        y_centerline = asarray(y_centerline)

        # Smooth the curve
        x_centerline_smooth = smoothing_window(
            x_centerline, window_len=window_length / pz, window=type_window, verbose=verbose
        )
        y_centerline_smooth = smoothing_window(
            y_centerline, window_len=window_length / pz, window=type_window, verbose=verbose
        )

        # convert to list final result
        x_centerline_smooth = x_centerline_smooth.tolist()
        y_centerline_smooth = y_centerline_smooth.tolist()

        # clear variable
        del data

        x_centerline_fit = x_centerline_smooth
        y_centerline_fit = y_centerline_smooth
        z_centerline_fit = z_centerline

        # get derivative
        x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = evaluate_derivative_3D(
            x_centerline_fit, y_centerline_fit, z_centerline, px, py, pz
        )

        x_centerline_fit = asarray(x_centerline_fit)
        y_centerline_fit = asarray(y_centerline_fit)
        z_centerline_fit = asarray(z_centerline_fit)

    elif algo_fitting == "nurbs":
        from msct_smooth import b_spline_nurbs

        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_nurbs(
            x_centerline, y_centerline, z_centerline, nbControl=None, verbose=verbose
        )

    else:
        sct.printv("ERROR: wrong algorithm for fitting", 1, "error")

    return (
        x_centerline_fit,
        y_centerline_fit,
        z_centerline_fit,
        x_centerline_deriv,
        y_centerline_deriv,
        z_centerline_deriv,
    )