예제 #1
0
 def getNonZeroCoordinates(self,
                           sorting=None,
                           reverse_coord=False,
                           coordValue=False):
     """
     This function return all the non-zero coordinates that the image contains.
     Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value'
     If reverse_coord is True, coordinate are sorted from larger to smaller.
     """
     from msct_types import Coordinate
     from sct_utils import printv
     try:
         if len(self.dim) == 3:
             X, Y, Z = (self.data > 0).nonzero()
             list_coordinates = [
                 Coordinate([X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]])
                 for i in range(0, len(X))
             ]
         elif len(self.dim) == 2:
             X, Y = (self.data > 0).nonzero()
             list_coordinates = [
                 Coordinate([X[i], Y[i], self.data[X[i], Y[i]]])
                 for i in range(0, len(X))
             ]
     except Exception, e:
         printv(
             'ERROR: Exception ' + str(e) +
             ' caught while geting non Zeros coordinates', 1, 'error')
예제 #2
0
    def getNonZeroCoordinates(self, sorting=None, reverse_coord=False, coordValue=False):
        """
        This function return all the non-zero coordinates that the image contains.
        Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value'
        If reverse_coord is True, coordinate are sorted from larger to smaller.
        """
        n_dim = 1
        if self.dim[3] == 1:
            n_dim = 3
        else:
            n_dim = 4
        if self.dim[2] == 1:
            n_dim = 2

        try:
            if n_dim == 3:
                X, Y, Z = (self.data > 0).nonzero()
                list_coordinates = [Coordinate([X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]]) for i in range(0, len(X))]
            elif n_dim == 2:
                try:
                    X, Y = (self.data > 0).nonzero()
                    list_coordinates = [Coordinate([X[i], Y[i], 0, self.data[X[i], Y[i]]]) for i in range(0, len(X))]
                except ValueError:
                    X, Y, Z = (self.data > 0).nonzero()
                    list_coordinates = [Coordinate([X[i], Y[i], 0, self.data[X[i], Y[i], 0]]) for i in range(0, len(X))]
        except Exception as e:
            sct.printv('ERROR: Exception ' + str(e) + ' caught while geting non Zeros coordinates', 1, 'error')

        if coordValue:
            from msct_types import CoordinateValue
            if n_dim == 3:
                list_coordinates = [CoordinateValue([X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]]) for i in range(0, len(X))]
            else:
                list_coordinates = [CoordinateValue([X[i], Y[i], 0, self.data[X[i], Y[i]]]) for i in range(0, len(X))]
        if sorting is not None:
            if reverse_coord not in [True, False]:
                raise ValueError('reverse_coord parameter must be a boolean')

            if sorting == 'x':
                list_coordinates = sorted(list_coordinates, key=lambda obj: obj.x, reverse=reverse_coord)
            elif sorting == 'y':
                list_coordinates = sorted(list_coordinates, key=lambda obj: obj.y, reverse=reverse_coord)
            elif sorting == 'z':
                list_coordinates = sorted(list_coordinates, key=lambda obj: obj.z, reverse=reverse_coord)
            elif sorting == 'value':
                list_coordinates = sorted(list_coordinates, key=lambda obj: obj.value, reverse=reverse_coord)
            else:
                raise ValueError("sorting parameter must be either 'x', 'y', 'z' or 'value'")

        return list_coordinates
