Exemple #1
0
def centerline2roi(fname_image, folder_output='./', verbose=0):
    """
    Tis method converts a binary centerline image to a .roi centerline file

    Args:
        fname_image: filename of the binary centerline image, in RPI orientation
        folder_output: path to output folder where to copy .roi centerline
        verbose: adjusts the verbosity of the logging.

    Returns: filename of the .roi centerline that has been created

    """
    path_data, file_data, ext_data = sct.extract_fname(fname_image)
    fname_output = file_data + '.roi'

    date_now = datetime.now()
    ROI_TEMPLATE = 'Begin Marker ROI\n' \
                   '  Build version="7.0_33"\n' \
                   '  Annotation=""\n' \
                   '  Colour=0\n' \
                   '  Image source="{fname_segmentation}"\n' \
                   '  Created  "{creation_date}" by Operator ID="SCT"\n' \
                   '  Slice={slice_num}\n' \
                   '  Begin Shape\n' \
                   '    X={position_x}; Y={position_y}\n' \
                   '  End Shape\n' \
                   'End Marker ROI\n'

    im = Image(fname_image)
    nx, ny, nz, nt, px, py, pz, pt = im.dim
    coordinates_centerline = im.getNonZeroCoordinates(sorting='z')

    f = open(fname_output, "w")
    sct.printv('\nWriting ROI file...', verbose)

    for coord in coordinates_centerline:
        coord_phys_center = im.transfo_pix2phys([[(nx - 1) / 2.0,
                                                  (ny - 1) / 2.0, coord.z]])[0]
        coord_phys = im.transfo_pix2phys([[coord.x, coord.y, coord.z]])[0]
        f.write(
            ROI_TEMPLATE.format(
                fname_segmentation=fname_image,
                creation_date=date_now.strftime("%d %B %Y %H:%M:%S.%f %Z"),
                slice_num=coord.z + 1,
                position_x=coord_phys_center[0] - coord_phys[0],
                position_y=coord_phys_center[1] - coord_phys[1]))

    f.close()

    if os.path.abspath(folder_output) != os.getcwd():
        shutil.copy(fname_output, folder_output)

    return fname_output
Exemple #2
0
def compute_ICBM152_centerline(dataset_info):
    """
    This function extracts the centerline from the ICBM152 brain template
    :param dataset_info: dictionary containing dataset information
    :return:
    """
    path_data = dataset_info['path_data']

    if not os.path.isdir(path_data + 'icbm152/'):
        download_data_template(path_data=path_data, name='icbm152', force=False)

    image_disks = Image(path_data + 'icbm152/mni_icbm152_t1_tal_nlin_sym_09c_disks_manual.nii.gz')
    coord = image_disks.getNonZeroCoordinates(sorting='z', reverse_coord=True)
    coord_physical = []

    for c in coord:
        if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
            c_p = image_disks.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            c_p.append(c.value)
            coord_physical.append(c_p)

    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        path_data + 'icbm152/mni_icbm152_t1_centerline_manual.nii.gz', algo_fitting='nurbs',
        verbose=0, nurbs_pts_number=300, all_slices=False, phys_coordinates=True, remove_outliers=False)

    centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                            x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)

    centerline.compute_vertebral_distribution(coord_physical, label_reference='PMG')
    return centerline
def test_integrity(param_test):
    """
    Test integrity of function
    """
    # initializations
    distance_detection = float('nan')

    # extract name of output centerline: data_centerline_optic.nii.gz
    file_pmj = os.path.join(param_test.path_output,
                            sct.add_suffix(param_test.file_input, '_pmj'))

    # open output segmentation
    im_pmj = Image(file_pmj)

    # open ground truth
    im_pmj_manual = Image(param_test.fname_gt)

    # compute Euclidean distance between predicted and GT PMJ label
    x_true, y_true, z_true = np.where(im_pmj_manual.data == 50)
    x_pred, y_pred, z_pred = np.where(im_pmj.data == 50)

    x_true, y_true, z_true = im_pmj_manual.transfo_pix2phys(
        [[x_true[0], y_true[0], z_true[0]]])[0]
    x_pred, y_pred, z_pred = im_pmj.transfo_pix2phys(
        [[x_pred[0], y_pred[0], z_pred[0]]])[0]

    distance_detection = math.sqrt(((x_true - x_pred))**2 +
                                   ((y_true - y_pred))**2 +
                                   ((z_true - z_pred))**2)

    param_test.output += 'Computed distance: ' + str(distance_detection)
    param_test.output += 'Distance threshold (if computed Distance higher: fail): ' + str(
        param_test.dist_threshold)

    if distance_detection > param_test.dist_threshold:
        param_test.status = 99
        param_test.output += '--> FAILED'
    else:
        param_test.output += '--> PASSED'

    # update Panda structure
    param_test.results['distance_detection'] = distance_detection

    return param_test
Exemple #4
0
def generate_centerline(dataset_info, contrast='t1', regenerate=False):
    """
    This function generates spinal cord centerline from binary images (either an image of centerline or segmentation)
    :param dataset_info: dictionary containing dataset information
    :param contrast: {'t1', 't2'}
    :return list of centerline objects
    """
    path_data = dataset_info['path_data']
    list_subjects = dataset_info['subjects']
    list_centerline = []

    current_path = os.getcwd()

    timer_centerline = sct.Timer(len(list_subjects))
    timer_centerline.start()
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        fname_image_centerline = path_data_subject + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
        fname_image_disks = path_data_subject + contrast + dataset_info['suffix_disks'] + '.nii.gz'

        # go to output folder
        sct.printv('\nExtracting centerline from ' + path_data_subject)
        os.chdir(path_data_subject)

        fname_centerline = 'centerline'
        # if centerline exists, we load it, if not, we compute it
        if os.path.isfile(fname_centerline + '.npz') and not regenerate:
            centerline = Centerline(fname=path_data_subject + fname_centerline + '.npz')
        else:
            # extracting intervertebral disks
            im = Image(fname_image_disks)
            coord = im.getNonZeroCoordinates(sorting='z', reverse_coord=True)
            coord_physical = []
            for c in coord:
                if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
                    c_p = im.transfo_pix2phys([[c.x, c.y, c.z]])[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)

            # extracting centerline from binary image and create centerline object with vertebral distribution
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_image_centerline, algo_fitting='nurbs',
                verbose=0, nurbs_pts_number=4000, all_slices=False, phys_coordinates=True, remove_outliers=False)
            centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                    x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
            centerline.compute_vertebral_distribution(coord_physical)
            centerline.save_centerline(fname_output=fname_centerline)

        list_centerline.append(centerline)
        timer_centerline.add_iteration()
    timer_centerline.stop()

    os.chdir(current_path)

    return list_centerline
def register_seg(seg_input, seg_dest):
    """Slice-by-slice registration by translation of two segmentations.

    For each slice, we estimate the translation vector by calculating the difference of position of the two centers of
    mass.
    The segmentations can be of different sizes but the output segmentation must be smaller than the input segmentation.

    input:
        seg_input: name of moving segmentation file (type: string)
        seg_dest: name of fixed segmentation file (type: string)

    output:
        x_displacement: list of translation along x axis for each slice (type: list)
        y_displacement: list of translation along y axis for each slice (type: list)

    """

    seg_input_img = Image(seg_input)
    seg_dest_img = Image(seg_dest)
    seg_input_data = seg_input_img.data
    seg_dest_data = seg_dest_img.data

    x_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    print '\nGet center of mass of the input segmentation for each slice (corresponding to a slice in the output segmentation)...'  #different if size of the two seg are different
    #TO DO: select only the slices corresponding to the output segmentation
    coord_origin_dest = seg_dest_img.transfo_pix2phys([[0, 0, 0]])
    [[x_o, y_o, z_o]] = seg_input_img.transfo_phys2pix(coord_origin_dest)
    for iz in xrange(seg_dest_data.shape[2]):
        x_center_of_mass_input[iz], y_center_of_mass_input[
            iz] = ndimage.measurements.center_of_mass(
                array(seg_input_data[:, :, z_o + iz]))

    x_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    print '\nGet center of mass of the output segmentation for each slice ...'
    for iz in xrange(seg_dest_data.shape[2]):
        x_center_of_mass_output[iz], y_center_of_mass_output[
            iz] = ndimage.measurements.center_of_mass(
                array(seg_dest_data[:, :, iz]))

    x_displacement = [0 for i in range(seg_input_data.shape[2])]
    y_displacement = [0 for i in range(seg_input_data.shape[2])]
    print '\nGet displacement by voxel...'
    for iz in xrange(seg_dest_data.shape[2]):
        x_displacement[iz] = -(
            x_center_of_mass_output[iz] - x_center_of_mass_input[iz]
        )  # WARNING: in ITK's coordinate system, this is actually Tx and not -Tx
        y_displacement[
            iz] = y_center_of_mass_output[iz] - y_center_of_mass_input[
                iz]  # This is Ty in ITK's and fslview' coordinate systems

    return x_displacement, y_displacement
def register_seg(seg_input, seg_dest):
    """Slice-by-slice registration by translation of two segmentations.

    For each slice, we estimate the translation vector by calculating the difference of position of the two centers of
    mass.
    The segmentations can be of different sizes but the output segmentation must be smaller than the input segmentation.

    input:
        seg_input: name of moving segmentation file (type: string)
        seg_dest: name of fixed segmentation file (type: string)

    output:
        x_displacement: list of translation along x axis for each slice (type: list)
        y_displacement: list of translation along y axis for each slice (type: list)

    """
    seg_input_img = Image(seg_input)
    seg_dest_img = Image(seg_dest)
    seg_input_data = seg_input_img.data
    seg_dest_data = seg_dest_img.data

    x_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    print "\nGet center of mass of the input segmentation for each slice (corresponding to a slice in the output segmentation)..."  # different if size of the two seg are different
    # TO DO: select only the slices corresponding to the output segmentation
    coord_origin_dest = seg_dest_img.transfo_pix2phys([[0, 0, 0]])
    [[x_o, y_o, z_o]] = seg_input_img.transfo_phys2pix(coord_origin_dest)
    for iz in xrange(seg_dest_data.shape[2]):
        x_center_of_mass_input[iz], y_center_of_mass_input[iz] = ndimage.measurements.center_of_mass(
            array(seg_input_data[:, :, z_o + iz])
        )

    x_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    print "\nGet center of mass of the output segmentation for each slice ..."
    for iz in xrange(seg_dest_data.shape[2]):
        x_center_of_mass_output[iz], y_center_of_mass_output[iz] = ndimage.measurements.center_of_mass(
            array(seg_dest_data[:, :, iz])
        )

    x_displacement = [0 for i in range(seg_input_data.shape[2])]
    y_displacement = [0 for i in range(seg_input_data.shape[2])]
    print "\nGet displacement by voxel..."
    for iz in xrange(seg_dest_data.shape[2]):
        x_displacement[iz] = -(
            x_center_of_mass_output[iz] - x_center_of_mass_input[iz]
        )  # WARNING: in ITK's coordinate system, this is actually Tx and not -Tx
        y_displacement[iz] = (
            y_center_of_mass_output[iz] - y_center_of_mass_input[iz]
        )  # This is Ty in ITK's and fslview' coordinate systems

    return x_displacement, y_displacement
def register_seg(seg_input, seg_dest):
    seg_input_img = Image(seg_input)
    seg_dest_img = Image(seg_dest)
    seg_input_data = seg_input_img.data
    seg_dest_data = seg_dest_img.data

    x_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    print '\nGet center of mass of the input segmentation for each slice (corresponding to a slice in the output segmentation)...'  #different if size of the two seg are different
    #TO DO: select only the slices corresponding to the output segmentation
    coord_origin_dest = seg_dest_img.transfo_pix2phys([[0, 0, 0]])
    [[x_o, y_o, z_o]] = seg_input_img.transfo_phys2pix(coord_origin_dest)
    for iz in xrange(seg_dest_data.shape[2]):
        print iz
        x_center_of_mass_input[iz], y_center_of_mass_input[
            iz] = ndimage.measurements.center_of_mass(
                array(seg_input_data[:, :, z_o + iz]))

    x_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    print '\nGet center of mass of the output segmentation for each slice ...'
    for iz in xrange(seg_dest_data.shape[2]):
        x_center_of_mass_output[iz], y_center_of_mass_output[
            iz] = ndimage.measurements.center_of_mass(
                array(seg_dest_data[:, :, iz]))

    x_displacement = [0 for i in range(seg_input_data.shape[2])]
    y_displacement = [0 for i in range(seg_input_data.shape[2])]
    print '\nGet displacement by voxel...'
    for iz in xrange(seg_dest_data.shape[2]):
        x_displacement[iz] = -(
            x_center_of_mass_output[iz] - x_center_of_mass_input[iz]
        )  #strangely, this is the inverse of x_displacement when the same equation defines y_displacement
        y_displacement[
            iz] = y_center_of_mass_output[iz] - y_center_of_mass_input[iz]

    return x_displacement, y_displacement
def register_seg(seg_input, seg_dest):
    seg_input_img = Image(seg_input)
    seg_dest_img = Image(seg_dest)
    seg_input_data = seg_input_img.data
    seg_dest_data = seg_dest_img.data

    x_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_input = [0 for i in range(seg_dest_data.shape[2])]
    print "\nGet center of mass of the input segmentation for each slice (corresponding to a slice in the output segmentation)..."  # different if size of the two seg are different
    # TO DO: select only the slices corresponding to the output segmentation
    coord_origin_dest = seg_dest_img.transfo_pix2phys([[0, 0, 0]])
    [[x_o, y_o, z_o]] = seg_input_img.transfo_phys2pix(coord_origin_dest)
    for iz in xrange(seg_dest_data.shape[2]):
        print iz
        x_center_of_mass_input[iz], y_center_of_mass_input[iz] = ndimage.measurements.center_of_mass(
            array(seg_input_data[:, :, z_o + iz])
        )

    x_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    y_center_of_mass_output = [0 for i in range(seg_dest_data.shape[2])]
    print "\nGet center of mass of the output segmentation for each slice ..."
    for iz in xrange(seg_dest_data.shape[2]):
        x_center_of_mass_output[iz], y_center_of_mass_output[iz] = ndimage.measurements.center_of_mass(
            array(seg_dest_data[:, :, iz])
        )

    x_displacement = [0 for i in range(seg_input_data.shape[2])]
    y_displacement = [0 for i in range(seg_input_data.shape[2])]
    print "\nGet displacement by voxel..."
    for iz in xrange(seg_dest_data.shape[2]):
        x_displacement[iz] = -(
            x_center_of_mass_output[iz] - x_center_of_mass_input[iz]
        )  # strangely, this is the inverse of x_displacement when the same equation defines y_displacement
        y_displacement[iz] = y_center_of_mass_output[iz] - y_center_of_mass_input[iz]

    return x_displacement, y_displacement
def register_images(
    im_input,
    im_dest,
    mask="",
    paramreg=Paramreg(
        step="0", type="im", algo="Translation", metric="MI", iter="5", shrink="1", smooth="0", gradStep="0.5"
    ),
    remove_tmp_folder=1,
):

    path_i, root_i, ext_i = sct.extract_fname(im_input)
    path_d, root_d, ext_d = sct.extract_fname(im_dest)
    path_m, root_m, ext_m = sct.extract_fname(mask)

    # set metricSize
    if paramreg.metric == "MI":
        metricSize = "32"  # corresponds to number of bins
    else:
        metricSize = "4"  # corresponds to radius (for CC, MeanSquares...)

    # initiate default parameters of antsRegistration transformation
    ants_registration_params = {
        "rigid": "",
        "affine": "",
        "compositeaffine": "",
        "similarity": "",
        "translation": "",
        "bspline": ",10",
        "gaussiandisplacementfield": ",3,0",
        "bsplinedisplacementfield": ",5,10",
        "syn": ",3,0",
        "bsplinesyn": ",3,32",
    }

    # Get image dimensions and retrieve nz
    print "\nGet image dimensions of destination image..."
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(im_dest)
    print ".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz)
    print ".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm"

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]
    matrix_def = [0 for i in range(nz)]

    # create temporary folder
    print ("\nCreate temporary folder...")
    path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print "\nCopy input data..."
    sct.run("cp " + im_input + " " + path_tmp + "/" + root_i + ext_i)
    sct.run("cp " + im_dest + " " + path_tmp + "/" + root_d + ext_d)
    if mask:
        sct.run("cp " + mask + " " + path_tmp + "/mask.nii.gz")

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print "\nSplit input volume..."
    sct.run(sct.fsloutput + "fslsplit " + im_input + " " + root_i + "_z -z")
    # file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Split destination volume along z
    print "\nSplit destination volume..."
    sct.run(sct.fsloutput + "fslsplit " + im_dest + " " + root_d + "_z -z")
    # file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Split mask volume along z
    if mask:
        print "\nSplit mask volume..."
        sct.run(sct.fsloutput + "fslsplit mask.nii.gz mask_z -z")
        # file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    im_dest_img = Image(im_dest)
    im_input_img = Image(im_input)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0, 0, 0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0, 0, 0]])
    coord_diff_origin_z = coord_origin_dest[0][2] - coord_origin_input[0][2]
    [[x_o, y_o, z_o]] = im_input_img.transfo_phys2pix([[0, 0, coord_diff_origin_z]])

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + z_o)
        if mask:
            masking = "-x mask_z" + num + ".nii"
        else:
            masking = ""

        cmd = (
            "isct_antsRegistration "
            "--dimensionality 2 "
            "--transform "
            + paramreg.algo
            + "["
            + paramreg.gradStep
            + ants_registration_params[paramreg.algo.lower()]
            + "] "
            "--metric "
            + paramreg.metric
            + "["
            + root_d
            + "_z"
            + num
            + ".nii"
            + ","
            + root_i
            + "_z"
            + num_2
            + ".nii"
            + ",1,"
            + metricSize
            + "] "  # [fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
            "--convergence " + paramreg.iter + " "
            "--shrink-factors " + paramreg.shrink + " "
            "--smoothing-sigmas " + paramreg.smooth + "mm "
            #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
            "--output [transform_"
            + num
            + "] "  # --> file.txt (contains Tx,Ty)    [outputTransformPrefix,<outputWarpedImage>,<outputInverseWarpedImage>]
            "--interpolation BSpline[3] " + masking
        )

        try:
            sct.run(cmd)

            if paramreg.algo == "Rigid" or paramreg.algo == "Translation":
                f = "transform_" + num + "0GenericAffine.mat"
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile["AffineTransform_double_2_2"]
                if i == 20 or i == 40:
                    print i
                x_displacement[i] = -array_transfo[4][0]  # is it? or is it y?
                y_displacement[i] = array_transfo[5][0]
                theta_rotation[i] = asin(array_transfo[2])

            if paramreg.algo == "Affine":
                f = "transform_" + num + "0GenericAffine.mat"
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile["AffineTransform_double_2_2"]
                x_displacement[i] = -array_transfo[4][0]  # is it? or is it y?
                y_displacement[i] = array_transfo[5][0]
                matrix_def[i] = [
                    [array_transfo[0][0], array_transfo[1][0]],
                    [array_transfo[2][0], array_transfo[3][0]],
                ]  # comment savoir lequel est lequel?

        except:
            if paramreg.algo == "Rigid" or paramreg.algo == "Translation":
                x_displacement[i] = x_displacement[i - 1]  # is it? or is it y?
                y_displacement[i] = y_displacement[i - 1]
                theta_rotation[i] = theta_rotation[i - 1]
            if paramreg.algo == "Affine":
                x_displacement[i] = x_displacement[i - 1]
                y_displacement[i] = y_displacement[i - 1]
                matrix_def[i] = matrix_def[i - 1]

        # # get displacement form this slice and complete x and y displacement lists
        # with open('transform_'+num+'.csv') as f:
        #     reader = csv.reader(f)
        #     count = 0
        #     for line in reader:
        #         count += 1
        #         if count == 2:
        #             x_displacement[i] = line[0]
        #             y_displacement[i] = line[1]
        #             f.close()

        # # get matrix of transfo for a rigid transform   (pb slicereg fait une rotation ie le deplacement n'est pas homogene par slice)
        # # recuperer le deplacement ne donnerait pas une liste mais un warping field: mieux vaut recup la matrice output
        # # pb du smoothing du deplacement par slice !!   on peut smoother les param theta tx ty
        # if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
        #     f = 'transform_' +num+ '0GenericAffine.mat'
        #     matfile = loadmat(f, struct_as_record=True)
        #     array_transfo = matfile['AffineTransform_double_2_2']
        #     x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
        #     y_displacement[i] = array_transfo[5][0]
        #     theta_rotation[i] = acos(array_transfo[0])

        # TO DO: different treatment for other algo

    # Delete tmp folder
    os.chdir("../")
    if remove_tmp_folder:
        print ("\nRemove temporary files...")
        sct.run("rm -rf " + path_tmp)
    if paramreg.algo == "Rigid":
        return (
            x_displacement,
            y_displacement,
            theta_rotation,
        )  # check if the displacement are not inverted (x_dis = -x_disp...)   theta is in radian
    if paramreg.algo == "Translation":
        return x_displacement, y_displacement
    if paramreg.algo == "Affine":
        return x_displacement, y_displacement, matrix_def
