def resample_labels(fname_labels, fname_dest, fname_output):
    """
    This function re-create labels into a space that has been resampled. It works by re-defining the location of each
    label using the old and new voxel size.
    """
    # get dimensions of input and destination files
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_labels)
    nxd, nyd, nzd, ntd, pxd, pyd, pzd, ptd = sct.get_dimension(fname_dest)
    sampling_factor = [float(nx)/nxd, float(ny)/nyd, float(nz)/nzd]
    # read labels
    from sct_label_utils import ProcessLabels
    processor = ProcessLabels(fname_labels)
    label_list = processor.display_voxel().split(':')
    # parse to get each label
    # TODO: modify sct_label_utils to output list of coordinates instead of string.
    label_new_list = []
    for label in label_list:
        label_sub = label.split(',')
        label_sub_new = []
        for i_label in range(0, 3):
            label_single = round(int(label_sub[i_label])/sampling_factor[i_label])
            label_sub_new.append(str(int(label_single)))
        label_sub_new.append(str(int(float(label_sub[3]))))
        label_new_list.append(','.join(label_sub_new))
    label_new_list = ':'.join(label_new_list)
    # create new labels
    sct.run('sct_label_utils -i '+fname_dest+' -t create -x '+label_new_list+' -v 1 -o '+fname_output)
Пример #2
0
def resample_labels(fname_labels, fname_dest, fname_output):
    """
    This function re-create labels into a space that has been resampled. It works by re-defining the location of each
    label using the old and new voxel size.
    """
    # get dimensions of input and destination files
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_labels).dim
    nxd, nyd, nzd, ntd, pxd, pyd, pzd, ptd = Image(fname_dest).dim
    sampling_factor = [float(nx) / nxd, float(ny) / nyd, float(nz) / nzd]
    # read labels
    from sct_label_utils import ProcessLabels
    processor = ProcessLabels(fname_labels)
    label_list = processor.display_voxel()
    label_new_list = []
    for label in label_list:
        label_sub_new = [
            str(int(round(int(label.x) / sampling_factor[0]))),
            str(int(round(int(label.y) / sampling_factor[1]))),
            str(int(round(int(label.z) / sampling_factor[2]))),
            str(int(float(label.value)))
        ]
        label_new_list.append(','.join(label_sub_new))
    label_new_list = ':'.join(label_new_list)
    # create new labels
    sct.run('sct_label_utils -i ' + fname_dest + ' -create ' + label_new_list +
            ' -v 1 -o ' + fname_output)
Пример #3
0
def add_label(brainstem_file, segmented_file, output_file_name, label_depth_compared_to_zmax=10 , label_value=1):
    #Calculating zmx of the segmented file  (data_RPI_seg.nii.gz)
    image_seg = Image(segmented_file)
    z_test = ComputeZMinMax(image_seg)
    zmax = z_test.Zmax
    print( "Zmax: ",zmax)

    #Test on the number of labels
    brainstem_image = Image(brainstem_file)
    print("nb_label_before=", np.sum(brainstem_image.data))

    #Center of mass
    X, Y = np.nonzero((z_test.image.data[:,:,zmax-label_depth_compared_to_zmax] > 0))
    x_bar = 0
    y_bar = 0
    for i in range(X.shape[0]):
        x_bar = x_bar+X[i]
        y_bar = y_bar+Y[i]
    x_bar = int(round(x_bar/X.shape[0]))
    y_bar = int(round(y_bar/X.shape[0]))

    #Placement du nouveau label aux coordonnees x_bar, y_bar et zmax-label_depth_compared_to_zmax
    coordi = Coordinate([x_bar, y_bar, zmax-label_depth_compared_to_zmax, label_value])
    object_for_process = ProcessLabels(brainstem_file, coordinates=[coordi])
    #print("object_for_process.coordinates=", object_for_process.coordinates.x, object_for_process.coordinates.y, object_for_process.coordinates.z)
    file_with_new_label=object_for_process.create_label()

    #Define output file
    im_output = object_for_process.image_input.copy()
    im_output.data *= 0
    brainstem_image=Image(brainstem_file)
    im_output.data = brainstem_image.data + file_with_new_label.data

    #Test the number of labels
    print("nb_label_after=", np.sum(im_output.data))

    #Save output file
    im_output.setFileName(output_file_name)
    im_output.save('minimize')
def resample_labels(fname_labels, fname_dest, fname_output):
    """
    This function re-create labels into a space that has been resampled. It works by re-defining the location of each
    label using the old and new voxel size.
    """
    # get dimensions of input and destination files
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_labels).dim
    nxd, nyd, nzd, ntd, pxd, pyd, pzd, ptd = Image(fname_dest).dim
    sampling_factor = [float(nx)/nxd, float(ny)/nyd, float(nz)/nzd]
    # read labels
    from sct_label_utils import ProcessLabels
    processor = ProcessLabels(fname_labels)
    label_list = processor.display_voxel()
    label_new_list = []
    for label in label_list:
        label_sub_new = [str(int(round(int(label.x)/sampling_factor[0]))),
                         str(int(round(int(label.y)/sampling_factor[1]))),
                         str(int(round(int(label.z)/sampling_factor[2]))),
                         str(int(float(label.value)))]
        label_new_list.append(','.join(label_sub_new))
    label_new_list = ':'.join(label_new_list)
    # create new labels
    sct.run('sct_label_utils -i '+fname_dest+' -t create -x '+label_new_list+' -v 1 -o '+fname_output)
