def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *= 0
        image_input_pos.data *= 0
        X, Y, Z = (self.image_input.data < 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[
                X[i], Y[i], Z[i]]  # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data > 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i],
                                 Z_pos[i]] = self.image_input.data[X_pos[i],
                                                                   Y_pos[i],
                                                                   Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(
                coord.z
            )] = -coord.value  # PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output
Esempio n. 2
0
    def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *=0
        image_input_pos.data *=0
        X, Y, Z = (self.image_input.data< 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]] # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data> 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(coord.z)] = -coord.value #PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output
def check_labels(fname_landmarks):
    """
    Make sure input labels are consistent
    Parameters
    ----------
    fname_landmarks: file name of input labels

    Returns
    -------
    none
    """
    sct.printv('\nCheck input labels...')
    # open label file
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    # check if there is two labels
    if not len(labels) == 2:
        sct.printv('ERROR: Label file has ' + str(len(labels)) + ' label(s). It must contain exactly two labels.', 1, 'error')
    # check if the two labels are integer
    for label in labels:
        if not int(label.value) == label.value:
            sct.printv('ERROR: Label should be integer.', 1, 'error')
    # check if the two labels are different
    if labels[0].value == labels[1].value:
        sct.printv('ERROR: The two labels must be different.', 1, 'error')
    return labels
Esempio n. 4
0
def compute_ICBM152_centerline(dataset_info):
    """
    This function extracts the centerline from the ICBM152 brain template
    :param dataset_info: dictionary containing dataset information
    :return:
    """
    path_data = dataset_info['path_data']

    if not os.path.isdir(path_data + 'icbm152/'):
        download_data_template(path_data=path_data, name='icbm152', force=False)

    image_disks = Image(path_data + 'icbm152/mni_icbm152_t1_tal_nlin_sym_09c_disks_manual.nii.gz')
    coord = image_disks.getNonZeroCoordinates(sorting='z', reverse_coord=True)
    coord_physical = []

    for c in coord:
        if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
            c_p = image_disks.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            c_p.append(c.value)
            coord_physical.append(c_p)

    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        path_data + 'icbm152/mni_icbm152_t1_centerline_manual.nii.gz', algo_fitting='nurbs',
        verbose=0, nurbs_pts_number=300, all_slices=False, phys_coordinates=True, remove_outliers=False)

    centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                            x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)

    centerline.compute_vertebral_distribution(coord_physical, label_reference='PMG')
    return centerline
Esempio n. 5
0
def check_labels(fname_landmarks, label_type='body'):
    """
    Make sure input labels are consistent
    Parameters
    ----------
    fname_landmarks: file name of input labels
    label_type: 'body', 'disc'
    Returns
    -------
    none
    """
    sct.printv('\nCheck input labels...')
    # open label file
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    # check if there is two labels
    if label_type == 'body' and not len(labels) == 2:
        sct.printv(
            'ERROR: Label file has ' + str(len(labels)) +
            ' label(s). It must contain exactly two labels.', 1, 'error')
    # check if labels are integer
    for label in labels:
        if not int(label.value) == label.value:
            sct.printv('ERROR: Label should be integer.', 1, 'error')
    # check if there are duplicates in label values
    n_labels = len(labels)
    list_values = [labels[i].value for i in xrange(0, n_labels)]
    list_duplicates = [x for x in list_values if list_values.count(x) > 1]
    if not list_duplicates == []:
        sct.printv('ERROR: Found two labels with same value.', 1, 'error')
    return labels
Esempio n. 6
0
def generate_centerline(dataset_info, contrast='t1', regenerate=False):
    """
    This function generates spinal cord centerline from binary images (either an image of centerline or segmentation)
    :param dataset_info: dictionary containing dataset information
    :param contrast: {'t1', 't2'}
    :return list of centerline objects
    """
    path_data = dataset_info['path_data']
    list_subjects = dataset_info['subjects']
    list_centerline = []

    current_path = os.getcwd()

    timer_centerline = sct.Timer(len(list_subjects))
    timer_centerline.start()
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        fname_image_centerline = path_data_subject + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
        fname_image_disks = path_data_subject + contrast + dataset_info['suffix_disks'] + '.nii.gz'

        # go to output folder
        sct.printv('\nExtracting centerline from ' + path_data_subject)
        os.chdir(path_data_subject)

        fname_centerline = 'centerline'
        # if centerline exists, we load it, if not, we compute it
        if os.path.isfile(fname_centerline + '.npz') and not regenerate:
            centerline = Centerline(fname=path_data_subject + fname_centerline + '.npz')
        else:
            # extracting intervertebral disks
            im = Image(fname_image_disks)
            coord = im.getNonZeroCoordinates(sorting='z', reverse_coord=True)
            coord_physical = []
            for c in coord:
                if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
                    c_p = im.transfo_pix2phys([[c.x, c.y, c.z]])[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)

            # extracting centerline from binary image and create centerline object with vertebral distribution
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_image_centerline, algo_fitting='nurbs',
                verbose=0, nurbs_pts_number=4000, all_slices=False, phys_coordinates=True, remove_outliers=False)
            centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                    x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
            centerline.compute_vertebral_distribution(coord_physical)
            centerline.save_centerline(fname_output=fname_centerline)

        list_centerline.append(centerline)
        timer_centerline.add_iteration()
    timer_centerline.stop()

    os.chdir(current_path)

    return list_centerline
Esempio n. 7
0
def centerline2roi(fname_image, folder_output='./', verbose=0):
    """
    Tis method converts a binary centerline image to a .roi centerline file

    Args:
        fname_image: filename of the binary centerline image, in RPI orientation
        folder_output: path to output folder where to copy .roi centerline
        verbose: adjusts the verbosity of the logging.

    Returns: filename of the .roi centerline that has been created

    """
    path_data, file_data, ext_data = sct.extract_fname(fname_image)
    fname_output = file_data + '.roi'

    date_now = datetime.now()
    ROI_TEMPLATE = 'Begin Marker ROI\n' \
                   '  Build version="7.0_33"\n' \
                   '  Annotation=""\n' \
                   '  Colour=0\n' \
                   '  Image source="{fname_segmentation}"\n' \
                   '  Created  "{creation_date}" by Operator ID="SCT"\n' \
                   '  Slice={slice_num}\n' \
                   '  Begin Shape\n' \
                   '    X={position_x}; Y={position_y}\n' \
                   '  End Shape\n' \
                   'End Marker ROI\n'

    im = Image(fname_image)
    nx, ny, nz, nt, px, py, pz, pt = im.dim
    coordinates_centerline = im.getNonZeroCoordinates(sorting='z')

    f = open(fname_output, "w")
    sct.printv('\nWriting ROI file...', verbose)

    for coord in coordinates_centerline:
        coord_phys_center = im.transfo_pix2phys([[(nx - 1) / 2.0,
                                                  (ny - 1) / 2.0, coord.z]])[0]
        coord_phys = im.transfo_pix2phys([[coord.x, coord.y, coord.z]])[0]
        f.write(
            ROI_TEMPLATE.format(
                fname_segmentation=fname_image,
                creation_date=date_now.strftime("%d %B %Y %H:%M:%S.%f %Z"),
                slice_num=coord.z + 1,
                position_x=coord_phys_center[0] - coord_phys[0],
                position_y=coord_phys_center[1] - coord_phys[1]))

    f.close()

    if os.path.abspath(folder_output) != os.getcwd():
        shutil.copy(fname_output, folder_output)

    return fname_output
def main():

    # get default parameters
    step1 = Paramreg(step='1', type='seg', algo='slicereg', metric='MeanSquares', iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(name="-l",
                      type_value="file",
                      description="Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
                      mandatory=True,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(name="-p",
                      type_value=[[':'], 'str'],
                      description="""Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""+paramreg.steps['1'].type+"""\nalgo="""+paramreg.steps['1'].algo+"""\nmetric="""+paramreg.steps['1'].metric+"""\npoly="""+paramreg.steps['1'].poly+"""\n--\nstep=2\ntype="""+paramreg.steps['2'].type+"""\nalgo="""+paramreg.steps['2'].algo+"""\nmetric="""+paramreg.steps['2'].metric+"""\niter="""+paramreg.steps['2'].iter+"""\nshrink="""+paramreg.steps['2'].shrink+"""\nsmooth="""+paramreg.steps['2'].smooth+"""\ngradStep="""+paramreg.steps['2'].gradStep+"""\n--""",
                      mandatory=False,
                      example="step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3")
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
                      mandatory=False,
                      default_value=param.verbose,
                      example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates()
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'/data.nii')
    sct.run('isct_c3d '+fname_landmarks+' -o '+path_tmp+'/landmarks.nii.gz')
    sct.run('isct_c3d '+fname_seg+' -o '+path_tmp+'/segmentation.nii.gz')
    sct.run('isct_c3d '+fname_template+' -o '+path_tmp+'/template.nii')
    sct.run('isct_c3d '+fname_template_label+' -o '+path_tmp+'/template_labels.nii.gz')
    sct.run('isct_c3d '+fname_template_seg+' -o '+path_tmp+'/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run('sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz')

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d template_label.nii.gz -type int -o template_label.nii.gz', verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run('sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5')

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv('\nCreate a 5mm cross for the input labels and dilate for straightening preparation...', verbose)
    sct.run('sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn')

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run('isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz')

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv('\nEstimate affine transfo: straight anat --> template (landmark-based)...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetAffineTransform template_label_cross.nii.gz landmarks_rpi_cross3x3_straight.nii.gz affine straight2templateAffine.txt')

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear')

    # find min-max of anat2template (for subsequent cropping)
    sct.run('export FSLOUTPUTTYPE=NIFTI; fslmaths segmentation_rpi_straight2templateAffine.nii.gz -thr 0.5 segmentation_rpi_straight2templateAffine_th.nii.gz', param.verbose)
    zmin_template, zmax_template = find_zmin_zmax('segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.run('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'/warp_template2anat.nii.gz', 'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'/warp_anat2template.nii.gz', 'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'/template2anat.nii.gz', 'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'/anat2template.nii.gz', 'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 anat2template &\n', verbose, 'info')
Esempio n. 9
0
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose=1, vertebral_levels=None, value=None):
        self.image_input = Image(fname_label, verbose=verbose)
        self.image_ref = None
        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.vertebral_levels = vertebral_levels
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose
        self.value = value

    def process(self, type_process):
        # for some processes, change orientation of input image to RPI
        change_orientation = False
        if type_process in ['vert-body', 'vert-disc', 'vert-continuous']:
            # get orientation of input image
            input_orientation = self.image_input.orientation
            # change orientation
            self.image_input.change_orientation('RPI')
            change_orientation = True
        if type_process == 'add':
            self.output_image = self.add(self.value)
        if type_process == 'cross':
            self.output_image = self.cross()
        if type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        if type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        if type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        if type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        if type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        if type_process == 'remove':
            self.output_image = self.remove_label()
        if type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        if type_process == 'centerline':
            self.extract_centerline()
        if type_process == 'create':
            self.output_image = self.create_label()
        if type_process == 'create-add':
            self.output_image = self.create_label(add=True)
        if type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        if type_process == 'diff':
            self.diff()
            self.fname_output = None
        if type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        if type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        if type_process == 'vert-body':
            self.output_image = self.label_vertebrae(self.vertebral_levels)
        # if type_process == 'vert-disc':
        #     self.output_image = self.label_disc(self.vertebral_levels)
        # if type_process == 'label-vertebrae-from-disks':
        #     self.output_image = self.label_vertebrae_from_disks(self.vertebral_levels)
        if type_process == 'vert-continuous':
            self.output_image = self.continuous_vertebral_levels()

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            if change_orientation:
                self.output_image.change_orientation(input_orientation)
            if type_process == 'vert-continuous':
                self.output_image.save('float32')
            elif type_process != 'plan_ref':
                self.output_image.save('minimize_int')
            else:
                self.output_image.save()

    def add(self, value):
        """
        This function add a specified value to all non-zero voxels.
        """
        image_output = Image(self.image_input, self.verbose)
        # image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = image_output.data[int(coord.x), int(coord.y), int(coord.z)] + float(value)
        return image_output


    def create_label(self, add=False):
        """
        Create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coord.value

        return image_output


    def cross(self):
        """
        create a cross.
        :return:
        """
        output_image = Image(self.image_input, self.verbose)
        nx, ny, nz, nt, px, py, pz, pt = Image(self.image_input.absolutepath).dim

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # clean output_image
        output_image.data *= 0

        cross_coordinates = self.get_crosses_coordinates(coordinates_input, dx, self.image_ref, self.dilate)

        for coord in cross_coordinates:
            output_image.data[int(round(coord.x)), int(round(coord.y)), int(round(coord.z))] = coord.value

        return output_image


    @staticmethod
    def get_crosses_coordinates(coordinates_input, gapxy=15, image_ref=None, dilate=False):
        from msct_types import Coordinate

        # if reference image is provided (segmentation), we draw the cross perpendicular to the centerline
        if image_ref is not None:
            # smooth centerline
            from sct_straighten_spinalcord import smooth_centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(self.image_ref, verbose=self.verbose)

        # compute crosses
        cross_coordinates = []
        for coord in coordinates_input:
            if image_ref is None:
                from sct_straighten_spinalcord import compute_cross
                cross_coordinates_temp = compute_cross(coord, gapxy)
            else:
                from sct_straighten_spinalcord import compute_cross_centerline
                from numpy import where
                index_z = where(z_centerline == coord.z)
                deriv = Coordinate([x_centerline_deriv[index_z][0], y_centerline_deriv[index_z][0], z_centerline_deriv[index_z][0], 0.0])
                cross_coordinates_temp = compute_cross_centerline(coord, deriv, gapxy)

            for i, coord_cross in enumerate(cross_coordinates_temp):
                coord_cross.value = coord.value * 10 + i + 1

            # dilate cross to 3x3x3
            if dilate:
                additional_coordinates = []
                for coord_temp in cross_coordinates_temp:
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y+1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y+1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y+1.0, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y-1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y-1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y-1.0, coord_temp.z-1.0, coord_temp.value]))

                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y+1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y+1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y+1.0, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y-1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y-1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y-1.0, coord_temp.z-1.0, coord_temp.value]))

                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y+1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y+1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y+1.0, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y-1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y-1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y-1.0, coord_temp.z-1.0, coord_temp.value]))

                cross_coordinates_temp.extend(additional_coordinates)

            cross_coordinates.extend(cross_coordinates_temp)

        cross_coordinates = sorted(cross_coordinates, key=lambda obj: obj.value)
        return cross_coordinates


    def plan(self, width, offset=0, gap=1):
        """
        Create a plane of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:, :, int(coord.z) - width:int(coord.z) + width] = offset + gap * coord.value

        return image_output


    def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *=0
        image_input_pos.data *=0
        X, Y, Z = (self.image_input.data< 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]] # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data> 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(coord.z)] = -coord.value #PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output


    def cubic_to_point(self):
        """
        Calculate the center of mass of each group of labels and returns a file of same size with only a
        label by group at the center of mass of this group.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with
         the same value are present but separated in space, this algorithm will compute the center of mass of the two
         groups together.
        :return: image_output
        """

        # 0. Initialization of output image
        output_image = self.image_input.copy()
        output_image.data *= 0

        # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value.
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')

        # 2. Separate all coordinates into groups by value
        groups = dict()
        for coord in coordinates:
            if coord.value in groups:
                groups[coord.value].append(coord)
            else:
                groups[coord.value] = [coord]

        # 3. Compute the center of mass of each group of voxels and write them into the output image
        for value, list_coord in groups.iteritems():
            center_of_mass = sum(list_coord)/float(len(list_coord))
            sct.printv("Value = " + str(center_of_mass.value) + " : ("+str(center_of_mass.x) + ", "+str(center_of_mass.y) + ", " + str(center_of_mass.z) + ") --> ( "+ str(round(center_of_mass.x)) + ", " + str(round(center_of_mass.y)) + ", " + str(round(center_of_mass.z)) + ")", verbose=self.verbose)
            output_image.data[int(round(center_of_mass.x)), int(round(center_of_mass.y)), int(round(center_of_mass.z))] = center_of_mass.value

        return output_image


    def increment_z_inverse(self):
        """
        Take all non-zero values, sort them along the inverse z direction, and attributes the values 1,
        2, 3, etc. This function assuming RPI orientation.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = i + 1

        return image_output


    def labelize_from_disks(self):
        """
        Create an image with regions labelized depending on values from reference.
        Typically, user inputs a segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref)-1):
                if coordinates_ref[j+1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coordinates_ref[j].value

        return image_output


    def label_vertebrae(self, levels_user=None):
        """
        Find the center of mass of vertebral levels specified by the user.
        :return: image_output: Image with labels.
        """
        # get center of mass of each vertebral level
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates = image_cubic2point.getNonZeroCoordinates(sorting='value')
        # if user did not specify levels, include all:
        if levels_user[0] == 0:
            levels_user = [int(i.value) for i in list_coordinates]
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates)):
            # check if this level is NOT in levels_user
            if not levels_user.count(int(list_coordinates[i_label].value)):
                # if not, set value to zero
                image_cubic2point.data[int(list_coordinates[i_label].x), int(list_coordinates[i_label].y), int(list_coordinates[i_label].z)] = 0
        # list all labels
        return image_cubic2point


    # FUNCTION BELOW REMOVED BY JULIEN ON 2016-07-04 BECAUSE SEEMS NOT TO BE USED (AND DUPLICATION WITH ABOVE)
    # def label_vertebrae_from_disks(self, levels_user):
    #     """
    #     Find the center of mass of vertebral levels specified by the user.
    #     :param levels_user:
    #     :return:
    #     """
    #     image_cubic2point = self.cubic_to_point()
    #     # get list of coordinates for each label
    #     list_coordinates_disks = image_cubic2point.getNonZeroCoordinates(sorting='value')
    #     image_cubic2point.data *= 0
    #     # compute vertebral labels from disk labels
    #     list_coordinates_vertebrae = []
    #     for i_label in range(len(list_coordinates_disks)-1):
    #         list_coordinates_vertebrae.append((list_coordinates_disks[i_label] + list_coordinates_disks[i_label+1]) / 2.0)
    #     # loop across labels and remove those that are not listed by the user
    #     for i_label in range(len(list_coordinates_vertebrae)):
    #         # check if this level is NOT in levels_user
    #         if levels_user.count(int(list_coordinates_vertebrae[i_label].value)):
    #             image_cubic2point.data[int(list_coordinates_vertebrae[i_label].x), int(list_coordinates_vertebrae[i_label].y), int(list_coordinates_vertebrae[i_label].z)] = list_coordinates_vertebrae[i_label].value
    #
    #     return image_cubic2point


    # BELOW: UNFINISHED BUSINESS (JULIEN)
    # def label_disc(self, levels_user=None):
    #     """
    #     Find the edge of vertebral labeling file and assign value corresponding to middle coordinate between two levels.
    #     Assumes RPI orientation.
    #     :return: image_output: Image with labels.
    #     """
    #     from msct_types import Coordinate
    #     # get dim
    #     nx, ny, nz, nt, px, py, pz, pt = self.image_input.dim
    #     # initialize disc as a coordinate variable
    #     disc = []
    #     # get center of mass of each vertebral level
    #     image_cubic2point = self.cubic_to_point()
    #     # get list of coordinates for each label
    #     list_centermass = image_cubic2point.getNonZeroCoordinates(sorting='value')
    #     # if user did not specify levels, include all:
    #     if levels_user[0] == 0:
    #         levels_user = [int(i.value) for i in list_centermass]
    #     # get list of all coordinates
    #     list_coordinates = self.display_voxel()
    #     # loop across labels and remove those that are not listed by the user
    #     # for i_label in range(len(list_centermass)):
    #
    #     # TOP DISC
    #     # get coordinates for value i_level
    #     list_i_level = [list_coordinates[i] for i in xrange(len(list_coordinates)) if int(list_coordinates[i].value) == levels_user[0]]
    #     # get max z-value
    #     zmax = max([list_i_level[i].z for i in xrange(len(list_i_level))])
    #     # get coordinates corresponding to bottom voxels
    #     list_i_level_top = [list_i_level[i] for i in xrange(len(list_i_level)) if list_i_level[i].z == zmax]
    #     # get center of mass of the top and bottom voxels
    #     arr_voxels_around_disc = np.array([[list_i_level_top[i].x, list_i_level_top[i].y, list_i_level_top[i].z] for i in range(len(list_i_level_top))])
    #     centermass = list(np.mean(arr_voxels_around_disc, 0))
    #     centermass.append(levels_user[0]-1)
    #     disc.append(Coordinate(centermass))
    #     # if minimum level corresponds to z=nz, then remove it (likely corresponds to top edge of the FOV)
    #     if disc[0].z == nz:
    #         sct.printv('WARNING: Maximum level corresponds to z=0. Removing it (likely corresponds to edge of the FOV)', 1, 'warning')
    #         # remove last element of the list
    #         disc.pop()
    #
    #     # ALL DISCS
    #     # loop across values
    #     for i_level in levels_user:
    #         # get coordinates for value i_level
    #         list_i_level = [list_coordinates[i] for i in xrange(len(list_coordinates)) if int(list_coordinates[i].value) == i_level]
    #         # get min z-value
    #         zmin = min([list_i_level[i].z for i in xrange(len(list_i_level))])
    #         # get coordinates corresponding to bottom voxels
    #         list_i_level_bottom = [list_i_level[i] for i in xrange(len(list_i_level)) if list_i_level[i].z == zmin]
    #         # get center of mass
    #         # arr_i_level_bottom = np.array([[list_i_level_bottom[i].x, list_i_level_bottom[i].y] for i in range(len(list_i_level_bottom))])
    #         # centermass_i_level = ndimage.measurements.center_of_mass()
    #         try:
    #             # get coordinates for value i_level+1
    #             list_i_level_plus_one = [list_coordinates[i] for i in xrange(len(list_coordinates)) if int(list_coordinates[i].value) == i_level+1]
    #             # get max z-value
    #             zmax = max([list_i_level_plus_one[i].z for i in xrange(len(list_i_level_plus_one))])
    #             # get coordinates corresponding to top voxels
    #             list_i_level_plus_one_top = [list_i_level_plus_one[i] for i in xrange(len(list_i_level_plus_one)) if list_i_level_plus_one[i].z == zmax]
    #         except:
    #             # if maximum level was reached, ignore it and disc will be located at the centermass of the bottom z.
    #             list_i_level_plus_one_top = []
    #         # stack bottom and top voxels
    #         list_voxels_around_disc = list_i_level_bottom + list_i_level_plus_one_top
    #         # get center of mass of the top and bottom voxels
    #         arr_voxels_around_disc = np.array([[list_voxels_around_disc[i].x, list_voxels_around_disc[i].y, list_voxels_around_disc[i].z] for i in range(len(list_voxels_around_disc))])
    #         centermass = list(np.mean(arr_voxels_around_disc, 0))
    #         centermass.append(i_level)
    #         disc.append(Coordinate(centermass))
    #     # if maximum level corresponds to z=0, then remove it (likely corresponds to edge of the FOV)
    #     if disc[-1].z == 0.0:
    #         sct.printv('WARNING: Maximum level corresponds to z=0. Removing it (likely corresponds to edge of the FOV)', 1, 'warning')
    #         # remove last element of the list
    #         disc.pop()
    #
    #     # loop across labels and assign voxels in image
    #     image_cubic2point.data[:, :, :] = 0
    #     for i_label in range(len(disc)):
    #         image_cubic2point.data[int(round(disc[i_label].x)),
    #                                int(round(disc[i_label].y)),
    #                                int(round(disc[i_label].z))] = disc[i_label].value
    #
    #     # return image of labels
    #     return image_cubic2point


    def MSE(self, threshold_mse=0):
        """
        Compute the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(self.image_input.path + 'error_log_' + self.image_input.file_name + '.txt', 'w')
            f.write(
                'The labels error (MSE) between ' + self.image_input.file_name + ' and ' + self.image_ref.file_name + ' is: ' + str(
                    result))
            f.close()

        return result


    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [coord for coord in coord_input if coord in coord_intersection]
            result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [coord for coord in coord_input if filter(lambda x: x.value == coord.value, coord_ref)]
            if symmetry:
                result_coord_ref = [coord for coord in coord_ref if filter(lambda x: x.value == coord.value, result_coord_input)]

        return result_coord_input, result_coord_ref


    def remove_label(self, symmetry=False):
        """
        Compare two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = Image(self.image_input, verbose=self.verbose)
        image_output.data *= 0  # put all voxels to 0

        result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True),
                                                                       self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = int(round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[int(coord.x), int(coord.y), int(coord.z)] = int(round(coord.value))
            image_output_ref.setFileName(self.fname_output[1])
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output


    def extract_centerline(self):
        """
        Write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x,coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()


    def display_voxel(self):
        """
        Display all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')
        useful_notation = ''
        for coord in coordinates_input:
            print 'Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value)
            if useful_notation:
                useful_notation = useful_notation + ':'
            useful_notation = useful_notation + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ',' + str(coord.value)
        print 'All labels (useful syntax):'
        print useful_notation
        return coordinates_input


    def diff(self):
        """
        Detect any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        print "Label in input image that are not in reference image:"
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord.value

        print "Label in ref image that are not in input image:"
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord_ref.value


    def distance_interlabels(self, max_dist):
        """
        Calculate the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt((coordinates_input[i].x - coordinates_input[i+1].x)**2 + (coordinates_input[i].y - coordinates_input[i+1].y)**2 + (coordinates_input[i].z - coordinates_input[i+1].z)**2)
            if dist < max_dist:
                print 'Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i+1) + '[' + str(
                    coordinates_input[i+1].x) + ',' + str(coordinates_input[i+1].y) + ',' + str(coordinates_input[i+1].z) + ']=' + str(
                    coordinates_input[i+1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist)


    def continuous_vertebral_levels(self):
        """
        This function transforms the vertebral levels file from the template into a continuous file.
        Instead of having integer representing the vertebral level on each slice, a continuous value that represents
        the position of the slice in the vertebral level coordinate system.
        The image must be RPI
        :return:
        """
        im_input = Image(self.image_input, self.verbose)
        im_output = Image(self.image_input, self.verbose)
        im_output.data *= 0

        # 1. extract vertebral levels from input image
        #   a. extract centerline
        #   b. for each slice, extract corresponding level
        nx, ny, nz, nt, px, py, pz, pt = im_input.dim
        from sct_straighten_spinalcord import smooth_centerline
        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(self.image_input, algo_fitting='nurbs', verbose=0)
        value_centerline = np.array([im_input.data[int(x_centerline_fit[it]), int(y_centerline_fit[it]), int(z_centerline_fit[it])] for it in range(len(z_centerline_fit))])

        # 2. compute distance for each vertebral level --> Di for i being the vertebral levels
        vertebral_levels = {}
        for slice_image, level in enumerate(value_centerline):
            if level not in vertebral_levels:
                vertebral_levels[level] = slice_image

        length_levels = {}
        for level in vertebral_levels:
            indexes_slice = np.where(value_centerline == level)
            length_levels[level] = np.sum([math.sqrt(((x_centerline_fit[indexes_slice[0][index_slice + 1]] - x_centerline_fit[indexes_slice[0][index_slice]])*px)**2 +
                                                     ((y_centerline_fit[indexes_slice[0][index_slice + 1]] - y_centerline_fit[indexes_slice[0][index_slice]])*py)**2 +
                                                     ((z_centerline_fit[indexes_slice[0][index_slice + 1]] - z_centerline_fit[indexes_slice[0][index_slice]])*pz)**2)
                                           for index_slice in range(len(indexes_slice[0]) - 1)])

        # 2. for each slice:
        #   a. identify corresponding vertebral level --> i
        #   b. calculate distance of slice from upper vertebral level --> d
        #   c. compute relative distance in the vertebral level coordinate system --> d/Di
        continuous_values = {}
        for it, iz in enumerate(z_centerline_fit):
            level = value_centerline[it]
            indexes_slice = np.where(value_centerline == level)
            indexes_slice = indexes_slice[0][indexes_slice[0] >= it]
            distance_from_level = np.sum([math.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] - x_centerline_fit[indexes_slice[index_slice]]) * px * px) ** 2 +
                                                    ((y_centerline_fit[indexes_slice[index_slice + 1]] - y_centerline_fit[indexes_slice[index_slice]]) * py * py) ** 2 +
                                                    ((z_centerline_fit[indexes_slice[index_slice + 1]] - z_centerline_fit[indexes_slice[index_slice]]) * pz * pz) ** 2)
                                          for index_slice in range(len(indexes_slice) - 1)])
            continuous_values[iz] = level + 2.0 * distance_from_level / float(length_levels[level])

        # 3. saving data
        # for each slice, get all non-zero pixels and replace with continuous values
        coordinates_input = self.image_input.getNonZeroCoordinates()
        im_output.changeType('float32')
        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            im_output.data[int(coord.x), int(coord.y), int(coord.z)] = continuous_values[coord.z]

        return im_output