Exemple #10
0
def register2d_columnwise(fname_src,
                          fname_dest,
                          fname_warp='warp_forward.nii.gz',
                          fname_warp_inv='warp_inverse.nii.gz',
                          verbose=0,
                          path_qc='./',
                          smoothWarpXY=1):
    """
    Column-wise non-linear registration of segmentations. Based on an idea from Allan Martin.
    - Assumes src/dest are segmentations (not necessarily binary), and already registered by center of mass
    - Assumes src/dest are in RPI orientation.
    - Split along Z, then for each slice:
    - scale in R-L direction to match src/dest
    - loop across R-L columns and register by (i) matching center of mass and (ii) scaling.
    :param fname_src:
    :param fname_dest:
    :param fname_warp:
    :param fname_warp_inv:
    :param verbose:
    :return:
    """

    # initialization
    th_nonzero = 0.5  # values below are considered zero

    # for display stuff
    if verbose == 2:
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt

    # Get image dimensions and retrieve nz
    sct.printv('\nGet image dimensions of destination image...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    sct.printv('  matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz),
               verbose)
    sct.printv(
        '  voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(pz) +
        'mm', verbose)

    # Split source volume along z
    sct.printv('\nSplit input volume...', verbose)
    from sct_image import split_data
    im_src = Image('src.nii')
    split_source_list = split_data(im_src, 2)
    for im in split_source_list:
        im.save()

    # Split destination volume along z
    sct.printv('\nSplit destination volume...', verbose)
    im_dest = Image('dest.nii')
    split_dest_list = split_data(im_dest, 2)
    for im in split_dest_list:
        im.save()

    # open image
    data_src = im_src.data
    data_dest = im_dest.data

    if len(data_src.shape) == 2:
        # reshape 2D data into pseudo 3D (only one slice)
        new_shape = list(data_src.shape)
        new_shape.append(1)
        new_shape = tuple(new_shape)
        data_src = data_src.reshape(new_shape)
        data_dest = data_dest.reshape(new_shape)

    # initialize forward warping field (defined in destination space)
    warp_x = np.zeros(data_dest.shape)
    warp_y = np.zeros(data_dest.shape)

    # initialize inverse warping field (defined in source space)
    warp_inv_x = np.zeros(data_src.shape)
    warp_inv_y = np.zeros(data_src.shape)

    # Loop across slices
    sct.printv('\nEstimate columnwise transformation...', verbose)
    for iz in range(0, nz):
        print str(iz) + '/' + str(nz) + '..',

        # PREPARE COORDINATES
        # ============================================================
        # get indices of x and y coordinates
        row, col = np.indices((nx, ny))
        # build 2xn array of coordinates in pixel space
        # ordering of indices is as follows:
        # coord_init_pix[:, 0] = 0, 0, 0, ..., 1, 1, 1..., nx, nx, nx
        # coord_init_pix[:, 1] = 0, 1, 2, ..., 0, 1, 2..., 0, 1, 2
        coord_init_pix = np.array([
            row.ravel(),
            col.ravel(),
            np.array(np.ones(len(row.ravel())) * iz)
        ]).T
        # convert coordinates to physical space
        coord_init_phy = np.array(im_src.transfo_pix2phys(coord_init_pix))
        # get 2d data from the selected slice
        src2d = data_src[:, :, iz]
        dest2d = data_dest[:, :, iz]
        # julien 20161105
        #<<<
        # threshold at 0.5
        src2d[src2d < th_nonzero] = 0
        dest2d[dest2d < th_nonzero] = 0
        # get non-zero coordinates, and transpose to obtain nx2 dimensions
        coord_src2d = np.array(np.where(src2d > 0)).T
        coord_dest2d = np.array(np.where(dest2d > 0)).T
        # here we use 0.5 as threshold for non-zero value
        # coord_src2d = np.array(np.where(src2d > th_nonzero)).T
        # coord_dest2d = np.array(np.where(dest2d > th_nonzero)).T
        #>>>

        # SCALING R-L (X dimension)
        # ============================================================
        # sum data across Y to obtain 1D signal: src_y and dest_y
        src1d = np.sum(src2d, 1)
        dest1d = np.sum(dest2d, 1)
        # make sure there are non-zero data in src or dest
        if np.any(src1d > th_nonzero) and np.any(dest1d > th_nonzero):
            # retrieve min/max of non-zeros elements (edge of the segmentation)
            # julien 20161105
            # <<<
            src1d_min, src1d_max = min(np.where(src1d != 0)[0]), max(
                np.where(src1d != 0)[0])
            dest1d_min, dest1d_max = min(np.where(dest1d != 0)[0]), max(
                np.where(dest1d != 0)[0])
            # for i in xrange(len(src1d)):
            #     if src1d[i] > 0.5:
            #         found index above 0.5, exit loop
            # break
            # get indices (in continuous space) at half-maximum of upward and downward slope
            # src1d_min, src1d_max = find_index_halfmax(src1d)
            # dest1d_min, dest1d_max = find_index_halfmax(dest1d)
            # >>>
            # 1D matching between src_y and dest_y
            mean_dest_x = (dest1d_max + dest1d_min) / 2
            mean_src_x = (src1d_max + src1d_min) / 2
            # compute x-scaling factor
            Sx = (dest1d_max - dest1d_min + 1) / float(src1d_max - src1d_min +
                                                       1)
            # apply transformation to coordinates
            coord_src2d_scaleX = np.copy(
                coord_src2d)  # need to use np.copy to avoid copying pointer
            coord_src2d_scaleX[:, 0] = (coord_src2d[:, 0] -
                                        mean_src_x) * Sx + mean_dest_x
            coord_init_pix_scaleX = np.copy(coord_init_pix)
            coord_init_pix_scaleX[:, 0] = (coord_init_pix[:, 0] -
                                           mean_src_x) * Sx + mean_dest_x
            coord_init_pix_scaleXinv = np.copy(coord_init_pix)
            coord_init_pix_scaleXinv[:, 0] = (
                coord_init_pix[:, 0] - mean_dest_x) / float(Sx) + mean_src_x
            # apply transformation to image
            from skimage.transform import warp
            row_scaleXinv = np.reshape(coord_init_pix_scaleXinv[:, 0],
                                       [nx, ny])
            src2d_scaleX = warp(src2d, np.array([row_scaleXinv, col]), order=1)

            # ============================================================
            # COLUMN-WISE REGISTRATION (Y dimension for each Xi)
            # ============================================================
            coord_init_pix_scaleY = np.copy(
                coord_init_pix)  # need to use np.copy to avoid copying pointer
            coord_init_pix_scaleYinv = np.copy(
                coord_init_pix)  # need to use np.copy to avoid copying pointer
            # coord_src2d_scaleXY = np.copy(coord_src2d_scaleX)  # need to use np.copy to avoid copying pointer
            # loop across columns (X dimension)
            for ix in xrange(nx):
                # retrieve 1D signal along Y
                src1d = src2d_scaleX[ix, :]
                dest1d = dest2d[ix, :]
                # make sure there are non-zero data in src or dest
                if np.any(src1d > th_nonzero) and np.any(dest1d > th_nonzero):
                    # retrieve min/max of non-zeros elements (edge of the segmentation)
                    # src1d_min, src1d_max = min(np.nonzero(src1d)[0]), max(np.nonzero(src1d)[0])
                    # dest1d_min, dest1d_max = min(np.nonzero(dest1d)[0]), max(np.nonzero(dest1d)[0])
                    # 1D matching between src_y and dest_y
                    # Ty = (dest1d_max + dest1d_min)/2 - (src1d_max + src1d_min)/2
                    # Sy = (dest1d_max - dest1d_min) / float(src1d_max - src1d_min)
                    # apply translation and scaling to coordinates in column
                    # get indices (in continuous space) at half-maximum of upward and downward slope
                    # src1d_min, src1d_max = find_index_halfmax(src1d)
                    # dest1d_min, dest1d_max = find_index_halfmax(dest1d)
                    src1d_min, src1d_max = np.min(
                        np.where(src1d > th_nonzero)), np.max(
                            np.where(src1d > th_nonzero))
                    dest1d_min, dest1d_max = np.min(
                        np.where(dest1d > th_nonzero)), np.max(
                            np.where(dest1d > th_nonzero))
                    # 1D matching between src_y and dest_y
                    mean_dest_y = (dest1d_max + dest1d_min) / 2
                    mean_src_y = (src1d_max + src1d_min) / 2
                    # Tx = (dest1d_max + dest1d_min)/2 - (src1d_max + src1d_min)/2
                    Sy = (dest1d_max - dest1d_min + 1) / float(src1d_max -
                                                               src1d_min + 1)
                    # apply forward transformation (in pixel space)
                    # below: only for debugging purpose
                    # coord_src2d_scaleX = np.copy(coord_src2d)  # need to use np.copy to avoid copying pointer
                    # coord_src2d_scaleX[:, 0] = (coord_src2d[:, 0] - mean_src) * Sx + mean_dest
                    # coord_init_pix_scaleY = np.copy(coord_init_pix)  # need to use np.copy to avoid copying pointer
                    # coord_init_pix_scaleY[:, 0] = (coord_init_pix[:, 0] - mean_src ) * Sx + mean_dest
                    range_x = range(ix * ny, ix * ny + nx)
                    coord_init_pix_scaleY[range_x,
                                          1] = (coord_init_pix[range_x, 1] -
                                                mean_src_y) * Sy + mean_dest_y
                    coord_init_pix_scaleYinv[
                        range_x, 1] = (coord_init_pix[range_x, 1] -
                                       mean_dest_y) / float(Sy) + mean_src_y
            # apply transformation to image
            col_scaleYinv = np.reshape(coord_init_pix_scaleYinv[:, 1],
                                       [nx, ny])
            src2d_scaleXY = warp(src2d,
                                 np.array([row_scaleXinv, col_scaleYinv]),
                                 order=1)
            # regularize Y warping fields
            from skimage.filters import gaussian
            col_scaleY = np.reshape(coord_init_pix_scaleY[:, 1], [nx, ny])
            col_scaleYsmooth = gaussian(col_scaleY, smoothWarpXY)
            col_scaleYinvsmooth = gaussian(col_scaleYinv, smoothWarpXY)
            # apply smoothed transformation to image
            src2d_scaleXYsmooth = warp(
                src2d, np.array([row_scaleXinv, col_scaleYinvsmooth]), order=1)
            # reshape warping field as 1d
            coord_init_pix_scaleY[:, 1] = col_scaleYsmooth.ravel()
            coord_init_pix_scaleYinv[:, 1] = col_scaleYinvsmooth.ravel()
            # display
            if verbose == 2:
                # FIG 1
                plt.figure(figsize=(15, 3))
                # plot #1
                ax = plt.subplot(141)
                plt.imshow(np.swapaxes(src2d, 1, 0),
                           cmap=plt.cm.gray,
                           interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0),
                           cmap=plt.cm.copper,
                           interpolation='none',
                           alpha=0.5)
                plt.title('src')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # plot #2
                ax = plt.subplot(142)
                plt.imshow(np.swapaxes(src2d_scaleX, 1, 0),
                           cmap=plt.cm.gray,
                           interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0),
                           cmap=plt.cm.copper,
                           interpolation='none',
                           alpha=0.5)
                plt.title('src_scaleX')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # plot #3
                ax = plt.subplot(143)
                plt.imshow(np.swapaxes(src2d_scaleXY, 1, 0),
                           cmap=plt.cm.gray,
                           interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0),
                           cmap=plt.cm.copper,
                           interpolation='none',
                           alpha=0.5)
                plt.title('src_scaleXY')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # plot #4
                ax = plt.subplot(144)
                plt.imshow(np.swapaxes(src2d_scaleXYsmooth, 1, 0),
                           cmap=plt.cm.gray,
                           interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0),
                           cmap=plt.cm.copper,
                           interpolation='none',
                           alpha=0.5)
                plt.title('src_scaleXYsmooth (s=' + str(smoothWarpXY) + ')')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # save figure
                plt.savefig(path_qc + 'register2d_columnwise_image_z' +
                            str(iz) + '.png')
                plt.close()

            # ============================================================
            # CALCULATE TRANSFORMATIONS
            # ============================================================
            # calculate forward transformation (in physical space)
            coord_init_phy_scaleX = np.array(
                im_dest.transfo_pix2phys(coord_init_pix_scaleX))
            coord_init_phy_scaleY = np.array(
                im_dest.transfo_pix2phys(coord_init_pix_scaleY))
            # calculate inverse transformation (in physical space)
            coord_init_phy_scaleXinv = np.array(
                im_src.transfo_pix2phys(coord_init_pix_scaleXinv))
            coord_init_phy_scaleYinv = np.array(
                im_src.transfo_pix2phys(coord_init_pix_scaleYinv))
            # compute displacement per pixel in destination space (for forward warping field)
            warp_x[:, :, iz] = np.array([
                coord_init_phy_scaleXinv[i, 0] - coord_init_phy[i, 0]
                for i in xrange(nx * ny)
            ]).reshape((nx, ny))
            warp_y[:, :, iz] = np.array([
                coord_init_phy_scaleYinv[i, 1] - coord_init_phy[i, 1]
                for i in xrange(nx * ny)
            ]).reshape((nx, ny))
            # compute displacement per pixel in source space (for inverse warping field)
            warp_inv_x[:, :, iz] = np.array([
                coord_init_phy_scaleX[i, 0] - coord_init_phy[i, 0]
                for i in xrange(nx * ny)
            ]).reshape((nx, ny))
            warp_inv_y[:, :, iz] = np.array([
                coord_init_phy_scaleY[i, 1] - coord_init_phy[i, 1]
                for i in xrange(nx * ny)
            ]).reshape((nx, ny))

    # Generate forward warping field (defined in destination space)
    generate_warping_field(fname_dest, warp_x, warp_y, fname_warp, verbose)
    # Generate inverse warping field (defined in source space)
    generate_warping_field(fname_src, warp_inv_x, warp_inv_y, fname_warp_inv,
                           verbose)
def register_images(
        fname_source,
        fname_dest,
        mask='',
        paramreg=Paramreg(step='0',
                          type='im',
                          algo='Translation',
                          metric='MI',
                          iter='5',
                          shrink='1',
                          smooth='0',
                          gradStep='0.5'),
        ants_registration_params={
            'rigid': '',
            'affine': '',
            'compositeaffine': '',
            'similarity': '',
            'translation': '',
            'bspline': ',10',
            'gaussiandisplacementfield': ',3,0',
            'bsplinedisplacementfield': ',5,10',
            'syn': ',3,0',
            'bsplinesyn': ',1,3'
        },
        remove_tmp_folder=1):
    """Slice-by-slice registration of two images.

    We first split the 3D images into 2D images (and the mask if inputted). Then we register slices of the two images
    that physically correspond to one another looking at the physical origin of each image. The images can be of
    different sizes but the destination image must be smaller thant the input image. We do that using antsRegistration
    in 2D. Once this has been done for each slices, we gather the results and return them.
    Algorithms implemented: translation, rigid, affine, syn and BsplineSyn.
    N.B.: If the mask is inputted, it must also be 3D and it must be in the same space as the destination image.

    input:
        fname_source: name of moving image (type: string)
        fname_dest: name of fixed image (type: string)
        mask[optional]: name of mask file (type: string) (parameter -x of antsRegistration)
        paramreg[optional]: parameters of antsRegistration (type: Paramreg class from sct_register_multimodal)
        ants_registration_params[optional]: specific algorithm's parameters for antsRegistration (type: dictionary)

    output:
        if algo==translation:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
        if algo==rigid:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
            theta_rotation: list of rotation angle in radian (and in ITK's coordinate system) for each slice (type: list)
        if algo==affine or algo==syn or algo==bsplinesyn:
            creation of two 3D warping fields (forward and inverse) that are the concatenations of the slice-by-slice
            warps.
    """
    # Extracting names
    path_i, root_i, ext_i = sct.extract_fname(fname_source)
    path_d, root_d, ext_d = sct.extract_fname(fname_dest)

    # set metricSize
    if paramreg.metric == 'MI':
        metricSize = '32'  # corresponds to number of bins
    else:
        metricSize = '4'  # corresponds to radius (for CC, MeanSquares...)

    # Get image dimensions and retrieve nz
    print '\nGet image dimensions of destination image...'
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp ' + fname_source + ' ' + path_tmp + '/' + root_i + ext_i)
    sct.run('cp ' + fname_dest + ' ' + path_tmp + '/' + root_d + ext_d)
    if mask:
        sct.run('cp ' + mask + ' ' + path_tmp + '/mask.nii.gz')

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print '\nSplit input volume...'
    from sct_split_data import split_data
    split_data(fname_source, 2, '_z')

    # Split destination volume along z
    print '\nSplit destination volume...'
    split_data(fname_dest, 2, '_z')

    # Split mask volume along z
    if mask:
        print '\nSplit mask volume...'
        split_data('mask.nii.gz', 2, '_z')

    im_dest_img = Image(fname_dest)
    im_input_img = Image(fname_source)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0, 0, 0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0, 0, 0]])
    coord_diff_origin = (asarray(coord_origin_dest[0]) -
                         asarray(coord_origin_input[0])).tolist()
    [x_o, y_o, z_o] = [
        coord_diff_origin[0] * 1.0 / px, coord_diff_origin[1] * 1.0 / py,
        coord_diff_origin[2] * 1.0 / pz
    ]

    if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
        list_warp_x = []
        list_warp_x_inv = []
        list_warp_y = []
        list_warp_y_inv = []
        name_warp_final = 'Warp_total'  #if modified, name should also be modified in msct_register (algo slicereg2d_bsplinesyn and slicereg2d_syn)

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + int(z_o))
        if mask:
            masking = '-x mask_z' + num + '.nii'
        else:
            masking = ''

        cmd = (
            'isct_antsRegistration '
            '--dimensionality 2 '
            '--transform ' + paramreg.algo + '[' + str(paramreg.gradStep) +
            ants_registration_params[paramreg.algo.lower()] + '] '
            '--metric ' + paramreg.metric + '[' + root_d + '_z' + num +
            '.nii' + ',' + root_i + '_z' + num_2 + '.nii' + ',1,' +
            metricSize +
            '] '  #[fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
            '--convergence ' + str(paramreg.iter) + ' '
            '--shrink-factors ' + str(paramreg.shrink) + ' '
            '--smoothing-sigmas ' + str(paramreg.smooth) + 'mm '
            #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
            '--output [transform_' + num + ',' + root_i + '_z' + num_2 +
            'reg.nii] '  #--> file.mat (contains Tx,Ty, theta)
            '--interpolation BSpline[3] ' + masking)

        try:
            sct.run(cmd)

            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                f = 'transform_' + num + '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                x_displacement[i] = array_transfo[4][
                    0]  # Tx in ITK'S coordinate system
                y_displacement[i] = array_transfo[5][
                    0]  # Ty  in ITK'S and fslview's coordinate systems
                theta_rotation[i] = asin(
                    array_transfo[2]
                )  # angle of rotation theta in ITK'S coordinate system (minus theta for fslview)

            if paramreg.algo == 'Affine':
                # New process added for generating total nifti warping field from mat warp
                name_dest = root_d + '_z' + num + '.nii'
                name_reg = root_i + '_z' + num + 'reg.nii'
                name_output_warp = 'warp_from_mat_' + num_2 + '.nii.gz'
                name_output_warp_inverse = 'warp_from_mat_' + num + '_inverse.nii.gz'
                name_warp_null = 'warp_null_' + num + '.nii.gz'
                name_warp_null_dest = 'warp_null_dest' + num + '.nii.gz'
                name_warp_mat = 'transform_' + num + '0GenericAffine.mat'
                # Generating null nifti warping fields
                nx, ny, nz, nt, px, py, pz, pt = Image(name_reg).dim
                nx_d, ny_d, nz_d, nt_d, px_d, py_d, pz_d, pt_d = Image(
                    name_dest).dim
                x_trans = [0 for i in range(nz)]
                x_trans_d = [0 for i in range(nz_d)]
                y_trans = [0 for i in range(nz)]
                y_trans_d = [0 for i in range(nz_d)]
                generate_warping_field(name_reg,
                                       x_trans=x_trans,
                                       y_trans=y_trans,
                                       fname=name_warp_null,
                                       verbose=0)
                generate_warping_field(name_dest,
                                       x_trans=x_trans_d,
                                       y_trans=y_trans_d,
                                       fname=name_warp_null_dest,
                                       verbose=0)
                # Concatenating mat wrp and null nifti warp to obtain equivalent nifti warp to mat warp
                sct.run('isct_ComposeMultiTransform 2 ' + name_output_warp +
                        ' -R ' + name_reg + ' ' + name_warp_null + ' ' +
                        name_warp_mat)
                sct.run('isct_ComposeMultiTransform 2 ' +
                        name_output_warp_inverse + ' -R ' + name_dest + ' ' +
                        name_warp_null_dest + ' -i ' + name_warp_mat)
                # Split the warping fields into two for displacement along x and y before merge
                sct.run('isct_c3d -mcs ' + name_output_warp +
                        ' -oo transform_' + num + '0Warp_x.nii.gz transform_' +
                        num + '0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs ' + name_output_warp_inverse +
                        ' -oo transform_' + num +
                        '0InverseWarp_x.nii.gz transform_' + num +
                        '0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_' + num + '0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_' + num +
                                       '0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_' + num + '0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_' + num +
                                       '0InverseWarp_y.nii.gz')

            if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN':
                # Split the warping fields into two for displacement along x and y before merge
                # Need to separate the merge for x and y displacement as merge of 3d warping fields does not work properly
                sct.run('isct_c3d -mcs transform_' + num +
                        '0Warp.nii.gz -oo transform_' + num +
                        '0Warp_x.nii.gz transform_' + num + '0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs transform_' + num +
                        '0InverseWarp.nii.gz -oo transform_' + num +
                        '0InverseWarp_x.nii.gz transform_' + num +
                        '0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_' + num + '0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_' + num +
                                       '0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_' + num + '0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_' + num +
                                       '0InverseWarp_y.nii.gz')
        # if an exception occurs with ants, take the last value for the transformation
        except:
            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                x_displacement[i] = x_displacement[i - 1]
                y_displacement[i] = y_displacement[i - 1]
                theta_rotation[i] = theta_rotation[i - 1]

            if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
                print 'Problem with ants for slice ' + str(
                    i) + '. Copy of the last warping field.'
                sct.run('cp transform_' + numerotation(i - 1) +
                        '0Warp.nii.gz transform_' + num + '0Warp.nii.gz')
                sct.run('cp transform_' + numerotation(i - 1) +
                        '0InverseWarp.nii.gz transform_' + num +
                        '0InverseWarp.nii.gz')
                # Split the warping fields into two for displacement along x and y before merge
                sct.run('isct_c3d -mcs transform_' + num +
                        '0Warp.nii.gz -oo transform_' + num +
                        '0Warp_x.nii.gz transform_' + num + '0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs transform_' + num +
                        '0InverseWarp.nii.gz -oo transform_' + num +
                        '0InverseWarp_x.nii.gz transform_' + num +
                        '0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_' + num + '0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_' + num +
                                       '0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_' + num + '0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_' + num +
                                       '0InverseWarp_y.nii.gz')

    if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
        print '\nMerge along z of the warping fields...'
        # from sct_concat_data import concat_data
        sct.run('sct_concat_data -i ' + ','.join(list_warp_x) + ' -o ' +
                name_warp_final + '_x.nii.gz -dim z')
        sct.run('sct_concat_data -i ' + ','.join(list_warp_x_inv) + ' -o ' +
                name_warp_final + '_x_inverse.nii.gz -dim z')
        sct.run('sct_concat_data -i ' + ','.join(list_warp_y) + ' -o ' +
                name_warp_final + '_y.nii.gz -dim z')
        sct.run('sct_concat_data -i ' + ','.join(list_warp_y_inv) + ' -o ' +
                name_warp_final + '_y_inverse.nii.gz -dim z')
        # concat_data(','.join(list_warp_x), name_warp_final+'_x.nii.gz', 2)
        # concat_data(','.join(list_warp_x_inv), name_warp_final+'_x_inverse.nii.gz', 2)
        # concat_data(','.join(list_warp_y), name_warp_final+'_y.nii.gz', 2)
        # concat_data(','.join(list_warp_y_inv), name_warp_final+'_y_inverse.nii.gz', 2)
        # sct.run('fslmerge -z ' + name_warp_final + '_x ' + " ".join(list_warp_x))
        # sct.run('fslmerge -z ' + name_warp_final + '_x_inverse ' + " ".join(list_warp_x_inv))
        # sct.run('fslmerge -z ' + name_warp_final + '_y ' + " ".join(list_warp_y))
        # sct.run('fslmerge -z ' + name_warp_final + '_y_inverse ' + " ".join(list_warp_y_inv))
        print '\nChange resolution of warping fields to match the resolution of the destination image...'
        from sct_copy_header import copy_header
        copy_header(fname_dest, name_warp_final + '_x.nii.gz')
        copy_header(fname_source, name_warp_final + '_x_inverse.nii.gz')
        copy_header(fname_dest, name_warp_final + '_y.nii.gz')
        copy_header(fname_source, name_warp_final + '_y_inverse.nii.gz')
        print '\nMerge translation fields along x and y into one global warping field '
        sct.run('isct_c3d ' + name_warp_final + '_x.nii.gz ' +
                name_warp_final + '_y.nii.gz -omc 2 ' + name_warp_final +
                '.nii.gz')
        sct.run('isct_c3d ' + name_warp_final + '_x_inverse.nii.gz ' +
                name_warp_final + '_y_inverse.nii.gz -omc 2 ' +
                name_warp_final + '_inverse.nii.gz')
        print '\nCopy to parent folder...'
        sct.run('cp ' + name_warp_final + '.nii.gz ../')
        sct.run('cp ' + name_warp_final + '_inverse.nii.gz ../')

    #Delete tmp folder
    os.chdir('../')
    if remove_tmp_folder:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)
    if paramreg.algo == 'Rigid':
        return x_displacement, y_displacement, theta_rotation
    if paramreg.algo == 'Translation':
        return x_displacement, y_displacement