예제 #3
0
    def getNonZeroCoordinates(self,
                              sorting=None,
                              reverse_coord=False,
                              coordValue=False):
        """
        This function return all the non-zero coordinates that the image contains.
        Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value'
        If reverse_coord is True, coordinate are sorted from larger to smaller.
        """

        X, Y, Z = (self.data > 0.0).nonzero()
        if coordValue:
            from msct_types import CoordinateValue
            list_coordinates = [
                CoordinateValue(
                    [X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]])
                for i in range(0, len(X))
            ]
        else:
            from msct_types import Coordinate
            list_coordinates = [
                Coordinate([X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]])
                for i in range(0, len(X))
            ]

        if sorting is not None:
            if reverse_coord not in [True, False]:
                raise ValueError('reverse_coord parameter must be a boolean')

            if sorting == 'x':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.x,
                                          reverse=reverse_coord)
            elif sorting == 'y':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.y,
                                          reverse=reverse_coord)
            elif sorting == 'z':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.z,
                                          reverse=reverse_coord)
            elif sorting == 'value':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.value,
                                          reverse=reverse_coord)
            else:
                raise ValueError(
                    "sorting parameter must be either 'x', 'y', 'z' or 'value'"
                )

        return list_coordinates
 def create_label_along_segmentation(self):
     """
     Create an image with labels defined along the spinal cord segmentation (or centerline).
     Input image does **not** need to be RPI (re-orientation is done within this function).
     Example:
     object_define=ProcessLabels(fname_segmentation, coordinates=[coord_1, coord_2, coord_i]), where coord_i='z,value'. If z=-1, then use z=nz/2 (i.e. center of FOV in superior-inferior direction)
     Returns
     """
     # reorient input image to RPI
     im_rpi = self.image_input.copy().change_orientation('RPI')
     im_output_rpi = zeros_like(im_rpi)
     # loop across labels
     for ilabel, coord in enumerate(self.coordinates):
         # split coord string
         list_coord = coord.split(',')
         # convert to int() and assign to variable
         z, value = [int(i) for i in list_coord]
         # update z based on native image orientation (z should represent superior-inferior axis)
         coord = Coordinate([z, z, z])  # since we don't know which dimension corresponds to the superior-inferior
         # axis, we put z in all dimensions (we don't care about x and y here)
         _, _, z_rpi = coord.permute(self.image_input, 'RPI')
         # if z=-1, replace with nz/2
         if z == -1:
             z_rpi = int(np.round(im_output_rpi.dim[2] / 2.0))
         # get center of mass of segmentation at given z
         x, y = ndimage.measurements.center_of_mass(np.array(im_rpi.data[:, :, z_rpi]))
         # round values to make indices
         x, y = int(np.round(x)), int(np.round(y))
         # display info
         sct.printv('Label #' + str(ilabel) + ': ' + str(x) + ',' + str(y) + ',' + str(z_rpi) + ' --> ' + str(value), 1)
         if len(im_output_rpi.data.shape) == 3:
             im_output_rpi.data[x, y, z_rpi] = value
         elif len(im_output_rpi.data.shape) == 2:
             assert str(z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(z)
             im_output_rpi.data[x, y] = value
     # change orientation back to native
     return im_output_rpi.change_orientation(self.image_input.orientation)
예제 #5
0
def add_label(brainstem_file, segmented_file, output_file_name, label_depth_compared_to_zmax=10 , label_value=1):
    #Calculating zmx of the segmented file  (data_RPI_seg.nii.gz)
    image_seg = Image(segmented_file)
    z_test = ComputeZMinMax(image_seg)
    zmax = z_test.Zmax
    print( "Zmax: ",zmax)

    #Test on the number of labels
    brainstem_image = Image(brainstem_file)
    print("nb_label_before=", np.sum(brainstem_image.data))

    #Center of mass
    X, Y = np.nonzero((z_test.image.data[:,:,zmax-label_depth_compared_to_zmax] > 0))
    x_bar = 0
    y_bar = 0
    for i in range(X.shape[0]):
        x_bar = x_bar+X[i]
        y_bar = y_bar+Y[i]
    x_bar = int(round(x_bar/X.shape[0]))
    y_bar = int(round(y_bar/X.shape[0]))

    #Placement du nouveau label aux coordonnees x_bar, y_bar et zmax-label_depth_compared_to_zmax
    coordi = Coordinate([x_bar, y_bar, zmax-label_depth_compared_to_zmax, label_value])
    object_for_process = ProcessLabels(brainstem_file, coordinates=[coordi])
    #print("object_for_process.coordinates=", object_for_process.coordinates.x, object_for_process.coordinates.y, object_for_process.coordinates.z)
    file_with_new_label=object_for_process.create_label()

    #Define output file
    im_output = object_for_process.image_input.copy()
    im_output.data *= 0
    brainstem_image=Image(brainstem_file)
    im_output.data = brainstem_image.data + file_with_new_label.data

    #Test the number of labels
    print("nb_label_after=", np.sum(im_output.data))

    #Save output file
    im_output.setFileName(output_file_name)
    im_output.save('minimize')
예제 #6
0
class Image(object):
    """

    """
    def __init__(self,
                 param=None,
                 hdr=None,
                 orientation=None,
                 absolutepath="",
                 verbose=1):
        from numpy import zeros, ndarray, generic
        from sct_utils import extract_fname
        from nibabel import AnalyzeHeader

        # initialization of all parameters
        self.verbose = verbose
        self.data = None
        self.absolutepath = ""
        self.path = ""
        self.file_name = ""
        self.ext = ""
        self.orientation = None
        if hdr == None:
            hdr = AnalyzeHeader()
            self.hdr = AnalyzeHeader()  #an empty header
        else:
            self.hdr = hdr

        self.dim = None

        self.verbose = verbose

        # load an image from file
        if type(param) is str:

            self.loadFromPath(param, verbose)
        # copy constructor
        elif isinstance(param, type(self)):
            self.copy(param)
        # create an empty image (full of zero) of dimension [dim]. dim must be [x,y,z] or (x,y,z). No header.
        elif type(param) is list:
            self.data = zeros(param)
            self.dim = param
            self.hdr = hdr
            self.orientation = orientation
            self.absolutepath = absolutepath
            self.path, self.file_name, self.ext = extract_fname(absolutepath)
        # create a copy of im_ref
        elif isinstance(param, (ndarray, generic)):
            self.data = param
            self.dim = self.data.shape
            self.hdr = hdr
            self.orientation = orientation
            self.absolutepath = absolutepath
            self.path, self.file_name, self.ext = extract_fname(absolutepath)
        else:
            raise TypeError(' Image constructor takes at least one argument.')

    def __deepcopy__(self, memo):
        from copy import deepcopy
        return type(self)(deepcopy(self.data, memo), deepcopy(self.hdr, memo),
                          deepcopy(self.orientation, memo),
                          deepcopy(self.absolutepath, memo))

    def copy(self, image=None):
        from copy import deepcopy
        from sct_utils import extract_fname
        if image is not None:
            self.data = deepcopy(image.data)
            self.dim = deepcopy(image.dim)
            self.hdr = deepcopy(image.hdr)
            self.orientation = deepcopy(image.orientation)
            self.absolutepath = deepcopy(image.absolutepath)
            self.path, self.file_name, self.ext = extract_fname(
                self.absolutepath)
        else:
            return deepcopy(self)

    def loadFromPath(self, path, verbose):
        """
        This function load an image from an absolute path using nibabel library
        :param path: path of the file from which the image will be loaded
        :return:
        """
        from nibabel import load, spatialimages
        from sct_utils import check_file_exist, printv, extract_fname, get_dimension
        from sct_orientation import get_orientation

        check_file_exist(path, verbose=verbose)
        try:
            im_file = load(path)
        except spatialimages.ImageFileError:
            printv('Error: make sure ' + path + ' is an image.', 1, 'error')
        self.orientation = get_orientation(path)
        self.data = im_file.get_data()
        self.hdr = im_file.get_header()
        self.absolutepath = path
        self.path, self.file_name, self.ext = extract_fname(path)
        nx, ny, nz, nt, px, py, pz, pt = get_dimension(path)
        self.dim = [nx, ny, nz]

    def setFileName(self, filename):
        from sct_utils import extract_fname
        self.absolutepath = filename
        self.path, self.file_name, self.ext = extract_fname(filename)

    def changeType(self, type=''):
        from numpy import uint8, uint16, uint32, uint64, int8, int16, int32, int64, float32, float64
        """
        Change the voxel type of the image
        :param type:    if not set, the image is saved in standard type
                        if 'minimize', image space is minimize
                        if 'minimize_int', image space is minimize and values are approximated to integers
                        (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),
                        (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),
                        (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),
                        (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),
                        (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),
                        (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),
                        (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),
                        (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),
                        (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),
                        (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),
                        (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),
                        (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),
                        (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),
                        (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"),
        :return:
        """
        if type == '':
            type = self.hdr.get_data_dtype()

        if type == 'minimize' or type == 'minimize_int':
            from numpy import nanmax, nanmin
            # compute max value in the image and choose the best pixel type to represent all the pixels within smallest memory space
            # warning: does not take intensity resolution into account, neither complex voxels
            max_vox = nanmax(self.data)
            min_vox = nanmin(self.data)

            # check if voxel values are real or integer
            isInteger = True
            if type == 'minimize':
                for vox in self.data.flatten():
                    if int(vox) != vox:
                        isInteger = False
                        break

            if isInteger:
                from numpy import iinfo, uint8, uint16, uint32, uint64
                if min_vox >= 0:  # unsigned
                    if max_vox <= iinfo(uint8).max:
                        type = 'uint8'
                    elif max_vox <= iinfo(uint16):
                        type = 'uint16'
                    elif max_vox <= iinfo(uint32).max:
                        type = 'uint32'
                    elif max_vox <= iinfo(uint64).max:
                        type = 'uint64'
                    else:
                        raise ValueError(
                            "Maximum value of the image is to big to be represented."
                        )
                else:
                    if max_vox <= iinfo(int8).max and min_vox >= iinfo(
                            int8).min:
                        type = 'int8'
                    elif max_vox <= iinfo(int16).max and min_vox >= iinfo(
                            int16).min:
                        type = 'int16'
                    elif max_vox <= iinfo(int32).max and min_vox >= iinfo(
                            int32).min:
                        type = 'int32'
                    elif max_vox <= iinfo(int64).max and min_vox >= iinfo(
                            int64).min:
                        type = 'int64'
                    else:
                        raise ValueError(
                            "Maximum value of the image is to big to be represented."
                        )
            else:
                from numpy import finfo, float32, float64
                # if max_vox <= np.finfo(np.float16).max and min_vox >= np.finfo(np.float16).min:
                #    type = 'np.float16' # not supported by nibabel
                if max_vox <= finfo(float32).max and min_vox >= finfo(
                        float32).min:
                    type = 'float32'
                elif max_vox <= finfo(float64).max and min_vox >= finfo(
                        float64).min:
                    type = 'float64'

        # print "The image has been set to "+type+" (previously "+str(self.hdr.get_data_dtype())+")"
        # change type of data in both numpy array and nifti header
        type_build = eval(type)
        self.data = type_build(self.data)
        self.hdr.set_data_dtype(type)

    def save(self, type=''):
        """
        Write an image in a nifti file
        :param type:    if not set, the image is saved in standard type
                        if 'minimize', image space is minimize
                        (2, 'uint8', np.uint8, "NIFTI_TYPE_UINT8"),
                        (4, 'int16', np.int16, "NIFTI_TYPE_INT16"),
                        (8, 'int32', np.int32, "NIFTI_TYPE_INT32"),
                        (16, 'float32', np.float32, "NIFTI_TYPE_FLOAT32"),
                        (32, 'complex64', np.complex64, "NIFTI_TYPE_COMPLEX64"),
                        (64, 'float64', np.float64, "NIFTI_TYPE_FLOAT64"),
                        (256, 'int8', np.int8, "NIFTI_TYPE_INT8"),
                        (512, 'uint16', np.uint16, "NIFTI_TYPE_UINT16"),
                        (768, 'uint32', np.uint32, "NIFTI_TYPE_UINT32"),
                        (1024,'int64', np.int64, "NIFTI_TYPE_INT64"),
                        (1280, 'uint64', np.uint64, "NIFTI_TYPE_UINT64"),
                        (1536, 'float128', _float128t, "NIFTI_TYPE_FLOAT128"),
                        (1792, 'complex128', np.complex128, "NIFTI_TYPE_COMPLEX128"),
                        (2048, 'complex256', _complex256t, "NIFTI_TYPE_COMPLEX256"),
        """
        from nibabel import Nifti1Image, save
        from sct_utils import printv

        if type != '':
            self.changeType(type)

        self.hdr.set_data_shape(self.data.shape)
        img = Nifti1Image(self.data, None, self.hdr)
        printv('saving ' + self.path + self.file_name + self.ext + '\n',
               verbose=self.verbose,
               type='normal')
        save(img, self.path + self.file_name + self.ext)

    # flatten the array in a single dimension vector, its shape will be (d, 1) compared to the flatten built in method
    # which would have returned (d,)
    def flatten(self):
        # return self.data.flatten().reshape(self.data.flatten().shape[0], 1)
        return self.data.flatten()

    # return a list of the image slices flattened
    def slices(self):
        slices = []
        for slc in self.data:
            slices.append(slc.flatten())
        return slices

    def getNonZeroCoordinates(self,
                              sorting=None,
                              reverse_coord=False,
                              coordValue=False):
        """
        This function return all the non-zero coordinates that the image contains.
        Coordinate list can also be sorted by x, y, z, or the value with the parameter sorting='x', sorting='y', sorting='z' or sorting='value'
        If reverse_coord is True, coordinate are sorted from larger to smaller.
        """
        from msct_types import Coordinate
        from sct_utils import printv
        try:
            if len(self.dim) == 3:
                X, Y, Z = (self.data > 0).nonzero()
                list_coordinates = [
                    Coordinate([X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]])
                    for i in range(0, len(X))
                ]
            elif len(self.dim) == 2:
                X, Y = (self.data > 0).nonzero()
                list_coordinates = [
                    Coordinate([X[i], Y[i], self.data[X[i], Y[i]]])
                    for i in range(0, len(X))
                ]
        except Exception, e:
            printv(
                'ERROR: Exception ' + str(e) +
                ' caught while geting non Zeros coordinates', 1, 'error')

        if coordValue:
            from msct_types import CoordinateValue
            if len(self.dim) == 3:
                list_coordinates = [
                    CoordinateValue(
                        [X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]])
                    for i in range(0, len(X))
                ]
            else:
                list_coordinates = [
                    CoordinateValue([X[i], Y[i], self.data[X[i], Y[i]]])
                    for i in range(0, len(X))
                ]
        else:
            from msct_types import Coordinate
            if len(self.dim) == 3:
                list_coordinates = [
                    Coordinate([X[i], Y[i], Z[i], self.data[X[i], Y[i], Z[i]]])
                    for i in range(0, len(X))
                ]
            else:
                list_coordinates = [
                    Coordinate([X[i], Y[i], self.data[X[i], Y[i]]])
                    for i in range(0, len(X))
                ]
        if sorting is not None:
            if reverse_coord not in [True, False]:
                raise ValueError('reverse_coord parameter must be a boolean')

            if sorting == 'x':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.x,
                                          reverse=reverse_coord)
            elif sorting == 'y':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.y,
                                          reverse=reverse_coord)
            elif sorting == 'z':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.z,
                                          reverse=reverse_coord)
            elif sorting == 'value':
                list_coordinates = sorted(list_coordinates,
                                          key=lambda obj: obj.value,
                                          reverse=reverse_coord)
            else:
                raise ValueError(
                    "sorting parameter must be either 'x', 'y', 'z' or 'value'"
                )

        return list_coordinates
    def straighten(self):
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        gapxy = self.gapxy
        gapz = self.gapz
        padding = self.padding
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp
        algo_fitting = self.algo_fitting
        window_length = self.window_length
        type_window = self.type_window
        crop = self.crop

        # start timer
        start_time = time.time()

        # get path of the toolbox
        status, path_sct = commands.getstatusoutput("echo $SCT_DIR")
        sct.printv(path_sct, verbose)

        if self.debug == 1:
            print "\n*** WARNING: DEBUG MODE ON ***\n"
            fname_anat = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/anat_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2.nii.gz'
            fname_centerline = (
                "/Users/julien/data/temp/sct_example_data/t2/tmp.150401221259/centerline_rpi.nii"
            )  # path_sct+'/testing/sct_testing_data/data/t2/t2_seg.nii.gz'
            remove_temp_files = 0
            type_window = "hanning"
            verbose = 2

        # check existence of input files
        sct.check_file_exist(fname_anat, verbose)
        sct.check_file_exist(fname_centerline, verbose)

        # Display arguments
        sct.printv("\nCheck input arguments...", verbose)
        sct.printv("  Input volume ...................... " + fname_anat, verbose)
        sct.printv("  Centerline ........................ " + fname_centerline, verbose)
        sct.printv("  Final interpolation ............... " + interpolation_warp, verbose)
        sct.printv("  Verbose ........................... " + str(verbose), verbose)
        sct.printv("", verbose)

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
        path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

        # create temporary folder
        path_tmp = "tmp." + time.strftime("%y%m%d%H%M%S")
        sct.run("mkdir " + path_tmp, verbose)

        # copy files into tmp folder
        sct.run("cp " + fname_anat + " " + path_tmp, verbose)
        sct.run("cp " + fname_centerline + " " + path_tmp, verbose)

        # go to tmp folder
        os.chdir(path_tmp)

        try:
            # Change orientation of the input centerline into RPI
            sct.printv("\nOrient centerline to RPI orientation...", verbose)
            fname_centerline_orient = file_centerline + "_rpi.nii.gz"
            set_orientation(file_centerline + ext_centerline, "RPI", fname_centerline_orient)

            # Get dimension
            sct.printv("\nGet dimensions...", verbose)
            nx, ny, nz, nt, px, py, pz, pt = sct.get_dimension(fname_centerline_orient)
            sct.printv(".. matrix size: " + str(nx) + " x " + str(ny) + " x " + str(nz), verbose)
            sct.printv(".. voxel size:  " + str(px) + "mm x " + str(py) + "mm x " + str(pz) + "mm", verbose)

            # smooth centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_centerline_orient,
                algo_fitting=algo_fitting,
                type_window=type_window,
                window_length=window_length,
                verbose=verbose,
            )

            # Get coordinates of landmarks along curved centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along curved centerline...", verbose)
            # landmarks are created along the curved centerline every z=gapz. They consist of a "cross" of size gapx and gapy. In voxel space!!!

            # find z indices along centerline given a specific gap: iz_curved
            nz_nonz = len(z_centerline)
            nb_landmark = int(round(float(nz_nonz) / gapz))

            if nb_landmark == 0:
                nb_landmark = 1

            if nb_landmark == 1:
                iz_curved = [0]
            else:
                iz_curved = [i * gapz for i in range(0, nb_landmark - 1)]

            iz_curved.append(nz_nonz - 1)
            # print iz_curved, len(iz_curved)
            n_iz_curved = len(iz_curved)
            # print n_iz_curved

            # landmark_curved initialisation
            # landmark_curved = [ [ [ 0 for i in range(0, 3)] for i in range(0, 5) ] for i in iz_curved ]

            from msct_types import Coordinate

            landmark_curved = []
            landmark_curved_value = 1

            ### TODO: THIS PART IS SLOW AND CAN BE MADE FASTER
            ### >>==============================================================================================================
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # calculate d (ax+by+cz+d=0)
                    # print iz_curved[index]
                    a = x_centerline_deriv[iz]
                    b = y_centerline_deriv[iz]
                    c = z_centerline_deriv[iz]
                    x = x_centerline_fit[iz]
                    y = y_centerline_fit[iz]
                    z = z_centerline[iz]
                    d = -(a * x + b * y + c * z)
                    # print a,b,c,d,x,y,z
                    # set coordinates for landmark at the center of the cross
                    coord = Coordinate([0, 0, 0, landmark_curved_value])
                    coord.x, coord.y, coord.z = x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz]
                    landmark_curved.append(coord)

                    # set y coordinate to y_centerline_fit[iz] for elements 1 and 2 of the cross
                    cross_coordinates = [
                        Coordinate([0, 0, 0, landmark_curved_value + 1]),
                        Coordinate([0, 0, 0, landmark_curved_value + 2]),
                        Coordinate([0, 0, 0, landmark_curved_value + 3]),
                        Coordinate([0, 0, 0, landmark_curved_value + 4]),
                    ]

                    cross_coordinates[0].y = y_centerline_fit[iz]
                    cross_coordinates[1].y = y_centerline_fit[iz]

                    # set x and z coordinates for landmarks +x and -x, forcing de landmark to be in the orthogonal plan and the distance landmark/curve to be gapxy
                    x_n = Symbol("x_n")
                    cross_coordinates[1].x, cross_coordinates[0].x = solve(
                        (x_n - x) ** 2 + ((-1 / c) * (a * x_n + b * y + d) - z) ** 2 - gapxy ** 2, x_n
                    )  # x for -x and +x
                    cross_coordinates[0].z = (-1 / c) * (a * cross_coordinates[0].x + b * y + d)  # z for +x
                    cross_coordinates[1].z = (-1 / c) * (a * cross_coordinates[1].x + b * y + d)  # z for -x

                    # set x coordinate to x_centerline_fit[iz] for elements 3 and 4 of the cross
                    cross_coordinates[2].x = x_centerline_fit[iz]
                    cross_coordinates[3].x = x_centerline_fit[iz]

                    # set coordinates for landmarks +y and -y. Here, x coordinate is 0 (already initialized).
                    y_n = Symbol("y_n")
                    cross_coordinates[3].y, cross_coordinates[2].y = solve(
                        (y_n - y) ** 2 + ((-1 / c) * (a * x + b * y_n + d) - z) ** 2 - gapxy ** 2, y_n
                    )  # y for -y and +y
                    cross_coordinates[2].z = (-1 / c) * (a * x + b * cross_coordinates[2].y + d)  # z for +y
                    cross_coordinates[3].z = (-1 / c) * (a * x + b * cross_coordinates[3].y + d)  # z for -y

                    for coord in cross_coordinates:
                        landmark_curved.append(coord)
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_curved.append(
                            Coordinate(
                                [x_centerline_fit[iz], y_centerline_fit[iz], z_centerline[iz], landmark_curved_value],
                                mode="continuous",
                            )
                        )
                        landmark_curved_value += 1
            ### <<==============================================================================================================

            # Get coordinates of landmarks along straight centerline
            # ==========================================================================================
            sct.printv("\nGet coordinates of landmarks along straight centerline...", verbose)
            # landmark_straight = [ [ [ 0 for i in range(0,3)] for i in range (0,5) ] for i in iz_curved ] # same structure as landmark_curved

            landmark_straight = []

            # calculate the z indices corresponding to the Euclidean distance between two consecutive points on the curved centerline (approximation curve --> line)
            # TODO: DO NOT APPROXIMATE CURVE --> LINE
            if nb_landmark == 1:
                iz_straight = [0 for i in range(0, nb_landmark + 1)]
            else:
                iz_straight = [0 for i in range(0, nb_landmark)]

            # print iz_straight,len(iz_straight)
            iz_straight[0] = iz_curved[0]
            for index in range(1, n_iz_curved, 1):
                # compute vector between two consecutive points on the curved centerline
                vector_centerline = [
                    x_centerline_fit[iz_curved[index]] - x_centerline_fit[iz_curved[index - 1]],
                    y_centerline_fit[iz_curved[index]] - y_centerline_fit[iz_curved[index - 1]],
                    z_centerline[iz_curved[index]] - z_centerline[iz_curved[index - 1]],
                ]
                # compute norm of this vector
                norm_vector_centerline = linalg.norm(vector_centerline, ord=2)
                # round to closest integer value
                norm_vector_centerline_rounded = int(round(norm_vector_centerline, 0))
                # assign this value to the current z-coordinate on the straight centerline
                iz_straight[index] = iz_straight[index - 1] + norm_vector_centerline_rounded

            # initialize x0 and y0 to be at the center of the FOV
            x0 = int(round(nx / 2))
            y0 = int(round(ny / 2))
            landmark_curved_value = 1
            for iz in range(min(iz_curved), max(iz_curved) + 1, 1):
                if iz in iz_curved:
                    index = iz_curved.index(iz)
                    # set coordinates for landmark at the center of the cross
                    landmark_straight.append(Coordinate([x0, y0, iz_straight[index], landmark_curved_value]))
                    # set x, y and z coordinates for landmarks +x
                    landmark_straight.append(
                        Coordinate([x0 + gapxy, y0, iz_straight[index], landmark_curved_value + 1])
                    )
                    # set x, y and z coordinates for landmarks -x
                    landmark_straight.append(
                        Coordinate([x0 - gapxy, y0, iz_straight[index], landmark_curved_value + 2])
                    )
                    # set x, y and z coordinates for landmarks +y
                    landmark_straight.append(
                        Coordinate([x0, y0 + gapxy, iz_straight[index], landmark_curved_value + 3])
                    )
                    # set x, y and z coordinates for landmarks -y
                    landmark_straight.append(
                        Coordinate([x0, y0 - gapxy, iz_straight[index], landmark_curved_value + 4])
                    )
                    landmark_curved_value += 5
                else:
                    if self.all_labels == 1:
                        landmark_straight.append(Coordinate([x0, y0, iz, landmark_curved_value]))
                        landmark_curved_value += 1

            # Create NIFTI volumes with landmarks
            # ==========================================================================================
            # Pad input volume to deal with the fact that some landmarks on the curved centerline might be outside the FOV
            # N.B. IT IS VERY IMPORTANT TO PAD ALSO ALONG X and Y, OTHERWISE SOME LANDMARKS MIGHT GET OUT OF THE FOV!!!
            # sct.run('fslview ' + fname_centerline_orient)
            sct.printv("\nPad input volume to account for landmarks that fall outside the FOV...", verbose)
            sct.run(
                "isct_c3d "
                + fname_centerline_orient
                + " -pad "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox "
                + str(padding)
                + "x"
                + str(padding)
                + "x"
                + str(padding)
                + "vox 0 -o tmp.centerline_pad.nii.gz",
                verbose,
            )

            # Open padded centerline for reading
            sct.printv("\nOpen padded centerline for reading...", verbose)
            file = load("tmp.centerline_pad.nii.gz")
            data = file.get_data()
            hdr = file.get_header()

            if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                # Reorganize landmarks
                points_fixed, points_moving = [], []
                for coord in landmark_straight:
                    points_fixed.append([coord.x, coord.y, coord.z])
                for coord in landmark_curved:
                    points_moving.append([coord.x, coord.y, coord.z])

                # Register curved landmarks on straight landmarks based on python implementation
                sct.printv("\nComputing rigid transformation (algo=" + self.algo_landmark_rigid + ") ...", verbose)
                import msct_register_landmarks

                (
                    rotation_matrix,
                    translation_array,
                    points_moving_reg,
                ) = msct_register_landmarks.getRigidTransformFromLandmarks(
                    points_fixed, points_moving, constraints=self.algo_landmark_rigid, show=False
                )

                # reorganize registered points
                landmark_curved_rigid = []
                for index_curved, ind in enumerate(range(0, len(points_moving_reg), 1)):
                    coord = Coordinate()
                    coord.x, coord.y, coord.z, coord.value = (
                        points_moving_reg[ind][0],
                        points_moving_reg[ind][1],
                        points_moving_reg[ind][2],
                        index_curved + 1,
                    )
                    landmark_curved_rigid.append(coord)

                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_curved_rigid_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved_rigid)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of curved landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_curved_rigid[index].x)),
                        int(round(landmark_curved_rigid[index].y)),
                        int(round(landmark_curved_rigid[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_rigid_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved_rigid[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_rigid_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved_rigid.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved_rigid.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # writing rigid transformation file
                text_file = open("tmp.curve2straight_rigid.txt", "w")
                text_file.write("#Insight Transform File V1.0\n")
                text_file.write("#Transform 0\n")
                text_file.write("Transform: AffineTransform_double_3_3\n")
                text_file.write(
                    "Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n"
                    % (
                        rotation_matrix[0, 0],
                        rotation_matrix[0, 1],
                        rotation_matrix[0, 2],
                        rotation_matrix[1, 0],
                        rotation_matrix[1, 1],
                        rotation_matrix[1, 2],
                        rotation_matrix[2, 0],
                        rotation_matrix[2, 1],
                        rotation_matrix[2, 2],
                        -translation_array[0, 0],
                        translation_array[0, 1],
                        -translation_array[0, 2],
                    )
                )
                text_file.write("FixedParameters: 0 0 0\n")
                text_file.close()

            else:
                # Create volumes containing curved and straight landmarks
                data_curved_landmarks = data * 0
                data_straight_landmarks = data * 0

                # Loop across cross index
                for index in range(0, len(landmark_curved)):
                    x, y, z = (
                        int(round(landmark_curved[index].x)),
                        int(round(landmark_curved[index].y)),
                        int(round(landmark_curved[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_curved_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_curved[index].value

                    # get x, y and z coordinates of straight landmark (rounded to closest integer)
                    x, y, z = (
                        int(round(landmark_straight[index].x)),
                        int(round(landmark_straight[index].y)),
                        int(round(landmark_straight[index].z)),
                    )

                    # attribute landmark_value to the voxel and its neighbours
                    data_straight_landmarks[
                        x + padding - 1 : x + padding + 2,
                        y + padding - 1 : y + padding + 2,
                        z + padding - 1 : z + padding + 2,
                    ] = landmark_straight[index].value

                # Write NIFTI volumes
                sct.printv("\nWrite NIFTI volumes...", verbose)
                hdr.set_data_dtype("uint32")  # set imagetype to uint8 #TODO: maybe use int32
                img = Nifti1Image(data_curved_landmarks, None, hdr)
                save(img, "tmp.landmarks_curved.nii.gz")
                sct.printv(".. File created: tmp.landmarks_curved.nii.gz", verbose)
                img = Nifti1Image(data_straight_landmarks, None, hdr)
                save(img, "tmp.landmarks_straight.nii.gz")
                sct.printv(".. File created: tmp.landmarks_straight.nii.gz", verbose)

                # Estimate deformation field by pairing landmarks
                # ==========================================================================================
                # convert landmarks to INT
                sct.printv("\nConvert landmarks to INT...", verbose)
                sct.run("isct_c3d tmp.landmarks_straight.nii.gz -type int -o tmp.landmarks_straight.nii.gz", verbose)
                sct.run("isct_c3d tmp.landmarks_curved.nii.gz -type int -o tmp.landmarks_curved.nii.gz", verbose)

                # This stands to avoid overlapping between landmarks
                sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
                label_process_straight = ProcessLabels(
                    fname_label="tmp.landmarks_straight.nii.gz",
                    fname_output="tmp.landmarks_straight.nii.gz",
                    fname_ref="tmp.landmarks_curved.nii.gz",
                    verbose=verbose,
                )
                label_process_straight.process("remove")
                label_process_curved = ProcessLabels(
                    fname_label="tmp.landmarks_curved.nii.gz",
                    fname_output="tmp.landmarks_curved.nii.gz",
                    fname_ref="tmp.landmarks_straight.nii.gz",
                    verbose=verbose,
                )
                label_process_curved.process("remove")

                # Estimate rigid transformation
                sct.printv("\nEstimate rigid transformation between paired landmarks...", verbose)
                sct.run(
                    "isct_ANTSUseLandmarkImagesToGetAffineTransform tmp.landmarks_straight.nii.gz tmp.landmarks_curved.nii.gz rigid tmp.curve2straight_rigid.txt",
                    verbose,
                )

                # Apply rigid transformation
                sct.printv("\nApply rigid transformation to curved landmarks...", verbose)
                # sct.run('sct_apply_transfo -i tmp.landmarks_curved.nii.gz -o tmp.landmarks_curved_rigid.nii.gz -d tmp.landmarks_straight.nii.gz -w tmp.curve2straight_rigid.txt -x nn', verbose)
                Transform(
                    input_filename="tmp.landmarks_curved.nii.gz",
                    source_reg="tmp.landmarks_curved_rigid.nii.gz",
                    output_filename="tmp.landmarks_straight.nii.gz",
                    warp="tmp.curve2straight_rigid.txt",
                    interp="nn",
                    verbose=verbose,
                ).apply()

            if verbose == 2:
                from mpl_toolkits.mplot3d import Axes3D
                import matplotlib.pyplot as plt

                fig = plt.figure()
                ax = Axes3D(fig)
                ax.plot(x_centerline_fit, y_centerline_fit, z_centerline, zdir="z")
                ax.plot(
                    [coord.x for coord in landmark_curved],
                    [coord.y for coord in landmark_curved],
                    [coord.z for coord in landmark_curved],
                    ".",
                )
                ax.plot(
                    [coord.x for coord in landmark_straight],
                    [coord.y for coord in landmark_straight],
                    [coord.z for coord in landmark_straight],
                    "r.",
                )
                if self.algo_landmark_rigid is not None and self.algo_landmark_rigid != "None":
                    ax.plot(
                        [coord.x for coord in landmark_curved_rigid],
                        [coord.y for coord in landmark_curved_rigid],
                        [coord.z for coord in landmark_curved_rigid],
                        "b.",
                    )
                ax.set_xlabel("x")
                ax.set_ylabel("y")
                ax.set_zlabel("z")
                plt.show()

            # This stands to avoid overlapping between landmarks
            sct.printv("\nMake sure all labels between landmark_curved and landmark_curved match...", verbose)
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_straight.nii.gz",
                fname_output="tmp.landmarks_straight.nii.gz",
                fname_ref="tmp.landmarks_curved_rigid.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")
            label_process = ProcessLabels(
                fname_label="tmp.landmarks_curved_rigid.nii.gz",
                fname_output="tmp.landmarks_curved_rigid.nii.gz",
                fname_ref="tmp.landmarks_straight.nii.gz",
                verbose=verbose,
            )
            label_process.process("remove")

            # Estimate b-spline transformation curve --> straight
            sct.printv("\nEstimate b-spline transformation: curve --> straight...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_straight.nii.gz tmp.landmarks_curved_rigid.nii.gz tmp.warp_curve2straight.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # remove padding for straight labels
            if crop == 1:
                ImageCropper(
                    input_file="tmp.landmarks_straight.nii.gz",
                    output_file="tmp.landmarks_straight_crop.nii.gz",
                    dim="0,1,2",
                    bmax=True,
                    verbose=verbose,
                ).crop()
                pass
            else:
                sct.run("cp tmp.landmarks_straight.nii.gz tmp.landmarks_straight_crop.nii.gz", verbose)

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            # sct.run('isct_ComposeMultiTransform 3 tmp.warp_rigid.nii -R tmp.landmarks_straight.nii tmp.warp.nii tmp.curve2straight_rigid.txt')
            # !!! DO NOT USE sct.run HERE BECAUSE isct_ComposeMultiTransform OUTPUTS A NON-NULL STATUS !!!
            cmd = "isct_ComposeMultiTransform 3 tmp.curve2straight.nii.gz -R tmp.landmarks_straight_crop.nii.gz tmp.warp_curve2straight.nii.gz tmp.curve2straight_rigid.txt"
            sct.printv(cmd, verbose, "code")
            sct.run(cmd, self.verbose)
            # commands.getstatusoutput(cmd)

            # Estimate b-spline transformation straight --> curve
            # TODO: invert warping field instead of estimating a new one
            sct.printv("\nEstimate b-spline transformation: straight --> curve...", verbose)
            sct.run(
                "isct_ANTSUseLandmarkImagesToGetBSplineDisplacementField tmp.landmarks_curved_rigid.nii.gz tmp.landmarks_straight.nii.gz tmp.warp_straight2curve.nii.gz "
                + self.bspline_meshsize
                + " "
                + self.bspline_numberOfLevels
                + " "
                + self.bspline_order
                + " 0",
                verbose,
            )

            # Concatenate rigid and non-linear transformations...
            sct.printv("\nConcatenate rigid and non-linear transformations...", verbose)
            cmd = (
                "isct_ComposeMultiTransform 3 tmp.straight2curve.nii.gz -R "
                + file_anat
                + ext_anat
                + " -i tmp.curve2straight_rigid.txt tmp.warp_straight2curve.nii.gz"
            )
            sct.printv(cmd, verbose, "code")
            # commands.getstatusoutput(cmd)
            sct.run(cmd, self.verbose)

            # Apply transformation to input image
            sct.printv("\nApply transformation to input image...", verbose)
            Transform(
                input_filename=str(file_anat + ext_anat),
                source_reg="tmp.anat_rigid_warp.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp=interpolation_warp,
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()

            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            sct.printv("\nApply transformation to centerline image...", verbose)
            # sct.run('sct_apply_transfo -i '+fname_centerline_orient+' -o tmp.centerline_straight.nii.gz -d tmp.landmarks_straight_crop.nii.gz -x nn -w tmp.curve2straight.nii.gz')
            Transform(
                input_filename=fname_centerline_orient,
                source_reg="tmp.centerline_straight.nii.gz",
                output_filename="tmp.landmarks_straight_crop.nii.gz",
                interp="nn",
                warp="tmp.curve2straight.nii.gz",
                verbose=verbose,
            ).apply()
            # c = sct.run('sct_crop_image -i tmp.centerline_straight.nii.gz -o tmp.centerline_straight_crop.nii.gz -dim 2 -bzmax')
            from msct_image import Image

            file_centerline_straight = Image("tmp.centerline_straight.nii.gz", verbose=verbose)
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting="z")
            mean_coord = []
            for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
                mean_coord.append(
                    mean(
                        [
                            [coord.x * coord.value, coord.y * coord.value]
                            for coord in coordinates_centerline
                            if coord.z == z
                        ],
                        axis=0,
                    )
                )

            # compute error between the input data and the nurbs
            from math import sqrt

            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            for coord_z in mean_coord:
                if not isnan(sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2
                    self.mse_straightening += dist
                    dist = sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = sqrt(self.mse_straightening / float(count_mean))

        except Exception as e:
            sct.printv("WARNING: Exception during Straightening:", 1, "warning")
            print e

        os.chdir("..")

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        sct.printv("\nGenerate output file (in current folder)...", verbose)
        sct.generate_output_file(
            path_tmp + "/tmp.curve2straight.nii.gz", "warp_curve2straight.nii.gz", verbose
        )  # warping field
        sct.generate_output_file(
            path_tmp + "/tmp.straight2curve.nii.gz", "warp_straight2curve.nii.gz", verbose
        )  # warping field
        if fname_output == "":
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", file_anat + "_straight" + ext_anat, verbose
            )  # straightened anatomic
        else:
            fname_straight = sct.generate_output_file(
                path_tmp + "/tmp.anat_rigid_warp.nii.gz", fname_output, verbose
            )  # straightened anatomic
        # Remove temporary files
        if remove_temp_files:
            sct.printv("\nRemove temporary files...", verbose)
            sct.run("rm -rf " + path_tmp, verbose)

        sct.printv("\nDone!\n", verbose)

        sct.printv("Maximum x-y error = " + str(round(self.max_distance_straightening, 2)) + " mm", verbose, "bold")
        sct.printv(
            "Accuracy of straightening (MSE) = " + str(round(self.mse_straightening, 2)) + " mm", verbose, "bold"
        )
        # display elapsed time
        elapsed_time = time.time() - start_time
        sct.printv("\nFinished! Elapsed time: " + str(int(round(elapsed_time))) + "s", verbose)
        sct.printv("\nTo view results, type:", verbose)
        sct.printv("fslview " + fname_straight + " &\n", verbose, "info")
예제 #8
0
    def get_crosses_coordinates(coordinates_input,
                                gapxy=15,
                                image_ref=None,
                                dilate=False):
        from msct_types import Coordinate

        # if reference image is provided (segmentation), we draw the cross perpendicular to the centerline
        if image_ref is not None:
            # smooth centerline
            from sct_straighten_spinalcord import smooth_centerline
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                self.image_ref, verbose=self.verbose)

        # compute crosses
        cross_coordinates = []
        for coord in coordinates_input:
            if image_ref is None:
                from sct_straighten_spinalcord import compute_cross
                cross_coordinates_temp = compute_cross(coord, gapxy)
            else:
                from sct_straighten_spinalcord import compute_cross_centerline
                from numpy import where
                index_z = where(z_centerline == coord.z)
                deriv = Coordinate([
                    x_centerline_deriv[index_z][0],
                    y_centerline_deriv[index_z][0],
                    z_centerline_deriv[index_z][0], 0.0
                ])
                cross_coordinates_temp = compute_cross_centerline(
                    coord, deriv, gapxy)

            for i, coord_cross in enumerate(cross_coordinates_temp):
                coord_cross.value = coord.value * 10 + i + 1

            # dilate cross to 3x3x3
            if dilate:
                additional_coordinates = []
                for coord_temp in cross_coordinates_temp:
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y, coord_temp.z + 1.0,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y, coord_temp.z - 1.0,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x + 1.0, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y, coord_temp.z,
                            coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y + 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z + 1.0, coord_temp.value
                        ]))
                    additional_coordinates.append(
                        Coordinate([
                            coord_temp.x - 1.0, coord_temp.y - 1.0,
                            coord_temp.z - 1.0, coord_temp.value
                        ]))

                cross_coordinates_temp.extend(additional_coordinates)

            cross_coordinates.extend(cross_coordinates_temp)

        cross_coordinates = sorted(cross_coordinates,
                                   key=lambda obj: obj.value)
        return cross_coordinates