def main(args=None):

    # initializations
    initz = ''
    initcenter = ''
    fname_initlabel = ''
    file_labelz = 'labelz.nii.gz'
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_in = os.path.abspath(arguments["-i"])
    fname_seg = os.path.abspath(arguments['-s'])
    contrast = arguments['-c']
    path_template = arguments['-t']
    scale_dist = arguments['-scale-dist']
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = os.curdir
    param.path_qc = arguments.get("-qc", None)
    if '-discfile' in arguments:
        fname_disc = os.path.abspath(arguments['-discfile'])
    else:
        fname_disc = None
    if '-initz' in arguments:
        initz = arguments['-initz']
    if '-initcenter' in arguments:
        initcenter = arguments['-initcenter']
    # if user provided text file, parse and overwrite arguments
    if '-initfile' in arguments:
        file = open(arguments['-initfile'], 'r')
        initfile = ' ' + file.read().replace('\n', '')
        arg_initfile = initfile.split(' ')
        for idx_arg, arg in enumerate(arg_initfile):
            if arg == '-initz':
                initz = [int(x) for x in arg_initfile[idx_arg + 1].split(',')]
            if arg == '-initcenter':
                initcenter = int(arg_initfile[idx_arg + 1])
    if '-initlabel' in arguments:
        # get absolute path of label
        fname_initlabel = os.path.abspath(arguments['-initlabel'])
    if '-param' in arguments:
        param.update(arguments['-param'][0])
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    remove_temp_files = int(arguments['-r'])
    denoise = int(arguments['-denoise'])
    laplacian = int(arguments['-laplacian'])

    path_tmp = sct.tmp_create(basename="label_vertebrae", verbose=verbose)

    # Copying input data to tmp folder
    sct.printv('\nCopying input data to tmp folder...', verbose)
    Image(fname_in).save(os.path.join(path_tmp, "data.nii"))
    Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii"))

    # Go go temp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Straighten spinal cord
    sct.printv('\nStraighten spinal cord...', verbose)
    # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
    cache_sig = sct.cache_signature(
     input_files=[fname_in, fname_seg],
    )
    cachefile = os.path.join(curdir, "straightening.cache")
    if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(os.path.join(curdir, "warp_curve2straight.nii.gz")) and os.path.isfile(os.path.join(curdir, "warp_straight2curve.nii.gz")) and os.path.isfile(os.path.join(curdir, "straight_ref.nii.gz")):
        # if they exist, copy them into current folder
        sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
        sct.copy(os.path.join(curdir, "warp_curve2straight.nii.gz"), 'warp_curve2straight.nii.gz')
        sct.copy(os.path.join(curdir, "warp_straight2curve.nii.gz"), 'warp_straight2curve.nii.gz')
        sct.copy(os.path.join(curdir, "straight_ref.nii.gz"), 'straight_ref.nii.gz')
        # apply straightening
        s, o = sct.run(['sct_apply_transfo', '-i', 'data.nii', '-w', 'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o', 'data_straight.nii'])
    else:
        cmd = ['sct_straighten_spinalcord',
               '-i', 'data.nii',
               '-s', 'segmentation.nii',
               '-r', str(remove_temp_files)]
        if param.path_qc is not None and os.environ.get("SCT_RECURSIVE_QC", None) == "1":
            cmd += ['-qc', param.path_qc]
        s, o = sct.run(cmd)
        sct.cache_save(cachefile, cache_sig)

    # resample to 0.5mm isotropic to match template resolution
    sct.printv('\nResample to 0.5mm isotropic...', verbose)
    s, o = sct.run(['sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x', 'linear', '-o', 'data_straightr.nii'], verbose=verbose)

    # Apply straightening to segmentation
    # N.B. Output is RPI
    sct.printv('\nApply straightening to segmentation...', verbose)
    sct.run('isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
            ('segmentation.nii',
             'data_straightr.nii',
             'warp_curve2straight.nii.gz',
             'segmentation_straight.nii',
             'Linear'),
            verbose=verbose,
            is_sct_binary=True,
           )
    # Threshold segmentation at 0.5
    sct.run(['sct_maths', '-i', 'segmentation_straight.nii', '-thr', '0.5', '-o', 'segmentation_straight.nii'], verbose)

    # If disc label file is provided, label vertebrae using that file instead of automatically
    if fname_disc:
        # Apply straightening to disc-label
        sct.printv('\nApply straightening to disc labels...', verbose)
        sct.run('isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
                (fname_disc,
                 'data_straightr.nii',
                 'warp_curve2straight.nii.gz',
                 'labeldisc_straight.nii.gz',
                 'NearestNeighbor'),
                 verbose=verbose,
                 is_sct_binary=True,
                )
        label_vert('segmentation_straight.nii', 'labeldisc_straight.nii.gz', verbose=1)

    else:
        # create label to identify disc
        sct.printv('\nCreate label to identify disc...', verbose)
        fname_labelz = os.path.join(path_tmp, file_labelz)
        if initz or initcenter:
            if initcenter:
                # find z centered in FOV
                nii = Image('segmentation.nii').change_orientation("RPI")
                nx, ny, nz, nt, px, py, pz, pt = nii.dim  # Get dimensions
                z_center = int(np.round(nz / 2))  # get z_center
                initz = [z_center, initcenter]
            # create single label and output as labels.nii.gz
            label = ProcessLabels('segmentation.nii', fname_output='tmp.labelz.nii.gz',
                                      coordinates=['{},{}'.format(initz[0], initz[1])])
            im_label = label.process('create-seg')
            im_label.data = sct_maths.dilate(im_label.data, [3])  # TODO: create a dilation method specific to labels,
            # which does not apply a convolution across all voxels (highly inneficient)
            im_label.save(fname_labelz)
        elif fname_initlabel:
            import sct_label_utils
            # subtract "1" to label value because due to legacy, in this code the disc C2-C3 has value "2", whereas in the
            # recent version of SCT it is defined as "3". Therefore, when asking the user to define a label, we point to the
            # new definition of labels (i.e., C2-C3 = 3).
            sct_label_utils.main(['-i', fname_initlabel, '-add', '-1', '-o', fname_labelz])
        else:
            # automatically finds C2-C3 disc
            im_data = Image('data.nii')
            im_seg = Image('segmentation.nii')
            im_label_c2c3 = detect_c2c3(im_data, im_seg, contrast)
            ind_label = np.where(im_label_c2c3.data)
            if not np.size(ind_label) == 0:
                # subtract "1" to label value because due to legacy, in this code the disc C2-C3 has value "2", whereas in the
                # recent version of SCT it is defined as "3".
                im_label_c2c3.data[ind_label] = 2
            else:
                sct.printv('Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils', 1, 'error')
            im_label_c2c3.save(fname_labelz)

        # dilate label so it is not lost when applying warping
        sct_maths.main(['-i', fname_labelz, '-dilate', '3', '-o', fname_labelz])

        # Apply straightening to z-label
        sct.printv('\nAnd apply straightening to label...', verbose)
        sct.run('isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
                (file_labelz,
                 'data_straightr.nii',
                 'warp_curve2straight.nii.gz',
                 'labelz_straight.nii.gz',
                 'NearestNeighbor'),
                verbose=verbose,
                is_sct_binary=True,
               )
        # get z value and disk value to initialize labeling
        sct.printv('\nGet z and disc values from straight label...', verbose)
        init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz')
        sct.printv('.. ' + str(init_disc), verbose)

        # denoise data
        if denoise:
            sct.printv('\nDenoise data...', verbose)
            sct.run(['sct_maths', '-i', 'data_straightr.nii', '-denoise', 'h=0.05', '-o', 'data_straightr.nii'], verbose)

        # apply laplacian filtering
        if laplacian:
            sct.printv('\nApply Laplacian filter...', verbose)
            sct.run(['sct_maths', '-i', 'data_straightr.nii', '-laplacian', '1', '-o', 'data_straightr.nii'], verbose)

        # detect vertebral levels on straight spinal cord
        vertebral_detection('data_straightr.nii', 'segmentation_straight.nii', contrast, param, init_disc=init_disc,
                            verbose=verbose, path_template=path_template, path_output=path_output, scale_dist=scale_dist)

    # un-straighten labeled spinal cord
    sct.printv('\nUn-straighten labeling...', verbose)
    sct.run('isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
            ('segmentation_straight_labeled.nii',
             'segmentation.nii',
             'warp_straight2curve.nii.gz',
             'segmentation_labeled.nii',
             'NearestNeighbor'),
            verbose=verbose,
            is_sct_binary=True,
           )
    # Clean labeled segmentation
    sct.printv('\nClean labeled segmentation (correct interpolation errors)...', verbose)
    clean_labeled_segmentation('segmentation_labeled.nii', 'segmentation.nii', 'segmentation_labeled.nii')

    # label discs
    sct.printv('\nLabel discs...', verbose)
    label_discs('segmentation_labeled.nii', verbose=verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    path_seg, file_seg, ext_seg = sct.extract_fname(fname_seg)
    fname_seg_labeled = os.path.join(path_output, file_seg + '_labeled' + ext_seg)
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"), fname_seg_labeled)
    sct.generate_output_file(os.path.join(path_tmp, "segmentation_labeled_disc.nii"), os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
    # copy straightening files in case subsequent SCT functions need them
    sct.generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        sct.printv('\nRemove temporary files...', verbose)
        sct.rmtree(path_tmp)

    # Generate QC report
    if param.path_qc is not None:
        path_qc = os.path.abspath(param.path_qc)
        qc_dataset = arguments.get("-qc-dataset", None)
        qc_subject = arguments.get("-qc-subject", None)
        labeled_seg_file = os.path.join(path_output, file_seg + '_labeled' + ext_seg)
        generate_qc(fname_in, fname_seg=labeled_seg_file, args=args, path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset, subject=qc_subject, process='sct_label_vertebrae')

    sct.display_viewer_syntax([fname_in, fname_seg_labeled], colormaps=['', 'subcortical'], opacities=['1', '0.5'])
Пример #6
0
def main(args=None):

    # initializations
    initz = ''
    initcenter = ''
    fname_initlabel = ''
    file_labelz = 'labelz.nii.gz'
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_in = os.path.abspath(arguments["-i"])
    fname_seg = os.path.abspath(arguments['-s'])
    contrast = arguments['-c']
    path_template = os.path.abspath(arguments['-t'])
    scale_dist = arguments['-scale-dist']
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = os.curdir
    param.path_qc = arguments.get("-qc", None)
    if '-discfile' in arguments:
        fname_disc = os.path.abspath(arguments['-discfile'])
    else:
        fname_disc = None
    if '-initz' in arguments:
        initz = arguments['-initz']
    if '-initcenter' in arguments:
        initcenter = arguments['-initcenter']
    # if user provided text file, parse and overwrite arguments
    if '-initfile' in arguments:
        file = open(arguments['-initfile'], 'r')
        initfile = ' ' + file.read().replace('\n', '')
        arg_initfile = initfile.split(' ')
        for idx_arg, arg in enumerate(arg_initfile):
            if arg == '-initz':
                initz = [int(x) for x in arg_initfile[idx_arg + 1].split(',')]
            if arg == '-initcenter':
                initcenter = int(arg_initfile[idx_arg + 1])
    if '-initlabel' in arguments:
        # get absolute path of label
        fname_initlabel = os.path.abspath(arguments['-initlabel'])
    if '-param' in arguments:
        param.update(arguments['-param'][0])
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    remove_temp_files = int(arguments['-r'])
    denoise = int(arguments['-denoise'])
    laplacian = int(arguments['-laplacian'])

    path_tmp = sct.tmp_create(basename="label_vertebrae", verbose=verbose)

    # Copying input data to tmp folder
    sct.printv('\nCopying input data to tmp folder...', verbose)
    Image(fname_in).save(os.path.join(path_tmp, "data.nii"))
    Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii"))

    # Go go temp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Straighten spinal cord
    sct.printv('\nStraighten spinal cord...', verbose)
    # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
    cache_sig = sct.cache_signature(input_files=[fname_in, fname_seg], )
    cachefile = os.path.join(curdir, "straightening.cache")
    if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(
            os.path.join(
                curdir, "warp_curve2straight.nii.gz")) and os.path.isfile(
                    os.path.join(
                        curdir,
                        "warp_straight2curve.nii.gz")) and os.path.isfile(
                            os.path.join(curdir, "straight_ref.nii.gz")):
        # if they exist, copy them into current folder
        sct.printv('Reusing existing warping field which seems to be valid',
                   verbose, 'warning')
        sct.copy(os.path.join(curdir, "warp_curve2straight.nii.gz"),
                 'warp_curve2straight.nii.gz')
        sct.copy(os.path.join(curdir, "warp_straight2curve.nii.gz"),
                 'warp_straight2curve.nii.gz')
        sct.copy(os.path.join(curdir, "straight_ref.nii.gz"),
                 'straight_ref.nii.gz')
        # apply straightening
        s, o = sct.run([
            'sct_apply_transfo', '-i', 'data.nii', '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'data_straight.nii'
        ])
    else:
        sct_straighten_spinalcord.main(args=[
            '-i',
            'data.nii',
            '-s',
            'segmentation.nii',
            '-r',
            str(remove_temp_files),
            '-v',
            str(verbose),
        ])
        sct.cache_save(cachefile, cache_sig)

    # resample to 0.5mm isotropic to match template resolution
    sct.printv('\nResample to 0.5mm isotropic...', verbose)
    s, o = sct.run([
        'sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x',
        'linear', '-o', 'data_straightr.nii'
    ],
                   verbose=verbose)

    # Apply straightening to segmentation
    # N.B. Output is RPI
    sct.printv('\nApply straightening to segmentation...', verbose)
    sct.run(
        'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
        ('segmentation.nii', 'data_straightr.nii',
         'warp_curve2straight.nii.gz', 'segmentation_straight.nii', 'Linear'),
        verbose=verbose,
        is_sct_binary=True,
    )
    # Threshold segmentation at 0.5
    sct.run([
        'sct_maths', '-i', 'segmentation_straight.nii', '-thr', '0.5', '-o',
        'segmentation_straight.nii'
    ], verbose)

    # If disc label file is provided, label vertebrae using that file instead of automatically
    if fname_disc:
        # Apply straightening to disc-label
        sct.printv('\nApply straightening to disc labels...', verbose)
        sct.run(
            'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
            (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labeldisc_straight.nii.gz', 'NearestNeighbor'),
            verbose=verbose,
            is_sct_binary=True,
        )
        label_vert('segmentation_straight.nii',
                   'labeldisc_straight.nii.gz',
                   verbose=1)

    else:
        # create label to identify disc
        sct.printv('\nCreate label to identify disc...', verbose)
        fname_labelz = os.path.join(path_tmp, file_labelz)
        if initz or initcenter:
            if initcenter:
                # find z centered in FOV
                nii = Image('segmentation.nii').change_orientation("RPI")
                nx, ny, nz, nt, px, py, pz, pt = nii.dim  # Get dimensions
                z_center = int(np.round(nz / 2))  # get z_center
                initz = [z_center, initcenter]
            # create single label and output as labels.nii.gz
            label = ProcessLabels(
                'segmentation.nii',
                fname_output='tmp.labelz.nii.gz',
                coordinates=['{},{}'.format(initz[0], initz[1])])
            im_label = label.process('create-seg')
            im_label.data = dilate(
                im_label.data, 3,
                'ball')  # TODO: create a dilation method specific to labels,
            # which does not apply a convolution across all voxels (highly inneficient)
            im_label.save(fname_labelz)
        elif fname_initlabel:
            Image(fname_initlabel).save(fname_labelz)
        else:
            # automatically finds C2-C3 disc
            im_data = Image('data.nii')
            im_seg = Image('segmentation.nii')
            if not remove_temp_files:  # because verbose is here also used for keeping temp files
                verbose_detect_c2c3 = 2
            else:
                verbose_detect_c2c3 = 0
            im_label_c2c3 = detect_c2c3(im_data,
                                        im_seg,
                                        contrast,
                                        verbose=verbose_detect_c2c3)
            ind_label = np.where(im_label_c2c3.data)
            if not np.size(ind_label) == 0:
                im_label_c2c3.data[ind_label] = 3
            else:
                sct.printv(
                    'Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils',
                    1, 'error')
                sys.exit()
            im_label_c2c3.save(fname_labelz)

        # dilate label so it is not lost when applying warping
        dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz)

        # Apply straightening to z-label
        sct.printv('\nAnd apply straightening to label...', verbose)
        sct.run(
            'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
            (file_labelz, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labelz_straight.nii.gz', 'NearestNeighbor'),
            verbose=verbose,
            is_sct_binary=True,
        )
        # get z value and disk value to initialize labeling
        sct.printv('\nGet z and disc values from straight label...', verbose)
        init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz')
        sct.printv('.. ' + str(init_disc), verbose)

        # denoise data
        if denoise:
            sct.printv('\nDenoise data...', verbose)
            sct.run([
                'sct_maths', '-i', 'data_straightr.nii', '-denoise', 'h=0.05',
                '-o', 'data_straightr.nii'
            ], verbose)

        # apply laplacian filtering
        if laplacian:
            sct.printv('\nApply Laplacian filter...', verbose)
            sct.run([
                'sct_maths', '-i', 'data_straightr.nii', '-laplacian', '1',
                '-o', 'data_straightr.nii'
            ], verbose)

        # detect vertebral levels on straight spinal cord
        init_disc[1] = init_disc[1] - 1
        vertebral_detection('data_straightr.nii',
                            'segmentation_straight.nii',
                            contrast,
                            param,
                            init_disc=init_disc,
                            verbose=verbose,
                            path_template=path_template,
                            path_output=path_output,
                            scale_dist=scale_dist)

    # un-straighten labeled spinal cord
    sct.printv('\nUn-straighten labeling...', verbose)
    sct.run(
        'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
        ('segmentation_straight_labeled.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled.nii',
         'NearestNeighbor'),
        verbose=verbose,
        is_sct_binary=True,
    )
    # Clean labeled segmentation
    sct.printv(
        '\nClean labeled segmentation (correct interpolation errors)...',
        verbose)
    clean_labeled_segmentation('segmentation_labeled.nii', 'segmentation.nii',
                               'segmentation_labeled.nii')

    # label discs
    sct.printv('\nLabel discs...', verbose)
    label_discs('segmentation_labeled.nii', verbose=verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    path_seg, file_seg, ext_seg = sct.extract_fname(fname_seg)
    fname_seg_labeled = os.path.join(path_output,
                                     file_seg + '_labeled' + ext_seg)
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(
        os.path.join(path_tmp, "segmentation_labeled.nii"), fname_seg_labeled)
    sct.generate_output_file(
        os.path.join(path_tmp, "segmentation_labeled_disc.nii"),
        os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
    # copy straightening files in case subsequent SCT functions need them
    sct.generate_output_file(
        os.path.join(path_tmp, "warp_curve2straight.nii.gz"),
        os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
    sct.generate_output_file(
        os.path.join(path_tmp, "warp_straight2curve.nii.gz"),
        os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"),
                             os.path.join(path_output, "straight_ref.nii.gz"),
                             verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        sct.printv('\nRemove temporary files...', verbose)
        sct.rmtree(path_tmp)

    # Generate QC report
    if param.path_qc is not None:
        path_qc = os.path.abspath(param.path_qc)
        qc_dataset = arguments.get("-qc-dataset", None)
        qc_subject = arguments.get("-qc-subject", None)
        labeled_seg_file = os.path.join(path_output,
                                        file_seg + '_labeled' + ext_seg)
        generate_qc(fname_in,
                    fname_seg=labeled_seg_file,
                    args=args,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_label_vertebrae')

    sct.display_viewer_syntax([fname_in, fname_seg_labeled],
                              colormaps=['', 'subcortical'],
                              opacities=['1', '0.5'])
    def straighten(self):
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        gapxy = self.gapxy
        gapz = self.gapz
        padding = self.padding
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp
        algo_fitting = self.algo_fitting
        window_length = self.window_length
        type_window = self.type_window
        crop = self.crop

        # start timer
        start_time = time.time()

        # get path of the toolbox
        status, path_sct = commands.getstatusoutput("echo $SCT_DIR")
        sct.printv(path_sct, verbose)

        if self.debug == 1:
            print "\n*** WARNING: DEBUG MODE ON ***\n"
            fname_anat = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/anat_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2.nii.gz'
            fname_centerline = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/centerline_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2_seg.nii.gz'
            remove_temp_files = 0
            type_window = "hanning"
            verbose = 2

        # check existence of input files
        sct.check_file_exist(fname_anat, verbose)
        sct.check_file_exist(fname_centerline, verbose)

        # Display arguments
        sct.printv("\nCheck input arguments...", verbose)
        sct.printv("  Input volume ...................... " + fname_anat, verbose)
        sct.printv("  Centerline ........................ " + fname_centerline, verbose)
        sct.printv("  Final interpolation ............... " + interpolation_warp, verbose)
        sct.printv("  Verbose ........................... " + str(verbose), verbose)
        sct.printv("", verbose)

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
        path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

        # create temporary folder
        path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
        sct.run("mkdir " + path_tmp, verbose)

        # copy files into tmp folder
        sct.run("cp " + fname_anat + " " + path_tmp, verbose)
        sct.run("cp " + fname_centerline + " " + path_tmp, verbose)

        # go to tmp folder
        os.chdir(path_tmp)

        try:
            # Change orientation of the input centerline into RPI
            sct.printv("\nOrient centerline to RPI orientation...", verbose)
            fname_centerline_orient = file_centerline + "_rpi.nii.gz"
            set_orientation(file_centerline + ext_centerline, "RPI", fname_centerline_orient)

            # Get dimension
            sct.printv("\nGet dimensions...", verbose)
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
            sct.printv(".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz), verbose)
            sct.printv(".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm", verbose)

            # smooth centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_centerline_orient,
                algo_fitting=algo_fitting,
                type_window=type_window,
                window_length=window_length,
                verbose=verbose,
            )

            # Get coordinates of landmarks along curved centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along curved centerline...", verbose)
            # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy. In voxel space!!!

            # find z indices along centerline given a specific gap: iz_curved
            nz_nonz = len(z_centerline)
            nb_landmark = int(round(float(nz_nonz) / gapz))

            if nb_landmark == 0:
                nb_landmark = 1

            if nb_landmark == 1:
                iz_curved = [0]
            else:
                iz_curved = [i * gapz for i in range(0, nb_landmark - 1)]

            iz_curved.append(nz_nonz - 1)
            # print iz_curved, len(iz_curved)
            n_iz_curved = len(iz_curved)
            # print n_iz_curved

            # landmark_curved initialisation
            # landmark_curved = [ [ [ 0 for i in range(0, 3)] for i in range(0, 5) ] for i in iz_curved ]

            from msct_types import Coordinate

            landmark_curved = []
            landmark_curved_value = 1

            ### TODO: THIS PART IS SLOW AND CAN BE MADE FASTER
            ### >>==============================================================================================================
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # calculate d (ax+by+cz+d=0)
                    # print iz_curved[index]
                    a = x_centerline_deriv[iz]
                    b = y_centerline_deriv[iz]
                    c = z_centerline_deriv[iz]
                    x = x_centerline_fit[iz]
                    y = y_centerline_fit[iz]
                    z = z_centerline[iz]
                    d = -(a * x + b * y + c * z)
                    # print a,b,c,d,x,y,z
                    # set coordinates for landmark at the center of the cross
                    coord = Coordinate([0, 0, 0, landmark_curved_value])
                    coord.x, coord.y, coord.z = x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz]
                    landmark_curved.append(coord)

                    # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
                    cross_coordinates = [
                        Coordinate([0, 0, 0, landmark_curved_value + 1]),
                        Coordinate([0, 0, 0, landmark_curved_value + 2]),
                        Coordinate([0, 0, 0, landmark_curved_value + 3]),
                        Coordinate([0, 0, 0, landmark_curved_value + 4]),
                    ]

                    cross_coordinates[0].y = y_centerline_fit[iz]
                    cross_coordinates[1].y = y_centerline_fit[iz]

                    # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
                    x_n = Symbol("x_n")
                    cross_coordinates[1].x, cross_coordinates[0].x = solve(
                        (x_n - x) ** 2 + ((-1 / c) * (a * x_n + b * y + d) - z) ** 2 - gapxy ** 2, x_n
                    )  # x for -x and +x
                    cross_coordinates[0].z = (-1 / c) * (a * cross_coordinates[0].x + b * y + d)  # z for +x
                    cross_coordinates[1].z = (-1 / c) * (a * cross_coordinates[1].x + b * y + d)  # z for -x

                    # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
                    cross_coordinates[2].x = x_centerline_fit[iz]
                    cross_coordinates[3].x = x_centerline_fit[iz]

                    # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
                    y_n = Symbol("y_n")
                    cross_coordinates[3].y, cross_coordinates[2].y = solve(
                        (y_n - y) ** 2 + ((-1 / c) * (a * x + b * y_n + d) - z) ** 2 - gapxy ** 2, y_n
                    )  # y for -y and +y
                    cross_coordinates[2].z = (-1 / c) * (a * x + b * cross_coordinates[2].y + d)  # z for +y
                    cross_coordinates[3].z = (-1 / c) * (a * x + b * cross_coordinates[3].y + d)  # z for -y

                    for coord in cross_coordinates:
                        landmark_curved.append(coord)
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_curved.append(
                            Coordinate(
                                [x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz], landmark_curved_value],
                                mode="continuous",
                            )
                        )
                        landmark_curved_value += 1
            ### <<==============================================================================================================

            # Get coordinates of landmarks along straight centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along straight centerline...", verbose)
            # landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved

            landmark_straight = []

            # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
            # TODO: DO NOT APPROXIMATE CURVE --> LINE
            if nb_landmark == 1:
                iz_straight = [0 for i in range(0, nb_landmark + 1)]
            else:
                iz_straight = [0 for i in range(0, nb_landmark)]

            # print iz_straight,len(iz_straight)
            iz_straight[0] = iz_curved[0]
            for index in range(1, n_iz_curved, 1):
                # compute vector between two consecutive points on the curved centerline
                vector_centerline = [
                    x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index - 1]],
                    y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index - 1]],
                    z_centerline[iz_curved[index]] - z_centerline[iz_curved[index - 1]],
                ]
                # compute norm of this vector
                norm_vector_centerline = linalg.norm(vector_centerline, ord=2)
                # round to closest integer value
                norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
                # assign this value to the current z-coordinate on the straight centerline
                iz_straight[index] = iz_straight[index - 1] + norm_vector_centerline_rounded

            # initialize x0 and y0 to be at the center of the FOV
            x0 = int(round(nx / 2))
            y0 = int(round(ny / 2))
            landmark_curved_value = 1
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # set coordinates for landmark at the center of the cross
                    landmark_straight.append(Coordinate([x0, y0, iz_straight[index], landmark_curved_value]))
                    # set x, y and z coordinates for landmarks +x
                    landmark_straight.append(
                        Coordinate([x0 + gapxy, y0, iz_straight[index], landmark_curved_value + 1])
                    )
                    # set x, y and z coordinates for landmarks -x
                    landmark_straight.append(
                        Coordinate([x0 - gapxy, y0, iz_straight[index], landmark_curved_value + 2])
                    )
                    # set x, y and z coordinates for landmarks +y
                    landmark_straight.append(
                        Coordinate([x0, y0 + gapxy, iz_straight[index], landmark_curved_value + 3])
                    )
                    # set x, y and z coordinates for landmarks -y
                    landmark_straight.append(
                        Coordinate([x0, y0 - gapxy, iz_straight[index], landmark_curved_value + 4])
                    )
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_straight.append(Coordinate([x0, y0, iz, landmark_curved_value]))
                        landmark_curved_value += 1

            # Create NIFTI volumes with landmarks
            # ==========================================================================================
            # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
            # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
            # sct.run('fslview ' + fname_centerline_orient)
            sct.printv("\nPad input volume to account for landmarks that fall outside the FOV...", verbose)
            sct.run(
                "isct_c3d "
                + fname_centerline_orient
                + " -pad "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox 0 -o tmp.centerline_pad.nii.gz",
                verbose,
            )

            # Open padded centerline for reading
            sct.printv("\nOpen padded centerline for reading...", verbose)
            file = load("tmp.centerline_pad.nii.gz")
            data = file.get_data()
            hdr = file.get_header()

            if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                # Reorganize landmarks
                points_fixed, points_moving = [], []
                for coord in landmark_straight:
                    points_fixed.append([coord.x, coord.y, coord.z])
                for coord in landmark_curved:
                    points_moving.append([coord.x, coord.y, coord.z])

                # Register curved landmarks on straight landmarks based on python implementation
                sct.printv("\nComputing rigid transformation (algo=" + self.algo_landmark_rigid + ") ...", verbose)
                import msct_register_landmarks

                (
                    rotation_matrix,
                    translation_array,
                    points_moving_reg,
                ) = msct_register_landmarks.getRigidTransformFromLandmarks(
                    points_fixed, points_moving, constraints=self.algo_landmark_rigid, show=False
                )

                # reorganize registered points
                landmark_curved_rigid = []
                for index_curved, ind in enumerate(range(0, len(points_moving_reg), 1)):
                    coord = Coordinate()
                    coord.x, coord.y, coord.z, coord.value = (
                        points_moving_reg[ind][0],
                        points_moving_reg[ind][1],
                        points_moving_reg[ind][2],
                        index_curved + 1,
                    )
                    landmark_curved_rigid.append(coord)

                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_curved_rigid_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved_rigid)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of curved landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_curved_rigid[index].x)),
                        int(round(landmark_curved_rigid[index].y)),
                        int(round(landmark_curved_rigid[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_rigid_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved_rigid[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_rigid_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved_rigid.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved_rigid.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # writing rigid transformation file
                text_file = open("tmp.curve2straight_rigid.txt", "w")
                text_file.write("#Insight Transform File V1.0\n")
                text_file.write("#Transform 0\n")
                text_file.write("Transform: AffineTransform_double_3_3\n")
                text_file.write(
                    "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
                    % (
                        rotation_matrix[0, 0],
                        rotation_matrix[0, 1],
                        rotation_matrix[0, 2],
                        rotation_matrix[1, 0],
                        rotation_matrix[1, 1],
                        rotation_matrix[1, 2],
                        rotation_matrix[2, 0],
                        rotation_matrix[2, 1],
                        rotation_matrix[2, 2],
                        -translation_array[0, 0],
                        translation_array[0, 1],
                        -translation_array[0, 2],
                    )
                )
                text_file.write("FixedParameters: 0 0 0\n")
                text_file.close()

            else:
                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # Estimate deformation field by pairing landmarks
                # ==========================================================================================
                # convert landmarks to INT
                sct.printv("\nConvert landmarks to INT...", verbose)
                sct.run("isct_c3d tmp.landmarks_straight.nii.gz -type int -o tmp.landmarks_straight.nii.gz", verbose)
                sct.run("isct_c3d tmp.landmarks_curved.nii.gz -type int -o tmp.landmarks_curved.nii.gz", verbose)

                # This stands to avoid overlapping between landmarks
                sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
                label_process_straight = ProcessLabels(
                    fname_label="tmp.landmarks_straight.nii.gz",
                    fname_output="tmp.landmarks_straight.nii.gz",
                    fname_ref="tmp.landmarks_curved.nii.gz",
                    verbose=verbose,
                )
                label_process_straight.process("remove")
                label_process_curved = ProcessLabels(
                    fname_label="tmp.landmarks_curved.nii.gz",
                    fname_output="tmp.landmarks_curved.nii.gz",
                    fname_ref="tmp.landmarks_straight.nii.gz",
                    verbose=verbose,
                )
                label_process_curved.process("remove")

                # Estimate rigid transformation
                sct.printv("\nEstimate rigid transformation between paired landmarks...", verbose)
                sct.run(
                    "isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt",
                    verbose,
                )

                # Apply rigid transformation
                sct.printv("\nApply rigid transformation to curved landmarks...", verbose)
                # sct.run('sct_apply_transfo -i tmp.landmarks_curved.nii.gz -o tmp.landmarks_curved_rigid.nii.gz -d tmp.landmarks_straight.nii.gz -w tmp.curve2straight_rigid.txt -x nn', verbose)
                Transform(
                    input_filename="tmp.landmarks_curved.nii.gz",
                    source_reg="tmp.landmarks_curved_rigid.nii.gz",
                    output_filename="tmp.landmarks_straight.nii.gz",
                    warp="tmp.curve2straight_rigid.txt",
                    interp="nn",
                    verbose=verbose,
                ).apply()

            if verbose == 2:
                from mpl_toolkits.mplot3d import Axes3D
                import matplotlib.pyplot as plt

                fig = plt.figure()
                ax = Axes3D(fig)
                ax.plot(x_centerline_fit, y_centerline_fit, z_centerline, zdir="z")
                ax.plot(
                    [coord.x for coord in landmark_curved],
                    [coord.y for coord in landmark_curved],
                    [coord.z for coord in landmark_curved],
                    ".",
                )
                ax.plot(
                    [coord.x for coord in landmark_straight],
                    [coord.y for coord in landmark_straight],
                    [coord.z for coord in landmark_straight],
                    "r.",
                )
                if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                    ax.plot(
                        [coord.x for coord in landmark_curved_rigid],
                        [coord.y for coord in landmark_curved_rigid],
                        [coord.z for coord in landmark_curved_rigid],
                        "b.",
                    )
                ax.set_xlabel("x")
                ax.set_ylabel("y")
                ax.set_zlabel("z")
                plt.show()

            # This stands to avoid overlapping between landmarks
            sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_straight.nii.gz",
                fname_output="tmp.landmarks_straight.nii.gz",
                fname_ref="tmp.landmarks_curved_rigid.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_curved_rigid.nii.gz",
                fname_output="tmp.landmarks_curved_rigid.nii.gz",
                fname_ref="tmp.landmarks_straight.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")

            # Estimate b-spline transformation curve --> straight
            sct.printv("\nEstimate b-spline transformation: curve --> straight...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # remove padding for straight labels
            if crop == 1:
                ImageCropper(
                    input_file="tmp.landmarks_straight.nii.gz",
                    output_file="tmp.landmarks_straight_crop.nii.gz",
                    dim="0,1,2",
                    bmax=True,
                    verbose=verbose,
                ).crop()
                pass
            else:
                sct.run("cp tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz", verbose)

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            # sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
            # !!! DO NOT USE sct.run HERE BECAUSE isct_ComposeMultiTransform OUTPUTS A NON-NULL STATUS !!!
            cmd = "isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight_crop.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt"
            sct.printv(cmd, verbose, "code")
            sct.run(cmd, self.verbose)
            # commands.getstatusoutput(cmd)

            # Estimate b-spline transformation straight --> curve
            # TODO: invert warping field instead of estimating a new one
            sct.printv("\nEstimate b-spline transformation: straight --> curve...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            cmd = (
                "isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R "
                + file_anat
                + ext_anat
                + " -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz"
            )
            sct.printv(cmd, verbose, "code")
            # commands.getstatusoutput(cmd)
            sct.run(cmd, self.verbose)

            # Apply transformation to input image
            sct.printv("\nApply transformation to input image...", verbose)
            Transform(
                input_filename=str(file_anat + ext_anat),
                source_reg="tmp.anat_rigid_warp.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp=interpolation_warp,
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()

            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            sct.printv("\nApply transformation to centerline image...", verbose)
            # sct.run('sct_apply_transfo -i '+fname_centerline_orient+' -o tmp.centerline_straight.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x nn -w tmp.curve2straight.nii.gz')
            Transform(
                input_filename=fname_centerline_orient,
                source_reg="tmp.centerline_straight.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp="nn",
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()
            # c = sct.run('sct_crop_image -i tmp.centerline_straight.nii.gz -o tmp.centerline_straight_crop.nii.gz -dim 2 -bzmax')
            from msct_image import Image

            file_centerline_straight = Image("tmp.centerline_straight.nii.gz", verbose=verbose)
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting="z")
            mean_coord = []
            for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
                mean_coord.append(
                    mean(
                        [
                            [coord.x * coord.value, coord.y * coord.value]
                            for coord in coordinates_centerline
                            if coord.z == z
                        ],
                        axis=0,
                    )
                )

            # compute error between the input data and the nurbs
            from math import sqrt

            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            for coord_z in mean_coord:
                if not isnan(sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2
                    self.mse_straightening += dist
                    dist = sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = sqrt(self.mse_straightening / float(count_mean))

        except Exception as e:
            sct.printv("WARNING: Exception during Straightening:", 1, "warning")
            print e

        os.chdir("..")

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        sct.printv("\nGenerate output file (in current folder)...", verbose)
        sct.generate_output_file(
            path_tmp + "/tmp.curve2straight.nii.gz", "warp_curve2straight.nii.gz", verbose
        )  # warping field
        sct.generate_output_file(
            path_tmp + "/tmp.straight2curve.nii.gz", "warp_straight2curve.nii.gz", verbose
        )  # warping field
        if fname_output == "":
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", file_anat + "_straight" + ext_anat, verbose
            )  # straightened anatomic
        else:
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", fname_output, verbose
            )  # straightened anatomic
        # Remove temporary files
        if remove_temp_files:
            sct.printv("\nRemove temporary files...", verbose)
            sct.run("rm -rf " + path_tmp, verbose)

        sct.printv("\nDone!\n", verbose)

        sct.printv("Maximum x-y error = " + str(round(self.max_distance_straightening, 2)) + " mm", verbose, "bold")
        sct.printv(
            "Accuracy of straightening (MSE) = " + str(round(self.mse_straightening, 2)) + " mm", verbose, "bold"
        )
        # display elapsed time
        elapsed_time = time.time() - start_time
        sct.printv("\nFinished! Elapsed time: " + str(int(round(elapsed_time))) + "s", verbose)
        sct.printv("\nTo view results, type:", verbose)
        sct.printv("fslview " + fname_straight + " &\n", verbose, "info")
def test(path_data='', parameters=''):

    # initializations
    output = ''
    file_init_label_vertebrae = 'init_label_vertebrae.txt'
    rmse = float('NaN')
    max_dist = float('NaN')
    diff_manual_result = float('NaN')

    if not parameters:
        parameters = '-i t2/t2.nii.gz -s t2/t2_seg.nii.gz -c t2 -initfile t2/init_label_vertebrae.txt'

    # retrieve flags
    try:
        parser = sct_label_vertebrae.get_parser()
        dict_param = parser.parse(parameters.split(), check_file_exist=False)
        dict_param_with_path = parser.add_path_to_file(deepcopy(dict_param),
                                                       path_data,
                                                       input_file=True)
        # update template path because the previous command wrongly adds path to testing data
        dict_param_with_path['-t'] = dict_param['-t']
        param_with_path = parser.dictionary_to_string(dict_param_with_path)
    # in case not all mandatory flags are filled
    except SyntaxError as err:
        print err
        status = 1
        output = err
        return status, output, DataFrame(data={
            'status': int(status),
            'output': output
        },
                                         index=[path_data])

    # create output folder to deal with multithreading (i.e., we don't want to have outputs from several subjects in the current directory)
    import time, random
    subject_folder = path_data.split('/')
    if subject_folder[-1] == '' and len(subject_folder) > 1:
        subject_folder = subject_folder[-2]
    else:
        subject_folder = subject_folder[-1]
    path_output = sct.slash_at_the_end(
        'sct_label_vertebrae_' + subject_folder + '_' +
        time.strftime("%y%m%d%H%M%S") + '_' + str(random.randint(1, 1000000)),
        slash=1)
    os.mkdir(path_output)
    param_with_path += ' -ofolder ' + path_output
    # log file
    fname_log = path_output + 'output.log'

    # Extract contrast
    contrast = ''
    if dict_param['-i'][0] == '/':
        dict_param['-i'] = dict_param['-i'][1:]
    input_split = dict_param['-i'].split('/')
    if len(input_split) == 2:
        contrast = input_split[0]
    if not contrast:  # if no contrast folder, send error.
        status = 1
        output += '\nERROR: when extracting the contrast folder from input file in command line: ' + dict_param[
            '-i'] + ' for ' + path_data
        write_to_log_file(fname_log, output, 'w')
        return status, output, DataFrame(data={
            'status': status,
            'output': output,
            'dice_segmentation': float('nan')
        },
                                         index=[path_data])

    # Check if input files exist
    if not os.path.isfile(dict_param_with_path['-i']):
        status = 200
        output += '\nERROR: This file does not exist: ' + dict_param_with_path[
            '-i']
        write_to_log_file(fname_log, output, 'w')
        return status, output, DataFrame(data={
            'status': int(status),
            'output': output
        },
                                         index=[path_data])
    if not os.path.isfile(dict_param_with_path['-s']):
        status = 200
        output += '\nERROR: This file does not exist: ' + dict_param_with_path[
            '-s']
        write_to_log_file(fname_log, output, 'w')
        return status, output, DataFrame(data={
            'status': int(status),
            'output': output
        },
                                         index=[path_data])

    # open ground truth
    fname_labels_manual = path_data + contrast + '/' + contrast + '_labeled_center_manual.nii.gz'
    try:
        label_manual = ProcessLabels(fname_labels_manual)
        list_label_manual = label_manual.image_input.getNonZeroCoordinates(
            sorting='value')
    except:
        status = 201
        output += '\nERROR: cannot file: ' + fname_labels_manual
        write_to_log_file(fname_log, output, 'w')
        return status, output, DataFrame(data={
            'status': int(status),
            'output': output
        },
                                         index=[path_data])

    cmd = 'sct_label_vertebrae ' + param_with_path
    output = '\n====================================================================================================\n' + cmd + '\n====================================================================================================\n\n'  # copy command
    time_start = time.time()
    try:
        status, o = sct.run(cmd, 0)
    except:
        status, o = 1, '\nERROR: Function crashed!'
    output += o
    duration = time.time() - time_start

    # initialization of results: must be NaN if test fails
    result_mse = float('nan'), float('nan')

    if status == 0:
        # copy input data (for easier debugging)
        sct.run('cp ' + dict_param_with_path['-i'] + ' ' + path_output,
                verbose=0)
        # extract center of vertebral labels
        path_seg, file_seg, ext_seg = sct.extract_fname(dict_param['-s'])
        try:
            sct.run('sct_label_utils -i ' + path_output + file_seg +
                    '_labeled.nii.gz -vert-body 0 -o ' + path_output +
                    contrast + '_seg_labeled_center.nii.gz',
                    verbose=0)
            label_results = ProcessLabels(path_output + contrast +
                                          '_seg_labeled_center.nii.gz')
            list_label_results = label_results.image_input.getNonZeroCoordinates(
                sorting='value')
            # get dimension
            # from msct_image import Image
            # img = Image(path_output+contrast+'_seg_labeled.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = label_results.image_input.dim
        except:
            status = 1
            output += '\nERROR: cannot open file: ' + path_output + contrast + '_seg_labeled.nii.gz'
            write_to_log_file(fname_log, output, 'w')
            return status, output, DataFrame(data={
                'status': int(status),
                'output': output
            },
                                             index=[path_data])

        mse = 0.0
        max_dist = 0.0
        for coord_manual in list_label_manual:
            for coord in list_label_results:
                if round(coord.value) == round(coord_manual.value):
                    # Calculate MSE
                    mse += (((coord_manual.x - coord.x) * px)**2 +
                            ((coord_manual.y - coord.y) * py)**2 +
                            ((coord_manual.z - coord.z) * pz)**2) / float(3)
                    # Calculate distance (Frobenius norm)
                    dist = linalg.norm([(coord_manual.x - coord.x) * px,
                                        (coord_manual.y - coord.y) * py,
                                        (coord_manual.z - coord.z) * pz])
                    if dist > max_dist:
                        max_dist = dist
                    break
        rmse = sqrt(mse / len(list_label_manual))
        # calculate number of label mismatch
        diff_manual_result = len(list_label_manual) - len(list_label_results)

        # check if MSE is superior to threshold
        th_rmse = 2
        if rmse > th_rmse:
            status = 99
            output += '\nWARNING: RMSE = ' + str(rmse) + ' > ' + str(th_rmse)
        th_max_dist = 4
        if max_dist > th_max_dist:
            status = 99
            output += '\nWARNING: Max distance = ' + str(
                max_dist) + ' > ' + str(th_max_dist)
        th_diff_manual_result = 3
        if abs(diff_manual_result) > th_diff_manual_result:
            status = 99
            output += '\nWARNING: Diff manual-result = ' + str(
                diff_manual_result) + ' > ' + str(th_diff_manual_result)

    # transform results into Pandas structure
    results = DataFrame(data={
        'status': int(status),
        'output': output,
        'rmse': rmse,
        'max_dist': max_dist,
        'diff_man': diff_manual_result,
        'duration [s]': duration
    },
                        index=[path_data])

    # write log file
    write_to_log_file(fname_log, output, 'w')

    return status, output, results