def main():

    # get default parameters
    step1 = Paramreg(step='1', type='seg', algo='slicereg', metric='MeanSquares', iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(name="-l",
                      type_value="file",
                      description="Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
                      mandatory=True,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(name="-p",
                      type_value=[[':'], 'str'],
                      description="""Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""+paramreg.steps['1'].type+"""\nalgo="""+paramreg.steps['1'].algo+"""\nmetric="""+paramreg.steps['1'].metric+"""\npoly="""+paramreg.steps['1'].poly+"""\n--\nstep=2\ntype="""+paramreg.steps['2'].type+"""\nalgo="""+paramreg.steps['2'].algo+"""\nmetric="""+paramreg.steps['2'].metric+"""\niter="""+paramreg.steps['2'].iter+"""\nshrink="""+paramreg.steps['2'].shrink+"""\nsmooth="""+paramreg.steps['2'].smooth+"""\ngradStep="""+paramreg.steps['2'].gradStep+"""\n--""",
                      mandatory=False,
                      example="step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3")
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
                      mandatory=False,
                      default_value=param.verbose,
                      example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')


    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir '+path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d '+fname_data+' -o '+path_tmp+'/data.nii')
    sct.run('isct_c3d '+fname_landmarks+' -o '+path_tmp+'/landmarks.nii.gz')
    sct.run('isct_c3d '+fname_seg+' -o '+path_tmp+'/segmentation.nii.gz')
    sct.run('isct_c3d '+fname_template+' -o '+path_tmp+'/template.nii')
    sct.run('isct_c3d '+fname_template_label+' -o '+path_tmp+'/template_labels.nii.gz')
    sct.run('isct_c3d '+fname_template_seg+' -o '+path_tmp+'/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run('isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii')
    sct.run('isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz')
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run('sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '+str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz')

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run('isct_c3d template_label.nii.gz -type int -o template_label.nii.gz', verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run('sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5')

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv('\nCreate a 5mm cross for the input labels and dilate for straightening preparation...', verbose)
    sct.run('sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn')

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run('isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz')

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.', verbose)
    sct.run('sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz')

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv('\nEstimate affine transfo: straight anat --> template (landmark-based)...', verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_moving.append([point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_fixed.append([point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv('\nComputing rigid transformation (algo=translation-scaling-z) ...', verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write("Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n" % (
        1.0/rotation_matrix[0, 0], rotation_matrix[0, 1],     rotation_matrix[0, 2],
        rotation_matrix[1, 0],     1.0/rotation_matrix[1, 1], rotation_matrix[1, 2],
        rotation_matrix[2, 0],     rotation_matrix[2, 1],     1.0/rotation_matrix[2, 2],
        translation_array[0, 0],   translation_array[0, 1],   -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" % (points_moving_barycenter[0],
                                                           points_moving_barycenter[1],
                                                           points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz')
    sct.run('sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear')

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax('segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    sct.run('sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'+zsubsample, verbose)
    sct.run('sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'+zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'/warp_template2anat.nii.gz', 'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'/warp_anat2template.nii.gz', 'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'/template2anat.nii.gz', 'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'/anat2template.nii.gz', 'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 anat2template &\n', verbose, 'info')
Exemple #13
0
#
# generate_warping_field('data_T2_RPI.nii.gz', x_disp_2_smooth, y_disp_2_smooth, fname='warping_field_im_trans.nii.gz')
# sct.run('sct_apply_transfo -i data_RPI_registered_reg1.nii.gz -d data_T2_RPI.nii.gz -w warping_field_im_trans.nii.gz -o data_RPI_registered_reg2.nii.gz -x spline')

f_1 = "/Users/tamag/data/data_template/independant_templates/Results_magma/t2_avg_RPI.nii.gz"
f_2 = "/Users/tamag/data/data_template/independant_templates/Results_magma/t1_avg.independent_RPI_reg1_unpad.nii.gz"
f_3 = "/Users/tamag/data/data_template/independant_templates/Results_magma/t1_avg.independent_RPI.nii.gz"

os.chdir("/Users/tamag/data/data_template/independant_templates/Results_magma")

im_1 = Image(f_1)
im_2 = Image(f_2)

data_1 = im_1.data

coord_test1 = [[1, 1, 1]]
coord_test = [[1, 1, 1], [2, 2, 2], [3, 3, 3]]

coordi_phys = im_1.transfo_pix2phys(coordi=coord_test)
coordi_pix = im_1.transfo_phys2pix(coordi=coordi_phys)
bla

# im_3 = nibabel.load(f_3)
# data_3 = im_3.get_data()
# hdr_3 = im_3.get_header()
#
# data_f = data_3 - laplace(data_3)
#
# img_f = nibabel.Nifti1Image(data_f, None, hdr_3)
# nibabel.save(img_f, "rehauss.nii.gz")
def register_images(fname_source, fname_dest, mask='', paramreg=Paramreg(step='0', type='im', algo='Translation', metric='MI', iter='5', shrink='1', smooth='0', gradStep='0.5'),
                    ants_registration_params={'rigid': '', 'affine': '', 'compositeaffine': '', 'similarity': '', 'translation': '','bspline': ',10', 'gaussiandisplacementfield': ',3,0',
                                              'bsplinedisplacementfield': ',5,10', 'syn': ',3,0', 'bsplinesyn': ',1,3'}, remove_tmp_folder = 1):
    """Slice-by-slice registration of two images.

    We first split the 3D images into 2D images (and the mask if inputted). Then we register slices of the two images
    that physically correspond to one another looking at the physical origin of each image. The images can be of
    different sizes but the destination image must be smaller thant the input image. We do that using antsRegistration
    in 2D. Once this has been done for each slices, we gather the results and return them.
    Algorithms implemented: translation, rigid, affine, syn and BsplineSyn.
    N.B.: If the mask is inputted, it must also be 3D and it must be in the same space as the destination image.

    input:
        fname_source: name of moving image (type: string)
        fname_dest: name of fixed image (type: string)
        mask[optional]: name of mask file (type: string) (parameter -x of antsRegistration)
        paramreg[optional]: parameters of antsRegistration (type: Paramreg class from sct_register_multimodal)
        ants_registration_params[optional]: specific algorithm's parameters for antsRegistration (type: dictionary)

    output:
        if algo==translation:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
        if algo==rigid:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
            theta_rotation: list of rotation angle in radian (and in ITK's coordinate system) for each slice (type: list)
        if algo==affine or algo==syn or algo==bsplinesyn:
            creation of two 3D warping fields (forward and inverse) that are the concatenations of the slice-by-slice
            warps.
    """
    # Extracting names
    path_i, root_i, ext_i = sct.extract_fname(fname_source)
    path_d, root_d, ext_d = sct.extract_fname(fname_dest)

    # set metricSize
    if paramreg.metric == 'MI':
        metricSize = '32'  # corresponds to number of bins
    else:
        metricSize = '4'  # corresponds to radius (for CC, MeanSquares...)


    # Get image dimensions and retrieve nz
    print '\nGet image dimensions of destination image...'
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    print '.. matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz)
    print '.. voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm'

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.'+time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp '+fname_source+ ' ' + path_tmp +'/'+ root_i+ext_i)
    sct.run('cp '+fname_dest+ ' ' + path_tmp +'/'+ root_d+ext_d)
    if mask:
        sct.run('cp '+mask+ ' '+path_tmp +'/mask.nii.gz')

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print '\nSplit input volume...'
    from sct_split_data import split_data
    split_data(fname_source, 2, '_z')

    # Split destination volume along z
    print '\nSplit destination volume...'
    split_data(fname_dest, 2, '_z')

    # Split mask volume along z
    if mask:
        print '\nSplit mask volume...'
        split_data('mask.nii.gz', 2, '_z')

    im_dest_img = Image(fname_dest)
    im_input_img = Image(fname_source)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0,0,0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0,0,0]])
    coord_diff_origin = (asarray(coord_origin_dest[0]) - asarray(coord_origin_input[0])).tolist()
    [x_o, y_o, z_o] = [coord_diff_origin[0] * 1.0/px, coord_diff_origin[1] * 1.0/py, coord_diff_origin[2] * 1.0/pz]

    if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
        list_warp_x = []
        list_warp_x_inv = []
        list_warp_y = []
        list_warp_y_inv = []
        name_warp_final = 'Warp_total' #if modified, name should also be modified in msct_register (algo slicereg2d_bsplinesyn and slicereg2d_syn)

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + int(z_o))
        if mask:
            masking = '-x mask_z' +num+ '.nii'
        else:
            masking = ''

        cmd = ('isct_antsRegistration '
               '--dimensionality 2 '
               '--transform '+paramreg.algo+'['+str(paramreg.gradStep) +
               ants_registration_params[paramreg.algo.lower()]+'] '
               '--metric '+paramreg.metric+'['+root_d+'_z'+ num +'.nii' +','+root_i+'_z'+ num_2 +'.nii' +',1,'+metricSize+'] '  #[fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
               '--convergence '+str(paramreg.iter)+' '
               '--shrink-factors '+str(paramreg.shrink)+' '
               '--smoothing-sigmas '+str(paramreg.smooth)+'mm '
               #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
               '--output [transform_' + num + ','+root_i+'_z'+ num_2 +'reg.nii] '    #--> file.mat (contains Tx,Ty, theta)
               '--interpolation BSpline[3] '
               +masking)

        try:
            sct.run(cmd)

            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                f = 'transform_' +num+ '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                x_displacement[i] = array_transfo[4][0]  # Tx in ITK'S coordinate system
                y_displacement[i] = array_transfo[5][0]  # Ty  in ITK'S and fslview's coordinate systems
                theta_rotation[i] = asin(array_transfo[2]) # angle of rotation theta in ITK'S coordinate system (minus theta for fslview)

            if paramreg.algo == 'Affine':
                # New process added for generating total nifti warping field from mat warp
                name_dest = root_d+'_z'+ num +'.nii'
                name_reg = root_i+'_z'+ num +'reg.nii'
                name_output_warp = 'warp_from_mat_' + num_2 + '.nii.gz'
                name_output_warp_inverse = 'warp_from_mat_' + num + '_inverse.nii.gz'
                name_warp_null = 'warp_null_' + num + '.nii.gz'
                name_warp_null_dest = 'warp_null_dest' + num + '.nii.gz'
                name_warp_mat = 'transform_' + num + '0GenericAffine.mat'
                # Generating null nifti warping fields
                nx, ny, nz, nt, px, py, pz, pt = Image(name_reg).dim
                nx_d, ny_d, nz_d, nt_d, px_d, py_d, pz_d, pt_d = Image(name_dest).dim
                x_trans = [0 for i in range(nz)]
                x_trans_d = [0 for i in range(nz_d)]
                y_trans= [0 for i in range(nz)]
                y_trans_d = [0 for i in range(nz_d)]
                generate_warping_field(name_reg, x_trans=x_trans, y_trans=y_trans, fname=name_warp_null, verbose=0)
                generate_warping_field(name_dest, x_trans=x_trans_d, y_trans=y_trans_d, fname=name_warp_null_dest, verbose=0)
                # Concatenating mat wrp and null nifti warp to obtain equivalent nifti warp to mat warp
                sct.run('isct_ComposeMultiTransform 2 ' + name_output_warp + ' -R ' + name_reg + ' ' + name_warp_null + ' ' + name_warp_mat)
                sct.run('isct_ComposeMultiTransform 2 ' + name_output_warp_inverse + ' -R ' + name_dest + ' ' + name_warp_null_dest + ' -i ' + name_warp_mat)
                # Split the warping fields into two for displacement along x and y before merge
                sct.run('isct_c3d -mcs ' + name_output_warp + ' -oo transform_'+num+'0Warp_x.nii.gz transform_'+num+'0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs ' + name_output_warp_inverse + ' -oo transform_'+num+'0InverseWarp_x.nii.gz transform_'+num+'0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_'+num+'0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_'+num+'0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_'+num+'0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_'+num+'0InverseWarp_y.nii.gz')

            if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN':
                # Split the warping fields into two for displacement along x and y before merge
                # Need to separate the merge for x and y displacement as merge of 3d warping fields does not work properly
                sct.run('isct_c3d -mcs transform_'+num+'0Warp.nii.gz -oo transform_'+num+'0Warp_x.nii.gz transform_'+num+'0Warp_y.nii.gz')
                sct.run('isct_c3d -mcs transform_'+num+'0InverseWarp.nii.gz -oo transform_'+num+'0InverseWarp_x.nii.gz transform_'+num+'0InverseWarp_y.nii.gz')
                # List names of warping fields for futur merge
                list_warp_x.append('transform_'+num+'0Warp_x.nii.gz')
                list_warp_x_inv.append('transform_'+num+'0InverseWarp_x.nii.gz')
                list_warp_y.append('transform_'+num+'0Warp_y.nii.gz')
                list_warp_y_inv.append('transform_'+num+'0InverseWarp_y.nii.gz')
        # if an exception occurs with ants, take the last value for the transformation
        except:
                if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                    x_displacement[i] = x_displacement[i-1]
                    y_displacement[i] = y_displacement[i-1]
                    theta_rotation[i] = theta_rotation[i-1]


                if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
                    print'Problem with ants for slice '+str(i)+'. Copy of the last warping field.'
                    sct.run('cp transform_' + numerotation(i-1) + '0Warp.nii.gz transform_' + num + '0Warp.nii.gz')
                    sct.run('cp transform_' + numerotation(i-1) + '0InverseWarp.nii.gz transform_' + num + '0InverseWarp.nii.gz')
                    # Split the warping fields into two for displacement along x and y before merge
                    sct.run('isct_c3d -mcs transform_'+num+'0Warp.nii.gz -oo transform_'+num+'0Warp_x.nii.gz transform_'+num+'0Warp_y.nii.gz')
                    sct.run('isct_c3d -mcs transform_'+num+'0InverseWarp.nii.gz -oo transform_'+num+'0InverseWarp_x.nii.gz transform_'+num+'0InverseWarp_y.nii.gz')
                    # List names of warping fields for futur merge
                    list_warp_x.append('transform_'+num+'0Warp_x.nii.gz')
                    list_warp_x_inv.append('transform_'+num+'0InverseWarp_x.nii.gz')
                    list_warp_y.append('transform_'+num+'0Warp_y.nii.gz')
                    list_warp_y_inv.append('transform_'+num+'0InverseWarp_y.nii.gz')

    if paramreg.algo == 'BSplineSyN' or paramreg.algo == 'SyN' or paramreg.algo == 'Affine':
        print'\nMerge along z of the warping fields...'
        # from sct_concat_data import concat_data
        sct.run('sct_concat_data -i '+','.join(list_warp_x)+' -o '+name_warp_final+'_x.nii.gz -dim z')
        sct.run('sct_concat_data -i '+','.join(list_warp_x_inv)+' -o '+name_warp_final+'_x_inverse.nii.gz -dim z')
        sct.run('sct_concat_data -i '+','.join(list_warp_y)+' -o '+name_warp_final+'_y.nii.gz -dim z')
        sct.run('sct_concat_data -i '+','.join(list_warp_y_inv)+' -o '+name_warp_final+'_y_inverse.nii.gz -dim z')
        # concat_data(','.join(list_warp_x), name_warp_final+'_x.nii.gz', 2)
        # concat_data(','.join(list_warp_x_inv), name_warp_final+'_x_inverse.nii.gz', 2)
        # concat_data(','.join(list_warp_y), name_warp_final+'_y.nii.gz', 2)
        # concat_data(','.join(list_warp_y_inv), name_warp_final+'_y_inverse.nii.gz', 2)
        # sct.run('fslmerge -z ' + name_warp_final + '_x ' + " ".join(list_warp_x))
        # sct.run('fslmerge -z ' + name_warp_final + '_x_inverse ' + " ".join(list_warp_x_inv))
        # sct.run('fslmerge -z ' + name_warp_final + '_y ' + " ".join(list_warp_y))
        # sct.run('fslmerge -z ' + name_warp_final + '_y_inverse ' + " ".join(list_warp_y_inv))
        print'\nChange resolution of warping fields to match the resolution of the destination image...'
        from sct_copy_header import copy_header
        copy_header(fname_dest, name_warp_final + '_x.nii.gz')
        copy_header(fname_source, name_warp_final + '_x_inverse.nii.gz')
        copy_header(fname_dest, name_warp_final + '_y.nii.gz')
        copy_header(fname_source, name_warp_final + '_y_inverse.nii.gz')
        print'\nMerge translation fields along x and y into one global warping field '
        sct.run('isct_c3d ' + name_warp_final + '_x.nii.gz ' + name_warp_final + '_y.nii.gz -omc 2 ' + name_warp_final + '.nii.gz')
        sct.run('isct_c3d ' + name_warp_final + '_x_inverse.nii.gz ' + name_warp_final + '_y_inverse.nii.gz -omc 2 ' + name_warp_final + '_inverse.nii.gz')
        print'\nCopy to parent folder...'
        sct.run('cp ' + name_warp_final + '.nii.gz ../')
        sct.run('cp ' + name_warp_final + '_inverse.nii.gz ../')

    #Delete tmp folder
    os.chdir('../')
    if remove_tmp_folder:
        print('\nRemove temporary files...')
        sct.run('rm -rf '+path_tmp)
    if paramreg.algo == 'Rigid':
        return x_displacement, y_displacement, theta_rotation
    if paramreg.algo == 'Translation':
        return x_displacement, y_displacement
def register_images(im_input,
                    im_dest,
                    mask='',
                    paramreg=Paramreg(step='0',
                                      type='im',
                                      algo='Translation',
                                      metric='MI',
                                      iter='5',
                                      shrink='1',
                                      smooth='0',
                                      gradStep='0.5'),
                    remove_tmp_folder=1):

    path_i, root_i, ext_i = sct.extract_fname(im_input)
    path_d, root_d, ext_d = sct.extract_fname(im_dest)
    path_m, root_m, ext_m = sct.extract_fname(mask)

    # set metricSize
    if paramreg.metric == 'MI':
        metricSize = '32'  # corresponds to number of bins
    else:
        metricSize = '4'  # corresponds to radius (for CC, MeanSquares...)

    # initiate default parameters of antsRegistration transformation
    ants_registration_params = {
        'rigid': '',
        'affine': '',
        'compositeaffine': '',
        'similarity': '',
        'translation': '',
        'bspline': ',10',
        'gaussiandisplacementfield': ',3,0',
        'bsplinedisplacementfield': ',5,10',
        'syn': ',3,0',
        'bsplinesyn': ',3,32'
    }

    # Get image dimensions and retrieve nz
    print '\nGet image dimensions of destination image...'
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(im_dest)
    print '.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz)
    print '.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(
        pz) + 'mm'

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]
    matrix_def = [0 for i in range(nz)]

    # create temporary folder
    print('\nCreate temporary folder...')
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print '\nCopy input data...'
    sct.run('cp ' + im_input + ' ' + path_tmp + '/' + root_i + ext_i)
    sct.run('cp ' + im_dest + ' ' + path_tmp + '/' + root_d + ext_d)
    if mask:
        sct.run('cp ' + mask + ' ' + path_tmp + '/mask.nii.gz')

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print '\nSplit input volume...'
    sct.run(sct.fsloutput + 'fslsplit ' + im_input + ' ' + root_i + '_z -z')
    #file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Split destination volume along z
    print '\nSplit destination volume...'
    sct.run(sct.fsloutput + 'fslsplit ' + im_dest + ' ' + root_d + '_z -z')
    #file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    # Split mask volume along z
    if mask:
        print '\nSplit mask volume...'
        sct.run(sct.fsloutput + 'fslsplit mask.nii.gz mask_z -z')
        #file_anat_split = ['tmp.anat_orient_z'+str(z).zfill(4) for z in range(0,nz,1)]

    im_dest_img = Image(im_dest)
    im_input_img = Image(im_input)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0, 0, 0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0, 0, 0]])
    coord_diff_origin_z = coord_origin_dest[0][2] - coord_origin_input[0][2]
    [[x_o, y_o,
      z_o]] = im_input_img.transfo_phys2pix([[0, 0, coord_diff_origin_z]])

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + z_o)
        if mask:
            masking = '-x mask_z' + num + '.nii'
        else:
            masking = ''

        cmd = (
            'isct_antsRegistration '
            '--dimensionality 2 '
            '--transform ' + paramreg.algo + '[' + paramreg.gradStep +
            ants_registration_params[paramreg.algo.lower()] + '] '
            '--metric ' + paramreg.metric + '[' + root_d + '_z' + num +
            '.nii' + ',' + root_i + '_z' + num_2 + '.nii' + ',1,' +
            metricSize +
            '] '  #[fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
            '--convergence ' + paramreg.iter + ' '
            '--shrink-factors ' + paramreg.shrink + ' '
            '--smoothing-sigmas ' + paramreg.smooth + 'mm '
            #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
            '--output [transform_' + num +
            '] '  #--> file.txt (contains Tx,Ty)    [outputTransformPrefix,<outputWarpedImage>,<outputInverseWarpedImage>]
            '--interpolation BSpline[3] ' + masking)

        try:
            sct.run(cmd)

            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                f = 'transform_' + num + '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                if i == 20 or i == 40:
                    print i
                x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
                y_displacement[i] = array_transfo[5][0]
                theta_rotation[i] = asin(array_transfo[2])

            if paramreg.algo == 'Affine':
                f = 'transform_' + num + '0GenericAffine.mat'
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile['AffineTransform_double_2_2']
                x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
                y_displacement[i] = array_transfo[5][0]
                matrix_def[i] = [[array_transfo[0][0], array_transfo[1][0]],
                                 [array_transfo[2][0], array_transfo[3][0]]
                                 ]  # comment savoir lequel est lequel?

        except:
            if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
                x_displacement[i] = x_displacement[i - 1]  #is it? or is it y?
                y_displacement[i] = y_displacement[i - 1]
                theta_rotation[i] = theta_rotation[i - 1]
            if paramreg.algo == 'Affine':
                x_displacement[i] = x_displacement[i - 1]
                y_displacement[i] = y_displacement[i - 1]
                matrix_def[i] = matrix_def[i - 1]

        # # get displacement form this slice and complete x and y displacement lists
        # with open('transform_'+num+'.csv') as f:
        #     reader = csv.reader(f)
        #     count = 0
        #     for line in reader:
        #         count += 1
        #         if count == 2:
        #             x_displacement[i] = line[0]
        #             y_displacement[i] = line[1]
        #             f.close()

        # # get matrix of transfo for a rigid transform   (pb slicereg fait une rotation ie le deplacement n'est pas homogene par slice)
        # # recuperer le deplacement ne donnerait pas une liste mais un warping field: mieux vaut recup la matrice output
        # # pb du smoothing du deplacement par slice !!   on peut smoother les param theta tx ty
        # if paramreg.algo == 'Rigid' or paramreg.algo == 'Translation':
        #     f = 'transform_' +num+ '0GenericAffine.mat'
        #     matfile = loadmat(f, struct_as_record=True)
        #     array_transfo = matfile['AffineTransform_double_2_2']
        #     x_displacement[i] = -array_transfo[4][0]  #is it? or is it y?
        #     y_displacement[i] = array_transfo[5][0]
        #     theta_rotation[i] = acos(array_transfo[0])

        #TO DO: different treatment for other algo

    #Delete tmp folder
    os.chdir('../')
    if remove_tmp_folder:
        print('\nRemove temporary files...')
        sct.run('rm -rf ' + path_tmp)
    if paramreg.algo == 'Rigid':
        return x_displacement, y_displacement, theta_rotation  # check if the displacement are not inverted (x_dis = -x_disp...)   theta is in radian
    if paramreg.algo == 'Translation':
        return x_displacement, y_displacement
    if paramreg.algo == 'Affine':
        return x_displacement, y_displacement, matrix_def
def interpolate_im_to_ref(im_input,
                          im_input_sc,
                          new_res=0.3,
                          sq_size_size_mm=22.5,
                          interpolation_mode=3):
    nx, ny, nz, nt, px, py, pz, pt = im_input.dim

    im_input_sc = im_input_sc.copy()
    im_input = im_input.copy()

    # keep only spacing and origin in qform to avoid rotation issues
    input_qform = im_input.hdr.get_qform()
    for i in range(4):
        for j in range(4):
            if i != j and j != 3:
                input_qform[i, j] = 0

    im_input.hdr.set_qform(input_qform)
    im_input.hdr.set_sform(input_qform)
    im_input_sc.hdr = im_input.hdr

    sq_size = int(sq_size_size_mm / new_res)
    # create a reference image : square of ones
    im_ref = Image(np.ones((sq_size, sq_size, 1), dtype=np.int),
                   dim=(sq_size, sq_size, 1, 0, new_res, new_res, pz, 0),
                   orientation='RPI')

    # copy input qform matrix to reference image
    im_ref.hdr.set_qform(im_input.hdr.get_qform())
    im_ref.hdr.set_sform(im_input.hdr.get_sform())

    # set correct header to reference image
    im_ref.hdr.set_data_shape((sq_size, sq_size, 1))
    im_ref.hdr.set_zooms((new_res, new_res, pz))

    # save image to set orientation to RPI (not properly done at the creation of the image)
    fname_ref = 'im_ref.nii.gz'
    im_ref.setFileName(fname_ref)
    im_ref.save()
    im_ref = set_orientation(im_ref, 'RPI', fname_out=fname_ref)

    # set header origin to zero to get physical coordinates of the center of the square
    im_ref.hdr.as_analyze_map()['qoffset_x'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_y'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_z'] = 0
    im_ref.hdr.set_sform(im_ref.hdr.get_qform())
    im_ref.hdr.set_qform(im_ref.hdr.get_qform())
    [[x_square_center_phys, y_square_center_phys,
      z_square_center_phys]] = im_ref.transfo_pix2phys(
          coordi=[[int(sq_size / 2), int(sq_size / 2), 0]])

    list_interpolate_images = []
    # iterate on z dimension of input image
    for iz in range(nz):
        # copy reference image: one reference image per slice
        im_ref_slice_iz = im_ref.copy()

        # get center of mass of SC for slice iz
        x_seg, y_seg = (im_input_sc.data[:, :, iz] > 0).nonzero()
        x_center, y_center = np.mean(x_seg), np.mean(y_seg)
        [[x_center_phys, y_center_phys, z_center_phys]
         ] = im_input_sc.transfo_pix2phys(coordi=[[x_center, y_center, iz]])

        # center reference image on SC for slice iz
        im_ref_slice_iz.hdr.as_analyze_map(
        )['qoffset_x'] = x_center_phys - x_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map(
        )['qoffset_y'] = y_center_phys - y_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_z'] = z_center_phys
        im_ref_slice_iz.hdr.set_sform(im_ref_slice_iz.hdr.get_qform())
        im_ref_slice_iz.hdr.set_qform(im_ref_slice_iz.hdr.get_qform())

        # interpolate input image to reference image
        im_input_interpolate_iz = im_input.interpolate_from_image(
            im_ref_slice_iz,
            interpolation_mode=interpolation_mode,
            border='nearest')
        # reshape data to 2D if needed
        if len(im_input_interpolate_iz.data.shape) == 3:
            im_input_interpolate_iz.data = im_input_interpolate_iz.data.reshape(
                im_input_interpolate_iz.data.shape[:-1])
        # add slice to list
        list_interpolate_images.append(im_input_interpolate_iz)

    return list_interpolate_images
def register_seg(seg_input, seg_dest, verbose=1):
    """Slice-by-slice registration by translation of two segmentations.
    For each slice, we estimate the translation vector by calculating the difference of position of the two centers of
    mass in voxel unit.
    The segmentations can be of different sizes but the output segmentation must be smaller than the input segmentation.

    input:
        seg_input: name of moving segmentation file (type: string)
        seg_dest: name of fixed segmentation file (type: string)

    output:
        x_displacement: list of translation along x axis for each slice (type: list)
        y_displacement: list of translation along y axis for each slice (type: list)

    """

    seg_input_img = Image(seg_input)
    seg_dest_img = Image(seg_dest)
    seg_input_data = seg_input_img.data
    seg_dest_data = seg_dest_img.data

    x_center_of_mass_input = [0] * seg_dest_data.shape[2]
    y_center_of_mass_input = [0] * seg_dest_data.shape[2]
    sct.printv('\nGet center of mass of the input segmentation for each slice '
               '(corresponding to a slice in the output segmentation)...', verbose)  # different if size of the two seg are different
    # TODO: select only the slices corresponding to the output segmentation

    # grab physical coordinates of destination origin
    coord_origin_dest = seg_dest_img.transfo_pix2phys([[0, 0, 0]])

    # grab the voxel coordinates of the destination origin from the source image
    [[x_o, y_o, z_o]] = seg_input_img.transfo_phys2pix(coord_origin_dest)

    # calculate center of mass for each slice of the input image
    for iz in xrange(seg_dest_data.shape[2]):
        # starts from z_o, which is the origin of the destination image in the source image
        x_center_of_mass_input[iz], y_center_of_mass_input[iz] = ndimage.measurements.center_of_mass(array(seg_input_data[:, :, z_o + iz]))

    # initialize data
    x_center_of_mass_output = [0] * seg_dest_data.shape[2]
    y_center_of_mass_output = [0] * seg_dest_data.shape[2]

    # calculate center of mass for each slice of the destination image
    sct.printv('\nGet center of mass of the destination segmentation for each slice ...', verbose)
    for iz in xrange(seg_dest_data.shape[2]):
        try:
            x_center_of_mass_output[iz], y_center_of_mass_output[iz] = ndimage.measurements.center_of_mass(array(seg_dest_data[:, :, iz]))
        except Exception as e:
            sct.printv('WARNING: Exception error in msct_register_regularized during register_seg:', 1, 'warning')
            print 'Error on line {}'.format(sys.exc_info()[-1].tb_lineno)
            print e

    # calculate displacement in voxel space
    x_displacement = [0] * seg_input_data.shape[2]
    y_displacement = [0] * seg_input_data.shape[2]
    sct.printv('\nGet displacement by voxel...', verbose)
    for iz in xrange(seg_dest_data.shape[2]):
        x_displacement[iz] = -(x_center_of_mass_output[iz] - x_center_of_mass_input[iz])    # WARNING: in ITK's coordinate system, this is actually Tx and not -Tx
        y_displacement[iz] = y_center_of_mass_output[iz] - y_center_of_mass_input[iz]      # This is Ty in ITK's and fslview' coordinate systems

    return x_displacement, y_displacement, None
class ProcessLabels(object):
    def __init__(self,
                 fname_label,
                 fname_output=None,
                 fname_ref=None,
                 cross_radius=5,
                 dilate=False,
                 coordinates=None,
                 verbose=1,
                 vertebral_levels=None,
                 value=None,
                 msg=""):
        """
        Collection of processes that deal with label creation/modification.
        :param fname_label:
        :param fname_output:
        :param fname_ref:
        :param cross_radius:
        :param dilate:
        :param coordinates:
        :param verbose:
        :param vertebral_levels:
        :param value:
        :param msg: string. message to display to the user.
        """
        self.image_input = Image(fname_label, verbose=verbose)
        self.image_ref = None
        if fname_ref is not None:
            self.image_ref = Image(fname_ref, verbose=verbose)

        if isinstance(fname_output, list):
            if len(fname_output) == 1:
                self.fname_output = fname_output[0]
            else:
                self.fname_output = fname_output
        else:
            self.fname_output = fname_output
        self.cross_radius = cross_radius
        self.vertebral_levels = vertebral_levels
        self.dilate = dilate
        self.coordinates = coordinates
        self.verbose = verbose
        self.value = value
        self.msg = msg

    def process(self, type_process):
        # for some processes, change orientation of input image to RPI
        change_orientation = False
        if type_process in ['vert-body', 'vert-disc', 'vert-continuous']:
            # get orientation of input image
            input_orientation = self.image_input.orientation
            # change orientation
            self.image_input.change_orientation('RPI')
            change_orientation = True
        if type_process == 'add':
            self.output_image = self.add(self.value)
        if type_process == 'cross':
            self.output_image = self.cross()
        if type_process == 'plan':
            self.output_image = self.plan(self.cross_radius, 100, 5)
        if type_process == 'plan_ref':
            self.output_image = self.plan_ref()
        if type_process == 'increment':
            self.output_image = self.increment_z_inverse()
        if type_process == 'disks':
            self.output_image = self.labelize_from_disks()
        if type_process == 'MSE':
            self.MSE()
            self.fname_output = None
        if type_process == 'remove':
            self.output_image = self.remove_label()
        if type_process == 'remove-symm':
            self.output_image = self.remove_label(symmetry=True)
        if type_process == 'centerline':
            self.extract_centerline()
        if type_process == 'create':
            self.output_image = self.create_label()
        if type_process == 'create-add':
            self.output_image = self.create_label(add=True)
        if type_process == 'create-seg':
            self.output_image = self.create_label_along_segmentation()
        if type_process == 'display-voxel':
            self.display_voxel()
            self.fname_output = None
        if type_process == 'diff':
            self.diff()
            self.fname_output = None
        if type_process == 'dist-inter':  # second argument is in pixel distance
            self.distance_interlabels(5)
            self.fname_output = None
        if type_process == 'cubic-to-point':
            self.output_image = self.cubic_to_point()
        if type_process == 'vert-body':
            self.output_image = self.label_vertebrae(self.vertebral_levels)
        if type_process == 'vert-continuous':
            self.output_image = self.continuous_vertebral_levels()
        if type_process == 'create-viewer':
            self.output_image = self.launch_sagittal_viewer(self.value)

        # save the output image as minimized integers
        if self.fname_output is not None:
            self.output_image.setFileName(self.fname_output)
            if change_orientation:
                self.output_image.change_orientation(input_orientation)
            if type_process == 'vert-continuous':
                self.output_image.save('float32')
            elif type_process != 'plan_ref':
                self.output_image.save('minimize_int')
            else:
                self.output_image.save()

    def add(self, value):
        """
        This function add a specified value to all non-zero voxels.
        """
        image_output = Image(self.image_input, self.verbose)
        # image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[
                int(coord.x), int(coord.y),
                int(coord.z)] = image_output.data[int(coord.x),
                                                  int(coord.y),
                                                  int(coord.z)] + float(value)
        return image_output

    def create_label(self, add=False):
        """
        Create an image with labels listed by the user.
        This method works only if the user inserted correct coordinates.

        self.coordinates is a list of coordinates (class in msct_types).
        a Coordinate contains x, y, z and value.
        If only one label is to be added, coordinates must be completed with '[]'
        examples:
        For one label:  object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types
        For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types
        """
        image_output = self.image_input.copy()
        if not add:
            image_output.data *= 0

        # loop across labels
        for i, coord in enumerate(self.coordinates):
            if len(image_output.data.shape) == 3:
                image_output.data[int(coord.x),
                                  int(coord.y),
                                  int(coord.z)] = coord.value
            elif len(image_output.data.shape) == 2:
                assert str(
                    coord.z
                ) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(
                    coord.z)
                image_output.data[int(coord.x), int(coord.y)] = coord.value
            else:
                sct.printv(
                    'ERROR: Data should be 2D or 3D. Current shape is: ' +
                    str(image_output.data.shape), 1, 'error')
            # display info
            sct.printv(
                'Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) +
                ',' + str(coord.z) + ' --> ' + str(coord.value), 1)
        return image_output

    def create_label_along_segmentation(self):
        """
        Create an image with labels defined along the spinal cord segmentation (or centerline)
        Example:
        object_define=ProcessLabels(fname_segmentation, coordinates=[coord_1, coord_2, coord_i]), where coord_i='z,value'. If z=-1, then use z=nz/2 (i.e. center of FOV in superior-inferior direction)
        Returns
        -------
        image_output: Image object with labels.
        """
        # copy input Image object (will use the same header)
        image_output = self.image_input.copy()
        # set all voxels to 0
        image_output.data *= 0
        # loop across labels
        for i, coord in enumerate(self.coordinates):
            # split coord string
            list_coord = coord.split(',')
            # convert to int() and assign to variable
            z, value = [int(i) for i in list_coord]
            # if z=-1, replace with nz/2
            if z == -1:
                z = int(round(image_output.dim[2] / 2.0))
            # get center of mass of segmentation at given z
            x, y = ndimage.measurements.center_of_mass(
                np.array(self.image_input.data[:, :, z]))
            # round values to make indices
            x, y = int(round(x)), int(round(y))
            # display info
            sct.printv(
                'Label #' + str(i) + ': ' + str(x) + ',' + str(y) + ',' +
                str(z) + ' --> ' + str(value), 1)
            if len(image_output.data.shape) == 3:
                image_output.data[x, y, z] = value
            elif len(image_output.data.shape) == 2:
                assert str(
                    z
                ) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(
                    z)
                image_output.data[x, y] = value
        return image_output

    def cross(self):
        """
        create a cross.
        :return:
        """
        output_image = Image(self.image_input, self.verbose)
        nx, ny, nz, nt, px, py, pz, pt = Image(
            self.image_input.absolutepath).dim

        coordinates_input = self.image_input.getNonZeroCoordinates()
        d = self.cross_radius  # cross radius in pixel
        dx = d / px  # cross radius in mm
        dy = d / py

        # clean output_image
        output_image.data *= 0

        cross_coordinates = self.get_crosses_coordinates(
            coordinates_input, dx, self.image_ref, self.dilate)

        for coord in cross_coordinates:
            output_image.data[int(round(coord.x)),
                              int(round(coord.y)),
                              int(round(coord.z))] = coord.value

        return output_image

    @staticmethod
    def get_crosses_coordinates(coordinates_input,
                                gapxy=15,
                                image_ref=None,
                                dilate=False):
        from msct_types import Coordinate

        # if reference image is provided (segmentation), we draw the cross perpendicular to the centerline
        if image_ref is not None:
            # smooth centerline
            from sct_straighten_spinalcord import smooth_centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                self.image_ref, verbose=self.verbose)

        # compute crosses
        cross_coordinates = []
        for coord in coordinates_input:
            if image_ref is None:
                from sct_straighten_spinalcord import compute_cross
                cross_coordinates_temp = compute_cross(coord, gapxy)
            else:
                from sct_straighten_spinalcord import compute_cross_centerline
                from numpy import where
                index_z = where(z_centerline == coord.z)
                deriv = Coordinate([
                    x_centerline_deriv[index_z][0],
                    y_centerline_deriv[index_z][0],
                    z_centerline_deriv[index_z][0], 0.0
                ])
                cross_coordinates_temp = compute_cross_centerline(
                    coord, deriv, gapxy)

            for i, coord_cross in enumerate(cross_coordinates_temp):
                coord_cross.value = coord.value * 10 + i + 1

            # dilate cross to 3x3x3
            if dilate:
                additional_coordinates = []
                for coord_temp in cross_coordinates_temp:
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y, coord_temp.z + 1.0,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y, coord_temp.z - 1.0,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                cross_coordinates_temp.extend(additional_coordinates)

            cross_coordinates.extend(cross_coordinates_temp)

        cross_coordinates = sorted(cross_coordinates,
                                   key=lambda obj: obj.value)
        return cross_coordinates

    def plan(self, width, offset=0, gap=1):
        """
        Create a plane of thickness="width" and changes its value with an offset and a gap between labels.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for coord in coordinates_input:
            image_output.data[:, :,
                              int(coord.z) - width:int(coord.z) +
                              width] = offset + gap * coord.value

        return image_output

    def plan_ref(self):
        """
        Generate a plane in the reference space for each label present in the input image
        """

        image_output = Image(self.image_ref, self.verbose)
        image_output.data *= 0

        image_input_neg = Image(self.image_input, self.verbose).copy()
        image_input_pos = Image(self.image_input, self.verbose).copy()
        image_input_neg.data *= 0
        image_input_pos.data *= 0
        X, Y, Z = (self.image_input.data < 0).nonzero()
        for i in range(len(X)):
            image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[
                X[i], Y[i], Z[i]]  # in order to apply getNonZeroCoordinates
        X_pos, Y_pos, Z_pos = (self.image_input.data > 0).nonzero()
        for i in range(len(X_pos)):
            image_input_pos.data[X_pos[i], Y_pos[i],
                                 Z_pos[i]] = self.image_input.data[X_pos[i],
                                                                   Y_pos[i],
                                                                   Z_pos[i]]

        coordinates_input_neg = image_input_neg.getNonZeroCoordinates()
        coordinates_input_pos = image_input_pos.getNonZeroCoordinates()

        image_output.changeType('float32')
        for coord in coordinates_input_neg:
            image_output.data[:, :, int(
                coord.z
            )] = -coord.value  # PB: takes the int value of coord.value
        for coord in coordinates_input_pos:
            image_output.data[:, :, int(coord.z)] = coord.value

        return image_output

    def cubic_to_point(self):
        """
        Calculate the center of mass of each group of labels and returns a file of same size with only a
        label by group at the center of mass of this group.
        It is to be used after applying homothetic warping field to a label file as the labels will be dilated.
        Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with
         the same value are present but separated in space, this algorithm will compute the center of mass of the two
         groups together.
        :return: image_output
        """

        # 0. Initialization of output image
        output_image = self.image_input.copy()
        output_image.data *= 0

        # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value.
        coordinates = self.image_input.getNonZeroCoordinates(sorting='value')

        # 2. Separate all coordinates into groups by value
        groups = dict()
        for coord in coordinates:
            if coord.value in groups:
                groups[coord.value].append(coord)
            else:
                groups[coord.value] = [coord]

        # 3. Compute the center of mass of each group of voxels and write them into the output image
        for value, list_coord in groups.items():
            center_of_mass = sum(list_coord) / float(len(list_coord))
            sct.printv("Value = " + str(center_of_mass.value) + " : (" +
                       str(center_of_mass.x) + ", " + str(center_of_mass.y) +
                       ", " + str(center_of_mass.z) + ") --> ( " +
                       str(round(center_of_mass.x)) + ", " +
                       str(round(center_of_mass.y)) + ", " +
                       str(round(center_of_mass.z)) + ")",
                       verbose=self.verbose)
            output_image.data[
                int(round(center_of_mass.x)),
                int(round(center_of_mass.y)),
                int(round(center_of_mass.z))] = center_of_mass.value

        return output_image

    def increment_z_inverse(self):
        """
        Take all non-zero values, sort them along the inverse z direction, and attributes the values 1,
        2, 3, etc. This function assuming RPI orientation.
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates(
            sorting='z', reverse_coord=True)

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i, coord in enumerate(coordinates_input):
            image_output.data[int(coord.x), int(coord.y), int(coord.z)] = i + 1

        return image_output

    def labelize_from_disks(self):
        """
        Create an image with regions labelized depending on values from reference.
        Typically, user inputs a segmentation image, and labels with disks position, and this function produces
        a segmentation image with vertebral levels labelized.
        Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation
        """
        image_output = Image(self.image_input, self.verbose)
        image_output.data *= 0
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value')

        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            for j in range(0, len(coordinates_ref) - 1):
                if coordinates_ref[j + 1].z < coord.z <= coordinates_ref[j].z:
                    image_output.data[int(coord.x),
                                      int(coord.y),
                                      int(coord.z)] = coordinates_ref[j].value

        return image_output

    def label_vertebrae(self, levels_user=None):
        """
        Find the center of mass of vertebral levels specified by the user.
        :return: image_output: Image with labels.
        """
        # get center of mass of each vertebral level
        image_cubic2point = self.cubic_to_point()
        # get list of coordinates for each label
        list_coordinates = image_cubic2point.getNonZeroCoordinates(
            sorting='value')
        # if user did not specify levels, include all:
        if levels_user[0] == 0:
            levels_user = [int(i.value) for i in list_coordinates]
        # loop across labels and remove those that are not listed by the user
        for i_label in range(len(list_coordinates)):
            # check if this level is NOT in levels_user
            if not levels_user.count(int(list_coordinates[i_label].value)):
                # if not, set value to zero
                image_cubic2point.data[int(list_coordinates[i_label].x),
                                       int(list_coordinates[i_label].y),
                                       int(list_coordinates[i_label].z)] = 0
        # list all labels
        return image_cubic2point

    def MSE(self, threshold_mse=0):
        """
        Compute the Mean Square Distance Error between two sets of labels (input and ref).
        Moreover, a warning is generated for each label mismatch.
        If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        # check if all the labels in both the images match
        if len(coordinates_input) != len(coordinates_ref):
            sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord in coordinates_input:
            if round(coord.value) not in [
                    round(coord_ref.value) for coord_ref in coordinates_ref
            ]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')
        for coord_ref in coordinates_ref:
            if round(coord_ref.value) not in [
                    round(coord.value) for coord in coordinates_input
            ]:
                sct.printv('ERROR: labels mismatch', 1, 'warning')

        result = 0.0
        for coord in coordinates_input:
            for coord_ref in coordinates_ref:
                if round(coord_ref.value) == round(coord.value):
                    result += (coord_ref.z - coord.z)**2
                    break
        result = math.sqrt(result / len(coordinates_input))
        sct.printv('MSE error in Z direction = ' + str(result) + ' mm')

        if result > threshold_mse:
            f = open(
                self.image_input.path + 'error_log_' +
                self.image_input.file_name + '.txt', 'w')
            f.write('The labels error (MSE) between ' +
                    self.image_input.file_name + ' and ' +
                    self.image_ref.file_name + ' is: ' + str(result))
            f.close()

        return result

    @staticmethod
    def remove_label_coord(coord_input, coord_ref, symmetry=False):
        """
        coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection
        :param coord_input: set of CoordinateValue
        :param coord_ref: set of CoordinateValue
        :param symmetry: boolean,
        :return: intersection of CoordinateValue: list
        """
        from msct_types import CoordinateValue
        if isinstance(coord_input[0], CoordinateValue) and isinstance(
                coord_ref[0], CoordinateValue) and symmetry:
            coord_intersection = list(
                set(coord_input).intersection(set(coord_ref)))
            result_coord_input = [
                coord for coord in coord_input if coord in coord_intersection
            ]
            result_coord_ref = [
                coord for coord in coord_ref if coord in coord_intersection
            ]
        else:
            result_coord_ref = coord_ref
            result_coord_input = [
                coord for coord in coord_input
                if list(filter(lambda x: x.value == coord.value, coord_ref))
            ]
            if symmetry:
                result_coord_ref = [
                    coord for coord in coord_ref if list(
                        filter(lambda x: x.value == coord.value,
                               result_coord_input))
                ]

        return result_coord_input, result_coord_ref

    def remove_label(self, symmetry=False):
        """
        Compare two label images and remove any labels in input image that are not in reference image.
        The symmetry option enables to remove labels from reference image that are not in input image
        """
        # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose)
        image_output = Image(self.image_input, verbose=self.verbose)
        image_output.data *= 0  # put all voxels to 0

        result_coord_input, result_coord_ref = self.remove_label_coord(
            self.image_input.getNonZeroCoordinates(coordValue=True),
            self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry)

        for coord in result_coord_input:
            image_output.data[int(coord.x),
                              int(coord.y),
                              int(coord.z)] = int(round(coord.value))

        if symmetry:
            # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose)
            image_output_ref = Image(self.image_ref, verbose=self.verbose)
            for coord in result_coord_ref:
                image_output_ref.data[int(coord.x),
                                      int(coord.y),
                                      int(coord.z)] = int(round(coord.value))
            image_output_ref.setFileName(self.fname_output[1])
            image_output_ref.save('minimize_int')

            self.fname_output = self.fname_output[0]

        return image_output

    def extract_centerline(self):
        """
        Write a text file with the coordinates of the centerline.
        The image is suppose to be RPI
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z')

        fo = open(self.fname_output, "wb")
        for coord in coordinates_input:
            line = (coord.x, coord.y, coord.z)
            fo.write("%i %i %i\n" % line)
        fo.close()

    def display_voxel(self):
        """
        Display all the labels that are contained in the input image.
        The image is suppose to be RPI to display voxels. But works also for other orientations
        """
        coordinates_input = self.image_input.getNonZeroCoordinates(
            sorting='value')
        self.useful_notation = ''
        for coord in coordinates_input:
            sct.printv('Position=(' + str(coord.x) + ',' + str(coord.y) + ',' +
                       str(coord.z) + ') -- Value= ' + str(coord.value),
                       verbose=self.verbose)
            if self.useful_notation:
                self.useful_notation = self.useful_notation + ':'
            self.useful_notation += str(coord)
        sct.printv('All labels (useful syntax):', verbose=self.verbose)
        sct.printv(self.useful_notation, verbose=self.verbose)
        return coordinates_input

    def get_physical_coordinates(self):
        """
        This function returns the coordinates of the labels in the physical referential system.
        :return: a list of CoordinateValue, in the physical (scanner) space
        """
        coord = self.image_input.getNonZeroCoordinates(sorting='value')
        phys_coord = []
        for c in coord:
            # convert pixelar coordinates to physical coordinates
            c_p = self.image_input.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            phys_coord.append(
                CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return phys_coord

    def get_coordinates_in_destination(self, im_dest, type='discrete'):
        """
        This function calculate the position of labels in the pixelar space of a destination image
        :param im_dest: Object Image
        :param type: 'discrete' or 'continuous'
        :return: a list of CoordinateValue, in the pixelar (image) space of the destination image
        """
        phys_coord = self.get_physical_coordinates()
        dest_coord = []
        for c in phys_coord:
            if type is 'discrete':
                c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]])[0]
            elif type is 'continuous':
                c_p = im_dest.transfo_phys2continuouspix([[c.x, c.y, c.y]])[0]
            else:
                raise ValueError(
                    "The value of 'type' should either be 'discrete' or 'continuous'."
                )
            dest_coord.append(
                CoordinateValue([c_p[0], c_p[1], c_p[2], c.value]))
        return dest_coord

    def diff(self):
        """
        Detect any label mismatch between input image and reference image
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()
        coordinates_ref = self.image_ref.getNonZeroCoordinates()

        sct.printv("Label in input image that are not in reference image:")
        for coord in coordinates_input:
            isIn = False
            for coord_ref in coordinates_ref:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord.value)

        sct.printv("Label in ref image that are not in input image:")
        for coord_ref in coordinates_ref:
            isIn = False
            for coord in coordinates_input:
                if coord.value == coord_ref.value:
                    isIn = True
                    break
            if not isIn:
                sct.printv(coord_ref.value)

    def distance_interlabels(self, max_dist):
        """
        Calculate the distances between each label in the input image.
        If a distance is larger than max_dist, a warning message is displayed.
        """
        coordinates_input = self.image_input.getNonZeroCoordinates()

        # for all points with non-zeros neighbors, force the neighbors to 0
        for i in range(0, len(coordinates_input) - 1):
            dist = math.sqrt(
                (coordinates_input[i].x - coordinates_input[i + 1].x)**2 +
                (coordinates_input[i].y - coordinates_input[i + 1].y)**2 +
                (coordinates_input[i].z - coordinates_input[i + 1].z)**2)
            if dist < max_dist:
                sct.printv('Warning: the distance between label ' + str(i) +
                           '[' + str(coordinates_input[i].x) + ',' +
                           str(coordinates_input[i].y) + ',' +
                           str(coordinates_input[i].z) + ']=' +
                           str(coordinates_input[i].value) + ' and label ' +
                           str(i + 1) + '[' + str(coordinates_input[i + 1].x) +
                           ',' + str(coordinates_input[i + 1].y) + ',' +
                           str(coordinates_input[i + 1].z) + ']=' +
                           str(coordinates_input[i + 1].value) +
                           ' is larger than ' + str(max_dist) + '. Distance=' +
                           str(dist))

    def continuous_vertebral_levels(self):
        """
        This function transforms the vertebral levels file from the template into a continuous file.
        Instead of having integer representing the vertebral level on each slice, a continuous value that represents
        the position of the slice in the vertebral level coordinate system.
        The image must be RPI
        :return:
        """
        im_input = Image(self.image_input, self.verbose)
        im_output = Image(self.image_input, self.verbose)
        im_output.data *= 0

        # 1. extract vertebral levels from input image
        #   a. extract centerline
        #   b. for each slice, extract corresponding level
        nx, ny, nz, nt, px, py, pz, pt = im_input.dim
        from sct_straighten_spinalcord import smooth_centerline
        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
            self.image_input, algo_fitting='nurbs', verbose=0)
        value_centerline = np.array([
            im_input.data[int(x_centerline_fit[it]),
                          int(y_centerline_fit[it]),
                          int(z_centerline_fit[it])]
            for it in range(len(z_centerline_fit))
        ])

        # 2. compute distance for each vertebral level --> Di for i being the vertebral levels
        vertebral_levels = {}
        for slice_image, level in enumerate(value_centerline):
            if level not in vertebral_levels:
                vertebral_levels[level] = slice_image

        length_levels = {}
        for level in vertebral_levels:
            indexes_slice = np.where(value_centerline == level)
            length_levels[level] = np.sum([
                math.sqrt(
                    ((x_centerline_fit[indexes_slice[0][index_slice + 1]] -
                      x_centerline_fit[indexes_slice[0][index_slice]]) *
                     px)**2 +
                    ((y_centerline_fit[indexes_slice[0][index_slice + 1]] -
                      y_centerline_fit[indexes_slice[0][index_slice]]) *
                     py)**2 +
                    ((z_centerline_fit[indexes_slice[0][index_slice + 1]] -
                      z_centerline_fit[indexes_slice[0][index_slice]]) *
                     pz)**2)
                for index_slice in range(len(indexes_slice[0]) - 1)
            ])

        # 2. for each slice:
        #   a. identify corresponding vertebral level --> i
        #   b. calculate distance of slice from upper vertebral level --> d
        #   c. compute relative distance in the vertebral level coordinate system --> d/Di
        continuous_values = {}
        for it, iz in enumerate(z_centerline_fit):
            level = value_centerline[it]
            indexes_slice = np.where(value_centerline == level)
            indexes_slice = indexes_slice[0][indexes_slice[0] >= it]
            distance_from_level = np.sum([
                math.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] -
                            x_centerline_fit[indexes_slice[index_slice]]) *
                           px * px)**2 +
                          ((y_centerline_fit[indexes_slice[index_slice + 1]] -
                            y_centerline_fit[indexes_slice[index_slice]]) *
                           py * py)**2 +
                          ((z_centerline_fit[indexes_slice[index_slice + 1]] -
                            z_centerline_fit[indexes_slice[index_slice]]) *
                           pz * pz)**2)
                for index_slice in range(len(indexes_slice) - 1)
            ])
            continuous_values[iz] = level + 2.0 * distance_from_level / float(
                length_levels[level])

        # 3. saving data
        # for each slice, get all non-zero pixels and replace with continuous values
        coordinates_input = self.image_input.getNonZeroCoordinates()
        im_output.changeType('float32')
        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            im_output.data[int(coord.x),
                           int(coord.y),
                           int(coord.z)] = continuous_values[coord.z]

        return im_output

    def launch_sagittal_viewer(self, labels):
        from spinalcordtoolbox.gui import base
        from spinalcordtoolbox.gui.sagittal import launch_sagittal_dialog

        params = base.AnatomicalParams()
        params.vertebraes = labels
        params.input_file_name = self.image_input.file_name
        params.output_file_name = self.fname_output
        params.subtitle = self.msg
        output = self.image_input.copy()
        output.data *= 0
        output.setFileName(self.fname_output)
        launch_sagittal_dialog(self.image_input, output, params)

        return output
