Ejemplo n.º 1
0
def compute_length(fname_segmentation, remove_temp_files, verbose=0):
    from math import sqrt

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir ' + path_tmp)

    # copy files into tmp folder
    sct.run('cp ' + fname_segmentation + ' ' + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', param.verbose)
    fname_segmentation_orient = 'segmentation_rpi' + ext_data
    set_orientation(file_data + ext_data, 'RPI', fname_segmentation_orient)

    # Get dimension
    sct.printv('\nGet dimensions...', param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = Iamge(fname_segmentation_orient).dim
    sct.printv(
        '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz),
        param.verbose)
    sct.printv(
        '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(pz) +
        'mm', param.verbose)

    # smooth segmentation/centerline
    #x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = smooth_centerline(fname_segmentation_orient, param, 'hanning', 1)
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient,
        type_window='hanning',
        window_length=80,
        algo_fitting='hanning',
        verbose=verbose)
    # compute length of centerline
    result_length = 0.0
    for i in range(len(x_centerline_fit) - 1):
        result_length += sqrt(
            ((x_centerline_fit[i + 1] - x_centerline_fit[i]) * px)**2 +
            ((y_centerline_fit[i + 1] - y_centerline_fit[i]) * py)**2 +
            ((z_centerline[i + 1] - z_centerline[i]) * pz)**2)

    return result_length
def compute_length(fname_segmentation, remove_temp_files, verbose=0):
    from math import sqrt

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
    sct.run("mkdir " + path_tmp)

    # copy files into tmp folder
    sct.run("cp " + fname_segmentation + " " + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv("\nOrient centerline to RPI orientation...", param.verbose)
    fname_segmentation_orient = "segmentation_rpi" + ext_data
    set_orientation(file_data + ext_data, "RPI", fname_segmentation_orient)

    # Get dimension
    sct.printv("\nGet dimensions...", param.verbose)
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
    sct.printv(".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz), param.verbose)
    sct.printv(".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm", param.verbose)

    # smooth segmentation/centerline
    # x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = smooth_centerline(fname_segmentation_orient, param, 'hanning', 1)
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient, type_window="hanning", window_length=80, algo_fitting="hanning", verbose=verbose
    )
    # compute length of centerline
    result_length = 0.0
    for i in range(len(x_centerline_fit) - 1):
        result_length += sqrt(
            ((x_centerline_fit[i + 1] - x_centerline_fit[i]) * px) ** 2
            + ((y_centerline_fit[i + 1] - y_centerline_fit[i]) * py) ** 2
            + ((z_centerline[i + 1] - z_centerline[i]) * pz) ** 2
        )

    return result_length
def extract_centerline(
    fname_segmentation, remove_temp_files, verbose=0, algo_fitting="hanning", type_window="hanning", window_length=80
):

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
    sct.run("mkdir " + path_tmp)

    # copy files into tmp folder
    sct.run("cp " + fname_segmentation + " " + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv("\nOrient centerline to RPI orientation...", verbose)
    fname_segmentation_orient = "segmentation_rpi" + ext_data
    set_orientation(file_data + ext_data, "RPI", fname_segmentation_orient)

    # Get dimension
    sct.printv("\nGet dimensions...", verbose)
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
    sct.printv(".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz), verbose)
    sct.printv(".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm", verbose)

    # Extract orientation of the input segmentation
    orientation = get_orientation(file_data + ext_data)
    sct.printv("\nOrientation of segmentation image: " + orientation, verbose)

    sct.printv("\nOpen segmentation volume...", verbose)
    file = nibabel.load(fname_segmentation_orient)
    data = file.get_data()
    hdr = file.get_header()

    # Extract min and max index in Z direction
    X, Y, Z = (data > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    x_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
    y_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index + 1)]
    # Extract segmentation points and average per slice
    for iz in range(min_z_index, max_z_index + 1):
        x_seg, y_seg = (data[:, :, iz] > 0).nonzero()
        x_centerline[iz - min_z_index] = np.mean(x_seg)
        y_centerline[iz - min_z_index] = np.mean(y_seg)
    for k in range(len(X)):
        data[X[k], Y[k], Z[k]] = 0

    # extract centerline and smooth it
    x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient,
        type_window=type_window,
        window_length=window_length,
        algo_fitting=algo_fitting,
        verbose=verbose,
    )

    if verbose == 2:
        import matplotlib.pyplot as plt

        # Creation of a vector x that takes into account the distance between the labels
        nz_nonz = len(z_centerline)
        x_display = [0 for i in range(x_centerline_fit.shape[0])]
        y_display = [0 for i in range(y_centerline_fit.shape[0])]
        for i in range(0, nz_nonz, 1):
            x_display[int(z_centerline[i] - z_centerline[0])] = x_centerline[i]
            y_display[int(z_centerline[i] - z_centerline[0])] = y_centerline[i]

        plt.figure(1)
        plt.subplot(2, 1, 1)
        plt.plot(z_centerline_fit, x_display, "ro")
        plt.plot(z_centerline_fit, x_centerline_fit)
        plt.xlabel("Z")
        plt.ylabel("X")
        plt.title("x and x_fit coordinates")

        plt.subplot(2, 1, 2)
        plt.plot(z_centerline_fit, y_display, "ro")
        plt.plot(z_centerline_fit, y_centerline_fit)
        plt.xlabel("Z")
        plt.ylabel("Y")
        plt.title("y and y_fit coordinates")
        plt.show()

    # Create an image with the centerline
    for iz in range(min_z_index, max_z_index + 1):
        data[
            round(x_centerline_fit[iz - min_z_index]), round(y_centerline_fit[iz - min_z_index]), iz
        ] = (
            1
        )  # if index is out of bounds here for hanning: either the segmentation has holes or labels have been added to the file
    # Write the centerline image in RPI orientation
    hdr.set_data_dtype("uint8")  # set imagetype to uint8
    sct.printv("\nWrite NIFTI volumes...", verbose)
    img = nibabel.Nifti1Image(data, None, hdr)
    nibabel.save(img, "centerline.nii.gz")
    sct.generate_output_file("centerline.nii.gz", file_data + "_centerline" + ext_data, verbose)

    # create a txt file with the centerline
    file_name = file_data + "_centerline" + ".txt"
    sct.printv("\nWrite text file...", verbose)
    file_results = open(file_name, "w")
    for i in range(min_z_index, max_z_index + 1):
        file_results.write(
            str(int(i))
            + " "
            + str(x_centerline_fit[i - min_z_index])
            + " "
            + str(y_centerline_fit[i - min_z_index])
            + "\n"
        )
    file_results.close()

    # Copy result into parent folder
    sct.run("cp " + file_name + " ../")

    del data

    # come back to parent folder
    os.chdir("..")

    # Change orientation of the output centerline into input orientation
    sct.printv("\nOrient centerline image to input orientation: " + orientation, verbose)
    fname_segmentation_orient = "tmp.segmentation_rpi" + ext_data
    set_orientation(
        path_tmp + "/" + file_data + "_centerline" + ext_data, orientation, file_data + "_centerline" + ext_data
    )

    # Remove temporary files
    if remove_temp_files:
        sct.printv("\nRemove temporary files...", verbose)
        sct.run("rm -rf " + path_tmp, verbose)

    return file_data + "_centerline" + ext_data
Ejemplo n.º 4
0
def main():

    # get default parameters
    step1 = Paramreg(step='1',
                     type='seg',
                     algo='slicereg',
                     metric='MeanSquares',
                     iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(
        name="-l",
        type_value="file",
        description=
        "Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
        mandatory=True,
        default_value='',
        example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(
        name="-p",
        type_value=[[':'], 'str'],
        description=
        """Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""
        + paramreg.steps['1'].type + """\nalgo=""" + paramreg.steps['1'].algo +
        """\nmetric=""" + paramreg.steps['1'].metric + """\npoly=""" +
        paramreg.steps['1'].poly + """\n--\nstep=2\ntype=""" +
        paramreg.steps['2'].type + """\nalgo=""" + paramreg.steps['2'].algo +
        """\nmetric=""" + paramreg.steps['2'].metric + """\niter=""" +
        paramreg.steps['2'].iter + """\nshrink=""" +
        paramreg.steps['2'].shrink + """\nsmooth=""" +
        paramreg.steps['2'].smooth + """\ngradStep=""" +
        paramreg.steps['2'].gradStep + """\n--""",
        mandatory=False,
        example=
        "step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3"
    )
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(
        name="-v",
        type_value="multiple_choice",
        description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
        mandatory=False,
        default_value=param.verbose,
        example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1) + file_template
    fname_template_label = sct.slash_at_the_end(path_template,
                                                1) + file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template,
                                              1) + file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 ' + fname_data, verbose)
    sct.printv('.. Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('.. Segmentation:         ' + fname_seg, verbose)
    sct.printv('.. Path template:        ' + path_template, verbose)
    sct.printv('.. Output type:          ' + str(output_type), verbose)
    sct.printv('.. Remove temp files:    ' + str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps) + 1):
        sct.printv('Step #' + paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #' + paramreg.steps[str(pStep)].type, verbose)
        sct.printv(
            '.. Algorithm................ ' + paramreg.steps[str(pStep)].algo,
            verbose)
        sct.printv(
            '.. Metric................... ' +
            paramreg.steps[str(pStep)].metric, verbose)
        sct.printv(
            '.. Number of iterations..... ' + paramreg.steps[str(pStep)].iter,
            verbose)
        sct.printv(
            '.. Shrink factor............ ' +
            paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv(
            '.. Smoothing factor......... ' +
            paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv(
            '.. Gradient step............ ' +
            paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv(
            '.. Degree of polynomial..... ' + paramreg.steps[str(pStep)].poly,
            verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv(
            'ERROR: Wrong landmarks input. All labels must be different.',
            verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(
        sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv(
            'ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
            'provided: ' + str(labels[-1].value) +
            '\nLabel max from template: ' + str(labels_template[-1].value),
            verbose, 'error')

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir ' + path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d ' + fname_data + ' -o ' + path_tmp + '/data.nii')
    sct.run('isct_c3d ' + fname_landmarks + ' -o ' + path_tmp +
            '/landmarks.nii.gz')
    sct.run('isct_c3d ' + fname_seg + ' -o ' + path_tmp +
            '/segmentation.nii.gz')
    sct.run('isct_c3d ' + fname_template + ' -o ' + path_tmp + '/template.nii')
    sct.run('isct_c3d ' + fname_template_label + ' -o ' + path_tmp +
            '/template_labels.nii.gz')
    sct.run('isct_c3d ' + fname_template_seg + ' -o ' + path_tmp +
            '/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run(
        'isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii'
    )
    sct.run(
        'isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz'
    )
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run(
        'sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax'
    )

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...',
               verbose)
    sct.run(
        'sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '
        + str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run(
        'sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz'
    )

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv(
        '\nRemove unused label on template. Keep only label present in the input label image...',
        verbose)
    sct.run(
        'sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz'
    )

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run(
        'isct_c3d template_label.nii.gz -type int -o template_label.nii.gz',
        verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run(
        'sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5'
    )

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv(
        '\nCreate a 5mm cross for the input labels and dilate for straightening preparation...',
        verbose)
    sct.run(
        'sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d'
    )

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run(
        'sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn'
    )

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run(
        'isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz'
    )

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.',
               verbose)
    sct.run(
        'sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz'
    )

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv(
        '\nEstimate affine transfo: straight anat --> template (landmark-based)...',
        verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_moving.append(
            [point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_fixed.append(
            [point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv(
        '\nComputing rigid transformation (algo=translation-scaling-z) ...',
        verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write(
        "Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write(
        "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
        % (1.0 / rotation_matrix[0, 0], rotation_matrix[0, 1],
           rotation_matrix[0, 2], rotation_matrix[1, 0],
           1.0 / rotation_matrix[1, 1], rotation_matrix[1, 2],
           rotation_matrix[2, 0], rotation_matrix[2, 1],
           1.0 / rotation_matrix[2, 2], translation_array[0, 0],
           translation_array[0, 1], -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" %
                    (points_moving_barycenter[0], points_moving_barycenter[1],
                     points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...',
               verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear'
    )

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax(
        'segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...',
               verbose)
    sct.run(
        'sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start ' +
        str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...',
               verbose)
    sct.run(
        'sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x' +
        zsubsample, verbose)
    sct.run(
        'sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps) + 1):
        sct.printv(
            '\nEstimate transformation for step #' + str(i_step) + '...',
            verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run(
                'sct_apply_transfo -i ' + src + ' -d ' + dest + ' -w ' +
                ','.join(warp_forward) + ' -o ' + sct.add_suffix(src, '_reg') +
                ' -x ' + interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg,
                                                      param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straightAffine.nii.gz,' +
        ','.join(warp_forward) +
        ' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run(
        'sct_concat_transfo -w ' + ','.join(warp_inverse) +
        ',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz',
        verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run(
            'sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1',
            verbose)
        sct.run(
            'sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1',
            verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp + '/warp_template2anat.nii.gz',
                             'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp + '/warp_anat2template.nii.gz',
                             'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp + '/template2anat.nii.gz',
                                 'template2anat' + ext_data, verbose)
        sct.generate_output_file(path_tmp + '/anat2template.nii.gz',
                                 'anat2template' + ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf ' + path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv(
        '\nFinished! Elapsed time: ' + str(int(round(elapsed_time))) + 's',
        verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview ' + fname_data + ' template2anat -b 0,4000 &', verbose,
               'info')
    sct.printv('fslview ' + fname_template + ' -b 0,5000 anat2template &\n',
               verbose, 'info')
Ejemplo n.º 5
0
def main():
    # Initialization
    fname_data = ''
    suffix_out = '_crop'
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '  # for faster processing, all outputs are in NIFTI
    remove_temp_files = param.remove_temp_files

    # Parameters for debug mode
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = path_sct + '/testing/data/errsm_23/t2/t2.nii.gz'
        remove_temp_files = 0
    else:
        # Check input parameters
        try:
            opts, args = getopt.getopt(sys.argv[1:], 'hi:r:v:')
        except getopt.GetoptError:
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_data = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-v'):
                verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_data == '':
        usage()

    # Check file existence
    sct.printv('\nCheck file existence...', verbose)
    sct.check_file_exist(fname_data, verbose)

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_data).dim
    sct.printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)
    # check if 4D data
    if not nt == 1:
        sct.printv(
            '\nERROR in ' + os.path.basename(__file__) +
            ': Data should be 3D.\n', 1, 'error')
        sys.exit(2)

    # print arguments
    print '\nCheck parameters:'
    print '  data ................... ' + fname_data
    print

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(fname_data)
    path_out, file_out, ext_out = '', file_data + suffix_out, ext_data

    # create temporary folder
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S") + '/'
    sct.run('mkdir ' + path_tmp)

    # copy files into tmp folder
    sct.run('isct_c3d ' + fname_data + ' -o ' + path_tmp + 'data.nii')

    # go to tmp folder
    os.chdir(path_tmp)

    # change orientation
    sct.printv('\nChange orientation to RPI...', verbose)
    set_orientation('data.nii', 'RPI', 'data_rpi.nii')

    # get image of medial slab
    sct.printv('\nGet image of medial slab...', verbose)
    image_array = nibabel.load('data_rpi.nii').get_data()
    nx, ny, nz = image_array.shape
    scipy.misc.imsave('image.jpg', image_array[math.floor(nx / 2), :, :])

    # Display the image
    sct.printv('\nDisplay image and get cropping region...', verbose)
    fig = plt.figure()
    # fig = plt.gcf()
    # ax = plt.gca()
    ax = fig.add_subplot(111)
    img = mpimg.imread("image.jpg")
    implot = ax.imshow(img.T)
    implot.set_cmap('gray')
    plt.gca().invert_yaxis()
    # mouse callback
    ax.set_title(
        'Left click on the top and bottom of your cropping field.\n Right click to remove last point.\n Close window when your done.'
    )
    line, = ax.plot([], [], 'ro')  # empty line
    cropping_coordinates = LineBuilder(line)
    plt.show()
    # disconnect callback
    # fig.canvas.mpl_disconnect(line)

    # check if user clicked two times
    if len(cropping_coordinates.xs) != 2:
        sct.printv('\nERROR: You have to select two points. Exit program.\n',
                   1, 'error')
        sys.exit(2)

    # convert coordinates to integer
    zcrop = [int(i) for i in cropping_coordinates.ys]

    # sort coordinates
    zcrop.sort()

    # crop image
    sct.printv('\nCrop image...', verbose)
    nii = Image('data_rpi.nii')
    data_crop = nii.data[:, :, zcrop[0]:zcrop[1]]
    nii.data = data_crop
    nii.setFileName('data_rpi_crop.nii')
    nii.save()

    # come back to parent folder
    os.chdir('..')

    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp + 'data_rpi_crop.nii',
                             path_out + file_out + ext_out)

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)

    # to view results
    print '\nDone! To view results, type:'
    print 'fslview ' + path_out + file_out + ext_out + ' &'
    print
def main():

    # get default parameters
    step1 = Paramreg(step='1', type='seg', algo='slicereg', metric='MeanSquares', iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(name="-l",
                      type_value="file",
                      description="Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
                      mandatory=True,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(name="-p",
                      type_value=[[':'], 'str'],
                      description="""Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""+paramreg.steps['1'].type+"""\nalgo="""+paramreg.steps['1'].algo+"""\nmetric="""+paramreg.steps['1'].metric+"""\npoly="""+paramreg.steps['1'].poly+"""\n--\nstep=2\ntype="""+paramreg.steps['2'].type+"""\nalgo="""+paramreg.steps['2'].algo+"""\nmetric="""+paramreg.steps['2'].metric+"""\niter="""+paramreg.steps['2'].iter+"""\nshrink="""+paramreg.steps['2'].shrink+"""\nsmooth="""+paramreg.steps['2'].smooth+"""\ngradStep="""+paramreg.steps['2'].gradStep+"""\n--""",
                      mandatory=False,
                      example="step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3")
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
                      mandatory=False,
                      default_value=param.verbose,
                      example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')


    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'/data.nii')
    sct.run('isct_c3d '+fname_landmarks+' -o '+path_tmp+'/landmarks.nii.gz')
    sct.run('isct_c3d '+fname_seg+' -o '+path_tmp+'/segmentation.nii.gz')
    sct.run('isct_c3d '+fname_template+' -o '+path_tmp+'/template.nii')
    sct.run('isct_c3d '+fname_template_label+' -o '+path_tmp+'/template_labels.nii.gz')
    sct.run('isct_c3d '+fname_template_seg+' -o '+path_tmp+'/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run('isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii')
    sct.run('isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz')
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run('sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '+str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz')

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d template_label.nii.gz -type int -o template_label.nii.gz', verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run('sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5')

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv('\nCreate a 5mm cross for the input labels and dilate for straightening preparation...', verbose)
    sct.run('sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn')

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run('isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz')

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.', verbose)
    sct.run('sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz')

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv('\nEstimate affine transfo: straight anat --> template (landmark-based)...', verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_moving.append([point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_fixed.append([point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv('\nComputing rigid transformation (algo=translation-scaling-z) ...', verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write("Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n" % (
        1.0/rotation_matrix[0, 0], rotation_matrix[0, 1],     rotation_matrix[0, 2],
        rotation_matrix[1, 0],     1.0/rotation_matrix[1, 1], rotation_matrix[1, 2],
        rotation_matrix[2, 0],     rotation_matrix[2, 1],     1.0/rotation_matrix[2, 2],
        translation_array[0, 0],   translation_array[0, 1],   -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" % (points_moving_barycenter[0],
                                                           points_moving_barycenter[1],
                                                           points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear')

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax('segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'/warp_template2anat.nii.gz', 'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'/warp_anat2template.nii.gz', 'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'/template2anat.nii.gz', 'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'/anat2template.nii.gz', 'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 anat2template &\n', verbose, 'info')
def compute_csa(fname_segmentation, name_method, volume_output, verbose, remove_temp_files, step, smoothing_param, figure_fit, name_output, slices, vert_levels, path_to_template, algo_fitting = 'hanning', type_window = 'hanning', window_length = 80):

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
    sct.run('mkdir '+path_tmp, verbose)

    # Copying input data to tmp folder and convert to nii
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('isct_c3d '+fname_segmentation+' -o '+path_tmp+'segmentation.nii')

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input segmentation into RPI
    sct.printv('\nChange orientation of the input segmentation into RPI...', verbose)
    fname_segmentation_orient = set_orientation('segmentation.nii', 'RPI', 'segmentation_orient.nii')

    # Get size of data
    sct.printv('\nGet data dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_segmentation_orient).dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)

    # Open segmentation volume
    sct.printv('\nOpen segmentation volume...', verbose)
    file_seg = nibabel.load(fname_segmentation_orient)
    data_seg = file_seg.get_data()
    hdr_seg = file_seg.get_header()

    # # Extract min and max index in Z direction
    X, Y, Z = (data_seg > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    # Xp, Yp = (data_seg[:, :, 0] >= 0).nonzero()  # X and Y range

    # extract centerline and smooth it
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(fname_segmentation_orient, algo_fitting=algo_fitting, type_window=type_window, window_length=window_length, verbose=verbose)
    z_centerline_scaled = [x*pz for x in z_centerline]

    # Compute CSA
    sct.printv('\nCompute CSA...', verbose)

    # Empty arrays in which CSA for each z slice will be stored
    csa = np.zeros(max_z_index-min_z_index+1)
    # csa = [0.0 for i in xrange(0, max_z_index-min_z_index+1)]

    for iz in xrange(0, len(z_centerline)):

        # compute the vector normal to the plane
        normal = normalize(np.array([x_centerline_deriv[iz], y_centerline_deriv[iz], z_centerline_deriv[iz]]))

        # compute the angle between the normal vector of the plane and the vector z
        angle = np.arccos(np.dot(normal, [0, 0, 1]))

        # compute the number of voxels, assuming the segmentation is coded for partial volume effect between 0 and 1.
        number_voxels = sum(sum(data_seg[:, :, iz+min_z_index]))

        # compute CSA, by scaling with voxel size (in mm) and adjusting for oblique plane
        csa[iz] = number_voxels * px * py * np.cos(angle)

    if smoothing_param:
        from msct_smooth import smoothing_window
        sct.printv('\nSmooth CSA across slices...', verbose)
        sct.printv('.. Hanning window: '+str(smoothing_param)+' mm', verbose)
        csa_smooth = smoothing_window(csa, window_len=smoothing_param/pz, window='hanning', verbose=0)
        # display figure
        if verbose == 2:
            import matplotlib.pyplot as plt
            plt.figure()
            pltx, = plt.plot(z_centerline_scaled, csa, 'bo')
            pltx_fit, = plt.plot(z_centerline_scaled, csa_smooth, 'r', linewidth=2)
            plt.title("Cross-sectional area (CSA)")
            plt.xlabel('z (mm)')
            plt.ylabel('CSA (mm^2)')
            plt.legend([pltx, pltx_fit], ['Raw', 'Smoothed'])
            plt.show()
        # update variable
        csa = csa_smooth

    # Create output text file
    sct.printv('\nWrite text file...', verbose)
    file_results = open('csa.txt', 'w')
    for i in range(min_z_index, max_z_index+1):
        file_results.write(str(int(i)) + ',' + str(csa[i-min_z_index])+'\n')
        # Display results
        sct.printv('z='+str(i-min_z_index)+': '+str(csa[i-min_z_index])+' mm^2', verbose, 'bold')
    file_results.close()

    # output volume of csa values
    if volume_output:
        sct.printv('\nCreate volume of CSA values...', verbose)
        # get orientation of the input data
        orientation = get_orientation('segmentation.nii')
        data_seg = data_seg.astype(np.float32, copy=False)
        # loop across slices
        for iz in range(min_z_index, max_z_index+1):
            # retrieve seg pixels
            x_seg, y_seg = (data_seg[:, :, iz] > 0).nonzero()
            seg = [[x_seg[i],y_seg[i]] for i in range(0, len(x_seg))]
            # loop across pixels in segmentation
            for i in seg:
                # replace value with csa value
                data_seg[i[0], i[1], iz] = csa[iz-min_z_index]
        # create header
        hdr_seg.set_data_dtype('float32')  # set imagetype to uint8
        # save volume
        img = nibabel.Nifti1Image(data_seg, None, hdr_seg)
        nibabel.save(img, 'csa_RPI.nii')
        # Change orientation of the output centerline into input orientation
        fname_csa_volume = set_orientation('csa_RPI.nii', orientation, 'csa_RPI_orient.nii')

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    from shutil import copyfile
    copyfile(path_tmp+'csa.txt', path_data+param.fname_csa)
    # sct.generate_output_file(path_tmp+'csa.txt', path_data+param.fname_csa)  # extension already included in param.fname_csa
    if volume_output:
        sct.generate_output_file(fname_csa_volume, path_data+name_output)  # extension already included in name_output

    # average csa across vertebral levels or slices if asked (flag -z or -l)
    if slices or vert_levels:

        if vert_levels and not path_to_template:
            sct.printv('\nERROR: Path to template is missing. See usage.\n', 1, 'error')
            sys.exit(2)
        elif vert_levels and path_to_template:
            abs_path_to_template = os.path.abspath(path_to_template)

        # go to tmp folder
        os.chdir(path_tmp)

        # create temporary folder
        sct.printv('\nCreate temporary folder to average csa...', verbose)
        path_tmp_extract_metric = sct.slash_at_the_end('label_temp', 1)
        sct.run('mkdir '+path_tmp_extract_metric, verbose)

        # Copying output CSA volume in the temporary folder
        sct.printv('\nCopy data to tmp folder...', verbose)
        sct.run('cp '+fname_segmentation+' '+path_tmp_extract_metric)

        # create file info_label
        path_fname_seg, file_fname_seg, ext_fname_seg = sct.extract_fname(fname_segmentation)
        create_info_label('info_label.txt', path_tmp_extract_metric, file_fname_seg+ext_fname_seg)

        # average CSA
        if slices:
            os.system("sct_extract_metric -i "+path_data+name_output+" -f "+path_tmp_extract_metric+" -m wa -o ../csa_mean.txt -z "+slices)
        if vert_levels:
            sct.run('cp -R '+abs_path_to_template+' .')
            os.system("sct_extract_metric -i "+path_data+name_output+" -f "+path_tmp_extract_metric+" -m wa -o ../csa_mean.txt -v "+vert_levels)

        os.chdir('..')

        # Remove temporary files
        print('\nRemove temporary folder used to average CSA...')
        sct.run('rm -rf '+path_tmp_extract_metric)

    # Remove temporary files
    if remove_temp_files:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)
Ejemplo n.º 8
0
def compute_csa(fname_segmentation,
                name_method,
                volume_output,
                verbose,
                remove_temp_files,
                step,
                smoothing_param,
                figure_fit,
                name_output,
                slices,
                vert_levels,
                path_to_template,
                algo_fitting='hanning',
                type_window='hanning',
                window_length=80):

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = sct.slash_at_the_end('tmp.' + time.strftime("%y%m%d%H%M%S"), 1)
    sct.run('mkdir ' + path_tmp, verbose)

    # Copying input data to tmp folder and convert to nii
    sct.printv('\nCopying input data to tmp folder and convert to nii...',
               verbose)
    sct.run('isct_c3d ' + fname_segmentation + ' -o ' + path_tmp +
            'segmentation.nii')

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input segmentation into RPI
    sct.printv('\nChange orientation of the input segmentation into RPI...',
               verbose)
    fname_segmentation_orient = set_orientation('segmentation.nii', 'RPI',
                                                'segmentation_orient.nii')

    # Get size of data
    sct.printv('\nGet data dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_segmentation_orient).dim
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)

    # Open segmentation volume
    sct.printv('\nOpen segmentation volume...', verbose)
    file_seg = nibabel.load(fname_segmentation_orient)
    data_seg = file_seg.get_data()
    hdr_seg = file_seg.get_header()

    # # Extract min and max index in Z direction
    X, Y, Z = (data_seg > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    # Xp, Yp = (data_seg[:, :, 0] >= 0).nonzero()  # X and Y range

    # extract centerline and smooth it
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient,
        algo_fitting=algo_fitting,
        type_window=type_window,
        window_length=window_length,
        verbose=verbose)
    z_centerline_scaled = [x * pz for x in z_centerline]

    # Compute CSA
    sct.printv('\nCompute CSA...', verbose)

    # Empty arrays in which CSA for each z slice will be stored
    csa = np.zeros(max_z_index - min_z_index + 1)
    # csa = [0.0 for i in xrange(0, max_z_index-min_z_index+1)]

    for iz in xrange(0, len(z_centerline)):

        # compute the vector normal to the plane
        normal = normalize(
            np.array([
                x_centerline_deriv[iz], y_centerline_deriv[iz],
                z_centerline_deriv[iz]
            ]))

        # compute the angle between the normal vector of the plane and the vector z
        angle = np.arccos(np.dot(normal, [0, 0, 1]))

        # compute the number of voxels, assuming the segmentation is coded for partial volume effect between 0 and 1.
        number_voxels = sum(sum(data_seg[:, :, iz + min_z_index]))

        # compute CSA, by scaling with voxel size (in mm) and adjusting for oblique plane
        csa[iz] = number_voxels * px * py * np.cos(angle)

    if smoothing_param:
        from msct_smooth import smoothing_window
        sct.printv('\nSmooth CSA across slices...', verbose)
        sct.printv('.. Hanning window: ' + str(smoothing_param) + ' mm',
                   verbose)
        csa_smooth = smoothing_window(csa,
                                      window_len=smoothing_param / pz,
                                      window='hanning',
                                      verbose=0)
        # display figure
        if verbose == 2:
            import matplotlib.pyplot as plt
            plt.figure()
            pltx, = plt.plot(z_centerline_scaled, csa, 'bo')
            pltx_fit, = plt.plot(z_centerline_scaled,
                                 csa_smooth,
                                 'r',
                                 linewidth=2)
            plt.title("Cross-sectional area (CSA)")
            plt.xlabel('z (mm)')
            plt.ylabel('CSA (mm^2)')
            plt.legend([pltx, pltx_fit], ['Raw', 'Smoothed'])
            plt.show()
        # update variable
        csa = csa_smooth

    # Create output text file
    sct.printv('\nWrite text file...', verbose)
    file_results = open('csa.txt', 'w')
    for i in range(min_z_index, max_z_index + 1):
        file_results.write(
            str(int(i)) + ',' + str(csa[i - min_z_index]) + '\n')
        # Display results
        sct.printv(
            'z=' + str(i - min_z_index) + ': ' + str(csa[i - min_z_index]) +
            ' mm^2', verbose, 'bold')
    file_results.close()

    # output volume of csa values
    if volume_output:
        sct.printv('\nCreate volume of CSA values...', verbose)
        # get orientation of the input data
        orientation = get_orientation('segmentation.nii')
        data_seg = data_seg.astype(np.float32, copy=False)
        # loop across slices
        for iz in range(min_z_index, max_z_index + 1):
            # retrieve seg pixels
            x_seg, y_seg = (data_seg[:, :, iz] > 0).nonzero()
            seg = [[x_seg[i], y_seg[i]] for i in range(0, len(x_seg))]
            # loop across pixels in segmentation
            for i in seg:
                # replace value with csa value
                data_seg[i[0], i[1], iz] = csa[iz - min_z_index]
        # create header
        hdr_seg.set_data_dtype('float32')  # set imagetype to uint8
        # save volume
        img = nibabel.Nifti1Image(data_seg, None, hdr_seg)
        nibabel.save(img, 'csa_RPI.nii')
        # Change orientation of the output centerline into input orientation
        fname_csa_volume = set_orientation('csa_RPI.nii', orientation,
                                           'csa_RPI_orient.nii')

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    from shutil import copyfile
    copyfile(path_tmp + 'csa.txt', path_data + param.fname_csa)
    # sct.generate_output_file(path_tmp+'csa.txt', path_data+param.fname_csa)  # extension already included in param.fname_csa
    if volume_output:
        sct.generate_output_file(
            fname_csa_volume, path_data +
            name_output)  # extension already included in name_output

    # average csa across vertebral levels or slices if asked (flag -z or -l)
    if slices or vert_levels:

        if vert_levels and not path_to_template:
            sct.printv('\nERROR: Path to template is missing. See usage.\n', 1,
                       'error')
            sys.exit(2)
        elif vert_levels and path_to_template:
            abs_path_to_template = os.path.abspath(path_to_template)

        # go to tmp folder
        os.chdir(path_tmp)

        # create temporary folder
        sct.printv('\nCreate temporary folder to average csa...', verbose)
        path_tmp_extract_metric = sct.slash_at_the_end('label_temp', 1)
        sct.run('mkdir ' + path_tmp_extract_metric, verbose)

        # Copying output CSA volume in the temporary folder
        sct.printv('\nCopy data to tmp folder...', verbose)
        sct.run('cp ' + fname_segmentation + ' ' + path_tmp_extract_metric)

        # create file info_label
        path_fname_seg, file_fname_seg, ext_fname_seg = sct.extract_fname(
            fname_segmentation)
        create_info_label('info_label.txt', path_tmp_extract_metric,
                          file_fname_seg + ext_fname_seg)

        # average CSA
        if slices:
            os.system("sct_extract_metric -i " + path_data + name_output +
                      " -f " + path_tmp_extract_metric +
                      " -m wa -o ../csa_mean.txt -z " + slices)
        if vert_levels:
            sct.run('cp -R ' + abs_path_to_template + ' .')
            os.system("sct_extract_metric -i " + path_data + name_output +
                      " -f " + path_tmp_extract_metric +
                      " -m wa -o ../csa_mean.txt -v " + vert_levels)

        os.chdir('..')

        # Remove temporary files
        print('\nRemove temporary folder used to average CSA...')
        sct.run('rm -rf ' + path_tmp_extract_metric)

    # Remove temporary files
    if remove_temp_files:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)
    def crop_with_gui(self):
        # Initialization
        fname_data = self.input_filename
        suffix_out = '_crop'
        remove_temp_files = self.rm_tmp_files
        verbose = self.verbose

        # for faster processing, all outputs are in NIFTI
        fsloutput = 'export FSLOUTPUTTYPE=NIFTI; '

        # Check file existence
        sct.printv('\nCheck file existence...', verbose)
        sct.check_file_exist(fname_data, verbose)

        # Get dimensions of data
        sct.printv('\nGet dimensions of data...', verbose)
        nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_data)
        sct.printv('.. '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
        # check if 4D data
        if not nt == 1:
            sct.printv('\nERROR in '+os.path.basename(__file__)+': Data should be 3D.\n', 1, 'error')
            sys.exit(2)

        # print arguments
        print '\nCheck parameters:'
        print '  data ................... '+fname_data
        print

        # Extract path/file/extension
        path_data, file_data, ext_data = sct.extract_fname(fname_data)
        path_out, file_out, ext_out = '', file_data+suffix_out, ext_data

        # create temporary folder
        path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")+'/'
        sct.run('mkdir '+path_tmp)

        # copy files into tmp folder
        sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'data.nii')

        # go to tmp folder
        os.chdir(path_tmp)

        # change orientation
        sct.printv('\nChange orientation to RPI...', verbose)
        set_orientation('data.nii', 'RPI', 'data_rpi.nii')

        # get image of medial slab
        sct.printv('\nGet image of medial slab...', verbose)
        image_array = nibabel.load('data_rpi.nii').get_data()
        nx, ny, nz = image_array.shape
        scipy.misc.imsave('image.jpg', image_array[math.floor(nx/2), :, :])

        # Display the image
        sct.printv('\nDisplay image and get cropping region...', verbose)
        fig = plt.figure()
        # fig = plt.gcf()
        # ax = plt.gca()
        ax = fig.add_subplot(111)
        img = mpimg.imread("image.jpg")
        implot = ax.imshow(img.T)
        implot.set_cmap('gray')
        plt.gca().invert_yaxis()
        # mouse callback
        ax.set_title('Left click on the top and bottom of your cropping field.\n Right click to remove last point.\n Close window when your done.')
        line, = ax.plot([], [], 'ro')  # empty line
        cropping_coordinates = LineBuilder(line)
        plt.show()
        # disconnect callback
        # fig.canvas.mpl_disconnect(line)

        # check if user clicked two times
        if len(cropping_coordinates.xs) != 2:
            sct.printv('\nERROR: You have to select two points. Exit program.\n', 1, 'error')
            sys.exit(2)

        # convert coordinates to integer
        zcrop = [int(i) for i in cropping_coordinates.ys]

        # sort coordinates
        zcrop.sort()

        # crop image
        sct.printv('\nCrop image...', verbose)
        sct.run(fsloutput+'fslroi data_rpi.nii data_rpi_crop.nii 0 -1 0 -1 '+str(zcrop[0])+' '+str(zcrop[1]-zcrop[0]+1))

        # come back to parent folder
        os.chdir('..')

        sct.printv('\nGenerate output files...', verbose)
        sct.generate_output_file(path_tmp+'data_rpi_crop.nii', path_out+file_out+ext_out)

        # Remove temporary files
        if remove_temp_files == 1:
            print('\nRemove temporary files...')
            sct.run('rm -rf '+path_tmp)

        # to view results
        print '\nDone! To view results, type:'
        print 'fslview '+path_out+file_out+ext_out+' &'
        print
def compute_csa(fname_segmentation, name_method, volume_output, verbose, remove_temp_files, spline_smoothing, step, smoothing_param, figure_fit, name_output, slices, vert_levels, path_to_template, algo_fitting = 'hanning', type_window = 'hanning', window_length = 80):

    #param.algo_fitting = 'hanning'

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
    sct.run('mkdir '+path_tmp, verbose)

    # Copying input data to tmp folder and convert to nii
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('isct_c3d '+fname_segmentation+' -o '+path_tmp+'segmentation.nii')

    # go to tmp folder
    os.chdir(path_tmp)
        
    # Change orientation of the input segmentation into RPI
    sct.printv('\nChange orientation of the input segmentation into RPI...', verbose)
    fname_segmentation_orient = set_orientation('segmentation.nii', 'RPI', 'segmentation_orient.nii')

    # Get size of data
    sct.printv('\nGet data dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)

    # Open segmentation volume
    sct.printv('\nOpen segmentation volume...', verbose)
    file_seg = nibabel.load(fname_segmentation_orient)
    data_seg = file_seg.get_data()
    hdr_seg = file_seg.get_header()

    #
    # # Extract min and max index in Z direction
    X, Y, Z = (data_seg > 0).nonzero()
    # coords_seg = np.array([str([X[i], Y[i], Z[i]]) for i in xrange(0,len(Z))])  # don't know why but finding strings in array of array of strings is WAY faster than doing the same with integers
    min_z_index, max_z_index = min(Z), max(Z)
    Xp,Yp = (data_seg[:,:,0]>=0).nonzero() # X and Y range
    #
    # x_centerline = [0 for i in xrange(0,max_z_index-min_z_index+1)]
    # y_centerline = [0 for i in xrange(0,max_z_index-min_z_index+1)]
    # z_centerline = np.array([iz for iz in xrange(min_z_index, max_z_index+1)])
    #
    # # Extract segmentation points and average per slice
    # for iz in xrange(min_z_index, max_z_index+1):
    #     x_seg, y_seg = (data_seg[:,:,iz]>0).nonzero()
    #     x_centerline[iz-min_z_index] = np.mean(x_seg)
    #     y_centerline[iz-min_z_index] = np.mean(y_seg)
    #
    # # Fit the centerline points with spline and return the new fitted coordinates
    # x_centerline_fit, y_centerline_fit,x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)

    # extract centerline and smooth it
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = smooth_centerline(fname_segmentation_orient, algo_fitting=algo_fitting, type_window=type_window, window_length=window_length, verbose = verbose)
    z_centerline_scaled = [x*pz for x in z_centerline]

   # # 3D plot of the fit
 #    fig=plt.figure()
 #    ax=Axes3D(fig)
 #    ax.plot(x_centerline,y_centerline,z_centerline,zdir='z')
 #    ax.plot(x_centerline_fit,y_centerline_fit,z_centerline,zdir='z')
 #    plt.show()

    # Defining cartesian basis vectors 
    x = np.array([1, 0, 0])
    y = np.array([0, 1, 0])
    z = np.array([0, 0, 1])
    
    # Creating folder in which JPG files will be stored
    sct.printv('\nCreating folder in which JPG files will be stored...', verbose)
    sct.create_folder('JPG_Results')

    # Compute CSA
    sct.printv('\nCompute CSA...', verbose)

    # Empty arrays in which CSA for each z slice will be stored
    csa = [0.0 for i in xrange(0,max_z_index-min_z_index+1)]
    # sections_ortho_counting = [0 for i in xrange(0,max_z_index-min_z_index+1)]
    # sections_ortho_ellipse = [0 for i in xrange(0,max_z_index-min_z_index+1)]
    # sections_z_ellipse = [0 for i in xrange(0,max_z_index-min_z_index+1)]
    # sections_z_counting = [0 for i in xrange(0,max_z_index-min_z_index+1)]
    sct.printv('\nCross-Section Area:', verbose, 'bold')

    for iz in xrange(0, len(z_centerline)):

        # Equation of the the plane which is orthogonal to the spline at z=iz
        a = x_centerline_deriv[iz]
        b = y_centerline_deriv[iz]
        c = z_centerline_deriv[iz]

        #vector normal to the plane
        normal = normalize(np.array([a, b, c]))

        # angle between normal vector and z
        angle = np.arccos(np.dot(normal, z))

        if name_method == 'counting_ortho_plane' or name_method == 'ellipse_ortho_plane':

            x_center = x_centerline_fit[iz]
            y_center = y_centerline_fit[iz]
            z_center = z_centerline[iz]

            # use of x in order to get orientation of each plane, basis_1 is in the plane ax+by+cz+d=0
            basis_1 = normalize(np.cross(normal,x))
            basis_2 = normalize(np.cross(normal,basis_1))

            # maximum dimension of the tilted plane. Try multiply numerator by sqrt(2) ?
            max_diameter = (max([(max(X)-min(X))*px,(max(Y)-min(Y))*py]))/(np.cos(angle))

            # Forcing the step to be the min of x and y scale (default value is 1 mm)
            step = min([px,py])

            # discretized plane which will be filled with 0/1
            plane_seg = np.zeros((int(max_diameter/step),int(max_diameter/step)))

            # how the plane will be skimmed through
            plane_grid = np.linspace(-int(max_diameter/2),int(max_diameter/2),int(max_diameter/step))

            # we go through the plane
            for i_b1 in plane_grid :

                for i_b2 in plane_grid :

                    point = np.array([x_center*px,y_center*py,z_center*pz]) + i_b1*basis_1 +i_b2*basis_2

                    # to which voxel belongs each point of the plane
                    coord_voxel = str([ int(point[0]/px), int(point[1]/py), int(point[2]/pz)])

                    if (coord_voxel in coords_seg) is True :  # if this voxel is 1
                        plane_seg[int((plane_grid==i_b1).nonzero()[0])][int((plane_grid==i_b2).nonzero()[0])] = 1

                        # number of voxels that are in the intersection of each plane and the nonzeros values of segmentation, times the area of one cell of the discretized plane
                        if name_method == 'counting_ortho_plane':
                            csa[iz] = len((plane_seg>0).nonzero()[0])*step*step

            # if verbose ==1 and name_method == 'counting_ortho_plane' :

                # print('Cross-Section Area : ' + str(csa[iz]) + ' mm^2')

            if name_method == 'ellipse_ortho_plane':

                # import scipy stuff
                from scipy.misc import imsave

                os.chdir('JPG_Results')
                imsave('plane_ortho_' + str(iz) + '.jpg', plane_seg)

                # Tresholded gradient image
                mag = edge_detection('plane_ortho_' + str(iz) + '.jpg')

                #Coordinates of the contour
                x_contour,y_contour = (mag>0).nonzero()

                x_contour = x_contour*step
                y_contour = y_contour*step

                #Fitting an ellipse
                fit = Ellipse_fit(x_contour,y_contour)

                # Semi-minor axis, semi-major axis
                a_ellipse, b_ellipse = ellipse_dim(fit)

                #Section = pi*a*b
                csa[iz] = a_ellipse*b_ellipse*np.pi

                # if verbose == 1 and name_method == 'ellipse_ortho_plane':
                #     print('Cross-Section Area : ' + str(csa[iz]) + ' mm^2')
                # os.chdir('..')

        if name_method == 'counting_z_plane' or name_method == 'ellipse_z_plane':

            # getting the segmentation for each z plane
            x_seg, y_seg = (data_seg[:, :, iz+min_z_index] > 0).nonzero()
            seg = [[x_seg[i], y_seg[i]] for i in range(0, len(x_seg))]

            plane = np.zeros((max(Xp), max(Yp)))

            for i in seg:
                # filling the plane with 0 and 1 regarding to the segmentation
                plane[i[0] - 1][i[1] - 1] = data_seg[i[0] - 1, i[1] - 1, iz+min_z_index]

            if name_method == 'counting_z_plane':
                x, y = (plane > 0.0).nonzero()
                len_x = len(x)
                for i in range(0, len_x):
                    csa[iz] += plane[x[i], y[i]]*px*py
                csa[iz] *= np.cos(angle)

            # if verbose == 1 and name_method == 'counting_z_plane':
            #     print('Cross-Section Area : ' + str(csa[iz]) + ' mm^2')

            if name_method == 'ellipse_z_plane':

                # import scipy stuff
                from scipy.misc import imsave

                os.chdir('JPG_Results')
                imsave('plane_z_' + str(iz) + '.jpg', plane)

                # Tresholded gradient image
                mag = edge_detection('plane_z_' + str(iz) + '.jpg')

                x_contour,y_contour = (mag>0).nonzero()

                x_contour = x_contour*px
                y_contour = y_contour*py

                # Fitting an ellipse
                fit = Ellipse_fit(x_contour,y_contour)
                a_ellipse, b_ellipse = ellipse_dim(fit)
                csa[iz] = a_ellipse*b_ellipse*np.pi*np.cos(angle)

                 # if verbose == 1 and name_method == 'ellipse_z_plane':
                 #     print('Cross-Section Area : ' + str(csa[iz]) + ' mm^2')

    if spline_smoothing == 1:
        sct.printv('\nSmoothing results with spline...', verbose)
        tck = scipy.interpolate.splrep(z_centerline_scaled, csa, s=smoothing_param)
        csa_smooth = scipy.interpolate.splev(z_centerline_scaled, tck)
        if figure_fit == 1:
            import matplotlib.pyplot as plt
            plt.figure()
            plt.plot(z_centerline_scaled, csa)
            plt.plot(z_centerline_scaled, csa_smooth)
            plt.legend(['CSA values', 'Smoothed values'], 2)
            plt.savefig('Spline_fit.png')
        csa = csa_smooth  # update variable

    # Create output text file
    sct.printv('\nWrite text file...', verbose)
    file_results = open('csa.txt', 'w')
    for i in range(min_z_index, max_z_index+1):
        file_results.write(str(int(i)) + ',' + str(csa[i-min_z_index])+'\n')
        # Display results
        sct.printv('z='+str(i-min_z_index)+': '+str(csa[i-min_z_index])+' mm^2', verbose, 'bold')
    file_results.close()

    # output volume of csa values
    if volume_output:
        sct.printv('\nCreate volume of CSA values...', verbose)
        # get orientation of the input data
        orientation = get_orientation('segmentation.nii')
        data_seg = data_seg.astype(np.float32, copy=False)
        # loop across slices
        for iz in range(min_z_index, max_z_index+1):
            # retrieve seg pixels
            x_seg, y_seg = (data_seg[:, :, iz] > 0).nonzero()
            seg = [[x_seg[i],y_seg[i]] for i in range(0, len(x_seg))]
            # loop across pixels in segmentation
            for i in seg:
                # replace value with csa value
                data_seg[i[0], i[1], iz] = csa[iz-min_z_index]
        # create header
        hdr_seg.set_data_dtype('float32')  # set imagetype to uint8
        # save volume
        img = nibabel.Nifti1Image(data_seg, None, hdr_seg)
        nibabel.save(img, 'csa_RPI.nii')
        # Change orientation of the output centerline into input orientation
        fname_csa_volume = set_orientation('csa_RPI.nii', orientation, 'csa_RPI_orient.nii')

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'csa.txt', path_data+param.fname_csa)  # extension already included in param.fname_csa
    if volume_output:
        sct.generate_output_file(fname_csa_volume, path_data+name_output)  # extension already included in name_output

    # average csa across vertebral levels or slices if asked (flag -z or -l)
    if slices or vert_levels:

        if vert_levels and not path_to_template:
            sct.printv('\nERROR: Path to template is missing. See usage.\n', 1, 'error')
            sys.exit(2)
        elif vert_levels and path_to_template:
            abs_path_to_template = os.path.abspath(path_to_template)

        # go to tmp folder
        os.chdir(path_tmp)

        # create temporary folder
        sct.printv('\nCreate temporary folder to average csa...', verbose)
        path_tmp_extract_metric = sct.slash_at_the_end('label_temp', 1)
        sct.run('mkdir '+path_tmp_extract_metric, verbose)

        # Copying output CSA volume in the temporary folder
        sct.printv('\nCopy data to tmp folder...', verbose)
        sct.run('cp '+fname_segmentation+' '+path_tmp_extract_metric)

        # create file info_label
        path_fname_seg, file_fname_seg, ext_fname_seg = sct.extract_fname(fname_segmentation)
        create_info_label('info_label.txt', path_tmp_extract_metric, file_fname_seg+ext_fname_seg)

        if slices:
            # average CSA
            os.system("sct_extract_metric -i "+path_data+name_output+" -f "+path_tmp_extract_metric+" -m wa -o "+sct.slash_at_the_end(path_data)+"mean_csa -z "+slices)
        if vert_levels:
            sct.run('cp -R '+abs_path_to_template+' .')
            # average CSA
            os.system("sct_extract_metric -i "+path_data+name_output+" -f "+path_tmp_extract_metric+" -m wa -o "+sct.slash_at_the_end(path_data)+"mean_csa -v "+vert_levels)

        os.chdir('..')

        # Remove temporary files
        print('\nRemove temporary folder used to average CSA...')
        sct.run('rm -rf '+path_tmp_extract_metric)

    # Remove temporary files
    if remove_temp_files:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)
def main(segmentation_file=None,
         label_file=None,
         output_file_name=None,
         parameter="binary_centerline",
         remove_temp_files=1,
         verbose=0):

    #Process for a binary file as output:
    if parameter == "binary_centerline":

        # Binary_centerline: Process for only a segmentation file:
        if "-i" in arguments and "-l" not in arguments:
            # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data, file_data, ext_data = sct.extract_fname(
                segmentation_file)

            # create temporary folder
            path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir ' + path_tmp)

            # copy files into tmp folder
            sct.run('cp ' + segmentation_file + ' ' + path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data
            set_orientation(file_data + ext_data, 'RPI',
                            fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data + ext_data)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(
                fname_segmentation_orient)
            print '.. ' + str(nx) + ' x ' + str(ny) + ' y ' + str(
                nz) + ' z ' + str(nt)

            print '\nOpen segmentation volume...'
            file = nibabel.load(fname_segmentation_orient)
            data = file.get_data()
            hdr = file.get_header()

            # Extract min and max index in Z direction
            X, Y, Z = (data > 0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            y_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index + 1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index + 1):
                x_seg, y_seg = (data[:, :, iz] > 0).nonzero()
                x_centerline[iz - min_z_index] = np.mean(x_seg)
                y_centerline[iz - min_z_index] = np.mean(y_seg)

            #ne sert a rien
            for k in range(len(X)):
                data[X[k], Y[k], Z[k]] = 0

            print len(x_centerline)
            # Fit the centerline points with splines and return the new fitted coordinates
            #done with nurbs for now
            x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_centerline(
                x_centerline, y_centerline, z_centerline)
            # Create an image with the centerline
            for iz in range(min_z_index, max_z_index + 1):
                data[round(x_centerline_fit[iz - min_z_index]),
                     round(y_centerline_fit[iz - min_z_index]),
                     iz] = 1  #with nurbs fitting
                #data[round(x_centerline[iz-min_z_index]), round(y_centerline[iz-min_z_index]), iz] = 1             #without nurbs fitting

            # Write the centerline image in RPI orientation
            hdr.set_data_dtype('uint8')  # set imagetype to uint8
            print '\nWrite NIFTI volumes...'
            img = nibabel.Nifti1Image(data, None, hdr)
            if output_file_name != None:
                file_name = output_file_name
            else:
                file_name = file_data + '_centerline' + ext_data
            nibabel.save(img, 'tmp.centerline.nii')
            sct.generate_output_file('tmp.centerline.nii', file_name)

            del data

            # come back to parent folder
            os.chdir('..')

            # Change orientation of the output centerline into input orientation
            print '\nOrient centerline image to input orientation: ' + orientation
            set_orientation(path_tmp + '/' + file_name, orientation, file_name)

            # Remove temporary files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf ' + path_tmp)

            return file_name

        # Binary_centerline: Process for only a label file:
        if "-l" in arguments and "-i" not in arguments:
            file = os.path.abspath(label_file)
            path_data, file_data, ext_data = sct.extract_fname(file)

            file = nibabel.load(label_file)
            data = file.get_data()
            hdr = file.get_header()

            X, Y, Z = (data > 0).nonzero()
            Z_new = np.linspace(min(Z), max(Z), (max(Z) - min(Z) + 1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose == 1:
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new, X_fit)
                plt.plot(Z, X, 'o', linestyle='None')
                plt.show()

                plt.figure()
                plt.plot(Z_new, Y_fit)
                plt.plot(Z, Y, 'o', linestyle='None')
                plt.show()

            data = data * 0

            for i in xrange(len(X_fit)):
                data[X_fit[i], Y_fit[i], Z_new[i]] = 1

            # Create NIFTI image
            print '\nSave volume ...'
            hdr.set_data_dtype('float32')  # set image type to uint8
            img = nibabel.Nifti1Image(data, None, hdr)
            if output_file_name != None:
                file_name = output_file_name
            else:
                file_name = file_data + '_centerline' + ext_data
            # save volume
            nibabel.save(img, file_name)
            print '\nFile created : ' + file_name

            del data

        #### Binary_centerline: Process for a segmentation file and a label file:
        if "-l" and "-i" in arguments:

            ## Creation of a temporary file that will contain each centerline file of the process
            path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir ' + path_tmp)

            ##From label file create centerline image
            print '\nPROCESS PART 1: From label file create centerline image.'
            file_label = os.path.abspath(label_file)
            path_data_label, file_data_label, ext_data_label = sct.extract_fname(
                file_label)

            file_label = nibabel.load(label_file)

            #Copy label_file into temporary folder
            sct.run('cp ' + label_file + ' ' + path_tmp)

            data_label = file_label.get_data()
            hdr_label = file_label.get_header()

            if verbose == 1:
                from copy import copy
                data_label_to_show = copy(data_label)

            X, Y, Z = (data_label > 0).nonzero()
            Z_new = np.linspace(min(Z), max(Z), (max(Z) - min(Z) + 1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose == 1:
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new, X_fit)
                plt.plot(Z, X, 'o', linestyle='None')
                plt.show()

                plt.figure()
                plt.plot(Z_new, Y_fit)
                plt.plot(Z, Y, 'o', linestyle='None')
                plt.show()

            data_label = data_label * 0

            for i in xrange(len(X_fit)):
                data_label[X_fit[i], Y_fit[i], Z_new[i]] = 1

            # Create NIFTI image
            print '\nSave volume ...'
            hdr_label.set_data_dtype('float32')  # set image type to uint8
            img = nibabel.Nifti1Image(data_label, None, hdr_label)
            # save volume
            file_name_label = file_data_label + '_centerline' + ext_data_label
            nibabel.save(img, file_name_label)
            print '\nFile created : ' + file_name_label

            # copy files into tmp folder
            sct.run('cp ' + file_name_label + ' ' + path_tmp)
            #effacer fichier dans folder parent
            os.remove(file_name_label)
            del data_label

            ##From segmentation file create centerline image
            print '\nPROCESS PART 2: From segmentation file create centerline image.'
            # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data_seg, file_data_seg, ext_data_seg = sct.extract_fname(
                segmentation_file)

            # copy files into tmp folder
            sct.run('cp ' + segmentation_file + ' ' + path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data_seg
            set_orientation(file_data_seg + ext_data_seg, 'RPI',
                            fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data_seg + ext_data_seg)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(
                fname_segmentation_orient)
            print '.. ' + str(nx) + ' x ' + str(ny) + ' y ' + str(
                nz) + ' z ' + str(nt)

            print '\nOpen segmentation volume...'
            file_seg = nibabel.load(fname_segmentation_orient)
            data_seg = file_seg.get_data()
            hdr_seg = file_seg.get_header()

            if verbose == 1:
                data_seg_to_show = copy(data_seg)

            # Extract min and max index in Z direction
            X, Y, Z = (data_seg > 0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            y_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index + 1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index + 1):
                x_seg, y_seg = (data_seg[:, :, iz] > 0).nonzero()
                x_centerline[iz - min_z_index] = np.mean(x_seg)
                y_centerline[iz - min_z_index] = np.mean(y_seg)
            for k in range(len(X)):
                data_seg[X[k], Y[k], Z[k]] = 0
            # Fit the centerline points with splines and return the new fitted coordinates
            #done with nurbs for now
            x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_centerline(
                x_centerline, y_centerline, z_centerline)

            # Create an image with the centerline
            for iz in range(min_z_index, max_z_index + 1):
                data_seg[round(x_centerline_fit[iz - min_z_index]),
                         round(y_centerline_fit[iz - min_z_index]), iz] = 1
            # Write the centerline image in RPI orientation
            hdr_seg.set_data_dtype('uint8')  # set imagetype to uint8
            print '\nWrite NIFTI volumes...'
            img = nibabel.Nifti1Image(data_seg, None, hdr_seg)
            nibabel.save(img, 'tmp.centerline.nii')
            file_name_seg = file_data_seg + '_centerline' + ext_data_seg
            sct.generate_output_file('tmp.centerline.nii',
                                     file_name_seg)  #pb ici

            # copy files into parent folder
            #sct.run('cp '+file_name_seg+' ../')

            del data_seg

            # come back to parent folder
            #            os.chdir('..')

            # Change orientation of the output centerline into input orientation
            print '\nOrient centerline image to input orientation: ' + orientation
            set_orientation(file_name_seg, orientation, file_name_seg)

            print '\nRemoving overlap of the centerline obtain with label file if there are any:'

            ## Remove overlap from centerline file obtain with label file
            remove_overlap(file_name_label, file_name_seg,
                           "generated_centerline_without_overlap.nii.gz")

            ## Concatenation of the two centerline files
            print '\nConcatenation of the two centerline files:'
            if output_file_name != None:
                file_name = output_file_name
            else:
                file_name = 'centerline_total_from_label_and_seg'

            sct.run(
                'fslmaths generated_centerline_without_overlap.nii.gz -add ' +
                file_name_seg + ' ' + file_name)

            if verbose == 1:
                import matplotlib.pyplot as plt
                from scipy import ndimage

                #Get back concatenation of segmentation and labels before any processing
                data_concatenate = data_seg_to_show + data_label_to_show
                z_centerline = [
                    iz for iz in range(0, nz, 1)
                    if data_concatenate[:, :, iz].any()
                ]
                nz_nonz = len(z_centerline)
                x_centerline = [0 for iz in range(0, nz_nonz, 1)]
                y_centerline = [0 for iz in range(0, nz_nonz, 1)]

                # Calculate centerline coordinates and create image of the centerline
                for iz in range(0, nz_nonz, 1):
                    x_centerline[iz], y_centerline[
                        iz] = ndimage.measurements.center_of_mass(
                            data_concatenate[:, :, z_centerline[iz]])

                #Load file with resulting centerline
                file_centerline_fit = nibabel.load(file_name)
                data_centerline_fit = file_centerline_fit.get_data()

                z_centerline_fit = [
                    iz for iz in range(0, nz, 1)
                    if data_centerline_fit[:, :, iz].any()
                ]
                nz_nonz_fit = len(z_centerline_fit)
                x_centerline_fit_total = [0 for iz in range(0, nz_nonz_fit, 1)]
                y_centerline_fit_total = [0 for iz in range(0, nz_nonz_fit, 1)]

                #Convert to array
                x_centerline_fit_total = np.asarray(x_centerline_fit_total)
                y_centerline_fit_total = np.asarray(y_centerline_fit_total)
                #Calculate overlap between seg and label
                length_overlap = X_fit.shape[0] + x_centerline_fit.shape[
                    0] - x_centerline_fit_total.shape[0]
                # The total fitting is the concatenation of the two fitting (
                for i in range(x_centerline_fit.shape[0]):
                    x_centerline_fit_total[i] = x_centerline_fit[i]
                    y_centerline_fit_total[i] = y_centerline_fit[i]
                for i in range(X_fit.shape[0] - length_overlap):
                    x_centerline_fit_total[x_centerline_fit.shape[0] +
                                           i] = X_fit[i + length_overlap]
                    y_centerline_fit_total[x_centerline_fit.shape[0] +
                                           i] = Y_fit[i + length_overlap]
                    print x_centerline_fit.shape[0] + i

                #for iz in range(0, nz_nonz_fit, 1):
                #    x_centerline_fit[iz], y_centerline_fit[iz] = ndimage.measurements.center_of_mass(data_centerline_fit[:, :, z_centerline_fit[iz]])

                #Creation of a vector x that takes into account the distance between the labels
                #x_centerline_fit = np.asarray(x_centerline_fit)
                #y_centerline_fit = np.asarray(y_centerline_fit)
                x_display = [0 for i in range(x_centerline_fit_total.shape[0])]
                y_display = [0 for i in range(y_centerline_fit_total.shape[0])]

                for i in range(0, nz_nonz, 1):
                    x_display[z_centerline[i] -
                              z_centerline[0]] = x_centerline[i]
                    y_display[z_centerline[i] -
                              z_centerline[0]] = y_centerline[i]

                plt.figure(1)
                plt.subplot(2, 1, 1)
                plt.plot(z_centerline_fit, x_display, 'ro')
                plt.plot(z_centerline_fit, x_centerline_fit_total)
                plt.xlabel("Z")
                plt.ylabel("X")
                plt.title("x and x_fit coordinates")

                plt.subplot(2, 1, 2)
                plt.plot(z_centerline_fit, y_display, 'ro')
                plt.plot(z_centerline_fit, y_centerline_fit_total)
                plt.xlabel("Z")
                plt.ylabel("Y")
                plt.title("y and y_fit coordinates")
                plt.show()

                del data_concatenate, data_label_to_show, data_seg_to_show, data_centerline_fit

            sct.run('cp ' + file_name + ' ../')

            # Copy result into parent folder
            sct.run('cp ' + file_name + ' ../')

            # Come back to parent folder
            os.chdir('..')

            # Remove temporary centerline files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf ' + path_tmp)

#Process for a text file as output:
    if parameter == "text_file":
        print "\nText file process"
        #Process for only a segmentation file:
        if "-i" in arguments and "-l" not in arguments:

            # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data, file_data, ext_data = sct.extract_fname(
                segmentation_file)

            # create temporary folder
            path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir ' + path_tmp)

            # copy files into tmp folder
            sct.run('cp ' + segmentation_file + ' ' + path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data
            set_orientation(file_data + ext_data, 'RPI',
                            fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data + ext_data)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(
                fname_segmentation_orient)
            print '.. ' + str(nx) + ' x ' + str(ny) + ' y ' + str(
                nz) + ' z ' + str(nt)

            print '\nOpen segmentation volume...'
            file = nibabel.load(fname_segmentation_orient)
            data = file.get_data()
            hdr = file.get_header()

            # Extract min and max index in Z direction
            X, Y, Z = (data > 0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            y_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index + 1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index + 1):
                x_seg, y_seg = (data[:, :, iz] > 0).nonzero()
                x_centerline[iz - min_z_index] = np.mean(x_seg)
                y_centerline[iz - min_z_index] = np.mean(y_seg)
            for k in range(len(X)):
                data[X[k], Y[k], Z[k]] = 0
            # Fit the centerline points with splines and return the new fitted coordinates
            x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_centerline(
                x_centerline, y_centerline, z_centerline)

            # Create output text file
            if output_file_name != None:
                file_name = output_file_name
            else:
                file_name = file_data + '_centerline' + '.txt'

            sct.printv('\nWrite text file...', verbose)
            #file_results = open("../"+file_name, 'w')
            file_results = open(file_name, 'w')
            for i in range(min_z_index, max_z_index + 1):
                file_results.write(
                    str(int(i)) + ' ' +
                    str(x_centerline_fit[i - min_z_index]) + ' ' +
                    str(y_centerline_fit[i - min_z_index]) + '\n')
            file_results.close()

            # Copy result into parent folder
            sct.run('cp ' + file_name + ' ../')

            del data

            # come back to parent folder
            os.chdir('..')

            # Remove temporary files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf ' + path_tmp)

            return file_name

        #Process for only a label file:
        if "-l" in arguments and "-i" not in arguments:
            file = os.path.abspath(label_file)
            path_data, file_data, ext_data = sct.extract_fname(file)

            file = nibabel.load(label_file)
            data = file.get_data()
            hdr = file.get_header()

            X, Y, Z = (data > 0).nonzero()
            Z_new = np.linspace(min(Z), max(Z), (max(Z) - min(Z) + 1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose == 1:
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new, X_fit)
                plt.plot(Z, X, 'o', linestyle='None')
                plt.show()

                plt.figure()
                plt.plot(Z_new, Y_fit)
                plt.plot(Z, Y, 'o', linestyle='None')
                plt.show()

            data = data * 0

            for iz in xrange(len(X_fit)):
                data[X_fit[iz], Y_fit[iz], Z_new[iz]] = 1

            # Create output text file
            sct.printv('\nWrite text file...', verbose)
            if output_file_name != None:
                file_name = output_file_name
            else:
                file_name = file_data + '_centerline' + ext_data
            file_results = open(file_name, 'w')
            min_z_index, max_z_index = min(Z), max(Z)
            for i in range(min_z_index, max_z_index + 1):
                file_results.write(
                    str(int(i)) + ' ' + str(X_fit[i - min_z_index]) + ' ' +
                    str(Y_fit[i - min_z_index]) + '\n')
            file_results.close()

            del data

        #Process for a segmentation file and a label file:
        if "-l" and "-i" in arguments:

            ## Creation of a temporary file that will contain each centerline file of the process
            path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir ' + path_tmp)

            ##From label file create centerline text file
            print '\nPROCESS PART 1: From label file create centerline text file.'
            file_label = os.path.abspath(label_file)
            path_data_label, file_data_label, ext_data_label = sct.extract_fname(
                file_label)

            file_label = nibabel.load(label_file)

            #Copy label_file into temporary folder
            sct.run('cp ' + label_file + ' ' + path_tmp)

            data_label = file_label.get_data()
            hdr_label = file_label.get_header()

            X, Y, Z = (data_label > 0).nonzero()
            Z_new = np.linspace(min(Z), max(Z), (max(Z) - min(Z) + 1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose == 1:
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new, X_fit)
                plt.plot(Z, X, 'o', linestyle='None')
                plt.show()

                plt.figure()
                plt.plot(Z_new, Y_fit)
                plt.plot(Z, Y, 'o', linestyle='None')
                plt.show()

            data_label = data_label * 0

            for i in xrange(len(X_fit)):
                data_label[X_fit[i], Y_fit[i], Z_new[i]] = 1

            # Create output text file
            sct.printv('\nWrite text file...', verbose)
            file_name_label = file_data_label + '_centerline' + '.txt'
            file_results = open(path_tmp + '/' + file_name_label, 'w')
            min_z_index, max_z_index = min(Z), max(Z)
            for i in range(min_z_index, max_z_index + 1):
                file_results.write(
                    str(int(i)) + ' ' + str(X_fit[i - min_z_index]) + ' ' +
                    str(Y_fit[i - min_z_index]) + '\n')
            file_results.close()

            # copy files into tmp folder
            #sct.run('cp '+file_name_label+' '+path_tmp)

            del data_label

            ##From segmentation file create centerline text file
            print '\nPROCESS PART 2: From segmentation file create centerline image.'
            # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data_seg, file_data_seg, ext_data_seg = sct.extract_fname(
                segmentation_file)

            # copy files into tmp folder
            sct.run('cp ' + segmentation_file + ' ' + path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data_seg
            set_orientation(file_data_seg + ext_data_seg, 'RPI',
                            fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data_seg + ext_data_seg)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(
                fname_segmentation_orient)
            print '.. ' + str(nx) + ' x ' + str(ny) + ' y ' + str(
                nz) + ' z ' + str(nt)

            print '\nOpen segmentation volume...'
            file_seg = nibabel.load(fname_segmentation_orient)
            data_seg = file_seg.get_data()
            hdr_seg = file_seg.get_header()

            # Extract min and max index in Z direction
            X, Y, Z = (data_seg > 0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            y_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index + 1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index + 1):
                x_seg, y_seg = (data_seg[:, :, iz] > 0).nonzero()
                x_centerline[iz - min_z_index] = np.mean(x_seg)
                y_centerline[iz - min_z_index] = np.mean(y_seg)
            for k in range(len(X)):
                data_seg[X[k], Y[k], Z[k]] = 0
            # Fit the centerline points with splines and return the new fitted coordinates
            #done with nurbs for now
            x_centerline_fit, y_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = b_spline_centerline(
                x_centerline, y_centerline, z_centerline)

            # Create output text file
            file_name_seg = file_data_seg + '_centerline' + '.txt'
            sct.printv('\nWrite text file...', verbose)
            file_results = open(file_name_seg, 'w')
            for i in range(min_z_index, max_z_index + 1):
                file_results.write(
                    str(int(i)) + ' ' +
                    str(x_centerline_fit[i - min_z_index]) + ' ' +
                    str(y_centerline_fit[i - min_z_index]) + '\n')
            file_results.close()

            del data_seg

            print '\nRemoving overlap of the centerline obtain with label file if there are any:'

            ## Remove overlap from centerline file obtain with label file
            remove_overlap(file_name_label,
                           file_name_seg,
                           "generated_centerline_without_overlap1.txt",
                           parameter=1)

            ## Concatenation of the two centerline files
            print '\nConcatenation of the two centerline files:'
            if output_file_name != None:
                file_name = output_file_name
            else:
                file_name = 'centerline_total_from_label_and_seg.txt'

            f_output = open(file_name, "w")
            f_output.close()
            with open(file_name_seg, "r") as f_seg:
                with open("generated_centerline_without_overlap1.txt",
                          "r") as f:
                    with open(file_name, "w") as f_output:
                        data_line_seg = f_seg.readlines()
                        data_line = f.readlines()
                        for line in data_line_seg:
                            f_output.write(line)
                        for line in data_line:
                            f_output.write(line)

            # Copy result into parent folder
            sct.run('cp ' + file_name + ' ../')

            # Come back to parent folder
            os.chdir('..')

            # Remove temporary centerline files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf ' + path_tmp)
def main():
    
    # Initialization
    fname_anat = ''
    fname_centerline = ''
    gapxy = param.gapxy
    gapz = param.gapz
    padding = param.padding
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    interpolation_warp = param.interpolation_warp
    algo_fitting = param.algo_fitting
    window_length = param.window_length
    type_window = param.type_window
    crop = param.crop

    # start timer
    start_time = time.time()

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    print path_sct

    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_anat = '/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/anat_rpi.nii'  #path_sct+'/testing/sct_testing_data/data/t2/t2.nii.gz'
        fname_centerline = '/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/centerline_rpi.nii'  # path_sct+'/testing/sct_testing_data/data/t2/t2_seg.nii.gz'
        remove_temp_files = 0
        type_window = 'hanning'
        verbose = 2
    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:c:p:r:v:x:a:f:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-c'):
                fname_centerline = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-p'):
                padding = int(arg)
            elif opt in ('-x'):
                interpolation_warp = str(arg)
            elif opt in ('-a'):
                algo_fitting = str(arg)
            elif opt in ('-f'):
                crop = int(arg)
            # elif opt in ('-f'):
            #     centerline_fitting = str(arg)
            elif opt in ('-v'):
                verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # check if algorithm for fitting is correct
    if algo_fitting not in ['hanning','nurbs']:
        sct.printv('ERROR: wrong fitting algorithm',1,'warning')
        usage()

    # update field
    param.verbose = verbose

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... '+fname_anat
    print '  Centerline ........................ '+fname_centerline
    print '  Final interpolation ............... '+interpolation_warp
    print '  Verbose ........................... '+str(verbose)
    print ''


    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)
    
    # create temporary folder
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir '+path_tmp, verbose)

    # copy files into tmp folder
    sct.run('cp '+fname_anat+' '+path_tmp)
    sct.run('cp '+fname_centerline+' '+path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', verbose)
    fname_centerline_orient = file_centerline+'_rpi.nii.gz'
    set_orientation(fname_centerline, 'RPI', fname_centerline_orient)

    # Get dimension
    sct.printv('\nGet dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
    sct.printv('.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
    sct.printv('.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm', verbose)

    # smooth centerline
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(fname_centerline_orient, algo_fitting=algo_fitting, type_window=type_window, window_length=window_length,verbose=verbose)

    # Get coordinates of landmarks along curved centerline
    #==========================================================================================
    sct.printv('\nGet coordinates of landmarks along curved centerline...', verbose)
    # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy. In voxel space!!!
    
    # find z indices along centerline given a specific gap: iz_curved
    nz_nonz = len(z_centerline)
    nb_landmark = int(round(float(nz_nonz)/gapz))

    if nb_landmark == 0:
        nb_landmark = 1

    if nb_landmark == 1:
        iz_curved = [0]
    else:
        iz_curved = [i*gapz for i in range(0, nb_landmark-1)]

    iz_curved.append(nz_nonz-1)
    #print iz_curved, len(iz_curved)
    n_iz_curved = len(iz_curved)
    #print n_iz_curved

    # landmark_curved initialisation
    landmark_curved = [ [ [ 0 for i in range(0, 3)] for i in range(0, 5) ] for i in iz_curved ]

    ### TODO: THIS PART IS SLOW AND CAN BE MADE FASTER
    ### >>==============================================================================================================
    for index in range(0, n_iz_curved, 1):
        # calculate d (ax+by+cz+d=0)
        # print iz_curved[index]
        a=x_centerline_deriv[iz_curved[index]]
        b=y_centerline_deriv[iz_curved[index]]
        c=z_centerline_deriv[iz_curved[index]]
        x=x_centerline_fit[iz_curved[index]]
        y=y_centerline_fit[iz_curved[index]]
        z=z_centerline[iz_curved[index]]
        d=-(a*x+b*y+c*z)
        #print a,b,c,d,x,y,z
        # set coordinates for landmark at the center of the cross
        landmark_curved[index][0][0], landmark_curved[index][0][1], landmark_curved[index][0][2] = x_centerline_fit[iz_curved[index]], y_centerline_fit[iz_curved[index]], z_centerline[iz_curved[index]]

        # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
        for i in range(1, 3):
            landmark_curved[index][i][1] = y_centerline_fit[iz_curved[index]]

        # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
        x_n = Symbol('x_n')
        landmark_curved[index][2][0], landmark_curved[index][1][0]=solve((x_n-x)**2+((-1/c)*(a*x_n+b*y+d)-z)**2-gapxy**2,x_n)  #x for -x and +x
        landmark_curved[index][1][2] = (-1/c)*(a*landmark_curved[index][1][0]+b*y+d)  # z for +x
        landmark_curved[index][2][2] = (-1/c)*(a*landmark_curved[index][2][0]+b*y+d)  # z for -x

        # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
        for i in range(3, 5):
            landmark_curved[index][i][0] = x_centerline_fit[iz_curved[index]]

        # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
        y_n = Symbol('y_n')
        landmark_curved[index][4][1],landmark_curved[index][3][1] = solve((y_n-y)**2+((-1/c)*(a*x+b*y_n+d)-z)**2-gapxy**2,y_n)  #y for -y and +y
        landmark_curved[index][3][2] = (-1/c)*(a*x+b*landmark_curved[index][3][1]+d)  # z for +y
        landmark_curved[index][4][2] = (-1/c)*(a*x+b*landmark_curved[index][4][1]+d)  # z for -y
    ### <<==============================================================================================================

    if verbose == 2:
        from mpl_toolkits.mplot3d import Axes3D
        import matplotlib.pyplot as plt
        fig = plt.figure()
        ax = Axes3D(fig)
        ax.plot(x_centerline_fit, y_centerline_fit,z_centerline,zdir='z')
        ax.plot([landmark_curved[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
              [landmark_curved[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
              [landmark_curved[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')
        plt.show()

    # Get coordinates of landmarks along straight centerline
    #==========================================================================================
    sct.printv('\nGet coordinates of landmarks along straight centerline...', verbose)
    landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved
    
    # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
    # TODO: DO NOT APPROXIMATE CURVE --> LINE
    if nb_landmark == 1:
        iz_straight = [0 for i in range(0, nb_landmark+1)]
    else:
        iz_straight = [0 for i in range(0, nb_landmark)]

    # print iz_straight,len(iz_straight)
    iz_straight[0] = iz_curved[0]
    for index in range(1, n_iz_curved, 1):
        # compute vector between two consecutive points on the curved centerline
        vector_centerline = [x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index-1]], \
                             y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index-1]], \
                             z_centerline[iz_curved[index]] - z_centerline[iz_curved[index-1]] ]
        # compute norm of this vector
        norm_vector_centerline = linalg.norm(vector_centerline, ord=2)
        # round to closest integer value
        norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
        # assign this value to the current z-coordinate on the straight centerline
        iz_straight[index] = iz_straight[index-1] + norm_vector_centerline_rounded
    
    # initialize x0 and y0 to be at the center of the FOV
    x0 = int(round(nx/2))
    y0 = int(round(ny/2))
    for index in range(0, n_iz_curved, 1):
        # set coordinates for landmark at the center of the cross
        landmark_straight[index][0][0], landmark_straight[index][0][1], landmark_straight[index][0][2] = x0, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +x
        landmark_straight[index][1][0], landmark_straight[index][1][1], landmark_straight[index][1][2] = x0 + gapxy, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks -x
        landmark_straight[index][2][0], landmark_straight[index][2][1], landmark_straight[index][2][2] = x0-gapxy, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +y
        landmark_straight[index][3][0], landmark_straight[index][3][1], landmark_straight[index][3][2] = x0, y0+gapxy, iz_straight[index]
        # set x, y and z coordinates for landmarks -y
        landmark_straight[index][4][0], landmark_straight[index][4][1], landmark_straight[index][4][2] = x0, y0-gapxy, iz_straight[index]

    # # display
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # #ax.plot(x_centerline_fit, y_centerline_fit,z_centerline, 'r')
    # ax.plot([landmark_straight[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
    # ax.set_xlabel('x')
    # ax.set_ylabel('y')
    # ax.set_zlabel('z')
    # plt.show()
    #
    
    # Create NIFTI volumes with landmarks
    #==========================================================================================
    # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
    # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
    #sct.run('fslview ' + fname_centerline_orient)
    sct.printv('\nPad input volume to account for landmarks that fall outside the FOV...', verbose)
    sct.run('isct_c3d '+fname_centerline_orient+' -pad '+str(padding)+'x'+str(padding)+'x'+str(padding)+'vox '+str(padding)+'x'+str(padding)+'x'+str(padding)+'vox 0 -o tmp.centerline_pad.nii.gz')
    
    # Open padded centerline for reading
    sct.printv('\nOpen padded centerline for reading...', verbose)
    file = load('tmp.centerline_pad.nii.gz')
    data = file.get_data()
    hdr = file.get_header()
    
    # Create volumes containing curved and straight landmarks
    data_curved_landmarks = data * 0
    data_straight_landmarks = data * 0
    # initialize landmark value
    landmark_value = 1
    # Loop across cross index
    for index in range(0, n_iz_curved, 1):
        # loop across cross element index
        for i_element in range(0, 5, 1):
            # get x, y and z coordinates of curved landmark (rounded to closest integer)
            x, y, z = int(round(landmark_curved[index][i_element][0])), int(round(landmark_curved[index][i_element][1])), int(round(landmark_curved[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_curved_landmarks[x+padding-1:x+padding+2, y+padding-1:y+padding+2, z+padding-1:z+padding+2] = landmark_value
            # get x, y and z coordinates of straight landmark (rounded to closest integer)
            x, y, z = int(round(landmark_straight[index][i_element][0])), int(round(landmark_straight[index][i_element][1])), int(round(landmark_straight[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_straight_landmarks[x+padding-1:x+padding+2, y+padding-1:y+padding+2, z+padding-1:z+padding+2] = landmark_value
            # increment landmark value
            landmark_value = landmark_value + 1

    # Write NIFTI volumes
    sct.printv('\nWrite NIFTI volumes...', verbose)
    hdr.set_data_dtype('uint32')  # set imagetype to uint8 #TODO: maybe use int32
    img = Nifti1Image(data_curved_landmarks, None, hdr)
    save(img, 'tmp.landmarks_curved.nii.gz')
    sct.printv('.. File created: tmp.landmarks_curved.nii.gz', verbose)
    img = Nifti1Image(data_straight_landmarks, None, hdr)
    save(img, 'tmp.landmarks_straight.nii.gz')
    sct.printv('.. File created: tmp.landmarks_straight.nii.gz', verbose)


    # Estimate deformation field by pairing landmarks
    #==========================================================================================
    
    # This stands to avoid overlapping between landmarks
    sct.printv('\nMake sure all labels between landmark_curved and landmark_curved match...', verbose)
    sct.run('sct_label_utils -t remove -i tmp.landmarks_straight.nii.gz -o tmp.landmarks_straight.nii.gz -r tmp.landmarks_curved.nii.gz', verbose)

    # convert landmarks to INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d tmp.landmarks_straight.nii.gz -type int -o tmp.landmarks_straight.nii.gz', verbose)
    sct.run('isct_c3d tmp.landmarks_curved.nii.gz -type int -o tmp.landmarks_curved.nii.gz', verbose)

    # Estimate rigid transformation
    sct.printv('\nEstimate rigid transformation between paired landmarks...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt', verbose)
    
    # Apply rigid transformation
    sct.printv('\nApply rigid transformation to curved landmarks...', verbose)
    sct.run('sct_apply_transfo -i tmp.landmarks_curved.nii.gz -o tmp.landmarks_curved_rigid.nii.gz -d tmp.landmarks_straight.nii.gz -w tmp.curve2straight_rigid.txt -x nn', verbose)

    # Estimate b-spline transformation curve --> straight
    sct.printv('\nEstimate b-spline transformation: curve --> straight...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz 5x5x10 3 2 0', verbose)

    # remove padding for straight labels
    if crop == 1:
        sct.run('sct_crop_image -i tmp.landmarks_straight.nii.gz -o tmp.landmarks_straight_crop.nii.gz -dim 0 -bzmax', verbose)
        sct.run('sct_crop_image -i tmp.landmarks_straight_crop.nii.gz -o tmp.landmarks_straight_crop.nii.gz -dim 1 -bzmax', verbose)
        sct.run('sct_crop_image -i tmp.landmarks_straight_crop.nii.gz -o tmp.landmarks_straight_crop.nii.gz -dim 2 -bzmax', verbose)
    else:
        sct.run('cp tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz', verbose)

    # Concatenate rigid and non-linear transformations...
    sct.printv('\nConcatenate rigid and non-linear transformations...', verbose)
    #sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
    # !!! DO NOT USE sct.run HERE BECAUSE isct_ComposeMultiTransform OUTPUTS A NON-NULL STATUS !!!
    cmd = 'isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight_crop.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt'
    sct.printv(cmd, verbose, 'code')
    commands.getstatusoutput(cmd)

    # Estimate b-spline transformation straight --> curve
    # TODO: invert warping field instead of estimating a new one
    sct.printv('\nEstimate b-spline transformation: straight --> curve...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz 5x5x10 3 2 0', verbose)
    
    # Concatenate rigid and non-linear transformations...
    sct.printv('\nConcatenate rigid and non-linear transformations...', verbose)
    # cmd = 'isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R tmp.landmarks_straight.nii.gz -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz'
    cmd = 'isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R '+file_anat+ext_anat+' -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz'
    sct.printv(cmd, verbose, 'code')
    commands.getstatusoutput(cmd)

    # Apply transformation to input image
    sct.printv('\nApply transformation to input image...', verbose)
    sct.run('sct_apply_transfo -i '+file_anat+ext_anat+' -o tmp.anat_rigid_warp.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x '+interpolation_warp+' -w tmp.curve2straight.nii.gz', verbose)

    # compute the error between the straightened centerline/segmentation and the central vertical line.
    # Ideally, the error should be zero.
    # Apply deformation to input image
    print '\nApply transformation to input image...'
    c = sct.run('sct_apply_transfo -i '+fname_centerline_orient+' -o tmp.centerline_straight.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x nn -w tmp.curve2straight.nii.gz')
    #c = sct.run('sct_crop_image -i tmp.centerline_straight.nii.gz -o tmp.centerline_straight_crop.nii.gz -dim 2 -bzmax')
    from msct_image import Image
    file_centerline_straight = Image('tmp.centerline_straight.nii.gz')
    coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting='z')
    mean_coord = []
    for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
        mean_coord.append(mean([[coord.x*coord.value, coord.y*coord.value] for coord in coordinates_centerline if coord.z == z], axis=0))

    # compute error between the input data and the nurbs
    from math import sqrt
    mse_curve = 0.0
    max_dist = 0.0
    x0 = int(round(file_centerline_straight.data.shape[0]/2.0))
    y0 = int(round(file_centerline_straight.data.shape[1]/2.0))
    count_mean = 0
    for coord_z in mean_coord:
        if not isnan(sum(coord_z)):
            dist = ((x0-coord_z[0])*px)**2 + ((y0-coord_z[1])*py)**2
            mse_curve += dist
            dist = sqrt(dist)
            if dist > max_dist:
                max_dist = dist
            count_mean += 1
    mse_curve = mse_curve/float(count_mean)

    # come back to parent folder
    os.chdir('..')

    # Generate output file (in current folder)
    # TODO: do not uncompress the warping field, it is too time consuming!
    sct.printv('\nGenerate output file (in current folder)...', verbose)
    sct.generate_output_file(path_tmp+'/tmp.curve2straight.nii.gz', 'warp_curve2straight.nii.gz', verbose)  # warping field
    sct.generate_output_file(path_tmp+'/tmp.straight2curve.nii.gz', 'warp_straight2curve.nii.gz', verbose)  # warping field
    fname_straight = sct.generate_output_file(path_tmp+'/tmp.anat_rigid_warp.nii.gz', file_anat+'_straight'+ext_anat, verbose)  # straightened anatomic

    # Remove temporary files
    if remove_temp_files:
        sct.printv('\nRemove temporary files...', verbose)
        sct.run('rm -rf '+path_tmp, verbose)
    
    print '\nDone!\n'

    sct.printv('Maximum x-y error = '+str(round(max_dist,2))+' mm', verbose, 'bold')
    sct.printv('Accuracy of straightening (MSE) = '+str(round(mse_curve,2))+' mm', verbose, 'bold')
    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_straight+' &\n', verbose, 'info')
def main():
    # Initialization to defaults parameters
    fname_data = ''  # data is empty by default
    path_label = ''  # empty by default
    method = param.method # extraction mode by default
    labels_of_interest = param.labels_of_interest
    slices_of_interest = param.slices_of_interest
    vertebral_levels = param.vertebral_levels
    average_all_labels = param.average_all_labels
    fname_output = param.fname_output
    fname_vertebral_labeling = param.fname_vertebral_labeling
    fname_normalizing_label = ''  # optional then default is empty
    normalization_method = ''  # optional then default is empty
    actual_vert_levels = None  # variable used in case the vertebral levels asked by the user don't correspond exactly to the vertebral levels available in the metric data
    warning_vert_levels = None  # variable used to warn the user in case the vertebral levels he asked don't correspond exactly to the vertebral levels available in the metric data
    verbose = param.verbose
    flag_h = 0
    ml_clusters = param.ml_clusters
    adv_param = param.adv_param
    adv_param_user = ''

    # Parameters for debug mode
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        status, path_sct_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_data = '/Users/julien/data/temp/sct_example_data/mt/mtr.nii.gz'
        path_label = '/Users/julien/data/temp/sct_example_data/mt/label/atlas/'
        method = 'map'
        ml_clusters = '0:29,30,31'
        labels_of_interest = '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29'
        slices_of_interest = ''
        vertebral_levels = ''
        average_all_labels = 1
        fname_normalizing_label = ''  #path_sct+'/testing/data/errsm_23/mt/label/template/MNI-Poly-AMU_CSF.nii.gz'
        normalization_method = ''  #'whole'
    else:
        # Check input parameters
        try:
            opts, args = getopt.getopt(sys.argv[1:], 'haf:i:l:m:n:o:p:v:w:z:') # define flags
        except getopt.GetoptError as err: # check if the arguments are defined
            print str(err) # error
            usage() # display usage
        if not opts:
            usage()
        for opt, arg in opts: # explore flags
            if opt in '-a':
                average_all_labels = 1
            elif opt in '-f':
                path_label = os.path.abspath(arg)  # save path of labels folder
            elif opt == '-h':  # help option
                flag_h = 1
            elif opt in '-i':
                fname_data = arg
            elif opt in '-l':
                labels_of_interest = arg
            elif opt in '-m':  # method for metric extraction
                method = arg
            elif opt in '-n':  # filename of the label by which the user wants to normalize
                fname_normalizing_label = arg
            elif opt in '-o': # output option
                fname_output = arg  # fname of output file
            elif opt in '-p':
                adv_param_user = arg
            elif opt in '-v':
                # vertebral levels option, if the user wants to average the metric across specific vertebral levels
                 vertebral_levels = arg
            elif opt in '-w':
                # method used for the normalization by the metric estimation into the normalizing label (see flag -n): 'sbs' for slice-by-slice or 'whole' for normalization after estimation in the whole labels
                normalization_method = arg
            elif opt in '-z':  # slices numbers option
                slices_of_interest = arg # save labels numbers

    # Display usage with tract parameters by default in case files aren't chosen in arguments inputs
    if fname_data == '' or path_label == '' or flag_h:
        param.path_label = path_label
        usage()

    # Check existence of data file
    sct.printv('\ncheck existence of input files...', verbose)
    sct.check_file_exist(fname_data)
    sct.check_folder_exist(path_label)
    if fname_normalizing_label:
        sct.check_folder_exist(fname_normalizing_label)

    # add slash at the end
    path_label = sct.slash_at_the_end(path_label, 1)

    # Find path to the vertebral labeling file if vertebral levels were specified by the user
    if vertebral_levels:
        if slices_of_interest:  # impossible to select BOTH specific slices and specific vertebral levels
            print '\nERROR: You cannot select BOTH vertebral levels AND slice numbers.'
            usage()
        else:
            fname_vertebral_labeling_list = sct.find_file_within_folder(fname_vertebral_labeling, path_label + '..')
            if len(fname_vertebral_labeling_list) > 1:
                print color.red + 'ERROR: More than one file named \'' + fname_vertebral_labeling + ' were found in ' + path_label + '. Exit program.' + color.end
                sys.exit(2)
            elif len(fname_vertebral_labeling_list) == 0:
                print color.red + 'ERROR: No file named \'' + fname_vertebral_labeling + ' were found in ' + path_label + '. Exit program.' + color.end
                sys.exit(2)
            else:
                fname_vertebral_labeling = os.path.abspath(fname_vertebral_labeling_list[0])

    # Check input parameters
    check_method(method, fname_normalizing_label, normalization_method)

    # parse argument for param
    if not adv_param_user == '':
        adv_param = adv_param_user.replace(' ', '').split(',')  # remove spaces and parse with comma
        del adv_param_user  # clean variable
        # TODO: check integrity of input

    # Extract label info
    label_id, label_name, label_file = read_label_file(path_label, param.file_info_label)
    nb_labels_total = len(label_id)

    # check consistency of label input parameter.
    label_id_user, average_all_labels = check_labels(labels_of_interest, nb_labels_total, average_all_labels, method)  # If 'labels_of_interest' is empty, then
    # 'label_id_user' contains the index of all labels in the file info_label.txt

    # print parameters
    print '\nChecked parameters:'
    print '  data ...................... '+fname_data
    print '  folder label .............. '+path_label
    print '  selected labels ........... '+str(label_id_user)
    print '  estimation method ......... '+method
    print '  slices of interest ........ '+slices_of_interest
    print '  vertebral levels .......... '+vertebral_levels
    print '  vertebral labeling file.... '+fname_vertebral_labeling
    print '  advanced parameters ....... '+str(adv_param)

    # Check if the orientation of the data is RPI
    orientation_data = get_orientation(fname_data)

    # If orientation is not RPI, change to RPI
    if orientation_data != 'RPI':
        sct.printv('\nCreate temporary folder to change the orientation of the NIFTI files into RPI...', verbose)
        path_tmp = sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
        sct.create_folder(path_tmp)
        # change orientation and load data
        sct.printv('\nChange image orientation and load it...', verbose)
        data = nib.load(set_orientation(fname_data, 'RPI', path_tmp+'orient_data.nii')).get_data()
        # Do the same for labels
        sct.printv('\nChange labels orientation and load them...', verbose)
        labels = np.empty([nb_labels_total], dtype=object)  # labels(nb_labels_total, x, y, z)
        for i_label in range(0, nb_labels_total):
            labels[i_label] = nib.load(set_orientation(path_label+label_file[i_label], 'RPI', path_tmp+'orient_'+label_file[i_label])).get_data()
        if fname_normalizing_label:  # if the "normalization" option is wanted,
            normalizing_label = np.empty([1], dtype=object)  # choose this kind of structure so as to keep easily the
            # compatibility with the rest of the code (dimensions: (1, x, y, z))
            normalizing_label[0] = nib.load(set_orientation(fname_normalizing_label, 'RPI', path_tmp+'orient_normalizing_volume.nii')).get_data()
        if vertebral_levels:  # if vertebral levels were selected,
            data_vertebral_labeling = nib.load(set_orientation(fname_vertebral_labeling, 'RPI', path_tmp+'orient_vertebral_labeling.nii.gz')).get_data()
        # Remove the temporary folder used to change the NIFTI files orientation into RPI
        sct.printv('\nRemove the temporary folder...', verbose)
        status, output = commands.getstatusoutput('rm -rf ' + path_tmp)
    else:
        # Load image
        sct.printv('\nLoad image...', verbose)
        data = nib.load(fname_data).get_data()

        # Load labels
        sct.printv('\nLoad labels...', verbose)
        labels = np.empty([nb_labels_total], dtype=object)  # labels(nb_labels_total, x, y, z)
        for i_label in range(0, nb_labels_total):
            labels[i_label] = nib.load(path_label+label_file[i_label]).get_data()
        if fname_normalizing_label:  # if the "normalization" option is wanted,
            normalizing_label = np.empty([1], dtype=object)  # choose this kind of structure so as to keep easily the
            # compatibility with the rest of the code (dimensions: (1, x, y, z))
            normalizing_label[0] = nib.load(fname_normalizing_label).get_data()  # load the data of the normalizing label
        if vertebral_levels:  # if vertebral levels were selected,
            data_vertebral_labeling = nib.load(fname_vertebral_labeling).get_data()

    # Change metric data type into floats for future manipulations (normalization)
    data = np.float64(data)

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', verbose)
    nx, ny, nz = data.shape
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)

    # Get dimensions of labels
    sct.printv('\nGet dimensions of label...', verbose)
    nx_atlas, ny_atlas, nz_atlas = labels[0].shape
    sct.printv('.. '+str(nx_atlas)+' x '+str(ny_atlas)+' x '+str(nz_atlas)+' x '+str(nb_labels_total), verbose)

    # Check dimensions consistency between atlas and data
    if (nx, ny, nz) != (nx_atlas, ny_atlas, nz_atlas):
        print '\nERROR: Metric data and labels DO NOT HAVE SAME DIMENSIONS.'
        sys.exit(2)

    # Update the flag "slices_of_interest" according to the vertebral levels selected by user (if it's the case)
    if vertebral_levels:
        slices_of_interest, actual_vert_levels, warning_vert_levels = \
            get_slices_matching_with_vertebral_levels(data, vertebral_levels, data_vertebral_labeling)

    # select slice of interest by cropping data and labels
    if slices_of_interest:
        data = remove_slices(data, slices_of_interest)
        for i_label in range(0, nb_labels_total):
            labels[i_label] = remove_slices(labels[i_label], slices_of_interest)
        if fname_normalizing_label:  # if the "normalization" option was selected,
            normalizing_label[0] = remove_slices(normalizing_label[0], slices_of_interest)

    # if user wants to get unique value across labels, then combine all labels together
    if average_all_labels == 1:
        sum_labels_user = np.sum(labels[label_id_user])  # sum the labels selected by user
        if method == 'ml' or method == 'map':  # in case the maximum likelihood and the average across different labels are wanted
            labels_tmp = np.empty([nb_labels_total - len(label_id_user) + 1], dtype=object)
            labels = np.delete(labels, label_id_user)  # remove the labels selected by user
            labels_tmp[0] = sum_labels_user  # put the sum of the labels selected by user in first position of the tmp
            # variable
            for i_label in range(1, len(labels_tmp)):
                labels_tmp[i_label] = labels[i_label-1]  # fill the temporary array with the values of the non-selected labels
            labels = labels_tmp  # replace the initial labels value by the updated ones (with the summed labels)
            del labels_tmp  # delete the temporary labels
        else:  # in other cases than the maximum likelihood, we can remove other labels (not needed for estimation)
            labels = np.empty(1, dtype=object)
            labels[0] = sum_labels_user  # we create a new label array that includes only the summed labels

    if fname_normalizing_label:  # if the "normalization" option is wanted
        sct.printv('\nExtract normalization values...', verbose)
        if normalization_method == 'sbs':  # case: the user wants to normalize slice-by-slice
            for z in range(0, data.shape[-1]):
                normalizing_label_slice = np.empty([1], dtype=object)  # in order to keep compatibility with the function
                # 'extract_metric_within_tract', define a new array for the slice z of the normalizing labels
                normalizing_label_slice[0] = normalizing_label[0][..., z]
                metric_normalizing_label = extract_metric_within_tract(data[..., z], normalizing_label_slice, method, 0)
                # estimate the metric mean in the normalizing label for the slice z
                if metric_normalizing_label[0][0] != 0:
                    data[..., z] = data[..., z]/metric_normalizing_label[0][0]  # divide all the slice z by this value

        elif normalization_method == 'whole':  # case: the user wants to normalize after estimations in the whole labels
            metric_mean_norm_label, metric_std_norm_label = extract_metric_within_tract(data, normalizing_label, method, param.verbose)  # mean and std are lists

    # identify cluster for each tract (for use with robust ML)
    ml_clusters_array = get_clusters(ml_clusters, labels)

    # extract metrics within labels
    sct.printv('\nExtract metric within labels...', verbose)
    metric_mean, metric_std = extract_metric_within_tract(data, labels, method, verbose, ml_clusters_array, adv_param)  # mean and std are lists

    if fname_normalizing_label and normalization_method == 'whole':  # case: user wants to normalize after estimations in the whole labels
        metric_mean, metric_std = np.divide(metric_mean, metric_mean_norm_label), np.divide(metric_std, metric_std_norm_label)

    # update label name if average
    if average_all_labels == 1:
        label_name[0] = 'AVERAGED'+' -'.join(label_name[i] for i in label_id_user)  # concatenate the names of the
        # labels selected by the user if the average tag was asked
        label_id_user = [0]  # update "label_id_user" to select the "averaged" label (which is in first position)

    metric_mean = metric_mean[label_id_user]
    metric_std = metric_std[label_id_user]

    # display metrics
    sct.printv('\nEstimation results:', 1)
    for i in range(0, metric_mean.size):
        sct.printv(str(label_id_user[i])+', '+str(label_name[label_id_user[i]])+':    '+str(metric_mean[i])+' +/- '+str(metric_std[i]), 1, 'info')

    # save and display metrics
    save_metrics(label_id_user, label_name, slices_of_interest, metric_mean, metric_std, fname_output, fname_data,
                 method, fname_normalizing_label, actual_vert_levels, warning_vert_levels)
def main(segmentation_file=None, label_file=None, output_file_name=None, parameter = "binary_centerline", remove_temp_files = 1, verbose = 0 ):

#Process for a binary file as output:
    if parameter == "binary_centerline":

        # Binary_centerline: Process for only a segmentation file:
        if "-i" in arguments and "-l" not in arguments:
                    # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data, file_data, ext_data = sct.extract_fname(segmentation_file)

            # create temporary folder
            path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir '+path_tmp)

            # copy files into tmp folder
            sct.run('cp '+segmentation_file+' '+path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data
            set_orientation(file_data+ext_data, 'RPI', fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data+ext_data)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
            print '.. '+str(nx)+' x '+str(ny)+' y '+str(nz)+' z '+str(nt)

            print '\nOpen segmentation volume...'
            file = nibabel.load(fname_segmentation_orient)
            data = file.get_data()
            hdr = file.get_header()

            # Extract min and max index in Z direction
            X, Y, Z = (data>0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            y_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index+1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index+1):
                x_seg, y_seg = (data[:,:,iz]>0).nonzero()
                x_centerline[iz-min_z_index] = np.mean(x_seg)
                y_centerline[iz-min_z_index] = np.mean(y_seg)

            #ne sert a rien
            for k in range(len(X)):
                data[X[k],Y[k],Z[k]] = 0

            print len(x_centerline)
            # Fit the centerline points with splines and return the new fitted coordinates
                    #done with nurbs for now
            x_centerline_fit, y_centerline_fit,x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)
                        # Create an image with the centerline
            for iz in range(min_z_index, max_z_index+1):
                data[round(x_centerline_fit[iz-min_z_index]), round(y_centerline_fit[iz-min_z_index]), iz] = 1    #with nurbs fitting
                #data[round(x_centerline[iz-min_z_index]), round(y_centerline[iz-min_z_index]), iz] = 1             #without nurbs fitting


            # Write the centerline image in RPI orientation
            hdr.set_data_dtype('uint8') # set imagetype to uint8
            print '\nWrite NIFTI volumes...'
            img = nibabel.Nifti1Image(data, None, hdr)
            if output_file_name != None :
                file_name = output_file_name
            else: file_name = file_data+'_centerline'+ext_data
            nibabel.save(img,'tmp.centerline.nii')
            sct.generate_output_file('tmp.centerline.nii',file_name)

            del data

            # come back to parent folder
            os.chdir('..')

            # Change orientation of the output centerline into input orientation
            print '\nOrient centerline image to input orientation: ' + orientation
            set_orientation(path_tmp+'/'+file_name, orientation, file_name)

           # Remove temporary files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf '+path_tmp)

            return file_name


        # Binary_centerline: Process for only a label file:
        if "-l" in arguments and "-i" not in arguments:
            file = os.path.abspath(label_file)
            path_data, file_data, ext_data = sct.extract_fname(file)

            file = nibabel.load(label_file)
            data = file.get_data()
            hdr = file.get_header()

            X,Y,Z = (data>0).nonzero()
            Z_new = np.linspace(min(Z),max(Z),(max(Z)-min(Z)+1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose==1 :
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new,X_fit)
                plt.plot(Z,X,'o',linestyle = 'None')
                plt.show()

                plt.figure()
                plt.plot(Z_new,Y_fit)
                plt.plot(Z,Y,'o',linestyle = 'None')
                plt.show()

            data =data*0

            for i in xrange(len(X_fit)):
                data[X_fit[i],Y_fit[i],Z_new[i]] = 1


            # Create NIFTI image
            print '\nSave volume ...'
            hdr.set_data_dtype('float32') # set image type to uint8
            img = nibabel.Nifti1Image(data, None, hdr)
            if output_file_name != None :
                file_name = output_file_name
            else: file_name = file_data+'_centerline'+ext_data
            # save volume
            nibabel.save(img,file_name)
            print '\nFile created : ' + file_name

            del data



        #### Binary_centerline: Process for a segmentation file and a label file:
        if "-l" and "-i" in arguments:

            ## Creation of a temporary file that will contain each centerline file of the process
            path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir '+path_tmp)

            ##From label file create centerline image
            print '\nPROCESS PART 1: From label file create centerline image.'
            file_label = os.path.abspath(label_file)
            path_data_label, file_data_label, ext_data_label = sct.extract_fname(file_label)

            file_label = nibabel.load(label_file)

            #Copy label_file into temporary folder
            sct.run('cp '+label_file+' '+path_tmp)

            data_label = file_label.get_data()
            hdr_label = file_label.get_header()

            if verbose == 1:
                from copy import copy
                data_label_to_show = copy(data_label)

            X,Y,Z = (data_label>0).nonzero()
            Z_new = np.linspace(min(Z),max(Z),(max(Z)-min(Z)+1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose==1 :
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new,X_fit)
                plt.plot(Z,X,'o',linestyle = 'None')
                plt.show()

                plt.figure()
                plt.plot(Z_new,Y_fit)
                plt.plot(Z,Y,'o',linestyle = 'None')
                plt.show()

            data_label =data_label*0

            for i in xrange(len(X_fit)):
                data_label[X_fit[i],Y_fit[i],Z_new[i]] = 1

            # Create NIFTI image
            print '\nSave volume ...'
            hdr_label.set_data_dtype('float32') # set image type to uint8
            img = nibabel.Nifti1Image(data_label, None, hdr_label)
            # save volume
            file_name_label = file_data_label + '_centerline' + ext_data_label
            nibabel.save(img, file_name_label)
            print '\nFile created : ' + file_name_label

            # copy files into tmp folder
            sct.run('cp '+file_name_label+' '+path_tmp)
            #effacer fichier dans folder parent
            os.remove(file_name_label)
            del data_label


            ##From segmentation file create centerline image
            print '\nPROCESS PART 2: From segmentation file create centerline image.'
            # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data_seg, file_data_seg, ext_data_seg = sct.extract_fname(segmentation_file)

            # copy files into tmp folder
            sct.run('cp '+segmentation_file+' '+path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data_seg
            set_orientation(file_data_seg+ext_data_seg, 'RPI', fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data_seg+ext_data_seg)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
            print '.. '+str(nx)+' x '+str(ny)+' y '+str(nz)+' z '+str(nt)

            print '\nOpen segmentation volume...'
            file_seg = nibabel.load(fname_segmentation_orient)
            data_seg = file_seg.get_data()
            hdr_seg = file_seg.get_header()

            if verbose == 1:
                data_seg_to_show = copy(data_seg)

            # Extract min and max index in Z direction
            X, Y, Z = (data_seg>0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            y_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index+1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index+1):
                x_seg, y_seg = (data_seg[:,:,iz]>0).nonzero()
                x_centerline[iz-min_z_index] = np.mean(x_seg)
                y_centerline[iz-min_z_index] = np.mean(y_seg)
            for k in range(len(X)):
                data_seg[X[k],Y[k],Z[k]] = 0
            # Fit the centerline points with splines and return the new fitted coordinates
                    #done with nurbs for now
            x_centerline_fit, y_centerline_fit,x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)


            # Create an image with the centerline
            for iz in range(min_z_index, max_z_index+1):
                data_seg[round(x_centerline_fit[iz-min_z_index]), round(y_centerline_fit[iz-min_z_index]), iz] = 1
            # Write the centerline image in RPI orientation
            hdr_seg.set_data_dtype('uint8') # set imagetype to uint8
            print '\nWrite NIFTI volumes...'
            img = nibabel.Nifti1Image(data_seg, None, hdr_seg)
            nibabel.save(img,'tmp.centerline.nii')
            file_name_seg = file_data_seg+'_centerline'+ext_data_seg
            sct.generate_output_file('tmp.centerline.nii',file_name_seg)   #pb ici

            # copy files into parent folder
            #sct.run('cp '+file_name_seg+' ../')

            del data_seg

            # come back to parent folder
#            os.chdir('..')

            # Change orientation of the output centerline into input orientation
            print '\nOrient centerline image to input orientation: ' + orientation
            set_orientation(file_name_seg, orientation, file_name_seg)



            print '\nRemoving overlap of the centerline obtain with label file if there are any:'

            ## Remove overlap from centerline file obtain with label file
            remove_overlap(file_name_label, file_name_seg, "generated_centerline_without_overlap.nii.gz")


            ## Concatenation of the two centerline files
            print '\nConcatenation of the two centerline files:'
            if output_file_name != None :
                file_name = output_file_name
            else: file_name = 'centerline_total_from_label_and_seg'

            sct.run('fslmaths generated_centerline_without_overlap.nii.gz -add ' + file_name_seg + ' ' + file_name)



            if verbose == 1 :
                import matplotlib.pyplot as plt
                from scipy import ndimage

                #Get back concatenation of segmentation and labels before any processing
                data_concatenate = data_seg_to_show + data_label_to_show
                z_centerline = [iz for iz in range(0, nz, 1) if data_concatenate[:, :, iz].any()]
                nz_nonz = len(z_centerline)
                x_centerline = [0 for iz in range(0, nz_nonz, 1)]
                y_centerline = [0 for iz in range(0, nz_nonz, 1)]


                # Calculate centerline coordinates and create image of the centerline
                for iz in range(0, nz_nonz, 1):
                    x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(data_concatenate[:, :, z_centerline[iz]])

                #Load file with resulting centerline
                file_centerline_fit = nibabel.load(file_name)
                data_centerline_fit = file_centerline_fit.get_data()

                z_centerline_fit = [iz for iz in range(0, nz, 1) if data_centerline_fit[:, :, iz].any()]
                nz_nonz_fit = len(z_centerline_fit)
                x_centerline_fit_total = [0 for iz in range(0, nz_nonz_fit, 1)]
                y_centerline_fit_total = [0 for iz in range(0, nz_nonz_fit, 1)]

                #Convert to array
                x_centerline_fit_total = np.asarray(x_centerline_fit_total)
                y_centerline_fit_total = np.asarray(y_centerline_fit_total)
                #Calculate overlap between seg and label
                length_overlap = X_fit.shape[0] + x_centerline_fit.shape[0] - x_centerline_fit_total.shape[0]
                # The total fitting is the concatenation of the two fitting (
                for i in range(x_centerline_fit.shape[0]):
                    x_centerline_fit_total[i] = x_centerline_fit[i]
                    y_centerline_fit_total[i] = y_centerline_fit[i]
                for i in range(X_fit.shape[0]-length_overlap):
                    x_centerline_fit_total[x_centerline_fit.shape[0] + i] = X_fit[i+length_overlap]
                    y_centerline_fit_total[x_centerline_fit.shape[0] + i] = Y_fit[i+length_overlap]
                    print x_centerline_fit.shape[0] + i

                #for iz in range(0, nz_nonz_fit, 1):
                #    x_centerline_fit[iz], y_centerline_fit[iz] = ndimage.measurements.center_of_mass(data_centerline_fit[:, :, z_centerline_fit[iz]])

                #Creation of a vector x that takes into account the distance between the labels
                #x_centerline_fit = np.asarray(x_centerline_fit)
                #y_centerline_fit = np.asarray(y_centerline_fit)
                x_display = [0 for i in range(x_centerline_fit_total.shape[0])]
                y_display = [0 for i in range(y_centerline_fit_total.shape[0])]


                for i in range(0, nz_nonz, 1):
                    x_display[z_centerline[i]-z_centerline[0]] = x_centerline[i]
                    y_display[z_centerline[i]-z_centerline[0]] = y_centerline[i]

                plt.figure(1)
                plt.subplot(2,1,1)
                plt.plot(z_centerline_fit, x_display, 'ro')
                plt.plot(z_centerline_fit, x_centerline_fit_total)
                plt.xlabel("Z")
                plt.ylabel("X")
                plt.title("x and x_fit coordinates")

                plt.subplot(2,1,2)
                plt.plot(z_centerline_fit, y_display, 'ro')
                plt.plot(z_centerline_fit, y_centerline_fit_total)
                plt.xlabel("Z")
                plt.ylabel("Y")
                plt.title("y and y_fit coordinates")
                plt.show()

                del data_concatenate, data_label_to_show, data_seg_to_show, data_centerline_fit

            sct.run('cp '+file_name+' ../')

            # Copy result into parent folder
            sct.run('cp '+file_name+' ../')

            # Come back to parent folder
            os.chdir('..')

            # Remove temporary centerline files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf '+path_tmp)


  #Process for a text file as output:
    if parameter == "text_file" :
        print "\nText file process"
        #Process for only a segmentation file:
        if "-i" in arguments and "-l" not in arguments:

                    # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data, file_data, ext_data = sct.extract_fname(segmentation_file)


            # create temporary folder
            path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir '+path_tmp)

            # copy files into tmp folder
            sct.run('cp '+segmentation_file+' '+path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data
            set_orientation(file_data+ext_data, 'RPI', fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data+ext_data)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
            print '.. '+str(nx)+' x '+str(ny)+' y '+str(nz)+' z '+str(nt)

            print '\nOpen segmentation volume...'
            file = nibabel.load(fname_segmentation_orient)
            data = file.get_data()
            hdr = file.get_header()

            # Extract min and max index in Z direction
            X, Y, Z = (data>0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            y_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index+1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index+1):
                x_seg, y_seg = (data[:,:,iz]>0).nonzero()
                x_centerline[iz-min_z_index] = np.mean(x_seg)
                y_centerline[iz-min_z_index] = np.mean(y_seg)
            for k in range(len(X)):
                data[X[k],Y[k],Z[k]] = 0
            # Fit the centerline points with splines and return the new fitted coordinates
            x_centerline_fit, y_centerline_fit,x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)

            # Create output text file
            if output_file_name != None :
                file_name = output_file_name
            else: file_name = file_data+'_centerline'+'.txt'

            sct.printv('\nWrite text file...', verbose)
            #file_results = open("../"+file_name, 'w')
            file_results = open(file_name, 'w')
            for i in range(min_z_index, max_z_index+1):
                file_results.write(str(int(i)) + ' ' + str(x_centerline_fit[i-min_z_index]) + ' ' + str(y_centerline_fit[i-min_z_index]) + '\n')
            file_results.close()

            # Copy result into parent folder
            sct.run('cp '+file_name+' ../')

            del data

            # come back to parent folder
            os.chdir('..')


           # Remove temporary files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf '+path_tmp)

            return file_name


        #Process for only a label file:
        if "-l" in arguments and "-i" not in arguments:
            file = os.path.abspath(label_file)
            path_data, file_data, ext_data = sct.extract_fname(file)

            file = nibabel.load(label_file)
            data = file.get_data()
            hdr = file.get_header()

            X,Y,Z = (data>0).nonzero()
            Z_new = np.linspace(min(Z),max(Z),(max(Z)-min(Z)+1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose==1 :
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new,X_fit)
                plt.plot(Z,X,'o',linestyle = 'None')
                plt.show()

                plt.figure()
                plt.plot(Z_new,Y_fit)
                plt.plot(Z,Y,'o',linestyle = 'None')
                plt.show()

            data =data*0

            for iz in xrange(len(X_fit)):
                data[X_fit[iz],Y_fit[iz],Z_new[iz]] = 1

            # Create output text file
            sct.printv('\nWrite text file...', verbose)
            if output_file_name != None :
                file_name = output_file_name
            else: file_name = file_data+'_centerline'+ext_data
            file_results = open(file_name, 'w')
            min_z_index, max_z_index = min(Z), max(Z)
            for i in range(min_z_index, max_z_index+1):
                file_results.write(str(int(i)) + ' ' + str(X_fit[i-min_z_index]) + ' ' + str(Y_fit[i-min_z_index]) + '\n')
            file_results.close()

            del data

        #Process for a segmentation file and a label file:
        if "-l" and "-i" in arguments:

            ## Creation of a temporary file that will contain each centerline file of the process
            path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
            sct.run('mkdir '+path_tmp)

            ##From label file create centerline text file
            print '\nPROCESS PART 1: From label file create centerline text file.'
            file_label = os.path.abspath(label_file)
            path_data_label, file_data_label, ext_data_label = sct.extract_fname(file_label)

            file_label = nibabel.load(label_file)

            #Copy label_file into temporary folder
            sct.run('cp '+label_file+' '+path_tmp)

            data_label = file_label.get_data()
            hdr_label = file_label.get_header()

            X,Y,Z = (data_label>0).nonzero()
            Z_new = np.linspace(min(Z),max(Z),(max(Z)-min(Z)+1))

            # sort X and Y arrays using Z
            X = [X[i] for i in Z[:].argsort()]
            Y = [Y[i] for i in Z[:].argsort()]
            Z = [Z[i] for i in Z[:].argsort()]

            #print X, Y, Z

            f1 = interpolate.UnivariateSpline(Z, X)
            f2 = interpolate.UnivariateSpline(Z, Y)

            X_fit = f1(Z_new)
            Y_fit = f2(Z_new)

            #print X_fit
            #print Y_fit

            if verbose==1 :
                import matplotlib.pyplot as plt

                plt.figure()
                plt.plot(Z_new,X_fit)
                plt.plot(Z,X,'o',linestyle = 'None')
                plt.show()

                plt.figure()
                plt.plot(Z_new,Y_fit)
                plt.plot(Z,Y,'o',linestyle = 'None')
                plt.show()

            data_label =data_label*0

            for i in xrange(len(X_fit)):
                data_label[X_fit[i],Y_fit[i],Z_new[i]] = 1

            # Create output text file
            sct.printv('\nWrite text file...', verbose)
            file_name_label = file_data_label+'_centerline'+'.txt'
            file_results = open(path_tmp + '/' + file_name_label, 'w')
            min_z_index, max_z_index = min(Z), max(Z)
            for i in range(min_z_index, max_z_index+1):
                file_results.write(str(int(i)) + ' ' + str(X_fit[i-min_z_index]) + ' ' + str(Y_fit[i-min_z_index]) + '\n')
            file_results.close()

            # copy files into tmp folder
            #sct.run('cp '+file_name_label+' '+path_tmp)

            del data_label


            ##From segmentation file create centerline text file
            print '\nPROCESS PART 2: From segmentation file create centerline image.'
            # Extract path, file and extension
            segmentation_file = os.path.abspath(segmentation_file)
            path_data_seg, file_data_seg, ext_data_seg = sct.extract_fname(segmentation_file)

            # copy files into tmp folder
            sct.run('cp '+segmentation_file+' '+path_tmp)

            # go to tmp folder
            os.chdir(path_tmp)

            # Change orientation of the input segmentation into RPI
            print '\nOrient segmentation image to RPI orientation...'
            fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data_seg
            set_orientation(file_data_seg+ext_data_seg, 'RPI', fname_segmentation_orient)

            # Extract orientation of the input segmentation
            orientation = get_orientation(file_data_seg+ext_data_seg)
            print '\nOrientation of segmentation image: ' + orientation

            # Get size of data
            print '\nGet dimensions data...'
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_segmentation_orient)
            print '.. '+str(nx)+' x '+str(ny)+' y '+str(nz)+' z '+str(nt)

            print '\nOpen segmentation volume...'
            file_seg = nibabel.load(fname_segmentation_orient)
            data_seg = file_seg.get_data()
            hdr_seg = file_seg.get_header()

            # Extract min and max index in Z direction
            X, Y, Z = (data_seg>0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)
            x_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            y_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
            z_centerline = [iz for iz in range(min_z_index, max_z_index+1)]
            # Extract segmentation points and average per slice
            for iz in range(min_z_index, max_z_index+1):
                x_seg, y_seg = (data_seg[:,:,iz]>0).nonzero()
                x_centerline[iz-min_z_index] = np.mean(x_seg)
                y_centerline[iz-min_z_index] = np.mean(y_seg)
            for k in range(len(X)):
                data_seg[X[k],Y[k],Z[k]] = 0
            # Fit the centerline points with splines and return the new fitted coordinates
                    #done with nurbs for now
            x_centerline_fit, y_centerline_fit,x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = b_spline_centerline(x_centerline,y_centerline,z_centerline)


             # Create output text file
            file_name_seg = file_data_seg+'_centerline'+'.txt'
            sct.printv('\nWrite text file...', verbose)
            file_results = open(file_name_seg, 'w')
            for i in range(min_z_index, max_z_index+1):
                file_results.write(str(int(i)) + ' ' + str(x_centerline_fit[i-min_z_index]) + ' ' + str(y_centerline_fit[i-min_z_index]) + '\n')
            file_results.close()

            del data_seg


            print '\nRemoving overlap of the centerline obtain with label file if there are any:'

            ## Remove overlap from centerline file obtain with label file
            remove_overlap(file_name_label, file_name_seg, "generated_centerline_without_overlap1.txt", parameter=1)

            ## Concatenation of the two centerline files
            print '\nConcatenation of the two centerline files:'
            if output_file_name != None :
                file_name = output_file_name
            else: file_name = 'centerline_total_from_label_and_seg.txt'

            f_output = open(file_name, "w")
            f_output.close()
            with open(file_name_seg, "r") as f_seg:
                with open("generated_centerline_without_overlap1.txt", "r") as f:
                    with open(file_name, "w") as f_output:
                        data_line_seg = f_seg.readlines()
                        data_line = f.readlines()
                        for line in data_line_seg :
                            f_output.write(line)
                        for line in data_line :
                            f_output.write(line)

            # Copy result into parent folder
            sct.run('cp '+file_name+' ../')

            # Come back to parent folder
            os.chdir('..')

            # Remove temporary centerline files
            if remove_temp_files:
                print('\nRemove temporary files...')
                sct.run('rm -rf '+path_tmp)
def main():
    # Initialization to defaults parameters
    fname_data = ''  # data is empty by default
    path_label = ''  # empty by default
    method = param.method  # extraction mode by default
    labels_of_interest = param.labels_of_interest
    slices_of_interest = param.slices_of_interest
    vertebral_levels = param.vertebral_levels
    average_all_labels = param.average_all_labels
    fname_output = param.fname_output
    fname_vertebral_labeling = param.fname_vertebral_labeling
    fname_normalizing_label = ''  # optional then default is empty
    normalization_method = ''  # optional then default is empty
    actual_vert_levels = None  # variable used in case the vertebral levels asked by the user don't correspond exactly to the vertebral levels available in the metric data
    warning_vert_levels = None  # variable used to warn the user in case the vertebral levels he asked don't correspond exactly to the vertebral levels available in the metric data
    verbose = param.verbose
    flag_h = 0
    ml_clusters = param.ml_clusters
    adv_param = param.adv_param
    adv_param_user = ''

    # Parameters for debug mode
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        status, path_sct_data = commands.getstatusoutput(
            'echo $SCT_TESTING_DATA_DIR')
        fname_data = '/Users/julien/data/temp/sct_example_data/mt/mtr.nii.gz'
        path_label = '/Users/julien/data/temp/sct_example_data/mt/label/atlas/'
        method = 'map'
        ml_clusters = '0:29,30,31'
        labels_of_interest = '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29'
        slices_of_interest = ''
        vertebral_levels = ''
        average_all_labels = 1
        fname_normalizing_label = ''  #path_sct+'/testing/data/errsm_23/mt/label/template/MNI-Poly-AMU_CSF.nii.gz'
        normalization_method = ''  #'whole'
    else:
        # Check input parameters
        try:
            opts, args = getopt.getopt(
                sys.argv[1:], 'haf:i:l:m:n:o:p:v:w:z:')  # define flags
        except getopt.GetoptError as err:  # check if the arguments are defined
            print str(err)  # error
            usage()  # display usage
        if not opts:
            usage()
        for opt, arg in opts:  # explore flags
            if opt in '-a':
                average_all_labels = 1
            elif opt in '-f':
                path_label = os.path.abspath(arg)  # save path of labels folder
            elif opt == '-h':  # help option
                flag_h = 1
            elif opt in '-i':
                fname_data = arg
            elif opt in '-l':
                labels_of_interest = arg
            elif opt in '-m':  # method for metric extraction
                method = arg
            elif opt in '-n':  # filename of the label by which the user wants to normalize
                fname_normalizing_label = arg
            elif opt in '-o':  # output option
                fname_output = arg  # fname of output file
            elif opt in '-p':
                adv_param_user = arg
            elif opt in '-v':
                # vertebral levels option, if the user wants to average the metric across specific vertebral levels
                vertebral_levels = arg
            elif opt in '-w':
                # method used for the normalization by the metric estimation into the normalizing label (see flag -n): 'sbs' for slice-by-slice or 'whole' for normalization after estimation in the whole labels
                normalization_method = arg
            elif opt in '-z':  # slices numbers option
                slices_of_interest = arg  # save labels numbers

    # Display usage with tract parameters by default in case files aren't chosen in arguments inputs
    if fname_data == '' or path_label == '' or flag_h:
        param.path_label = path_label
        usage()

    # Check existence of data file
    sct.printv('\ncheck existence of input files...', verbose)
    sct.check_file_exist(fname_data)
    sct.check_folder_exist(path_label)
    if fname_normalizing_label:
        sct.check_folder_exist(fname_normalizing_label)

    # add slash at the end
    path_label = sct.slash_at_the_end(path_label, 1)

    # Find path to the vertebral labeling file if vertebral levels were specified by the user
    if vertebral_levels:
        if slices_of_interest:  # impossible to select BOTH specific slices and specific vertebral levels
            print '\nERROR: You cannot select BOTH vertebral levels AND slice numbers.'
            usage()
        else:
            fname_vertebral_labeling_list = sct.find_file_within_folder(
                fname_vertebral_labeling, path_label + '..')
            if len(fname_vertebral_labeling_list) > 1:
                print color.red + 'ERROR: More than one file named \'' + fname_vertebral_labeling + ' were found in ' + path_label + '. Exit program.' + color.end
                sys.exit(2)
            elif len(fname_vertebral_labeling_list) == 0:
                print color.red + 'ERROR: No file named \'' + fname_vertebral_labeling + ' were found in ' + path_label + '. Exit program.' + color.end
                sys.exit(2)
            else:
                fname_vertebral_labeling = os.path.abspath(
                    fname_vertebral_labeling_list[0])

    # Check input parameters
    check_method(method, fname_normalizing_label, normalization_method)

    # parse argument for param
    if not adv_param_user == '':
        adv_param = adv_param_user.replace(' ', '').split(
            ',')  # remove spaces and parse with comma
        del adv_param_user  # clean variable
        # TODO: check integrity of input

    # Extract label info
    label_id, label_name, label_file = read_label_file(path_label,
                                                       param.file_info_label)
    nb_labels_total = len(label_id)

    # check consistency of label input parameter.
    label_id_user, average_all_labels = check_labels(
        labels_of_interest, nb_labels_total, average_all_labels,
        method)  # If 'labels_of_interest' is empty, then
    # 'label_id_user' contains the index of all labels in the file info_label.txt

    # print parameters
    print '\nChecked parameters:'
    print '  data ...................... ' + fname_data
    print '  folder label .............. ' + path_label
    print '  selected labels ........... ' + str(label_id_user)
    print '  estimation method ......... ' + method
    print '  slices of interest ........ ' + slices_of_interest
    print '  vertebral levels .......... ' + vertebral_levels
    print '  vertebral labeling file.... ' + fname_vertebral_labeling
    print '  advanced parameters ....... ' + str(adv_param)

    # Check if the orientation of the data is RPI
    orientation_data = get_orientation(fname_data)

    # If orientation is not RPI, change to RPI
    if orientation_data != 'RPI':
        sct.printv(
            '\nCreate temporary folder to change the orientation of the NIFTI files into RPI...',
            verbose)
        path_tmp = sct.slash_at_the_end('tmp.' + time.strftime("%y%m%d%H%M%S"),
                                        1)
        sct.create_folder(path_tmp)
        # change orientation and load data
        sct.printv('\nChange image orientation and load it...', verbose)
        data = nib.load(
            set_orientation(fname_data, 'RPI',
                            path_tmp + 'orient_data.nii')).get_data()
        # Do the same for labels
        sct.printv('\nChange labels orientation and load them...', verbose)
        labels = np.empty([nb_labels_total],
                          dtype=object)  # labels(nb_labels_total, x, y, z)
        for i_label in range(0, nb_labels_total):
            labels[i_label] = nib.load(
                set_orientation(path_label + label_file[i_label], 'RPI',
                                path_tmp + 'orient_' +
                                label_file[i_label])).get_data()
        if fname_normalizing_label:  # if the "normalization" option is wanted,
            normalizing_label = np.empty(
                [1], dtype=object
            )  # choose this kind of structure so as to keep easily the
            # compatibility with the rest of the code (dimensions: (1, x, y, z))
            normalizing_label[0] = nib.load(
                set_orientation(fname_normalizing_label, 'RPI', path_tmp +
                                'orient_normalizing_volume.nii')).get_data()
        if vertebral_levels:  # if vertebral levels were selected,
            data_vertebral_labeling = nib.load(
                set_orientation(
                    fname_vertebral_labeling, 'RPI',
                    path_tmp + 'orient_vertebral_labeling.nii.gz')).get_data()
        # Remove the temporary folder used to change the NIFTI files orientation into RPI
        sct.printv('\nRemove the temporary folder...', verbose)
        status, output = commands.getstatusoutput('rm -rf ' + path_tmp)
    else:
        # Load image
        sct.printv('\nLoad image...', verbose)
        data = nib.load(fname_data).get_data()

        # Load labels
        sct.printv('\nLoad labels...', verbose)
        labels = np.empty([nb_labels_total],
                          dtype=object)  # labels(nb_labels_total, x, y, z)
        for i_label in range(0, nb_labels_total):
            labels[i_label] = nib.load(path_label +
                                       label_file[i_label]).get_data()
        if fname_normalizing_label:  # if the "normalization" option is wanted,
            normalizing_label = np.empty(
                [1], dtype=object
            )  # choose this kind of structure so as to keep easily the
            # compatibility with the rest of the code (dimensions: (1, x, y, z))
            normalizing_label[0] = nib.load(fname_normalizing_label).get_data(
            )  # load the data of the normalizing label
        if vertebral_levels:  # if vertebral levels were selected,
            data_vertebral_labeling = nib.load(
                fname_vertebral_labeling).get_data()

    # Change metric data type into floats for future manipulations (normalization)
    data = np.float64(data)

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', verbose)
    nx, ny, nz = data.shape
    sct.printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)

    # Get dimensions of labels
    sct.printv('\nGet dimensions of label...', verbose)
    nx_atlas, ny_atlas, nz_atlas = labels[0].shape
    sct.printv(
        '.. ' + str(nx_atlas) + ' x ' + str(ny_atlas) + ' x ' + str(nz_atlas) +
        ' x ' + str(nb_labels_total), verbose)

    # Check dimensions consistency between atlas and data
    if (nx, ny, nz) != (nx_atlas, ny_atlas, nz_atlas):
        print '\nERROR: Metric data and labels DO NOT HAVE SAME DIMENSIONS.'
        sys.exit(2)

    # Update the flag "slices_of_interest" according to the vertebral levels selected by user (if it's the case)
    if vertebral_levels:
        slices_of_interest, actual_vert_levels, warning_vert_levels = \
            get_slices_matching_with_vertebral_levels(data, vertebral_levels, data_vertebral_labeling)

    # select slice of interest by cropping data and labels
    if slices_of_interest:
        data = remove_slices(data, slices_of_interest)
        for i_label in range(0, nb_labels_total):
            labels[i_label] = remove_slices(labels[i_label],
                                            slices_of_interest)
        if fname_normalizing_label:  # if the "normalization" option was selected,
            normalizing_label[0] = remove_slices(normalizing_label[0],
                                                 slices_of_interest)

    # if user wants to get unique value across labels, then combine all labels together
    if average_all_labels == 1:
        sum_labels_user = np.sum(
            labels[label_id_user])  # sum the labels selected by user
        if method == 'ml' or method == 'map':  # in case the maximum likelihood and the average across different labels are wanted
            labels_tmp = np.empty([nb_labels_total - len(label_id_user) + 1],
                                  dtype=object)
            labels = np.delete(
                labels, label_id_user)  # remove the labels selected by user
            labels_tmp[
                0] = sum_labels_user  # put the sum of the labels selected by user in first position of the tmp
            # variable
            for i_label in range(1, len(labels_tmp)):
                labels_tmp[i_label] = labels[
                    i_label -
                    1]  # fill the temporary array with the values of the non-selected labels
            labels = labels_tmp  # replace the initial labels value by the updated ones (with the summed labels)
            del labels_tmp  # delete the temporary labels
        else:  # in other cases than the maximum likelihood, we can remove other labels (not needed for estimation)
            labels = np.empty(1, dtype=object)
            labels[
                0] = sum_labels_user  # we create a new label array that includes only the summed labels

    if fname_normalizing_label:  # if the "normalization" option is wanted
        sct.printv('\nExtract normalization values...', verbose)
        if normalization_method == 'sbs':  # case: the user wants to normalize slice-by-slice
            for z in range(0, data.shape[-1]):
                normalizing_label_slice = np.empty(
                    [1], dtype=object
                )  # in order to keep compatibility with the function
                # 'extract_metric_within_tract', define a new array for the slice z of the normalizing labels
                normalizing_label_slice[0] = normalizing_label[0][..., z]
                metric_normalizing_label = extract_metric_within_tract(
                    data[..., z], normalizing_label_slice, method, 0)
                # estimate the metric mean in the normalizing label for the slice z
                if metric_normalizing_label[0][0] != 0:
                    data[..., z] = data[..., z] / metric_normalizing_label[0][
                        0]  # divide all the slice z by this value

        elif normalization_method == 'whole':  # case: the user wants to normalize after estimations in the whole labels
            metric_mean_norm_label, metric_std_norm_label = extract_metric_within_tract(
                data, normalizing_label, method,
                param.verbose)  # mean and std are lists

    # identify cluster for each tract (for use with robust ML)
    ml_clusters_array = get_clusters(ml_clusters, labels)

    # extract metrics within labels
    sct.printv('\nExtract metric within labels...', verbose)
    metric_mean, metric_std = extract_metric_within_tract(
        data, labels, method, verbose, ml_clusters_array,
        adv_param)  # mean and std are lists

    if fname_normalizing_label and normalization_method == 'whole':  # case: user wants to normalize after estimations in the whole labels
        metric_mean, metric_std = np.divide(metric_mean,
                                            metric_mean_norm_label), np.divide(
                                                metric_std,
                                                metric_std_norm_label)

    # update label name if average
    if average_all_labels == 1:
        label_name[0] = 'AVERAGED' + ' -'.join(
            label_name[i]
            for i in label_id_user)  # concatenate the names of the
        # labels selected by the user if the average tag was asked
        label_id_user = [
            0
        ]  # update "label_id_user" to select the "averaged" label (which is in first position)

    metric_mean = metric_mean[label_id_user]
    metric_std = metric_std[label_id_user]

    # display metrics
    sct.printv('\nEstimation results:', 1)
    for i in range(0, metric_mean.size):
        sct.printv(
            str(label_id_user[i]) + ', ' + str(label_name[label_id_user[i]]) +
            ':    ' + str(metric_mean[i]) + ' +/- ' + str(metric_std[i]), 1,
            'info')

    # save and display metrics
    save_metrics(label_id_user, label_name, slices_of_interest, metric_mean,
                 metric_std, fname_output, fname_data, method,
                 fname_normalizing_label, actual_vert_levels,
                 warning_vert_levels)
    def straighten(self):
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        gapxy = self.gapxy
        gapz = self.gapz
        padding = self.padding
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp
        algo_fitting = self.algo_fitting
        window_length = self.window_length
        type_window = self.type_window
        crop = self.crop

        # start timer
        start_time = time.time()

        # get path of the toolbox
        status, path_sct = commands.getstatusoutput("echo $SCT_DIR")
        sct.printv(path_sct, verbose)

        if self.debug == 1:
            print "\n*** WARNING: DEBUG MODE ON ***\n"
            fname_anat = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/anat_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2.nii.gz'
            fname_centerline = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/centerline_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2_seg.nii.gz'
            remove_temp_files = 0
            type_window = "hanning"
            verbose = 2

        # check existence of input files
        sct.check_file_exist(fname_anat, verbose)
        sct.check_file_exist(fname_centerline, verbose)

        # Display arguments
        sct.printv("\nCheck input arguments...", verbose)
        sct.printv("  Input volume ...................... " + fname_anat, verbose)
        sct.printv("  Centerline ........................ " + fname_centerline, verbose)
        sct.printv("  Final interpolation ............... " + interpolation_warp, verbose)
        sct.printv("  Verbose ........................... " + str(verbose), verbose)
        sct.printv("", verbose)

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
        path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

        # create temporary folder
        path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
        sct.run("mkdir " + path_tmp, verbose)

        # copy files into tmp folder
        sct.run("cp " + fname_anat + " " + path_tmp, verbose)
        sct.run("cp " + fname_centerline + " " + path_tmp, verbose)

        # go to tmp folder
        os.chdir(path_tmp)

        try:
            # Change orientation of the input centerline into RPI
            sct.printv("\nOrient centerline to RPI orientation...", verbose)
            fname_centerline_orient = file_centerline + "_rpi.nii.gz"
            set_orientation(file_centerline + ext_centerline, "RPI", fname_centerline_orient)

            # Get dimension
            sct.printv("\nGet dimensions...", verbose)
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
            sct.printv(".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz), verbose)
            sct.printv(".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm", verbose)

            # smooth centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_centerline_orient,
                algo_fitting=algo_fitting,
                type_window=type_window,
                window_length=window_length,
                verbose=verbose,
            )

            # Get coordinates of landmarks along curved centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along curved centerline...", verbose)
            # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy. In voxel space!!!

            # find z indices along centerline given a specific gap: iz_curved
            nz_nonz = len(z_centerline)
            nb_landmark = int(round(float(nz_nonz) / gapz))

            if nb_landmark == 0:
                nb_landmark = 1

            if nb_landmark == 1:
                iz_curved = [0]
            else:
                iz_curved = [i * gapz for i in range(0, nb_landmark - 1)]

            iz_curved.append(nz_nonz - 1)
            # print iz_curved, len(iz_curved)
            n_iz_curved = len(iz_curved)
            # print n_iz_curved

            # landmark_curved initialisation
            # landmark_curved = [ [ [ 0 for i in range(0, 3)] for i in range(0, 5) ] for i in iz_curved ]

            from msct_types import Coordinate

            landmark_curved = []
            landmark_curved_value = 1

            ### TODO: THIS PART IS SLOW AND CAN BE MADE FASTER
            ### >>==============================================================================================================
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # calculate d (ax+by+cz+d=0)
                    # print iz_curved[index]
                    a = x_centerline_deriv[iz]
                    b = y_centerline_deriv[iz]
                    c = z_centerline_deriv[iz]
                    x = x_centerline_fit[iz]
                    y = y_centerline_fit[iz]
                    z = z_centerline[iz]
                    d = -(a * x + b * y + c * z)
                    # print a,b,c,d,x,y,z
                    # set coordinates for landmark at the center of the cross
                    coord = Coordinate([0, 0, 0, landmark_curved_value])
                    coord.x, coord.y, coord.z = x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz]
                    landmark_curved.append(coord)

                    # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
                    cross_coordinates = [
                        Coordinate([0, 0, 0, landmark_curved_value + 1]),
                        Coordinate([0, 0, 0, landmark_curved_value + 2]),
                        Coordinate([0, 0, 0, landmark_curved_value + 3]),
                        Coordinate([0, 0, 0, landmark_curved_value + 4]),
                    ]

                    cross_coordinates[0].y = y_centerline_fit[iz]
                    cross_coordinates[1].y = y_centerline_fit[iz]

                    # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
                    x_n = Symbol("x_n")
                    cross_coordinates[1].x, cross_coordinates[0].x = solve(
                        (x_n - x) ** 2 + ((-1 / c) * (a * x_n + b * y + d) - z) ** 2 - gapxy ** 2, x_n
                    )  # x for -x and +x
                    cross_coordinates[0].z = (-1 / c) * (a * cross_coordinates[0].x + b * y + d)  # z for +x
                    cross_coordinates[1].z = (-1 / c) * (a * cross_coordinates[1].x + b * y + d)  # z for -x

                    # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
                    cross_coordinates[2].x = x_centerline_fit[iz]
                    cross_coordinates[3].x = x_centerline_fit[iz]

                    # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
                    y_n = Symbol("y_n")
                    cross_coordinates[3].y, cross_coordinates[2].y = solve(
                        (y_n - y) ** 2 + ((-1 / c) * (a * x + b * y_n + d) - z) ** 2 - gapxy ** 2, y_n
                    )  # y for -y and +y
                    cross_coordinates[2].z = (-1 / c) * (a * x + b * cross_coordinates[2].y + d)  # z for +y
                    cross_coordinates[3].z = (-1 / c) * (a * x + b * cross_coordinates[3].y + d)  # z for -y

                    for coord in cross_coordinates:
                        landmark_curved.append(coord)
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_curved.append(
                            Coordinate(
                                [x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz], landmark_curved_value],
                                mode="continuous",
                            )
                        )
                        landmark_curved_value += 1
            ### <<==============================================================================================================

            # Get coordinates of landmarks along straight centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along straight centerline...", verbose)
            # landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved

            landmark_straight = []

            # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
            # TODO: DO NOT APPROXIMATE CURVE --> LINE
            if nb_landmark == 1:
                iz_straight = [0 for i in range(0, nb_landmark + 1)]
            else:
                iz_straight = [0 for i in range(0, nb_landmark)]

            # print iz_straight,len(iz_straight)
            iz_straight[0] = iz_curved[0]
            for index in range(1, n_iz_curved, 1):
                # compute vector between two consecutive points on the curved centerline
                vector_centerline = [
                    x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index - 1]],
                    y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index - 1]],
                    z_centerline[iz_curved[index]] - z_centerline[iz_curved[index - 1]],
                ]
                # compute norm of this vector
                norm_vector_centerline = linalg.norm(vector_centerline, ord=2)
                # round to closest integer value
                norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
                # assign this value to the current z-coordinate on the straight centerline
                iz_straight[index] = iz_straight[index - 1] + norm_vector_centerline_rounded

            # initialize x0 and y0 to be at the center of the FOV
            x0 = int(round(nx / 2))
            y0 = int(round(ny / 2))
            landmark_curved_value = 1
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # set coordinates for landmark at the center of the cross
                    landmark_straight.append(Coordinate([x0, y0, iz_straight[index], landmark_curved_value]))
                    # set x, y and z coordinates for landmarks +x
                    landmark_straight.append(
                        Coordinate([x0 + gapxy, y0, iz_straight[index], landmark_curved_value + 1])
                    )
                    # set x, y and z coordinates for landmarks -x
                    landmark_straight.append(
                        Coordinate([x0 - gapxy, y0, iz_straight[index], landmark_curved_value + 2])
                    )
                    # set x, y and z coordinates for landmarks +y
                    landmark_straight.append(
                        Coordinate([x0, y0 + gapxy, iz_straight[index], landmark_curved_value + 3])
                    )
                    # set x, y and z coordinates for landmarks -y
                    landmark_straight.append(
                        Coordinate([x0, y0 - gapxy, iz_straight[index], landmark_curved_value + 4])
                    )
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_straight.append(Coordinate([x0, y0, iz, landmark_curved_value]))
                        landmark_curved_value += 1

            # Create NIFTI volumes with landmarks
            # ==========================================================================================
            # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
            # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
            # sct.run('fslview ' + fname_centerline_orient)
            sct.printv("\nPad input volume to account for landmarks that fall outside the FOV...", verbose)
            sct.run(
                "isct_c3d "
                + fname_centerline_orient
                + " -pad "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox 0 -o tmp.centerline_pad.nii.gz",
                verbose,
            )

            # Open padded centerline for reading
            sct.printv("\nOpen padded centerline for reading...", verbose)
            file = load("tmp.centerline_pad.nii.gz")
            data = file.get_data()
            hdr = file.get_header()

            if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                # Reorganize landmarks
                points_fixed, points_moving = [], []
                for coord in landmark_straight:
                    points_fixed.append([coord.x, coord.y, coord.z])
                for coord in landmark_curved:
                    points_moving.append([coord.x, coord.y, coord.z])

                # Register curved landmarks on straight landmarks based on python implementation
                sct.printv("\nComputing rigid transformation (algo=" + self.algo_landmark_rigid + ") ...", verbose)
                import msct_register_landmarks

                (
                    rotation_matrix,
                    translation_array,
                    points_moving_reg,
                ) = msct_register_landmarks.getRigidTransformFromLandmarks(
                    points_fixed, points_moving, constraints=self.algo_landmark_rigid, show=False
                )

                # reorganize registered points
                landmark_curved_rigid = []
                for index_curved, ind in enumerate(range(0, len(points_moving_reg), 1)):
                    coord = Coordinate()
                    coord.x, coord.y, coord.z, coord.value = (
                        points_moving_reg[ind][0],
                        points_moving_reg[ind][1],
                        points_moving_reg[ind][2],
                        index_curved + 1,
                    )
                    landmark_curved_rigid.append(coord)

                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_curved_rigid_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved_rigid)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of curved landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_curved_rigid[index].x)),
                        int(round(landmark_curved_rigid[index].y)),
                        int(round(landmark_curved_rigid[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_rigid_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved_rigid[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_rigid_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved_rigid.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved_rigid.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # writing rigid transformation file
                text_file = open("tmp.curve2straight_rigid.txt", "w")
                text_file.write("#Insight Transform File V1.0\n")
                text_file.write("#Transform 0\n")
                text_file.write("Transform: AffineTransform_double_3_3\n")
                text_file.write(
                    "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
                    % (
                        rotation_matrix[0, 0],
                        rotation_matrix[0, 1],
                        rotation_matrix[0, 2],
                        rotation_matrix[1, 0],
                        rotation_matrix[1, 1],
                        rotation_matrix[1, 2],
                        rotation_matrix[2, 0],
                        rotation_matrix[2, 1],
                        rotation_matrix[2, 2],
                        -translation_array[0, 0],
                        translation_array[0, 1],
                        -translation_array[0, 2],
                    )
                )
                text_file.write("FixedParameters: 0 0 0\n")
                text_file.close()

            else:
                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # Estimate deformation field by pairing landmarks
                # ==========================================================================================
                # convert landmarks to INT
                sct.printv("\nConvert landmarks to INT...", verbose)
                sct.run("isct_c3d tmp.landmarks_straight.nii.gz -type int -o tmp.landmarks_straight.nii.gz", verbose)
                sct.run("isct_c3d tmp.landmarks_curved.nii.gz -type int -o tmp.landmarks_curved.nii.gz", verbose)

                # This stands to avoid overlapping between landmarks
                sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
                label_process_straight = ProcessLabels(
                    fname_label="tmp.landmarks_straight.nii.gz",
                    fname_output="tmp.landmarks_straight.nii.gz",
                    fname_ref="tmp.landmarks_curved.nii.gz",
                    verbose=verbose,
                )
                label_process_straight.process("remove")
                label_process_curved = ProcessLabels(
                    fname_label="tmp.landmarks_curved.nii.gz",
                    fname_output="tmp.landmarks_curved.nii.gz",
                    fname_ref="tmp.landmarks_straight.nii.gz",
                    verbose=verbose,
                )
                label_process_curved.process("remove")

                # Estimate rigid transformation
                sct.printv("\nEstimate rigid transformation between paired landmarks...", verbose)
                sct.run(
                    "isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt",
                    verbose,
                )

                # Apply rigid transformation
                sct.printv("\nApply rigid transformation to curved landmarks...", verbose)
                # sct.run('sct_apply_transfo -i tmp.landmarks_curved.nii.gz -o tmp.landmarks_curved_rigid.nii.gz -d tmp.landmarks_straight.nii.gz -w tmp.curve2straight_rigid.txt -x nn', verbose)
                Transform(
                    input_filename="tmp.landmarks_curved.nii.gz",
                    source_reg="tmp.landmarks_curved_rigid.nii.gz",
                    output_filename="tmp.landmarks_straight.nii.gz",
                    warp="tmp.curve2straight_rigid.txt",
                    interp="nn",
                    verbose=verbose,
                ).apply()

            if verbose == 2:
                from mpl_toolkits.mplot3d import Axes3D
                import matplotlib.pyplot as plt

                fig = plt.figure()
                ax = Axes3D(fig)
                ax.plot(x_centerline_fit, y_centerline_fit, z_centerline, zdir="z")
                ax.plot(
                    [coord.x for coord in landmark_curved],
                    [coord.y for coord in landmark_curved],
                    [coord.z for coord in landmark_curved],
                    ".",
                )
                ax.plot(
                    [coord.x for coord in landmark_straight],
                    [coord.y for coord in landmark_straight],
                    [coord.z for coord in landmark_straight],
                    "r.",
                )
                if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                    ax.plot(
                        [coord.x for coord in landmark_curved_rigid],
                        [coord.y for coord in landmark_curved_rigid],
                        [coord.z for coord in landmark_curved_rigid],
                        "b.",
                    )
                ax.set_xlabel("x")
                ax.set_ylabel("y")
                ax.set_zlabel("z")
                plt.show()

            # This stands to avoid overlapping between landmarks
            sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_straight.nii.gz",
                fname_output="tmp.landmarks_straight.nii.gz",
                fname_ref="tmp.landmarks_curved_rigid.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_curved_rigid.nii.gz",
                fname_output="tmp.landmarks_curved_rigid.nii.gz",
                fname_ref="tmp.landmarks_straight.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")

            # Estimate b-spline transformation curve --> straight
            sct.printv("\nEstimate b-spline transformation: curve --> straight...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # remove padding for straight labels
            if crop == 1:
                ImageCropper(
                    input_file="tmp.landmarks_straight.nii.gz",
                    output_file="tmp.landmarks_straight_crop.nii.gz",
                    dim="0,1,2",
                    bmax=True,
                    verbose=verbose,
                ).crop()
                pass
            else:
                sct.run("cp tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz", verbose)

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            # sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
            # !!! DO NOT USE sct.run HERE BECAUSE isct_ComposeMultiTransform OUTPUTS A NON-NULL STATUS !!!
            cmd = "isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight_crop.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt"
            sct.printv(cmd, verbose, "code")
            sct.run(cmd, self.verbose)
            # commands.getstatusoutput(cmd)

            # Estimate b-spline transformation straight --> curve
            # TODO: invert warping field instead of estimating a new one
            sct.printv("\nEstimate b-spline transformation: straight --> curve...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            cmd = (
                "isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R "
                + file_anat
                + ext_anat
                + " -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz"
            )
            sct.printv(cmd, verbose, "code")
            # commands.getstatusoutput(cmd)
            sct.run(cmd, self.verbose)

            # Apply transformation to input image
            sct.printv("\nApply transformation to input image...", verbose)
            Transform(
                input_filename=str(file_anat + ext_anat),
                source_reg="tmp.anat_rigid_warp.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp=interpolation_warp,
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()

            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            sct.printv("\nApply transformation to centerline image...", verbose)
            # sct.run('sct_apply_transfo -i '+fname_centerline_orient+' -o tmp.centerline_straight.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x nn -w tmp.curve2straight.nii.gz')
            Transform(
                input_filename=fname_centerline_orient,
                source_reg="tmp.centerline_straight.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp="nn",
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()
            # c = sct.run('sct_crop_image -i tmp.centerline_straight.nii.gz -o tmp.centerline_straight_crop.nii.gz -dim 2 -bzmax')
            from msct_image import Image

            file_centerline_straight = Image("tmp.centerline_straight.nii.gz", verbose=verbose)
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting="z")
            mean_coord = []
            for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
                mean_coord.append(
                    mean(
                        [
                            [coord.x * coord.value, coord.y * coord.value]
                            for coord in coordinates_centerline
                            if coord.z == z
                        ],
                        axis=0,
                    )
                )

            # compute error between the input data and the nurbs
            from math import sqrt

            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            for coord_z in mean_coord:
                if not isnan(sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2
                    self.mse_straightening += dist
                    dist = sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = sqrt(self.mse_straightening / float(count_mean))

        except Exception as e:
            sct.printv("WARNING: Exception during Straightening:", 1, "warning")
            print e

        os.chdir("..")

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        sct.printv("\nGenerate output file (in current folder)...", verbose)
        sct.generate_output_file(
            path_tmp + "/tmp.curve2straight.nii.gz", "warp_curve2straight.nii.gz", verbose
        )  # warping field
        sct.generate_output_file(
            path_tmp + "/tmp.straight2curve.nii.gz", "warp_straight2curve.nii.gz", verbose
        )  # warping field
        if fname_output == "":
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", file_anat + "_straight" + ext_anat, verbose
            )  # straightened anatomic
        else:
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", fname_output, verbose
            )  # straightened anatomic
        # Remove temporary files
        if remove_temp_files:
            sct.printv("\nRemove temporary files...", verbose)
            sct.run("rm -rf " + path_tmp, verbose)

        sct.printv("\nDone!\n", verbose)

        sct.printv("Maximum x-y error = " + str(round(self.max_distance_straightening, 2)) + " mm", verbose, "bold")
        sct.printv(
            "Accuracy of straightening (MSE) = " + str(round(self.mse_straightening, 2)) + " mm", verbose, "bold"
        )
        # display elapsed time
        elapsed_time = time.time() - start_time
        sct.printv("\nFinished! Elapsed time: " + str(int(round(elapsed_time))) + "s", verbose)
        sct.printv("\nTo view results, type:", verbose)
        sct.printv("fslview " + fname_straight + " &\n", verbose, "info")
Ejemplo n.º 17
0
def extract_centerline(fname_segmentation,
                       remove_temp_files,
                       name_output='',
                       verbose=0,
                       algo_fitting='hanning',
                       type_window='hanning',
                       window_length=80):

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir ' + path_tmp)

    # copy files into tmp folder
    sct.run('cp ' + fname_segmentation + ' ' + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', verbose)
    fname_segmentation_orient = 'segmentation_rpi' + ext_data
    set_orientation(file_data + ext_data, 'RPI', fname_segmentation_orient)

    # Get dimension
    sct.printv('\nGet dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_segmentation_orient).dim
    sct.printv(
        '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz),
        verbose)
    sct.printv(
        '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(pz) +
        'mm', verbose)

    # Extract orientation of the input segmentation
    orientation = get_orientation(file_data + ext_data)
    sct.printv('\nOrientation of segmentation image: ' + orientation, verbose)

    sct.printv('\nOpen segmentation volume...', verbose)
    file = nibabel.load(fname_segmentation_orient)
    data = file.get_data()
    hdr = file.get_header()

    # Extract min and max index in Z direction
    X, Y, Z = (data > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    x_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
    y_centerline = [0 for i in range(0, max_z_index - min_z_index + 1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index + 1)]
    # Extract segmentation points and average per slice
    for iz in range(min_z_index, max_z_index + 1):
        x_seg, y_seg = (data[:, :, iz] > 0).nonzero()
        x_centerline[iz - min_z_index] = np.mean(x_seg)
        y_centerline[iz - min_z_index] = np.mean(y_seg)
    for k in range(len(X)):
        data[X[k], Y[k], Z[k]] = 0

    # extract centerline and smooth it
    x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient,
        type_window=type_window,
        window_length=window_length,
        algo_fitting=algo_fitting,
        verbose=verbose)

    if verbose == 2:
        import matplotlib.pyplot as plt

        #Creation of a vector x that takes into account the distance between the labels
        nz_nonz = len(z_centerline)
        x_display = [0 for i in range(x_centerline_fit.shape[0])]
        y_display = [0 for i in range(y_centerline_fit.shape[0])]
        for i in range(0, nz_nonz, 1):
            x_display[int(z_centerline[i] - z_centerline[0])] = x_centerline[i]
            y_display[int(z_centerline[i] - z_centerline[0])] = y_centerline[i]

        plt.figure(1)
        plt.subplot(2, 1, 1)
        plt.plot(z_centerline_fit, x_display, 'ro')
        plt.plot(z_centerline_fit, x_centerline_fit)
        plt.xlabel("Z")
        plt.ylabel("X")
        plt.title("x and x_fit coordinates")

        plt.subplot(2, 1, 2)
        plt.plot(z_centerline_fit, y_display, 'ro')
        plt.plot(z_centerline_fit, y_centerline_fit)
        plt.xlabel("Z")
        plt.ylabel("Y")
        plt.title("y and y_fit coordinates")
        plt.show()

    # Create an image with the centerline
    for iz in range(min_z_index, max_z_index + 1):
        data[
            round(x_centerline_fit[iz - min_z_index]),
            round(y_centerline_fit[iz - min_z_index]),
            iz] = 1  # if index is out of bounds here for hanning: either the segmentation has holes or labels have been added to the file
    # Write the centerline image in RPI orientation
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    sct.printv('\nWrite NIFTI volumes...', verbose)
    img = nibabel.Nifti1Image(data, None, hdr)
    nibabel.save(img, 'centerline.nii.gz')
    # Define name if output name is not specified
    if name_output == 'csa_volume.nii.gz' or name_output == '':
        # sct.generate_output_file('centerline.nii.gz', file_data+'_centerline'+ext_data, verbose)
        name_output = file_data + '_centerline' + ext_data
    sct.generate_output_file('centerline.nii.gz', name_output, verbose)

    # create a txt file with the centerline
    path, rad_output, ext = sct.extract_fname(name_output)
    name_output_txt = rad_output + '.txt'
    sct.printv('\nWrite text file...', verbose)
    file_results = open(name_output_txt, 'w')
    for i in range(min_z_index, max_z_index + 1):
        file_results.write(
            str(int(i)) + ' ' + str(x_centerline_fit[i - min_z_index]) + ' ' +
            str(y_centerline_fit[i - min_z_index]) + '\n')
    file_results.close()

    # Copy result into parent folder
    sct.run('cp ' + name_output_txt + ' ../')

    del data

    # come back to parent folder
    os.chdir('..')

    # Change orientation of the output centerline into input orientation
    sct.printv(
        '\nOrient centerline image to input orientation: ' + orientation,
        verbose)
    fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data
    set_orientation(path_tmp + '/' + name_output, orientation, name_output)

    # Remove temporary files
    if remove_temp_files:
        sct.printv('\nRemove temporary files...', verbose)
        sct.run('rm -rf ' + path_tmp, verbose)

    return name_output
def main():

    # get default parameters
    step1 = Paramreg(step='1', type='seg', algo='slicereg', metric='MeanSquares', iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(name="-l",
                      type_value="file",
                      description="Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
                      mandatory=True,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(name="-p",
                      type_value=[[':'], 'str'],
                      description="""Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""+paramreg.steps['1'].type+"""\nalgo="""+paramreg.steps['1'].algo+"""\nmetric="""+paramreg.steps['1'].metric+"""\npoly="""+paramreg.steps['1'].poly+"""\n--\nstep=2\ntype="""+paramreg.steps['2'].type+"""\nalgo="""+paramreg.steps['2'].algo+"""\nmetric="""+paramreg.steps['2'].metric+"""\niter="""+paramreg.steps['2'].iter+"""\nshrink="""+paramreg.steps['2'].shrink+"""\nsmooth="""+paramreg.steps['2'].smooth+"""\ngradStep="""+paramreg.steps['2'].gradStep+"""\n--""",
                      mandatory=False,
                      example="step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3")
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
                      mandatory=False,
                      default_value=param.verbose,
                      example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates()
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'/data.nii')
    sct.run('isct_c3d '+fname_landmarks+' -o '+path_tmp+'/landmarks.nii.gz')
    sct.run('isct_c3d '+fname_seg+' -o '+path_tmp+'/segmentation.nii.gz')
    sct.run('isct_c3d '+fname_template+' -o '+path_tmp+'/template.nii')
    sct.run('isct_c3d '+fname_template_label+' -o '+path_tmp+'/template_labels.nii.gz')
    sct.run('isct_c3d '+fname_template_seg+' -o '+path_tmp+'/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run('sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz')

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d template_label.nii.gz -type int -o template_label.nii.gz', verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run('sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5')

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv('\nCreate a 5mm cross for the input labels and dilate for straightening preparation...', verbose)
    sct.run('sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn')

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run('isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz')

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv('\nEstimate affine transfo: straight anat --> template (landmark-based)...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetAffineTransform template_label_cross.nii.gz landmarks_rpi_cross3x3_straight.nii.gz affine straight2templateAffine.txt')

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear')

    # find min-max of anat2template (for subsequent cropping)
    sct.run('export FSLOUTPUTTYPE=NIFTI; fslmaths segmentation_rpi_straight2templateAffine.nii.gz -thr 0.5 segmentation_rpi_straight2templateAffine_th.nii.gz', param.verbose)
    zmin_template, zmax_template = find_zmin_zmax('segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.run('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'/warp_template2anat.nii.gz', 'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'/warp_anat2template.nii.gz', 'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'/template2anat.nii.gz', 'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'/anat2template.nii.gz', 'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 anat2template &\n', verbose, 'info')
def extract_centerline(fname_segmentation, remove_temp_files, name_output='', verbose = 0, algo_fitting = 'hanning', type_window = 'hanning', window_length = 80):

    # Extract path, file and extension
    fname_segmentation = os.path.abspath(fname_segmentation)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # create temporary folder
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir '+path_tmp)

    # copy files into tmp folder
    sct.run('cp '+fname_segmentation+' '+path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', verbose)
    fname_segmentation_orient = 'segmentation_rpi' + ext_data
    set_orientation(file_data+ext_data, 'RPI', fname_segmentation_orient)

    # Get dimension
    sct.printv('\nGet dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_segmentation_orient).dim
    sct.printv('.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
    sct.printv('.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm', verbose)

    # Extract orientation of the input segmentation
    orientation = get_orientation(file_data+ext_data)
    sct.printv('\nOrientation of segmentation image: ' + orientation, verbose)

    sct.printv('\nOpen segmentation volume...', verbose)
    file = nibabel.load(fname_segmentation_orient)
    data = file.get_data()
    hdr = file.get_header()

    # Extract min and max index in Z direction
    X, Y, Z = (data>0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    x_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
    y_centerline = [0 for i in range(0,max_z_index-min_z_index+1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index+1)]
    # Extract segmentation points and average per slice
    for iz in range(min_z_index, max_z_index+1):
        x_seg, y_seg = (data[:,:,iz]>0).nonzero()
        x_centerline[iz-min_z_index] = np.mean(x_seg)
        y_centerline[iz-min_z_index] = np.mean(y_seg)
    for k in range(len(X)):
        data[X[k], Y[k], Z[k]] = 0

    # extract centerline and smooth it
    x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv,y_centerline_deriv,z_centerline_deriv = smooth_centerline(fname_segmentation_orient, type_window = type_window, window_length = window_length, algo_fitting = algo_fitting, verbose = verbose)

    if verbose == 2:
            import matplotlib.pyplot as plt

            #Creation of a vector x that takes into account the distance between the labels
            nz_nonz = len(z_centerline)
            x_display = [0 for i in range(x_centerline_fit.shape[0])]
            y_display = [0 for i in range(y_centerline_fit.shape[0])]
            for i in range(0, nz_nonz, 1):
                x_display[int(z_centerline[i]-z_centerline[0])] = x_centerline[i]
                y_display[int(z_centerline[i]-z_centerline[0])] = y_centerline[i]

            plt.figure(1)
            plt.subplot(2,1,1)
            plt.plot(z_centerline_fit, x_display, 'ro')
            plt.plot(z_centerline_fit, x_centerline_fit)
            plt.xlabel("Z")
            plt.ylabel("X")
            plt.title("x and x_fit coordinates")

            plt.subplot(2,1,2)
            plt.plot(z_centerline_fit, y_display, 'ro')
            plt.plot(z_centerline_fit, y_centerline_fit)
            plt.xlabel("Z")
            plt.ylabel("Y")
            plt.title("y and y_fit coordinates")
            plt.show()


    # Create an image with the centerline
    for iz in range(min_z_index, max_z_index+1):
        data[round(x_centerline_fit[iz-min_z_index]), round(y_centerline_fit[iz-min_z_index]), iz] = 1 # if index is out of bounds here for hanning: either the segmentation has holes or labels have been added to the file
    # Write the centerline image in RPI orientation
    hdr.set_data_dtype('uint8') # set imagetype to uint8
    sct.printv('\nWrite NIFTI volumes...', verbose)
    img = nibabel.Nifti1Image(data, None, hdr)
    nibabel.save(img, 'centerline.nii.gz')
    # Define name if output name is not specified
    if name_output=='csa_volume.nii.gz' or name_output=='':
        # sct.generate_output_file('centerline.nii.gz', file_data+'_centerline'+ext_data, verbose)
        name_output = file_data+'_centerline'+ext_data
    sct.generate_output_file('centerline.nii.gz', name_output, verbose)

    # create a txt file with the centerline
    path, rad_output, ext = sct.extract_fname(name_output)
    name_output_txt = rad_output + '.txt'
    sct.printv('\nWrite text file...', verbose)
    file_results = open(name_output_txt, 'w')
    for i in range(min_z_index, max_z_index+1):
        file_results.write(str(int(i)) + ' ' + str(x_centerline_fit[i-min_z_index]) + ' ' + str(y_centerline_fit[i-min_z_index]) + '\n')
    file_results.close()

    # Copy result into parent folder
    sct.run('cp '+name_output_txt+' ../')

    del data

    # come back to parent folder
    os.chdir('..')

    # Change orientation of the output centerline into input orientation
    sct.printv('\nOrient centerline image to input orientation: ' + orientation, verbose)
    fname_segmentation_orient = 'tmp.segmentation_rpi' + ext_data
    set_orientation(path_tmp+'/'+name_output, orientation, name_output)

   # Remove temporary files
    if remove_temp_files:
        sct.printv('\nRemove temporary files...', verbose)
        sct.run('rm -rf '+path_tmp, verbose)

    return name_output
def main():

    # Initialization
    fname_anat = ''
    fname_centerline = ''
    sigma = 3 # default value of the standard deviation for the Gaussian smoothing (in terms of number of voxels)
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    start_time = time.time()


    # Check input param
    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hi:c:r:s:v:')
    except getopt.GetoptError as err:
        print str(err)
        usage()
    if not opts:
        usage()
    for opt, arg in opts:
        if opt == '-h':
            usage()
        elif opt in ('-c'):
            fname_centerline = arg
        elif opt in ('-i'):
            fname_anat = arg
        elif opt in ('-r'):
            remove_temp_files = arg
        elif opt in ('-s'):
            sigma = arg
        elif opt in ('-v'):
            verbose = int(arg)

    # Display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # Display arguments
    print '\nCheck input arguments...'
    print '  Volume to smooth .................. ' + fname_anat
    print '  Centerline ........................ ' + fname_centerline
    print '  FWHM .............................. '+str(sigma)
    print '  Verbose ........................... '+str(verbose)

    # Check existence of input files
    print('\nCheck existence of input files...')
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    print('\nCopy files...')
    sct.run('isct_c3d '+fname_anat+' -o '+path_tmp+'/anat.nii')
    sct.run('isct_c3d '+fname_centerline+' -o '+path_tmp+'/centerline.nii')

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input image into RPI
    print '\nOrient input volume to RPI orientation...'
    set_orientation('anat.nii', 'RPI', 'anat_rpi.nii')
    # Change orientation of the input image into RPI
    print '\nOrient centerline to RPI orientation...'
    set_orientation('centerline.nii', 'RPI', 'centerline_rpi.nii')

    # Straighten the spinal cord
    print '\nStraighten the spinal cord...'
    sct.run('sct_straighten_spinalcord -i anat_rpi.nii -c centerline_rpi.nii -x spline -v '+str(verbose))

    # Smooth the straightened image along z
    print '\nSmooth the straightened image along z...'
    sct.run('isct_c3d anat_rpi_straight.nii -smooth 0x0x'+str(sigma)+'vox -o anat_rpi_straight_smooth.nii', verbose)

    # Apply the reversed warping field to get back the curved spinal cord
    print '\nApply the reversed warping field to get back the curved spinal cord...'
    sct.run('sct_apply_transfo -i anat_rpi_straight_smooth.nii -o anat_rpi_straight_smooth_curved.nii -d anat.nii -w warp_straight2curve.nii.gz -x spline', verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output file
    print '\nGenerate output file...'
    sct.generate_output_file(path_tmp+'/anat_rpi_straight_smooth_curved.nii', file_anat+'_smooth'+ext_anat)

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)

    # Display elapsed time
    elapsed_time = time.time() - start_time
    print '\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s\n'

    # to view results
    sct.printv('Done! To view results, type:', verbose)
    sct.printv('fslview '+file_anat+' '+file_anat+'_smooth &\n', verbose, 'info')
def main():

    # Initialization
    fname_anat = ''
    fname_point = ''
    slice_gap = param.gap
    remove_tmp_files = param.remove_tmp_files
    gaussian_kernel = param.gaussian_kernel
    start_time = time.time()

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    path_sct = sct.slash_at_the_end(path_sct, 1)

    # Parameters for debug mode
    if param.debug == 1:
        sct.printv('\n*** WARNING: DEBUG MODE ON ***\n\t\t\tCurrent working directory: '+os.getcwd(), 'warning')
        status, path_sct_testing_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_testing_data+'/t2/t2.nii.gz'
        fname_point = path_sct_testing_data+'/t2/t2_centerline_init.nii.gz'
        slice_gap = 5

    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:p:g:r:k:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-p'):
                fname_point = arg
            elif opt in ('-g'):
                slice_gap = int(arg)
            elif opt in ('-r'):
                remove_tmp_files = int(arg)
            elif opt in ('-k'):
                gaussian_kernel = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_point == '':
        usage()

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_point)

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_point, file_point, ext_point = sct.extract_fname(fname_point)

    # extract path of schedule file
    # TODO: include schedule file in sct
    # TODO: check existence of schedule file
    file_schedule = path_sct + param.schedule_file

    # Get input image orientation
    input_image_orientation = get_orientation(fname_anat)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Anatomical image:     '+fname_anat
    print '  Orientation:          '+input_image_orientation
    print '  Point in spinal cord: '+fname_point
    print '  Slice gap:            '+str(slice_gap)
    print '  Gaussian kernel:      '+str(gaussian_kernel)
    print '  Degree of polynomial: '+str(param.deg_poly)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp '+fname_anat+ ' '+path_tmp+'/tmp.anat'+ext_anat)
    sct.run('cp '+fname_point+ ' '+path_tmp+'/tmp.point'+ext_point)

    # go to temporary folder
    os.chdir(path_tmp)

    # convert to nii
    sct.run('fslchfiletype NIFTI tmp.anat')
    sct.run('fslchfiletype NIFTI tmp.point')

    # Reorient input anatomical volume into RL PA IS orientation
    print '\nReorient input volume to RL PA IS orientation...'
    #sct.run(sct.fsloutput + 'fslswapdim tmp.anat RL PA IS tmp.anat_orient')
    set_orientation('tmp.anat.nii', 'RPI', 'tmp.anat_orient.nii')
    # Reorient binary point into RL PA IS orientation
    print '\nReorient binary point into RL PA IS orientation...'
    sct.run(sct.fsloutput + 'fslswapdim tmp.point RL PA IS tmp.point_orient')
    set_orientation('tmp.point.nii', 'RPI', 'tmp.point_orient')

    # Get image dimensions
    print '\nGet image dimensions...'
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension('tmp.anat_orient')
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'

    # Split input volume
    print '\nSplit input volume...'
    sct.run(sct.fsloutput + 'fslsplit tmp.anat_orient tmp.anat_orient_z -z')
    file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Get the coordinates of the input point
    print '\nGet the coordinates of the input point...'
    file = nibabel.load('tmp.point_orient.nii')
    data = file.get_data()
    x_init, y_init, z_init = (data > 0).nonzero()
    x_init = x_init[0]
    y_init = y_init[0]
    z_init = z_init[0]
    print '('+str(x_init)+', '+str(y_init)+', '+str(z_init)+')'

    # Extract the slice corresponding to z=z_init
    print '\nExtract the slice corresponding to z='+str(z_init)+'...'
    file_point_split = ['tmp.point_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    sct.run(sct.fsloutput+'fslroi tmp.point_orient '+file_point_split[z_init]+' 0 -1 0 -1 '+str(z_init)+' 1')

    # Create gaussian mask from point
    print '\nCreate gaussian mask from point...'
    file_mask_split = ['tmp.mask_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]
    sct.run(sct.fsloutput+'fslmaths '+file_point_split[z_init]+' -s '+str(gaussian_kernel)+' '+file_mask_split[z_init])

    # Obtain max value from mask
    print '\nFind maximum value from mask...'
    file = nibabel.load(file_mask_split[z_init]+'.nii')
    data = file.get_data()
    max_value_mask = numpy.max(data)
    print '..'+str(max_value_mask)

    # Normalize mask beween 0 and 1
    print '\nNormalize mask beween 0 and 1...'
    sct.run(sct.fsloutput+'fslmaths '+file_mask_split[z_init]+' -div '+str(max_value_mask)+' '+file_mask_split[z_init])

    ## Take the square of the mask
    #print '\nCalculate the square of the mask...'
    #sct.run(sct.fsloutput+'fslmaths '+file_mask_split[z_init]+' -mul '+file_mask_split[z_init]+' '+file_mask_split[z_init])

    # initialize variables
    file_mat = ['tmp.mat_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_inv = ['tmp.mat_inv_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # create identity matrix for initial transformation matrix
    fid = open(file_mat_inv_cumul[z_init], 'w')
    fid.write('%i %i %i %i\n' %(1, 0, 0, 0) )
    fid.write('%i %i %i %i\n' %(0, 1, 0, 0) )
    fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
    fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
    fid.close()

    # initialize centerline: give value corresponding to initial point
    x_centerline = [x_init]
    y_centerline = [y_init]
    z_centerline = [z_init]
    warning_count = 0

    # go up (1), then down (2) in reference to the binary point
    for iUpDown in range(1, 3):

        if iUpDown == 1:
            # z increases
            slice_gap_signed = slice_gap
        elif iUpDown == 2:
            # z decreases
            slice_gap_signed = -slice_gap
            # reverse centerline (because values will be appended at the end)
            x_centerline.reverse()
            y_centerline.reverse()
            z_centerline.reverse()

        # initialization before looping
        z_dest = z_init # point given by user
        z_src = z_dest + slice_gap_signed

        # continue looping if 0 < z < nz
        while 0 <= z_src and z_src <= nz-1:

            # print current z:
            print 'z='+str(z_src)+':'

            # estimate transformation
            sct.run(fsloutput+'flirt -in '+file_anat_split[z_src]+' -ref '+file_anat_split[z_dest]+' -schedule '+file_schedule+ ' -verbose 0 -omat '+file_mat[z_src]+' -cost normcorr -forcescaling -inweight '+file_mask_split[z_dest]+' -refweight '+file_mask_split[z_dest])

            # display transfo
            status, output = sct.run('cat '+file_mat[z_src])
            print output

            # check if transformation is bigger than 1.5x slice_gap
            tx = float(output.split()[3])
            ty = float(output.split()[7])
            norm_txy = numpy.linalg.norm([tx, ty],ord=2)
            if norm_txy > 1.5*slice_gap:
                print 'WARNING: Transformation is too large --> using previous one.'
                warning_count = warning_count + 1
                # if previous transformation exists, replace current one with previous one
                if os.path.isfile(file_mat[z_dest]):
                    sct.run('cp '+file_mat[z_dest]+' '+file_mat[z_src])

            # estimate inverse transformation matrix
            sct.run('convert_xfm -omat '+file_mat_inv[z_src]+' -inverse '+file_mat[z_src])

            # compute cumulative transformation
            sct.run('convert_xfm -omat '+file_mat_inv_cumul[z_src]+' -concat '+file_mat_inv[z_src]+' '+file_mat_inv_cumul[z_dest])

            # apply inverse cumulative transformation to initial gaussian mask (to put it in src space)
            sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul[z_src]+' -out '+file_mask_split[z_src])

            # open inverse cumulative transformation file and generate centerline
            fid = open(file_mat_inv_cumul[z_src])
            mat = fid.read().split()
            x_centerline.append(x_init + float(mat[3]))
            y_centerline.append(y_init + float(mat[7]))
            z_centerline.append(z_src)
            #z_index = z_index+1

            # define new z_dest (target slice) and new z_src (moving slice)
            z_dest = z_dest + slice_gap_signed
            z_src = z_src + slice_gap_signed


    # Reconstruct centerline
    # ====================================================================================================

    # reverse back centerline (because it's been reversed once, so now all values are in the right order)
    x_centerline.reverse()
    y_centerline.reverse()
    z_centerline.reverse()

    # fit centerline in the Z-X plane using polynomial function
    print '\nFit centerline in the Z-X plane using polynomial function...'
    coeffsx = numpy.polyfit(z_centerline, x_centerline, deg=param.deg_poly)
    polyx = numpy.poly1d(coeffsx)
    x_centerline_fit = numpy.polyval(polyx, z_centerline)
    # calculate RMSE
    rmse = numpy.linalg.norm(x_centerline_fit-x_centerline)/numpy.sqrt( len(x_centerline) )
    # calculate max absolute error
    max_abs = numpy.max( numpy.abs(x_centerline_fit-x_centerline) )
    print '.. RMSE (in mm): '+str(rmse*px)
    print '.. Maximum absolute error (in mm): '+str(max_abs*px)

    # fit centerline in the Z-Y plane using polynomial function
    print '\nFit centerline in the Z-Y plane using polynomial function...'
    coeffsy = numpy.polyfit(z_centerline, y_centerline, deg=param.deg_poly)
    polyy = numpy.poly1d(coeffsy)
    y_centerline_fit = numpy.polyval(polyy, z_centerline)
    # calculate RMSE
    rmse = numpy.linalg.norm(y_centerline_fit-y_centerline)/numpy.sqrt( len(y_centerline) )
    # calculate max absolute error
    max_abs = numpy.max( numpy.abs(y_centerline_fit-y_centerline) )
    print '.. RMSE (in mm): '+str(rmse*py)
    print '.. Maximum absolute error (in mm): '+str(max_abs*py)

    # display
    if param.debug == 1:
        import matplotlib.pyplot as plt
        plt.figure()
        plt.plot(z_centerline,x_centerline,'.',z_centerline,x_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-X plane polynomial interpolation')
        plt.show()

        plt.figure()
        plt.plot(z_centerline,y_centerline,'.',z_centerline,y_centerline_fit,'r')
        plt.legend(['Data','Polynomial Fit'])
        plt.title('Z-Y plane polynomial interpolation')
        plt.show()

    # generate full range z-values for centerline
    z_centerline_full = [iz for iz in range(0, nz, 1)]

    # calculate X and Y values for the full centerline
    x_centerline_fit_full = numpy.polyval(polyx, z_centerline_full)
    y_centerline_fit_full = numpy.polyval(polyy, z_centerline_full)

    # Generate fitted transformation matrices and write centerline coordinates in text file
    print '\nGenerate fitted transformation matrices and write centerline coordinates in text file...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mat_cumul_fit = ['tmp.mat_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    fid_centerline = open('tmp.centerline_coordinates.txt', 'w')
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, x_centerline_fit_full[iz]-x_init) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, y_centerline_fit_full[iz]-y_init) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()
        # compute forward cumulative fitted transformation matrix
        sct.run('convert_xfm -omat '+file_mat_cumul_fit[iz]+' -inverse '+file_mat_inv_cumul_fit[iz])
        # write centerline coordinates in x, y, z format
        fid_centerline.write('%f %f %f\n' %(x_centerline_fit_full[iz], y_centerline_fit_full[iz], z_centerline_full[iz]) )
    fid_centerline.close()


    # Prepare output data
    # ====================================================================================================

    # write centerline as text file
    for iz in range(0, nz, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, x_centerline_fit_full[iz]-x_init) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, y_centerline_fit_full[iz]-y_init) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()

    # write polynomial coefficients
    numpy.savetxt('tmp.centerline_polycoeffs_x.txt',coeffsx)
    numpy.savetxt('tmp.centerline_polycoeffs_y.txt',coeffsy)

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_mask_split_fit = ['tmp.mask_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    file_point_split_fit = ['tmp.point_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_cumul_fit[iz]+' -out '+file_anat_split_fit[iz])
        # inverse cumulative transformation to mask
        sct.run(fsloutput+'flirt -in '+file_mask_split[z_init]+' -ref '+file_mask_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_mask_split_fit[iz])
        # inverse cumulative transformation to point
        sct.run(fsloutput+'flirt -in '+file_point_split[z_init]+' -ref '+file_point_split[z_init]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_point_split_fit[iz]+' -interp nearestneighbour')

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')
    sct.run(fsloutput+'fslmerge -z tmp.mask_orient_fit tmp.mask_orient_fit_z*')
    sct.run(fsloutput+'fslmerge -z tmp.point_orient_fit tmp.point_orient_fit_z*')

    # Copy header geometry from input data
    print '\nCopy header geometry from input data...'
    sct.run(fsloutput+'fslcpgeom tmp.anat_orient.nii tmp.anat_orient_fit.nii ')
    sct.run(fsloutput+'fslcpgeom tmp.anat_orient.nii tmp.mask_orient_fit.nii ')
    sct.run(fsloutput+'fslcpgeom tmp.anat_orient.nii tmp.point_orient_fit.nii ')

    # Reorient outputs into the initial orientation of the input image
    print '\nReorient the centerline into the initial orientation of the input image...'
    set_orientation('tmp.point_orient_fit.nii', input_image_orientation, 'tmp.point_orient_fit.nii')
    set_orientation('tmp.mask_orient_fit.nii', input_image_orientation, 'tmp.mask_orient_fit.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    os.chdir('..')  # come back to parent folder
    #sct.generate_output_file('tmp.centerline_polycoeffs_x.txt','./','centerline_polycoeffs_x','.txt')
    #sct.generate_output_file('tmp.centerline_polycoeffs_y.txt','./','centerline_polycoeffs_y','.txt')
    #sct.generate_output_file('tmp.centerline_coordinates.txt','./','centerline_coordinates','.txt')
    #sct.generate_output_file('tmp.anat_orient.nii','./',file_anat+'_rpi',ext_anat)
    #sct.generate_output_file('tmp.anat_orient_fit.nii', file_anat+'_rpi_align'+ext_anat)
    #sct.generate_output_file('tmp.mask_orient_fit.nii', file_anat+'_mask'+ext_anat)
    fname_output_centerline = sct.generate_output_file(path_tmp+'/tmp.point_orient_fit.nii', file_anat+'_centerline'+ext_anat)

    # Delete temporary files
    if remove_tmp_files == 1:
        print '\nRemove temporary files...'
        sct.run('rm -rf '+path_tmp)

    # print number of warnings
    print '\nNumber of warnings: '+str(warning_count)+' (if >10, you should probably reduce the gap and/or increase the kernel size'

    # display elapsed time
    elapsed_time = time.time() - start_time
    print '\nFinished! \n\tGenerated file: '+fname_output_centerline+'\n\tElapsed time: '+str(int(round(elapsed_time)))+'s\n'
def main():

    # Initialization
    fname_anat = ''
    fname_centerline = ''
    sigma = 3 # default value of the standard deviation for the Gaussian smoothing (in terms of number of voxels)
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    start_time = time.time()


    # Check input param
    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hi:c:r:s:v:')
    except getopt.GetoptError as err:
        print str(err)
        usage()
    if not opts:
        usage()
    for opt, arg in opts:
        if opt == '-h':
            usage()
        elif opt in ('-c'):
            fname_centerline = arg
        elif opt in ('-i'):
            fname_anat = arg
        elif opt in ('-r'):
            remove_temp_files = arg
        elif opt in ('-s'):
            sigma = arg
        elif opt in ('-v'):
            verbose = int(arg)

    # Display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # Display arguments
    print '\nCheck input arguments...'
    print '  Volume to smooth .................. ' + fname_anat
    print '  Centerline ........................ ' + fname_centerline
    print '  FWHM .............................. '+str(sigma)
    print '  Verbose ........................... '+str(verbose)

    # Check existence of input files
    print('\nCheck existence of input files...')
    sct.check_file_exist(fname_anat, verbose)
    sct.check_file_exist(fname_centerline, verbose)

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    print('\nCopy files...')
    sct.run('isct_c3d '+fname_anat+' -o '+path_tmp+'/anat.nii')
    sct.run('isct_c3d '+fname_centerline+' -o '+path_tmp+'/centerline.nii')

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input image into RPI
    print '\nOrient input volume to RPI orientation...'
    set_orientation('anat.nii', 'RPI', 'anat_rpi.nii')
    # Change orientation of the input image into RPI
    print '\nOrient centerline to RPI orientation...'
    set_orientation('centerline.nii', 'RPI', 'centerline_rpi.nii')


    ## new

    ### Make sure that centerline file does not have halls
    file_c = load('centerline_rpi.nii')
    data_c = file_c.get_data()
    hdr_c = file_c.get_header()

    data_temp = copy(data_c)
    data_temp *= 0
    data_output = copy(data_c)
    data_output *= 0
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension('centerline_rpi.nii')

    ## Change seg to centerline if it is a segmentation
    sct.printv('\nChange segmentation to centerline if it is a centerline...\n')
    z_centerline = [iz for iz in range(0, nz, 1) if data_c[:,:,iz].any() ]
    nz_nonz = len(z_centerline)
    if nz_nonz==0 :
        print '\nERROR: Centerline is empty'
        sys.exit()
    x_centerline = [0 for iz in range(0, nz_nonz, 1)]
    y_centerline = [0 for iz in range(0, nz_nonz, 1)]
    #print("z_centerline", z_centerline,nz_nonz,len(x_centerline))
    print '\nGet center of mass of the centerline ...'
    for iz in xrange(len(z_centerline)):
        x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(array(data_c[:,:,z_centerline[iz]]))
        data_temp[x_centerline[iz], y_centerline[iz], z_centerline[iz]] = 1

    ## Complete centerline
    sct.printv('\nComplete the halls of the centerline if there are any...\n')
    X,Y,Z = data_temp.nonzero()

    x_centerline_extended = [0 for i in range(0, nz, 1)]
    y_centerline_extended = [0 for i in range(0, nz, 1)]
    for iz in range(len(Z)):
        x_centerline_extended[Z[iz]] = X[iz]
        y_centerline_extended[Z[iz]] = Y[iz]

    X_centerline_extended = nonzero(x_centerline_extended)
    X_centerline_extended = transpose(X_centerline_extended)
    Y_centerline_extended = nonzero(y_centerline_extended)
    Y_centerline_extended = transpose(Y_centerline_extended)

    # initialization: we set the extrem values to avoid edge effects
    x_centerline_extended[0] = x_centerline_extended[X_centerline_extended[0]]
    x_centerline_extended[-1] = x_centerline_extended[X_centerline_extended[-1]]
    y_centerline_extended[0] = y_centerline_extended[Y_centerline_extended[0]]
    y_centerline_extended[-1] = y_centerline_extended[Y_centerline_extended[-1]]

    # Add two rows to the vector X_means_smooth_extended:
    # one before as means_smooth_extended[0] is now diff from 0
    # one after as means_smooth_extended[-1] is now diff from 0
    X_centerline_extended = append(X_centerline_extended, len(x_centerline_extended)-1)
    X_centerline_extended = insert(X_centerline_extended, 0, 0)
    Y_centerline_extended = append(Y_centerline_extended, len(y_centerline_extended)-1)
    Y_centerline_extended = insert(Y_centerline_extended, 0, 0)

    #recurrence
    count_zeros_x=0
    count_zeros_y=0
    for i in range(1,nz-1):
        if x_centerline_extended[i]==0:
           x_centerline_extended[i] = 0.5*(x_centerline_extended[X_centerline_extended[i-1-count_zeros_x]] + x_centerline_extended[X_centerline_extended[i-count_zeros_x]])
           count_zeros_x += 1
        if y_centerline_extended[i]==0:
           y_centerline_extended[i] = 0.5*(y_centerline_extended[Y_centerline_extended[i-1-count_zeros_y]] + y_centerline_extended[Y_centerline_extended[i-count_zeros_y]])
           count_zeros_y += 1

    # Save image centerline completed to be used after
    sct.printv('\nSave image completed: centerline_rpi_completed.nii...\n')
    for i in range(nz):
        data_output[x_centerline_extended[i],y_centerline_extended[i],i] = 1
    img = Nifti1Image(data_output, None, hdr_c)
    save(img, 'centerline_rpi_completed.nii')

    #end new


   # Straighten the spinal cord
    print '\nStraighten the spinal cord...'
    sct.run('sct_straighten_spinalcord -i anat_rpi.nii -c centerline_rpi_completed.nii -x spline -v '+str(verbose))

    # Smooth the straightened image along z
    print '\nSmooth the straightened image along z...'
    sct.run('isct_c3d anat_rpi_straight.nii -smooth 0x0x'+str(sigma)+'vox -o anat_rpi_straight_smooth.nii', verbose)

    # Apply the reversed warping field to get back the curved spinal cord
    print '\nApply the reversed warping field to get back the curved spinal cord...'
    sct.run('sct_apply_transfo -i anat_rpi_straight_smooth.nii -o anat_rpi_straight_smooth_curved.nii -d anat.nii -w warp_straight2curve.nii.gz -x spline', verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output file
    print '\nGenerate output file...'
    sct.generate_output_file(path_tmp+'/anat_rpi_straight_smooth_curved.nii', file_anat+'_smooth'+ext_anat)

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)

    # Display elapsed time
    elapsed_time = time.time() - start_time
    print '\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s\n'

    # to view results
    sct.printv('Done! To view results, type:', verbose)
    sct.printv('fslview '+file_anat+' '+file_anat+'_smooth &\n', verbose, 'info')
Ejemplo n.º 23
0
def main():
    # Initialization
    fname_data = ''
    suffix_out = '_crop'
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    fsloutput = 'export FSLOUTPUTTYPE=NIFTI; ' # for faster processing, all outputs are in NIFTI
    remove_temp_files = param.remove_temp_files
    
    # Parameters for debug mode
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = path_sct+'/testing/data/errsm_23/t2/t2.nii.gz'
        remove_temp_files = 0
    else:
        # Check input parameters
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:r:v:')
        except getopt.GetoptError:
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_data = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-v'):
                verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_data == '':
        usage()

    # Check file existence
    sct.printv('\nCheck file existence...', verbose)
    sct.check_file_exist(fname_data, verbose)

    # Get dimensions of data
    sct.printv('\nGet dimensions of data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_data).dim
    sct.printv('.. '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
    # check if 4D data
    if not nt == 1:
        sct.printv('\nERROR in '+os.path.basename(__file__)+': Data should be 3D.\n', 1, 'error')
        sys.exit(2)

    # print arguments
    print '\nCheck parameters:'
    print '  data ................... '+fname_data
    print

    # Extract path/file/extension
    path_data, file_data, ext_data = sct.extract_fname(fname_data)
    path_out, file_out, ext_out = '', file_data+suffix_out, ext_data

    # create temporary folder
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")+'/'
    sct.run('mkdir '+path_tmp)

    # copy files into tmp folder
    sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'data.nii')

    # go to tmp folder
    os.chdir(path_tmp)

    # change orientation
    sct.printv('\nChange orientation to RPI...', verbose)
    set_orientation('data.nii', 'RPI', 'data_rpi.nii')

    # get image of medial slab
    sct.printv('\nGet image of medial slab...', verbose)
    image_array = nibabel.load('data_rpi.nii').get_data()
    nx, ny, nz = image_array.shape
    scipy.misc.imsave('image.jpg', image_array[math.floor(nx/2), :, :])

    # Display the image
    sct.printv('\nDisplay image and get cropping region...', verbose)
    fig = plt.figure()
    # fig = plt.gcf()
    # ax = plt.gca()
    ax = fig.add_subplot(111)
    img = mpimg.imread("image.jpg")
    implot = ax.imshow(img.T)
    implot.set_cmap('gray')
    plt.gca().invert_yaxis()
    # mouse callback
    ax.set_title('Left click on the top and bottom of your cropping field.\n Right click to remove last point.\n Close window when your done.')
    line, = ax.plot([], [], 'ro')  # empty line
    cropping_coordinates = LineBuilder(line)
    plt.show()
    # disconnect callback
    # fig.canvas.mpl_disconnect(line)

    # check if user clicked two times
    if len(cropping_coordinates.xs) != 2:
        sct.printv('\nERROR: You have to select two points. Exit program.\n', 1, 'error')
        sys.exit(2)

    # convert coordinates to integer
    zcrop = [int(i) for i in cropping_coordinates.ys]

    # sort coordinates
    zcrop.sort()

    # crop image
    sct.printv('\nCrop image...', verbose)
    nii = Image('data_rpi.nii')
    data_crop = nii.data[:, :, zcrop[0]:zcrop[1]]
    nii.data = data_crop
    nii.setFileName('data_rpi_crop.nii')
    nii.save()

    # come back to parent folder
    os.chdir('..')

    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'data_rpi_crop.nii', path_out+file_out+ext_out)

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)

    # to view results
    print '\nDone! To view results, type:'
    print 'fslview '+path_out+file_out+ext_out+' &'
    print
def main():

    # Initialization
    fname_anat = ''
    fname_centerline = ''
    centerline_fitting = 'polynome'
    remove_temp_files = param.remove_temp_files
    interp = param.interp
    degree_poly = param.deg_poly

    # extract path of the script
    path_script = os.path.dirname(__file__) + '/'

    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        status, path_sct_data = commands.getstatusoutput(
            'echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_data + '/t2/t2.nii.gz'
        fname_centerline = path_sct_data + '/t2/t2_seg.nii.gz'
    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:], 'hi:c:r:d:f:s:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-c'):
                fname_centerline = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-d'):
                degree_poly = int(arg)
            elif opt in ('-f'):
                centerline_fitting = str(arg)
            elif opt in ('-s'):
                interp = str(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)

    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... ' + fname_anat
    print '  Centerline ........................ ' + fname_centerline
    print ''

    # Get input image orientation
    input_image_orientation = get_orientation(fname_anat)

    # Reorient input data into RL PA IS orientation
    set_orientation(fname_anat, 'RPI', 'tmp.anat_orient.nii')
    set_orientation(fname_centerline, 'RPI', 'tmp.centerline_orient.nii')

    # Open centerline
    #==========================================================================================
    print '\nGet dimensions of input centerline...'
    nx, ny, nz, nt, px, py, pz, pt = Image('tmp.centerline_orient.nii').dim
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    print '\nOpen centerline volume...'
    file = nibabel.load('tmp.centerline_orient.nii')
    data = file.get_data()

    X, Y, Z = (data > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)

    # loop across z and associate x,y coordinate with the point having maximum intensity
    x_centerline = [0 for iz in range(min_z_index, max_z_index + 1, 1)]
    y_centerline = [0 for iz in range(min_z_index, max_z_index + 1, 1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index + 1, 1)]

    # Two possible scenario:
    # 1. the centerline is probabilistic: each slices contains voxels with the probability of containing the centerline [0:...:1]
    # We only take the maximum value of the image to aproximate the centerline.
    # 2. The centerline/segmentation image contains many pixels per slice with values {0,1}.
    # We take all the points and approximate the centerline on all these points.

    X, Y, Z = ((data < 1) * (data > 0)).nonzero()  # X is empty if binary image
    if (len(X) > 0):  # Scenario 1
        for iz in range(min_z_index, max_z_index + 1, 1):
            x_centerline[iz - min_z_index], y_centerline[
                iz - min_z_index] = numpy.unravel_index(
                    data[:, :, iz].argmax(), data[:, :, iz].shape)
    else:  # Scenario 2
        for iz in range(min_z_index, max_z_index + 1, 1):
            x_seg, y_seg = (data[:, :, iz] > 0).nonzero()
            if len(x_seg) > 0:
                x_centerline[iz - min_z_index] = numpy.mean(x_seg)
                y_centerline[iz - min_z_index] = numpy.mean(y_seg)

    # TODO: find a way to do the previous loop with this, which is more neat:
    # [numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape) for iz in range(0,nz,1)]

    # clear variable
    del data

    # Fit the centerline points with the kind of curve given as argument of the script and return the new smoothed coordinates
    if centerline_fitting == 'splines':
        try:
            x_centerline_fit, y_centerline_fit = b_spline_centerline(
                x_centerline, y_centerline, z_centerline)
        except ValueError:
            print "splines fitting doesn't work, trying with polynomial fitting...\n"
            x_centerline_fit, y_centerline_fit = polynome_centerline(
                x_centerline, y_centerline, z_centerline)
    elif centerline_fitting == 'polynome':
        x_centerline_fit, y_centerline_fit = polynome_centerline(
            x_centerline, y_centerline, z_centerline)

    #==========================================================================================
    # Split input volume
    print '\nSplit input volume...'
    from sct_split_data import split_data
    if not split_data('tmp.anat_orient.nii', 2, '_z'):
        sct.printv('ERROR in split_data.', 1, 'error')
    file_anat_split = [
        'tmp.anat_orient_z' + str(z).zfill(4) for z in range(0, nz, 1)
    ]

    # initialize variables
    file_mat_inv_cumul = [
        'tmp.mat_inv_cumul_z' + str(z).zfill(4) for z in range(0, nz, 1)
    ]
    z_init = min_z_index
    displacement_max_z_index = x_centerline_fit[
        z_init - min_z_index] - x_centerline_fit[max_z_index - min_z_index]

    # write centerline as text file
    print '\nGenerate fitted transformation matrices...'
    file_mat_inv_cumul_fit = [
        'tmp.mat_inv_cumul_fit_z' + str(z).zfill(4) for z in range(0, nz, 1)
    ]
    for iz in range(min_z_index, max_z_index + 1, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        if (x_centerline[iz - min_z_index] == 0
                and y_centerline[iz - min_z_index] == 0):
            displacement = 0
        else:
            displacement = x_centerline_fit[
                z_init - min_z_index] - x_centerline_fit[iz - min_z_index]
        fid.write('%i %i %i %f\n' % (1, 0, 0, displacement))
        fid.write('%i %i %i %f\n' % (0, 1, 0, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()

    # we complete the displacement matrix in z direction
    for iz in range(0, min_z_index, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, 0))
        fid.write('%i %i %i %f\n' % (0, 1, 0, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()
    for iz in range(max_z_index + 1, nz, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' % (1, 0, 0, displacement_max_z_index))
        fid.write('%i %i %i %f\n' % (0, 1, 0, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 1, 0))
        fid.write('%i %i %i %i\n' % (0, 0, 0, 1))
        fid.close()

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = [
        'tmp.anat_orient_fit_z' + str(z).zfill(4) for z in range(0, nz, 1)
    ]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput + 'flirt -in ' + file_anat_split[iz] + ' -ref ' +
                file_anat_split[iz] + ' -applyxfm -init ' +
                file_mat_inv_cumul_fit[iz] + ' -out ' +
                file_anat_split_fit[iz] + ' -interp ' + interp)

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    from sct_concat_data import concat_data
    from glob import glob
    concat_data(glob('tmp.anat_orient_fit_z*.nii'),
                'tmp.anat_orient_fit.nii',
                dim=2)
    # sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')

    # Reorient data as it was before
    print '\nReorient data back into native orientation...'
    set_orientation('tmp.anat_orient_fit.nii', input_image_orientation,
                    'tmp.anat_orient_fit_reorient.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    sct.generate_output_file('tmp.anat_orient_fit_reorient.nii',
                             file_anat + '_flatten' + ext_anat)

    # Delete temporary files
    if remove_temp_files == 1:
        print '\nDelete temporary files...'
        sct.run('rm -rf tmp.*')

    # to view results
    print '\nDone! To view results, type:'
    print 'fslview ' + file_anat + ext_anat + ' ' + file_anat + '_flatten' + ext_anat + ' &\n'
def main():
    
    # Initialization
    fname_anat = ''
    fname_centerline = ''
    centerline_fitting = 'polynome'
    remove_temp_files = param.remove_temp_files
    interp = param.interp
    degree_poly = param.deg_poly
    
    # extract path of the script
    path_script = os.path.dirname(__file__)+'/'
    
    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        status, path_sct_data = commands.getstatusoutput('echo $SCT_TESTING_DATA_DIR')
        fname_anat = path_sct_data+'/t2/t2.nii.gz'
        fname_centerline = path_sct_data+'/t2/t2_seg.nii.gz'
    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:c:r:d:f:s:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-c'):
                fname_centerline = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-d'):
                degree_poly = int(arg)
            elif opt in ('-f'):
                centerline_fitting = str(arg)
            elif opt in ('-s'):
                interp = str(arg)
    
    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()
    
    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)
    
    # extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    
    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... '+fname_anat
    print '  Centerline ........................ '+fname_centerline
    print ''
    
    # Get input image orientation
    input_image_orientation = get_orientation(fname_anat)

    # Reorient input data into RL PA IS orientation
    set_orientation(fname_anat, 'RPI', 'tmp.anat_orient.nii')
    set_orientation(fname_centerline, 'RPI', 'tmp.centerline_orient.nii')

    # Open centerline
    #==========================================================================================
    print '\nGet dimensions of input centerline...'
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension('tmp.centerline_orient.nii')
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'
    
    print '\nOpen centerline volume...'
    file = nibabel.load('tmp.centerline_orient.nii')
    data = file.get_data()

    X, Y, Z = (data>0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)
    
    
    # loop across z and associate x,y coordinate with the point having maximum intensity
    x_centerline = [0 for iz in range(min_z_index, max_z_index+1, 1)]
    y_centerline = [0 for iz in range(min_z_index, max_z_index+1, 1)]
    z_centerline = [iz for iz in range(min_z_index, max_z_index+1, 1)]

    # Two possible scenario:
    # 1. the centerline is probabilistic: each slices contains voxels with the probability of containing the centerline [0:...:1]
    # We only take the maximum value of the image to aproximate the centerline.
    # 2. The centerline/segmentation image contains many pixels per slice with values {0,1}.
    # We take all the points and approximate the centerline on all these points.

    X, Y, Z = ((data<1)*(data>0)).nonzero() # X is empty if binary image
    if (len(X) > 0): # Scenario 1
        for iz in range(min_z_index, max_z_index+1, 1):
            x_centerline[iz-min_z_index], y_centerline[iz-min_z_index] = numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape)
    else: # Scenario 2
        for iz in range(min_z_index, max_z_index+1, 1):
            x_seg, y_seg = (data[:,:,iz]>0).nonzero()
            if len(x_seg) > 0:
                x_centerline[iz-min_z_index] = numpy.mean(x_seg)
                y_centerline[iz-min_z_index] = numpy.mean(y_seg)

    # TODO: find a way to do the previous loop with this, which is more neat:
    # [numpy.unravel_index(data[:,:,iz].argmax(), data[:,:,iz].shape) for iz in range(0,nz,1)]
    
    # clear variable
    del data
    
    # Fit the centerline points with the kind of curve given as argument of the script and return the new smoothed coordinates
    if centerline_fitting == 'splines':
        try:
            x_centerline_fit, y_centerline_fit = b_spline_centerline(x_centerline,y_centerline,z_centerline)
        except ValueError:
            print "splines fitting doesn't work, trying with polynomial fitting...\n"
            x_centerline_fit, y_centerline_fit = polynome_centerline(x_centerline,y_centerline,z_centerline)
    elif centerline_fitting == 'polynome':
        x_centerline_fit, y_centerline_fit = polynome_centerline(x_centerline,y_centerline,z_centerline)

    #==========================================================================================
    # Split input volume
    print '\nSplit input volume...'
    sct.run(sct.fsloutput + 'fslsplit tmp.anat_orient.nii tmp.anat_z -z')
    file_anat_split = ['tmp.anat_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # initialize variables
    file_mat_inv_cumul = ['tmp.mat_inv_cumul_z'+str(z).zfill(4) for z in range(0,nz,1)]
    z_init = min_z_index
    displacement_max_z_index = x_centerline_fit[z_init-min_z_index]-x_centerline_fit[max_z_index-min_z_index]

    # write centerline as text file
    print '\nGenerate fitted transformation matrices...'
    file_mat_inv_cumul_fit = ['tmp.mat_inv_cumul_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(min_z_index, max_z_index+1, 1):
        # compute inverse cumulative fitted transformation matrix
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        if (x_centerline[iz-min_z_index] == 0 and y_centerline[iz-min_z_index] == 0):
            displacement = 0
        else:
            displacement = x_centerline_fit[z_init-min_z_index]-x_centerline_fit[iz-min_z_index]
        fid.write('%i %i %i %f\n' %(1, 0, 0, displacement) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()

    # we complete the displacement matrix in z direction
    for iz in range(0, min_z_index, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, 0) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()
    for iz in range(max_z_index+1, nz, 1):
        fid = open(file_mat_inv_cumul_fit[iz], 'w')
        fid.write('%i %i %i %f\n' %(1, 0, 0, displacement_max_z_index) )
        fid.write('%i %i %i %f\n' %(0, 1, 0, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 1, 0) )
        fid.write('%i %i %i %i\n' %(0, 0, 0, 1) )
        fid.close()

    # apply transformations to data
    print '\nApply fitted transformation matrices...'
    file_anat_split_fit = ['tmp.anat_orient_fit_z'+str(z).zfill(4) for z in range(0,nz,1)]
    for iz in range(0, nz, 1):
        # forward cumulative transformation to data
        sct.run(fsloutput+'flirt -in '+file_anat_split[iz]+' -ref '+file_anat_split[iz]+' -applyxfm -init '+file_mat_inv_cumul_fit[iz]+' -out '+file_anat_split_fit[iz]+' -interp '+interp)

    # Merge into 4D volume
    print '\nMerge into 4D volume...'
    sct.run(fsloutput+'fslmerge -z tmp.anat_orient_fit tmp.anat_orient_fit_z*')

    # Reorient data as it was before
    print '\nReorient data back into native orientation...'
    set_orientation('tmp.anat_orient_fit.nii', input_image_orientation, 'tmp.anat_orient_fit_reorient.nii')

    # Generate output file (in current folder)
    print '\nGenerate output file (in current folder)...'
    sct.generate_output_file('tmp.anat_orient_fit_reorient.nii', file_anat+'_flatten'+ext_anat)

    # Delete temporary files
    if remove_temp_files == 1:
        print '\nDelete temporary files...'
        sct.run('rm -rf tmp.*')

    # to view results
    print '\nDone! To view results, type:'
    print 'fslview '+file_anat+ext_anat+' '+file_anat+'_flatten'+ext_anat+' &\n'
Ejemplo n.º 26
0
def main():

    # Initialization
    fname_anat = ''
    fname_centerline = ''
    sigma = 3 # default value of the standard deviation for the Gaussian smoothing (in terms of number of voxels)
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    start_time = time.time()


    # Check input param
    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hi:c:r:s:v:')
    except getopt.GetoptError as err:
        print str(err)
        usage()
    if not opts:
        usage()
    for opt, arg in opts:
        if opt == '-h':
            usage()
        elif opt in ('-c'):
            fname_centerline = arg
        elif opt in ('-i'):
            fname_anat = arg
        elif opt in ('-r'):
            remove_temp_files = arg
        elif opt in ('-s'):
            sigma = arg
        elif opt in ('-v'):
            verbose = int(arg)

    # Display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # Display arguments
    print '\nCheck input arguments...'
    print '  Volume to smooth .................. ' + fname_anat
    print '  Centerline ........................ ' + fname_centerline
    print '  FWHM .............................. '+str(sigma)
    print '  Verbose ........................... '+str(verbose)

    # Check existence of input files
    print('\nCheck existence of input files...')
    sct.check_file_exist(fname_anat, verbose)
    sct.check_file_exist(fname_centerline, verbose)

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = sct.slash_at_the_end('tmp.'+time.strftime("%y%m%d%H%M%S"), 1)
    sct.run('mkdir '+path_tmp, verbose)

    # Copying input data to tmp folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('cp '+fname_anat+' '+path_tmp+'anat'+ext_anat, verbose)
    sct.run('cp '+fname_centerline+' '+path_tmp+'centerline'+ext_centerline, verbose)

    # go to tmp folder
    os.chdir(path_tmp)

    # convert to nii format
    convert('anat'+ext_anat, 'anat.nii')
    convert('centerline'+ext_centerline, 'centerline.nii')

    # Change orientation of the input image into RPI
    print '\nOrient input volume to RPI orientation...'
    set_orientation('anat.nii', 'RPI', 'anat_rpi.nii')
    # Change orientation of the input image into RPI
    print '\nOrient centerline to RPI orientation...'
    set_orientation('centerline.nii', 'RPI', 'centerline_rpi.nii')

    # ## new
    #
    # ### Make sure that centerline file does not have halls
    # file_c = load('centerline_rpi.nii')
    # data_c = file_c.get_data()
    # hdr_c = file_c.get_header()
    #
    # data_temp = copy(data_c)
    # data_temp *= 0
    # data_output = copy(data_c)
    # data_output *= 0
    # nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension('centerline_rpi.nii')
    #
    # ## Change seg to centerline if it is a segmentation
    # sct.printv('\nChange segmentation to centerline if it is a centerline...\n')
    # z_centerline = [iz for iz in range(0, nz, 1) if data_c[:,:,iz].any() ]
    # nz_nonz = len(z_centerline)
    # if nz_nonz==0 :
    #     print '\nERROR: Centerline is empty'
    #     sys.exit()
    # x_centerline = [0 for iz in range(0, nz_nonz, 1)]
    # y_centerline = [0 for iz in range(0, nz_nonz, 1)]
    # #print("z_centerline", z_centerline,nz_nonz,len(x_centerline))
    # print '\nGet center of mass of the centerline ...'
    # for iz in xrange(len(z_centerline)):
    #     x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(array(data_c[:,:,z_centerline[iz]]))
    #     data_temp[x_centerline[iz], y_centerline[iz], z_centerline[iz]] = 1
    #
    # ## Complete centerline
    # sct.printv('\nComplete the halls of the centerline if there are any...\n')
    # X,Y,Z = data_temp.nonzero()
    #
    # x_centerline_extended = [0 for i in range(0, nz, 1)]
    # y_centerline_extended = [0 for i in range(0, nz, 1)]
    # for iz in range(len(Z)):
    #     x_centerline_extended[Z[iz]] = X[iz]
    #     y_centerline_extended[Z[iz]] = Y[iz]
    #
    # X_centerline_extended = nonzero(x_centerline_extended)
    # X_centerline_extended = transpose(X_centerline_extended)
    # Y_centerline_extended = nonzero(y_centerline_extended)
    # Y_centerline_extended = transpose(Y_centerline_extended)
    #
    # # initialization: we set the extrem values to avoid edge effects
    # x_centerline_extended[0] = x_centerline_extended[X_centerline_extended[0]]
    # x_centerline_extended[-1] = x_centerline_extended[X_centerline_extended[-1]]
    # y_centerline_extended[0] = y_centerline_extended[Y_centerline_extended[0]]
    # y_centerline_extended[-1] = y_centerline_extended[Y_centerline_extended[-1]]
    #
    # # Add two rows to the vector X_means_smooth_extended:
    # # one before as means_smooth_extended[0] is now diff from 0
    # # one after as means_smooth_extended[-1] is now diff from 0
    # X_centerline_extended = append(X_centerline_extended, len(x_centerline_extended)-1)
    # X_centerline_extended = insert(X_centerline_extended, 0, 0)
    # Y_centerline_extended = append(Y_centerline_extended, len(y_centerline_extended)-1)
    # Y_centerline_extended = insert(Y_centerline_extended, 0, 0)
    #
    # #recurrence
    # count_zeros_x=0
    # count_zeros_y=0
    # for i in range(1,nz-1):
    #     if x_centerline_extended[i]==0:
    #        x_centerline_extended[i] = 0.5*(x_centerline_extended[X_centerline_extended[i-1-count_zeros_x]] + x_centerline_extended[X_centerline_extended[i-count_zeros_x]])
    #        count_zeros_x += 1
    #     if y_centerline_extended[i]==0:
    #        y_centerline_extended[i] = 0.5*(y_centerline_extended[Y_centerline_extended[i-1-count_zeros_y]] + y_centerline_extended[Y_centerline_extended[i-count_zeros_y]])
    #        count_zeros_y += 1
    #
    # # Save image centerline completed to be used after
    # sct.printv('\nSave image completed: centerline_rpi_completed.nii...\n')
    # for i in range(nz):
    #     data_output[x_centerline_extended[i],y_centerline_extended[i],i] = 1
    # img = Nifti1Image(data_output, None, hdr_c)
    # save(img, 'centerline_rpi_completed.nii')
    #
    # #end new


   # Straighten the spinal cord
    print '\nStraighten the spinal cord...'
    sct.run('sct_straighten_spinalcord -i anat_rpi.nii -c centerline_rpi.nii -x spline -v '+str(verbose))

    # Smooth the straightened image along z
    print '\nSmooth the straightened image along z...'
    sct.run('isct_c3d anat_rpi_straight.nii -smooth 0x0x'+str(sigma)+'vox -o anat_rpi_straight_smooth.nii', verbose)

    # Apply the reversed warping field to get back the curved spinal cord
    print '\nApply the reversed warping field to get back the curved spinal cord...'
    sct.run('sct_apply_transfo -i anat_rpi_straight_smooth.nii -o anat_rpi_straight_smooth_curved.nii -d anat.nii -w warp_straight2curve.nii.gz -x spline', verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output file
    print '\nGenerate output file...'
    sct.generate_output_file(path_tmp+'/anat_rpi_straight_smooth_curved.nii', file_anat+'_smooth'+ext_anat)

    # Remove temporary files
    if remove_temp_files == 1:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)

    # Display elapsed time
    elapsed_time = time.time() - start_time
    print '\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s\n'

    # to view results
    sct.printv('Done! To view results, type:', verbose)
    sct.printv('fslview '+file_anat+' '+file_anat+'_smooth &\n', verbose, 'info')