def scadMRValidation(algorithm, isPython=False, verbose=True):
    if not isinstance(algorithm, str) or not algorithm:
        print 'ERROR: You must provide the name of your algorithm as a string.'
        usage()

    import time
    import sct_utils as sct

    # creating a new folder with the experiment
    path_experiment = 'scad-experiment.'+algorithm+'.'+time.strftime("%y%m%d%H%M%S")
    #status, output = sct.run('mkdir '+path_experiment, verbose)

    # copying images from "data" folder into experiment folder
    sct.copyDirectory('data', path_experiment)

    # Starting validation
    os.chdir(path_experiment)
    # t1
    os.chdir('t1/')
    for subject_dir in os.listdir('./'):
        if os.path.isdir(subject_dir):
            os.chdir(subject_dir)

            # creating list of images and corresponding manual segmentation
            list_images = dict()
            for file_name in os.listdir('./'):
                if not 'manual_segmentation' in file_name:
                    for file_name_corr in os.listdir('./'):
                        if 'manual_segmentation' in file_name_corr and sct.extract_fname(file_name)[1] in file_name_corr:
                            list_images[file_name] = file_name_corr

            # running the proposed algorithm on images in the folder and analyzing the results
            for image, image_manual_seg in list_images.items():
                print image
                path_in, file_in, ext_in = sct.extract_fname(image)
                image_output = file_in+'_centerline'+ext_in
                if ispython:
                    try:
                        eval(algorithm+'('+image+', t1, verbose='+str(verbose)+')')
                    except Exception as e:
                        print 'Error during spinal cord detection on line {}:'.format(sys.exc_info()[-1].tb_lineno)
                        print 'Subject: t1/'+subject_dir+'/'+image
                        print e
                        sys.exit(2)
                else:
                    cmd = algorithm+' -i '+image+' -t t1'
                    if verbose:
                        cmd += ' -v'
                    status, output = sct.run(cmd, verbose=verbose)
                    if status != 0:
                        print 'Error during spinal cord detection on Subject: t1/'+subject_dir+'/'+image
                        print output
                        sys.exit(2)

                # analyzing the resulting centerline
                from msct_image import Image
                manual_segmentation_image = Image(image_manual_seg)
                manual_segmentation_image.change_orientation()
                centerline_image = Image(image_output)
                centerline_image.change_orientation()

                from msct_types import Coordinate
                # coord_manseg = manual_segmentation_image.getNonZeroCoordinates()
                coord_centerline = centerline_image.getNonZeroCoordinates()

                # check if centerline is in manual segmentation
                result_centerline_in = True
                for coord in coord_centerline:
                    if manual_segmentation_image.data[coord.x, coord.y, coord.z] == 0:
                        result_centerline_in = False
                        print 'failed on slice #' + str(coord.z)
                        break
                if result_centerline_in:
                    print 'OK: Centerline is inside manual segmentation.'
                else:
                    print 'FAIL: Centerline is outside manual segmentation.'


                # check the length of centerline compared to manual segmentation
                # import sct_process_segmentation as sct_seg
                # length_manseg = sct_seg.compute_length(image_manual_seg)
                # length_centerline = sct_seg.compute_length(image_output)
                # if length_manseg*0.9 <= length_centerline <= length_manseg*1.1:
                #     print 'OK: Length of centerline correspond to length of manual segmentation.'
                # else:
                #     print 'FAIL: Length of centerline does not correspond to length of manual segmentation.'
            os.chdir('..')