def validate_scad(folder_input, contrast):
    """
    Expecting folder to have the following structure :
    errsm_01:
    - t2
    -- errsm_01.nii.gz or t2.nii.gz
    --
    :param folder_input:
    :return:
    """
    from sct_get_centerline import ind2sub
    import time
    import math
    import numpy

    t0 = time.time()

    current_folder = os.getcwd()
    os.chdir(folder_input)

    try:
        patients = next(os.walk('.'))[1]
        overall_distance = {}
        max_distance = {}
        standard_deviation = 0
        overall_std = {}
        rmse = {}
        for i in patients:
            directory = i + "/" + str(contrast)
            try:
                os.chdir(directory)
            except Exception, e:
                print str(i)+" : "+contrast+" directory not found"

            try:
                if os.path.isfile(i+"_"+contrast+".nii.gz"):
                    raw_image = Image(i+"_"+contrast+".nii.gz")
                elif os.path.isfile(contrast+".nii.gz"):
                    raw_image = Image(contrast+".nii.gz")
                else:
                    raise Exception("Patient scan not found")

                if os.path.isfile(i+"_"+contrast+"_manual_segmentation.nii.gz"):
                    raw_orientation = raw_image.change_orientation()
                    scad = SCAD(raw_image, contrast=contrast, rm_tmp_file=1, verbose=1)
                    scad.execute()

                    manual_seg = Image(i+"_"+contrast+"_manual_segmentation.nii.gz")
                    manual_orientation = manual_seg.change_orientation()

                    from scipy.ndimage.measurements import center_of_mass
                    # find COM
                    iterator = range(manual_seg.data.shape[2])
                    com_x = [0 for ix in iterator]
                    com_y = [0 for iy in iterator]

                    for iz in iterator:
                        com_x[iz], com_y[iz] = center_of_mass(manual_seg.data[:, :, iz])

                    centerline_scad = Image(i+"_"+contrast+"_centerline.nii.gz")
                    # os.remove(i+"_"+contrast+"_centerline.nii.gz")

                    centerline_scad.change_orientation()
                    distance = {}
                    for iz in range(1, centerline_scad.data.shape[2]-1):
                        ind1 = np.argmax(centerline_scad.data[:, :, iz])
                        X,Y = ind2sub(centerline_scad.data[:, :, iz].shape,ind1)
                        com_phys = np.array(manual_seg.transfo_pix2phys([[com_x[iz], com_y[iz], iz]]))
                        scad_phys = np.array(centerline_scad.transfo_pix2phys([[X, Y, iz]]))
                        distance_magnitude = np.linalg.norm([com_phys[0][0]-scad_phys[0][0], com_phys[0][1]-scad_phys[0][1], 0])
                        if math.isnan(distance_magnitude):
                            print "Value is nan"
                        else:
                            distance[iz] = distance_magnitude

                    f = open(i+"_"+contrast+"_results.txt", 'w+')
                    f.write("Patient,Slice,Distance")
                    for key, value in distance.items():
                        f.write(i+","+str(key)+","+str(value))

                    standard_deviation = np.std(np.array(distance.values()))
                    average = sum(distance.values())/len(distance)
                    root_mean_square = np.sqrt(np.mean(np.square(distance.values())))

                    f.write("\nAverage : "+str(average))
                    f.write("\nStandard Deviation : "+str(standard_deviation))

                    f.close()

                    overall_distance[i] = average
                    max_distance[i] = max(distance.values())
                    overall_std[i] = standard_deviation
                    rmse[i] = root_mean_square

                else:
                    printv("Cannot find the manual segmentation", type="warning")

            except Exception, e:
                print e.message

            os.chdir(folder_input)
def project_labels_on_spinalcord(fname_label, fname_seg):
    """
    Project labels orthogonally on the spinal cord centerline. The algorithm works by finding the smallest distance
    between each label and the spinal cord center of mass.
    :param fname_label: file name of labels
    :param fname_seg: file name of cord segmentation (could also be of centerline)
    :return: file name of projected labels
    """
    # build output name
    fname_label_projected = sct.add_suffix(fname_label, "_projected")
    # open labels and segmentation
    im_label = Image(fname_label)
    im_seg = Image(fname_seg)
    # orient to RPI
    native_orient = im_seg.change_orientation('RPI')
    im_label.change_orientation('RPI')
    # smooth centerline and return fitted coordinates in voxel space
    centerline_x, centerline_y, centerline_z, centerline_derivx, centerline_derivy, centerline_derivz = smooth_centerline(
        im_seg,
        algo_fitting="hanning",
        type_window="hanning",
        window_length=50,
        nurbs_pts_number=3000,
        phys_coordinates=False,
        all_slices=True)
    # convert pixel into physical coordinates
    centerline_xyz_transposed = [
        im_seg.transfo_pix2phys(
            [[centerline_x[i], centerline_y[i], centerline_z[i]]])[0]
        for i in range(len(centerline_x))
    ]
    # transpose list
    centerline_phys_x, centerline_phys_y, centerline_phys_z = map(
        list, map(None, *centerline_xyz_transposed))
    # get center of mass of label
    labels = im_label.getCoordinatesAveragedByValue()
    # initialize image of projected labels. Note that we use the space of the seg (not label).
    im_label_projected = im_seg.copy()
    im_label_projected.data = np.zeros(im_label_projected.data.shape,
                                       dtype='uint8')
    # loop across label values
    for label in labels:
        # convert pixel into physical coordinates for the label
        label_phys_x, label_phys_y, label_phys_z = im_label.transfo_pix2phys(
            [[label.x, label.y, label.z]])[0]
        # calculate distance between label and each point of the centerline
        distance_centerline = [
            np.linalg.norm([
                centerline_phys_x[i] - label_phys_x,
                centerline_phys_y[i] - label_phys_y,
                centerline_phys_z[i] - label_phys_z
            ]) for i in range(len(centerline_x))
        ]
        # get the index corresponding to the min distance
        ind_min_distance = np.argmin(distance_centerline)
        # get centerline coordinate (in physical space)
        [min_phy_x, min_phy_y, min_phy_z] = [
            centerline_phys_x[ind_min_distance],
            centerline_phys_y[ind_min_distance],
            centerline_phys_z[ind_min_distance]
        ]
        # convert coordinate to voxel space
        minx, miny, minz = im_seg.transfo_phys2pix(
            [[min_phy_x, min_phy_y, min_phy_z]])[0]
        # use that index to assign projected label in the centerline
        im_label_projected.data[minx, miny, minz] = label.value
    # re-orient projected labels to native orientation and save
    im_label_projected.change_orientation(
        native_orient)  # note: native_orient refers to im_seg (not im_label)
    im_label_projected.setFileName(fname_label_projected)
    im_label_projected.save()
    return fname_label_projected
Exemple #21
0
def test(path_data='', parameters=''):

    if not parameters:
        parameters = '-i t2/t2.nii.gz -c t2 -p auto'

    # parameters
    folder_data = 't2/'
    file_data = [
        't2.nii.gz', 't2_centerline_init.nii.gz',
        't2_centerline_labels.nii.gz', 't2_seg_manual.nii.gz'
    ]

    parser = sct_get_centerline.get_parser()
    dict_param = parser.parse(parameters.split(), check_file_exist=False)
    contrast = dict_param['-c']
    dict_param_with_path = parser.add_path_to_file(dict_param,
                                                   path_data,
                                                   input_file=True)
    param_with_path = parser.dictionary_to_string(dict_param_with_path)

    # Check if input files exist
    if not (os.path.isfile(dict_param_with_path['-i'])):
        status = 200
        output = 'ERROR: the file(s) provided to test function do not exist in folder: ' + path_data
        return status, output, DataFrame(data={
            'status': status,
            'output': output,
            'mse': float('nan'),
            'dist_max': float('nan')
        },
                                         index=[path_data])

    cmd = 'sct_get_centerline ' + param_with_path
    status, output = sct.run(cmd, 0)
    scad_centerline = Image(contrast + "_centerline.nii.gz")
    manual_seg = Image(path_data + folder_data + contrast +
                       '_seg_manual.nii.gz')

    max_distance = 0
    standard_deviation = 0
    average = 0
    root_mean_square = 0
    overall_distance = 0
    max_distance = 0
    overall_std = 0
    rmse = 0

    try:
        if status == 0:
            manual_seg.change_orientation()
            scad_centerline.change_orientation()
            from scipy.ndimage.measurements import center_of_mass
            # find COM
            iterator = range(manual_seg.data.shape[2])
            com_x = [0 for ix in iterator]
            com_y = [0 for iy in iterator]

            for iz in iterator:
                com_x[iz], com_y[iz] = center_of_mass(manual_seg.data[:, :,
                                                                      iz])
            max_distance = {}
            distance = {}
            for iz in range(1, scad_centerline.data.shape[2] - 1):
                ind1 = np.argmax(scad_centerline.data[:, :, iz])
                X, Y = ind2sub(scad_centerline.data[:, :, iz].shape, ind1)
                com_phys = np.array(
                    manual_seg.transfo_pix2phys([[com_x[iz], com_y[iz], iz]]))
                scad_phys = np.array(
                    scad_centerline.transfo_pix2phys([[X, Y, iz]]))
                distance_magnitude = np.linalg.norm([
                    com_phys[0][0] - scad_phys[0][0],
                    com_phys[0][1] - scad_phys[0][1], 0
                ])
                if math.isnan(distance_magnitude):
                    print "Value is nan"
                else:
                    distance[iz] = distance_magnitude

            max_distance = max(distance.values())
            standard_deviation = np.std(np.array(distance.values()))
            average = sum(distance.values()) / len(distance)
            root_mean_square = np.sqrt(np.mean(np.square(distance.values())))
            overall_distance = average
            max_distance = max(distance.values())
            overall_std = standard_deviation
            rmse = root_mean_square

    except Exception, e:
        sct.printv("Exception found while testing scad integrity")
        output = e.message
#
# generate_warping_field('data_T2_RPI.nii.gz', x_disp_2_smooth, y_disp_2_smooth, fname='warping_field_im_trans.nii.gz')
# sct.run('sct_apply_transfo -i data_RPI_registered_reg1.nii.gz -d data_T2_RPI.nii.gz -w warping_field_im_trans.nii.gz -o data_RPI_registered_reg2.nii.gz -x spline')


f_1 = "/Users/tamag/data/data_template/independant_templates/Results_magma/t2_avg_RPI.nii.gz"
f_2 = "/Users/tamag/data/data_template/independant_templates/Results_magma/t1_avg.independent_RPI_reg1_unpad.nii.gz"
f_3 = "/Users/tamag/data/data_template/independant_templates/Results_magma/t1_avg.independent_RPI.nii.gz"

os.chdir("/Users/tamag/data/data_template/independant_templates/Results_magma")

im_1 = Image(f_1)
im_2 = Image(f_2)

data_1 = im_1.data

coord_test1 = [[1,1,1]]
coord_test = [[1,1,1],[2,2,2],[3,3,3]]

coordi_phys = im_1.transfo_pix2phys(coordi=coord_test)
coordi_pix = im_1.transfo_phys2pix(coordi = coordi_phys)
bla

# im_3 = nibabel.load(f_3)
# data_3 = im_3.get_data()
# hdr_3 = im_3.get_header()
#
# data_f = data_3 - laplace(data_3)
#
# img_f = nibabel.Nifti1Image(data_f, None, hdr_3)
# nibabel.save(img_f, "rehauss.nii.gz")
Exemple #23
0
def register2d_columnwise(fname_src, fname_dest, fname_warp='warp_forward.nii.gz', fname_warp_inv='warp_inverse.nii.gz', verbose=0, path_qc='./', smoothWarpXY=1):
    """
    Column-wise non-linear registration of segmentations. Based on an idea from Allan Martin.
    - Assumes src/dest are segmentations (not necessarily binary), and already registered by center of mass
    - Assumes src/dest are in RPI orientation.
    - Split along Z, then for each slice:
    - scale in R-L direction to match src/dest
    - loop across R-L columns and register by (i) matching center of mass and (ii) scaling.
    :param fname_src:
    :param fname_dest:
    :param fname_warp:
    :param fname_warp_inv:
    :param verbose:
    :return:
    """

    # initialization
    th_nonzero = 0.5  # values below are considered zero

    # for display stuff
    if verbose == 2:
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt

    # Get image dimensions and retrieve nz
    sct.printv('\nGet image dimensions of destination image...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    sct.printv('  matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
    sct.printv('  voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm', verbose)

    # Split source volume along z
    sct.printv('\nSplit input volume...', verbose)
    from sct_image import split_data
    im_src = Image('src.nii')
    split_source_list = split_data(im_src, 2)
    for im in split_source_list:
        im.save()

    # Split destination volume along z
    sct.printv('\nSplit destination volume...', verbose)
    im_dest = Image('dest.nii')
    split_dest_list = split_data(im_dest, 2)
    for im in split_dest_list:
        im.save()

    # open image
    data_src = im_src.data
    data_dest = im_dest.data

    if len(data_src.shape) == 2:
        # reshape 2D data into pseudo 3D (only one slice)
        new_shape = list(data_src.shape)
        new_shape.append(1)
        new_shape = tuple(new_shape)
        data_src = data_src.reshape(new_shape)
        data_dest = data_dest.reshape(new_shape)

    # initialize forward warping field (defined in destination space)
    warp_x = np.zeros(data_dest.shape)
    warp_y = np.zeros(data_dest.shape)

    # initialize inverse warping field (defined in source space)
    warp_inv_x = np.zeros(data_src.shape)
    warp_inv_y = np.zeros(data_src.shape)

    # Loop across slices
    sct.printv('\nEstimate columnwise transformation...', verbose)
    for iz in range(0, nz):
        print str(iz)+'/'+str(nz)+'..',

        # PREPARE COORDINATES
        # ============================================================
        # get indices of x and y coordinates
        row, col = np.indices((nx, ny))
        # build 2xn array of coordinates in pixel space
        # ordering of indices is as follows:
        # coord_init_pix[:, 0] = 0, 0, 0, ..., 1, 1, 1..., nx, nx, nx
        # coord_init_pix[:, 1] = 0, 1, 2, ..., 0, 1, 2..., 0, 1, 2
        coord_init_pix = np.array([row.ravel(), col.ravel(), np.array(np.ones(len(row.ravel())) * iz)]).T
        # convert coordinates to physical space
        coord_init_phy = np.array(im_src.transfo_pix2phys(coord_init_pix))
        # get 2d data from the selected slice
        src2d = data_src[:, :, iz]
        dest2d = data_dest[:, :, iz]
        # julien 20161105
        #<<<
        # threshold at 0.5
        src2d[src2d < th_nonzero] = 0
        dest2d[dest2d < th_nonzero] = 0
        # get non-zero coordinates, and transpose to obtain nx2 dimensions
        coord_src2d = np.array(np.where(src2d > 0)).T
        coord_dest2d = np.array(np.where(dest2d > 0)).T
        # here we use 0.5 as threshold for non-zero value
        # coord_src2d = np.array(np.where(src2d > th_nonzero)).T
        # coord_dest2d = np.array(np.where(dest2d > th_nonzero)).T
        #>>>

        # SCALING R-L (X dimension)
        # ============================================================
        # sum data across Y to obtain 1D signal: src_y and dest_y
        src1d = np.sum(src2d, 1)
        dest1d = np.sum(dest2d, 1)
        # make sure there are non-zero data in src or dest
        if np.any(src1d > th_nonzero) and np.any(dest1d > th_nonzero):
            # retrieve min/max of non-zeros elements (edge of the segmentation)
            # julien 20161105
            # <<<
            src1d_min, src1d_max = min(np.where(src1d != 0)[0]), max(np.where(src1d != 0)[0])
            dest1d_min, dest1d_max = min(np.where(dest1d != 0)[0]), max(np.where(dest1d != 0)[0])
            # for i in xrange(len(src1d)):
            #     if src1d[i] > 0.5:
            #         found index above 0.5, exit loop
                    # break
            # get indices (in continuous space) at half-maximum of upward and downward slope
            # src1d_min, src1d_max = find_index_halfmax(src1d)
            # dest1d_min, dest1d_max = find_index_halfmax(dest1d)
            # >>>
            # 1D matching between src_y and dest_y
            mean_dest_x = (dest1d_max + dest1d_min) / 2
            mean_src_x = (src1d_max + src1d_min) / 2
            # compute x-scaling factor
            Sx = (dest1d_max - dest1d_min + 1) / float(src1d_max - src1d_min + 1)
            # apply transformation to coordinates
            coord_src2d_scaleX = np.copy(coord_src2d)  # need to use np.copy to avoid copying pointer
            coord_src2d_scaleX[:, 0] = (coord_src2d[:, 0] - mean_src_x) * Sx + mean_dest_x
            coord_init_pix_scaleX = np.copy(coord_init_pix)
            coord_init_pix_scaleX[:, 0] = (coord_init_pix[:, 0] - mean_src_x) * Sx + mean_dest_x
            coord_init_pix_scaleXinv = np.copy(coord_init_pix)
            coord_init_pix_scaleXinv[:, 0] = (coord_init_pix[:, 0] - mean_dest_x) / float(Sx) + mean_src_x
            # apply transformation to image
            from skimage.transform import warp
            row_scaleXinv = np.reshape(coord_init_pix_scaleXinv[:, 0], [nx, ny])
            src2d_scaleX = warp(src2d, np.array([row_scaleXinv, col]), order=1)

            # ============================================================
            # COLUMN-WISE REGISTRATION (Y dimension for each Xi)
            # ============================================================
            coord_init_pix_scaleY = np.copy(coord_init_pix)  # need to use np.copy to avoid copying pointer
            coord_init_pix_scaleYinv = np.copy(coord_init_pix)  # need to use np.copy to avoid copying pointer
            # coord_src2d_scaleXY = np.copy(coord_src2d_scaleX)  # need to use np.copy to avoid copying pointer
            # loop across columns (X dimension)
            for ix in xrange(nx):
                # retrieve 1D signal along Y
                src1d = src2d_scaleX[ix, :]
                dest1d = dest2d[ix, :]
                # make sure there are non-zero data in src or dest
                if np.any(src1d>th_nonzero) and np.any(dest1d>th_nonzero):
                    # retrieve min/max of non-zeros elements (edge of the segmentation)
                    # src1d_min, src1d_max = min(np.nonzero(src1d)[0]), max(np.nonzero(src1d)[0])
                    # dest1d_min, dest1d_max = min(np.nonzero(dest1d)[0]), max(np.nonzero(dest1d)[0])
                    # 1D matching between src_y and dest_y
                    # Ty = (dest1d_max + dest1d_min)/2 - (src1d_max + src1d_min)/2
                    # Sy = (dest1d_max - dest1d_min) / float(src1d_max - src1d_min)
                    # apply translation and scaling to coordinates in column
                    # get indices (in continuous space) at half-maximum of upward and downward slope
                    # src1d_min, src1d_max = find_index_halfmax(src1d)
                    # dest1d_min, dest1d_max = find_index_halfmax(dest1d)
                    src1d_min, src1d_max = np.min(np.where(src1d > th_nonzero)), np.max(np.where(src1d > th_nonzero))
                    dest1d_min, dest1d_max = np.min(np.where(dest1d > th_nonzero)), np.max(np.where(dest1d > th_nonzero))
                    # 1D matching between src_y and dest_y
                    mean_dest_y = (dest1d_max + dest1d_min) / 2
                    mean_src_y = (src1d_max + src1d_min) / 2
                    # Tx = (dest1d_max + dest1d_min)/2 - (src1d_max + src1d_min)/2
                    Sy = (dest1d_max - dest1d_min + 1) / float(src1d_max - src1d_min + 1)
                    # apply forward transformation (in pixel space)
                    # below: only for debugging purpose
                    # coord_src2d_scaleX = np.copy(coord_src2d)  # need to use np.copy to avoid copying pointer
                    # coord_src2d_scaleX[:, 0] = (coord_src2d[:, 0] - mean_src) * Sx + mean_dest
                    # coord_init_pix_scaleY = np.copy(coord_init_pix)  # need to use np.copy to avoid copying pointer
                    # coord_init_pix_scaleY[:, 0] = (coord_init_pix[:, 0] - mean_src ) * Sx + mean_dest
                    range_x = range(ix * ny, ix * ny + nx)
                    coord_init_pix_scaleY[range_x, 1] = (coord_init_pix[range_x, 1] - mean_src_y) * Sy + mean_dest_y
                    coord_init_pix_scaleYinv[range_x, 1] = (coord_init_pix[range_x, 1] - mean_dest_y) / float(Sy) + mean_src_y
            # apply transformation to image
            col_scaleYinv = np.reshape(coord_init_pix_scaleYinv[:, 1], [nx, ny])
            src2d_scaleXY = warp(src2d, np.array([row_scaleXinv, col_scaleYinv]), order=1)
            # regularize Y warping fields
            from skimage.filters import gaussian
            col_scaleY = np.reshape(coord_init_pix_scaleY[:, 1], [nx, ny])
            col_scaleYsmooth = gaussian(col_scaleY, smoothWarpXY)
            col_scaleYinvsmooth = gaussian(col_scaleYinv, smoothWarpXY)
            # apply smoothed transformation to image
            src2d_scaleXYsmooth = warp(src2d, np.array([row_scaleXinv, col_scaleYinvsmooth]), order=1)
            # reshape warping field as 1d
            coord_init_pix_scaleY[:, 1] = col_scaleYsmooth.ravel()
            coord_init_pix_scaleYinv[:, 1] = col_scaleYinvsmooth.ravel()
            # display
            if verbose == 2:
                # FIG 1
                plt.figure(figsize=(15, 3))
                # plot #1
                ax = plt.subplot(141)
                plt.imshow(np.swapaxes(src2d, 1, 0), cmap=plt.cm.gray, interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0), cmap=plt.cm.copper, interpolation='none', alpha=0.5)
                plt.title('src')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # plot #2
                ax = plt.subplot(142)
                plt.imshow(np.swapaxes(src2d_scaleX, 1, 0), cmap=plt.cm.gray, interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0), cmap=plt.cm.copper, interpolation='none', alpha=0.5)
                plt.title('src_scaleX')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # plot #3
                ax = plt.subplot(143)
                plt.imshow(np.swapaxes(src2d_scaleXY, 1, 0), cmap=plt.cm.gray, interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0), cmap=plt.cm.copper, interpolation='none', alpha=0.5)
                plt.title('src_scaleXY')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # plot #4
                ax = plt.subplot(144)
                plt.imshow(np.swapaxes(src2d_scaleXYsmooth, 1, 0), cmap=plt.cm.gray, interpolation='none')
                plt.hold(True)  # add other layer
                plt.imshow(np.swapaxes(dest2d, 1, 0), cmap=plt.cm.copper, interpolation='none', alpha=0.5)
                plt.title('src_scaleXYsmooth (s='+str(smoothWarpXY)+')')
                plt.xlabel('x')
                plt.ylabel('y')
                plt.xlim(mean_dest_x - 15, mean_dest_x + 15)
                plt.ylim(mean_dest_y - 15, mean_dest_y + 15)
                ax.grid(True, color='w')
                # save figure
                plt.savefig(path_qc + 'register2d_columnwise_image_z' + str(iz) + '.png')
                plt.close()

            # ============================================================
            # CALCULATE TRANSFORMATIONS
            # ============================================================
            # calculate forward transformation (in physical space)
            coord_init_phy_scaleX = np.array(im_dest.transfo_pix2phys(coord_init_pix_scaleX))
            coord_init_phy_scaleY = np.array(im_dest.transfo_pix2phys(coord_init_pix_scaleY))
            # calculate inverse transformation (in physical space)
            coord_init_phy_scaleXinv = np.array(im_src.transfo_pix2phys(coord_init_pix_scaleXinv))
            coord_init_phy_scaleYinv = np.array(im_src.transfo_pix2phys(coord_init_pix_scaleYinv))
            # compute displacement per pixel in destination space (for forward warping field)
            warp_x[:, :, iz] = np.array([coord_init_phy_scaleXinv[i, 0] - coord_init_phy[i, 0] for i in xrange(nx*ny)]).reshape((nx, ny))
            warp_y[:, :, iz] = np.array([coord_init_phy_scaleYinv[i, 1] - coord_init_phy[i, 1] for i in xrange(nx*ny)]).reshape((nx, ny))
            # compute displacement per pixel in source space (for inverse warping field)
            warp_inv_x[:, :, iz] = np.array([coord_init_phy_scaleX[i, 0] - coord_init_phy[i, 0] for i in xrange(nx*ny)]).reshape((nx, ny))
            warp_inv_y[:, :, iz] = np.array([coord_init_phy_scaleY[i, 1] - coord_init_phy[i, 1] for i in xrange(nx*ny)]).reshape((nx, ny))

    # Generate forward warping field (defined in destination space)
    generate_warping_field(fname_dest, warp_x, warp_y, fname_warp, verbose)
    # Generate inverse warping field (defined in source space)
    generate_warping_field(fname_src, warp_inv_x, warp_inv_y, fname_warp_inv, verbose)