def main():
    parser = get_parser()
    param = Param()

    arguments = parser.parse(sys.argv[1:])

    # get arguments
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    fname_landmarks = arguments['-l']
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''
    path_template = sct.slash_at_the_end(arguments['-t'], 1)
    contrast_template = arguments['-c']
    remove_temp_files = int(arguments['-r'])
    verbose = int(arguments['-v'])
    if '-param-straighten' in arguments:
        param.param_straighten = arguments['-param-straighten']
    if 'cpu-nb' in arguments:
        arg_cpu = ' -cpu-nb '+arguments['-cpu-nb']
    else:
        arg_cpu = ''
    if '-param' in arguments:
        paramreg_user = arguments['-param']
        # update registration parameters
        for paramStep in paramreg_user:
            paramreg.addStep(paramStep)

    # initialize other parameters
    file_template_label = param.file_template_label
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # capitalize letters for contrast
    if contrast_template == 't1':
        contrast_template = 'T1'
    elif contrast_template == 't2':
        contrast_template = 'T2'

    # retrieve file_template based on contrast
    fname_template_list = glob(path_template+param.folder_template+'*'+contrast_template+'.nii.gz')
    # TODO: make sure there is only one file -- check if file is there otherwise it crashes
    fname_template = fname_template_list[0]

    # retrieve file_template_seg
    fname_template_seg_list = glob(path_template+param.folder_template+'*cord.nii.gz')
    # TODO: make sure there is only one file
    fname_template_seg = fname_template_seg_list[0]

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template+param.folder_template)

    # get fname of the template + template objects
    # fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    # fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Path output:          '+path_output, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # create temporary folder
    path_tmp = sct.tmp_create(verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('sct_convert -i '+fname_data+' -o '+path_tmp+ftmp_data)
    sct.run('sct_convert -i '+fname_seg+' -o '+path_tmp+ftmp_seg)
    sct.run('sct_convert -i '+fname_landmarks+' -o '+path_tmp+ftmp_label)
    sct.run('sct_convert -i '+fname_template+' -o '+path_tmp+ftmp_template)
    sct.run('sct_convert -i '+fname_template_seg+' -o '+path_tmp+ftmp_template_seg)
    sct.run('sct_convert -i '+fname_template_label+' -o '+path_tmp+ftmp_template_label)

    # go to tmp folder
    os.chdir(path_tmp)

    # smooth segmentation (jcohenadad, issue #613)
    sct.printv('\nSmooth segmentation...', verbose)
    sct.run('sct_maths -i '+ftmp_seg+' -smooth 1.5 -o '+add_suffix(ftmp_seg, '_smooth'))
    ftmp_seg = add_suffix(ftmp_seg, '_smooth')

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run('sct_resample -i '+ftmp_data+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_data, '_1mm'))
    ftmp_data = add_suffix(ftmp_data, '_1mm')
    sct.run('sct_resample -i '+ftmp_seg+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_seg, '_1mm'))
    ftmp_seg = add_suffix(ftmp_seg, '_1mm')
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
    ftmp_label = add_suffix(ftmp_label, '_1mm')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    sct.run('sct_image -i '+ftmp_data+' -setorient RPI -o '+add_suffix(ftmp_data, '_rpi'))
    ftmp_data = add_suffix(ftmp_data, '_rpi')
    sct.run('sct_image -i '+ftmp_seg+' -setorient RPI -o '+add_suffix(ftmp_seg, '_rpi'))
    ftmp_seg = add_suffix(ftmp_seg, '_rpi')
    sct.run('sct_image -i '+ftmp_label+' -setorient RPI -o '+add_suffix(ftmp_label, '_rpi'))
    ftmp_label = add_suffix(ftmp_label, '_rpi')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    status_crop, output_crop = sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -bzmax', verbose)
    ftmp_seg = add_suffix(ftmp_seg, '_crop')
    cropping_slices = output_crop.split('Dimension 2: ')[1].split('\n')[0].split(' ')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i '+ftmp_seg+' -s '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straight')+' -qc 0 -r 0 -v '+str(verbose)+' '+param.param_straighten+arg_cpu, verbose)
    # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d '+ftmp_data+' -o warp_straight2curve.nii.gz')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -p remove -i '+ftmp_template_label+' -o '+ftmp_template_label+' -r '+ftmp_label)

    # Dilating the input label so they can be straighten without losing them
    sct.printv('\nDilating input labels using 3vox ball radius')
    sct.run('sct_maths -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_dilate')+' -dilate 3')
    ftmp_label = add_suffix(ftmp_label, '_dilate')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_straight')+' -d '+add_suffix(ftmp_seg, '_straight')+' -w warp_curve2straight.nii.gz -x nn')
    ftmp_label = add_suffix(ftmp_label, '_straight')

    # Create crosses for the template labels and get coordinates
    sct.printv('\nCreate a 15 mm cross for the template labels...', verbose)
    template_image = Image(ftmp_template_label)
    coordinates_input = template_image.getNonZeroCoordinates(sorting='value')
    # jcohenadad, issue #628 <<<<<
    # landmark_template = ProcessLabels.get_crosses_coordinates(coordinates_input, gapxy=15)
    landmark_template = coordinates_input
    # >>>>>
    if verbose == 2:
        # TODO: assign cross to image before saving
        template_image.setFileName(add_suffix(ftmp_template_label, '_cross'))
        template_image.save(type='minimize_int')

    # Create crosses for the input labels into straight space and get coordinates
    sct.printv('\nCreate a 15 mm cross for the input labels...', verbose)
    label_straight_image = Image(ftmp_label)
    coordinates_input = label_straight_image.getCoordinatesAveragedByValue()  # landmarks are sorted by value
    # jcohenadad, issue #628 <<<<<
    # landmark_straight = ProcessLabels.get_crosses_coordinates(coordinates_input, gapxy=15)
    landmark_straight = coordinates_input
    # >>>>>
    if verbose == 2:
        # TODO: assign cross to image before saving
        label_straight_image.setFileName(add_suffix(ftmp_label, '_cross'))
        label_straight_image.save(type='minimize_int')

    # Reorganize landmarks
    points_fixed, points_moving = [], []
    for coord in landmark_straight:
        point_straight = label_straight_image.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_moving.append([point_straight[0][0], point_straight[0][1], point_straight[0][2]])

    for coord in landmark_template:
        point_template = template_image.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_fixed.append([point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv('\nComputing rigid transformation (algo=translation-scaling-z) ...', verbose)

    import msct_register_landmarks
    # for some reason, the moving and fixed points are inverted between ITK transform and our python-based transform.
    # and for another unknown reason, x and y dimensions have a negative sign (at least for translation and center of rotation).
    if verbose == 2:
        show_transfo = True
    else:
        show_transfo = False
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = msct_register_landmarks.getRigidTransformFromLandmarks(points_moving, points_fixed, constraints='translation-scaling-z', show=show_transfo)
    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: AffineTransform_double_3_3\n")
    text_file.write("Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n" % (
        rotation_matrix[0, 0], rotation_matrix[0, 1], rotation_matrix[0, 2],
        rotation_matrix[1, 0], rotation_matrix[1, 1], rotation_matrix[1, 2],
        rotation_matrix[2, 0], rotation_matrix[2, 1], rotation_matrix[2, 2],
        -translation_array[0, 0], -translation_array[0, 1], translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" % (-points_moving_barycenter[0],
                                                           -points_moving_barycenter[1],
                                                           points_moving_barycenter[2]))
    text_file.close()

    # Concatenate transformations: curve --> straight --> affine
    sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')

    # Apply transformation
    sct.printv('\nApply transformation...', verbose)
    sct.run('sct_apply_transfo -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz')
    ftmp_data = add_suffix(ftmp_data, '_straightAffine')
    sct.run('sct_apply_transfo -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz -x linear')
    ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

    # threshold and binarize
    sct.printv('\nBinarize segmentation...', verbose)
    sct.run('sct_maths -i '+ftmp_seg+' -thr 0.4 -o '+add_suffix(ftmp_seg, '_thr'))
    sct.run('sct_maths -i '+add_suffix(ftmp_seg, '_thr')+' -bin -o '+add_suffix(ftmp_seg, '_thr_bin'))
    ftmp_seg = add_suffix(ftmp_seg, '_thr_bin')

    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax(ftmp_seg)

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_template = add_suffix(ftmp_template, '_crop')
    sct.run('sct_crop_image -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_template_seg = add_suffix(ftmp_template_seg, '_crop')
    sct.run('sct_crop_image -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_data = add_suffix(ftmp_data, '_crop')
    sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_seg = add_suffix(ftmp_seg, '_crop')

    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_template = add_suffix(ftmp_template, '_sub')
    sct.run('sct_resample -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
    sct.run('sct_resample -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_data = add_suffix(ftmp_data, '_sub')
    sct.run('sct_resample -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_seg = add_suffix(ftmp_seg, '_sub')

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = ftmp_data
            dest = ftmp_template
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = ftmp_seg
            dest = ftmp_template_seg
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    sct.printv('\nConcatenate transformations: template --> anat...', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -crop 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -crop 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'warp_template2anat.nii.gz', path_output+'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'warp_anat2template.nii.gz', path_output+'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'template2anat.nii.gz', path_output+'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'anat2template.nii.gz', path_output+'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' '+path_output+'template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 '+path_output+'anat2template &\n', verbose, 'info')
Esempio n. 12
0
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose=1):
        self.image_input = Image(fname_label, verbose=verbose)

        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose

    def process(self, type_process):
        if type_process == 'cross':
            self.output_image = self.cross()
        elif type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        elif type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        elif type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        elif type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        elif type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        elif type_process == 'remove':
            self.output_image = self.remove_label()
        elif type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        elif type_process == 'centerline':
            self.extract_centerline()
        elif type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        elif type_process == 'create':
            self.output_image = self.create_label()
        elif type_process == 'add':
            self.output_image = self.create_label(add=True)
        elif type_process == 'diff':
            self.diff()
            self.fname_output = None
        elif type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        elif type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        else:
            sct.printv('Error: The chosen process is not available.',1,'error')

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            if type_process != 'plan_ref':
                self.output_image.save('minimize_int')
            else:
                self.output_image.save()


    def cross(self):
        image_output = Image(self.image_input, self.verbose)
        nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(self.image_input.absolutepath)

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[coord.x][coord.y][coord.z] = 0  # remove point on the center of the spinal cord
            image_output.data[coord.x][coord.y + dy][
                coord.z] = coord.value * 10 + 1  # add point at distance from center of spinal cord
            image_output.data[coord.x + dx][coord.y][coord.z] = coord.value * 10 + 2
            image_output.data[coord.x][coord.y - dy][coord.z] = coord.value * 10 + 3
            image_output.data[coord.x - dx][coord.y][coord.z] = coord.value * 10 + 4

            # dilate cross to 3x3
            if self.dilate:
                image_output.data[coord.x - 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x][coord.y + dy - 1][coord.z] = \
                    image_output.data[coord.x + 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y + dy][coord.z] = \
                    image_output.data[coord.x + 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x][coord.y + dy + 1][coord.z] = \
                    image_output.data[coord.x - 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y + dy][coord.z] = \
                    image_output.data[coord.x][coord.y + dy][coord.z]
                image_output.data[coord.x + dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx][coord.y - 1][coord.z] = \
                    image_output.data[coord.x + dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx + 1][coord.y][coord.z] = \
                    image_output.data[coord.x + dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx][coord.y + 1][coord.z] = \
                    image_output.data[coord.x + dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx - 1][coord.y][coord.z] = \
                    image_output.data[coord.x + dx][coord.y][coord.z]
                image_output.data[coord.x - 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x][coord.y - dy - 1][coord.z] = \
                    image_output.data[coord.x + 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y - dy][coord.z] = \
                    image_output.data[coord.x + 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x][coord.y - dy + 1][coord.z] = \
                    image_output.data[coord.x - 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y - dy][coord.z] = \
                    image_output.data[coord.x][coord.y - dy][coord.z]
                image_output.data[coord.x - dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx][coord.y - 1][coord.z] = \
                    image_output.data[coord.x - dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx + 1][coord.y][coord.z] = \
                    image_output.data[coord.x - dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx][coord.y + 1][coord.z] = \
                    image_output.data[coord.x - dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx - 1][coord.y][coord.z] = \
                    image_output.data[coord.x - dx][coord.y][coord.z]

        return image_output

    def plan(self, width, offset=0, gap=1):
        """
        This function creates a plan of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:,:,coord.z-width:coord.z+width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        This function generate a plan in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *=0
        image_input_pos.data *=0
        X, Y, Z = (self.image_input.data< 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]] # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data> 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, coord.z] = -coord.value #PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, coord.z] = coord.value

        return image_output
    def cubic_to_point(self):
        """
        This function calculates the center of mass of each group of labels and returns a file of same size with only a label by group at the center of mass.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        :return:
        """
        from scipy import ndimage
        from numpy import array,mean
        data = self.image_input.data


        # pb: doesn't work if several groups have same value
        image_output = self.image_input.copy()
        data_output = image_output.data
        data_output *= 0
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')
        #list of present values
        list_values = []
        for i,coord in enumerate(coordinates):
            if i == 0 or coord.value != coordinates[i-1].value:
                list_values.append(coord.value)

        # make list of group of labels coordinates per value
        list_group_labels = []
        list_barycenter = []
        for i in range(0, len(list_values)):
            #mean_coord = mean(array([[coord.x, coord.y, coord.z] for coord in coordinates if coord.value==i]))
            list_group_labels.append([])
            list_group_labels[i] = [coordinates[j] for j in range(len(coordinates)) if coordinates[j].value == list_values[i]]
            # find barycenter: first define each case as a coordinate instance then calculate the value
            list_barycenter.append([0,0,0,0])
            sum_x = 0
            sum_y = 0
            sum_z = 0
            for j in range(len(list_group_labels[i])):
                sum_x += list_group_labels[i][j].x
                sum_y += list_group_labels[i][j].y
                sum_z += list_group_labels[i][j].z
            list_barycenter[i][0] = int(round(sum_x/len(list_group_labels[i])))
            list_barycenter[i][1] = int(round(sum_y/len(list_group_labels[i])))
            list_barycenter[i][2] = int(round(sum_z/len(list_group_labels[i])))
            list_barycenter[i][3] = list_group_labels[i][0].value

        # put value of group at each center of mass
        for i in range(len(list_values)):
            data_output[list_barycenter[i][0],list_barycenter[i][1], list_barycenter[i][2]] = list_barycenter[i][3]

        return image_output

        # Process to use if one wants to calculate the center of mass of a group of labels ordered by volume (and not by value)
        # image_output = self.image_input.copy()
        # data_output = image_output.data
        # data_output *= 0
        # nx = image_output.data.shape[0]
        # ny = image_output.data.shape[1]
        # nz = image_output.data.shape[2]
        # print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
        #
        # z_centerline = [iz for iz in range(0, nz, 1) if data[:,:,iz].any() ]
        # nz_nonz = len(z_centerline)
        # if nz_nonz==0 :
        #     print '\nERROR: Label file is empty'
        #     sys.exit()
        # x_centerline = [0 for iz in range(0, nz_nonz, 1)]
        # y_centerline = [0 for iz in range(0, nz_nonz, 1)]
        # print '\nGet center of mass for each slice of the label file ...'
        # for iz in xrange(len(z_centerline)):
        #     x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(array(data[:,:,z_centerline[iz]]))
        #
        # ## Calculate mean coordinate according to z for each cube of labels:
        # list_cube_labels_x = [[]]
        # list_cube_labels_y = [[]]
        # list_cube_labels_z = [[]]
        # count = 0
        # for i in range(nz_nonz-1):
        #     # Make a list of group of slices that contains a non zero value
        #     # check if the group is only one slice of height (at first slice)
        #     if i==0 and z_centerline[i] - z_centerline[i+1] != -1:
        #         list_cube_labels_z[count].append(z_centerline[i])
        #         list_cube_labels_x[count].append(x_centerline[i])
        #         list_cube_labels_y[count].append(y_centerline[i])
        #         list_cube_labels_z.append([])
        #         list_cube_labels_x.append([])
        #         list_cube_labels_y.append([])
        #         count += 1
        #     # check if the group is only one slice of height (in the middle)
        #     if i>0 and z_centerline[i-1] - z_centerline[i] != -1 and z_centerline[i] - z_centerline[i+1] != -1:
        #         list_cube_labels_z[count].append(z_centerline[i])
        #         list_cube_labels_x[count].append(x_centerline[i])
        #         list_cube_labels_y[count].append(y_centerline[i])
        #         list_cube_labels_z.append([])
        #         list_cube_labels_x.append([])
        #         list_cube_labels_y.append([])
        #         count += 1
        #     if z_centerline[i] - z_centerline[i+1] == -1:
        #         # Verify if the value has already been recovered and add if not
        #         #If the group is empty add first value do not if it is not empty as it will copy it for a second time
        #         if len(list_cube_labels_z[count]) == 0 :#or list_cube_labels[count][-1] != z_centerline[i]:
        #             list_cube_labels_z[count].append(z_centerline[i])
        #             list_cube_labels_x[count].append(x_centerline[i])
        #             list_cube_labels_y[count].append(y_centerline[i])
        #         list_cube_labels_z[count].append(z_centerline[i+1])
        #         list_cube_labels_x[count].append(x_centerline[i+1])
        #         list_cube_labels_y[count].append(y_centerline[i+1])
        #         if i+2 < nz_nonz-1 and z_centerline[i+1] - z_centerline[i+2] != -1:
        #             list_cube_labels_z.append([])
        #             list_cube_labels_x.append([])
        #             list_cube_labels_y.append([])
        #             count += 1
        #
        # z_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # x_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # y_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # v_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # for i in range(len(list_cube_labels_z)):
        #     for j in range(len(list_cube_labels_z[i])):
        #         z_label_mean[i] += list_cube_labels_z[i][j]
        #         x_label_mean[i] += list_cube_labels_x[i][j]
        #         y_label_mean[i] += list_cube_labels_y[i][j]
        #     z_label_mean[i] = int(round(z_label_mean[i]/len(list_cube_labels_z[i])))
        #     x_label_mean[i] = int(round(x_label_mean[i]/len(list_cube_labels_x[i])))
        #     y_label_mean[i] = int(round(y_label_mean[i]/len(list_cube_labels_y[i])))
        #     # We suppose that the labels' value of the group is the value of the barycentre
        #     v_label_mean[i] = data[x_label_mean[i],y_label_mean[i], z_label_mean[i]]
        #
        # ## Put labels of value one into mean coordinates
        # for i in range(len(z_label_mean)):
        #     data_output[x_label_mean[i],y_label_mean[i], z_label_mean[i]] = v_label_mean[i]
        #
        # return image_output

    def increment_z_inverse(self):
        """
        This function increments all the labels present in the input image, inversely ordered by Z.
        Therefore, labels are incremented from top to bottom, assuming a RPI orientation
        Labels are assumed to be non-zero.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[coord.x, coord.y, coord.z] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        This function creates an image with regions labelized depending on values from reference.
        Typically, user inputs an segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref)-1):
                if coordinates_ref[j+1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[coord.x, coord.y, coord.z] = coordinates_ref[j].value

        return image_output

    def symmetrizer(self, side='left'):
        """
        This function symmetrize the input image. One side of the image will be copied on the other side. We assume a
        RPI orientation.
        :param side: string 'left' or 'right'. Side that will be copied on the other side.
        :return:
        """
        image_output = Image(self.image_input, self.verbose)

        image_output[0:]

        """inspiration: (from atlas creation matlab script)
        temp_sum = temp_g + temp_d;
        temp_sum_flip = temp_sum(end:-1:1,:);
        temp_sym = (temp_sum + temp_sum_flip) / 2;

        temp_g(1:end / 2,:) = 0;
        temp_g(1 + end / 2:end,:) = temp_sym(1 + end / 2:end,:);
        temp_d(1:end / 2,:) = temp_sym(1:end / 2,:);
        temp_d(1 + end / 2:end,:) = 0;

        tractsHR
        {label_l}(:,:, num_slice_ref) = temp_g;
        tractsHR
        {label_r}(:,:, num_slice_ref) = temp_d;
        """

        return image_output

    def MSE(self, threshold_mse=0):
        """
        This function computes the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(self.image_input.path + 'error_log_' + self.image_input.file_name + '.txt', 'w')
            f.write(
                'The labels error (MSE) between ' + self.image_input.file_name + ' and ' + self.image_ref.file_name + ' is: ' + str(
                    result))
            f.close()

        return result

    def create_label(self, add=False):
        """
        This function create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
            image_output.data[coord.x, coord.y, coord.z] = coord.value

        return image_output

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [coord for coord in coord_input if coord in coord_intersection]
            result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [coord for coord in coord_input if filter(lambda x: x.value == coord.value, coord_ref)]
            if symmetry:
                result_coord_ref = [coord for coord in coord_ref if filter(lambda x: x.value == coord.value, result_coord_input)]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        This function compares two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)

        result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True),
                                                                       self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[coord.x, coord.y, coord.z] = int(round(coord.value))

        if symmetry:
            image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[coord.x, coord.y, coord.z] = int(round(coord.value))
            image_output_ref.setFileName(self.fname_output[1])
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def extract_centerline(self):
        """
        This function write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x,coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()

    def display_voxel(self):
        """
        This function displays all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')
        useful_notation = ''
        for coord in coordinates_input:
            print 'Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value)
            if useful_notation != '':
                useful_notation = useful_notation + ':'
            useful_notation = useful_notation + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ',' + str(coord.value)
        print 'Useful notation:'
        print useful_notation
        return coordinates_input

    def diff(self):
        """
        This function detects any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        print "Label in input image that are not in reference image:"
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord.value

        print "Label in ref image that are not in input image:"
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord_ref.value

    def distance_interlabels(self, max_dist):
        """
        This function calculates the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt((coordinates_input[i].x - coordinates_input[i+1].x)**2 + (coordinates_input[i].y - coordinates_input[i+1].y)**2 + (coordinates_input[i].z - coordinates_input[i+1].z)**2)
            if dist < max_dist:
                print 'Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i+1) + '[' + str(
                    coordinates_input[i+1].x) + ',' + str(coordinates_input[i+1].y) + ',' + str(coordinates_input[i+1].z) + ']=' + str(
                    coordinates_input[i+1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist)
def main(args=None):

    # initializations
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    if '-l' in arguments:
        fname_landmarks = arguments['-l']
        label_type = 'body'
    elif '-ldisc' in arguments:
        fname_landmarks = arguments['-ldisc']
        label_type = 'disc'
    else:
        sct.printv('ERROR: Labels should be provided.', 1, 'error')
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''

    param.path_qc = arguments.get("-qc", None)

    path_template = arguments['-t']
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    remove_temp_files = int(arguments['-r'])
    verbose = int(arguments['-v'])
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    if '-param-straighten' in arguments:
        param.param_straighten = arguments['-param-straighten']
    # if '-cpu-nb' in arguments:
    #     arg_cpu = ' -cpu-nb '+str(arguments['-cpu-nb'])
    # else:
    #     arg_cpu = ''
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    # file_template_label = param.file_template_label
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # retrieve template file names
    file_template_vertebral_labeling = get_file_label(
        os.path.join(path_template, 'template'), 'vertebral labeling')
    file_template = get_file_label(
        os.path.join(path_template, 'template'),
        contrast_template.upper() + '-weighted template')
    file_template_seg = get_file_label(os.path.join(path_template, 'template'),
                                       'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = os.path.join(path_template, 'template', file_template)
    fname_template_vertebral_labeling = os.path.join(
        path_template, 'template', file_template_vertebral_labeling)
    fname_template_seg = os.path.join(path_template, 'template',
                                      file_template_seg)
    fname_template_disc_labeling = os.path.join(path_template, 'template',
                                                'PAM50_label_disc.nii.gz')

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # sct.printv(arguments)
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 ' + fname_data, verbose)
    sct.printv('  Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('  Segmentation:         ' + fname_seg, verbose)
    sct.printv('  Path template:        ' + path_template, verbose)
    sct.printv('  Remove temp files:    ' + str(remove_temp_files), verbose)

    # check if data, segmentation and landmarks are in the same space
    # JULIEN 2017-04-25: removed because of issue #1168
    # sct.printv('\nCheck if data, segmentation and landmarks are in the same space...')
    # if not sct.check_if_same_space(fname_data, fname_seg):
    #     sct.printv('ERROR: Data image and segmentation are not in the same space. Please check space and orientation of your files', verbose, 'error')
    # if not sct.check_if_same_space(fname_data, fname_landmarks):
    #     sct.printv('ERROR: Data image and landmarks are not in the same space. Please check space and orientation of your files', verbose, 'error')

    # check input labels
    labels = check_labels(fname_landmarks, label_type=label_type)

    vertebral_alignment = False
    if len(labels) > 2 and label_type == 'disc':
        vertebral_alignment = True

    path_tmp = sct.tmp_create(basename="register_to_template", verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'
    # ftmp_template_label_disc = 'template_label_disc.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...',
               verbose)
    sct.run([
        'sct_convert', '-i', fname_data, '-o',
        os.path.join(path_tmp, ftmp_data)
    ])
    sct.run([
        'sct_convert', '-i', fname_seg, '-o',
        os.path.join(path_tmp, ftmp_seg)
    ])
    sct.run([
        'sct_convert', '-i', fname_landmarks, '-o',
        os.path.join(path_tmp, ftmp_label)
    ])
    sct.run([
        'sct_convert', '-i', fname_template, '-o',
        os.path.join(path_tmp, ftmp_template)
    ])
    sct.run([
        'sct_convert', '-i', fname_template_seg, '-o',
        os.path.join(path_tmp, ftmp_template_seg)
    ])
    sct_convert.main(args=[
        '-i', fname_template_vertebral_labeling, '-o',
        os.path.join(path_tmp, ftmp_template_label)
    ])
    if label_type == 'disc':
        sct_convert.main(args=[
            '-i', fname_template_disc_labeling, '-o',
            os.path.join(path_tmp, ftmp_template_label)
        ])
    # sct.run('sct_convert -i '+fname_template_label+' -o '+os.path.join(path_tmp, ftmp_template_label))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Generate labels from template vertebral labeling
    if label_type == 'body':
        sct.printv('\nGenerate labels from template vertebral labeling',
                   verbose)
        sct_label_utils.main(args=[
            '-i', ftmp_template_label, '-vert-body', '0', '-o',
            ftmp_template_label
        ])

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template',
               verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(
        sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv(
            'ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
            'provided: ' + str(labels[-1].value) +
            '\nLabel max from template: ' + str(labels_template[-1].value),
            verbose, 'error')

    # if only one label is present, force affine transformation to be Tx,Ty,Tz only (no scaling)
    if len(labels) == 1:
        paramreg.steps['0'].dof = 'Tx_Ty_Tz'
        sct.printv(
            'WARNING: Only one label is present. Forcing initial transformation to: '
            + paramreg.steps['0'].dof, 1, 'warning')

    # Project labels onto the spinal cord centerline because later, an affine transformation is estimated between the
    # template's labels (centered in the cord) and the subject's labels (assumed to be centered in the cord).
    # If labels are not centered, mis-registration errors are observed (see issue #1826)
    ftmp_label = project_labels_on_spinalcord(ftmp_label, ftmp_seg)

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    sct.run(
        ['sct_maths', '-i', 'seg.nii.gz', '-bin', '0.5', '-o', 'seg.nii.gz'])

    # smooth segmentation (jcohenadad, issue #613)
    # sct.printv('\nSmooth segmentation...', verbose)
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 1.5 -o '+add_suffix(ftmp_seg, '_smooth'))
    # jcohenadad: updated 2016-06-16: DO NOT smooth the seg anymore. Issue #
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 0 -o '+add_suffix(ftmp_seg, '_smooth'))
    # ftmp_seg = add_suffix(ftmp_seg, '_smooth')

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        sct.run([
            'sct_resample', '-i', ftmp_data, '-mm', '1.0x1.0x1.0', '-x',
            'linear', '-o',
            add_suffix(ftmp_data, '_1mm')
        ])
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        sct.run([
            'sct_resample', '-i', ftmp_seg, '-mm', '1.0x1.0x1.0', '-x',
            'linear', '-o',
            add_suffix(ftmp_seg, '_1mm')
        ])
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling
        # with nearest neighbour can make them disappear.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run([
            'sct_image', '-i', ftmp_data, '-setorient', 'RPI', '-o',
            add_suffix(ftmp_data, '_rpi')
        ])
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run([
            'sct_image', '-i', ftmp_seg, '-setorient', 'RPI', '-o',
            add_suffix(ftmp_seg, '_rpi')
        ])
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run([
            'sct_image', '-i', ftmp_label, '-setorient', 'RPI', '-o',
            add_suffix(ftmp_label, '_rpi')
        ])
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        if vertebral_alignment:
            # cropping the segmentation based on the label coverage to ensure good registration with vertebral alignment
            # See https://github.com/neuropoly/spinalcordtoolbox/pull/1669 for details
            image_labels = Image(ftmp_label)
            coordinates_labels = image_labels.getNonZeroCoordinates(
                sorting='z')
            nx, ny, nz, nt, px, py, pz, pt = image_labels.dim
            offset_crop = 10.0 * pz  # cropping the image 10 mm above and below the highest and lowest label
            cropping_slices = [
                coordinates_labels[0].z - offset_crop,
                coordinates_labels[-1].z + offset_crop
            ]
            # make sure that the cropping slices do not extend outside of the slice range (issue #1811)
            if cropping_slices[0] < 0:
                cropping_slices[0] = 0
            if cropping_slices[1] > nz:
                cropping_slices[1] = nz
            status_crop, output_crop = sct.run([
                'sct_crop_image', '-i', ftmp_seg, '-o',
                add_suffix(ftmp_seg, '_crop'), '-dim', '2', '-start',
                str(cropping_slices[0]), '-end',
                str(cropping_slices[1])
            ], verbose)
        else:
            # if we do not align the vertebral levels, we crop the segmentation from top to bottom
            status_crop, output_crop = sct.run([
                'sct_crop_image', '-i', ftmp_seg, '-o',
                add_suffix(ftmp_seg, '_crop'), '-dim', '2', '-bzmax'
            ], verbose)
            cropping_slices = output_crop.split('Dimension 2: ')[1].split(
                '\n')[0].split(' ')

        # output: segmentation_rpi_crop.nii.gz
        ftmp_seg = add_suffix(ftmp_seg, '_crop')

        # straighten segmentation
        sct.printv(
            '\nStraighten the spinal cord using centerline/segmentation...',
            verbose)

        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        fn_warp_curve2straight = os.path.join(curdir,
                                              "warp_curve2straight.nii.gz")
        fn_warp_straight2curve = os.path.join(curdir,
                                              "warp_straight2curve.nii.gz")
        fn_straight_ref = os.path.join(curdir, "straight_ref.nii.gz")

        cache_input_files = [ftmp_seg]
        if vertebral_alignment:
            cache_input_files += [
                ftmp_template_seg,
                ftmp_label,
                ftmp_template_label,
            ]
        cache_sig = sct.cache_signature(input_files=cache_input_files, )
        cachefile = os.path.join(curdir, "straightening.cache")
        if sct.cache_valid(
                cachefile, cache_sig
        ) and os.path.isfile(fn_warp_curve2straight) and os.path.isfile(
                fn_warp_straight2curve) and os.path.isfile(fn_straight_ref):
            sct.printv(
                'Reusing existing warping field which seems to be valid',
                verbose, 'warning')
            sct.copy(fn_warp_curve2straight, 'warp_curve2straight.nii.gz')
            sct.copy(fn_warp_straight2curve, 'warp_straight2curve.nii.gz')
            sct.copy(fn_straight_ref, 'straight_ref.nii.gz')
            # apply straightening
            sct.run([
                'sct_apply_transfo', '-i', ftmp_seg, '-w',
                'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz',
                '-o',
                add_suffix(ftmp_seg, '_straight')
            ])
        else:
            from sct_straighten_spinalcord import SpinalCordStraightener
            sc_straight = SpinalCordStraightener(ftmp_seg, ftmp_seg)
            sc_straight.output_filename = add_suffix(ftmp_seg, '_straight')
            sc_straight.path_output = './'
            sc_straight.qc = '0'
            sc_straight.remove_temp_files = remove_temp_files
            sc_straight.verbose = verbose

            if vertebral_alignment:
                sc_straight.centerline_reference_filename = ftmp_template_seg
                sc_straight.use_straight_reference = True
                sc_straight.discs_input_filename = ftmp_label
                sc_straight.discs_ref_filename = ftmp_template_label

            sc_straight.straighten()
            sct.cache_save(cachefile, cache_sig)

        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct.run([
            'sct_concat_transfo', '-w', 'warp_straight2curve.nii.gz', '-d',
            ftmp_data, '-o', 'warp_straight2curve.nii.gz'
        ])

        if vertebral_alignment:
            sct.copy('warp_curve2straight.nii.gz',
                     'warp_curve2straightAffine.nii.gz')
        else:
            # Label preparation:
            # --------------------------------------------------------------------------------
            # Remove unused label on template. Keep only label present in the input label image
            sct.printv(
                '\nRemove unused label on template. Keep only label present in the input label image...',
                verbose)
            sct.run([
                'sct_label_utils', '-i', ftmp_template_label, '-o',
                ftmp_template_label, '-remove', ftmp_label
            ])

            # Dilating the input label so they can be straighten without losing them
            sct.printv('\nDilating input labels using 3vox ball radius')
            sct.run([
                'sct_maths', '-i', ftmp_label, '-o',
                add_suffix(ftmp_label, '_dilate'), '-dilate', '3'
            ])
            ftmp_label = add_suffix(ftmp_label, '_dilate')

            # Apply straightening to labels
            sct.printv('\nApply straightening to labels...', verbose)
            sct.run([
                'sct_apply_transfo', '-i', ftmp_label, '-o',
                add_suffix(ftmp_label, '_straight'), '-d',
                add_suffix(ftmp_seg, '_straight'), '-w',
                'warp_curve2straight.nii.gz', '-x', 'nn'
            ])
            ftmp_label = add_suffix(ftmp_label, '_straight')

            # Compute rigid transformation straight landmarks --> template landmarks
            sct.printv('\nEstimate transformation for step #0...', verbose)
            from msct_register_landmarks import register_landmarks
            try:
                register_landmarks(ftmp_label,
                                   ftmp_template_label,
                                   paramreg.steps['0'].dof,
                                   fname_affine='straight2templateAffine.txt',
                                   verbose=verbose)
            except Exception:
                sct.printv(
                    'ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/',
                    verbose=verbose,
                    type='error')

            # Concatenate transformations: curve --> straight --> affine
            sct.printv(
                '\nConcatenate transformations: curve --> straight --> affine...',
                verbose)
            sct.run([
                'sct_concat_transfo', '-w',
                'warp_curve2straight.nii.gz,straight2templateAffine.txt', '-d',
                'template.nii', '-o', 'warp_curve2straightAffine.nii.gz'
            ])

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct.run([
            'sct_apply_transfo', '-i', ftmp_data, '-o',
            add_suffix(ftmp_data, '_straightAffine'), '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz'
        ])
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct.run([
            'sct_apply_transfo', '-i', ftmp_seg, '-o',
            add_suffix(ftmp_seg, '_straightAffine'), '-d', ftmp_template, '-w',
            'warp_curve2straightAffine.nii.gz', '-x', 'linear'
        ])
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')
        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(round(np.min(points_straight))), int(round(np.max(points_straight)))
        sct.run('sct_crop_image -i ' + ftmp_seg + ' -start ' + str(min_point) + ' -end ' + str(max_point) + ' -dim 2 -b 0 -o ' + add_suffix(ftmp_seg, '_black'))
        ftmp_seg = add_suffix(ftmp_seg, '_black')
        """

        # binarize
        sct.printv('\nBinarize segmentation...', verbose)
        sct.run([
            'sct_maths', '-i', ftmp_seg, '-bin', '0.5', '-o',
            add_suffix(ftmp_seg, '_bin')
        ])
        ftmp_seg = add_suffix(ftmp_seg, '_bin')

        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = find_zmin_zmax(ftmp_seg)

        # crop template in z-direction (for faster processing)
        sct.printv('\nCrop data in template space (for faster processing)...',
                   verbose)
        sct.run([
            'sct_crop_image', '-i', ftmp_template, '-o',
            add_suffix(ftmp_template, '_crop'), '-dim', '2', '-start',
            str(zmin_template), '-end',
            str(zmax_template)
        ])
        ftmp_template = add_suffix(ftmp_template, '_crop')
        sct.run([
            'sct_crop_image', '-i', ftmp_template_seg, '-o',
            add_suffix(ftmp_template_seg, '_crop'), '-dim', '2', '-start',
            str(zmin_template), '-end',
            str(zmax_template)
        ])
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_crop')
        sct.run([
            'sct_crop_image', '-i', ftmp_data, '-o',
            add_suffix(ftmp_data, '_crop'), '-dim', '2', '-start',
            str(zmin_template), '-end',
            str(zmax_template)
        ])
        ftmp_data = add_suffix(ftmp_data, '_crop')
        sct.run([
            'sct_crop_image', '-i', ftmp_seg, '-o',
            add_suffix(ftmp_seg, '_crop'), '-dim', '2', '-start',
            str(zmin_template), '-end',
            str(zmax_template)
        ])
        ftmp_seg = add_suffix(ftmp_seg, '_crop')

        # sub-sample in z-direction
        sct.printv('\nSub-sample in z-direction (for faster processing)...',
                   verbose)
        sct.run([
            'sct_resample', '-i', ftmp_template, '-o',
            add_suffix(ftmp_template, '_sub'), '-f', '1x1x' + zsubsample
        ], verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run([
            'sct_resample', '-i', ftmp_template_seg, '-o',
            add_suffix(ftmp_template_seg, '_sub'), '-f', '1x1x' + zsubsample
        ], verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run([
            'sct_resample', '-i', ftmp_data, '-o',
            add_suffix(ftmp_data, '_sub'), '-f', '1x1x' + zsubsample
        ], verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run([
            'sct_resample', '-i', ftmp_seg, '-o',
            add_suffix(ftmp_seg, '_sub'), '-f', '1x1x' + zsubsample
        ], verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv(
                '\nEstimate transformation for step #' + str(i_step) + '...',
                verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
                # apply transformation from previous step, to use as new src for registration
                sct.run([
                    'sct_apply_transfo', '-i', src, '-d', dest, '-w',
                    ','.join(warp_forward), '-o',
                    add_suffix(src,
                               '_regStep' + str(i_step - 1)), '-x', interp_step
                ], verbose)
                src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(
                src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: anat --> template...',
                   verbose)
        sct.run([
            'sct_concat_transfo', '-w',
            'warp_curve2straightAffine.nii.gz,' + ','.join(warp_forward), '-d',
            'template.nii', '-o', 'warp_anat2template.nii.gz'
        ], verbose)
        # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: template --> anat...',
                   verbose)
        warp_inverse.reverse()

        if vertebral_alignment:
            sct.run([
                'sct_concat_transfo', '-w',
                ','.join(warp_inverse) + ',warp_straight2curve.nii.gz', '-d',
                'data.nii', '-o', 'warp_template2anat.nii.gz'
            ], verbose)
        else:
            sct.run([
                'sct_concat_transfo', '-w', ','.join(warp_inverse) +
                ',-straight2templateAffine.txt,warp_straight2curve.nii.gz',
                '-d', 'data.nii', '-o', 'warp_template2anat.nii.gz'
            ], verbose)

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run([
            'sct_image', '-i', ftmp_data, '-setorient', 'RPI', '-o',
            add_suffix(ftmp_data, '_rpi')
        ])
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run([
            'sct_image', '-i', ftmp_seg, '-setorient', 'RPI', '-o',
            add_suffix(ftmp_seg, '_rpi')
        ])
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run([
            'sct_image', '-i', ftmp_label, '-setorient', 'RPI', '-o',
            add_suffix(ftmp_label, '_rpi')
        ])
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv(
            '\nRemove unused label on template. Keep only label present in the input label image...',
            verbose)
        sct.run([
            'sct_label_utils', '-i', ftmp_template_label, '-o',
            ftmp_template_label, '-remove', ftmp_label
        ])

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This
        # new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue(
            )  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x),
                          int(new_label.y),
                          int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.setFileName('label_rpi_modif.nii.gz')
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['-template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label,
                               ftmp_label,
                               paramreg.steps['0'].dof,
                               fname_affine=warp_forward[0],
                               verbose=verbose,
                               path_qc="./")
        except Exception:
            sct.printv(
                'ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/',
                verbose=verbose,
                type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv(
                '\nEstimate transformation for step #' + str(i_step) + '...',
                verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct.run([
                'sct_apply_transfo', '-i', src, '-d', dest, '-w',
                ','.join(warp_forward), '-o',
                add_suffix(src,
                           '_regStep' + str(i_step - 1)), '-x', interp_step
            ], verbose)
            src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(
                src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...',
                   verbose)
        sct.run([
            'sct_concat_transfo', '-w', ','.join(warp_forward), '-d',
            'data.nii', '-o', 'warp_template2anat.nii.gz'
        ], verbose)
        sct.printv('\nConcatenate transformations: subject --> template...',
                   verbose)
        sct.run([
            'sct_concat_transfo', '-w', ','.join(warp_inverse), '-d',
            'template.nii', '-o', 'warp_anat2template.nii.gz'
        ], verbose)

    # Apply warping fields to anat and template
    sct.run([
        'sct_apply_transfo', '-i', 'template.nii', '-o',
        'template2anat.nii.gz', '-d', 'data.nii', '-w',
        'warp_template2anat.nii.gz', '-crop', '1'
    ], verbose)
    sct.run([
        'sct_apply_transfo', '-i', 'data.nii', '-o', 'anat2template.nii.gz',
        '-d', 'template.nii', '-w', 'warp_anat2template.nii.gz', '-crop', '1'
    ], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    fname_template2anat = os.path.join(path_output, 'template2anat' + ext_data)
    fname_anat2template = os.path.join(path_output, 'anat2template' + ext_data)
    sct.generate_output_file(
        os.path.join(path_tmp, "warp_template2anat.nii.gz"),
        os.path.join(path_output, "warp_template2anat.nii.gz"), verbose)
    sct.generate_output_file(
        os.path.join(path_tmp, "warp_anat2template.nii.gz"),
        os.path.join(path_output, "warp_anat2template.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "template2anat.nii.gz"),
                             fname_template2anat, verbose)
    sct.generate_output_file(os.path.join(path_tmp, "anat2template.nii.gz"),
                             fname_anat2template, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(
            os.path.join(path_tmp, "warp_curve2straight.nii.gz"),
            os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
        sct.generate_output_file(
            os.path.join(path_tmp, "warp_straight2curve.nii.gz"),
            os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
        sct.generate_output_file(
            os.path.join(path_tmp, "straight_ref.nii.gz"),
            os.path.join(path_output, "straight_ref.nii.gz"), verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv(
        '\nFinished! Elapsed time: ' + str(int(round(elapsed_time))) + 's',
        verbose)

    if param.path_qc is not None:
        generate_qc(fname_data, fname_template2anat, fname_seg, args,
                    os.path.abspath(param.path_qc))

    sct.display_viewer_syntax([fname_data, fname_template2anat],
                              verbose=verbose)
    sct.display_viewer_syntax([fname_template, fname_anat2template],
                              verbose=verbose)
Esempio n. 14
0
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose=1):
        self.image_input = Image(fname_label, verbose=verbose)

        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose

    def process(self, type_process):
        if type_process == 'cross':
            self.output_image = self.cross()
        elif type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        elif type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        elif type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        elif type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        elif type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        elif type_process == 'remove':
            self.output_image = self.remove_label()
        elif type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        elif type_process == 'centerline':
            self.extract_centerline()
        elif type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        elif type_process == 'create':
            self.output_image = self.create_label()
        elif type_process == 'add':
            self.output_image = self.create_label(add=True)
        elif type_process == 'diff':
            self.diff()
            self.fname_output = None
        elif type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        elif type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        else:
            sct.printv('Error: The chosen process is not available.',1,'error')

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            if type_process != 'plan_ref':
                self.output_image.save('minimize_int')
            else:
                self.output_image.save()


    def cross(self):
        image_output = Image(self.image_input, self.verbose)
        nx, ny, nz, nt, px, py, pz, pt = Image(self.image_input.absolutepath).dim

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[coord.x][coord.y][coord.z] = 0  # remove point on the center of the spinal cord
            image_output.data[coord.x][coord.y + dy][
                coord.z] = coord.value * 10 + 1  # add point at distance from center of spinal cord
            image_output.data[coord.x + dx][coord.y][coord.z] = coord.value * 10 + 2
            image_output.data[coord.x][coord.y - dy][coord.z] = coord.value * 10 + 3
            image_output.data[coord.x - dx][coord.y][coord.z] = coord.value * 10 + 4

            # dilate cross to 3x3
            if self.dilate:
                image_output.data[coord.x - 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x][coord.y + dy - 1][coord.z] = \
                    image_output.data[coord.x + 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y + dy][coord.z] = \
                    image_output.data[coord.x + 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x][coord.y + dy + 1][coord.z] = \
                    image_output.data[coord.x - 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y + dy][coord.z] = \
                    image_output.data[coord.x][coord.y + dy][coord.z]
                image_output.data[coord.x + dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx][coord.y - 1][coord.z] = \
                    image_output.data[coord.x + dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx + 1][coord.y][coord.z] = \
                    image_output.data[coord.x + dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx][coord.y + 1][coord.z] = \
                    image_output.data[coord.x + dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx - 1][coord.y][coord.z] = \
                    image_output.data[coord.x + dx][coord.y][coord.z]
                image_output.data[coord.x - 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x][coord.y - dy - 1][coord.z] = \
                    image_output.data[coord.x + 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y - dy][coord.z] = \
                    image_output.data[coord.x + 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x][coord.y - dy + 1][coord.z] = \
                    image_output.data[coord.x - 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y - dy][coord.z] = \
                    image_output.data[coord.x][coord.y - dy][coord.z]
                image_output.data[coord.x - dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx][coord.y - 1][coord.z] = \
                    image_output.data[coord.x - dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx + 1][coord.y][coord.z] = \
                    image_output.data[coord.x - dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx][coord.y + 1][coord.z] = \
                    image_output.data[coord.x - dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx - 1][coord.y][coord.z] = \
                    image_output.data[coord.x - dx][coord.y][coord.z]

        return image_output

    def plan(self, width, offset=0, gap=1):
        """
        This function creates a plan of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:,:,coord.z-width:coord.z+width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        This function generate a plan in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *=0
        image_input_pos.data *=0
        X, Y, Z = (self.image_input.data< 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]] # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data> 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, coord.z] = -coord.value #PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, coord.z] = coord.value

        return image_output
    def cubic_to_point(self):
        """
        This function calculates the center of mass of each group of labels and returns a file of same size with only a label by group at the center of mass.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        :return:
        """
        from scipy import ndimage
        from numpy import array,mean
        data = self.image_input.data


        # pb: doesn't work if several groups have same value
        image_output = self.image_input.copy()
        data_output = image_output.data
        data_output *= 0
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')
        #list of present values
        list_values = []
        for i,coord in enumerate(coordinates):
            if i == 0 or coord.value != coordinates[i-1].value:
                list_values.append(coord.value)

        # make list of group of labels coordinates per value
        list_group_labels = []
        list_barycenter = []
        for i in range(0, len(list_values)):
            #mean_coord = mean(array([[coord.x, coord.y, coord.z] for coord in coordinates if coord.value==i]))
            list_group_labels.append([])
            list_group_labels[i] = [coordinates[j] for j in range(len(coordinates)) if coordinates[j].value == list_values[i]]
            # find barycenter: first define each case as a coordinate instance then calculate the value
            list_barycenter.append([0,0,0,0])
            sum_x = 0
            sum_y = 0
            sum_z = 0
            for j in range(len(list_group_labels[i])):
                sum_x += list_group_labels[i][j].x
                sum_y += list_group_labels[i][j].y
                sum_z += list_group_labels[i][j].z
            list_barycenter[i][0] = int(round(sum_x/len(list_group_labels[i])))
            list_barycenter[i][1] = int(round(sum_y/len(list_group_labels[i])))
            list_barycenter[i][2] = int(round(sum_z/len(list_group_labels[i])))
            list_barycenter[i][3] = list_group_labels[i][0].value

        # put value of group at each center of mass
        for i in range(len(list_values)):
            data_output[list_barycenter[i][0],list_barycenter[i][1], list_barycenter[i][2]] = list_barycenter[i][3]

        return image_output

        # Process to use if one wants to calculate the center of mass of a group of labels ordered by volume (and not by value)
        # image_output = self.image_input.copy()
        # data_output = image_output.data
        # data_output *= 0
        # nx = image_output.data.shape[0]
        # ny = image_output.data.shape[1]
        # nz = image_output.data.shape[2]
        # print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
        #
        # z_centerline = [iz for iz in range(0, nz, 1) if data[:,:,iz].any() ]
        # nz_nonz = len(z_centerline)
        # if nz_nonz==0 :
        #     print '\nERROR: Label file is empty'
        #     sys.exit()
        # x_centerline = [0 for iz in range(0, nz_nonz, 1)]
        # y_centerline = [0 for iz in range(0, nz_nonz, 1)]
        # print '\nGet center of mass for each slice of the label file ...'
        # for iz in xrange(len(z_centerline)):
        #     x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(array(data[:,:,z_centerline[iz]]))
        #
        # ## Calculate mean coordinate according to z for each cube of labels:
        # list_cube_labels_x = [[]]
        # list_cube_labels_y = [[]]
        # list_cube_labels_z = [[]]
        # count = 0
        # for i in range(nz_nonz-1):
        #     # Make a list of group of slices that contains a non zero value
        #     # check if the group is only one slice of height (at first slice)
        #     if i==0 and z_centerline[i] - z_centerline[i+1] != -1:
        #         list_cube_labels_z[count].append(z_centerline[i])
        #         list_cube_labels_x[count].append(x_centerline[i])
        #         list_cube_labels_y[count].append(y_centerline[i])
        #         list_cube_labels_z.append([])
        #         list_cube_labels_x.append([])
        #         list_cube_labels_y.append([])
        #         count += 1
        #     # check if the group is only one slice of height (in the middle)
        #     if i>0 and z_centerline[i-1] - z_centerline[i] != -1 and z_centerline[i] - z_centerline[i+1] != -1:
        #         list_cube_labels_z[count].append(z_centerline[i])
        #         list_cube_labels_x[count].append(x_centerline[i])
        #         list_cube_labels_y[count].append(y_centerline[i])
        #         list_cube_labels_z.append([])
        #         list_cube_labels_x.append([])
        #         list_cube_labels_y.append([])
        #         count += 1
        #     if z_centerline[i] - z_centerline[i+1] == -1:
        #         # Verify if the value has already been recovered and add if not
        #         #If the group is empty add first value do not if it is not empty as it will copy it for a second time
        #         if len(list_cube_labels_z[count]) == 0 :#or list_cube_labels[count][-1] != z_centerline[i]:
        #             list_cube_labels_z[count].append(z_centerline[i])
        #             list_cube_labels_x[count].append(x_centerline[i])
        #             list_cube_labels_y[count].append(y_centerline[i])
        #         list_cube_labels_z[count].append(z_centerline[i+1])
        #         list_cube_labels_x[count].append(x_centerline[i+1])
        #         list_cube_labels_y[count].append(y_centerline[i+1])
        #         if i+2 < nz_nonz-1 and z_centerline[i+1] - z_centerline[i+2] != -1:
        #             list_cube_labels_z.append([])
        #             list_cube_labels_x.append([])
        #             list_cube_labels_y.append([])
        #             count += 1
        #
        # z_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # x_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # y_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # v_label_mean = [0 for i in range(len(list_cube_labels_z))]
        # for i in range(len(list_cube_labels_z)):
        #     for j in range(len(list_cube_labels_z[i])):
        #         z_label_mean[i] += list_cube_labels_z[i][j]
        #         x_label_mean[i] += list_cube_labels_x[i][j]
        #         y_label_mean[i] += list_cube_labels_y[i][j]
        #     z_label_mean[i] = int(round(z_label_mean[i]/len(list_cube_labels_z[i])))
        #     x_label_mean[i] = int(round(x_label_mean[i]/len(list_cube_labels_x[i])))
        #     y_label_mean[i] = int(round(y_label_mean[i]/len(list_cube_labels_y[i])))
        #     # We suppose that the labels' value of the group is the value of the barycentre
        #     v_label_mean[i] = data[x_label_mean[i],y_label_mean[i], z_label_mean[i]]
        #
        # ## Put labels of value one into mean coordinates
        # for i in range(len(z_label_mean)):
        #     data_output[x_label_mean[i],y_label_mean[i], z_label_mean[i]] = v_label_mean[i]
        #
        # return image_output

    def increment_z_inverse(self):
        """
        This function increments all the labels present in the input image, inversely ordered by Z.
        Therefore, labels are incremented from top to bottom, assuming a RPI orientation
        Labels are assumed to be non-zero.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[coord.x, coord.y, coord.z] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        This function creates an image with regions labelized depending on values from reference.
        Typically, user inputs an segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref)-1):
                if coordinates_ref[j+1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[coord.x, coord.y, coord.z] = coordinates_ref[j].value

        return image_output

    def symmetrizer(self, side='left'):
        """
        This function symmetrize the input image. One side of the image will be copied on the other side. We assume a
        RPI orientation.
        :param side: string 'left' or 'right'. Side that will be copied on the other side.
        :return:
        """
        image_output = Image(self.image_input, self.verbose)

        image_output[0:]

        """inspiration: (from atlas creation matlab script)
        temp_sum = temp_g + temp_d;
        temp_sum_flip = temp_sum(end:-1:1,:);
        temp_sym = (temp_sum + temp_sum_flip) / 2;

        temp_g(1:end / 2,:) = 0;
        temp_g(1 + end / 2:end,:) = temp_sym(1 + end / 2:end,:);
        temp_d(1:end / 2,:) = temp_sym(1:end / 2,:);
        temp_d(1 + end / 2:end,:) = 0;

        tractsHR
        {label_l}(:,:, num_slice_ref) = temp_g;
        tractsHR
        {label_r}(:,:, num_slice_ref) = temp_d;
        """

        return image_output

    def MSE(self, threshold_mse=0):
        """
        This function computes the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(self.image_input.path + 'error_log_' + self.image_input.file_name + '.txt', 'w')
            f.write(
                'The labels error (MSE) between ' + self.image_input.file_name + ' and ' + self.image_ref.file_name + ' is: ' + str(
                    result))
            f.close()

        return result

    def create_label(self, add=False):
        """
        This function create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
            image_output.data[coord.x, coord.y, coord.z] = coord.value

        return image_output

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [coord for coord in coord_input if coord in coord_intersection]
            result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [coord for coord in coord_input if filter(lambda x: x.value == coord.value, coord_ref)]
            if symmetry:
                result_coord_ref = [coord for coord in coord_ref if filter(lambda x: x.value == coord.value, result_coord_input)]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        This function compares two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = Image(self.image_input, verbose=self.verbose)
        image_output.data *= 0  # put all voxels to 0

        result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True),
                                                                       self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[coord.x, coord.y, coord.z] = int(round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[coord.x, coord.y, coord.z] = int(round(coord.value))
            image_output_ref.setFileName(self.fname_output[1])
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def extract_centerline(self):
        """
        This function write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x,coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()

    def display_voxel(self):
        """
        This function displays all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')
        useful_notation = ''
        for coord in coordinates_input:
            print 'Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value)
            if useful_notation != '':
                useful_notation = useful_notation + ':'
            useful_notation = useful_notation + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ',' + str(coord.value)
        print 'Useful notation:'
        print useful_notation
        return coordinates_input

    def diff(self):
        """
        This function detects any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        print "Label in input image that are not in reference image:"
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord.value

        print "Label in ref image that are not in input image:"
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord_ref.value

    def distance_interlabels(self, max_dist):
        """
        This function calculates the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt((coordinates_input[i].x - coordinates_input[i+1].x)**2 + (coordinates_input[i].y - coordinates_input[i+1].y)**2 + (coordinates_input[i].z - coordinates_input[i+1].z)**2)
            if dist < max_dist:
                print 'Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i+1) + '[' + str(
                    coordinates_input[i+1].x) + ',' + str(coordinates_input[i+1].y) + ',' + str(coordinates_input[i+1].z) + ']=' + str(
                    coordinates_input[i+1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist)
def main():
    
    # Initialization
    fname_anat = ''
    fname_centerline = ''
    gapxy = param.gapxy
    gapz = param.gapz
    padding = param.padding
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    interpolation_warp = param.interpolation_warp
    algo_fitting = param.algo_fitting
    window_length = param.window_length
    type_window = param.type_window
    crop = param.crop

    # start timer
    start_time = time.time()

    # get path of the toolbox
    status, path_sct = commands.getstatusoutput('echo $SCT_DIR')
    print path_sct

    # Parameters for debug mode
    if param.debug == 1:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_anat = '/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/anat_rpi.nii'  #path_sct+'/testing/sct_testing_data/data/t2/t2.nii.gz'
        fname_centerline = '/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/centerline_rpi.nii'  # path_sct+'/testing/sct_testing_data/data/t2/t2_seg.nii.gz'
        remove_temp_files = 0
        type_window = 'hanning'
        verbose = 2
    else:
        # Check input param
        try:
            opts, args = getopt.getopt(sys.argv[1:],'hi:c:p:r:v:x:a:f:')
        except getopt.GetoptError as err:
            print str(err)
            usage()
        if not opts:
            usage()
        for opt, arg in opts:
            if opt == '-h':
                usage()
            elif opt in ('-i'):
                fname_anat = arg
            elif opt in ('-c'):
                fname_centerline = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-p'):
                padding = int(arg)
            elif opt in ('-x'):
                interpolation_warp = str(arg)
            elif opt in ('-a'):
                algo_fitting = str(arg)
            elif opt in ('-f'):
                crop = int(arg)
            # elif opt in ('-f'):
            #     centerline_fitting = str(arg)
            elif opt in ('-v'):
                verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_anat == '' or fname_centerline == '':
        usage()

    # check if algorithm for fitting is correct
    if algo_fitting not in ['hanning','nurbs']:
        sct.printv('ERROR: wrong fitting algorithm',1,'warning')
        usage()

    # update field
    param.verbose = verbose

    # check existence of input files
    sct.check_file_exist(fname_anat)
    sct.check_file_exist(fname_centerline)

    # Display arguments
    print '\nCheck input arguments...'
    print '  Input volume ...................... '+fname_anat
    print '  Centerline ........................ '+fname_centerline
    print '  Final interpolation ............... '+interpolation_warp
    print '  Verbose ........................... '+str(verbose)
    print ''


    # Extract path/file/extension
    path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)
    
    # create temporary folder
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.run('mkdir '+path_tmp, verbose)

    # copy files into tmp folder
    sct.run('cp '+fname_anat+' '+path_tmp)
    sct.run('cp '+fname_centerline+' '+path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', verbose)
    fname_centerline_orient = file_centerline+'_rpi.nii.gz'
    set_orientation(fname_centerline, 'RPI', fname_centerline_orient)

    # Get dimension
    sct.printv('\nGet dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
    sct.printv('.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
    sct.printv('.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm', verbose)

    # smooth centerline
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(fname_centerline_orient, algo_fitting=algo_fitting, type_window=type_window, window_length=window_length,verbose=verbose)

    # Get coordinates of landmarks along curved centerline
    #==========================================================================================
    sct.printv('\nGet coordinates of landmarks along curved centerline...', verbose)
    # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy. In voxel space!!!
    
    # find z indices along centerline given a specific gap: iz_curved
    nz_nonz = len(z_centerline)
    nb_landmark = int(round(float(nz_nonz)/gapz))

    if nb_landmark == 0:
        nb_landmark = 1

    if nb_landmark == 1:
        iz_curved = [0]
    else:
        iz_curved = [i*gapz for i in range(0, nb_landmark-1)]

    iz_curved.append(nz_nonz-1)
    #print iz_curved, len(iz_curved)
    n_iz_curved = len(iz_curved)
    #print n_iz_curved

    # landmark_curved initialisation
    landmark_curved = [ [ [ 0 for i in range(0, 3)] for i in range(0, 5) ] for i in iz_curved ]

    ### TODO: THIS PART IS SLOW AND CAN BE MADE FASTER
    ### >>==============================================================================================================
    for index in range(0, n_iz_curved, 1):
        # calculate d (ax+by+cz+d=0)
        # print iz_curved[index]
        a=x_centerline_deriv[iz_curved[index]]
        b=y_centerline_deriv[iz_curved[index]]
        c=z_centerline_deriv[iz_curved[index]]
        x=x_centerline_fit[iz_curved[index]]
        y=y_centerline_fit[iz_curved[index]]
        z=z_centerline[iz_curved[index]]
        d=-(a*x+b*y+c*z)
        #print a,b,c,d,x,y,z
        # set coordinates for landmark at the center of the cross
        landmark_curved[index][0][0], landmark_curved[index][0][1], landmark_curved[index][0][2] = x_centerline_fit[iz_curved[index]], y_centerline_fit[iz_curved[index]], z_centerline[iz_curved[index]]

        # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
        for i in range(1, 3):
            landmark_curved[index][i][1] = y_centerline_fit[iz_curved[index]]

        # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
        x_n = Symbol('x_n')
        landmark_curved[index][2][0], landmark_curved[index][1][0]=solve((x_n-x)**2+((-1/c)*(a*x_n+b*y+d)-z)**2-gapxy**2,x_n)  #x for -x and +x
        landmark_curved[index][1][2] = (-1/c)*(a*landmark_curved[index][1][0]+b*y+d)  # z for +x
        landmark_curved[index][2][2] = (-1/c)*(a*landmark_curved[index][2][0]+b*y+d)  # z for -x

        # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
        for i in range(3, 5):
            landmark_curved[index][i][0] = x_centerline_fit[iz_curved[index]]

        # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
        y_n = Symbol('y_n')
        landmark_curved[index][4][1],landmark_curved[index][3][1] = solve((y_n-y)**2+((-1/c)*(a*x+b*y_n+d)-z)**2-gapxy**2,y_n)  #y for -y and +y
        landmark_curved[index][3][2] = (-1/c)*(a*x+b*landmark_curved[index][3][1]+d)  # z for +y
        landmark_curved[index][4][2] = (-1/c)*(a*x+b*landmark_curved[index][4][1]+d)  # z for -y
    ### <<==============================================================================================================

    if verbose == 2:
        from mpl_toolkits.mplot3d import Axes3D
        import matplotlib.pyplot as plt
        fig = plt.figure()
        ax = Axes3D(fig)
        ax.plot(x_centerline_fit, y_centerline_fit,z_centerline,zdir='z')
        ax.plot([landmark_curved[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
              [landmark_curved[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
              [landmark_curved[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')
        plt.show()

    # Get coordinates of landmarks along straight centerline
    #==========================================================================================
    sct.printv('\nGet coordinates of landmarks along straight centerline...', verbose)
    landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved
    
    # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
    # TODO: DO NOT APPROXIMATE CURVE --> LINE
    if nb_landmark == 1:
        iz_straight = [0 for i in range(0, nb_landmark+1)]
    else:
        iz_straight = [0 for i in range(0, nb_landmark)]

    # print iz_straight,len(iz_straight)
    iz_straight[0] = iz_curved[0]
    for index in range(1, n_iz_curved, 1):
        # compute vector between two consecutive points on the curved centerline
        vector_centerline = [x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index-1]], \
                             y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index-1]], \
                             z_centerline[iz_curved[index]] - z_centerline[iz_curved[index-1]] ]
        # compute norm of this vector
        norm_vector_centerline = linalg.norm(vector_centerline, ord=2)
        # round to closest integer value
        norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
        # assign this value to the current z-coordinate on the straight centerline
        iz_straight[index] = iz_straight[index-1] + norm_vector_centerline_rounded
    
    # initialize x0 and y0 to be at the center of the FOV
    x0 = int(round(nx/2))
    y0 = int(round(ny/2))
    for index in range(0, n_iz_curved, 1):
        # set coordinates for landmark at the center of the cross
        landmark_straight[index][0][0], landmark_straight[index][0][1], landmark_straight[index][0][2] = x0, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +x
        landmark_straight[index][1][0], landmark_straight[index][1][1], landmark_straight[index][1][2] = x0 + gapxy, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks -x
        landmark_straight[index][2][0], landmark_straight[index][2][1], landmark_straight[index][2][2] = x0-gapxy, y0, iz_straight[index]
        # set x, y and z coordinates for landmarks +y
        landmark_straight[index][3][0], landmark_straight[index][3][1], landmark_straight[index][3][2] = x0, y0+gapxy, iz_straight[index]
        # set x, y and z coordinates for landmarks -y
        landmark_straight[index][4][0], landmark_straight[index][4][1], landmark_straight[index][4][2] = x0, y0-gapxy, iz_straight[index]

    # # display
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # #ax.plot(x_centerline_fit, y_centerline_fit,z_centerline, 'r')
    # ax.plot([landmark_straight[i][j][0] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][1] for i in range(0, n_iz_curved) for j in range(0, 5)], \
    #        [landmark_straight[i][j][2] for i in range(0, n_iz_curved) for j in range(0, 5)], '.')
    # ax.set_xlabel('x')
    # ax.set_ylabel('y')
    # ax.set_zlabel('z')
    # plt.show()
    #
    
    # Create NIFTI volumes with landmarks
    #==========================================================================================
    # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
    # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
    #sct.run('fslview ' + fname_centerline_orient)
    sct.printv('\nPad input volume to account for landmarks that fall outside the FOV...', verbose)
    sct.run('isct_c3d '+fname_centerline_orient+' -pad '+str(padding)+'x'+str(padding)+'x'+str(padding)+'vox '+str(padding)+'x'+str(padding)+'x'+str(padding)+'vox 0 -o tmp.centerline_pad.nii.gz')
    
    # Open padded centerline for reading
    sct.printv('\nOpen padded centerline for reading...', verbose)
    file = load('tmp.centerline_pad.nii.gz')
    data = file.get_data()
    hdr = file.get_header()
    
    # Create volumes containing curved and straight landmarks
    data_curved_landmarks = data * 0
    data_straight_landmarks = data * 0
    # initialize landmark value
    landmark_value = 1
    # Loop across cross index
    for index in range(0, n_iz_curved, 1):
        # loop across cross element index
        for i_element in range(0, 5, 1):
            # get x, y and z coordinates of curved landmark (rounded to closest integer)
            x, y, z = int(round(landmark_curved[index][i_element][0])), int(round(landmark_curved[index][i_element][1])), int(round(landmark_curved[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_curved_landmarks[x+padding-1:x+padding+2, y+padding-1:y+padding+2, z+padding-1:z+padding+2] = landmark_value
            # get x, y and z coordinates of straight landmark (rounded to closest integer)
            x, y, z = int(round(landmark_straight[index][i_element][0])), int(round(landmark_straight[index][i_element][1])), int(round(landmark_straight[index][i_element][2]))
            # attribute landmark_value to the voxel and its neighbours
            data_straight_landmarks[x+padding-1:x+padding+2, y+padding-1:y+padding+2, z+padding-1:z+padding+2] = landmark_value
            # increment landmark value
            landmark_value = landmark_value + 1

    # Write NIFTI volumes
    sct.printv('\nWrite NIFTI volumes...', verbose)
    hdr.set_data_dtype('uint32')  # set imagetype to uint8 #TODO: maybe use int32
    img = Nifti1Image(data_curved_landmarks, None, hdr)
    save(img, 'tmp.landmarks_curved.nii.gz')
    sct.printv('.. File created: tmp.landmarks_curved.nii.gz', verbose)
    img = Nifti1Image(data_straight_landmarks, None, hdr)
    save(img, 'tmp.landmarks_straight.nii.gz')
    sct.printv('.. File created: tmp.landmarks_straight.nii.gz', verbose)


    # Estimate deformation field by pairing landmarks
    #==========================================================================================
    
    # This stands to avoid overlapping between landmarks
    sct.printv('\nMake sure all labels between landmark_curved and landmark_curved match...', verbose)
    sct.run('sct_label_utils -t remove -i tmp.landmarks_straight.nii.gz -o tmp.landmarks_straight.nii.gz -r tmp.landmarks_curved.nii.gz', verbose)

    # convert landmarks to INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d tmp.landmarks_straight.nii.gz -type int -o tmp.landmarks_straight.nii.gz', verbose)
    sct.run('isct_c3d tmp.landmarks_curved.nii.gz -type int -o tmp.landmarks_curved.nii.gz', verbose)

    # Estimate rigid transformation
    sct.printv('\nEstimate rigid transformation between paired landmarks...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt', verbose)
    
    # Apply rigid transformation
    sct.printv('\nApply rigid transformation to curved landmarks...', verbose)
    sct.run('sct_apply_transfo -i tmp.landmarks_curved.nii.gz -o tmp.landmarks_curved_rigid.nii.gz -d tmp.landmarks_straight.nii.gz -w tmp.curve2straight_rigid.txt -x nn', verbose)

    # Estimate b-spline transformation curve --> straight
    sct.printv('\nEstimate b-spline transformation: curve --> straight...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz 5x5x10 3 2 0', verbose)

    # remove padding for straight labels
    if crop == 1:
        sct.run('sct_crop_image -i tmp.landmarks_straight.nii.gz -o tmp.landmarks_straight_crop.nii.gz -dim 0 -bzmax', verbose)
        sct.run('sct_crop_image -i tmp.landmarks_straight_crop.nii.gz -o tmp.landmarks_straight_crop.nii.gz -dim 1 -bzmax', verbose)
        sct.run('sct_crop_image -i tmp.landmarks_straight_crop.nii.gz -o tmp.landmarks_straight_crop.nii.gz -dim 2 -bzmax', verbose)
    else:
        sct.run('cp tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz', verbose)

    # Concatenate rigid and non-linear transformations...
    sct.printv('\nConcatenate rigid and non-linear transformations...', verbose)
    #sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
    # !!! DO NOT USE sct.run HERE BECAUSE isct_ComposeMultiTransform OUTPUTS A NON-NULL STATUS !!!
    cmd = 'isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight_crop.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt'
    sct.printv(cmd, verbose, 'code')
    commands.getstatusoutput(cmd)

    # Estimate b-spline transformation straight --> curve
    # TODO: invert warping field instead of estimating a new one
    sct.printv('\nEstimate b-spline transformation: straight --> curve...', verbose)
    sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz 5x5x10 3 2 0', verbose)
    
    # Concatenate rigid and non-linear transformations...
    sct.printv('\nConcatenate rigid and non-linear transformations...', verbose)
    # cmd = 'isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R tmp.landmarks_straight.nii.gz -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz'
    cmd = 'isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R '+file_anat+ext_anat+' -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz'
    sct.printv(cmd, verbose, 'code')
    commands.getstatusoutput(cmd)

    # Apply transformation to input image
    sct.printv('\nApply transformation to input image...', verbose)
    sct.run('sct_apply_transfo -i '+file_anat+ext_anat+' -o tmp.anat_rigid_warp.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x '+interpolation_warp+' -w tmp.curve2straight.nii.gz', verbose)

    # compute the error between the straightened centerline/segmentation and the central vertical line.
    # Ideally, the error should be zero.
    # Apply deformation to input image
    print '\nApply transformation to input image...'
    c = sct.run('sct_apply_transfo -i '+fname_centerline_orient+' -o tmp.centerline_straight.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x nn -w tmp.curve2straight.nii.gz')
    #c = sct.run('sct_crop_image -i tmp.centerline_straight.nii.gz -o tmp.centerline_straight_crop.nii.gz -dim 2 -bzmax')
    from msct_image import Image
    file_centerline_straight = Image('tmp.centerline_straight.nii.gz')
    coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting='z')
    mean_coord = []
    for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
        mean_coord.append(mean([[coord.x*coord.value, coord.y*coord.value] for coord in coordinates_centerline if coord.z == z], axis=0))

    # compute error between the input data and the nurbs
    from math import sqrt
    mse_curve = 0.0
    max_dist = 0.0
    x0 = int(round(file_centerline_straight.data.shape[0]/2.0))
    y0 = int(round(file_centerline_straight.data.shape[1]/2.0))
    count_mean = 0
    for coord_z in mean_coord:
        if not isnan(sum(coord_z)):
            dist = ((x0-coord_z[0])*px)**2 + ((y0-coord_z[1])*py)**2
            mse_curve += dist
            dist = sqrt(dist)
            if dist > max_dist:
                max_dist = dist
            count_mean += 1
    mse_curve = mse_curve/float(count_mean)

    # come back to parent folder
    os.chdir('..')

    # Generate output file (in current folder)
    # TODO: do not uncompress the warping field, it is too time consuming!
    sct.printv('\nGenerate output file (in current folder)...', verbose)
    sct.generate_output_file(path_tmp+'/tmp.curve2straight.nii.gz', 'warp_curve2straight.nii.gz', verbose)  # warping field
    sct.generate_output_file(path_tmp+'/tmp.straight2curve.nii.gz', 'warp_straight2curve.nii.gz', verbose)  # warping field
    fname_straight = sct.generate_output_file(path_tmp+'/tmp.anat_rigid_warp.nii.gz', file_anat+'_straight'+ext_anat, verbose)  # straightened anatomic

    # Remove temporary files
    if remove_temp_files:
        sct.printv('\nRemove temporary files...', verbose)
        sct.run('rm -rf '+path_tmp, verbose)
    
    print '\nDone!\n'

    sct.printv('Maximum x-y error = '+str(round(max_dist,2))+' mm', verbose, 'bold')
    sct.printv('Accuracy of straightening (MSE) = '+str(round(mse_curve,2))+' mm', verbose, 'bold')
    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_straight+' &\n', verbose, 'info')
Esempio n. 16
0
class ProcessLabels(object):
    def __init__(self,
                 fname_label,
                 fname_output=None,
                 fname_ref=None,
                 cross_radius=5,
                 dilate=False,
                 coordinates=None,
                 verbose=1,
                 vertebral_levels=None,
                 value=None,
                 msg=""):
        """
        Collection of processes that deal with label creation/modification.
        :param fname_label:
        :param fname_output:
        :param fname_ref:
        :param cross_radius:
        :param dilate:
        :param coordinates:
        :param verbose:
        :param vertebral_levels:
        :param value:
        :param msg: string. message to display to the user.
        """
        self.image_input = Image(fname_label, verbose=verbose)
        self.image_ref = None
        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.vertebral_levels = vertebral_levels
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose
        self.value = value
        self.msg = msg

    def process(self, type_process):
        # for some processes, change orientation of input image to RPI
        change_orientation = False
        if type_process in ['vert-body', 'vert-disc', 'vert-continuous']:
            # get orientation of input image
            input_orientation = self.image_input.orientation
            # change orientation
            self.image_input.change_orientation('RPI')
            change_orientation = True
        if type_process == 'add':
            self.output_image = self.add(self.value)
        if type_process == 'cross':
            self.output_image = self.cross()
        if type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        if type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        if type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        if type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        if type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        if type_process == 'remove':
            self.output_image = self.remove_label()
        if type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        if type_process == 'centerline':
            self.extract_centerline()
        if type_process == 'create':
            self.output_image = self.create_label()
        if type_process == 'create-add':
            self.output_image = self.create_label(add=True)
        if type_process == 'create-seg':
            self.output_image = self.create_label_along_segmentation()
        if type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        if type_process == 'diff':
            self.diff()
            self.fname_output = None
        if type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        if type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        if type_process == 'vert-body':
            self.output_image = self.label_vertebrae(self.vertebral_levels)
        if type_process == 'vert-continuous':
            self.output_image = self.continuous_vertebral_levels()
        if type_process == 'create-viewer':
            self.output_image = self.launch_sagittal_viewer(self.value)

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            if change_orientation:
                self.output_image.change_orientation(input_orientation)
            if type_process == 'vert-continuous':
                self.output_image.save('float32')
            elif type_process != 'plan_ref':
                self.output_image.save('minimize_int')
            else:
                self.output_image.save()

    def add(self, value):
        """
        This function add a specified value to all non-zero voxels.
        """
        image_output = Image(self.image_input, self.verbose)
        # image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[
                int(coord.x), int(coord.y),
                int(coord.z)] = image_output.data[int(coord.x),
                                                  int(coord.y),
                                                  int(coord.z)] + float(value)
        return image_output

    def create_label(self, add=False):
        """
        Create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            if len(image_output.data.shape) == 3:
                image_output.data[int(coord.x),
                                  int(coord.y),
                                  int(coord.z)] = coord.value
            elif len(image_output.data.shape) == 2:
                assert str(
                    coord.z
                ) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(
                    coord.z)
                image_output.data[int(coord.x), int(coord.y)] = coord.value
            else:
                sct.printv(
                    'ERROR: Data should be 2D or 3D. Current shape is: ' +
                    str(image_output.data.shape), 1, 'error')
            # display info
            sct.printv(
                'Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) +
                ',' + str(coord.z) + ' --> ' + str(coord.value), 1)
        return image_output

    def create_label_along_segmentation(self):
        """
        Create an image with labels defined along the spinal cord segmentation (or centerline)
        Example:
        object_define=ProcessLabels(fname_segmentation, coordinates=[coord_1, coord_2, coord_i]), where coord_i='z,value'. If z=-1, then use z=nz/2 (i.e. center of FOV in superior-inferior direction)
        Returns
        -------
        image_output: Image object with labels.
        """
        # copy input Image object (will use the same header)
        image_output = self.image_input.copy()
        # set all voxels to 0
        image_output.data *= 0
        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # split coord string
            list_coord = coord.split(',')
            # convert to int() and assign to variable
            z, value = [int(i) for i in list_coord]
            # if z=-1, replace with nz/2
            if z == -1:
                z = int(round(image_output.dim[2] / 2.0))
            # get center of mass of segmentation at given z
            x, y = ndimage.measurements.center_of_mass(
                np.array(self.image_input.data[:, :, z]))
            # round values to make indices
            x, y = int(round(x)), int(round(y))
            # display info
            sct.printv(
                'Label #' + str(i) + ': ' + str(x) + ',' + str(y) + ',' +
                str(z) + ' --> ' + str(value), 1)
            if len(image_output.data.shape) == 3:
                image_output.data[x, y, z] = value
            elif len(image_output.data.shape) == 2:
                assert str(
                    z
                ) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(
                    z)
                image_output.data[x, y] = value
        return image_output

    def cross(self):
        """
        create a cross.
        :return:
        """
        output_image = Image(self.image_input, self.verbose)
        nx, ny, nz, nt, px, py, pz, pt = Image(
            self.image_input.absolutepath).dim

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # clean output_image
        output_image.data *= 0

        cross_coordinates = self.get_crosses_coordinates(
            coordinates_input, dx, self.image_ref, self.dilate)

        for coord in cross_coordinates:
            output_image.data[int(round(coord.x)),
                              int(round(coord.y)),
                              int(round(coord.z))] = coord.value

        return output_image

    @staticmethod
    def get_crosses_coordinates(coordinates_input,
                                gapxy=15,
                                image_ref=None,
                                dilate=False):
        from msct_types import Coordinate

        # if reference image is provided (segmentation), we draw the cross perpendicular to the centerline
        if image_ref is not None:
            # smooth centerline
            from sct_straighten_spinalcord import smooth_centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                self.image_ref, verbose=self.verbose)

        # compute crosses
        cross_coordinates = []
        for coord in coordinates_input:
            if image_ref is None:
                from sct_straighten_spinalcord import compute_cross
                cross_coordinates_temp = compute_cross(coord, gapxy)
            else:
                from sct_straighten_spinalcord import compute_cross_centerline
                from numpy import where
                index_z = where(z_centerline == coord.z)
                deriv = Coordinate([
                    x_centerline_deriv[index_z][0],
                    y_centerline_deriv[index_z][0],
                    z_centerline_deriv[index_z][0], 0.0
                ])
                cross_coordinates_temp = compute_cross_centerline(
                    coord, deriv, gapxy)

            for i, coord_cross in enumerate(cross_coordinates_temp):
                coord_cross.value = coord.value * 10 + i + 1

            # dilate cross to 3x3x3
            if dilate:
                additional_coordinates = []
                for coord_temp in cross_coordinates_temp:
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y, coord_temp.z + 1.0,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y, coord_temp.z - 1.0,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                cross_coordinates_temp.extend(additional_coordinates)

            cross_coordinates.extend(cross_coordinates_temp)

        cross_coordinates = sorted(cross_coordinates,
                                   key=lambda obj: obj.value)
        return cross_coordinates

    def plan(self, width, offset=0, gap=1):
        """
        Create a plane of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:, :,
                              int(coord.z) - width:int(coord.z) +
                              width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *= 0
        image_input_pos.data *= 0
        X, Y, Z = (self.image_input.data < 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[
                X[i], Y[i], Z[i]]  # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data > 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i],
                                 Z_pos[i]] = self.image_input.data[X_pos[i],
                                                                   Y_pos[i],
                                                                   Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(
                coord.z
            )] = -coord.value  # PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output

    def cubic_to_point(self):
        """
        Calculate the center of mass of each group of labels and returns a file of same size with only a
        label by group at the center of mass of this group.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with
         the same value are present but separated in space, this algorithm will compute the center of mass of the two
         groups together.
        :return: image_output
        """

        # 0. Initialization of output image
        output_image = self.image_input.copy()
        output_image.data *= 0

        # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value.
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')

        # 2. Separate all coordinates into groups by value
        groups = dict()
        for coord in coordinates:
            if coord.value in groups:
                groups[coord.value].append(coord)
            else:
                groups[coord.value] = [coord]

        # 3. Compute the center of mass of each group of voxels and write them into the output image
        for value, list_coord in groups.items():
            center_of_mass = sum(list_coord) / float(len(list_coord))
            sct.printv("Value = " + str(center_of_mass.value) + " : (" +
                       str(center_of_mass.x) + ", " + str(center_of_mass.y) +
                       ", " + str(center_of_mass.z) + ") --> ( " +
                       str(round(center_of_mass.x)) + ", " +
                       str(round(center_of_mass.y)) + ", " +
                       str(round(center_of_mass.z)) + ")",
                       verbose=self.verbose)
            output_image.data[
                int(round(center_of_mass.x)),
                int(round(center_of_mass.y)),
                int(round(center_of_mass.z))] = center_of_mass.value

        return output_image

    def increment_z_inverse(self):
        """
        Take all non-zero values, sort them along the inverse z direction, and attributes the values 1,
        2, 3, etc. This function assuming RPI orientation.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(
            sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        Create an image with regions labelized depending on values from reference.
        Typically, user inputs a segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref) - 1):
                if coordinates_ref[j + 1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[int(coord.x),
                                      int(coord.y),
                                      int(coord.z)] = coordinates_ref[j].value

        return image_output

    def label_vertebrae(self, levels_user=None):
        """
        Find the center of mass of vertebral levels specified by the user.
        :return: image_output: Image with labels.
        """
        # get center of mass of each vertebral level
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates = image_cubic2point.getNonZeroCoordinates(
            sorting='value')
        # if user did not specify levels, include all:
        if levels_user[0] == 0:
            levels_user = [int(i.value) for i in list_coordinates]
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates)):
            # check if this level is NOT in levels_user
            if not levels_user.count(int(list_coordinates[i_label].value)):
                # if not, set value to zero
                image_cubic2point.data[int(list_coordinates[i_label].x),
                                       int(list_coordinates[i_label].y),
                                       int(list_coordinates[i_label].z)] = 0
        # list all labels
        return image_cubic2point

    def MSE(self, threshold_mse=0):
        """
        Compute the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [
                    round(coord_ref.value) for coord_ref in coordinates_ref
            ]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [
                    round(coord.value) for coord in coordinates_input
            ]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z)**2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(
                self.image_input.path + 'error_log_' +
                self.image_input.file_name + '.txt', 'w')
            f.write('The labels error (MSE) between ' +
                    self.image_input.file_name + ' and ' +
                    self.image_ref.file_name + ' is: ' + str(result))
            f.close()

        return result

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(
                coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(
                set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [
                coord for coord in coord_input if coord in coord_intersection
            ]
            result_coord_ref = [
                coord for coord in coord_ref if coord in coord_intersection
            ]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [
                coord for coord in coord_input
                if list(filter(lambda x: x.value == coord.value, coord_ref))
            ]
            if symmetry:
                result_coord_ref = [
                    coord for coord in coord_ref if list(
                        filter(lambda x: x.value == coord.value,
                               result_coord_input))
                ]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        Compare two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = Image(self.image_input, verbose=self.verbose)
        image_output.data *= 0  # put all voxels to 0

        result_coord_input, result_coord_ref = self.remove_label_coord(
            self.image_input.getNonZeroCoordinates(coordValue=True),
            self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[int(coord.x),
                              int(coord.y),
                              int(coord.z)] = int(round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[int(coord.x),
                                      int(coord.y),
                                      int(coord.z)] = int(round(coord.value))
            image_output_ref.setFileName(self.fname_output[1])
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def extract_centerline(self):
        """
        Write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x, coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()

    def display_voxel(self):
        """
        Display all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(
            sorting='value')
        self.useful_notation = ''
        for coord in coordinates_input:
            sct.printv('Position=(' + str(coord.x) + ',' + str(coord.y) + ',' +
                       str(coord.z) + ') -- Value= ' + str(coord.value),
                       verbose=self.verbose)
            if self.useful_notation:
                self.useful_notation = self.useful_notation + ':'
            self.useful_notation += str(coord)
        sct.printv('All labels (useful syntax):', verbose=self.verbose)
        sct.printv(self.useful_notation, verbose=self.verbose)
        return coordinates_input

    def get_physical_coordinates(self):
        """
        This function returns the coordinates of the labels in the physical referential system.
        :return: a list of CoordinateValue, in the physical (scanner) space
        """
        coord = self.image_input.getNonZeroCoordinates(sorting='value')
        phys_coord = []
        for c in coord:
            # convert pixelar coordinates to physical coordinates
            c_p = self.image_input.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            phys_coord.append(
                CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return phys_coord

    def get_coordinates_in_destination(self, im_dest, type='discrete'):
        """
        This function calculate the position of labels in the pixelar space of a destination image
        :param im_dest: Object Image
        :param type: 'discrete' or 'continuous'
        :return: a list of CoordinateValue, in the pixelar (image) space of the destination image
        """
        phys_coord = self.get_physical_coordinates()
        dest_coord = []
        for c in phys_coord:
            if type is 'discrete':
                c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]])[0]
            elif type is 'continuous':
                c_p = im_dest.transfo_phys2continuouspix([[c.x, c.y, c.y]])[0]
            else:
                raise ValueError(
                    "The value of 'type' should either be 'discrete' or 'continuous'."
                )
            dest_coord.append(
                CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return dest_coord

    def diff(self):
        """
        Detect any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        sct.printv("Label in input image that are not in reference image:")
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord.value)

        sct.printv("Label in ref image that are not in input image:")
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord_ref.value)

    def distance_interlabels(self, max_dist):
        """
        Calculate the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt(
                (coordinates_input[i].x - coordinates_input[i + 1].x)**2 +
                (coordinates_input[i].y - coordinates_input[i + 1].y)**2 +
                (coordinates_input[i].z - coordinates_input[i + 1].z)**2)
            if dist < max_dist:
                sct.printv('Warning: the distance between label ' + str(i) +
                           '[' + str(coordinates_input[i].x) + ',' +
                           str(coordinates_input[i].y) + ',' +
                           str(coordinates_input[i].z) + ']=' +
                           str(coordinates_input[i].value) + ' and label ' +
                           str(i + 1) + '[' + str(coordinates_input[i + 1].x) +
                           ',' + str(coordinates_input[i + 1].y) + ',' +
                           str(coordinates_input[i + 1].z) + ']=' +
                           str(coordinates_input[i + 1].value) +
                           ' is larger than ' + str(max_dist) + '. Distance=' +
                           str(dist))

    def continuous_vertebral_levels(self):
        """
        This function transforms the vertebral levels file from the template into a continuous file.
        Instead of having integer representing the vertebral level on each slice, a continuous value that represents
        the position of the slice in the vertebral level coordinate system.
        The image must be RPI
        :return:
        """
        im_input = Image(self.image_input, self.verbose)
        im_output = Image(self.image_input, self.verbose)
        im_output.data *= 0

        # 1. extract vertebral levels from input image
        #   a. extract centerline
        #   b. for each slice, extract corresponding level
        nx, ny, nz, nt, px, py, pz, pt = im_input.dim
        from sct_straighten_spinalcord import smooth_centerline
        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
            self.image_input, algo_fitting='nurbs', verbose=0)
        value_centerline = np.array([
            im_input.data[int(x_centerline_fit[it]),
                          int(y_centerline_fit[it]),
                          int(z_centerline_fit[it])]
            for it in range(len(z_centerline_fit))
        ])

        # 2. compute distance for each vertebral level --> Di for i being the vertebral levels
        vertebral_levels = {}
        for slice_image, level in enumerate(value_centerline):
            if level not in vertebral_levels:
                vertebral_levels[level] = slice_image

        length_levels = {}
        for level in vertebral_levels:
            indexes_slice = np.where(value_centerline == level)
            length_levels[level] = np.sum([
                math.sqrt(
                    ((x_centerline_fit[indexes_slice[0][index_slice + 1]] -
                      x_centerline_fit[indexes_slice[0][index_slice]]) *
                     px)**2 +
                    ((y_centerline_fit[indexes_slice[0][index_slice + 1]] -
                      y_centerline_fit[indexes_slice[0][index_slice]]) *
                     py)**2 +
                    ((z_centerline_fit[indexes_slice[0][index_slice + 1]] -
                      z_centerline_fit[indexes_slice[0][index_slice]]) *
                     pz)**2)
                for index_slice in range(len(indexes_slice[0]) - 1)
            ])

        # 2. for each slice:
        #   a. identify corresponding vertebral level --> i
        #   b. calculate distance of slice from upper vertebral level --> d
        #   c. compute relative distance in the vertebral level coordinate system --> d/Di
        continuous_values = {}
        for it, iz in enumerate(z_centerline_fit):
            level = value_centerline[it]
            indexes_slice = np.where(value_centerline == level)
            indexes_slice = indexes_slice[0][indexes_slice[0] >= it]
            distance_from_level = np.sum([
                math.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] -
                            x_centerline_fit[indexes_slice[index_slice]]) *
                           px * px)**2 +
                          ((y_centerline_fit[indexes_slice[index_slice + 1]] -
                            y_centerline_fit[indexes_slice[index_slice]]) *
                           py * py)**2 +
                          ((z_centerline_fit[indexes_slice[index_slice + 1]] -
                            z_centerline_fit[indexes_slice[index_slice]]) *
                           pz * pz)**2)
                for index_slice in range(len(indexes_slice) - 1)
            ])
            continuous_values[iz] = level + 2.0 * distance_from_level / float(
                length_levels[level])

        # 3. saving data
        # for each slice, get all non-zero pixels and replace with continuous values
        coordinates_input = self.image_input.getNonZeroCoordinates()
        im_output.changeType('float32')
        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            im_output.data[int(coord.x),
                           int(coord.y),
                           int(coord.z)] = continuous_values[coord.z]

        return im_output

    def launch_sagittal_viewer(self, labels):
        from spinalcordtoolbox.gui import base
        from spinalcordtoolbox.gui.sagittal import launch_sagittal_dialog

        params = base.AnatomicalParams()
        params.vertebraes = labels
        params.input_file_name = self.image_input.file_name
        params.output_file_name = self.fname_output
        params.subtitle = self.msg
        output = self.image_input.copy()
        output.data *= 0
        output.setFileName(self.fname_output)
        launch_sagittal_dialog(self.image_input, output, params)

        return output
def main():

    # get default parameters
    step1 = Paramreg(step='1', type='seg', algo='slicereg', metric='MeanSquares', iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(name="-l",
                      type_value="file",
                      description="Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
                      mandatory=True,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(name="-p",
                      type_value=[[':'], 'str'],
                      description="""Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""+paramreg.steps['1'].type+"""\nalgo="""+paramreg.steps['1'].algo+"""\nmetric="""+paramreg.steps['1'].metric+"""\npoly="""+paramreg.steps['1'].poly+"""\n--\nstep=2\ntype="""+paramreg.steps['2'].type+"""\nalgo="""+paramreg.steps['2'].algo+"""\nmetric="""+paramreg.steps['2'].metric+"""\niter="""+paramreg.steps['2'].iter+"""\nshrink="""+paramreg.steps['2'].shrink+"""\nsmooth="""+paramreg.steps['2'].smooth+"""\ngradStep="""+paramreg.steps['2'].gradStep+"""\n--""",
                      mandatory=False,
                      example="step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3")
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
                      mandatory=False,
                      default_value=param.verbose,
                      example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')


    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'/data.nii')
    sct.run('isct_c3d '+fname_landmarks+' -o '+path_tmp+'/landmarks.nii.gz')
    sct.run('isct_c3d '+fname_seg+' -o '+path_tmp+'/segmentation.nii.gz')
    sct.run('isct_c3d '+fname_template+' -o '+path_tmp+'/template.nii')
    sct.run('isct_c3d '+fname_template_label+' -o '+path_tmp+'/template_labels.nii.gz')
    sct.run('isct_c3d '+fname_template_seg+' -o '+path_tmp+'/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run('isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii')
    sct.run('isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz')
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run('sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '+str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz')

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d template_label.nii.gz -type int -o template_label.nii.gz', verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run('sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5')

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv('\nCreate a 5mm cross for the input labels and dilate for straightening preparation...', verbose)
    sct.run('sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn')

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run('isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz')

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.', verbose)
    sct.run('sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz')

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv('\nEstimate affine transfo: straight anat --> template (landmark-based)...', verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_moving.append([point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_fixed.append([point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv('\nComputing rigid transformation (algo=translation-scaling-z) ...', verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write("Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n" % (
        1.0/rotation_matrix[0, 0], rotation_matrix[0, 1],     rotation_matrix[0, 2],
        rotation_matrix[1, 0],     1.0/rotation_matrix[1, 1], rotation_matrix[1, 2],
        rotation_matrix[2, 0],     rotation_matrix[2, 1],     1.0/rotation_matrix[2, 2],
        translation_array[0, 0],   translation_array[0, 1],   -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" % (points_moving_barycenter[0],
                                                           points_moving_barycenter[1],
                                                           points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear')

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax('segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'/warp_template2anat.nii.gz', 'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'/warp_anat2template.nii.gz', 'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'/template2anat.nii.gz', 'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'/anat2template.nii.gz', 'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 anat2template &\n', verbose, 'info')
    def straighten(self):
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        gapxy = self.gapxy
        gapz = self.gapz
        padding = self.padding
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp
        algo_fitting = self.algo_fitting
        window_length = self.window_length
        type_window = self.type_window
        crop = self.crop

        # start timer
        start_time = time.time()

        # get path of the toolbox
        status, path_sct = commands.getstatusoutput("echo $SCT_DIR")
        sct.printv(path_sct, verbose)

        if self.debug == 1:
            print "\n*** WARNING: DEBUG MODE ON ***\n"
            fname_anat = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/anat_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2.nii.gz'
            fname_centerline = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/centerline_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2_seg.nii.gz'
            remove_temp_files = 0
            type_window = "hanning"
            verbose = 2

        # check existence of input files
        sct.check_file_exist(fname_anat, verbose)
        sct.check_file_exist(fname_centerline, verbose)

        # Display arguments
        sct.printv("\nCheck input arguments...", verbose)
        sct.printv("  Input volume ...................... " + fname_anat, verbose)
        sct.printv("  Centerline ........................ " + fname_centerline, verbose)
        sct.printv("  Final interpolation ............... " + interpolation_warp, verbose)
        sct.printv("  Verbose ........................... " + str(verbose), verbose)
        sct.printv("", verbose)

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
        path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

        # create temporary folder
        path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
        sct.run("mkdir " + path_tmp, verbose)

        # copy files into tmp folder
        sct.run("cp " + fname_anat + " " + path_tmp, verbose)
        sct.run("cp " + fname_centerline + " " + path_tmp, verbose)

        # go to tmp folder
        os.chdir(path_tmp)

        try:
            # Change orientation of the input centerline into RPI
            sct.printv("\nOrient centerline to RPI orientation...", verbose)
            fname_centerline_orient = file_centerline + "_rpi.nii.gz"
            set_orientation(file_centerline + ext_centerline, "RPI", fname_centerline_orient)

            # Get dimension
            sct.printv("\nGet dimensions...", verbose)
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
            sct.printv(".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz), verbose)
            sct.printv(".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm", verbose)

            # smooth centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_centerline_orient,
                algo_fitting=algo_fitting,
                type_window=type_window,
                window_length=window_length,
                verbose=verbose,
            )

            # Get coordinates of landmarks along curved centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along curved centerline...", verbose)
            # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy. In voxel space!!!

            # find z indices along centerline given a specific gap: iz_curved
            nz_nonz = len(z_centerline)
            nb_landmark = int(round(float(nz_nonz) / gapz))

            if nb_landmark == 0:
                nb_landmark = 1

            if nb_landmark == 1:
                iz_curved = [0]
            else:
                iz_curved = [i * gapz for i in range(0, nb_landmark - 1)]

            iz_curved.append(nz_nonz - 1)
            # print iz_curved, len(iz_curved)
            n_iz_curved = len(iz_curved)
            # print n_iz_curved

            # landmark_curved initialisation
            # landmark_curved = [ [ [ 0 for i in range(0, 3)] for i in range(0, 5) ] for i in iz_curved ]

            from msct_types import Coordinate

            landmark_curved = []
            landmark_curved_value = 1

            ### TODO: THIS PART IS SLOW AND CAN BE MADE FASTER
            ### >>==============================================================================================================
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # calculate d (ax+by+cz+d=0)
                    # print iz_curved[index]
                    a = x_centerline_deriv[iz]
                    b = y_centerline_deriv[iz]
                    c = z_centerline_deriv[iz]
                    x = x_centerline_fit[iz]
                    y = y_centerline_fit[iz]
                    z = z_centerline[iz]
                    d = -(a * x + b * y + c * z)
                    # print a,b,c,d,x,y,z
                    # set coordinates for landmark at the center of the cross
                    coord = Coordinate([0, 0, 0, landmark_curved_value])
                    coord.x, coord.y, coord.z = x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz]
                    landmark_curved.append(coord)

                    # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
                    cross_coordinates = [
                        Coordinate([0, 0, 0, landmark_curved_value + 1]),
                        Coordinate([0, 0, 0, landmark_curved_value + 2]),
                        Coordinate([0, 0, 0, landmark_curved_value + 3]),
                        Coordinate([0, 0, 0, landmark_curved_value + 4]),
                    ]

                    cross_coordinates[0].y = y_centerline_fit[iz]
                    cross_coordinates[1].y = y_centerline_fit[iz]

                    # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
                    x_n = Symbol("x_n")
                    cross_coordinates[1].x, cross_coordinates[0].x = solve(
                        (x_n - x) ** 2 + ((-1 / c) * (a * x_n + b * y + d) - z) ** 2 - gapxy ** 2, x_n
                    )  # x for -x and +x
                    cross_coordinates[0].z = (-1 / c) * (a * cross_coordinates[0].x + b * y + d)  # z for +x
                    cross_coordinates[1].z = (-1 / c) * (a * cross_coordinates[1].x + b * y + d)  # z for -x

                    # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
                    cross_coordinates[2].x = x_centerline_fit[iz]
                    cross_coordinates[3].x = x_centerline_fit[iz]

                    # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
                    y_n = Symbol("y_n")
                    cross_coordinates[3].y, cross_coordinates[2].y = solve(
                        (y_n - y) ** 2 + ((-1 / c) * (a * x + b * y_n + d) - z) ** 2 - gapxy ** 2, y_n
                    )  # y for -y and +y
                    cross_coordinates[2].z = (-1 / c) * (a * x + b * cross_coordinates[2].y + d)  # z for +y
                    cross_coordinates[3].z = (-1 / c) * (a * x + b * cross_coordinates[3].y + d)  # z for -y

                    for coord in cross_coordinates:
                        landmark_curved.append(coord)
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_curved.append(
                            Coordinate(
                                [x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz], landmark_curved_value],
                                mode="continuous",
                            )
                        )
                        landmark_curved_value += 1
            ### <<==============================================================================================================

            # Get coordinates of landmarks along straight centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along straight centerline...", verbose)
            # landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved

            landmark_straight = []

            # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
            # TODO: DO NOT APPROXIMATE CURVE --> LINE
            if nb_landmark == 1:
                iz_straight = [0 for i in range(0, nb_landmark + 1)]
            else:
                iz_straight = [0 for i in range(0, nb_landmark)]

            # print iz_straight,len(iz_straight)
            iz_straight[0] = iz_curved[0]
            for index in range(1, n_iz_curved, 1):
                # compute vector between two consecutive points on the curved centerline
                vector_centerline = [
                    x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index - 1]],
                    y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index - 1]],
                    z_centerline[iz_curved[index]] - z_centerline[iz_curved[index - 1]],
                ]
                # compute norm of this vector
                norm_vector_centerline = linalg.norm(vector_centerline, ord=2)
                # round to closest integer value
                norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
                # assign this value to the current z-coordinate on the straight centerline
                iz_straight[index] = iz_straight[index - 1] + norm_vector_centerline_rounded

            # initialize x0 and y0 to be at the center of the FOV
            x0 = int(round(nx / 2))
            y0 = int(round(ny / 2))
            landmark_curved_value = 1
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # set coordinates for landmark at the center of the cross
                    landmark_straight.append(Coordinate([x0, y0, iz_straight[index], landmark_curved_value]))
                    # set x, y and z coordinates for landmarks +x
                    landmark_straight.append(
                        Coordinate([x0 + gapxy, y0, iz_straight[index], landmark_curved_value + 1])
                    )
                    # set x, y and z coordinates for landmarks -x
                    landmark_straight.append(
                        Coordinate([x0 - gapxy, y0, iz_straight[index], landmark_curved_value + 2])
                    )
                    # set x, y and z coordinates for landmarks +y
                    landmark_straight.append(
                        Coordinate([x0, y0 + gapxy, iz_straight[index], landmark_curved_value + 3])
                    )
                    # set x, y and z coordinates for landmarks -y
                    landmark_straight.append(
                        Coordinate([x0, y0 - gapxy, iz_straight[index], landmark_curved_value + 4])
                    )
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_straight.append(Coordinate([x0, y0, iz, landmark_curved_value]))
                        landmark_curved_value += 1

            # Create NIFTI volumes with landmarks
            # ==========================================================================================
            # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
            # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
            # sct.run('fslview ' + fname_centerline_orient)
            sct.printv("\nPad input volume to account for landmarks that fall outside the FOV...", verbose)
            sct.run(
                "isct_c3d "
                + fname_centerline_orient
                + " -pad "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox 0 -o tmp.centerline_pad.nii.gz",
                verbose,
            )

            # Open padded centerline for reading
            sct.printv("\nOpen padded centerline for reading...", verbose)
            file = load("tmp.centerline_pad.nii.gz")
            data = file.get_data()
            hdr = file.get_header()

            if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                # Reorganize landmarks
                points_fixed, points_moving = [], []
                for coord in landmark_straight:
                    points_fixed.append([coord.x, coord.y, coord.z])
                for coord in landmark_curved:
                    points_moving.append([coord.x, coord.y, coord.z])

                # Register curved landmarks on straight landmarks based on python implementation
                sct.printv("\nComputing rigid transformation (algo=" + self.algo_landmark_rigid + ") ...", verbose)
                import msct_register_landmarks

                (
                    rotation_matrix,
                    translation_array,
                    points_moving_reg,
                ) = msct_register_landmarks.getRigidTransformFromLandmarks(
                    points_fixed, points_moving, constraints=self.algo_landmark_rigid, show=False
                )

                # reorganize registered points
                landmark_curved_rigid = []
                for index_curved, ind in enumerate(range(0, len(points_moving_reg), 1)):
                    coord = Coordinate()
                    coord.x, coord.y, coord.z, coord.value = (
                        points_moving_reg[ind][0],
                        points_moving_reg[ind][1],
                        points_moving_reg[ind][2],
                        index_curved + 1,
                    )
                    landmark_curved_rigid.append(coord)

                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_curved_rigid_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved_rigid)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of curved landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_curved_rigid[index].x)),
                        int(round(landmark_curved_rigid[index].y)),
                        int(round(landmark_curved_rigid[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_rigid_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved_rigid[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_rigid_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved_rigid.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved_rigid.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # writing rigid transformation file
                text_file = open("tmp.curve2straight_rigid.txt", "w")
                text_file.write("#Insight Transform File V1.0\n")
                text_file.write("#Transform 0\n")
                text_file.write("Transform: AffineTransform_double_3_3\n")
                text_file.write(
                    "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
                    % (
                        rotation_matrix[0, 0],
                        rotation_matrix[0, 1],
                        rotation_matrix[0, 2],
                        rotation_matrix[1, 0],
                        rotation_matrix[1, 1],
                        rotation_matrix[1, 2],
                        rotation_matrix[2, 0],
                        rotation_matrix[2, 1],
                        rotation_matrix[2, 2],
                        -translation_array[0, 0],
                        translation_array[0, 1],
                        -translation_array[0, 2],
                    )
                )
                text_file.write("FixedParameters: 0 0 0\n")
                text_file.close()

            else:
                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # Estimate deformation field by pairing landmarks
                # ==========================================================================================
                # convert landmarks to INT
                sct.printv("\nConvert landmarks to INT...", verbose)
                sct.run("isct_c3d tmp.landmarks_straight.nii.gz -type int -o tmp.landmarks_straight.nii.gz", verbose)
                sct.run("isct_c3d tmp.landmarks_curved.nii.gz -type int -o tmp.landmarks_curved.nii.gz", verbose)

                # This stands to avoid overlapping between landmarks
                sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
                label_process_straight = ProcessLabels(
                    fname_label="tmp.landmarks_straight.nii.gz",
                    fname_output="tmp.landmarks_straight.nii.gz",
                    fname_ref="tmp.landmarks_curved.nii.gz",
                    verbose=verbose,
                )
                label_process_straight.process("remove")
                label_process_curved = ProcessLabels(
                    fname_label="tmp.landmarks_curved.nii.gz",
                    fname_output="tmp.landmarks_curved.nii.gz",
                    fname_ref="tmp.landmarks_straight.nii.gz",
                    verbose=verbose,
                )
                label_process_curved.process("remove")

                # Estimate rigid transformation
                sct.printv("\nEstimate rigid transformation between paired landmarks...", verbose)
                sct.run(
                    "isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt",
                    verbose,
                )

                # Apply rigid transformation
                sct.printv("\nApply rigid transformation to curved landmarks...", verbose)
                # sct.run('sct_apply_transfo -i tmp.landmarks_curved.nii.gz -o tmp.landmarks_curved_rigid.nii.gz -d tmp.landmarks_straight.nii.gz -w tmp.curve2straight_rigid.txt -x nn', verbose)
                Transform(
                    input_filename="tmp.landmarks_curved.nii.gz",
                    source_reg="tmp.landmarks_curved_rigid.nii.gz",
                    output_filename="tmp.landmarks_straight.nii.gz",
                    warp="tmp.curve2straight_rigid.txt",
                    interp="nn",
                    verbose=verbose,
                ).apply()

            if verbose == 2:
                from mpl_toolkits.mplot3d import Axes3D
                import matplotlib.pyplot as plt

                fig = plt.figure()
                ax = Axes3D(fig)
                ax.plot(x_centerline_fit, y_centerline_fit, z_centerline, zdir="z")
                ax.plot(
                    [coord.x for coord in landmark_curved],
                    [coord.y for coord in landmark_curved],
                    [coord.z for coord in landmark_curved],
                    ".",
                )
                ax.plot(
                    [coord.x for coord in landmark_straight],
                    [coord.y for coord in landmark_straight],
                    [coord.z for coord in landmark_straight],
                    "r.",
                )
                if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                    ax.plot(
                        [coord.x for coord in landmark_curved_rigid],
                        [coord.y for coord in landmark_curved_rigid],
                        [coord.z for coord in landmark_curved_rigid],
                        "b.",
                    )
                ax.set_xlabel("x")
                ax.set_ylabel("y")
                ax.set_zlabel("z")
                plt.show()

            # This stands to avoid overlapping between landmarks
            sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_straight.nii.gz",
                fname_output="tmp.landmarks_straight.nii.gz",
                fname_ref="tmp.landmarks_curved_rigid.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_curved_rigid.nii.gz",
                fname_output="tmp.landmarks_curved_rigid.nii.gz",
                fname_ref="tmp.landmarks_straight.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")

            # Estimate b-spline transformation curve --> straight
            sct.printv("\nEstimate b-spline transformation: curve --> straight...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # remove padding for straight labels
            if crop == 1:
                ImageCropper(
                    input_file="tmp.landmarks_straight.nii.gz",
                    output_file="tmp.landmarks_straight_crop.nii.gz",
                    dim="0,1,2",
                    bmax=True,
                    verbose=verbose,
                ).crop()
                pass
            else:
                sct.run("cp tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz", verbose)

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            # sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
            # !!! DO NOT USE sct.run HERE BECAUSE isct_ComposeMultiTransform OUTPUTS A NON-NULL STATUS !!!
            cmd = "isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight_crop.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt"
            sct.printv(cmd, verbose, "code")
            sct.run(cmd, self.verbose)
            # commands.getstatusoutput(cmd)

            # Estimate b-spline transformation straight --> curve
            # TODO: invert warping field instead of estimating a new one
            sct.printv("\nEstimate b-spline transformation: straight --> curve...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            cmd = (
                "isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R "
                + file_anat
                + ext_anat
                + " -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz"
            )
            sct.printv(cmd, verbose, "code")
            # commands.getstatusoutput(cmd)
            sct.run(cmd, self.verbose)

            # Apply transformation to input image
            sct.printv("\nApply transformation to input image...", verbose)
            Transform(
                input_filename=str(file_anat + ext_anat),
                source_reg="tmp.anat_rigid_warp.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp=interpolation_warp,
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()

            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            sct.printv("\nApply transformation to centerline image...", verbose)
            # sct.run('sct_apply_transfo -i '+fname_centerline_orient+' -o tmp.centerline_straight.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x nn -w tmp.curve2straight.nii.gz')
            Transform(
                input_filename=fname_centerline_orient,
                source_reg="tmp.centerline_straight.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp="nn",
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()
            # c = sct.run('sct_crop_image -i tmp.centerline_straight.nii.gz -o tmp.centerline_straight_crop.nii.gz -dim 2 -bzmax')
            from msct_image import Image

            file_centerline_straight = Image("tmp.centerline_straight.nii.gz", verbose=verbose)
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting="z")
            mean_coord = []
            for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
                mean_coord.append(
                    mean(
                        [
                            [coord.x * coord.value, coord.y * coord.value]
                            for coord in coordinates_centerline
                            if coord.z == z
                        ],
                        axis=0,
                    )
                )

            # compute error between the input data and the nurbs
            from math import sqrt

            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            for coord_z in mean_coord:
                if not isnan(sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2
                    self.mse_straightening += dist
                    dist = sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = sqrt(self.mse_straightening / float(count_mean))

        except Exception as e:
            sct.printv("WARNING: Exception during Straightening:", 1, "warning")
            print e

        os.chdir("..")

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        sct.printv("\nGenerate output file (in current folder)...", verbose)
        sct.generate_output_file(
            path_tmp + "/tmp.curve2straight.nii.gz", "warp_curve2straight.nii.gz", verbose
        )  # warping field
        sct.generate_output_file(
            path_tmp + "/tmp.straight2curve.nii.gz", "warp_straight2curve.nii.gz", verbose
        )  # warping field
        if fname_output == "":
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", file_anat + "_straight" + ext_anat, verbose
            )  # straightened anatomic
        else:
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", fname_output, verbose
            )  # straightened anatomic
        # Remove temporary files
        if remove_temp_files:
            sct.printv("\nRemove temporary files...", verbose)
            sct.run("rm -rf " + path_tmp, verbose)

        sct.printv("\nDone!\n", verbose)

        sct.printv("Maximum x-y error = " + str(round(self.max_distance_straightening, 2)) + " mm", verbose, "bold")
        sct.printv(
            "Accuracy of straightening (MSE) = " + str(round(self.mse_straightening, 2)) + " mm", verbose, "bold"
        )
        # display elapsed time
        elapsed_time = time.time() - start_time
        sct.printv("\nFinished! Elapsed time: " + str(int(round(elapsed_time))) + "s", verbose)
        sct.printv("\nTo view results, type:", verbose)
        sct.printv("fslview " + fname_straight + " &\n", verbose, "info")
Esempio n. 19
0
def main():
    parser = get_parser()
    param = Param()

    arguments = parser.parse(sys.argv[1:])

    # get arguments
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    fname_landmarks = arguments['-l']
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''
    path_template = sct.slash_at_the_end(arguments['-t'], 1)
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    remove_temp_files = int(arguments['-r'])
    verbose = int(arguments['-v'])
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    if '-param-straighten' in arguments:
        param.param_straighten = arguments['-param-straighten']
    # if '-cpu-nb' in arguments:
    #     arg_cpu = ' -cpu-nb '+str(arguments['-cpu-nb'])
    # else:
    #     arg_cpu = ''
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    # file_template_label = param.file_template_label
    zsubsample = param.zsubsample
    template = os.path.basename(os.path.normpath(path_template))
    # smoothing_sigma = param.smoothing_sigma

    # retrieve template file names
    from sct_warp_template import get_file_label
    file_template_vertebral_labeling = get_file_label(path_template+'template/', 'vertebral')
    file_template = get_file_label(path_template+'template/', contrast_template.upper()+'-weighted')
    file_template_seg = get_file_label(path_template+'template/', 'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = path_template+'template/'+file_template
    fname_template_vertebral_labeling = path_template+'template/'+file_template_vertebral_labeling
    fname_template_seg = path_template+'template/'+file_template_seg

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 '+fname_data, verbose)
    sct.printv('  Landmarks:            '+fname_landmarks, verbose)
    sct.printv('  Segmentation:         '+fname_seg, verbose)
    sct.printv('  Path template:        '+path_template, verbose)
    sct.printv('  Remove temp files:    '+str(remove_temp_files), verbose)

    # create QC folder
    sct.create_folder(param.path_qc)

    #
    # sct.printv('\nParameters for registration:')
    # for pStep in range(0, len(paramreg.steps)):
    #     sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
    #     sct.printv('  Type .................... '+paramreg.steps[str(pStep)].type, verbose)
    #     sct.printv('  Algorithm ............... '+paramreg.steps[str(pStep)].algo, verbose)
    #     sct.printv('  Metric .................. '+paramreg.steps[str(pStep)].metric, verbose)
    #     sct.printv('  Number of iterations .... '+paramreg.steps[str(pStep)].iter, verbose)
    #     sct.printv('  Shrink factor ........... '+paramreg.steps[str(pStep)].shrink, verbose)
    #     sct.printv('  Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
    #     sct.printv('  Gradient step ........... '+paramreg.steps[str(pStep)].gradStep, verbose)
    #     sct.printv('  Degree of polynomial .... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck if data, segmentation and landmarks are in the same space...')
    if not sct.check_if_same_space(fname_data, fname_seg):
        sct.printv('ERROR: Data image and segmentation are not in the same space. Please check space and orientation of your files', verbose, 'error')
    if not sct.check_if_same_space(fname_data, fname_landmarks):
        sct.printv('ERROR: Data image and landmarks are not in the same space. Please check space and orientation of your files', verbose, 'error')

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')

    # create temporary folder
    path_tmp = sct.tmp_create(verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('sct_convert -i '+fname_data+' -o '+path_tmp+ftmp_data)
    sct.run('sct_convert -i '+fname_seg+' -o '+path_tmp+ftmp_seg)
    sct.run('sct_convert -i '+fname_landmarks+' -o '+path_tmp+ftmp_label)
    sct.run('sct_convert -i '+fname_template+' -o '+path_tmp+ftmp_template)
    sct.run('sct_convert -i '+fname_template_seg+' -o '+path_tmp+ftmp_template_seg)
    # sct.run('sct_convert -i '+fname_template_label+' -o '+path_tmp+ftmp_template_label)

    # go to tmp folder
    os.chdir(path_tmp)

    # Generate labels from template vertebral labeling
    sct.printv('\nGenerate labels from template vertebral labeling', verbose)
    sct.run('sct_label_utils -i '+fname_template_vertebral_labeling+' -vert-body 0 -o '+ftmp_template_label)

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    sct.run('sct_maths -i seg.nii.gz -bin 0.5 -o seg.nii.gz')

    # smooth segmentation (jcohenadad, issue #613)
    # sct.printv('\nSmooth segmentation...', verbose)
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 1.5 -o '+add_suffix(ftmp_seg, '_smooth'))
    # jcohenadad: updated 2016-06-16: DO NOT smooth the seg anymore. Issue #
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 0 -o '+add_suffix(ftmp_seg, '_smooth'))
    # ftmp_seg = add_suffix(ftmp_seg, '_smooth')

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        sct.run('sct_resample -i '+ftmp_data+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_data, '_1mm'))
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        sct.run('sct_resample -i '+ftmp_seg+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_seg, '_1mm'))
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run('sct_image -i '+ftmp_data+' -setorient RPI -o '+add_suffix(ftmp_data, '_rpi'))
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run('sct_image -i '+ftmp_seg+' -setorient RPI -o '+add_suffix(ftmp_seg, '_rpi'))
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run('sct_image -i '+ftmp_label+' -setorient RPI -o '+add_suffix(ftmp_label, '_rpi'))
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # get landmarks in native space
        # crop segmentation
        # output: segmentation_rpi_crop.nii.gz
        status_crop, output_crop = sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -bzmax', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_crop')
        cropping_slices = output_crop.split('Dimension 2: ')[1].split('\n')[0].split(' ')

        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        if os.path.isfile('../warp_curve2straight.nii.gz') and os.path.isfile('../warp_straight2curve.nii.gz') and os.path.isfile('../straight_ref.nii.gz'):
            # if they exist, copy them into current folder
            sct.printv('WARNING: Straightening was already run previously. Copying warping fields...', verbose, 'warning')
            shutil.copy('../warp_curve2straight.nii.gz', 'warp_curve2straight.nii.gz')
            shutil.copy('../warp_straight2curve.nii.gz', 'warp_straight2curve.nii.gz')
            shutil.copy('../straight_ref.nii.gz', 'straight_ref.nii.gz')
            # apply straightening
            sct.run('sct_apply_transfo -i '+ftmp_seg+' -w warp_curve2straight.nii.gz -d straight_ref.nii.gz -o '+add_suffix(ftmp_seg, '_straight'))
        else:
            sct.run('sct_straighten_spinalcord -i '+ftmp_seg+' -s '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straight')+' -qc 0 -r 0 -v '+str(verbose), verbose)
        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d '+ftmp_data+' -o warp_straight2curve.nii.gz')

        # Label preparation:
        # --------------------------------------------------------------------------------
        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run('sct_label_utils -i '+ftmp_template_label+' -o '+ftmp_template_label+' -remove '+ftmp_label)

        # Dilating the input label so they can be straighten without losing them
        sct.printv('\nDilating input labels using 3vox ball radius')
        sct.run('sct_maths -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_dilate')+' -dilate 3')
        ftmp_label = add_suffix(ftmp_label, '_dilate')

        # Apply straightening to labels
        sct.printv('\nApply straightening to labels...', verbose)
        sct.run('sct_apply_transfo -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_straight')+' -d '+add_suffix(ftmp_seg, '_straight')+' -w warp_curve2straight.nii.gz -x nn')
        ftmp_label = add_suffix(ftmp_label, '_straight')

        # Compute rigid transformation straight landmarks --> template landmarks
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        try:
            register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof, fname_affine='straight2templateAffine.txt', verbose=verbose)
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/', verbose=verbose, type='error')

        # Concatenate transformations: curve --> straight --> affine
        sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
        sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct.run('sct_apply_transfo -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz')
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct.run('sct_apply_transfo -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz -x linear')
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(round(np.min(points_straight))), int(round(np.max(points_straight)))
        sct.run('sct_crop_image -i ' + ftmp_seg + ' -start ' + str(min_point) + ' -end ' + str(max_point) + ' -dim 2 -b 0 -o ' + add_suffix(ftmp_seg, '_black'))
        ftmp_seg = add_suffix(ftmp_seg, '_black')
        """

        # binarize
        sct.printv('\nBinarize segmentation...', verbose)
        sct.run('sct_maths -i '+ftmp_seg+' -bin 0.5 -o '+add_suffix(ftmp_seg, '_bin'))
        ftmp_seg = add_suffix(ftmp_seg, '_bin')

        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = find_zmin_zmax(ftmp_seg)

        # crop template in z-direction (for faster processing)
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        sct.run('sct_crop_image -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_template = add_suffix(ftmp_template, '_crop')
        sct.run('sct_crop_image -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_crop')
        sct.run('sct_crop_image -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_data = add_suffix(ftmp_data, '_crop')
        sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_seg = add_suffix(ftmp_seg, '_crop')

        # sub-sample in z-direction
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run('sct_resample -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run('sct_resample -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run('sct_resample -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run('sct_resample -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
                # apply transformation from previous step, to use as new src for registration
                sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+add_suffix(src, '_regStep'+str(i_step-1))+' -x '+interp_step, verbose)
                src = add_suffix(src, '_regStep'+str(i_step-1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()
        sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run('sct_image -i ' + ftmp_data + ' -setorient RPI -o ' + add_suffix(ftmp_data, '_rpi'))
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run('sct_image -i ' + ftmp_seg + ' -setorient RPI -o ' + add_suffix(ftmp_seg, '_rpi'))
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run('sct_image -i ' + ftmp_label + ' -setorient RPI -o ' + add_suffix(ftmp_label, '_rpi'))
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run('sct_label_utils -i '+ftmp_template_label+' -o '+ftmp_template_label+' -remove '+ftmp_label)

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[new_label.x, new_label.y, new_label.z] = new_label.value
            # Overwrite label file
            # im_label.setFileName('label_rpi_modif.nii.gz')
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['-template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc=param.path_qc)
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+add_suffix(src, '_regStep'+str(i_step-1))+' -x '+interp_step, verbose)
            src = add_suffix(src, '_regStep'+str(i_step-1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct.run('sct_concat_transfo -w '+','.join(warp_forward)+' -d data.nii -o warp_template2anat.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+' -d template.nii -o warp_anat2template.nii.gz', verbose)

    # Apply warping fields to anat and template
    sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -crop 1', verbose)
    sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -crop 1', verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'warp_template2anat.nii.gz', path_output+'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'warp_anat2template.nii.gz', path_output+'warp_anat2template.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'template2anat.nii.gz', path_output+'template2anat'+ext_data, verbose)
    sct.generate_output_file(path_tmp+'anat2template.nii.gz', path_output+'anat2template'+ext_data, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(path_tmp+'warp_curve2straight.nii.gz', path_output+'warp_curve2straight.nii.gz', verbose)
        sct.generate_output_file(path_tmp+'warp_straight2curve.nii.gz', path_output+'warp_straight2curve.nii.gz', verbose)
        sct.generate_output_file(path_tmp+'straight_ref.nii.gz', path_output+'straight_ref.nii.gz', verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' '+path_output+'template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 '+path_output+'anat2template &\n', verbose, 'info')
def main():
    parser = get_parser()
    param = Param()

    args = sys.argv[1:]

    arguments = parser.parse(args)

    # get arguments
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    fname_landmarks = arguments['-l']
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''
    path_template = sct.slash_at_the_end(arguments['-t'], 1)
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    remove_temp_files = int(arguments['-r'])
    verbose = int(arguments['-v'])
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    if '-param-straighten' in arguments:
        param.param_straighten = arguments['-param-straighten']
    # if '-cpu-nb' in arguments:
    #     arg_cpu = ' -cpu-nb '+str(arguments['-cpu-nb'])
    # else:
    #     arg_cpu = ''
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    # file_template_label = param.file_template_label
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # retrieve template file names
    from sct_warp_template import get_file_label
    file_template_vertebral_labeling = get_file_label(path_template + 'template/', 'vertebral')
    file_template = get_file_label(path_template + 'template/', contrast_template.upper() + '-weighted')
    file_template_seg = get_file_label(path_template + 'template/', 'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = path_template + 'template/' + file_template
    fname_template_vertebral_labeling = path_template + 'template/' + file_template_vertebral_labeling
    fname_template_seg = path_template + 'template/' + file_template_seg

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 ' + fname_data, verbose)
    sct.printv('  Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('  Segmentation:         ' + fname_seg, verbose)
    sct.printv('  Path template:        ' + path_template, verbose)
    sct.printv('  Remove temp files:    ' + str(remove_temp_files), verbose)

    # create QC folder
    sct.create_folder(param.path_qc)

    # check if data, segmentation and landmarks are in the same space
    # JULIEN 2017-04-25: removed because of issue #1168
    # sct.printv('\nCheck if data, segmentation and landmarks are in the same space...')
    # if not sct.check_if_same_space(fname_data, fname_seg):
    #     sct.printv('ERROR: Data image and segmentation are not in the same space. Please check space and orientation of your files', verbose, 'error')
    # if not sct.check_if_same_space(fname_data, fname_landmarks):
    #     sct.printv('ERROR: Data image and landmarks are not in the same space. Please check space and orientation of your files', verbose, 'error')

    # check input labels
    labels = check_labels(fname_landmarks)

    # create temporary folder
    path_tmp = sct.tmp_create(verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('sct_convert -i ' + fname_data + ' -o ' + path_tmp + ftmp_data)
    sct.run('sct_convert -i ' + fname_seg + ' -o ' + path_tmp + ftmp_seg)
    sct.run('sct_convert -i ' + fname_landmarks + ' -o ' + path_tmp + ftmp_label)
    sct.run('sct_convert -i ' + fname_template + ' -o ' + path_tmp + ftmp_template)
    sct.run('sct_convert -i ' + fname_template_seg + ' -o ' + path_tmp + ftmp_template_seg)
    # sct.run('sct_convert -i '+fname_template_label+' -o '+path_tmp+ftmp_template_label)

    # go to tmp folder
    os.chdir(path_tmp)

    # copy header of anat to segmentation (issue #1168)
    # from sct_image import copy_header
    # im_data = Image(ftmp_data)
    # im_seg = Image(ftmp_seg)
    # copy_header(im_data, im_seg)
    # im_seg.save()
    # im_label = Image(ftmp_label)
    # copy_header(im_data, im_label)
    # im_label.save()

    # Generate labels from template vertebral labeling
    sct.printv('\nGenerate labels from template vertebral labeling', verbose)
    sct.run('sct_label_utils -i ' + fname_template_vertebral_labeling + ' -vert-body 0 -o ' + ftmp_template_label)

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    sct.run('sct_maths -i seg.nii.gz -bin 0.5 -o seg.nii.gz')

    # smooth segmentation (jcohenadad, issue #613)
    # sct.printv('\nSmooth segmentation...', verbose)
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 1.5 -o '+add_suffix(ftmp_seg, '_smooth'))
    # jcohenadad: updated 2016-06-16: DO NOT smooth the seg anymore. Issue #
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 0 -o '+add_suffix(ftmp_seg, '_smooth'))
    # ftmp_seg = add_suffix(ftmp_seg, '_smooth')

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        sct.run('sct_resample -i ' + ftmp_data + ' -mm 1.0x1.0x1.0 -x linear -o ' + add_suffix(ftmp_data, '_1mm'))
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        sct.run('sct_resample -i ' + ftmp_seg + ' -mm 1.0x1.0x1.0 -x linear -o ' + add_suffix(ftmp_seg, '_1mm'))
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run('sct_image -i ' + ftmp_data + ' -setorient RPI -o ' + add_suffix(ftmp_data, '_rpi'))
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run('sct_image -i ' + ftmp_seg + ' -setorient RPI -o ' + add_suffix(ftmp_seg, '_rpi'))
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run('sct_image -i ' + ftmp_label + ' -setorient RPI -o ' + add_suffix(ftmp_label, '_rpi'))
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # get landmarks in native space
        # crop segmentation
        # output: segmentation_rpi_crop.nii.gz
        status_crop, output_crop = sct.run('sct_crop_image -i ' + ftmp_seg + ' -o ' + add_suffix(ftmp_seg, '_crop') + ' -dim 2 -bzmax', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_crop')
        cropping_slices = output_crop.split('Dimension 2: ')[1].split('\n')[0].split(' ')

        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        if os.path.isfile('../warp_curve2straight.nii.gz') and os.path.isfile('../warp_straight2curve.nii.gz') and os.path.isfile('../straight_ref.nii.gz'):
            # if they exist, copy them into current folder
            sct.printv('WARNING: Straightening was already run previously. Copying warping fields...', verbose, 'warning')
            shutil.copy('../warp_curve2straight.nii.gz', 'warp_curve2straight.nii.gz')
            shutil.copy('../warp_straight2curve.nii.gz', 'warp_straight2curve.nii.gz')
            shutil.copy('../straight_ref.nii.gz', 'straight_ref.nii.gz')
            # apply straightening
            sct.run('sct_apply_transfo -i ' + ftmp_seg + ' -w warp_curve2straight.nii.gz -d straight_ref.nii.gz -o ' + add_suffix(ftmp_seg, '_straight'))
        else:
            sct.run('sct_straighten_spinalcord -i ' + ftmp_seg + ' -s ' + ftmp_seg + ' -o ' + add_suffix(ftmp_seg, '_straight') + ' -qc 0 -r 0 -v ' + str(verbose), verbose)
        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d ' + ftmp_data + ' -o warp_straight2curve.nii.gz')

        # Label preparation:
        # --------------------------------------------------------------------------------
        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run('sct_label_utils -i ' + ftmp_template_label + ' -o ' + ftmp_template_label + ' -remove ' + ftmp_label)

        # Dilating the input label so they can be straighten without losing them
        sct.printv('\nDilating input labels using 3vox ball radius')
        sct.run('sct_maths -i ' + ftmp_label + ' -o ' + add_suffix(ftmp_label, '_dilate') + ' -dilate 3')
        ftmp_label = add_suffix(ftmp_label, '_dilate')

        # Apply straightening to labels
        sct.printv('\nApply straightening to labels...', verbose)
        sct.run('sct_apply_transfo -i ' + ftmp_label + ' -o ' + add_suffix(ftmp_label, '_straight') + ' -d ' + add_suffix(ftmp_seg, '_straight') + ' -w warp_curve2straight.nii.gz -x nn')
        ftmp_label = add_suffix(ftmp_label, '_straight')

        # Compute rigid transformation straight landmarks --> template landmarks
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        try:
            register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof, fname_affine='straight2templateAffine.txt', verbose=verbose)
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/', verbose=verbose, type='error')

        # Concatenate transformations: curve --> straight --> affine
        sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
        sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct.run('sct_apply_transfo -i ' + ftmp_data + ' -o ' + add_suffix(ftmp_data, '_straightAffine') + ' -d ' + ftmp_template + ' -w warp_curve2straightAffine.nii.gz')
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct.run('sct_apply_transfo -i ' + ftmp_seg + ' -o ' + add_suffix(ftmp_seg, '_straightAffine') + ' -d ' + ftmp_template + ' -w warp_curve2straightAffine.nii.gz -x linear')
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(round(np.min(points_straight))), int(round(np.max(points_straight)))
        sct.run('sct_crop_image -i ' + ftmp_seg + ' -start ' + str(min_point) + ' -end ' + str(max_point) + ' -dim 2 -b 0 -o ' + add_suffix(ftmp_seg, '_black'))
        ftmp_seg = add_suffix(ftmp_seg, '_black')
        """

        # binarize
        sct.printv('\nBinarize segmentation...', verbose)
        sct.run('sct_maths -i ' + ftmp_seg + ' -bin 0.5 -o ' + add_suffix(ftmp_seg, '_bin'))
        ftmp_seg = add_suffix(ftmp_seg, '_bin')

        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = find_zmin_zmax(ftmp_seg)

        # crop template in z-direction (for faster processing)
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        sct.run('sct_crop_image -i ' + ftmp_template + ' -o ' + add_suffix(ftmp_template, '_crop') + ' -dim 2 -start ' + str(zmin_template) + ' -end ' + str(zmax_template))
        ftmp_template = add_suffix(ftmp_template, '_crop')
        sct.run('sct_crop_image -i ' + ftmp_template_seg + ' -o ' + add_suffix(ftmp_template_seg, '_crop') + ' -dim 2 -start ' + str(zmin_template) + ' -end ' + str(zmax_template))
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_crop')
        sct.run('sct_crop_image -i ' + ftmp_data + ' -o ' + add_suffix(ftmp_data, '_crop') + ' -dim 2 -start ' + str(zmin_template) + ' -end ' + str(zmax_template))
        ftmp_data = add_suffix(ftmp_data, '_crop')
        sct.run('sct_crop_image -i ' + ftmp_seg + ' -o ' + add_suffix(ftmp_seg, '_crop') + ' -dim 2 -start ' + str(zmin_template) + ' -end ' + str(zmax_template))
        ftmp_seg = add_suffix(ftmp_seg, '_crop')

        # sub-sample in z-direction
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run('sct_resample -i ' + ftmp_template + ' -o ' + add_suffix(ftmp_template, '_sub') + ' -f 1x1x' + zsubsample, verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run('sct_resample -i ' + ftmp_template_seg + ' -o ' + add_suffix(ftmp_template_seg, '_sub') + ' -f 1x1x' + zsubsample, verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run('sct_resample -i ' + ftmp_data + ' -o ' + add_suffix(ftmp_data, '_sub') + ' -f 1x1x' + zsubsample, verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run('sct_resample -i ' + ftmp_seg + ' -o ' + add_suffix(ftmp_seg, '_sub') + ' -f 1x1x' + zsubsample, verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
                # apply transformation from previous step, to use as new src for registration
                sct.run('sct_apply_transfo -i ' + src + ' -d ' + dest + ' -w ' + ','.join(warp_forward) + ' -o ' + add_suffix(src, '_regStep' + str(i_step - 1)) + ' -x ' + interp_step, verbose)
                src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,' + ','.join(warp_forward) + ' -d template.nii -o warp_anat2template.nii.gz', verbose)
        # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()
        sct.run('sct_concat_transfo -w ' + ','.join(warp_inverse) + ',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run('sct_image -i ' + ftmp_data + ' -setorient RPI -o ' + add_suffix(ftmp_data, '_rpi'))
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run('sct_image -i ' + ftmp_seg + ' -setorient RPI -o ' + add_suffix(ftmp_seg, '_rpi'))
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run('sct_image -i ' + ftmp_label + ' -setorient RPI -o ' + add_suffix(ftmp_label, '_rpi'))
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run('sct_label_utils -i ' + ftmp_template_label + ' -o ' + ftmp_template_label + ' -remove ' + ftmp_label)

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x), int(new_label.y), int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.setFileName('label_rpi_modif.nii.gz')
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['-template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc=param.path_qc)
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct.run('sct_apply_transfo -i ' + src + ' -d ' + dest + ' -w ' + ','.join(warp_forward) + ' -o ' + add_suffix(src, '_regStep' + str(i_step - 1)) + ' -x ' + interp_step, verbose)
            src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct.run('sct_concat_transfo -w ' + ','.join(warp_forward) + ' -d data.nii -o warp_template2anat.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct.run('sct_concat_transfo -w ' + ','.join(warp_inverse) + ' -d template.nii -o warp_anat2template.nii.gz', verbose)

    # Apply warping fields to anat and template
    sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -crop 1', verbose)
    sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -crop 1', verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp + 'warp_template2anat.nii.gz', path_output + 'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp + 'warp_anat2template.nii.gz', path_output + 'warp_anat2template.nii.gz', verbose)
    sct.generate_output_file(path_tmp + 'template2anat.nii.gz', path_output + 'template2anat' + ext_data, verbose)
    sct.generate_output_file(path_tmp + 'anat2template.nii.gz', path_output + 'anat2template' + ext_data, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(path_tmp + 'warp_curve2straight.nii.gz', path_output + 'warp_curve2straight.nii.gz', verbose)
        sct.generate_output_file(path_tmp + 'warp_straight2curve.nii.gz', path_output + 'warp_straight2curve.nii.gz', verbose)
        sct.generate_output_file(path_tmp + 'straight_ref.nii.gz', path_output + 'straight_ref.nii.gz', verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf ' + path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: ' + str(int(round(elapsed_time))) + 's', verbose)

    if '-qc' in arguments and not arguments.get('-noqc', False):
        qc_path = arguments['-qc']

        import spinalcordtoolbox.reports.qc as qc
        import spinalcordtoolbox.reports.slice as qcslice

        qc_param = qc.Params(fname_data, 'sct_register_to_template', args, 'Sagittal', qc_path)
        report = qc.QcReport(qc_param, '')

        @qc.QcImage(report, 'none', [qc.QcImage.no_seg_seg])
        def test(qslice):
            return qslice.single()

        fname_template2anat = path_output + 'template2anat' + ext_data
        test(qcslice.SagittalTemplate2Anat(Image(fname_data), Image(fname_template2anat), Image(fname_seg)))
        sct.printv('Sucessfully generate the QC results in %s' % qc_param.qc_results)
        sct.printv('Use the following command to see the results in a browser')
        sct.printv('sct_qc -folder %s' % qc_path, type='info')

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview ' + fname_data + ' ' + path_output + 'template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview ' + fname_template + ' -b 0,5000 ' + path_output + 'anat2template &\n', verbose, 'info')
def main():
    parser = get_parser()
    param = Param()

    """ Rewrite arguments and set parameters"""
    arguments = parser.parse(sys.argv[1:])
    (fname_data, fname_landmarks, path_output, path_template, contrast_template, ref, remove_temp_files,
     verbose, init_labels, first_label,nb_slice_to_mean)=rewrite_arguments(arguments)
    (param, paramreg)=write_paramaters(arguments,param,ref,verbose)

    if(init_labels):
        use_viewer_to_define_labels(fname_data,first_label,nb_slice_to_mean)
    # initialize other parameters
    # file_template_label = param.file_template_label
    zsubsample = param.zsubsample
    template = os.path.basename(os.path.normpath(pth_template))
    # smoothing_sigma = param.smoothing_sigma

    # retrieve template file names

    from sct_warp_template import get_file_label
    file_template_vertebral_labeling = get_file_label(path_template+'template/', 'vertebral')
    file_template = get_file_label(path_template+'template/', contrast_template.upper()+'-weighted')
    file_template_seg = get_file_label(path_template+'template/', 'spinal cord')


    """ Start timer"""
    start_time = time.time()

    """ Manage file of templates"""
    (fname_template, fname_template_vertebral_labeling, fname_template_seg)=make_fname_of_templates(file_template,path_template,file_template_vertebral_labeling,file_template_seg)
    check_do_files_exist(fname_template,fname_template_vertebral_labeling,fname_template_seg,verbose)
    sct.printv(arguments(verbose, fname_data, fname_landmarks, fname_seg, path_template, remove_temp_files))

    """ Create QC folder """
    sct.create_folder(param.path_qc)

    """ Check if data, segmentation and landmarks are in the same space"""
    (ext_data, path_data, file_data)=check_data_segmentation_landmarks_same_space(fname_data, fname_seg, fname_landmarks,verbose)

    ''' Check input labels'''
    labels = check_labels(fname_landmarks)

    """ Create temporary folder, set temporary file names, copy files into it and go in it """
    path_tmp = sct.tmp_create(verbose=verbose)
    (ftmp_data, ftmp_seg, ftmp_label, ftmp_template, ftmp_template_seg, ftmp_template_label)=set_temporary_files()
    copy_files_to_temporary_files(verbose, fname_data, path_tmp, ftmp_seg, ftmp_data, fname_seg, fname_landmarks,
                                  ftmp_label, fname_template, ftmp_template, fname_template_seg, ftmp_template_seg)
    os.chdir(path_tmp)

    ''' Generate labels from template vertebral labeling'''
    sct.printv('\nGenerate labels from template vertebral labeling', verbose)
    sct.run('sct_label_utils -i '+fname_template_vertebral_labeling+' -vert-body 0 -o '+ftmp_template_label)

    ''' Check if provided labels are available in the template'''
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    ''' Binarize segmentation (in case it has values below 0 caused by manual editing)'''
    sct.printv('\nBinarize segmentation', verbose)
    sct.run('sct_maths -i seg.nii.gz -bin 0.5 -o seg.nii.gz')

    # smooth segmentation (jcohenadad, issue #613)
    # sct.printv('\nSmooth segmentation...', verbose)
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 1.5 -o '+add_suffix(ftmp_seg, '_smooth'))
    # jcohenadad: updated 2016-06-16: DO NOT smooth the seg anymore. Issue #
    # sct.run('sct_maths -i '+ftmp_seg+' -smooth 0 -o '+add_suffix(ftmp_seg, '_smooth'))
    # ftmp_seg = add_suffix(ftmp_seg, '_smooth')

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        sct.run('sct_resample -i '+ftmp_data+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_data, '_1mm'))
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        sct.run('sct_resample -i '+ftmp_seg+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_seg, '_1mm'))
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run('sct_image -i '+ftmp_data+' -setorient RPI -o '+add_suffix(ftmp_data, '_rpi'))
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run('sct_image -i '+ftmp_seg+' -setorient RPI -o '+add_suffix(ftmp_seg, '_rpi'))
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run('sct_image -i '+ftmp_label+' -setorient RPI -o '+add_suffix(ftmp_label, '_rpi'))
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # get landmarks in native space
        # crop segmentation
        # output: segmentation_rpi_crop.nii.gz
        status_crop, output_crop = sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -bzmax', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_crop')
        cropping_slices = output_crop.split('Dimension 2: ')[1].split('\n')[0].split(' ')

        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        if os.path.isfile('../warp_curve2straight.nii.gz') and os.path.isfile('../warp_straight2curve.nii.gz') and os.path.isfile('../straight_ref.nii.gz'):
            # if they exist, copy them into current folder
            sct.printv('WARNING: Straightening was already run previously. Copying warping fields...', verbose, 'warning')
            shutil.copy('../warp_curve2straight.nii.gz', 'warp_curve2straight.nii.gz')
            shutil.copy('../warp_straight2curve.nii.gz', 'warp_straight2curve.nii.gz')
            shutil.copy('../straight_ref.nii.gz', 'straight_ref.nii.gz')
            # apply straightening
            sct.run('sct_apply_transfo -i '+ftmp_seg+' -w warp_curve2straight.nii.gz -d straight_ref.nii.gz -o '+add_suffix(ftmp_seg, '_straight'))
        else:
            sct.run('sct_straighten_spinalcord -i '+ftmp_seg+' -s '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straight')+' -qc 0 -r 0 -v '+str(verbose), verbose)
        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d '+ftmp_data+' -o warp_straight2curve.nii.gz')

        # Label preparation:
        # --------------------------------------------------------------------------------
        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run('sct_label_utils -i '+ftmp_template_label+' -o '+ftmp_template_label+' -remove '+ftmp_label)

        # Dilating the input label so they can be straighten without losing them
        sct.printv('\nDilating input labels using 3vox ball radius')
        sct.run('sct_maths -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_dilate')+' -dilate 3')
        ftmp_label = add_suffix(ftmp_label, '_dilate')

        # Apply straightening to labels
        sct.printv('\nApply straightening to labels...', verbose)
        sct.run('sct_apply_transfo -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_straight')+' -d '+add_suffix(ftmp_seg, '_straight')+' -w warp_curve2straight.nii.gz -x nn')
        ftmp_label = add_suffix(ftmp_label, '_straight')

        # Compute rigid transformation straight landmarks --> template landmarks
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        try:
            register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof, fname_affine='straight2templateAffine.txt', verbose=verbose)
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/', verbose=verbose, type='error')

        # Concatenate transformations: curve --> straight --> affine
        sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
        sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct.run('sct_apply_transfo -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz')
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct.run('sct_apply_transfo -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz -x linear')
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(round(np.min(points_straight))), int(round(np.max(points_straight)))
        sct.run('sct_crop_image -i ' + ftmp_seg + ' -start ' + str(min_point) + ' -end ' + str(max_point) + ' -dim 2 -b 0 -o ' + add_suffix(ftmp_seg, '_black'))
        ftmp_seg = add_suffix(ftmp_seg, '_black')
        """

        # binarize
        sct.printv('\nBinarize segmentation...', verbose)
        sct.run('sct_maths -i '+ftmp_seg+' -bin 0.5 -o '+add_suffix(ftmp_seg, '_bin'))
        ftmp_seg = add_suffix(ftmp_seg, '_bin')

        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = find_zmin_zmax(ftmp_seg)

        # crop template in z-direction (for faster processing)
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        sct.run('sct_crop_image -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_template = add_suffix(ftmp_template, '_crop')
        sct.run('sct_crop_image -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_crop')
        sct.run('sct_crop_image -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_data = add_suffix(ftmp_data, '_crop')
        sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
        ftmp_seg = add_suffix(ftmp_seg, '_crop')

        # sub-sample in z-direction
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run('sct_resample -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run('sct_resample -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run('sct_resample -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run('sct_resample -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
                # apply transformation from previous step, to use as new src for registration
                sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+add_suffix(src, '_regStep'+str(i_step-1))+' -x '+interp_step, verbose)
                src = add_suffix(src, '_regStep'+str(i_step-1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()
        sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        sct.run('sct_image -i ' + ftmp_data + ' -setorient RPI -o ' + add_suffix(ftmp_data, '_rpi'))
        ftmp_data = add_suffix(ftmp_data, '_rpi')
        sct.run('sct_image -i ' + ftmp_seg + ' -setorient RPI -o ' + add_suffix(ftmp_seg, '_rpi'))
        ftmp_seg = add_suffix(ftmp_seg, '_rpi')
        sct.run('sct_image -i ' + ftmp_label + ' -setorient RPI -o ' + add_suffix(ftmp_label, '_rpi'))
        ftmp_label = add_suffix(ftmp_label, '_rpi')

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run('sct_label_utils -i '+ftmp_template_label+' -o '+ftmp_template_label+' -remove '+ftmp_label)

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x), int(new_label.y), int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.setFileName('label_rpi_modif.nii.gz')
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        from msct_register_landmarks import register_landmarks
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['-template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc=param.path_qc)
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+add_suffix(src, '_regStep'+str(i_step-1))+' -x '+interp_step, verbose)
            src = add_suffix(src, '_regStep'+str(i_step-1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct.run('sct_concat_transfo -w '+','.join(warp_forward)+' -d data.nii -o warp_template2anat.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+' -d template.nii -o warp_anat2template.nii.gz', verbose)

    # Apply warping fields to anat and template
    sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -crop 1', verbose)
    sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -crop 1', verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'warp_template2anat.nii.gz', path_output+'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'warp_anat2template.nii.gz', path_output+'warp_anat2template.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'template2anat.nii.gz', path_output+'template2anat'+ext_data, verbose)
    sct.generate_output_file(path_tmp+'anat2template.nii.gz', path_output+'anat2template'+ext_data, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(path_tmp+'warp_curve2straight.nii.gz', path_output+'warp_curve2straight.nii.gz', verbose)
        sct.generate_output_file(path_tmp+'warp_straight2curve.nii.gz', path_output+'warp_straight2curve.nii.gz', verbose)
        sct.generate_output_file(path_tmp+'straight_ref.nii.gz', path_output+'straight_ref.nii.gz', verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' '+path_output+'template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 '+path_output+'anat2template &\n', verbose, 'info')
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose=1, vertebral_levels=None):
        self.image_input = Image(fname_label, verbose=verbose)

        self.image_ref = None
        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.vertebral_levels = vertebral_levels
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose

    def process(self, type_process):
        if type_process == 'cross':
            self.output_image = self.cross()
        elif type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        elif type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        elif type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        elif type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        elif type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        elif type_process == 'remove':
            self.output_image = self.remove_label()
        elif type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        elif type_process == 'centerline':
            self.extract_centerline()
        elif type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        elif type_process == 'create':
            self.output_image = self.create_label()
        elif type_process == 'add':
            self.output_image = self.create_label(add=True)
        elif type_process == 'diff':
            self.diff()
            self.fname_output = None
        elif type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        elif type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        elif type_process == 'label-vertebrae':
            self.output_image = self.label_vertebrae(self.vertebral_levels)
        elif type_process == 'label-vertebrae-from-disks':
            self.output_image = self.label_vertebrae_from_disks(self.vertebral_levels)
        else:
            sct.printv('Error: The chosen process is not available.', 1, 'error')

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            if type_process != 'plan_ref':
                self.output_image.save('minimize_int')
            else:
                self.output_image.save()

    @staticmethod
    def get_crosses_coordinates(coordinates_input, gapxy=15, image_ref=None, dilate=False):
        from msct_types import Coordinate

        # if reference image is provided (segmentation), we draw the cross perpendicular to the centerline
        if image_ref is not None:
            # smooth centerline
            from sct_straighten_spinalcord import smooth_centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(self.image_ref, verbose=self.verbose)

        # compute crosses
        cross_coordinates = []
        for coord in coordinates_input:
            if image_ref is None:
                from sct_straighten_spinalcord import compute_cross
                cross_coordinates_temp = compute_cross(coord, gapxy)
            else:
                from sct_straighten_spinalcord import compute_cross_centerline
                from numpy import where
                index_z = where(z_centerline == coord.z)
                deriv = Coordinate([x_centerline_deriv[index_z][0], y_centerline_deriv[index_z][0], z_centerline_deriv[index_z][0], 0.0])
                cross_coordinates_temp = compute_cross_centerline(coord, deriv, gapxy)

            for i, coord_cross in enumerate(cross_coordinates_temp):
                coord_cross.value = coord.value * 10 + i + 1

            # dilate cross to 3x3x3
            if dilate:
                additional_coordinates = []
                for coord_temp in cross_coordinates_temp:
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y+1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y+1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y+1.0, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y-1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y-1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x, coord_temp.y-1.0, coord_temp.z-1.0, coord_temp.value]))

                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y+1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y+1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y+1.0, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y-1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y-1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x+1.0, coord_temp.y-1.0, coord_temp.z-1.0, coord_temp.value]))

                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y+1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y+1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y+1.0, coord_temp.z-1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y-1.0, coord_temp.z, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y-1.0, coord_temp.z+1.0, coord_temp.value]))
                    additional_coordinates.append(Coordinate([coord_temp.x-1.0, coord_temp.y-1.0, coord_temp.z-1.0, coord_temp.value]))

                cross_coordinates_temp.extend(additional_coordinates)

            cross_coordinates.extend(cross_coordinates_temp)

        cross_coordinates = sorted(cross_coordinates, key=lambda obj: obj.value)
        return cross_coordinates

    # JULIEN <<<<<<
    # OLD IMPLEMENTATION:
    # def cross(self):
    #     """
    #     create a cross.
    #     :return:
    #     """
    #     image_output = Image(self.image_input, self.verbose)
    #     nx, ny, nz, nt, px, py, pz, pt = Image(self.image_input.absolutepath).dim
    #
    #     coordinates_input = self.image_input.getNonZeroCoordinates()
    #     d = self.cross_radius  # cross radius in pixel
    #     dx = d / px  # cross radius in mm
    #     dy = d / py
    #
    #     # for all points with non-zeros neighbors, force the neighbors to 0
    #     for coord in coordinates_input:
    #         image_output.data[coord.x][coord.y][coord.z] = 0  # remove point on the center of the spinal cord
    #         image_output.data[coord.x][coord.y + dy][
    #             coord.z] = coord.value * 10 + 1  # add point at distance from center of spinal cord
    #         image_output.data[coord.x + dx][coord.y][coord.z] = coord.value * 10 + 2
    #         image_output.data[coord.x][coord.y - dy][coord.z] = coord.value * 10 + 3
    #         image_output.data[coord.x - dx][coord.y][coord.z] = coord.value * 10 + 4
    #
    #         # dilate cross to 3x3
    #         if self.dilate:
    #             image_output.data[coord.x - 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x][coord.y + dy - 1][coord.z] = \
    #                 image_output.data[coord.x + 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y + dy][coord.z] = \
    #                 image_output.data[coord.x + 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x][coord.y + dy + 1][coord.z] = \
    #                 image_output.data[coord.x - 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y + dy][coord.z] = \
    #                 image_output.data[coord.x][coord.y + dy][coord.z]
    #             image_output.data[coord.x + dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx][coord.y - 1][coord.z] = \
    #                 image_output.data[coord.x + dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx + 1][coord.y][coord.z] = \
    #                 image_output.data[coord.x + dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx][coord.y + 1][coord.z] = \
    #                 image_output.data[coord.x + dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx - 1][coord.y][coord.z] = \
    #                 image_output.data[coord.x + dx][coord.y][coord.z]
    #             image_output.data[coord.x - 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x][coord.y - dy - 1][coord.z] = \
    #                 image_output.data[coord.x + 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y - dy][coord.z] = \
    #                 image_output.data[coord.x + 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x][coord.y - dy + 1][coord.z] = \
    #                 image_output.data[coord.x - 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y - dy][coord.z] = \
    #                 image_output.data[coord.x][coord.y - dy][coord.z]
    #             image_output.data[coord.x - dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx][coord.y - 1][coord.z] = \
    #                 image_output.data[coord.x - dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx + 1][coord.y][coord.z] = \
    #                 image_output.data[coord.x - dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx][coord.y + 1][coord.z] = \
    #                 image_output.data[coord.x - dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx - 1][coord.y][coord.z] = \
    #                 image_output.data[coord.x - dx][coord.y][coord.z]
    #
    #     return image_output
    # >>>>>>>>>
    def cross(self):
        """
        create a cross.
        :return:
        """
        output_image = Image(self.image_input, self.verbose)
        nx, ny, nz, nt, px, py, pz, pt = Image(self.image_input.absolutepath).dim

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # clean output_image
        output_image.data *= 0

        cross_coordinates = self.get_crosses_coordinates(coordinates_input, dx, self.image_ref, self.dilate)

        for coord in cross_coordinates:
            output_image.data[round(coord.x), round(coord.y), round(coord.z)] = coord.value

        return output_image
    # >>>

    def plan(self, width, offset=0, gap=1):
        """
        This function creates a plan of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:,:,coord.z-width:coord.z+width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        This function generate a plan in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *=0
        image_input_pos.data *=0
        X, Y, Z = (self.image_input.data< 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]] # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data> 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, coord.z] = -coord.value #PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, coord.z] = coord.value

        return image_output

    def cubic_to_point(self):
        """
        This function calculates the center of mass of each group of labels and returns a file of same size with only a
         label by group at the center of mass of this group.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with
         the same value are present but separated in space, this algorithm will compute the center of mass of the two
         groups together.
        :return: image_output
        """
        from scipy import ndimage
        from numpy import array, mean

        # 0. Initialization of output image
        output_image = self.image_input.copy()
        output_image.data *= 0

        # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value.
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')

        # 2. Separate all coordinates into groups by value
        groups = dict()
        for coord in coordinates:
            if coord.value in groups:
                groups[coord.value].append(coord)
            else:
                groups[coord.value] = [coord]

        # 3. Compute the center of mass of each group of voxels and write them into the output image
        for value, list_coord in groups.iteritems():
            center_of_mass = sum(list_coord)/float(len(list_coord))
            sct.printv("Value = " + str(center_of_mass.value) + " : ("+str(center_of_mass.x) + ", "+str(center_of_mass.y) + ", " + str(center_of_mass.z) + ") --> ( "+ str(round(center_of_mass.x)) + ", " + str(round(center_of_mass.y)) + ", " + str(round(center_of_mass.z)) + ")", verbose=self.verbose)
            output_image.data[round(center_of_mass.x), round(center_of_mass.y), round(center_of_mass.z)] = center_of_mass.value

        return output_image

    def increment_z_inverse(self):
        """
        This function increments all the labels present in the input image, inversely ordered by Z.
        Therefore, labels are incremented from top to bottom, assuming a RPI orientation
        Labels are assumed to be non-zero.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[coord.x, coord.y, coord.z] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        This function creates an image with regions labelized depending on values from reference.
        Typically, user inputs a segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref)-1):
                if coordinates_ref[j+1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[coord.x, coord.y, coord.z] = coordinates_ref[j].value

        return image_output


    def label_vertebrae(self, levels_user=None):
        """
        Finds the center of mass of vertebral levels specified by the user.
        :return: image_output: Image with labels.
        """
        # get center of mass of each vertebral level
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates = image_cubic2point.getNonZeroCoordinates(sorting='value')
        # if user did not specify levels, include all:
        if levels_user == None:
            levels_user = [int(i.value) for i in list_coordinates]
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates)):
            # check if this level is NOT in levels_user
            if not levels_user.count(int(list_coordinates[i_label].value)):
                # if not, set value to zero
                image_cubic2point.data[list_coordinates[i_label].x, list_coordinates[i_label].y, list_coordinates[i_label].z] = 0

        # list all labels
        return image_cubic2point

    def label_vertebrae_from_disks(self, levels_user):
        """
        Finds the center of mass of vertebral levels specified by the user.
        :param levels_user:
        :return:
        """
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates_disks = image_cubic2point.getNonZeroCoordinates(sorting='value')
        image_cubic2point.data *= 0
        # compute vertebral labels from disk labels
        list_coordinates_vertebrae = []
        for i_label in range(len(list_coordinates_disks)-1):
            list_coordinates_vertebrae.append((list_coordinates_disks[i_label] + list_coordinates_disks[i_label+1]) / 2.0)
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates_vertebrae)):
            # check if this level is NOT in levels_user
            if levels_user.count(int(list_coordinates_vertebrae[i_label].value)):
                image_cubic2point.data[int(list_coordinates_vertebrae[i_label].x), int(list_coordinates_vertebrae[i_label].y), int(list_coordinates_vertebrae[i_label].z)] = list_coordinates_vertebrae[i_label].value

        return image_cubic2point


    def symmetrizer(self, side='left'):
        """
        This function symmetrize the input image. One side of the image will be copied on the other side. We assume a
        RPI orientation.
        :param side: string 'left' or 'right'. Side that will be copied on the other side.
        :return:
        """
        image_output = Image(self.image_input, self.verbose)

        image_output[0:]

        """inspiration: (from atlas creation matlab script)
        temp_sum = temp_g + temp_d;
        temp_sum_flip = temp_sum(end:-1:1,:);
        temp_sym = (temp_sum + temp_sum_flip) / 2;

        temp_g(1:end / 2,:) = 0;
        temp_g(1 + end / 2:end,:) = temp_sym(1 + end / 2:end,:);
        temp_d(1:end / 2,:) = temp_sym(1:end / 2,:);
        temp_d(1 + end / 2:end,:) = 0;

        tractsHR
        {label_l}(:,:, num_slice_ref) = temp_g;
        tractsHR
        {label_r}(:,:, num_slice_ref) = temp_d;
        """

        return image_output

    def MSE(self, threshold_mse=0):
        """
        This function computes the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(self.image_input.path + 'error_log_' + self.image_input.file_name + '.txt', 'w')
            f.write(
                'The labels error (MSE) between ' + self.image_input.file_name + ' and ' + self.image_ref.file_name + ' is: ' + str(
                    result))
            f.close()

        return result

    def create_label(self, add=False):
        """
        This function create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
            image_output.data[coord.x, coord.y, coord.z] = coord.value

        return image_output

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [coord for coord in coord_input if coord in coord_intersection]
            result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [coord for coord in coord_input if filter(lambda x: x.value == coord.value, coord_ref)]
            if symmetry:
                result_coord_ref = [coord for coord in coord_ref if filter(lambda x: x.value == coord.value, result_coord_input)]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        This function compares two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = Image(self.image_input, verbose=self.verbose)
        image_output.data *= 0  # put all voxels to 0

        result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True),
                                                                       self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[coord.x, coord.y, coord.z] = int(round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[coord.x, coord.y, coord.z] = int(round(coord.value))
            image_output_ref.setFileName(self.fname_output[1])
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def extract_centerline(self):
        """
        This function write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x,coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()

    def display_voxel(self):
        """
        This function displays all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')
        useful_notation = ''
        for coord in coordinates_input:
            print 'Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value)
            if useful_notation != '':
                useful_notation = useful_notation + ':'
            useful_notation = useful_notation + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ',' + str(coord.value)
        print 'Useful notation:'
        print useful_notation
        return coordinates_input

    def diff(self):
        """
        This function detects any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        print "Label in input image that are not in reference image:"
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord.value

        print "Label in ref image that are not in input image:"
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord_ref.value

    def distance_interlabels(self, max_dist):
        """
        This function calculates the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt((coordinates_input[i].x - coordinates_input[i+1].x)**2 + (coordinates_input[i].y - coordinates_input[i+1].y)**2 + (coordinates_input[i].z - coordinates_input[i+1].z)**2)
            if dist < max_dist:
                print 'Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i+1) + '[' + str(
                    coordinates_input[i+1].x) + ',' + str(coordinates_input[i+1].y) + ',' + str(coordinates_input[i+1].z) + ']=' + str(
                    coordinates_input[i+1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist)
def main():
    
    
    # get path of the toolbox
    status, path_sct = getstatusoutput('echo $SCT_DIR')
    #print path_sct


    #Initialization
    fname = ''
    landmark = ''
    verbose = param.verbose
    output_name = 'aligned.nii.gz'
    template_landmark = ''
    final_warp = param.final_warp
    compose = param.compose
    transfo = 'affine'
        
    try:
         opts, args = getopt.getopt(sys.argv[1:],'hi:l:o:R:t:w:c:v:')
    except getopt.GetoptError:
        usage()
    for opt, arg in opts :
        if opt == '-h':
            usage()
        elif opt in ("-i"):
            fname = arg
        elif opt in ("-l"):
            landmark = arg       
        elif opt in ("-o"):
            output_name = arg  
        elif opt in ("-R"):
            template_landmark = arg
        elif opt in ("-t"):
            transfo = arg    
        elif opt in ("-w"):
            final_warp = arg
        elif opt in ("-c"):
            compose = int(arg)                          
        elif opt in ('-v'):
            verbose = int(arg)
    
    # display usage if a mandatory argument is not provided
    if fname == '' or landmark == '' or template_landmark == '' :
        usage()
        
    if final_warp not in ['','spline','NN']:
        usage()
        
    if transfo not in ['affine', 'bspline', 'SyN', 'nurbs']:
        usage()       
    
    # check existence of input files
    print'\nCheck if file exists ...'
    
    sct.check_file_exist(fname)
    sct.check_file_exist(landmark)
    sct.check_file_exist(template_landmark)
    
    
        
    # Display arguments
    print'\nCheck input arguments...'
    print'  Input volume ...................... '+fname
    print'  Verbose ........................... '+str(verbose)

    if transfo == 'affine':
        print 'Creating cross using input landmarks\n...'
        sct.run('sct_label_utils -i ' + landmark + ' -o ' + 'cross_native.nii.gz -t cross ' )
    
        print 'Creating cross using template landmarks\n...'
        sct.run('sct_label_utils -i ' + template_landmark + ' -o ' + 'cross_template.nii.gz -t cross ' )
    
        print 'Computing affine transformation between subject and destination landmarks\n...'
        os.system('isct_ANTSUseLandmarkImagesToGetAffineTransform cross_template.nii.gz cross_native.nii.gz affine n2t.txt')
        warping = 'n2t.txt'
    elif transfo == 'nurbs':
        warping_subject2template = 'warp_subject2template.nii.gz'
        warping_template2subject = 'warp_template2subject.nii.gz'
        tmp_name = 'tmp.' + time.strftime("%y%m%d%H%M%S")
        sct.run('mkdir ' + tmp_name)
        tmp_abs_path = os.path.abspath(tmp_name)
        sct.run('cp ' + landmark + ' ' + tmp_abs_path)
        os.chdir(tmp_name)

        from msct_image import Image
        image_landmark = Image(landmark)
        image_template = Image(template_landmark)
        landmarks_input = image_landmark.getNonZeroCoordinates(sorting='value')
        landmarks_template = image_template.getNonZeroCoordinates(sorting='value')
        min_value = min([int(landmarks_input[0].value), int(landmarks_template[0].value)])
        max_value = max([int(landmarks_input[-1].value), int(landmarks_template[-1].value)])
        nx, ny, nz, nt, px, py, pz, pt = image_landmark.dim

        displacement_subject2template, displacement_template2subject = [], []
        for value in range(min_value, max_value+1):
            is_in_input = False
            coord_input = None
            for coord in landmarks_input:
                if int(value) == int(coord.value):
                    coord_input = coord
                    is_in_input = True
                    break
            is_in_template = False
            coord_template = None
            for coord in landmarks_template:
                if int(value) == int(coord.value):
                    coord_template = coord
                    is_in_template = True
                    break
            if is_in_template and is_in_input:
                displacement_subject2template.append([0.0, coord_input.z, coord_template.z - coord_input.z])
                displacement_template2subject.append([0.0, coord_template.z, coord_input.z - coord_template.z])

        # create displacement field
        from numpy import zeros
        from nibabel import Nifti1Image, save
        data_warp_subject2template = zeros((nx, ny, nz, 1, 3))
        data_warp_template2subject = zeros((nx, ny, nz, 1, 3))
        hdr_warp = image_template.hdr.copy()
        hdr_warp.set_intent('vector', (), '')
        hdr_warp.set_data_dtype('float32')

        # approximate displacement with nurbs
        from msct_smooth import b_spline_nurbs
        displacement_z = [item[1] for item in displacement_subject2template]
        displacement_x = [item[2] for item in displacement_subject2template]
        verbose = 1
        displacement_z, displacement_y, displacement_y_deriv, displacement_z_deriv = b_spline_nurbs(displacement_x, displacement_z, None, nbControl=None, verbose=verbose, all_slices=True)

        arg_min_z, arg_max_z = np.argmin(displacement_y), np.argmax(displacement_y)
        min_z, max_z = int(displacement_y[arg_min_z]), int(displacement_y[arg_max_z])
        displac = []
        for index, iz in enumerate(displacement_y):
            displac.append([iz, displacement_z[index]])
        for iz in range(0, min_z):
            displac.append([iz, displacement_z[arg_min_z]])

        for iz in range(max_z, nz):
            displac.append([iz, displacement_z[arg_max_z]])

        for item in displac:
            if 0 <= item[0] < nz:
                data_warp_template2subject[:, :, item[0], 0, 2] = item[1] * pz

        displacement_z = [item[1] for item in displacement_template2subject]
        displacement_x = [item[2] for item in displacement_template2subject]
        verbose = 1
        displacement_z, displacement_y, displacement_y_deriv, displacement_z_deriv = b_spline_nurbs(displacement_x, displacement_z, None, nbControl=None, verbose=verbose, all_slices=True)

        arg_min_z, arg_max_z = np.argmin(displacement_y), np.argmax(displacement_y)
        min_z, max_z = int(displacement_y[arg_min_z]), int(displacement_y[arg_max_z])
        displac = []
        for index, iz in enumerate(displacement_y):
            displac.append([iz, displacement_z[index]])
        for iz in range(0, min_z):
            displac.append([iz, displacement_z[arg_min_z]])
        for iz in range(max_z, nz):
            displac.append([iz, displacement_z[arg_max_z]])

        for item in displac:
            data_warp_subject2template[:, :, item[0], 0, 2] = item[1] * pz

        img = Nifti1Image(data_warp_template2subject, None, hdr_warp)
        save(img, warping_template2subject)
        sct.printv('\nDONE ! Warping field generated: ' + warping_template2subject, verbose)
        img = Nifti1Image(data_warp_subject2template, None, hdr_warp)
        save(img, warping_subject2template)
        sct.printv('\nDONE ! Warping field generated: ' + warping_subject2template, verbose)

        # Copy warping into parent folder
        sct.run('cp ' + warping_subject2template + ' ../' + warping_subject2template)
        sct.run('cp ' + warping_template2subject + ' ../' + warping_template2subject)
        warping = warping_subject2template

        os.chdir('..')
        remove_temp_files = True
        if remove_temp_files:
            sct.run('rm -rf ' + tmp_name)

    elif transfo == 'SyN':
        warping = 'warp_subject2template.nii.gz'
        tmp_name = 'tmp.'+time.strftime("%y%m%d%H%M%S")
        sct.run('mkdir '+tmp_name)
        tmp_abs_path = os.path.abspath(tmp_name)
        sct.run('cp ' + landmark + ' ' + tmp_abs_path)
        os.chdir(tmp_name)

        # sct.run('sct_label_utils -i '+landmark+' -t dist-inter')
        # sct.run('sct_label_utils -i '+template_landmark+' -t plan -o template_landmarks_plan.nii.gz -c 5')
        # sct.run('sct_crop_image -i template_landmarks_plan.nii.gz -o template_landmarks_plan_cropped.nii.gz -start 0.35,0.35 -end 0.65,0.65 -dim 0,1')
        # sct.run('sct_label_utils -i '+landmark+' -t plan -o landmarks_plan.nii.gz -c 5')
        # sct.run('sct_crop_image -i landmarks_plan.nii.gz -o landmarks_plan_cropped.nii.gz -start 0.35,0.35 -end 0.65,0.65 -dim 0,1')
        # sct.run('isct_antsRegistration --dimensionality 3 --transform SyN[0.5,3,0] --metric MeanSquares[template_landmarks_plan_cropped.nii.gz,landmarks_plan_cropped.nii.gz,1] --convergence 400x200 --shrink-factors 4x2 --smoothing-sigmas 4x2mm --restrict-deformation 0x0x1 --output [landmarks_reg,landmarks_reg.nii.gz] --interpolation NearestNeighbor --float')
        # sct.run('isct_c3d -mcs landmarks_reg0Warp.nii.gz -oo warp_vecx.nii.gz warp_vecy.nii.gz warp_vecz.nii.gz')
        # sct.run('isct_c3d warp_vecz.nii.gz -resample 200% -o warp_vecz_r.nii.gz')
        # sct.run('isct_c3d warp_vecz_r.nii.gz -smooth 0x0x3mm -o warp_vecz_r_sm.nii.gz')
        # sct.run('sct_crop_image -i warp_vecz_r_sm.nii.gz -o warp_vecz_r_sm_line.nii.gz -start 0.5,0.5 -end 0.5,0.5 -dim 0,1 -b 0')
        # sct.run('sct_label_utils -i warp_vecz_r_sm_line.nii.gz -t plan_ref -o warp_vecz_r_sm_line_extended.nii.gz -c 0 -r '+template_landmark)
        # sct.run('isct_c3d '+template_landmark+' warp_vecx.nii.gz -reslice-identity -o warp_vecx_res.nii.gz')
        # sct.run('isct_c3d '+template_landmark+' warp_vecy.nii.gz -reslice-identity -o warp_vecy_res.nii.gz')
        # sct.run('isct_c3d warp_vecx_res.nii.gz warp_vecy_res.nii.gz warp_vecz_r_sm_line_extended.nii.gz -omc 3 '+warping)  # no x?


        #new
        #put labels of the subject at the center of the image (for plan xOy)
        import nibabel
        from copy import copy
        file_labels_input = nibabel.load(landmark)
        hdr_labels_input = file_labels_input.get_header()
        data_labels_input = file_labels_input.get_data()
        data_labels_middle = copy(data_labels_input)
        data_labels_middle *= 0
        from msct_image import Image
        nx, ny, nz, nt, px, py, pz, pt = Image(landmark).dim
        X,Y,Z = data_labels_input.nonzero()
        x_middle = int(round(nx/2.0))
        y_middle = int(round(ny/2.0))

        #put labels of the template at the center of the image (for plan xOy)  #probably not necessary as already done by average labels
        file_labels_template = nibabel.load(template_landmark)
        hdr_labels_template = file_labels_template.get_header()
        data_labels_template = file_labels_template.get_data()
        data_template_middle = copy(data_labels_template)
        data_template_middle *= 0

        x, y, z = data_labels_template.nonzero()

        max_num = min([len(z), len(Z)])
        index_sort = np.argsort(Z)
        index_sort = index_sort[::-1]
        X = X[index_sort]
        Y = Y[index_sort]
        Z = Z[index_sort]
        index_sort = np.argsort(z)
        index_sort = index_sort[::-1]
        x = x[index_sort]
        y = y[index_sort]
        z = z[index_sort]

        for i in range(max_num):
            data_labels_middle[x_middle, y_middle, Z[i]] = data_labels_input[X[i], Y[i], Z[i]]
        img = nibabel.Nifti1Image(data_labels_middle, None, hdr_labels_input)
        nibabel.save(img, 'labels_input_middle_xy.nii.gz')

        for i in range(max_num):
            data_template_middle[x_middle, y_middle, z[i]] = data_labels_template[x[i], y[i], z[i]]
        img_template = nibabel.Nifti1Image(data_template_middle, None, hdr_labels_template)
        nibabel.save(img_template, 'labels_template_middle_xy.nii.gz')


        #estimate Bspline transform to register to template
        sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField labels_template_middle_xy.nii.gz labels_input_middle_xy.nii.gz '+ warping+' 40x40x1 5 5 0')

        # select centerline of warping field according to z and extend it
        sct.run('isct_c3d -mcs '+warping+' -oo warp_vecx.nii.gz warp_vecy.nii.gz warp_vecz.nii.gz')
        #sct.run('isct_c3d warp_vecz.nii.gz -resample 200% -o warp_vecz_r.nii.gz')
        #sct.run('isct_c3d warp_vecz.nii.gz -smooth 0x0x3mm -o warp_vecz_r_sm.nii.gz')
        sct.run('sct_crop_image -i warp_vecz.nii.gz -o warp_vecz_r_sm_line.nii.gz -start 0.5,0.5 -end 0.5,0.5 -dim 0,1 -b 0')
        sct.run('sct_label_utils -i warp_vecz_r_sm_line.nii.gz -t plan_ref -o warp_vecz_r_sm_line_extended.nii.gz -r '+template_landmark)
        sct.run('isct_c3d '+template_landmark+' warp_vecx.nii.gz -reslice-identity -o warp_vecx_res.nii.gz')
        sct.run('isct_c3d '+template_landmark+' warp_vecy.nii.gz -reslice-identity -o warp_vecy_res.nii.gz')
        sct.run('isct_c3d warp_vecx_res.nii.gz warp_vecy_res.nii.gz warp_vecz_r_sm_line_extended.nii.gz -omc 3 '+warping)

        # check results
        #dilate first labels
        sct.run('fslmaths labels_input_middle_xy.nii.gz -dilF landmark_dilated.nii.gz') #new
        sct.run('sct_apply_transfo -i landmark_dilated.nii.gz -o label_moved.nii.gz -d labels_template_middle_xy.nii.gz -w '+warping+' -x nn')
        #undilate
        sct.run('sct_label_utils -i label_moved.nii.gz -t cubic-to-point -o label_moved_2point.nii.gz')
        sct.run('sct_label_utils -i labels_template_middle_xy.nii.gz -r label_moved_2point.nii.gz -o template_removed.nii.gz -t remove')
        #end new


        # check results
        #dilate first labels

        #sct.run('fslmaths '+landmark+' -dilF landmark_dilated.nii.gz') #old
        #sct.run('sct_apply_transfo -i landmark_dilated.nii.gz -o label_moved.nii.gz -d '+template_landmark+' -w '+warping+' -x nn') #old



        #undilate
        #sct.run('sct_label_utils -i label_moved.nii.gz -t cubic-to-point -o label_moved_2point.nii.gz') #old

        #sct.run('sct_label_utils -i '+template_landmark+' -r label_moved_2point.nii.gz -o template_removed.nii.gz -t remove') #old


        # # sct.run('sct_apply_transfo -i '+landmark+' -o label_moved.nii.gz -d '+template_landmark+' -w '+warping+' -x nn')
        # # sct.run('sct_label_utils -i '+template_landmark+' -r label_moved.nii.gz -o template_removed.nii.gz -t remove')
        # # status, output = sct.run('sct_label_utils -i label_moved.nii.gz -r template_removed.nii.gz -t MSE')

        status, output = sct.run('sct_label_utils -i label_moved_2point.nii.gz -r template_removed.nii.gz -t MSE')
        sct.printv(output,1,'info')
        remove_temp_files = False
        if os.path.isfile('error_log_label_moved.txt'):
            remove_temp_files = False
            with open('log.txt', 'a') as log_file:
                log_file.write('Error for '+fname+'\n')
        # Copy warping into parent folder
        sct.run('cp '+ warping+' ../'+warping)

        os.chdir('..')
        if remove_temp_files:
            sct.run('rm -rf '+tmp_name)



    # if transfo == 'bspline' :
    #     print 'Computing bspline transformation between subject and destination landmarks\n...'
    #     sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField cross_template.nii.gz cross_native.nii.gz warp_ntotemp.nii.gz 5x5x5 3 2 0')
    #     warping = 'warp_ntotemp.nii.gz'
        
    # if final_warp == '' :    
    #     print 'Apply transfo to input image\n...'
    #     sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping + ' -n Linear')
        
    # if final_warp == 'NN':
    #     print 'Apply transfo to input image\n...'
    #     sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping + ' -n NearestNeighbor')
    if final_warp == 'spline':
        print 'Apply transfo to input image\n...'
        sct.run('sct_apply_transfo -i ' + fname + ' -o ' + output_name + ' -d ' + template_landmark + ' -w ' + warping + ' -x spline')


    # Remove warping
    #os.remove(warping)

    # if compose :
        
    #     print 'Computing affine transformation between subject and destination landmarks\n...'
    #     sct.run('isct_ANTSUseLandmarkImagesToGetAffineTransform cross_template.nii.gz cross_native.nii.gz affine n2t.txt')
    #     warping_affine = 'n2t.txt'
        
        
    #     print 'Apply transfo to input landmarks\n...'
    #     sct.run('isct_antsApplyTransforms 3 ' + cross_native + ' cross_affine.nii.gz -r ' + template_landmark + ' -t ' + warping_affine + ' -n NearestNeighbor')
        
    #     print 'Computing transfo between moved landmarks and template landmarks\n...'
    #     sct.run('isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField cross_template.nii.gz cross_affine.nii.gz warp_affine2temp.nii.gz 5x5x5 3 2 0')
    #     warping_bspline = 'warp_affine2temp.nii.gz'
        
    #     print 'Composing transformations\n...'
    #     sct.run('isct_ComposeMultiTransform 3 warp_full.nii.gz -r ' + template_landmark + ' ' + warping_bspline + ' ' + warping_affine)
    #     warping_concat = 'warp_full.nii.gz'
        
    #     if final_warp == '' :    
    #         print 'Apply concat warp to input image\n...'
    #         sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping_concat + ' -n Linear')
        
    #     if final_warp == 'NN':
    #         print 'Apply concat warp to input image\n...'
    #         sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping_concat + ' -n NearestNeighbor')
        
    #     if final_warp == 'spline':
    #         print 'Apply concat warp to input image\n...'
    #         sct.run('isct_antsApplyTransforms 3 ' + fname + ' ' + output_name + ' -r ' + template_landmark + ' -t ' + warping_concat + ' -n BSpline[3]')
          
    
    
    print '\nFile created : ' + output_name
class ProcessLabels(object):
    def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False,
                 coordinates=None, verbose='1'):
        self.image_input = Image(fname_label)

        if fname_ref is not None:
            self.image_ref = Image(fname_ref)

        self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose

    def process(self, type_process):
        if type_process == 'cross':
            self.output_image = self.cross()
        elif type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        elif type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        elif type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        elif type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        elif type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        elif type_process == 'remove':
            self.output_image = self.remove_label()
        elif type_process == 'centerline':
            self.extract_centerline()
        elif type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        elif type_process == 'create':
            self.output_image = self.create_label()
        elif type_process == 'add':
            self.output_image = self.create_label(add=True)
        elif type_process == 'diff':
            self.diff()
            self.fname_output = None
        elif type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        elif type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        else:
            sct.printv('Error: The chosen process is not available.',1,'error')

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            self.output_image.save('minimize_int')


    def cross(self):
        image_output = Image(self.image_input)
        nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(self.image_input.absolutepath)

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[coord.x][coord.y][coord.z] = 0  # remove point on the center of the spinal cord
            image_output.data[coord.x][coord.y + dy][
                coord.z] = coord.value * 10 + 1  # add point at distance from center of spinal cord
            image_output.data[coord.x + dx][coord.y][coord.z] = coord.value * 10 + 2
            image_output.data[coord.x][coord.y - dy][coord.z] = coord.value * 10 + 3
            image_output.data[coord.x - dx][coord.y][coord.z] = coord.value * 10 + 4

            # dilate cross to 3x3
            if self.dilate:
                image_output.data[coord.x - 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x][coord.y + dy - 1][coord.z] = \
                    image_output.data[coord.x + 1][coord.y + dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y + dy][coord.z] = \
                    image_output.data[coord.x + 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x][coord.y + dy + 1][coord.z] = \
                    image_output.data[coord.x - 1][coord.y + dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y + dy][coord.z] = \
                    image_output.data[coord.x][coord.y + dy][coord.z]
                image_output.data[coord.x + dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx][coord.y - 1][coord.z] = \
                    image_output.data[coord.x + dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x + dx + 1][coord.y][coord.z] = \
                    image_output.data[coord.x + dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx][coord.y + 1][coord.z] = \
                    image_output.data[coord.x + dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x + dx - 1][coord.y][coord.z] = \
                    image_output.data[coord.x + dx][coord.y][coord.z]
                image_output.data[coord.x - 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x][coord.y - dy - 1][coord.z] = \
                    image_output.data[coord.x + 1][coord.y - dy - 1][coord.z] = image_output.data[coord.x + 1][coord.y - dy][coord.z] = \
                    image_output.data[coord.x + 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x][coord.y - dy + 1][coord.z] = \
                    image_output.data[coord.x - 1][coord.y - dy + 1][coord.z] = image_output.data[coord.x - 1][coord.y - dy][coord.z] = \
                    image_output.data[coord.x][coord.y - dy][coord.z]
                image_output.data[coord.x - dx - 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx][coord.y - 1][coord.z] = \
                    image_output.data[coord.x - dx + 1][coord.y - 1][coord.z] = image_output.data[coord.x - dx + 1][coord.y][coord.z] = \
                    image_output.data[coord.x - dx + 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx][coord.y + 1][coord.z] = \
                    image_output.data[coord.x - dx - 1][coord.y + 1][coord.z] = image_output.data[coord.x - dx - 1][coord.y][coord.z] = \
                    image_output.data[coord.x - dx][coord.y][coord.z]

        return image_output

    def plan(self, width, offset=0, gap=1):
        """
        This function creates a plan of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:,:,coord.z-width:coord.z+width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        This function generate a plan in the reference space for each label present in the input image
        """
        image_output = Image(self.image_ref)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:, :, coord.z] = coord.value

        return image_output
    def cubic_to_point(self):
        """
        This function calculates the center of mass of each group of labels and returns a file of same size with only a label by group at the center of mass.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        :return:
        """
        from scipy import ndimage
        from numpy import array
        data = self.image_input.data

        image_output = self.image_input.copy()
        data_output = image_output.data
        data_output *= 0
        nx = image_output.data.shape[0]
        ny = image_output.data.shape[1]
        nz = image_output.data.shape[2]
        print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)

        z_centerline = [iz for iz in range(0, nz, 1) if data[:,:,iz].any() ]
        nz_nonz = len(z_centerline)
        if nz_nonz==0 :
            print '\nERROR: Label file is empty'
            sys.exit()
        x_centerline = [0 for iz in range(0, nz_nonz, 1)]
        y_centerline = [0 for iz in range(0, nz_nonz, 1)]
        print '\nGet center of mass for each slice of the label file ...'
        for iz in xrange(len(z_centerline)):
            x_centerline[iz], y_centerline[iz] = ndimage.measurements.center_of_mass(array(data[:,:,z_centerline[iz]]))

        ## Calculate mean coordinate according to z for each cube of labels:
        list_cube_labels_x = [[]]
        list_cube_labels_y = [[]]
        list_cube_labels_z = [[]]
        count = 0
        for i in range(nz_nonz-1):
            # Make a list of group of slices that contains a non zero value
            if z_centerline[i] - z_centerline[i+1] == -1:
                # Verify if the value has already been recovered and add if not
                #If the group is empty add first value do not if it is not empty as it will copy it for a second time
                if len(list_cube_labels_z[count]) == 0 :#or list_cube_labels[count][-1] != z_centerline[i]:
                    list_cube_labels_z[count].append(z_centerline[i])
                    list_cube_labels_x[count].append(x_centerline[i])
                    list_cube_labels_y[count].append(y_centerline[i])
                list_cube_labels_z[count].append(z_centerline[i+1])
                list_cube_labels_x[count].append(x_centerline[i+1])
                list_cube_labels_y[count].append(y_centerline[i+1])
                if i+2 < nz_nonz-1 and z_centerline[i+1] - z_centerline[i+2] != -1:
                    list_cube_labels_z.append([])
                    list_cube_labels_x.append([])
                    list_cube_labels_y.append([])
                    count += 1

        z_label_mean = [0 for i in range(len(list_cube_labels_z))]
        x_label_mean = [0 for i in range(len(list_cube_labels_z))]
        y_label_mean = [0 for i in range(len(list_cube_labels_z))]
        for i in range(len(list_cube_labels_z)):
            for j in range(len(list_cube_labels_z[i])):
                z_label_mean[i] += list_cube_labels_z[i][j]
                x_label_mean[i] += list_cube_labels_x[i][j]
                y_label_mean[i] += list_cube_labels_y[i][j]
            z_label_mean[i] = int(round(z_label_mean[i]/len(list_cube_labels_z[i])))
            x_label_mean[i] = int(round(x_label_mean[i]/len(list_cube_labels_x[i])))
            y_label_mean[i] = int(round(y_label_mean[i]/len(list_cube_labels_y[i])))


        ## Put labels of value one into mean coordinates
        for i in range(len(z_label_mean)):
            data_output[x_label_mean[i],y_label_mean[i], z_label_mean[i]] = 1

        return image_output

    def increment_z_inverse(self):
        """
        This function increments all the labels present in the input image, inversely ordered by Z.
        Therefore, labels are incremented from top to bottom, assuming a RPI orientation
        Labels are assumed to be non-zero.
        """
        image_output = Image(self.image_input)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[coord.x, coord.y, coord.z] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        This function creates an image with regions labelized depending on values from reference.
        Typically, user inputs an segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref)-1):
                if coordinates_ref[j+1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[coord.x, coord.y, coord.z] = coordinates_ref[j].value

        return image_output

    def symmetrizer(self, side='left'):
        """
        This function symmetrize the input image. One side of the image will be copied on the other side. We assume a
        RPI orientation.
        :param side: string 'left' or 'right'. Side that will be copied on the other side.
        :return:
        """
        image_output = Image(self.image_input)

        image_output[0:]

        """inspiration: (from atlas creation matlab script)
        temp_sum = temp_g + temp_d;
        temp_sum_flip = temp_sum(end:-1:1,:);
        temp_sym = (temp_sum + temp_sum_flip) / 2;

        temp_g(1:end / 2,:) = 0;
        temp_g(1 + end / 2:end,:) = temp_sym(1 + end / 2:end,:);
        temp_d(1:end / 2,:) = temp_sym(1:end / 2,:);
        temp_d(1 + end / 2:end,:) = 0;

        tractsHR
        {label_l}(:,:, num_slice_ref) = temp_g;
        tractsHR
        {label_r}(:,:, num_slice_ref) = temp_d;
        """

        return image_output

    def MSE(self, threshold_mse=0):
        """
        This function computes the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [round(coord_ref.value) for coord_ref in coordinates_ref]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [round(coord.value) for coord in coordinates_input]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z) ** 2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(self.image_input.path + 'error_log_' + self.image_input.file_name + '.txt', 'w')
            f.write(
                'The labels error (MSE) between ' + self.image_input.file_name + ' and ' + self.image_ref.file_name + ' is: ' + str(
                    result))
            f.close()

        return result

    def create_label(self, add=False):
        """
        This function create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # display info
            sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' +
                       str(coord.value), 1)
            image_output.data[coord.x, coord.y, coord.z] = int(coord.value)

        return image_output

    def remove_label(self):
        """
        This function compares two label images and remove any labels in input image that are not in reference image.
        """
        image_output = Image(self.image_input)
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        for coord in coordinates_input:
            value = self.image_input.data[coord.x, coord.y, coord.z]
            isInRef = False
            for coord_ref in coordinates_ref:
                # the following line could make issues when down sampling input, for example 21,00001 not = 21,0
                if abs(coord.value - coord_ref.value) < 0.1:
                    image_output.data[coord.x, coord.y, coord.z] = int(round(coord_ref.value))
                    isInRef = True
            if isInRef == False:
                image_output.data[coord.x, coord.y, coord.z] = 0

        return image_output

    def extract_centerline(self):
        """
        This function write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x,coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()

    def display_voxel(self):
        """
        This function displays all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')
        useful_notation = ''
        for coord in coordinates_input:
            print 'Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value)
            if useful_notation != '':
                useful_notation = useful_notation + ':'
            useful_notation = useful_notation + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ',' + str(coord.value)
        print 'Useful notation:'
        print useful_notation

    def diff(self):
        """
        This function detects any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        print "Label in input image that are not in reference image:"
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord.value

        print "Label in ref image that are not in input image:"
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                print coord_ref.value

    def distance_interlabels(self, max_dist):
        """
        This function calculates the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt((coordinates_input[i].x - coordinates_input[i+1].x)**2 + (coordinates_input[i].y - coordinates_input[i+1].y)**2 + (coordinates_input[i].z - coordinates_input[i+1].z)**2)
            if dist < max_dist:
                print 'Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str(
                    coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i+1) + '[' + str(
                    coordinates_input[i+1].x) + ',' + str(coordinates_input[i+1].y) + ',' + str(coordinates_input[i+1].z) + ']=' + str(
                    coordinates_input[i+1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist)
Esempio n. 25
0
def main():

    # get default parameters
    step1 = Paramreg(step='1',
                     type='seg',
                     algo='slicereg',
                     metric='MeanSquares',
                     iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(
        name="-l",
        type_value="file",
        description=
        "Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
        mandatory=True,
        default_value='',
        example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(
        name="-p",
        type_value=[[':'], 'str'],
        description=
        """Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""
        + paramreg.steps['1'].type + """\nalgo=""" + paramreg.steps['1'].algo +
        """\nmetric=""" + paramreg.steps['1'].metric + """\npoly=""" +
        paramreg.steps['1'].poly + """\n--\nstep=2\ntype=""" +
        paramreg.steps['2'].type + """\nalgo=""" + paramreg.steps['2'].algo +
        """\nmetric=""" + paramreg.steps['2'].metric + """\niter=""" +
        paramreg.steps['2'].iter + """\nshrink=""" +
        paramreg.steps['2'].shrink + """\nsmooth=""" +
        paramreg.steps['2'].smooth + """\ngradStep=""" +
        paramreg.steps['2'].gradStep + """\n--""",
        mandatory=False,
        example=
        "step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3"
    )
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(
        name="-v",
        type_value="multiple_choice",
        description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
        mandatory=False,
        default_value=param.verbose,
        example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1) + file_template
    fname_template_label = sct.slash_at_the_end(path_template,
                                                1) + file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template,
                                              1) + file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 ' + fname_data, verbose)
    sct.printv('.. Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('.. Segmentation:         ' + fname_seg, verbose)
    sct.printv('.. Path template:        ' + path_template, verbose)
    sct.printv('.. Output type:          ' + str(output_type), verbose)
    sct.printv('.. Remove temp files:    ' + str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps) + 1):
        sct.printv('Step #' + paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #' + paramreg.steps[str(pStep)].type, verbose)
        sct.printv(
            '.. Algorithm................ ' + paramreg.steps[str(pStep)].algo,
            verbose)
        sct.printv(
            '.. Metric................... ' +
            paramreg.steps[str(pStep)].metric, verbose)
        sct.printv(
            '.. Number of iterations..... ' + paramreg.steps[str(pStep)].iter,
            verbose)
        sct.printv(
            '.. Shrink factor............ ' +
            paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv(
            '.. Smoothing factor......... ' +
            paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv(
            '.. Gradient step............ ' +
            paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv(
            '.. Degree of polynomial..... ' + paramreg.steps[str(pStep)].poly,
            verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv(
            'ERROR: Wrong landmarks input. All labels must be different.',
            verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(
        sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv(
            'ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
            'provided: ' + str(labels[-1].value) +
            '\nLabel max from template: ' + str(labels_template[-1].value),
            verbose, 'error')

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir ' + path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d ' + fname_data + ' -o ' + path_tmp + '/data.nii')
    sct.run('isct_c3d ' + fname_landmarks + ' -o ' + path_tmp +
            '/landmarks.nii.gz')
    sct.run('isct_c3d ' + fname_seg + ' -o ' + path_tmp +
            '/segmentation.nii.gz')
    sct.run('isct_c3d ' + fname_template + ' -o ' + path_tmp + '/template.nii')
    sct.run('isct_c3d ' + fname_template_label + ' -o ' + path_tmp +
            '/template_labels.nii.gz')
    sct.run('isct_c3d ' + fname_template_seg + ' -o ' + path_tmp +
            '/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run(
        'isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii'
    )
    sct.run(
        'isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz'
    )
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run(
        'sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax'
    )

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...',
               verbose)
    sct.run(
        'sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '
        + str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run(
        'sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz'
    )

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv(
        '\nRemove unused label on template. Keep only label present in the input label image...',
        verbose)
    sct.run(
        'sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz'
    )

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run(
        'isct_c3d template_label.nii.gz -type int -o template_label.nii.gz',
        verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run(
        'sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5'
    )

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv(
        '\nCreate a 5mm cross for the input labels and dilate for straightening preparation...',
        verbose)
    sct.run(
        'sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d'
    )

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run(
        'sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn'
    )

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run(
        'isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz'
    )

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.',
               verbose)
    sct.run(
        'sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz'
    )

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv(
        '\nEstimate affine transfo: straight anat --> template (landmark-based)...',
        verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_moving.append(
            [point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_fixed.append(
            [point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv(
        '\nComputing rigid transformation (algo=translation-scaling-z) ...',
        verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write(
        "Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write(
        "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
        % (1.0 / rotation_matrix[0, 0], rotation_matrix[0, 1],
           rotation_matrix[0, 2], rotation_matrix[1, 0],
           1.0 / rotation_matrix[1, 1], rotation_matrix[1, 2],
           rotation_matrix[2, 0], rotation_matrix[2, 1],
           1.0 / rotation_matrix[2, 2], translation_array[0, 0],
           translation_array[0, 1], -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" %
                    (points_moving_barycenter[0], points_moving_barycenter[1],
                     points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...',
               verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear'
    )

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax(
        'segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...',
               verbose)
    sct.run(
        'sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start ' +
        str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...',
               verbose)
    sct.run(
        'sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x' +
        zsubsample, verbose)
    sct.run(
        'sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps) + 1):
        sct.printv(
            '\nEstimate transformation for step #' + str(i_step) + '...',
            verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run(
                'sct_apply_transfo -i ' + src + ' -d ' + dest + ' -w ' +
                ','.join(warp_forward) + ' -o ' + sct.add_suffix(src, '_reg') +
                ' -x ' + interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg,
                                                      param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straightAffine.nii.gz,' +
        ','.join(warp_forward) +
        ' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run(
        'sct_concat_transfo -w ' + ','.join(warp_inverse) +
        ',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz',
        verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run(
            'sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1',
            verbose)
        sct.run(
            'sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1',
            verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp + '/warp_template2anat.nii.gz',
                             'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp + '/warp_anat2template.nii.gz',
                             'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp + '/template2anat.nii.gz',
                                 'template2anat' + ext_data, verbose)
        sct.generate_output_file(path_tmp + '/anat2template.nii.gz',
                                 'anat2template' + ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf ' + path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv(
        '\nFinished! Elapsed time: ' + str(int(round(elapsed_time))) + 's',
        verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview ' + fname_data + ' template2anat -b 0,4000 &', verbose,
               'info')
    sct.printv('fslview ' + fname_template + ' -b 0,5000 anat2template &\n',
               verbose, 'info')