def main():
    parser = get_parser()
    param = Param()

    arguments = parser.parse(sys.argv[1:])

    # get arguments
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    fname_landmarks = arguments['-l']
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''
    path_template = sct.slash_at_the_end(arguments['-t'], 1)
    contrast_template = arguments['-c']
    remove_temp_files = int(arguments['-r'])
    verbose = int(arguments['-v'])
    if '-param-straighten' in arguments:
        param.param_straighten = arguments['-param-straighten']
    if 'cpu-nb' in arguments:
        arg_cpu = ' -cpu-nb '+arguments['-cpu-nb']
    else:
        arg_cpu = ''
    if '-param' in arguments:
        paramreg_user = arguments['-param']
        # update registration parameters
        for paramStep in paramreg_user:
            paramreg.addStep(paramStep)

    # initialize other parameters
    file_template_label = param.file_template_label
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # capitalize letters for contrast
    if contrast_template == 't1':
        contrast_template = 'T1'
    elif contrast_template == 't2':
        contrast_template = 'T2'

    # retrieve file_template based on contrast
    fname_template_list = glob(path_template+param.folder_template+'*'+contrast_template+'.nii.gz')
    # TODO: make sure there is only one file -- check if file is there otherwise it crashes
    fname_template = fname_template_list[0]

    # retrieve file_template_seg
    fname_template_seg_list = glob(path_template+param.folder_template+'*cord.nii.gz')
    # TODO: make sure there is only one file
    fname_template_seg = fname_template_seg_list[0]

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template+param.folder_template)

    # get fname of the template + template objects
    # fname_template = sct.slash_at_the_end(path_template, 1)+file_template
    fname_template_label = sct.slash_at_the_end(path_template, 1)+file_template_label
    # fname_template_seg = sct.slash_at_the_end(path_template, 1)+file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 '+fname_data, verbose)
    sct.printv('.. Landmarks:            '+fname_landmarks, verbose)
    sct.printv('.. Segmentation:         '+fname_seg, verbose)
    sct.printv('.. Path template:        '+path_template, verbose)
    sct.printv('.. Path output:          '+path_output, verbose)
    sct.printv('.. Output type:          '+str(output_type), verbose)
    sct.printv('.. Remove temp files:    '+str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps)+1):
        sct.printv('Step #'+paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #'+paramreg.steps[str(pStep)].type, verbose)
        sct.printv('.. Algorithm................ '+paramreg.steps[str(pStep)].algo, verbose)
        sct.printv('.. Metric................... '+paramreg.steps[str(pStep)].metric, verbose)
        sct.printv('.. Number of iterations..... '+paramreg.steps[str(pStep)].iter, verbose)
        sct.printv('.. Shrink factor............ '+paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv('.. Smoothing factor......... '+paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv('.. Gradient step............ '+paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv('.. Degree of polynomial..... '+paramreg.steps[str(pStep)].poly, verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv('ERROR: Wrong landmarks input. All labels must be different.', verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # create temporary folder
    path_tmp = sct.tmp_create(verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    sct.run('sct_convert -i '+fname_data+' -o '+path_tmp+ftmp_data)
    sct.run('sct_convert -i '+fname_seg+' -o '+path_tmp+ftmp_seg)
    sct.run('sct_convert -i '+fname_landmarks+' -o '+path_tmp+ftmp_label)
    sct.run('sct_convert -i '+fname_template+' -o '+path_tmp+ftmp_template)
    sct.run('sct_convert -i '+fname_template_seg+' -o '+path_tmp+ftmp_template_seg)
    sct.run('sct_convert -i '+fname_template_label+' -o '+path_tmp+ftmp_template_label)

    # go to tmp folder
    os.chdir(path_tmp)

    # smooth segmentation (jcohenadad, issue #613)
    sct.printv('\nSmooth segmentation...', verbose)
    sct.run('sct_maths -i '+ftmp_seg+' -smooth 1.5 -o '+add_suffix(ftmp_seg, '_smooth'))
    ftmp_seg = add_suffix(ftmp_seg, '_smooth')

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run('sct_resample -i '+ftmp_data+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_data, '_1mm'))
    ftmp_data = add_suffix(ftmp_data, '_1mm')
    sct.run('sct_resample -i '+ftmp_seg+' -mm 1.0x1.0x1.0 -x linear -o '+add_suffix(ftmp_seg, '_1mm'))
    ftmp_seg = add_suffix(ftmp_seg, '_1mm')
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
    ftmp_label = add_suffix(ftmp_label, '_1mm')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    sct.run('sct_image -i '+ftmp_data+' -setorient RPI -o '+add_suffix(ftmp_data, '_rpi'))
    ftmp_data = add_suffix(ftmp_data, '_rpi')
    sct.run('sct_image -i '+ftmp_seg+' -setorient RPI -o '+add_suffix(ftmp_seg, '_rpi'))
    ftmp_seg = add_suffix(ftmp_seg, '_rpi')
    sct.run('sct_image -i '+ftmp_label+' -setorient RPI -o '+add_suffix(ftmp_label, '_rpi'))
    ftmp_label = add_suffix(ftmp_label, '_rpi')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    status_crop, output_crop = sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -bzmax', verbose)
    ftmp_seg = add_suffix(ftmp_seg, '_crop')
    cropping_slices = output_crop.split('Dimension 2: ')[1].split('\n')[0].split(' ')

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)
    sct.run('sct_straighten_spinalcord -i '+ftmp_seg+' -s '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straight')+' -qc 0 -r 0 -v '+str(verbose)+' '+param.param_straighten+arg_cpu, verbose)
    # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run('sct_concat_transfo -w warp_straight2curve.nii.gz -d '+ftmp_data+' -o warp_straight2curve.nii.gz')

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
    sct.run('sct_label_utils -p remove -i '+ftmp_template_label+' -o '+ftmp_template_label+' -r '+ftmp_label)

    # Dilating the input label so they can be straighten without losing them
    sct.printv('\nDilating input labels using 3vox ball radius')
    sct.run('sct_maths -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_dilate')+' -dilate 3')
    ftmp_label = add_suffix(ftmp_label, '_dilate')

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run('sct_apply_transfo -i '+ftmp_label+' -o '+add_suffix(ftmp_label, '_straight')+' -d '+add_suffix(ftmp_seg, '_straight')+' -w warp_curve2straight.nii.gz -x nn')
    ftmp_label = add_suffix(ftmp_label, '_straight')

    # Create crosses for the template labels and get coordinates
    sct.printv('\nCreate a 15 mm cross for the template labels...', verbose)
    template_image = Image(ftmp_template_label)
    coordinates_input = template_image.getNonZeroCoordinates(sorting='value')
    # jcohenadad, issue #628 <<<<<
    # landmark_template = ProcessLabels.get_crosses_coordinates(coordinates_input, gapxy=15)
    landmark_template = coordinates_input
    # >>>>>
    if verbose == 2:
        # TODO: assign cross to image before saving
        template_image.setFileName(add_suffix(ftmp_template_label, '_cross'))
        template_image.save(type='minimize_int')

    # Create crosses for the input labels into straight space and get coordinates
    sct.printv('\nCreate a 15 mm cross for the input labels...', verbose)
    label_straight_image = Image(ftmp_label)
    coordinates_input = label_straight_image.getCoordinatesAveragedByValue()  # landmarks are sorted by value
    # jcohenadad, issue #628 <<<<<
    # landmark_straight = ProcessLabels.get_crosses_coordinates(coordinates_input, gapxy=15)
    landmark_straight = coordinates_input
    # >>>>>
    if verbose == 2:
        # TODO: assign cross to image before saving
        label_straight_image.setFileName(add_suffix(ftmp_label, '_cross'))
        label_straight_image.save(type='minimize_int')

    # Reorganize landmarks
    points_fixed, points_moving = [], []
    for coord in landmark_straight:
        point_straight = label_straight_image.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_moving.append([point_straight[0][0], point_straight[0][1], point_straight[0][2]])

    for coord in landmark_template:
        point_template = template_image.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        points_fixed.append([point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv('\nComputing rigid transformation (algo=translation-scaling-z) ...', verbose)

    import msct_register_landmarks
    # for some reason, the moving and fixed points are inverted between ITK transform and our python-based transform.
    # and for another unknown reason, x and y dimensions have a negative sign (at least for translation and center of rotation).
    if verbose == 2:
        show_transfo = True
    else:
        show_transfo = False
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = msct_register_landmarks.getRigidTransformFromLandmarks(points_moving, points_fixed, constraints='translation-scaling-z', show=show_transfo)
    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: AffineTransform_double_3_3\n")
    text_file.write("Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n" % (
        rotation_matrix[0, 0], rotation_matrix[0, 1], rotation_matrix[0, 2],
        rotation_matrix[1, 0], rotation_matrix[1, 1], rotation_matrix[1, 2],
        rotation_matrix[2, 0], rotation_matrix[2, 1], rotation_matrix[2, 2],
        -translation_array[0, 0], -translation_array[0, 1], translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" % (-points_moving_barycenter[0],
                                                           -points_moving_barycenter[1],
                                                           points_moving_barycenter[2]))
    text_file.close()

    # Concatenate transformations: curve --> straight --> affine
    sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz')

    # Apply transformation
    sct.printv('\nApply transformation...', verbose)
    sct.run('sct_apply_transfo -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz')
    ftmp_data = add_suffix(ftmp_data, '_straightAffine')
    sct.run('sct_apply_transfo -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_straightAffine')+' -d '+ftmp_template+' -w warp_curve2straightAffine.nii.gz -x linear')
    ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

    # threshold and binarize
    sct.printv('\nBinarize segmentation...', verbose)
    sct.run('sct_maths -i '+ftmp_seg+' -thr 0.4 -o '+add_suffix(ftmp_seg, '_thr'))
    sct.run('sct_maths -i '+add_suffix(ftmp_seg, '_thr')+' -bin -o '+add_suffix(ftmp_seg, '_thr_bin'))
    ftmp_seg = add_suffix(ftmp_seg, '_thr_bin')

    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax(ftmp_seg)

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...', verbose)
    sct.run('sct_crop_image -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_template = add_suffix(ftmp_template, '_crop')
    sct.run('sct_crop_image -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_template_seg = add_suffix(ftmp_template_seg, '_crop')
    sct.run('sct_crop_image -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_data = add_suffix(ftmp_data, '_crop')
    sct.run('sct_crop_image -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_crop')+' -dim 2 -start '+str(zmin_template)+' -end '+str(zmax_template))
    ftmp_seg = add_suffix(ftmp_seg, '_crop')

    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
    sct.run('sct_resample -i '+ftmp_template+' -o '+add_suffix(ftmp_template, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_template = add_suffix(ftmp_template, '_sub')
    sct.run('sct_resample -i '+ftmp_template_seg+' -o '+add_suffix(ftmp_template_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
    sct.run('sct_resample -i '+ftmp_data+' -o '+add_suffix(ftmp_data, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_data = add_suffix(ftmp_data, '_sub')
    sct.run('sct_resample -i '+ftmp_seg+' -o '+add_suffix(ftmp_seg, '_sub')+' -f 1x1x'+zsubsample, verbose)
    ftmp_seg = add_suffix(ftmp_seg, '_sub')

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps)+1):
        sct.printv('\nEstimate transformation for step #'+str(i_step)+'...', verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = ftmp_data
            dest = ftmp_template
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = ftmp_seg
            dest = ftmp_template_seg
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            src = add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run('sct_concat_transfo -w warp_curve2straightAffine.nii.gz,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    sct.printv('\nConcatenate transformations: template --> anat...', verbose)
    warp_inverse.reverse()
    sct.run('sct_concat_transfo -w '+','.join(warp_inverse)+',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz', verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run('sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -crop 1', verbose)
        sct.run('sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -crop 1', verbose)

    # come back to parent folder
    os.chdir('..')

   # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp+'warp_template2anat.nii.gz', path_output+'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp+'warp_anat2template.nii.gz', path_output+'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp+'template2anat.nii.gz', path_output+'template2anat'+ext_data, verbose)
        sct.generate_output_file(path_tmp+'anat2template.nii.gz', path_output+'anat2template'+ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf '+path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: '+str(int(round(elapsed_time)))+'s', verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview '+fname_data+' '+path_output+'template2anat -b 0,4000 &', verbose, 'info')
    sct.printv('fslview '+fname_template+' -b 0,5000 '+path_output+'anat2template &\n', verbose, 'info')
def test(path_data='', parameters=''):

    if not parameters:
        parameters = '-i t2/t2.nii.gz -c t2 -p auto'

    # parameters
    folder_data = 't2/'
    file_data = ['t2.nii.gz', 't2_centerline_init.nii.gz', 't2_centerline_labels.nii.gz', 't2_seg_manual.nii.gz']

    parser = sct_get_centerline.get_parser()
    dict_param = parser.parse(parameters.split(), check_file_exist=False)
    contrast = dict_param['-c']
    dict_param_with_path = parser.add_path_to_file(dict_param, path_data, input_file=True)
    param_with_path = parser.dictionary_to_string(dict_param_with_path)

    # Check if input files exist
    if not (os.path.isfile(dict_param_with_path['-i'])):
        status = 200
        output = 'ERROR: the file(s) provided to test function do not exist in folder: ' + path_data
        return status, output, DataFrame(data={'status': status, 'output': output, 'mse': float('nan'), 'dist_max': float('nan')}, index=[path_data])

    cmd = 'sct_get_centerline '+param_with_path
    status, output = sct.run(cmd, 0)
    scad_centerline = Image(contrast+"_centerline.nii.gz")
    manual_seg = Image(path_data + folder_data + contrast +'_seg_manual.nii.gz')

    max_distance = 0
    standard_deviation = 0
    average = 0
    root_mean_square = 0
    overall_distance = 0
    max_distance = 0
    overall_std = 0
    rmse = 0

    try:
        if status == 0:
            manual_seg.change_orientation()
            scad_centerline.change_orientation()
            from scipy.ndimage.measurements import center_of_mass
            # find COM
            iterator = range(manual_seg.data.shape[2])
            com_x = [0 for ix in iterator]
            com_y = [0 for iy in iterator]

            for iz in iterator:
                com_x[iz], com_y[iz] = center_of_mass(manual_seg.data[:, :, iz])
            max_distance = {}
            distance = {}
            for iz in range(1, scad_centerline.data.shape[2]-1):
                ind1 = np.argmax(scad_centerline.data[:, :, iz])
                X,Y = ind2sub(scad_centerline.data[:, :, iz].shape,ind1)
                com_phys = np.array(manual_seg.transfo_pix2phys([[com_x[iz], com_y[iz], iz]]))
                scad_phys = np.array(scad_centerline.transfo_pix2phys([[X, Y, iz]]))
                distance_magnitude = np.linalg.norm([com_phys[0][0]-scad_phys[0][0], com_phys[0][1]-scad_phys[0][1], 0])
                if math.isnan(distance_magnitude):
                    print "Value is nan"
                else:
                    distance[iz] = distance_magnitude

            max_distance = max(distance.values())
            standard_deviation = np.std(np.array(distance.values()))
            average = sum(distance.values())/len(distance)
            root_mean_square = np.sqrt(np.mean(np.square(distance.values())))
            overall_distance = average
            max_distance = max(distance.values())
            overall_std = standard_deviation
            rmse = root_mean_square

    except Exception, e:
        sct.printv("Exception found while testing scad integrity")
        output = e.message
Exemple #26
0
def validate_scad(folder_input):
    """
    Expecting folder to have the following structure :
    errsm_01:
    - t2
    -- errsm_01.nii.gz or t2.nii.gz
    :param folder_input:
    :return:
    """
    current_folder = os.getcwd()
    os.chdir(folder_input)
    try:
        patients = next(os.walk('.'))[1]
        for i in patients:
            if i != "errsm_01" and i !="errsm_02":
                directory = i + "/t2"
                os.chdir(directory)
                try:
                    if os.path.isfile(i+"_t2.nii.gz"):
                        raw_image = Image(i+"_t2.nii.gz")
                    elif os.path.isfile("t2.nii.gz"):
                        raw_image = Image("t2.nii.gz")
                    else:
                        raise Exception("t2.nii.gz or "+i+"_t2.nii.gz file is not found")

                    raw_orientation = raw_image.change_orientation()
                    SCAD(raw_image, contrast="t2", rm_tmp_file=1, verbose=1).test_debug()

                    manual_seg = Image(i+"_t2_manual_segmentation.nii.gz")
                    manual_orientation = manual_seg.change_orientation()

                    from scipy.ndimage.measurements import center_of_mass
                    # find COM
                    iterator = range(manual_seg.data.shape[2])
                    com_x = [0 for ix in iterator]
                    com_y = [0 for iy in iterator]

                    for iz in iterator:
                        com_x[iz], com_y[iz] = center_of_mass(manual_seg.data[:, :, iz])
                    #raw_image.change_orientation(raw_orientation)
                    #manual_seg.change_orientation(manual_orientation)

                    centerline_scad = Image(i+"_t2_centerline.nii.gz")
                    os.remove(i+"_t2_centerline.nii.gz")

                    centerline_scad.change_orientation()
                    distance = []
                    for iz in range(centerline_scad.data.shape[2]):
                        ind1 = np.argmax(centerline_scad.data[:, :, iz])
                        X,Y = scad.ind2sub(centerline_scad.data[:, :, i].shape,ind1)
                        com_phys = centerline_scad.transfo_pix2phys([[com_x[iz], com_y[iz], iz]])
                        scad_phys = centerline_scad.transfo_pix2phys([[X, Y, iz]])
                        distance_magnitude = np.linalg.norm(com_phys-scad_phys)
                        distance.append(distance_magnitude)



                    os.chdir(folder_input)

                except Exception, e:
                    print e.message
                pass
    except Exception, e:
        print e.message
def register_landmarks(fname_src, fname_dest, dof, fname_affine="affine.txt", verbose=1, path_qc="./"):
    """
    Register two NIFTI volumes containing landmarks
    :param fname_src: fname of source landmarks
    :param fname_dest: fname of destination landmarks
    :param dof: degree of freedom. Separate with "_". Example: Tx_Ty_Tz_Rx_Ry_Sz
    :param fname_affine: output affine transformation
    :param verbose: 0, 1, 2
    :return:
    """
    from msct_image import Image

    # open src label
    im_src = Image(fname_src)
    # coord_src = im_src.getNonZeroCoordinates(sorting='value')  # landmarks are sorted by value
    coord_src = im_src.getCoordinatesAveragedByValue()  # landmarks are sorted by value
    # open dest labels
    im_dest = Image(fname_dest)
    # coord_dest = im_dest.getNonZeroCoordinates(sorting='value')
    coord_dest = im_dest.getCoordinatesAveragedByValue()
    # Reorganize landmarks

    points_src, points_dest = [], []
    for coord in coord_src:
        point_src = im_src.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        # convert NIFTI to ITK world coordinate
        # points_src.append([point_src[0][0], point_src[0][1], point_src[0][2]])
        points_src.append([-point_src[0][0], -point_src[0][1], point_src[0][2]])
    for coord in coord_dest:
        point_dest = im_dest.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        # convert NIFTI to ITK world coordinate
        # points_dest.append([point_dest[0][0], point_dest[0][1], point_dest[0][2]])
        points_dest.append([-point_dest[0][0], -point_dest[0][1], point_dest[0][2]])

    # display
    sct.printv("Labels src: " + str(points_src), verbose)
    sct.printv("Labels dest: " + str(points_dest), verbose)
    sct.printv("Degrees of freedom (dof): " + dof, verbose)

    if len(coord_src) != len(coord_dest):
        raise Exception(
            "Error: number of source and destination landmarks are not the same, so landmarks cannot be paired."
        )

    # estimate transformation
    # N.B. points_src and points_dest are inverted below, because ITK uses inverted transformation matrices, i.e., src->dest is defined in dest instead of src.
    # (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = getRigidTransformFromLandmarks(points_dest, points_src, constraints=dof, verbose=verbose, path_qc=path_qc)
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = getRigidTransformFromLandmarks(
        points_src, points_dest, constraints=dof, verbose=verbose, path_qc=path_qc
    )
    # writing rigid transformation file
    # N.B. x and y dimensions have a negative sign to ensure compatibility between Python and ITK transfo
    text_file = open(fname_affine, "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: AffineTransform_double_3_3\n")
    text_file.write(
        "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
        % (
            rotation_matrix[0, 0],
            rotation_matrix[0, 1],
            rotation_matrix[0, 2],
            rotation_matrix[1, 0],
            rotation_matrix[1, 1],
            rotation_matrix[1, 2],
            rotation_matrix[2, 0],
            rotation_matrix[2, 1],
            rotation_matrix[2, 2],
            translation_array[0, 0],
            translation_array[0, 1],
            translation_array[0, 2],
        )
    )
    text_file.write(
        "FixedParameters: %.9f %.9f %.9f\n"
        % (points_moving_barycenter[0], points_moving_barycenter[1], points_moving_barycenter[2])
    )
    text_file.close()
Exemple #28
0
def interpolate_im_to_ref(im_input, im_input_sc, new_res=0.3, sq_size_size_mm=22.5, interpolation_mode=3):
    nx, ny, nz, nt, px, py, pz, pt = im_input.dim

    im_input_sc = im_input_sc.copy()
    im_input= im_input.copy()

    # keep only spacing and origin in qform to avoid rotation issues
    input_qform = im_input.hdr.get_qform()
    for i in range(4):
        for j in range(4):
            if i!=j and j!=3:
                input_qform[i, j] = 0

    im_input.hdr.set_qform(input_qform)
    im_input.hdr.set_sform(input_qform)
    im_input_sc.hdr = im_input.hdr

    sq_size = int(sq_size_size_mm/new_res)
    # create a reference image : square of ones
    im_ref = Image(np.ones((sq_size, sq_size, 1), dtype=np.int), dim=(sq_size, sq_size, 1, 0, new_res, new_res, pz, 0), orientation='RPI')

    # copy input qform matrix to reference image
    im_ref.hdr.set_qform(im_input.hdr.get_qform())
    im_ref.hdr.set_sform(im_input.hdr.get_sform())

    # set correct header to reference image
    im_ref.hdr.set_data_shape((sq_size, sq_size, 1))
    im_ref.hdr.set_zooms((new_res, new_res, pz))

    # save image to set orientation to RPI (not properly done at the creation of the image)
    fname_ref = 'im_ref.nii.gz'
    im_ref.setFileName(fname_ref)
    im_ref.save()
    im_ref = set_orientation(im_ref, 'RPI', fname_out=fname_ref)

    # set header origin to zero to get physical coordinates of the center of the square
    im_ref.hdr.as_analyze_map()['qoffset_x'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_y'] = 0
    im_ref.hdr.as_analyze_map()['qoffset_z'] = 0
    im_ref.hdr.set_sform(im_ref.hdr.get_qform())
    im_ref.hdr.set_qform(im_ref.hdr.get_qform())
    [[x_square_center_phys, y_square_center_phys, z_square_center_phys]] = im_ref.transfo_pix2phys(coordi=[[int(sq_size / 2), int(sq_size / 2), 0]])

    list_interpolate_images = []
    # iterate on z dimension of input image
    for iz in range(nz):
        # copy reference image: one reference image per slice
        im_ref_slice_iz = im_ref.copy()

        # get center of mass of SC for slice iz
        x_seg, y_seg = (im_input_sc.data[:, :, iz] > 0).nonzero()
        x_center, y_center = np.mean(x_seg), np.mean(y_seg)
        [[x_center_phys, y_center_phys, z_center_phys]] = im_input_sc.transfo_pix2phys(coordi=[[x_center, y_center, iz]])

        # center reference image on SC for slice iz
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_x'] = x_center_phys - x_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_y'] = y_center_phys - y_square_center_phys
        im_ref_slice_iz.hdr.as_analyze_map()['qoffset_z'] = z_center_phys
        im_ref_slice_iz.hdr.set_sform(im_ref_slice_iz.hdr.get_qform())
        im_ref_slice_iz.hdr.set_qform(im_ref_slice_iz.hdr.get_qform())

        # interpolate input image to reference image
        im_input_interpolate_iz = im_input.interpolate_from_image(im_ref_slice_iz, interpolation_mode=interpolation_mode, border='nearest')
        # reshape data to 2D if needed
        if len(im_input_interpolate_iz.data.shape) == 3:
            im_input_interpolate_iz.data = im_input_interpolate_iz.data.reshape(im_input_interpolate_iz.data.shape[:-1])
        # add slice to list
        list_interpolate_images.append(im_input_interpolate_iz)

    return list_interpolate_images
Exemple #29
0
def register_landmarks(fname_src,
                       fname_dest,
                       dof,
                       fname_affine='affine.txt',
                       verbose=1,
                       path_qc='./'):
    """
    Register two NIFTI volumes containing landmarks
    :param fname_src: fname of source landmarks
    :param fname_dest: fname of destination landmarks
    :param dof: degree of freedom. Separate with "_". Example: Tx_Ty_Tz_Rx_Ry_Sz
    :param fname_affine: output affine transformation
    :param verbose: 0, 1, 2
    :return:
    """
    from msct_image import Image
    # open src label
    im_src = Image(fname_src)
    # coord_src = im_src.getNonZeroCoordinates(sorting='value')  # landmarks are sorted by value
    coord_src = im_src.getCoordinatesAveragedByValue(
    )  # landmarks are sorted by value
    # open dest labels
    im_dest = Image(fname_dest)
    # coord_dest = im_dest.getNonZeroCoordinates(sorting='value')
    coord_dest = im_dest.getCoordinatesAveragedByValue()
    # Reorganize landmarks

    points_src, points_dest = [], []
    for coord in coord_src:
        point_src = im_src.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        # convert NIFTI to ITK world coordinate
        # points_src.append([point_src[0][0], point_src[0][1], point_src[0][2]])
        points_src.append(
            [-point_src[0][0], -point_src[0][1], point_src[0][2]])
    for coord in coord_dest:
        point_dest = im_dest.transfo_pix2phys([[coord.x, coord.y, coord.z]])
        # convert NIFTI to ITK world coordinate
        # points_dest.append([point_dest[0][0], point_dest[0][1], point_dest[0][2]])
        points_dest.append(
            [-point_dest[0][0], -point_dest[0][1], point_dest[0][2]])

    # display
    sct.printv('Labels src: ' + str(points_src), verbose)
    sct.printv('Labels dest: ' + str(points_dest), verbose)
    sct.printv('Degrees of freedom (dof): ' + dof, verbose)

    if len(coord_src) != len(coord_dest):
        raise Exception(
            'Error: number of source and destination landmarks are not the same, so landmarks cannot be paired.'
        )

    # estimate transformation
    # N.B. points_src and points_dest are inverted below, because ITK uses inverted transformation matrices, i.e., src->dest is defined in dest instead of src.
    # (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = getRigidTransformFromLandmarks(points_dest, points_src, constraints=dof, verbose=verbose, path_qc=path_qc)
    (rotation_matrix, translation_array, points_moving_reg,
     points_moving_barycenter) = getRigidTransformFromLandmarks(
         points_src,
         points_dest,
         constraints=dof,
         verbose=verbose,
         path_qc=path_qc)
    # writing rigid transformation file
    # N.B. x and y dimensions have a negative sign to ensure compatibility between Python and ITK transfo
    text_file = open(fname_affine, 'w')
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write("Transform: AffineTransform_double_3_3\n")
    text_file.write(
        "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
        % (rotation_matrix[0, 0], rotation_matrix[0, 1], rotation_matrix[0, 2],
           rotation_matrix[1, 0], rotation_matrix[1, 1], rotation_matrix[1, 2],
           rotation_matrix[2, 0], rotation_matrix[2, 1], rotation_matrix[2, 2],
           translation_array[0, 0], translation_array[0, 1],
           translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" %
                    (points_moving_barycenter[0], points_moving_barycenter[1],
                     points_moving_barycenter[2]))
    text_file.close()
def register_images(
    im_input,
    im_dest,
    mask="",
    paramreg=Paramreg(
        step="0", type="im", algo="Translation", metric="MI", iter="5", shrink="1", smooth="0", gradStep="0.5"
    ),
    ants_registration_params={
        "rigid": "",
        "affine": "",
        "compositeaffine": "",
        "similarity": "",
        "translation": "",
        "bspline": ",10",
        "gaussiandisplacementfield": ",3,0",
        "bsplinedisplacementfield": ",5,10",
        "syn": ",3,0",
        "bsplinesyn": ",1,3",
    },
    remove_tmp_folder=1,
):
    """Slice-by-slice registration of two images.

    We first split the 3D images into 2D images (and the mask if inputted). Then we register slices of the two images
    that physically correspond to one another looking at the physical origin of each image. The images can be of
    different sizes but the destination image must be smaller thant the input image. We do that using antsRegistration
    in 2D. Once this has been done for each slices, we gather the results and return them.
    Algorithms implemented: translation, rigid, affine, syn and BsplineSyn.
    N.B.: If the mask is inputted, it must also be 3D and it must be in the same space as the destination image.

    input:
        im_input: name of moving image (type: string)
        im_dest: name of fixed image (type: string)
        mask[optional]: name of mask file (type: string) (parameter -x of antsRegistration)
        paramreg[optional]: parameters of antsRegistration (type: Paramreg class from sct_register_multimodal)
        ants_registration_params[optional]: specific algorithm's parameters for antsRegistration (type: dictionary)

    output:
        if algo==translation:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
        if algo==rigid:
            x_displacement: list of translation along x axis for each slice (type: list)
            y_displacement: list of translation along y axis for each slice (type: list)
            theta_rotation: list of rotation angle in radian (and in ITK's coordinate system) for each slice (type: list)
        if algo==affine or algo==syn or algo==bsplinesyn:
            creation of two 3D warping fields (forward and inverse) that are the concatenations of the slice-by-slice
            warps.
    """
    # Extracting names
    path_i, root_i, ext_i = sct.extract_fname(im_input)
    path_d, root_d, ext_d = sct.extract_fname(im_dest)

    # set metricSize
    if paramreg.metric == "MI":
        metricSize = "32"  # corresponds to number of bins
    else:
        metricSize = "4"  # corresponds to radius (for CC, MeanSquares...)

    # Get image dimensions and retrieve nz
    print "\nGet image dimensions of destination image..."
    nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(im_dest)
    print ".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz)
    print ".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm"

    # Define x and y displacement as list
    x_displacement = [0 for i in range(nz)]
    y_displacement = [0 for i in range(nz)]
    theta_rotation = [0 for i in range(nz)]

    # create temporary folder
    print ("\nCreate temporary folder...")
    path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
    sct.create_folder(path_tmp)
    print "\nCopy input data..."
    sct.run("cp " + im_input + " " + path_tmp + "/" + root_i + ext_i)
    sct.run("cp " + im_dest + " " + path_tmp + "/" + root_d + ext_d)
    if mask:
        sct.run("cp " + mask + " " + path_tmp + "/mask.nii.gz")

    # go to temporary folder
    os.chdir(path_tmp)

    # Split input volume along z
    print "\nSplit input volume..."
    sct.run(sct.fsloutput + "fslsplit " + im_input + " " + root_i + "_z -z")

    # Split destination volume along z
    print "\nSplit destination volume..."
    sct.run(sct.fsloutput + "fslsplit " + im_dest + " " + root_d + "_z -z")

    # Split mask volume along z
    if mask:
        print "\nSplit mask volume..."
        sct.run(sct.fsloutput + "fslsplit mask.nii.gz mask_z -z")

    im_dest_img = Image(im_dest)
    im_input_img = Image(im_input)
    coord_origin_dest = im_dest_img.transfo_pix2phys([[0, 0, 0]])
    coord_origin_input = im_input_img.transfo_pix2phys([[0, 0, 0]])
    coord_diff_origin = (asarray(coord_origin_dest[0]) - asarray(coord_origin_input[0])).tolist()
    [x_o, y_o, z_o] = [
        coord_diff_origin[0] * 1.0 / px,
        coord_diff_origin[1] * 1.0 / py,
        coord_diff_origin[2] * 1.0 / pz,
    ]

    if paramreg.algo == "BSplineSyN" or paramreg.algo == "SyN" or paramreg.algo == "Affine":
        list_warp_x = []
        list_warp_x_inv = []
        list_warp_y = []
        list_warp_y_inv = []
        name_warp_final = (
            "Warp_total"
        )  # if modified, name should also be modified in msct_register (algo slicereg2d_bsplinesyn and slicereg2d_syn)

    # loop across slices
    for i in range(nz):
        # set masking
        num = numerotation(i)
        num_2 = numerotation(int(num) + int(z_o))
        if mask:
            masking = "-x mask_z" + num + ".nii"
        else:
            masking = ""

        cmd = (
            "isct_antsRegistration "
            "--dimensionality 2 "
            "--transform "
            + paramreg.algo
            + "["
            + str(paramreg.gradStep)
            + ants_registration_params[paramreg.algo.lower()]
            + "] "
            "--metric "
            + paramreg.metric
            + "["
            + root_d
            + "_z"
            + num
            + ".nii"
            + ","
            + root_i
            + "_z"
            + num_2
            + ".nii"
            + ",1,"
            + metricSize
            + "] "  # [fixedImage,movingImage,metricWeight +nb_of_bins (MI) or radius (other)
            "--convergence " + str(paramreg.iter) + " "
            "--shrink-factors " + str(paramreg.shrink) + " "
            "--smoothing-sigmas " + str(paramreg.smooth) + "mm "
            #'--restrict-deformation 1x1x0 '    # how to restrict? should not restrict here, if transform is precised...?
            "--output [transform_"
            + num
            + ","
            + root_i
            + "_z"
            + num_2
            + "reg.nii] "  # --> file.mat (contains Tx,Ty, theta)
            "--interpolation BSpline[3] " + masking
        )

        try:
            sct.run(cmd)

            if paramreg.algo == "Rigid" or paramreg.algo == "Translation":
                f = "transform_" + num + "0GenericAffine.mat"
                matfile = loadmat(f, struct_as_record=True)
                array_transfo = matfile["AffineTransform_double_2_2"]
                x_displacement[i] = array_transfo[4][0]  # Tx in ITK'S coordinate system
                y_displacement[i] = array_transfo[5][0]  # Ty  in ITK'S and fslview's coordinate systems
                theta_rotation[i] = asin(
                    array_transfo[2]
                )  # angle of rotation theta in ITK'S coordinate system (minus theta for fslview)

            if paramreg.algo == "Affine":
                # New process added for generating total nifti warping field from mat warp
                name_dest = root_d + "_z" + num + ".nii"
                name_reg = root_i + "_z" + num + "reg.nii"
                name_output_warp = "warp_from_mat_" + num_2 + ".nii.gz"
                name_output_warp_inverse = "warp_from_mat_" + num + "_inverse.nii.gz"
                name_warp_null = "warp_null_" + num + ".nii.gz"
                name_warp_null_dest = "warp_null_dest" + num + ".nii.gz"
                name_warp_mat = "transform_" + num + "0GenericAffine.mat"
                # Generating null nifti warping fields
                nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(name_reg)
                nx_d, ny_d, nz_d, nt_d, px_d, py_d, pz_d, pt_d = sct.get_dimension(name_dest)
                x_trans = [0 for i in range(nz)]
                x_trans_d = [0 for i in range(nz_d)]
                y_trans = [0 for i in range(nz)]
                y_trans_d = [0 for i in range(nz_d)]
                generate_warping_field(name_reg, x_trans=x_trans, y_trans=y_trans, fname=name_warp_null, verbose=0)
                generate_warping_field(
                    name_dest, x_trans=x_trans_d, y_trans=y_trans_d, fname=name_warp_null_dest, verbose=0
                )
                # Concatenating mat wrp and null nifti warp to obtain equivalent nifti warp to mat warp
                sct.run(
                    "isct_ComposeMultiTransform 2 "
                    + name_output_warp
                    + " -R "
                    + name_reg
                    + " "
                    + name_warp_null
                    + " "
                    + name_warp_mat
                )
                sct.run(
                    "isct_ComposeMultiTransform 2 "
                    + name_output_warp_inverse
                    + " -R "
                    + name_dest
                    + " "
                    + name_warp_null_dest
                    + " -i "
                    + name_warp_mat
                )
                # Split the warping fields into two for displacement along x and y before merge
                sct.run(
                    "isct_c3d -mcs "
                    + name_output_warp
                    + " -oo transform_"
                    + num
                    + "0Warp_x.nii.gz transform_"
                    + num
                    + "0Warp_y.nii.gz"
                )
                sct.run(
                    "isct_c3d -mcs "
                    + name_output_warp_inverse
                    + " -oo transform_"
                    + num
                    + "0InverseWarp_x.nii.gz transform_"
                    + num
                    + "0InverseWarp_y.nii.gz"
                )
                # List names of warping fields for futur merge
                list_warp_x.append("transform_" + num + "0Warp_x.nii.gz")
                list_warp_x_inv.append("transform_" + num + "0InverseWarp_x.nii.gz")
                list_warp_y.append("transform_" + num + "0Warp_y.nii.gz")
                list_warp_y_inv.append("transform_" + num + "0InverseWarp_y.nii.gz")

            if paramreg.algo == "BSplineSyN" or paramreg.algo == "SyN":
                # Split the warping fields into two for displacement along x and y before merge
                # Need to separate the merge for x and y displacement as merge of 3d warping fields does not work properly
                sct.run(
                    "isct_c3d -mcs transform_"
                    + num
                    + "0Warp.nii.gz -oo transform_"
                    + num
                    + "0Warp_x.nii.gz transform_"
                    + num
                    + "0Warp_y.nii.gz"
                )
                sct.run(
                    "isct_c3d -mcs transform_"
                    + num
                    + "0InverseWarp.nii.gz -oo transform_"
                    + num
                    + "0InverseWarp_x.nii.gz transform_"
                    + num
                    + "0InverseWarp_y.nii.gz"
                )
                # List names of warping fields for futur merge
                list_warp_x.append("transform_" + num + "0Warp_x.nii.gz")
                list_warp_x_inv.append("transform_" + num + "0InverseWarp_x.nii.gz")
                list_warp_y.append("transform_" + num + "0Warp_y.nii.gz")
                list_warp_y_inv.append("transform_" + num + "0InverseWarp_y.nii.gz")
        # if an exception occurs with ants, take the last value for the transformation
        except:
            if paramreg.algo == "Rigid" or paramreg.algo == "Translation":
                x_displacement[i] = x_displacement[i - 1]
                y_displacement[i] = y_displacement[i - 1]
                theta_rotation[i] = theta_rotation[i - 1]

            if paramreg.algo == "BSplineSyN" or paramreg.algo == "SyN" or paramreg.algo == "Affine":
                print "Problem with ants for slice " + str(i) + ". Copy of the last warping field."
                sct.run("cp transform_" + numerotation(i - 1) + "0Warp.nii.gz transform_" + num + "0Warp.nii.gz")
                sct.run(
                    "cp transform_"
                    + numerotation(i - 1)
                    + "0InverseWarp.nii.gz transform_"
                    + num
                    + "0InverseWarp.nii.gz"
                )
                # Split the warping fields into two for displacement along x and y before merge
                sct.run(
                    "isct_c3d -mcs transform_"
                    + num
                    + "0Warp.nii.gz -oo transform_"
                    + num
                    + "0Warp_x.nii.gz transform_"
                    + num
                    + "0Warp_y.nii.gz"
                )
                sct.run(
                    "isct_c3d -mcs transform_"
                    + num
                    + "0InverseWarp.nii.gz -oo transform_"
                    + num
                    + "0InverseWarp_x.nii.gz transform_"
                    + num
                    + "0InverseWarp_y.nii.gz"
                )
                # List names of warping fields for futur merge
                list_warp_x.append("transform_" + num + "0Warp_x.nii.gz")
                list_warp_x_inv.append("transform_" + num + "0InverseWarp_x.nii.gz")
                list_warp_y.append("transform_" + num + "0Warp_y.nii.gz")
                list_warp_y_inv.append("transform_" + num + "0InverseWarp_y.nii.gz")

    if paramreg.algo == "BSplineSyN" or paramreg.algo == "SyN" or paramreg.algo == "Affine":
        print "\nMerge along z of the warping fields..."
        sct.run("fslmerge -z " + name_warp_final + "_x " + " ".join(list_warp_x))
        sct.run("fslmerge -z " + name_warp_final + "_x_inverse " + " ".join(list_warp_x_inv))
        sct.run("fslmerge -z " + name_warp_final + "_y " + " ".join(list_warp_y))
        sct.run("fslmerge -z " + name_warp_final + "_y_inverse " + " ".join(list_warp_y_inv))
        print "\nChange resolution of warping fields to match the resolution of the destination image..."
        sct.run("fslcpgeom " + im_dest + " " + name_warp_final + "_x.nii.gz")
        sct.run("fslcpgeom " + im_input + " " + name_warp_final + "_x_inverse.nii.gz")
        sct.run("fslcpgeom " + im_dest + " " + name_warp_final + "_y.nii.gz")
        sct.run("fslcpgeom " + im_input + " " + name_warp_final + "_y_inverse.nii.gz")
        print "\nMerge translation fields along x and y into one global warping field "
        sct.run(
            "isct_c3d "
            + name_warp_final
            + "_x.nii.gz "
            + name_warp_final
            + "_y.nii.gz -omc 2 "
            + name_warp_final
            + ".nii.gz"
        )
        sct.run(
            "isct_c3d "
            + name_warp_final
            + "_x_inverse.nii.gz "
            + name_warp_final
            + "_y_inverse.nii.gz -omc 2 "
            + name_warp_final
            + "_inverse.nii.gz"
        )
        print "\nCopy to parent folder..."
        sct.run("cp " + name_warp_final + ".nii.gz ../")
        sct.run("cp " + name_warp_final + "_inverse.nii.gz ../")

    # Delete tmp folder
    os.chdir("../")
    if remove_tmp_folder:
        print ("\nRemove temporary files...")
        sct.run("rm -rf " + path_tmp)
    if paramreg.algo == "Rigid":
        return x_displacement, y_displacement, theta_rotation
    if paramreg.algo == "Translation":
        return x_displacement, y_displacement
Exemple #31
0
def register2d_centermassrot(fname_src,
                             fname_dest,
                             fname_warp='warp_forward.nii.gz',
                             fname_warp_inv='warp_inverse.nii.gz',
                             rot=1,
                             poly=0,
                             path_qc='./',
                             verbose=0,
                             pca_eigenratio_th=1.6):
    """
    Rotate the source image to match the orientation of the destination image, using the first and second eigenvector
    of the PCA. This function should be used on segmentations (not images).
    This works for 2D and 3D images.  If 3D, it splits the image and performs the rotation slice-by-slice.
    input:
        fname_source: name of moving image (type: string)
        fname_dest: name of fixed image (type: string)
        fname_warp: name of output 3d forward warping field
        fname_warp_inv: name of output 3d inverse warping field
        rot: estimate rotation with PCA (type: int)
        poly: degree of polynomial regularization along z for rotation angle (type: int). 0: no regularization
        verbose:
    output:
        none
    """

    if verbose == 2:
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt

    # Get image dimensions and retrieve nz
    sct.printv('\nGet image dimensions of destination image...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    sct.printv('  matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz),
               verbose)
    sct.printv(
        '  voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(pz) +
        'mm', verbose)

    # Split source volume along z
    sct.printv('\nSplit input volume...', verbose)
    from sct_image import split_data
    im_src = Image('src.nii')
    split_source_list = split_data(im_src, 2)
    for im in split_source_list:
        im.save()

    # Split destination volume along z
    sct.printv('\nSplit destination volume...', verbose)
    im_dest = Image('dest.nii')
    split_dest_list = split_data(im_dest, 2)
    for im in split_dest_list:
        im.save()

    # display image
    data_src = im_src.data
    data_dest = im_dest.data

    if len(data_src.shape) == 2:
        # reshape 2D data into pseudo 3D (only one slice)
        new_shape = list(data_src.shape)
        new_shape.append(1)
        new_shape = tuple(new_shape)
        data_src = data_src.reshape(new_shape)
        data_dest = data_dest.reshape(new_shape)

    # initialize displacement and rotation
    coord_src = [None] * nz
    pca_src = [None] * nz
    coord_dest = [None] * nz
    pca_dest = [None] * nz
    centermass_src = np.zeros([nz, 2])
    centermass_dest = np.zeros([nz, 2])
    # displacement_forward = np.zeros([nz, 2])
    # displacement_inverse = np.zeros([nz, 2])
    angle_src_dest = np.zeros(nz)
    z_nonzero = []
    # Loop across slices
    for iz in range(0, nz):
        try:
            # compute PCA and get center or mass
            coord_src[iz], pca_src[iz], centermass_src[iz, :] = compute_pca(
                data_src[:, :, iz])
            coord_dest[iz], pca_dest[iz], centermass_dest[iz, :] = compute_pca(
                data_dest[:, :, iz])
            # compute (src,dest) angle for first eigenvector
            if rot == 1:
                eigenv_src = pca_src[iz].components_.T[0][0], pca_src[
                    iz].components_.T[1][0]  # pca_src.components_.T[0]
                eigenv_dest = pca_dest[iz].components_.T[0][0], pca_dest[
                    iz].components_.T[1][0]  # pca_dest.components_.T[0]
                angle_src_dest[iz] = angle_between(eigenv_src, eigenv_dest)
                # check if ratio between the two eigenvectors is high enough to prevent poor robustness
                if pca_src[iz].explained_variance_ratio_[0] / pca_src[
                        iz].explained_variance_ratio_[1] < pca_eigenratio_th:
                    angle_src_dest[iz] = 0
                if pca_dest[iz].explained_variance_ratio_[0] / pca_dest[
                        iz].explained_variance_ratio_[1] < pca_eigenratio_th:
                    angle_src_dest[iz] = 0
            # append to list of z_nonzero
            z_nonzero.append(iz)
        # if one of the slice is empty, ignore it
        except ValueError:
            sct.printv(
                'WARNING: Slice #' + str(iz) +
                ' is empty. It will be ignored.', verbose, 'warning')

    # regularize rotation
    if not poly == 0 and rot == 1:
        from msct_smooth import polynomial_fit
        angle_src_dest_regularized = polynomial_fit(z_nonzero,
                                                    angle_src_dest[z_nonzero],
                                                    poly)[0]
        # display
        if verbose == 2:
            plt.plot(180 * angle_src_dest[z_nonzero] / np.pi)
            plt.plot(180 * angle_src_dest_regularized / np.pi,
                     'r',
                     linewidth=2)
            plt.grid()
            plt.xlabel('z')
            plt.ylabel('Angle (deg)')
            plt.savefig(path_qc +
                        'register2d_centermassrot_regularize_rotation.png')
            plt.close()
        # update variable
        angle_src_dest[z_nonzero] = angle_src_dest_regularized

    # initialize warping fields
    # N.B. forward transfo is defined in destination space and inverse transfo is defined in the source space
    warp_x = np.zeros(data_dest.shape)
    warp_y = np.zeros(data_dest.shape)
    warp_inv_x = np.zeros(data_src.shape)
    warp_inv_y = np.zeros(data_src.shape)

    # construct 3D warping matrix
    for iz in z_nonzero:
        print str(iz) + '/' + str(nz) + '..',
        # get indices of x and y coordinates
        row, col = np.indices((nx, ny))
        # build 2xn array of coordinates in pixel space
        coord_init_pix = np.array([
            row.ravel(),
            col.ravel(),
            np.array(np.ones(len(row.ravel())) * iz)
        ]).T
        # convert coordinates to physical space
        coord_init_phy = np.array(im_src.transfo_pix2phys(coord_init_pix))
        # get centermass coordinates in physical space
        centermass_src_phy = im_src.transfo_pix2phys(
            [[centermass_src[iz, :].T[0], centermass_src[iz, :].T[1], iz]])[0]
        centermass_dest_phy = im_src.transfo_pix2phys(
            [[centermass_dest[iz, :].T[0], centermass_dest[iz, :].T[1],
              iz]])[0]
        # build rotation matrix
        R = np.matrix(((cos(angle_src_dest[iz]), sin(angle_src_dest[iz])),
                       (-sin(angle_src_dest[iz]), cos(angle_src_dest[iz]))))
        # build 3D rotation matrix
        R3d = np.eye(3)
        R3d[0:2, 0:2] = R
        # apply forward transformation (in physical space)
        coord_forward_phy = np.array(
            np.dot((coord_init_phy - np.transpose(centermass_dest_phy)), R3d) +
            np.transpose(centermass_src_phy))
        # apply inverse transformation (in physical space)
        coord_inverse_phy = np.array(
            np.dot((coord_init_phy -
                    np.transpose(centermass_src_phy)), R3d.T) +
            np.transpose(centermass_dest_phy))
        # display rotations
        if verbose == 2 and not angle_src_dest[iz] == 0:
            # compute new coordinates
            coord_src_rot = coord_src[iz] * R
            coord_dest_rot = coord_dest[iz] * R.T
            # generate figure
            plt.figure('iz=' + str(iz) + ', angle_src_dest=' +
                       str(angle_src_dest[iz]),
                       figsize=(9, 9))
            # plt.ion()  # enables interactive mode (allows keyboard interruption)
            # plt.title('iz='+str(iz))
            for isub in [221, 222, 223, 224]:
                # plt.figure
                plt.subplot(isub)
                # ax = matplotlib.pyplot.axis()
                if isub == 221:
                    plt.scatter(coord_src[iz][:, 0],
                                coord_src[iz][:, 1],
                                s=5,
                                marker='o',
                                zorder=10,
                                color='steelblue',
                                alpha=0.5)
                    pcaaxis = pca_src[iz].components_.T
                    pca_eigenratio = pca_src[iz].explained_variance_ratio_
                    plt.title('src')
                elif isub == 222:
                    plt.scatter(coord_src_rot[:, 0],
                                coord_src_rot[:, 1],
                                s=5,
                                marker='o',
                                zorder=10,
                                color='steelblue',
                                alpha=0.5)
                    pcaaxis = pca_dest[iz].components_.T
                    pca_eigenratio = pca_dest[iz].explained_variance_ratio_
                    plt.title('src_rot')
                elif isub == 223:
                    plt.scatter(coord_dest[iz][:, 0],
                                coord_dest[iz][:, 1],
                                s=5,
                                marker='o',
                                zorder=10,
                                color='red',
                                alpha=0.5)
                    pcaaxis = pca_dest[iz].components_.T
                    pca_eigenratio = pca_dest[iz].explained_variance_ratio_
                    plt.title('dest')
                elif isub == 224:
                    plt.scatter(coord_dest_rot[:, 0],
                                coord_dest_rot[:, 1],
                                s=5,
                                marker='o',
                                zorder=10,
                                color='red',
                                alpha=0.5)
                    pcaaxis = pca_src[iz].components_.T
                    pca_eigenratio = pca_src[iz].explained_variance_ratio_
                    plt.title('dest_rot')
                plt.text(-2.5,
                         -2,
                         'eigenvectors:',
                         horizontalalignment='left',
                         verticalalignment='bottom')
                plt.text(-2.5,
                         -2.8,
                         str(pcaaxis),
                         horizontalalignment='left',
                         verticalalignment='bottom')
                plt.text(-2.5,
                         2.5,
                         'eigenval_ratio:',
                         horizontalalignment='left',
                         verticalalignment='bottom')
                plt.text(-2.5,
                         2,
                         str(pca_eigenratio),
                         horizontalalignment='left',
                         verticalalignment='bottom')
                plt.plot([0, pcaaxis[0, 0]], [0, pcaaxis[1, 0]],
                         linewidth=2,
                         color='red')
                plt.plot([0, pcaaxis[0, 1]], [0, pcaaxis[1, 1]],
                         linewidth=2,
                         color='orange')
                plt.axis([-3, 3, -3, 3])
                plt.gca().set_aspect('equal', adjustable='box')
                # plt.axis('equal')
            plt.savefig(path_qc + 'register2d_centermassrot_pca_z' + str(iz) +
                        '.png')
            plt.close()

        # construct 3D warping matrix
        warp_x[:, :, iz] = np.array([
            coord_forward_phy[i, 0] - coord_init_phy[i, 0]
            for i in xrange(nx * ny)
        ]).reshape((nx, ny))
        warp_y[:, :, iz] = np.array([
            coord_forward_phy[i, 1] - coord_init_phy[i, 1]
            for i in xrange(nx * ny)
        ]).reshape((nx, ny))
        warp_inv_x[:, :, iz] = np.array([
            coord_inverse_phy[i, 0] - coord_init_phy[i, 0]
            for i in xrange(nx * ny)
        ]).reshape((nx, ny))
        warp_inv_y[:, :, iz] = np.array([
            coord_inverse_phy[i, 1] - coord_init_phy[i, 1]
            for i in xrange(nx * ny)
        ]).reshape((nx, ny))

    # Generate forward warping field (defined in destination space)
    generate_warping_field(fname_dest, warp_x, warp_y, fname_warp, verbose)
    generate_warping_field(fname_src, warp_inv_x, warp_inv_y, fname_warp_inv,
                           verbose)
Exemple #32
0
def register2d_centermassrot(fname_src, fname_dest, fname_warp='warp_forward.nii.gz', fname_warp_inv='warp_inverse.nii.gz', rot=1, poly=0, path_qc='./', verbose=0, pca_eigenratio_th=1.6):
    """
    Rotate the source image to match the orientation of the destination image, using the first and second eigenvector
    of the PCA. This function should be used on segmentations (not images).
    This works for 2D and 3D images.  If 3D, it splits the image and performs the rotation slice-by-slice.
    input:
        fname_source: name of moving image (type: string)
        fname_dest: name of fixed image (type: string)
        fname_warp: name of output 3d forward warping field
        fname_warp_inv: name of output 3d inverse warping field
        rot: estimate rotation with PCA (type: int)
        poly: degree of polynomial regularization along z for rotation angle (type: int). 0: no regularization
        verbose:
    output:
        none
    """

    if verbose == 2:
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt

    # Get image dimensions and retrieve nz
    sct.printv('\nGet image dimensions of destination image...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_dest).dim
    sct.printv('  matrix size: '+str(nx)+' x '+str(ny)+' x '+str(nz), verbose)
    sct.printv('  voxel size:  '+str(px)+'mm x '+str(py)+'mm x '+str(pz)+'mm', verbose)

    # Split source volume along z
    sct.printv('\nSplit input volume...', verbose)
    from sct_image import split_data
    im_src = Image('src.nii')
    split_source_list = split_data(im_src, 2)
    for im in split_source_list:
        im.save()

    # Split destination volume along z
    sct.printv('\nSplit destination volume...', verbose)
    im_dest = Image('dest.nii')
    split_dest_list = split_data(im_dest, 2)
    for im in split_dest_list:
        im.save()

    # display image
    data_src = im_src.data
    data_dest = im_dest.data

    if len(data_src.shape) == 2:
        # reshape 2D data into pseudo 3D (only one slice)
        new_shape = list(data_src.shape)
        new_shape.append(1)
        new_shape = tuple(new_shape)
        data_src = data_src.reshape(new_shape)
        data_dest = data_dest.reshape(new_shape)

    # initialize displacement and rotation
    coord_src = [None] * nz
    pca_src = [None] * nz
    coord_dest = [None] * nz
    pca_dest = [None] * nz
    centermass_src = np.zeros([nz, 2])
    centermass_dest = np.zeros([nz, 2])
    # displacement_forward = np.zeros([nz, 2])
    # displacement_inverse = np.zeros([nz, 2])
    angle_src_dest = np.zeros(nz)
    z_nonzero = []
    # Loop across slices
    for iz in range(0, nz):
        try:
            # compute PCA and get center or mass
            coord_src[iz], pca_src[iz], centermass_src[iz, :] = compute_pca(data_src[:, :, iz])
            coord_dest[iz], pca_dest[iz], centermass_dest[iz, :] = compute_pca(data_dest[:, :, iz])
            # compute (src,dest) angle for first eigenvector
            if rot == 1:
                eigenv_src = pca_src[iz].components_.T[0][0], pca_src[iz].components_.T[1][0]  # pca_src.components_.T[0]
                eigenv_dest = pca_dest[iz].components_.T[0][0], pca_dest[iz].components_.T[1][0]  # pca_dest.components_.T[0]
                angle_src_dest[iz] = angle_between(eigenv_src, eigenv_dest)
                # check if ratio between the two eigenvectors is high enough to prevent poor robustness
                if pca_src[iz].explained_variance_ratio_[0] / pca_src[iz].explained_variance_ratio_[1] < pca_eigenratio_th:
                    angle_src_dest[iz] = 0
                if pca_dest[iz].explained_variance_ratio_[0] / pca_dest[iz].explained_variance_ratio_[1] < pca_eigenratio_th:
                    angle_src_dest[iz] = 0
            # append to list of z_nonzero
            z_nonzero.append(iz)
        # if one of the slice is empty, ignore it
        except ValueError:
            sct.printv('WARNING: Slice #' + str(iz) + ' is empty. It will be ignored.', verbose, 'warning')

    # regularize rotation
    if not poly == 0 and rot == 1:
        from msct_smooth import polynomial_fit
        angle_src_dest_regularized = polynomial_fit(z_nonzero, angle_src_dest[z_nonzero], poly)[0]
        # display
        if verbose == 2:
            plt.plot(180 * angle_src_dest[z_nonzero] / np.pi)
            plt.plot(180 * angle_src_dest_regularized / np.pi, 'r', linewidth=2)
            plt.grid()
            plt.xlabel('z')
            plt.ylabel('Angle (deg)')
            plt.savefig(path_qc+'register2d_centermassrot_regularize_rotation.png')
            plt.close()
        # update variable
        angle_src_dest[z_nonzero] = angle_src_dest_regularized

    # initialize warping fields
    # N.B. forward transfo is defined in destination space and inverse transfo is defined in the source space
    warp_x = np.zeros(data_dest.shape)
    warp_y = np.zeros(data_dest.shape)
    warp_inv_x = np.zeros(data_src.shape)
    warp_inv_y = np.zeros(data_src.shape)

    # construct 3D warping matrix
    for iz in z_nonzero:
        print str(iz)+'/'+str(nz)+'..',
        # get indices of x and y coordinates
        row, col = np.indices((nx, ny))
        # build 2xn array of coordinates in pixel space
        coord_init_pix = np.array([row.ravel(), col.ravel(), np.array(np.ones(len(row.ravel())) * iz)]).T
        # convert coordinates to physical space
        coord_init_phy = np.array(im_src.transfo_pix2phys(coord_init_pix))
        # get centermass coordinates in physical space
        centermass_src_phy = im_src.transfo_pix2phys([[centermass_src[iz, :].T[0], centermass_src[iz, :].T[1], iz]])[0]
        centermass_dest_phy = im_src.transfo_pix2phys([[centermass_dest[iz, :].T[0], centermass_dest[iz, :].T[1], iz]])[0]
        # build rotation matrix
        R = np.matrix(((cos(angle_src_dest[iz]), sin(angle_src_dest[iz])), (-sin(angle_src_dest[iz]), cos(angle_src_dest[iz]))))
        # build 3D rotation matrix
        R3d = np.eye(3)
        R3d[0:2, 0:2] = R
        # apply forward transformation (in physical space)
        coord_forward_phy = np.array(np.dot((coord_init_phy - np.transpose(centermass_dest_phy)), R3d) + np.transpose(centermass_src_phy))
        # apply inverse transformation (in physical space)
        coord_inverse_phy = np.array(np.dot((coord_init_phy - np.transpose(centermass_src_phy)), R3d.T) + np.transpose(centermass_dest_phy))
        # display rotations
        if verbose == 2 and not angle_src_dest[iz] == 0:
            # compute new coordinates
            coord_src_rot = coord_src[iz] * R
            coord_dest_rot = coord_dest[iz] * R.T
            # generate figure
            plt.figure('iz=' + str(iz) + ', angle_src_dest=' + str(angle_src_dest[iz]), figsize=(9, 9))
            # plt.ion()  # enables interactive mode (allows keyboard interruption)
            # plt.title('iz='+str(iz))
            for isub in [221, 222, 223, 224]:
                # plt.figure
                plt.subplot(isub)
                # ax = matplotlib.pyplot.axis()
                if isub == 221:
                    plt.scatter(coord_src[iz][:, 0], coord_src[iz][:, 1], s=5, marker='o', zorder=10, color='steelblue',
                                alpha=0.5)
                    pcaaxis = pca_src[iz].components_.T
                    pca_eigenratio = pca_src[iz].explained_variance_ratio_
                    plt.title('src')
                elif isub == 222:
                    plt.scatter(coord_src_rot[:, 0], coord_src_rot[:, 1], s=5, marker='o', zorder=10,
                                color='steelblue',
                                alpha=0.5)
                    pcaaxis = pca_dest[iz].components_.T
                    pca_eigenratio = pca_dest[iz].explained_variance_ratio_
                    plt.title('src_rot')
                elif isub == 223:
                    plt.scatter(coord_dest[iz][:, 0], coord_dest[iz][:, 1], s=5, marker='o', zorder=10, color='red',
                                alpha=0.5)
                    pcaaxis = pca_dest[iz].components_.T
                    pca_eigenratio = pca_dest[iz].explained_variance_ratio_
                    plt.title('dest')
                elif isub == 224:
                    plt.scatter(coord_dest_rot[:, 0], coord_dest_rot[:, 1], s=5, marker='o', zorder=10, color='red',
                                alpha=0.5)
                    pcaaxis = pca_src[iz].components_.T
                    pca_eigenratio = pca_src[iz].explained_variance_ratio_
                    plt.title('dest_rot')
                plt.text(-2.5, -2, 'eigenvectors:', horizontalalignment='left', verticalalignment='bottom')
                plt.text(-2.5, -2.8, str(pcaaxis), horizontalalignment='left', verticalalignment='bottom')
                plt.text(-2.5, 2.5, 'eigenval_ratio:', horizontalalignment='left', verticalalignment='bottom')
                plt.text(-2.5, 2, str(pca_eigenratio), horizontalalignment='left', verticalalignment='bottom')
                plt.plot([0, pcaaxis[0, 0]], [0, pcaaxis[1, 0]], linewidth=2, color='red')
                plt.plot([0, pcaaxis[0, 1]], [0, pcaaxis[1, 1]], linewidth=2, color='orange')
                plt.axis([-3, 3, -3, 3])
                plt.gca().set_aspect('equal', adjustable='box')
                # plt.axis('equal')
            plt.savefig(path_qc + 'register2d_centermassrot_pca_z' + str(iz) + '.png')
            plt.close()

        # construct 3D warping matrix
        warp_x[:, :, iz] = np.array([coord_forward_phy[i, 0] - coord_init_phy[i, 0] for i in xrange(nx * ny)]).reshape((nx, ny))
        warp_y[:, :, iz] = np.array([coord_forward_phy[i, 1] - coord_init_phy[i, 1] for i in xrange(nx * ny)]).reshape((nx, ny))
        warp_inv_x[:, :, iz] = np.array([coord_inverse_phy[i, 0] - coord_init_phy[i, 0] for i in xrange(nx * ny)]).reshape((nx, ny))
        warp_inv_y[:, :, iz] = np.array([coord_inverse_phy[i, 1] - coord_init_phy[i, 1] for i in xrange(nx * ny)]).reshape((nx, ny))

    # Generate forward warping field (defined in destination space)
    generate_warping_field(fname_dest, warp_x, warp_y, fname_warp, verbose)
    generate_warping_field(fname_src, warp_inv_x, warp_inv_y, fname_warp_inv, verbose)
Exemple #33
0
def main():

    # get default parameters
    step1 = Paramreg(step='1',
                     type='seg',
                     algo='slicereg',
                     metric='MeanSquares',
                     iter='10')
    step2 = Paramreg(step='2', type='im', algo='syn', metric='MI', iter='3')
    # step1 = Paramreg()
    paramreg = ParamregMultiStep([step1, step2])

    # step1 = Paramreg_step(step='1', type='seg', algo='bsplinesyn', metric='MeanSquares', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # step2 = Paramreg_step(step='2', type='im', algo='syn', metric='MI', iter='10', shrink='1', smooth='0', gradStep='0.5')
    # paramreg = ParamregMultiStep([step1, step2])

    # Initialize the parser
    parser = Parser(__file__)
    parser.usage.set_description('Register anatomical image to the template.')
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(
        name="-l",
        type_value="file",
        description=
        "Labels. See: http://sourceforge.net/p/spinalcordtoolbox/wiki/create_labels/",
        mandatory=True,
        default_value='',
        example="anat_labels.nii.gz")
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to MNI-Poly-AMU template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(
        name="-p",
        type_value=[[':'], 'str'],
        description=
        """Parameters for registration (see sct_register_multimodal). Default:\n--\nstep=1\ntype="""
        + paramreg.steps['1'].type + """\nalgo=""" + paramreg.steps['1'].algo +
        """\nmetric=""" + paramreg.steps['1'].metric + """\npoly=""" +
        paramreg.steps['1'].poly + """\n--\nstep=2\ntype=""" +
        paramreg.steps['2'].type + """\nalgo=""" + paramreg.steps['2'].algo +
        """\nmetric=""" + paramreg.steps['2'].metric + """\niter=""" +
        paramreg.steps['2'].iter + """\nshrink=""" +
        paramreg.steps['2'].shrink + """\nsmooth=""" +
        paramreg.steps['2'].smooth + """\ngradStep=""" +
        paramreg.steps['2'].gradStep + """\n--""",
        mandatory=False,
        example=
        "step=2,type=seg,algo=bsplinesyn,metric=MeanSquares,iter=5,shrink=2:step=3,type=im,algo=syn,metric=MI,iter=5,shrink=1,gradStep=0.3"
    )
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value='1',
                      example=['0', '1'])
    parser.add_option(
        name="-v",
        type_value="multiple_choice",
        description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
        mandatory=False,
        default_value=param.verbose,
        example=['0', '1', '2'])
    if param.debug:
        print '\n*** WARNING: DEBUG MODE ON ***\n'
        fname_data = '/Users/julien/data/temp/sct_example_data/t2/t2.nii.gz'
        fname_landmarks = '/Users/julien/data/temp/sct_example_data/t2/labels.nii.gz'
        fname_seg = '/Users/julien/data/temp/sct_example_data/t2/t2_seg.nii.gz'
        path_template = param.path_template
        remove_temp_files = 0
        verbose = 2
        # speed = 'superfast'
        #param_reg = '2,BSplineSyN,0.6,MeanSquares'
    else:
        arguments = parser.parse(sys.argv[1:])

        # get arguments
        fname_data = arguments['-i']
        fname_seg = arguments['-s']
        fname_landmarks = arguments['-l']
        path_template = arguments['-t']
        remove_temp_files = int(arguments['-r'])
        verbose = int(arguments['-v'])
        if '-p' in arguments:
            paramreg_user = arguments['-p']
            # update registration parameters
            for paramStep in paramreg_user:
                paramreg.addStep(paramStep)

    # initialize other parameters
    file_template = param.file_template
    file_template_label = param.file_template_label
    file_template_seg = param.file_template_seg
    output_type = param.output_type
    zsubsample = param.zsubsample
    # smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # get absolute path - TO DO: remove! NEVER USE ABSOLUTE PATH...
    path_template = os.path.abspath(path_template)

    # get fname of the template + template objects
    fname_template = sct.slash_at_the_end(path_template, 1) + file_template
    fname_template_label = sct.slash_at_the_end(path_template,
                                                1) + file_template_label
    fname_template_seg = sct.slash_at_the_end(path_template,
                                              1) + file_template_seg

    # check file existence
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_label, verbose)
    sct.check_file_exist(fname_template_seg, verbose)

    # print arguments
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('.. Data:                 ' + fname_data, verbose)
    sct.printv('.. Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('.. Segmentation:         ' + fname_seg, verbose)
    sct.printv('.. Path template:        ' + path_template, verbose)
    sct.printv('.. Output type:          ' + str(output_type), verbose)
    sct.printv('.. Remove temp files:    ' + str(remove_temp_files), verbose)

    sct.printv('\nParameters for registration:')
    for pStep in range(1, len(paramreg.steps) + 1):
        sct.printv('Step #' + paramreg.steps[str(pStep)].step, verbose)
        sct.printv('.. Type #' + paramreg.steps[str(pStep)].type, verbose)
        sct.printv(
            '.. Algorithm................ ' + paramreg.steps[str(pStep)].algo,
            verbose)
        sct.printv(
            '.. Metric................... ' +
            paramreg.steps[str(pStep)].metric, verbose)
        sct.printv(
            '.. Number of iterations..... ' + paramreg.steps[str(pStep)].iter,
            verbose)
        sct.printv(
            '.. Shrink factor............ ' +
            paramreg.steps[str(pStep)].shrink, verbose)
        sct.printv(
            '.. Smoothing factor......... ' +
            paramreg.steps[str(pStep)].smooth, verbose)
        sct.printv(
            '.. Gradient step............ ' +
            paramreg.steps[str(pStep)].gradStep, verbose)
        sct.printv(
            '.. Degree of polynomial..... ' + paramreg.steps[str(pStep)].poly,
            verbose)

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    sct.printv('\nCheck input labels...')
    # check if label image contains coherent labels
    image_label = Image(fname_landmarks)
    # -> all labels must be different
    labels = image_label.getNonZeroCoordinates(sorting='value')
    hasDifferentLabels = True
    for lab in labels:
        for otherlabel in labels:
            if lab != otherlabel and lab.hasEqualValue(otherlabel):
                hasDifferentLabels = False
                break
    if not hasDifferentLabels:
        sct.printv(
            'ERROR: Wrong landmarks input. All labels must be different.',
            verbose, 'error')
    # all labels must be available in tempalte
    image_label_template = Image(fname_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(
        sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv(
            'ERROR: Wrong landmarks input. Labels must have correspondance in tempalte space. \nLabel max '
            'provided: ' + str(labels[-1].value) +
            '\nLabel max from template: ' + str(labels_template[-1].value),
            verbose, 'error')

    # create temporary folder
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = 'tmp.' + time.strftime("%y%m%d%H%M%S")
    status, output = sct.run('mkdir ' + path_tmp)

    # copy files to temporary folder
    sct.printv('\nCopy files...', verbose)
    sct.run('isct_c3d ' + fname_data + ' -o ' + path_tmp + '/data.nii')
    sct.run('isct_c3d ' + fname_landmarks + ' -o ' + path_tmp +
            '/landmarks.nii.gz')
    sct.run('isct_c3d ' + fname_seg + ' -o ' + path_tmp +
            '/segmentation.nii.gz')
    sct.run('isct_c3d ' + fname_template + ' -o ' + path_tmp + '/template.nii')
    sct.run('isct_c3d ' + fname_template_label + ' -o ' + path_tmp +
            '/template_labels.nii.gz')
    sct.run('isct_c3d ' + fname_template_seg + ' -o ' + path_tmp +
            '/template_seg.nii.gz')

    # go to tmp folder
    os.chdir(path_tmp)

    # resample data to 1mm isotropic
    sct.printv('\nResample data to 1mm isotropic...', verbose)
    sct.run(
        'isct_c3d data.nii -resample-mm 1.0x1.0x1.0mm -interpolation Linear -o datar.nii'
    )
    sct.run(
        'isct_c3d segmentation.nii.gz -resample-mm 1.0x1.0x1.0mm -interpolation NearestNeighbor -o segmentationr.nii.gz'
    )
    # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling with neighrest neighbour can make them disappear. Therefore a more clever approach is required.
    resample_labels('landmarks.nii.gz', 'datar.nii', 'landmarksr.nii.gz')
    # # TODO
    # sct.run('sct_label_utils -i datar.nii -t create -x 124,186,19,2:129,98,23,8 -o landmarksr.nii.gz')

    # Change orientation of input images to RPI
    sct.printv('\nChange orientation of input images to RPI...', verbose)
    set_orientation('datar.nii', 'RPI', 'data_rpi.nii')
    set_orientation('landmarksr.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    set_orientation('segmentationr.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # # Change orientation of input images to RPI
    # sct.printv('\nChange orientation of input images to RPI...', verbose)
    # set_orientation('data.nii', 'RPI', 'data_rpi.nii')
    # set_orientation('landmarks.nii.gz', 'RPI', 'landmarks_rpi.nii.gz')
    # set_orientation('segmentation.nii.gz', 'RPI', 'segmentation_rpi.nii.gz')

    # get landmarks in native space
    # crop segmentation
    # output: segmentation_rpi_crop.nii.gz
    sct.run(
        'sct_crop_image -i segmentation_rpi.nii.gz -o segmentation_rpi_crop.nii.gz -dim 2 -bzmax'
    )

    # straighten segmentation
    sct.printv('\nStraighten the spinal cord using centerline/segmentation...',
               verbose)
    sct.run(
        'sct_straighten_spinalcord -i segmentation_rpi_crop.nii.gz -c segmentation_rpi_crop.nii.gz -r 0 -v '
        + str(verbose), verbose)
    # re-define warping field using non-cropped space (to avoid issue #367)
    sct.run(
        'sct_concat_transfo -w warp_straight2curve.nii.gz -d data_rpi.nii -o warp_straight2curve.nii.gz'
    )

    # Label preparation:
    # --------------------------------------------------------------------------------
    # Remove unused label on template. Keep only label present in the input label image
    sct.printv(
        '\nRemove unused label on template. Keep only label present in the input label image...',
        verbose)
    sct.run(
        'sct_label_utils -t remove -i template_labels.nii.gz -o template_label.nii.gz -r landmarks_rpi.nii.gz'
    )

    # Make sure landmarks are INT
    sct.printv('\nConvert landmarks to INT...', verbose)
    sct.run(
        'isct_c3d template_label.nii.gz -type int -o template_label.nii.gz',
        verbose)

    # Create a cross for the template labels - 5 mm
    sct.printv('\nCreate a 5 mm cross for the template labels...', verbose)
    sct.run(
        'sct_label_utils -t cross -i template_label.nii.gz -o template_label_cross.nii.gz -c 5'
    )

    # Create a cross for the input labels and dilate for straightening preparation - 5 mm
    sct.printv(
        '\nCreate a 5mm cross for the input labels and dilate for straightening preparation...',
        verbose)
    sct.run(
        'sct_label_utils -t cross -i landmarks_rpi.nii.gz -o landmarks_rpi_cross3x3.nii.gz -c 5 -d'
    )

    # Apply straightening to labels
    sct.printv('\nApply straightening to labels...', verbose)
    sct.run(
        'sct_apply_transfo -i landmarks_rpi_cross3x3.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz -d segmentation_rpi_crop_straight.nii.gz -w warp_curve2straight.nii.gz -x nn'
    )

    # Convert landmarks from FLOAT32 to INT
    sct.printv('\nConvert landmarks from FLOAT32 to INT...', verbose)
    sct.run(
        'isct_c3d landmarks_rpi_cross3x3_straight.nii.gz -type int -o landmarks_rpi_cross3x3_straight.nii.gz'
    )

    # Remove labels that do not correspond with each others.
    sct.printv('\nRemove labels that do not correspond with each others.',
               verbose)
    sct.run(
        'sct_label_utils -t remove-symm -i landmarks_rpi_cross3x3_straight.nii.gz -o landmarks_rpi_cross3x3_straight.nii.gz,template_label_cross.nii.gz -r template_label_cross.nii.gz'
    )

    # Estimate affine transfo: straight --> template (landmark-based)'
    sct.printv(
        '\nEstimate affine transfo: straight anat --> template (landmark-based)...',
        verbose)
    # converting landmarks straight and curved to physical coordinates
    image_straight = Image('landmarks_rpi_cross3x3_straight.nii.gz')
    landmark_straight = image_straight.getNonZeroCoordinates(sorting='value')
    image_template = Image('template_label_cross.nii.gz')
    landmark_template = image_template.getNonZeroCoordinates(sorting='value')
    # Reorganize landmarks
    points_fixed, points_moving = [], []
    landmark_straight_mean = []
    for coord in landmark_straight:
        if coord.value not in [c.value for c in landmark_straight_mean]:
            temp_landmark = coord
            temp_number = 1
            for other_coord in landmark_straight:
                if coord.hasEqualValue(other_coord) and coord != other_coord:
                    temp_landmark += other_coord
                    temp_number += 1
            landmark_straight_mean.append(temp_landmark / temp_number)

    for coord in landmark_straight_mean:
        point_straight = image_straight.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_moving.append(
            [point_straight[0][0], point_straight[0][1], point_straight[0][2]])
    for coord in landmark_template:
        point_template = image_template.transfo_pix2phys(
            [[coord.x, coord.y, coord.z]])
        points_fixed.append(
            [point_template[0][0], point_template[0][1], point_template[0][2]])

    # Register curved landmarks on straight landmarks based on python implementation
    sct.printv(
        '\nComputing rigid transformation (algo=translation-scaling-z) ...',
        verbose)
    import msct_register_landmarks
    (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = \
        msct_register_landmarks.getRigidTransformFromLandmarks(
            points_fixed, points_moving, constraints='translation-scaling-z', show=False)

    # writing rigid transformation file
    text_file = open("straight2templateAffine.txt", "w")
    text_file.write("#Insight Transform File V1.0\n")
    text_file.write("#Transform 0\n")
    text_file.write(
        "Transform: FixedCenterOfRotationAffineTransform_double_3_3\n")
    text_file.write(
        "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
        % (1.0 / rotation_matrix[0, 0], rotation_matrix[0, 1],
           rotation_matrix[0, 2], rotation_matrix[1, 0],
           1.0 / rotation_matrix[1, 1], rotation_matrix[1, 2],
           rotation_matrix[2, 0], rotation_matrix[2, 1],
           1.0 / rotation_matrix[2, 2], translation_array[0, 0],
           translation_array[0, 1], -translation_array[0, 2]))
    text_file.write("FixedParameters: %.9f %.9f %.9f\n" %
                    (points_moving_barycenter[0], points_moving_barycenter[1],
                     points_moving_barycenter[2]))
    text_file.close()

    # Apply affine transformation: straight --> template
    sct.printv('\nApply affine transformation: straight --> template...',
               verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt -d template.nii -o warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i data_rpi.nii -o data_rpi_straight2templateAffine.nii -d template.nii -w warp_curve2straightAffine.nii.gz'
    )
    sct.run(
        'sct_apply_transfo -i segmentation_rpi.nii.gz -o segmentation_rpi_straight2templateAffine.nii.gz -d template.nii -w warp_curve2straightAffine.nii.gz -x linear'
    )

    # threshold to 0.5
    nii = Image('segmentation_rpi_straight2templateAffine.nii.gz')
    data = nii.data
    data[data < 0.5] = 0
    nii.data = data
    nii.setFileName('segmentation_rpi_straight2templateAffine_th.nii.gz')
    nii.save()
    # find min-max of anat2template (for subsequent cropping)
    zmin_template, zmax_template = find_zmin_zmax(
        'segmentation_rpi_straight2templateAffine_th.nii.gz')

    # crop template in z-direction (for faster processing)
    sct.printv('\nCrop data in template space (for faster processing)...',
               verbose)
    sct.run(
        'sct_crop_image -i template.nii -o template_crop.nii -dim 2 -start ' +
        str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i template_seg.nii.gz -o template_seg_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i data_rpi_straight2templateAffine.nii -o data_rpi_straight2templateAffine_crop.nii -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    sct.run(
        'sct_crop_image -i segmentation_rpi_straight2templateAffine.nii.gz -o segmentation_rpi_straight2templateAffine_crop.nii.gz -dim 2 -start '
        + str(zmin_template) + ' -end ' + str(zmax_template))
    # sub-sample in z-direction
    sct.printv('\nSub-sample in z-direction (for faster processing)...',
               verbose)
    sct.run(
        'sct_resample -i template_crop.nii -o template_crop_r.nii -f 1x1x' +
        zsubsample, verbose)
    sct.run(
        'sct_resample -i template_seg_crop.nii.gz -o template_seg_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i data_rpi_straight2templateAffine_crop.nii -o data_rpi_straight2templateAffine_crop_r.nii -f 1x1x'
        + zsubsample, verbose)
    sct.run(
        'sct_resample -i segmentation_rpi_straight2templateAffine_crop.nii.gz -o segmentation_rpi_straight2templateAffine_crop_r.nii.gz -f 1x1x'
        + zsubsample, verbose)

    # Registration straight spinal cord to template
    sct.printv('\nRegister straight spinal cord to template...', verbose)

    # loop across registration steps
    warp_forward = []
    warp_inverse = []
    for i_step in range(1, len(paramreg.steps) + 1):
        sct.printv(
            '\nEstimate transformation for step #' + str(i_step) + '...',
            verbose)
        # identify which is the src and dest
        if paramreg.steps[str(i_step)].type == 'im':
            src = 'data_rpi_straight2templateAffine_crop_r.nii'
            dest = 'template_crop_r.nii'
            interp_step = 'linear'
        elif paramreg.steps[str(i_step)].type == 'seg':
            src = 'segmentation_rpi_straight2templateAffine_crop_r.nii.gz'
            dest = 'template_seg_crop_r.nii.gz'
            interp_step = 'nn'
        else:
            sct.printv('ERROR: Wrong image type.', 1, 'error')
        # if step>1, apply warp_forward_concat to the src image to be used
        if i_step > 1:
            # sct.run('sct_apply_transfo -i '+src+' -d '+dest+' -w '+','.join(warp_forward)+' -o '+sct.add_suffix(src, '_reg')+' -x '+interp_step, verbose)
            sct.run(
                'sct_apply_transfo -i ' + src + ' -d ' + dest + ' -w ' +
                ','.join(warp_forward) + ' -o ' + sct.add_suffix(src, '_reg') +
                ' -x ' + interp_step, verbose)
            src = sct.add_suffix(src, '_reg')
        # register src --> dest
        warp_forward_out, warp_inverse_out = register(src, dest, paramreg,
                                                      param, str(i_step))
        warp_forward.append(warp_forward_out)
        warp_inverse.append(warp_inverse_out)

    # Concatenate transformations:
    sct.printv('\nConcatenate transformations: anat --> template...', verbose)
    sct.run(
        'sct_concat_transfo -w warp_curve2straightAffine.nii.gz,' +
        ','.join(warp_forward) +
        ' -d template.nii -o warp_anat2template.nii.gz', verbose)
    # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
    warp_inverse.reverse()
    sct.run(
        'sct_concat_transfo -w ' + ','.join(warp_inverse) +
        ',-straight2templateAffine.txt,warp_straight2curve.nii.gz -d data.nii -o warp_template2anat.nii.gz',
        verbose)

    # Apply warping fields to anat and template
    if output_type == 1:
        sct.run(
            'sct_apply_transfo -i template.nii -o template2anat.nii.gz -d data.nii -w warp_template2anat.nii.gz -c 1',
            verbose)
        sct.run(
            'sct_apply_transfo -i data.nii -o anat2template.nii.gz -d template.nii -w warp_anat2template.nii.gz -c 1',
            verbose)

    # come back to parent folder
    os.chdir('..')

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    sct.generate_output_file(path_tmp + '/warp_template2anat.nii.gz',
                             'warp_template2anat.nii.gz', verbose)
    sct.generate_output_file(path_tmp + '/warp_anat2template.nii.gz',
                             'warp_anat2template.nii.gz', verbose)
    if output_type == 1:
        sct.generate_output_file(path_tmp + '/template2anat.nii.gz',
                                 'template2anat' + ext_data, verbose)
        sct.generate_output_file(path_tmp + '/anat2template.nii.gz',
                                 'anat2template' + ext_data, verbose)

    # Delete temporary files
    if remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.run('rm -rf ' + path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv(
        '\nFinished! Elapsed time: ' + str(int(round(elapsed_time))) + 's',
        verbose)

    # to view results
    sct.printv('\nTo view results, type:', verbose)
    sct.printv('fslview ' + fname_data + ' template2anat -b 0,4000 &', verbose,
               'info')
    sct.printv('fslview ' + fname_template + ' -b 0,5000 anat2template &\n',
               verbose, 'info')