Exemplo n.º 1
0
def compute_ICBM152_centerline(dataset_info):
    """
    This function extracts the centerline from the ICBM152 brain template
    :param dataset_info: dictionary containing dataset information
    :return:
    """
    path_data = dataset_info['path_data']

    if not os.path.isdir(path_data + 'icbm152/'):
        download_data_template(path_data=path_data, name='icbm152', force=False)

    image_disks = Image(path_data + 'icbm152/mni_icbm152_t1_tal_nlin_sym_09c_disks_manual.nii.gz')
    coord = image_disks.getNonZeroCoordinates(sorting='z', reverse_coord=True)
    coord_physical = []

    for c in coord:
        if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
            c_p = image_disks.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            c_p.append(c.value)
            coord_physical.append(c_p)

    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        path_data + 'icbm152/mni_icbm152_t1_centerline_manual.nii.gz', algo_fitting='nurbs',
        verbose=0, nurbs_pts_number=300, all_slices=False, phys_coordinates=True, remove_outliers=False)

    centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                            x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)

    centerline.compute_vertebral_distribution(coord_physical, label_reference='PMG')
    return centerline
    def angle_correction(self):
        if self.fname_sc is not None:
            im_seg = Image(self.fname_sc)
            data_seg = im_seg.data
            X, Y, Z = (data_seg > 0).nonzero()
            min_z_index, max_z_index = min(Z), max(Z)

        # fit centerline, smooth it and return the first derivative (in physical space)
        x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(self.fname_sc, algo_fitting='hanning', type_window='hanning', window_length=80, nurbs_pts_number=3000, phys_coordinates=True, verbose=self.verbose, all_slices=False)
        centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)

        # average centerline coordinates over slices of the image
        x_centerline_deriv_rescorr, y_centerline_deriv_rescorr, z_centerline_deriv_rescorr = centerline.average_coordinates_over_slices(im_seg)[3:]

        # compute Z axis of the image, in physical coordinate
        axis_Z = im_seg.get_directions()[2]

        # Empty arrays in which angle for each z slice will be stored
        self.angles = np.zeros(im_seg.dim[2])

        # for iz in xrange(min_z_index, max_z_index + 1):
        for zz in range(im_seg.dim[2]):
            if zz >= min_z_index and zz <= max_z_index:
                # in the case of problematic segmentation (e.g., non continuous segmentation often at the extremities), display a warning but do not crash
                try:  # normalize the tangent vector to the centerline (i.e. its derivative)
                    tangent_vect = self._normalize(np.array([x_centerline_deriv_rescorr[zz], y_centerline_deriv_rescorr[zz], z_centerline_deriv_rescorr[zz]]))
                    # compute the angle between the normal vector of the plane and the vector z
                    self.angles[zz] = np.arccos(np.vdot(tangent_vect, axis_Z))
                except IndexError:
                    printv('WARNING: Your segmentation does not seem continuous, which could cause wrong estimations at the problematic slices. Please check it, especially at the extremities.', type='warning')
Exemplo n.º 3
0
def generate_centerline(dataset_info, contrast='t1', regenerate=False):
    """
    This function generates spinal cord centerline from binary images (either an image of centerline or segmentation)
    :param dataset_info: dictionary containing dataset information
    :param contrast: {'t1', 't2'}
    :return list of centerline objects
    """
    path_data = dataset_info['path_data']
    list_subjects = dataset_info['subjects']
    list_centerline = []

    current_path = os.getcwd()

    timer_centerline = sct.Timer(len(list_subjects))
    timer_centerline.start()
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        fname_image_centerline = path_data_subject + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
        fname_image_disks = path_data_subject + contrast + dataset_info['suffix_disks'] + '.nii.gz'

        # go to output folder
        sct.printv('\nExtracting centerline from ' + path_data_subject)
        os.chdir(path_data_subject)

        fname_centerline = 'centerline'
        # if centerline exists, we load it, if not, we compute it
        if os.path.isfile(fname_centerline + '.npz') and not regenerate:
            centerline = Centerline(fname=path_data_subject + fname_centerline + '.npz')
        else:
            # extracting intervertebral disks
            im = Image(fname_image_disks)
            coord = im.getNonZeroCoordinates(sorting='z', reverse_coord=True)
            coord_physical = []
            for c in coord:
                if c.value <= 22 or c.value in [48, 49, 50, 51, 52]:  # 22 corresponds to L2
                    c_p = im.transfo_pix2phys([[c.x, c.y, c.z]])[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)

            # extracting centerline from binary image and create centerline object with vertebral distribution
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_image_centerline, algo_fitting='nurbs',
                verbose=0, nurbs_pts_number=4000, all_slices=False, phys_coordinates=True, remove_outliers=False)
            centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                    x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
            centerline.compute_vertebral_distribution(coord_physical)
            centerline.save_centerline(fname_output=fname_centerline)

        list_centerline.append(centerline)
        timer_centerline.add_iteration()
    timer_centerline.stop()

    os.chdir(current_path)

    return list_centerline
Exemplo n.º 4
0
def _get_centerline(img, algo_fitting, verbose):
    nx, ny, nz, nt, px, py, pz, pt = img.dim
    _, arr_ctl, arr_ctl_der = get_centerline(img, algo_fitting=algo_fitting, minmax=True, verbose=verbose)
    # Transform centerline to physical coordinate system
    arr_ctl_phys = img.transfo_pix2phys(
        [[arr_ctl[0][i], arr_ctl[1][i], arr_ctl[2][i]] for i in range(len(arr_ctl[0]))])
    x_centerline, y_centerline, z_centerline = arr_ctl_phys[:, 0], arr_ctl_phys[:, 1], arr_ctl_phys[:, 2]
    # Adjust derivatives with pixel size
    x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = arr_ctl_der[0][:] * px, \
                                                                 arr_ctl_der[1][:] * py, \
                                                                 arr_ctl_der[2][:] * pz
    # Construct centerline object
    return Centerline(x_centerline.tolist(), y_centerline.tolist(), z_centerline.tolist(),
                      x_centerline_deriv.tolist(), y_centerline_deriv.tolist(), z_centerline_deriv.tolist())
Exemplo n.º 5
0
def normalize_intensity_template(dataset_info, fname_template_centerline=None, contrast='t1', verbose=1):
    """
    This function normalizes the intensity of the image inside the spinal cord
    :param fname_template: path to template image
    :param fname_template_centerline: path to template centerline (binary image or npz)
    :return:
    """

    path_data = dataset_info['path_data']
    list_subjects = dataset_info['subjects']
    path_template = dataset_info['path_template']

    average_intensity = []
    intensity_profiles = {}

    timer_profile = sct.Timer(len(list_subjects))
    timer_profile.start()

    # computing the intensity profile for each subject
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        if fname_template_centerline is None:
            fname_image = path_data_subject + contrast + '.nii.gz'
            fname_image_centerline = path_data_subject + contrast + dataset_info['suffix_centerline'] + '.nii.gz'
        else:
            fname_image = path_data_subject + contrast + '_straight.nii.gz'
            if fname_template_centerline.endswith('.npz'):
                fname_image_centerline = None
            else:
                fname_image_centerline = fname_template_centerline

        image = Image(fname_image)
        nx, ny, nz, nt, px, py, pz, pt = image.dim

        if fname_image_centerline is not None:
            # open centerline from template
            number_of_points_in_centerline = 4000
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                fname_image_centerline, algo_fitting='nurbs', verbose=0,
                nurbs_pts_number=number_of_points_in_centerline,
                all_slices=False, phys_coordinates=True, remove_outliers=True)
            centerline_template = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                             x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
        else:
            centerline_template = Centerline(fname=fname_template_centerline)

        x, y, z, xd, yd, zd = centerline_template.average_coordinates_over_slices(image)

        # Compute intensity values
        z_values, intensities = [], []
        extend = 1  # this means the mean intensity of the slice will be calculated over a 3x3 square
        for i in range(len(z)):
            coord_z = image.transfo_phys2pix([[x[i], y[i], z[i]]])[0]
            z_values.append(coord_z[2])
            intensities.append(np.mean(image.data[coord_z[0] - extend - 1:coord_z[0] + extend, coord_z[1] - extend - 1:coord_z[1] + extend, coord_z[2]]))

        # for the slices that are not in the image, extend min and max values to cover the whole image
        min_z, max_z = min(z_values), max(z_values)
        intensities_temp = copy(intensities)
        z_values_temp = copy(z_values)
        for cz in range(nz):
            if cz not in z_values:
                z_values_temp.append(cz)
                if cz < min_z:
                    intensities_temp.append(intensities[z_values.index(min_z)])
                elif cz > max_z:
                    intensities_temp.append(intensities[z_values.index(max_z)])
                else:
                    print 'error...', cz
        intensities = intensities_temp
        z_values = z_values_temp

        # Preparing data for smoothing
        arr_int = [[z_values[i], intensities[i]] for i in range(len(z_values))]
        arr_int.sort(key=lambda x: x[0])  # and make sure it is ordered with z

        def smooth(x, window_len=11, window='hanning'):
            """smooth the data using a window with requested size.
            """

            if x.ndim != 1:
                raise ValueError, "smooth only accepts 1 dimension arrays."

            if x.size < window_len:
                raise ValueError, "Input vector needs to be bigger than window size."

            if window_len < 3:
                return x

            if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
                raise ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"

            s = np.r_[x[window_len - 1:0:-1], x, x[-2:-window_len - 1:-1]]
            if window == 'flat':  # moving average
                w = np.ones(window_len, 'd')
            else:
                w = eval('np.' + window + '(window_len)')

            y = np.convolve(w / w.sum(), s, mode='same')
            return y[window_len - 1:-window_len + 1]

        # Smoothing
        intensities = [c[1] for c in arr_int]
        intensity_profile_smooth = smooth(np.array(intensities), window_len=50)
        average_intensity.append(np.mean(intensity_profile_smooth))

        intensity_profiles[subject_name] = intensity_profile_smooth

        if verbose == 2:
            import matplotlib.pyplot as plt
            plt.figure()
            plt.title(subject_name)
            plt.plot(intensities)
            plt.plot(intensity_profile_smooth)
            plt.show()

    # set the average image intensity over the entire dataset
    average_intensity = 1000.0

    # normalize the intensity of the image based on spinal cord
    for subject_name in list_subjects:
        path_data_subject = path_data + subject_name + '/' + contrast + '/'
        fname_image = path_data_subject + contrast + '_straight.nii.gz'

        image = Image(fname_image)
        nx, ny, nz, nt, px, py, pz, pt = image.dim

        image_image_new = image.copy()
        image_image_new.changeType(type='float32')
        for i in range(nz):
            image_image_new.data[:, :, i] *= average_intensity / intensity_profiles[subject_name][i]

        # Save intensity normalized template
        fname_image_normalized = sct.add_suffix(fname_image, '_norm')
        image_image_new.setFileName(fname_image_normalized)
        image_image_new.save()
Exemplo n.º 6
0
def generate_initial_template_space(dataset_info, points_average_centerline, position_template_disks):
    """
    This function generates the initial template space, on which all images will be registered.
    :param points_average_centerline: list of points (x, y, z) of the average spinal cord and brainstem centerline
    :param position_template_disks: index of intervertebral disks along the template centerline
    :return:
    """

    # initializing variables
    path_template = dataset_info['path_template']
    x_size_of_template_space, y_size_of_template_space = 201, 201
    spacing = 0.5

    # creating template space
    size_template_z = int(abs(points_average_centerline[0][2] - points_average_centerline[-1][2]) / spacing) + 15
    template_space = Image([x_size_of_template_space, y_size_of_template_space, size_template_z])
    template_space.data = np.zeros((x_size_of_template_space, y_size_of_template_space, size_template_z))
    template_space.hdr.set_data_dtype('float32')
    origin = [points_average_centerline[-1][0] + x_size_of_template_space * spacing / 2.0,
              points_average_centerline[-1][1] - y_size_of_template_space * spacing / 2.0,
              (points_average_centerline[-1][2] - spacing)]
    template_space.hdr.as_analyze_map()['dim'] = [3.0, x_size_of_template_space, y_size_of_template_space, size_template_z, 1.0, 1.0, 1.0, 1.0]
    template_space.hdr.as_analyze_map()['qoffset_x'] = origin[0]
    template_space.hdr.as_analyze_map()['qoffset_y'] = origin[1]
    template_space.hdr.as_analyze_map()['qoffset_z'] = origin[2]
    template_space.hdr.as_analyze_map()['srow_x'][-1] = origin[0]
    template_space.hdr.as_analyze_map()['srow_y'][-1] = origin[1]
    template_space.hdr.as_analyze_map()['srow_z'][-1] = origin[2]
    template_space.hdr.as_analyze_map()['srow_x'][0] = -spacing
    template_space.hdr.as_analyze_map()['srow_y'][1] = spacing
    template_space.hdr.as_analyze_map()['srow_z'][2] = spacing
    template_space.hdr.set_sform(template_space.hdr.get_sform())
    template_space.hdr.set_qform(template_space.hdr.get_sform())
    template_space.setFileName(path_template + 'template_space.nii.gz')
    template_space.save(type='uint8')

    # generate template centerline as an image
    image_centerline = template_space.copy()
    for coord in points_average_centerline:
        coord_pix = image_centerline.transfo_phys2pix([coord])[0]
        if 0 <= coord_pix[0] < image_centerline.data.shape[0] and 0 <= coord_pix[1] < image_centerline.data.shape[1] and 0 <= coord_pix[2] < image_centerline.data.shape[2]:
            image_centerline.data[int(coord_pix[0]), int(coord_pix[1]), int(coord_pix[2])] = 1
    image_centerline.setFileName(path_template + 'template_centerline.nii.gz')
    image_centerline.save(type='float32')

    # generate template disks position
    coord_physical = []
    image_disks = template_space.copy()
    for disk in position_template_disks:
        label = labels_regions[disk]
        coord = position_template_disks[disk]
        coord_pix = image_disks.transfo_phys2pix([coord])[0]

        coord = coord.tolist()
        coord.append(label)
        coord_physical.append(coord)
        if 0 <= coord_pix[0] < image_disks.data.shape[0] and 0 <= coord_pix[1] < image_disks.data.shape[1] and 0 <= coord_pix[2] < image_disks.data.shape[2]:
            image_disks.data[int(coord_pix[0]), int(coord_pix[1]), int(coord_pix[2])] = label
        else:
            sct.printv(str(coord_pix))
            sct.printv('ERROR: the disk label ' + str(disk) + ' is not in the template image.')
    image_disks.setFileName(path_template + 'template_disks.nii.gz')
    image_disks.save(type='uint8')

    # generate template centerline as a npz file
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        path_template + 'template_centerline.nii.gz', algo_fitting='nurbs', verbose=0, nurbs_pts_number=4000,
        all_slices=False, phys_coordinates=True, remove_outliers=True)
    centerline_template = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                                     x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)
    centerline_template.compute_vertebral_distribution(coord_physical)
    centerline_template.save_centerline(fname_output=path_template + 'template_centerline')
Exemplo n.º 7
0
def compute_properties_along_centerline(fname_seg_image,
                                        property_list,
                                        fname_disks_image=None,
                                        smooth_factor=5.0,
                                        interpolation_mode=0,
                                        remove_temp_files=1,
                                        verbose=1):

    # Check list of properties
    # If diameters is in the list, compute major and minor axis length and check orientation
    compute_diameters = False
    property_list_local = list(property_list)
    if 'diameters' in property_list_local:
        compute_diameters = True
        property_list_local.remove('diameters')
        property_list_local.append('major_axis_length')
        property_list_local.append('minor_axis_length')
        property_list_local.append('orientation')

    # TODO: make sure fname_segmentation and fname_disks are in the same space
    # create temporary folder and copying data
    sct.printv('\nCreate temporary folder...', verbose)
    path_tmp = sct.slash_at_the_end(
        'tmp.' + time.strftime("%y%m%d%H%M%S") + '_' +
        str(randint(1, 1000000)), 1)
    sct.run('mkdir ' + path_tmp, verbose)

    sct.run('cp ' + fname_seg_image + ' ' + path_tmp)
    if fname_disks_image is not None:
        sct.run('cp ' + fname_disks_image + ' ' + path_tmp)

    # go to tmp folder
    os.chdir(path_tmp)

    fname_segmentation = os.path.abspath(fname_seg_image)
    path_data, file_data, ext_data = sct.extract_fname(fname_segmentation)

    # Change orientation of the input centerline into RPI
    sct.printv('\nOrient centerline to RPI orientation...', verbose)
    im_seg = Image(file_data + ext_data)
    fname_segmentation_orient = 'segmentation_rpi' + ext_data
    image = set_orientation(im_seg, 'RPI')
    image.setFileName(fname_segmentation_orient)
    image.save()

    # Initiating some variables
    nx, ny, nz, nt, px, py, pz, pt = image.dim
    resolution = 0.5
    properties = {key: [] for key in property_list_local}
    properties['incremental_length'] = []
    properties['distance_from_C1'] = []
    properties['vertebral_level'] = []
    properties['z_slice'] = []

    # compute the spinal cord centerline based on the spinal cord segmentation
    number_of_points = 5 * nz
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
        fname_segmentation_orient,
        algo_fitting='nurbs',
        verbose=verbose,
        nurbs_pts_number=number_of_points,
        all_slices=False,
        phys_coordinates=True,
        remove_outliers=True)
    centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                            x_centerline_deriv, y_centerline_deriv,
                            z_centerline_deriv)

    # Compute vertebral distribution along centerline based on position of intervertebral disks
    if fname_disks_image is not None:
        fname_disks = os.path.abspath(fname_disks_image)
        path_data, file_data, ext_data = sct.extract_fname(fname_disks)
        im_disks = Image(file_data + ext_data)
        fname_disks_orient = 'disks_rpi' + ext_data
        image_disks = set_orientation(im_disks, 'RPI')
        image_disks.setFileName(fname_disks_orient)
        image_disks.save()

        image_disks = Image(fname_disks_orient)
        coord = image_disks.getNonZeroCoordinates(sorting='z',
                                                  reverse_coord=True)
        coord_physical = []
        for c in coord:
            c_p = image_disks.transfo_pix2phys([[c.x, c.y, c.z]])[0]
            c_p.append(c.value)
            coord_physical.append(c_p)
        centerline.compute_vertebral_distribution(coord_physical)

    sct.printv('Computing spinal cord shape along the spinal cord...')
    timer_properties = sct.Timer(
        number_of_iteration=centerline.number_of_points)
    timer_properties.start()
    # Extracting patches perpendicular to the spinal cord and computing spinal cord shape
    for index in range(centerline.number_of_points):
        # value_out = -5.0
        value_out = 0.0
        current_patch = centerline.extract_perpendicular_square(
            image,
            index,
            resolution=resolution,
            interpolation_mode=interpolation_mode,
            border='constant',
            cval=value_out)

        # check for pixels close to the spinal cord segmentation that are out of the image
        from skimage.morphology import dilation
        patch_zero = np.copy(current_patch)
        patch_zero[patch_zero == value_out] = 0.0
        patch_borders = dilation(patch_zero) - patch_zero
        """
        if np.count_nonzero(patch_borders + current_patch == value_out + 1.0) != 0:
            c = image.transfo_phys2pix([centerline.points[index]])[0]
            print 'WARNING: no patch for slice', c[2]
            timer_properties.add_iteration()
            continue
        """

        sc_properties = properties2d(patch_zero, [resolution, resolution])
        if sc_properties is not None:
            properties['incremental_length'].append(
                centerline.incremental_length[index])
            if fname_disks_image is not None:
                properties['distance_from_C1'].append(
                    centerline.dist_points[index])
                properties['vertebral_level'].append(
                    centerline.l_points[index])
            properties['z_slice'].append(
                image.transfo_phys2pix([centerline.points[index]])[0][2])
            for property_name in property_list_local:
                properties[property_name].append(sc_properties[property_name])
        else:
            c = image.transfo_phys2pix([centerline.points[index]])[0]
            print 'WARNING: no properties for slice', c[2]

        timer_properties.add_iteration()
    timer_properties.stop()

    # Adding centerline to the properties for later use
    properties['centerline'] = centerline

    # We assume that the major axis is in the right-left direction
    # this script checks the orientation of the spinal cord and invert axis if necessary to make sure the major axis is right-left
    if compute_diameters:
        diameter_major = properties['major_axis_length']
        diameter_minor = properties['minor_axis_length']
        orientation = properties['orientation']
        for i, orientation_item in enumerate(orientation):
            if -45.0 < orientation_item < 45.0:
                continue
            else:
                temp = diameter_minor[i]
                properties['minor_axis_length'][i] = diameter_major[i]
                properties['major_axis_length'][i] = temp

        properties['RL_diameter'] = properties['major_axis_length']
        properties['AP_diameter'] = properties['minor_axis_length']
        del properties['major_axis_length']
        del properties['minor_axis_length']

    # smooth the spinal cord shape with a gaussian kernel if required
    # TODO: not all properties can be smoothed
    if smooth_factor != 0.0:  # smooth_factor is in mm
        import scipy
        window = scipy.signal.hann(smooth_factor /
                                   np.mean(centerline.progressive_length))
        for property_name in property_list_local:
            properties[property_name] = scipy.signal.convolve(
                properties[property_name], window,
                mode='same') / np.sum(window)

    if compute_diameters:
        property_list_local.remove('major_axis_length')
        property_list_local.remove('minor_axis_length')
        property_list_local.append('RL_diameter')
        property_list_local.append('AP_diameter')
        property_list = property_list_local

    # Display properties on the referential space. Requires intervertebral disks
    if verbose == 2:
        x_increment = 'distance_from_C1'
        if fname_disks_image is None:
            x_increment = 'incremental_length'

        # Display the image and plot all contours found
        fig, axes = plt.subplots(len(property_list_local),
                                 sharex=True,
                                 sharey=False)
        for k, property_name in enumerate(property_list_local):
            axes[k].plot(properties[x_increment], properties[property_name])
            axes[k].set_ylabel(property_name)

        if fname_disks_image is not None:
            properties[
                'distance_disk_from_C1'] = centerline.distance_from_C1label  # distance between each disk and C1 (or first disk)
            xlabel_disks = [
                centerline.convert_vertlabel2disklabel[label]
                for label in properties['distance_disk_from_C1']
            ]
            xtick_disks = [
                properties['distance_disk_from_C1'][label]
                for label in properties['distance_disk_from_C1']
            ]
            plt.xticks(xtick_disks, xlabel_disks, rotation=30)
        else:
            axes[-1].set_xlabel('Position along the spinal cord (in mm)')

        plt.show()

    # Removing temporary folder
    os.chdir('..')
    shutil.rmtree(path_tmp, ignore_errors=True)

    return property_list, properties
Exemplo n.º 8
0
    def construct3D_uniform(self, P, k, prec):  # P point de controles
        global Nik_temp, Nik_temp_deriv
        n = len(P)  # Nombre de points de controle - 1

        # Calcul des xi
        x = self.calculX3D(P, k)

        # Calcul des coefficients N(i,k)
        Nik_temp = [[-1 for j in range(k)] for i in range(n)]
        for i in range(n):
            Nik_temp[i][-1] = self.N(i, k, x)
        Nik = []
        for i in range(n):
            Nik.append(Nik_temp[i][-1])

        # Calcul des Nik,p'
        Nik_temp_deriv = [[-1] for i in range(n)]
        for i in range(n):
            Nik_temp_deriv[i][-1] = self.Np(i, k, x)
        Nikp = []
        for i in range(n):
            Nikp.append(Nik_temp_deriv[i][-1])

        # Calcul de la courbe
        # reparametrization of the curve
        param = np.linspace(x[0], x[-1], prec)
        P_x, P_y, P_z, P_x_d, P_y_d, P_z_d = self.compute_curve_from_parametrization(P, k, x, Nik, Nikp, param)
        from msct_types import Centerline
        centerline = Centerline(P_x, P_y, P_z, P_x_d, P_y_d, P_z_d)
        distances_between_points = centerline.progressive_length
        range_points = np.linspace(0.0, 1.0, prec)
        dist_curved = np.zeros(prec)
        for i in range(1, prec):
            dist_curved[i] = dist_curved[i - 1] + distances_between_points[i - 1] / centerline.length
        param = x[0] + (x[-1] - x[0]) * np.interp(range_points, dist_curved, range_points)
        P_x, P_y, P_z, P_x_d, P_y_d, P_z_d = self.compute_curve_from_parametrization(P, k, x, Nik, Nikp, param)

        if self.all_slices:
            P_z = np.array([int(np.round(P_z[i])) for i in range(0, len(P_z))])

            # not perfect but works (if "enough" points), in order to deal with missing z slices
            for i in range(min(P_z), max(P_z) + 1, 1):
                if i not in P_z:
                    # sct.printv(' Missing z slice ')
                    # sct.printv(i)
                    P_z_temp = np.insert(P_z, np.where(P_z == i - 1)[-1][-1] + 1, i)
                    P_x_temp = np.insert(P_x, np.where(P_z == i - 1)[-1][-1] + 1,
                                         (P_x[np.where(P_z == i - 1)[-1][-1]] + P_x[np.where(P_z == i - 1)[-1][-1] + 1]) / 2)
                    P_y_temp = np.insert(P_y, np.where(P_z == i - 1)[-1][-1] + 1,
                                         (P_y[np.where(P_z == i - 1)[-1][-1]] + P_y[np.where(P_z == i - 1)[-1][-1] + 1]) / 2)
                    P_x_d_temp = np.insert(P_x_d, np.where(P_z == i - 1)[-1][-1] + 1, (
                        P_x_d[np.where(P_z == i - 1)[-1][-1]] + P_x_d[np.where(P_z == i - 1)[-1][-1] + 1]) / 2)
                    P_y_d_temp = np.insert(P_y_d, np.where(P_z == i - 1)[-1][-1] + 1, (
                        P_y_d[np.where(P_z == i - 1)[-1][-1]] + P_y_d[np.where(P_z == i - 1)[-1][-1] + 1]) / 2)
                    P_z_d_temp = np.insert(P_z_d, np.where(P_z == i - 1)[-1][-1] + 1, (
                        P_z_d[np.where(P_z == i - 1)[-1][-1]] + P_z_d[np.where(P_z == i - 1)[-1][-1] + 1]) / 2)
                    P_x, P_y, P_z, P_x_d, P_y_d, P_z_d = P_x_temp, P_y_temp, P_z_temp, P_x_d_temp, P_y_d_temp, P_z_d_temp

            coord_mean = np.array(
                [[np.mean(P_x[P_z == i]), np.mean(P_y[P_z == i]), i] for i in range(min(P_z), max(P_z) + 1, 1)])

            P_x = coord_mean[:, :][:, 0]
            P_y = coord_mean[:, :][:, 1]

            coord_mean_d = np.array([[np.mean(P_x_d[P_z == i]), np.mean(P_y_d[P_z == i]), np.mean(P_z_d[P_z == i])] for i in
                                     range(min(P_z), max(P_z) + 1, 1)])

            P_z = coord_mean[:, :][:, 2]

            P_x_d = coord_mean_d[:, :][:, 0]
            P_y_d = coord_mean_d[:, :][:, 1]
            P_z_d = coord_mean_d[:, :][:, 2]

            # check if slice should be in the result, based on self.P_z
            indexes_to_remove = []
            for i, ind_z in enumerate(P_z):
                if ind_z not in self.P_z:
                    indexes_to_remove.append(i)

            P_x = np.delete(P_x, indexes_to_remove)
            P_y = np.delete(P_y, indexes_to_remove)
            P_z = np.delete(P_z, indexes_to_remove)
            P_x_d = np.delete(P_x_d, indexes_to_remove)
            P_y_d = np.delete(P_y_d, indexes_to_remove)
            P_z_d = np.delete(P_z_d, indexes_to_remove)

        return [P_x, P_y, P_z], [P_x_d, P_y_d, P_z_d]
Exemplo n.º 9
0
def run_main():
    sct.start_stream_logger()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    # Input filename
    fname_input_data = arguments["-i"]
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = 'optic'
    if "-method" in arguments:
        method = arguments["-method"]

    # Contrast type
    contrast_type = ''
    if "-c" in arguments:
        contrast_type = arguments["-c"]
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = 'ERROR: -c is a mandatory argument when using Optic method.'
        sct.printv(error, type='error')
        return

    # Ga between slices
    interslice_gap = 10.0
    if "-gap" in arguments:
        interslice_gap = float(arguments["-gap"])

    # Output folder
    if "-ofolder" in arguments:
        folder_output = sct.slash_at_the_end(arguments["-ofolder"], slash=1)
    else:
        folder_output = './'

    # Remove temporary files
    remove_temp_files = True
    if "-r" in arguments:
        remove_temp_files = bool(int(arguments["-r"]))

    # Outputs a ROI file
    output_roi = False
    if "-roi" in arguments:
        output_roi = bool(int(arguments["-roi"]))

    # Verbosity
    verbose = 0
    if "-v" in arguments:
        verbose = int(arguments["-v"])

    if method == 'viewer':
        path_data, file_data, ext_data = sct.extract_fname(fname_data)

        # create temporary folder
        temp_folder = sct.TempFolder()
        temp_folder.copy_from(fname_data)
        temp_folder.chdir()

        # make sure image is in SAL orientation, as it is the orientation used by the viewer
        image_input = Image(fname_data)
        image_input_orientation = orientation(image_input,
                                              get=True,
                                              verbose=False)
        reoriented_image_filename = sct.add_suffix(file_data + ext_data,
                                                   "_SAL")
        cmd_image = 'sct_image -i "%s" -o "%s" -setorient SAL -v 0' % (
            fname_data, reoriented_image_filename)
        sct.run(cmd_image, verbose=False)

        # extract points manually using the viewer
        fname_points = viewer_centerline(image_fname=reoriented_image_filename,
                                         interslice_gap=interslice_gap,
                                         verbose=verbose)

        if fname_points is not None:
            image_points_RPI = sct.add_suffix(fname_points, "_RPI")
            cmd_image = 'sct_image -i "%s" -o "%s" -setorient RPI -v 0' % (
                fname_points, image_points_RPI)
            sct.run(cmd_image, verbose=False)

            image_input_reoriented = Image(image_points_RPI)

            # fit centerline, smooth it and return the first derivative (in physical space)
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                image_points_RPI,
                algo_fitting='nurbs',
                nurbs_pts_number=3000,
                phys_coordinates=True,
                verbose=verbose,
                all_slices=False)
            centerline = Centerline(x_centerline_fit, y_centerline_fit,
                                    z_centerline, x_centerline_deriv,
                                    y_centerline_deriv, z_centerline_deriv)

            # average centerline coordinates over slices of the image
            x_centerline_fit_rescorr, y_centerline_fit_rescorr, z_centerline_rescorr, x_centerline_deriv_rescorr, y_centerline_deriv_rescorr, z_centerline_deriv_rescorr = centerline.average_coordinates_over_slices(
                image_input_reoriented)

            # compute z_centerline in image coordinates for usage in vertebrae mapping
            voxel_coordinates = image_input_reoriented.transfo_phys2pix([[
                x_centerline_fit_rescorr[i], y_centerline_fit_rescorr[i],
                z_centerline_rescorr[i]
            ] for i in range(len(z_centerline_rescorr))])
            x_centerline_voxel = [coord[0] for coord in voxel_coordinates]
            y_centerline_voxel = [coord[1] for coord in voxel_coordinates]
            z_centerline_voxel = [coord[2] for coord in voxel_coordinates]

            # compute z_centerline in image coordinates with continuous precision
            voxel_coordinates = image_input_reoriented.transfo_phys2continuouspix(
                [[
                    x_centerline_fit_rescorr[i], y_centerline_fit_rescorr[i],
                    z_centerline_rescorr[i]
                ] for i in range(len(z_centerline_rescorr))])
            x_centerline_voxel_cont = [coord[0] for coord in voxel_coordinates]
            y_centerline_voxel_cont = [coord[1] for coord in voxel_coordinates]
            z_centerline_voxel_cont = [coord[2] for coord in voxel_coordinates]

            # Create an image with the centerline
            image_input_reoriented.data *= 0
            min_z_index, max_z_index = int(round(
                min(z_centerline_voxel))), int(round(max(z_centerline_voxel)))
            for iz in range(min_z_index, max_z_index + 1):
                image_input_reoriented.data[
                    int(round(x_centerline_voxel[iz - min_z_index])),
                    int(round(y_centerline_voxel[iz - min_z_index])),
                    int(
                        iz
                    )] = 1  # if index is out of bounds here for hanning: either the segmentation has holes or labels have been added to the file

            # Write the centerline image
            sct.printv('\nWrite NIFTI volumes...', verbose)
            fname_centerline_oriented = file_data + '_centerline' + ext_data
            image_input_reoriented.setFileName(fname_centerline_oriented)
            image_input_reoriented.changeType('uint8')
            image_input_reoriented.save()

            sct.printv('\nSet to original orientation...', verbose)
            sct.run('sct_image -i ' + fname_centerline_oriented +
                    ' -setorient ' + image_input_orientation + ' -o ' +
                    fname_centerline_oriented)

            # create a txt file with the centerline
            fname_centerline_oriented_txt = file_data + '_centerline.txt'
            file_results = open(fname_centerline_oriented_txt, 'w')
            for i in range(min_z_index, max_z_index + 1):
                file_results.write(
                    str(int(i)) + ' ' +
                    str(round(x_centerline_voxel_cont[i - min_z_index], 2)) +
                    ' ' +
                    str(round(y_centerline_voxel_cont[i - min_z_index], 2)) +
                    '\n')
            file_results.close()

            fname_centerline_oriented_roi = optic.centerline2roi(
                fname_image=fname_centerline_oriented,
                folder_output='./',
                verbose=verbose)

            # return to initial folder
            temp_folder.chdir_undo()

            # copy result to output folder
            shutil.copy(temp_folder.get_path() + fname_centerline_oriented,
                        folder_output)
            shutil.copy(temp_folder.get_path() + fname_centerline_oriented_txt,
                        folder_output)
            if output_roi:
                shutil.copy(
                    temp_folder.get_path() + fname_centerline_oriented_roi,
                    folder_output)
            centerline_filename = folder_output + fname_centerline_oriented

        else:
            centerline_filename = 'error'

        # delete temporary folder
        if remove_temp_files:
            temp_folder.cleanup()

    else:
        # condition on verbose when using OptiC
        if verbose == 1:
            verbose = 2

        # OptiC models
        path_script = os.path.dirname(__file__)
        path_sct = os.path.dirname(path_script)
        optic_models_path = os.path.join(path_sct, 'data/optic_models',
                                         '{}_model'.format(contrast_type))

        # Execute OptiC binary
        _, centerline_filename = optic.detect_centerline(
            image_fname=fname_data,
            contrast_type=contrast_type,
            optic_models_path=optic_models_path,
            folder_output=folder_output,
            remove_temp_files=remove_temp_files,
            output_roi=output_roi,
            verbose=verbose)

    sct.printv('\nDone! To view results, type:', verbose)
    sct.printv(
        "fslview " + fname_input_data + " " + centerline_filename +
        " -l Red -b 0,1 -t 0.7 &\n", verbose, 'info')
Exemplo n.º 10
0
def compute_properties_along_centerline(im_seg, smooth_factor=5.0, interpolation_mode=0, algo_fitting='hanning',
                                        window_length=50, size_patch=7, remove_temp_files=1, verbose=1):
    """
    Compute shape property along spinal cord centerline. This algorithm computes the centerline,
    oversample it, extract 2D patch orthogonal to the centerline, compute the shape on the 2D patches, and finally
    undersample the shape information in order to match the input slice #.
    :param im_seg: Image of segmentation, already oriented in RPI
    :param smooth_factor:
    :param interpolation_mode:
    :param algo_fitting:
    :param window_length:
    :param remove_temp_files:
    :param verbose:
    :return:
    """
    # TODO: put size_patch back to 20 (was put to 7 for debugging purpose)
    # List of properties to output (in the right order)
    property_list = ['area',
                     'equivalent_diameter',
                     'AP_diameter',
                     'RL_diameter',
                     'ratio_minor_major',
                     'eccentricity',
                     'solidity',
                     'orientation']

    # Initiating some variables
    nx, ny, nz, nt, px, py, pz, pt = im_seg.dim

    # Extract min and max index in Z direction
    data_seg = im_seg.data
    X, Y, Z = (data_seg > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)

    # Define the resampling resolution. Here, we take the minimum of half the pixel size along X or Y in order to have
    # sufficient precision upon resampling. Since we want isotropic resamping, we take the min between the two dims.
    resolution = min(float(px) / 2, float(py) / 2)
    # resolution = 0.5
    # Initialize 1d array with nan. Each element corresponds to a slice.
    properties = {key: np.full_like(np.empty(nz), np.nan, dtype=np.double) for key in property_list}
    # properties['incremental_length'] = np.full_like(np.empty(nz), np.nan, dtype=np.double)
    # properties['distance_from_C1'] = np.full_like(np.empty(nz), np.nan, dtype=np.double)
    # properties['vertebral_level'] = np.full_like(np.empty(nz), np.nan, dtype=np.double)
    # properties['z_slice'] = []

    # compute the spinal cord centerline based on the spinal cord segmentation
    number_of_points = nz  # 5 * nz
    x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = \
        smooth_centerline(im_seg, algo_fitting=algo_fitting, window_length=window_length,
                          verbose=verbose, nurbs_pts_number=number_of_points, all_slices=False, phys_coordinates=True,
                          remove_outliers=True)
    centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline,
                            x_centerline_deriv, y_centerline_deriv, z_centerline_deriv)

    sct.printv('Computing spinal cord shape along the spinal cord...')
    with tqdm.tqdm(total=len(range(min_z_index, max_z_index))) as pbar:

        # Extracting patches perpendicular to the spinal cord and computing spinal cord shape
        i_centerline = 0  # index of the centerline() object
        for iz in range(min_z_index, max_z_index-1):
        # for index in range(centerline.number_of_points):  Julien
            # value_out = -5.0
            value_out = 0.0
            # TODO: correct for angulation using the cosine. The current approach has 2 issues:
            # - the centerline is not homogeneously sampled along z (which is the reason it is oversampled)
            # - computationally expensive
            # - requires resampling to higher resolution --> to check: maybe required with cosine approach
            current_patch = centerline.extract_perpendicular_square(im_seg, i_centerline, size=size_patch,
                                                                    resolution=resolution,
                                                                    interpolation_mode=interpolation_mode,
                                                                    border='constant', cval=value_out)

            # check for pixels close to the spinal cord segmentation that are out of the image
            patch_zero = np.copy(current_patch)
            patch_zero[patch_zero == value_out] = 0.0
            # patch_borders = dilation(patch_zero) - patch_zero

            """
            if np.count_nonzero(patch_borders + current_patch == value_out + 1.0) != 0:
                c = image.transfo_phys2pix([centerline.points[index]])[0]
                print('WARNING: no patch for slice', c[2])
                continue
            """
            # compute shape properties on 2D patch
            sc_properties = properties2d(patch_zero, [resolution, resolution])
            # assign AP and RL to minor or major axis, depending on the orientation
            sc_properties = assign_AP_and_RL_diameter(sc_properties)
            # loop across properties and assign values for function output
            if sc_properties is not None:
                # properties['incremental_length'][iz] = centerline.incremental_length[i_centerline]
                for property_name in property_list:
                    properties[property_name][iz] = sc_properties[property_name]
            else:
                c = im_seg.transfo_phys2pix([centerline.points[i_centerline]])[0]
                sct.printv('WARNING: no properties for slice', c[2])

            i_centerline += 1
            pbar.update(1)

    # # smooth the spinal cord shape with a gaussian kernel if required
    # # TODO: remove this smoothing
    # if smooth_factor != 0.0:  # smooth_factor is in mm
    #     import scipy
    #     window = scipy.signal.hann(smooth_factor / np.mean(centerline.progressive_length))
    #     for property_name in property_list:
    #         properties[property_name] = scipy.signal.convolve(properties[property_name], window, mode='same') / np.sum(window)

    # extract all values for shape properties to be averaged across the oversampled centerline in order to match the
    # input slice #
    # sorting_values = []
    # for label in properties['z_slice']:
    #     if label not in sorting_values:
    #         sorting_values.append(label)
    # prepare output
    # shape_output = dict()
    # for property_name in property_list:
    #     shape_output[property_name] = []
    #     for label in sorting_values:
    #         averaged_shape[property_name].append(np.mean(
    #             [item for i, item in enumerate(properties[property_name]) if
    #              properties['z_slice'][i] == label]))

    return properties
Exemplo n.º 11
0
    def straighten(self):
        """
        Straighten spinal cord. Steps: (everything is done in physical space)
        1. open input image and centreline image
        2. extract bspline fitting of the centreline, and its derivatives
        3. compute length of centerline
        4. compute and generate straight space
        5. compute transformations
            for each voxel of one space: (done using matrices --> improves speed by a factor x300)
                a. determine which plane of spinal cord centreline it is included
                b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane)
                c. find the correspondant centreline point in the other space
                d. find the correspondance of the voxel in the corresponding plane
        6. generate warping fields for each transformations
        7. write warping fields and apply them

        step 5.b: how to find the corresponding plane?
            The centerline plane corresponding to a voxel correspond to the nearest point of the centerline.
            However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant.
            If it is more far than a threshold, warping value should be 0.

        step 5.d: how to make the correspondance between centerline point in both images?
            Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve.
            If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward.

        :return:
        """
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp
        algo_fitting = self.algo_fitting

        # start timer
        start_time = time.time()

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)

        path_tmp = sct.tmp_create(basename="straighten_spinalcord", verbose=verbose)

        # Copying input data to tmp folder
        sct.printv('\nCopy files to tmp folder...', verbose)
        Image(fname_anat).save(os.path.join(path_tmp, "data.nii"))
        Image(fname_centerline).save(os.path.join(path_tmp, "centerline.nii.gz"))

        if self.use_straight_reference:
            Image(self.centerline_reference_filename).save(os.path.join(path_tmp, "centerline_ref.nii.gz"))
        if self.discs_input_filename != '':
            Image(self.discs_input_filename).save(os.path.join(path_tmp, "labels_input.nii.gz"))
        if self.discs_ref_filename != '':
            Image(self.discs_ref_filename).save(os.path.join(path_tmp, "labels_ref.nii.gz"))

        # go to tmp folder
        curdir = os.getcwd()
        os.chdir(path_tmp)

        # Change orientation of the input centerline into RPI
        image_centerline = Image("centerline.nii.gz").change_orientation("RPI").save("centerline_rpi.nii.gz",
                                                                                     mutable=True)

        # Get dimension
        nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
        if self.speed_factor != 1.0:
            intermediate_resampling = True
            px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor
        else:
            intermediate_resampling = False

        if intermediate_resampling:
            sct.mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz')
            pz_native = pz
            # TODO: remove system call
            sct.run(['sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm',
                     str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o', 'centerline_rpi.nii.gz'])
            image_centerline = Image('centerline_rpi.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim

        if np.min(image_centerline.data) < 0 or np.max(image_centerline.data) > 1:
            image_centerline.data[image_centerline.data < 0] = 0
            image_centerline.data[image_centerline.data > 1] = 1
            image_centerline.save()

        # 2. extract bspline fitting of the centerline, and its derivatives
        img_ctl = Image('centerline_rpi.nii.gz')
        centerline = _get_centerline(img_ctl, algo_fitting, self.degree, verbose)
        number_of_points = centerline.number_of_points

        # ==========================================================================================
        logger.info('Create the straight space and the safe zone')
        # 3. compute length of centerline
        # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction

        # Computation of the safe zone.
        # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete
        # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the
        # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle)
        # Calculate Ls for both edges and remove appropriate number of centerline points
        radius_safe = 0.0  # mm

        # inferior edge
        u = centerline.derivatives[0]
        v = np.array([0, 0, -1])

        angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v))
        length_safe_inferior = radius_safe * np.sin(angle_inferior)

        # superior edge
        u = centerline.derivatives[-1]
        v = np.array([0, 0, 1])
        angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v))
        length_safe_superior = radius_safe * np.sin(angle_superior)

        # remove points
        inferior_bound = bisect.bisect(centerline.progressive_length, length_safe_inferior) - 1
        superior_bound = centerline.number_of_points - bisect.bisect(centerline.progressive_length_inverse,
                                                                     length_safe_superior)

        z_centerline = centerline.points[:, 2]
        length_centerline = centerline.length
        size_z_centerline = z_centerline[-1] - z_centerline[0]

        # compute the size factor between initial centerline and straight bended centerline
        factor_curved_straight = length_centerline / size_z_centerline
        middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0

        bound_curved = [z_centerline[inferior_bound], z_centerline[superior_bound]]
        bound_straight = [(z_centerline[inferior_bound] - middle_slice) * factor_curved_straight + middle_slice,
                          (z_centerline[superior_bound] - middle_slice) * factor_curved_straight + middle_slice]

        logger.info('Length of spinal cord: {}'.format(length_centerline))
        logger.info('Size of spinal cord in z direction: {}'.format(size_z_centerline))
        logger.info('Ratio length/size: {}'.format(factor_curved_straight))
        logger.info('Safe zone boundaries (curved space): {}'.format(bound_curved))
        logger.info('Safe zone boundaries (straight space): {}'.format(bound_straight))

        # 4. compute and generate straight space
        # points along curved centerline are already regularly spaced.
        # calculate position of points along straight centerline

        # Create straight NIFTI volumes.
        # ==========================================================================================
        # TODO: maybe this if case is not needed?
        if self.use_straight_reference:
            image_centerline_pad = Image('centerline_rpi.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim

            fname_ref = 'centerline_ref_rpi.nii.gz'
            image_centerline_straight = Image('centerline_ref.nii.gz') \
                .change_orientation("RPI") \
                .save(fname_ref, mutable=True)
            centerline_straight = _get_centerline(image_centerline_straight, algo_fitting, self.degree, verbose)
            nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim

            # Prepare warping fields headers
            hdr_warp = image_centerline_pad.hdr.copy()
            hdr_warp.set_data_dtype('float32')
            hdr_warp_s = image_centerline_straight.hdr.copy()
            hdr_warp_s.set_data_dtype('float32')

            if self.discs_input_filename != "" and self.discs_ref_filename != "":
                discs_input_image = Image('labels_input.nii.gz')
                coord = discs_input_image.getNonZeroCoordinates(sorting='z', reverse_coord=True)
                coord_physical = []
                for c in coord:
                    c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)
                centerline.compute_vertebral_distribution(coord_physical)
                centerline.save_centerline(image=discs_input_image, fname_output='discs_input_image.nii.gz')

                discs_ref_image = Image('labels_ref.nii.gz')
                coord = discs_ref_image.getNonZeroCoordinates(sorting='z', reverse_coord=True)
                coord_physical = []
                for c in coord:
                    c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)
                centerline_straight.compute_vertebral_distribution(coord_physical)
                centerline_straight.save_centerline(image=discs_ref_image, fname_output='discs_ref_image.nii.gz')

        else:
            logger.info('Pad input volume to account for spinal cord length...')

            start_point, end_point = bound_straight[0], bound_straight[1]
            offset_z = 0

            # if the destination image is resampled, we still create the straight reference space with the native
            # resolution.
            # TODO: Maybe this if case is not needed?
            if intermediate_resampling:
                padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz_native))
                sct.run(
                    ['sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o', 'tmp.centerline_pad_native.nii.gz',
                     '-pad', '0,0,' + str(padding_z)])
                image_centerline_pad = Image('centerline_rpi_native.nii.gz')
                nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim
                start_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0]
                end_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0]
                straight_size_x = int(self.xy_size / px)
                straight_size_y = int(self.xy_size / py)
                warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x]
                warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y]
                if warp_space_x[0] < 0:
                    warp_space_x[1] += warp_space_x[0] - 2
                    warp_space_x[0] = 0
                if warp_space_y[0] < 0:
                    warp_space_y[1] += warp_space_y[0] - 2
                    warp_space_y[0] = 0

                spec = dict((
                    (0, warp_space_x),
                    (1, warp_space_y),
                    (2, (0, end_point_coord_native[2] - start_point_coord_native[2])),
                ))
                msct_image.spatial_crop(Image("tmp.centerline_pad_native.nii.gz"), spec).save(
                    "tmp.centerline_pad_crop_native.nii.gz")

                fname_ref = 'tmp.centerline_pad_crop_native.nii.gz'
                offset_z = 4
            else:
                fname_ref = 'tmp.centerline_pad_crop.nii.gz'

            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
            padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz)) + offset_z
            image_centerline_pad = pad_image(image_centerline, pad_z_i=padding_z, pad_z_f=padding_z)
            nx, ny, nz = image_centerline_pad.data.shape
            hdr_warp = image_centerline_pad.hdr.copy()
            hdr_warp.set_data_dtype('float32')
            start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0]
            end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0]

            straight_size_x = int(self.xy_size / px)
            straight_size_y = int(self.xy_size / py)
            warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x]
            warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y]

            if warp_space_x[0] < 0:
                warp_space_x[1] += warp_space_x[0] - 2
                warp_space_x[0] = 0
            if warp_space_x[1] >= nx:
                warp_space_x[1] = nx - 1
            if warp_space_y[0] < 0:
                warp_space_y[1] += warp_space_y[0] - 2
                warp_space_y[0] = 0
            if warp_space_y[1] >= ny:
                warp_space_y[1] = ny - 1

            spec = dict((
                (0, warp_space_x),
                (1, warp_space_y),
                (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)),
            ))
            image_centerline_straight = msct_image.spatial_crop(image_centerline_pad, spec)

            nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim
            hdr_warp_s = image_centerline_straight.hdr.copy()
            hdr_warp_s.set_data_dtype('float32')

            if self.template_orientation == 1:
                raise NotImplementedError()

            start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0]
            end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0]

            number_of_voxel = nx * ny * nz
            logger.debug('Number of voxels: {}'.format(number_of_voxel))

            time_centerlines = time.time()

            coord_straight = np.empty((number_of_points, 3))
            coord_straight[..., 0] = int(np.round(nx_s / 2))
            coord_straight[..., 1] = int(np.round(ny_s / 2))
            coord_straight[..., 2] = np.linspace(0, end_point_coord[2] - start_point_coord[2], number_of_points)
            coord_phys_straight = image_centerline_straight.transfo_pix2phys(coord_straight)
            derivs_straight = np.empty((number_of_points, 3))
            derivs_straight[..., 0] = derivs_straight[..., 1] = 0
            derivs_straight[..., 2] = 1
            dx_straight, dy_straight, dz_straight = derivs_straight.T
            centerline_straight = Centerline(coord_phys_straight[:, 0], coord_phys_straight[:, 1],
                                             coord_phys_straight[:, 2],
                                             dx_straight, dy_straight, dz_straight)

            time_centerlines = time.time() - time_centerlines
            logger.info('Time to generate centerline: {} ms'.format(np.round(time_centerlines * 1000.0)))

        if verbose == 2:
            # TODO: use OO
            import matplotlib.pyplot as plt
            from datetime import datetime
            curved_points = centerline.progressive_length
            straight_points = centerline_straight.progressive_length
            range_points = np.linspace(0, 1, number_of_points)
            dist_curved = np.zeros(number_of_points)
            dist_straight = np.zeros(number_of_points)
            for i in range(1, number_of_points):
                dist_curved[i] = dist_curved[i - 1] + curved_points[i - 1] / centerline.length
                dist_straight[i] = dist_straight[i - 1] + straight_points[i - 1] / centerline_straight.length
            plt.plot(range_points, dist_curved)
            plt.plot(range_points, dist_straight)
            plt.grid(True)
            plt.savefig('fig_straighten_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '.png')
            plt.close()

        # alignment_mode = 'length'
        alignment_mode = 'levels'

        lookup_curved2straight = list(range(centerline.number_of_points))
        if self.discs_input_filename != "":
            # create look-up table curved to straight
            for index in range(centerline.number_of_points):
                disc_label = centerline.l_points[index]
                if alignment_mode == 'length':
                    relative_position = centerline.dist_points[index]
                else:
                    relative_position = centerline.dist_points_rel[index]
                idx_closest = centerline_straight.get_closest_to_absolute_position(disc_label, relative_position,
                                                                                   backup_index=index,
                                                                                   backup_centerline=centerline_straight,
                                                                                   mode=alignment_mode)
                if idx_closest is not None:
                    lookup_curved2straight[index] = idx_closest
                else:
                    lookup_curved2straight[index] = 0
        for p in range(0, len(lookup_curved2straight) // 2):
            if lookup_curved2straight[p] == lookup_curved2straight[p + 1]:
                lookup_curved2straight[p] = 0
            else:
                break
        for p in range(len(lookup_curved2straight) - 1, len(lookup_curved2straight) // 2, -1):
            if lookup_curved2straight[p] == lookup_curved2straight[p - 1]:
                lookup_curved2straight[p] = 0
            else:
                break
        lookup_curved2straight = np.array(lookup_curved2straight)

        lookup_straight2curved = list(range(centerline_straight.number_of_points))
        if self.discs_input_filename != "":
            for index in range(centerline_straight.number_of_points):
                disc_label = centerline_straight.l_points[index]
                if alignment_mode == 'length':
                    relative_position = centerline_straight.dist_points[index]
                else:
                    relative_position = centerline_straight.dist_points_rel[index]
                idx_closest = centerline.get_closest_to_absolute_position(disc_label, relative_position,
                                                                          backup_index=index,
                                                                          backup_centerline=centerline_straight,
                                                                          mode=alignment_mode)
                if idx_closest is not None:
                    lookup_straight2curved[index] = idx_closest
        for p in range(0, len(lookup_straight2curved) // 2):
            if lookup_straight2curved[p] == lookup_straight2curved[p + 1]:
                lookup_straight2curved[p] = 0
            else:
                break
        for p in range(len(lookup_straight2curved) - 1, len(lookup_straight2curved) // 2, -1):
            if lookup_straight2curved[p] == lookup_straight2curved[p - 1]:
                lookup_straight2curved[p] = 0
            else:
                break
        lookup_straight2curved = np.array(lookup_straight2curved)

        # Create volumes containing curved and straight warping fields
        data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3))
        data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3))

        # 5. compute transformations
        # Curved and straight images and the same dimensions, so we compute both warping fields at the same time.
        # b. determine which plane of spinal cord centreline it is included
        # sct.printv(nx * ny * nz, nx_s * ny_s * nz_s)

        if self.curved2straight:
            for u in tqdm(range(nz_s)):
                x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1]
                indexes_straight = np.array(list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel())))
                physical_coordinates_straight = image_centerline_straight.transfo_pix2phys(indexes_straight)
                nearest_indexes_straight = centerline_straight.find_nearest_indexes(physical_coordinates_straight)
                distances_straight = centerline_straight.get_distances_from_planes(physical_coordinates_straight,
                                                                                   nearest_indexes_straight)
                lookup = lookup_straight2curved[nearest_indexes_straight]
                indexes_out_distance_straight = np.logical_or(
                    np.logical_or(distances_straight > self.threshold_distance,
                                  distances_straight < -self.threshold_distance), lookup == 0)
                projected_points_straight = centerline_straight.get_projected_coordinates_on_planes(
                    physical_coordinates_straight, nearest_indexes_straight)
                coord_in_planes_straight = centerline_straight.get_in_plans_coordinates(projected_points_straight,
                                                                                        nearest_indexes_straight)

                coord_straight2curved = centerline.get_inverse_plans_coordinates(coord_in_planes_straight, lookup)
                displacements_straight = coord_straight2curved - physical_coordinates_straight
                # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+)
                # while ours is LPI-
                # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a
                #  https://www.slicer.org/wiki/Coordinate_systems
                displacements_straight[:, 2] = -displacements_straight[:, 2]
                displacements_straight[indexes_out_distance_straight] = [100000.0, 100000.0, 100000.0]

                data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\
                    = -displacements_straight

        if self.straight2curved:
            for u in tqdm(range(nz)):
                x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1]
                indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel())))
                physical_coordinates = image_centerline_pad.transfo_pix2phys(indexes)
                nearest_indexes_curved = centerline.find_nearest_indexes(physical_coordinates)
                distances_curved = centerline.get_distances_from_planes(physical_coordinates,
                                                                        nearest_indexes_curved)
                lookup = lookup_curved2straight[nearest_indexes_curved]
                indexes_out_distance_curved = np.logical_or(
                    np.logical_or(distances_curved > self.threshold_distance,
                                  distances_curved < -self.threshold_distance), lookup == 0)
                projected_points_curved = centerline.get_projected_coordinates_on_planes(physical_coordinates,
                                                                                         nearest_indexes_curved)
                coord_in_planes_curved = centerline.get_in_plans_coordinates(projected_points_curved,
                                                                             nearest_indexes_curved)

                coord_curved2straight = centerline_straight.points[lookup]
                coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2]
                coord_curved2straight[:, 2] += distances_curved

                displacements_curved = coord_curved2straight - physical_coordinates

                displacements_curved[:, 2] = -displacements_curved[:, 2]
                displacements_curved[indexes_out_distance_curved] = [100000.0, 100000.0, 100000.0]

                data_warp_straight2curved[indexes[:, 0], indexes[:, 1], indexes[:, 2], 0, :] = -displacements_curved

        # Creation of the safe zone based on pre-calculated safe boundaries
        coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix(
            [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix([[0, 0, bound_curved[1]]])
        coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix(
            [[0, 0, bound_straight[0]]]), image_centerline_straight.transfo_phys2pix([[0, 0, bound_straight[1]]])

        if radius_safe > 0:
            data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2], 0, :] = 100000.0
            data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:, 0, :] = 100000.0
            data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2], 0, :] = 100000.0
            data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:, 0, :] = 100000.0

        # Generate warp files as a warping fields
        hdr_warp_s.set_intent('vector', (), '')
        hdr_warp_s.set_data_dtype('float32')
        hdr_warp.set_intent('vector', (), '')
        hdr_warp.set_data_dtype('float32')
        if self.curved2straight:
            img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s)
            save(img, 'tmp.curve2straight.nii.gz')
            logger.info('Warping field generated: tmp.curve2straight.nii.gz')

        if self.straight2curved:
            img = Nifti1Image(data_warp_straight2curved, None, hdr_warp)
            save(img, 'tmp.straight2curve.nii.gz')
            logger.info('Warping field generated: tmp.straight2curve.nii.gz')

        image_centerline_straight.save(fname_ref)
        if self.curved2straight:
            logger.info('Apply transformation to input image...')
            sct.run(['isct_antsApplyTransforms',
                     '-d', '3',
                     '-r', fname_ref,
                     '-i', 'data.nii',
                     '-o', 'tmp.anat_rigid_warp.nii.gz',
                     '-t', 'tmp.curve2straight.nii.gz',
                     '-n', 'BSpline[3]'],
                    is_sct_binary=True,
                    verbose=verbose)

        if self.accuracy_results:
            time_accuracy_results = time.time()
            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            logger.info('Apply transformation to centerline image...')
            sct.run(['isct_antsApplyTransforms',
                     '-d', '3',
                     '-r', fname_ref,
                     '-i', 'centerline.nii.gz',
                     '-o', 'tmp.centerline_straight.nii.gz',
                     '-t', 'tmp.curve2straight.nii.gz',
                     '-n', 'NearestNeighbor'],
                    is_sct_binary=True,
                    verbose=verbose)
            file_centerline_straight = Image('tmp.centerline_straight.nii.gz', verbose=verbose)
            nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting='z')
            mean_coord = []
            for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
                temp_mean = [coord.value for coord in coordinates_centerline if coord.z == z]
                if temp_mean:
                    mean_value = np.mean(temp_mean)
                    mean_coord.append(
                        np.mean([[coord.x * coord.value / mean_value, coord.y * coord.value / mean_value]
                                 for coord in coordinates_centerline if coord.z == z], axis=0))

            # compute error between the straightened centerline and the straight line.
            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            if number_of_points >= 10:
                mean_c = mean_coord[2:-2]  # we don't include the four extrema because there are usually messy.
            else:
                mean_c = mean_coord
            for coord_z in mean_c:
                if not np.isnan(np.sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2
                    self.mse_straightening += dist
                    dist = np.sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = np.sqrt(self.mse_straightening / float(count_mean))

            self.elapsed_time_accuracy = time.time() - time_accuracy_results

        os.chdir(curdir)

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        logger.info('Generate output files...')
        if self.curved2straight:
            sct.generate_output_file(os.path.join(path_tmp, "tmp.curve2straight.nii.gz"),
                                     os.path.join(self.path_output, "warp_curve2straight.nii.gz"), verbose)
        if self.straight2curved:
            sct.generate_output_file(os.path.join(path_tmp, "tmp.straight2curve.nii.gz"),
                                     os.path.join(self.path_output, "warp_straight2curve.nii.gz"), verbose)

        # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space
        if self.curved2straight:
            sct.copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                     os.path.join(self.path_output, "straight_ref.nii.gz"))
            # move straightened input file
            if fname_output == '':
                fname_straight = sct.generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                                                          os.path.join(self.path_output,
                                                                       file_anat + "_straight" + ext_anat), verbose)
            else:
                fname_straight = sct.generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                                                          os.path.join(self.path_output, fname_output),
                                                          verbose)  # straightened anatomic

        # Remove temporary files
        if remove_temp_files:
            logger.info('Remove temporary files...')
            sct.rmtree(path_tmp)

        if self.accuracy_results:
            logger.info('Maximum x-y error: {} mm'.format(self.max_distance_straightening))
            logger.info('Accuracy of straightening (MSE): {} mm'.format(self.mse_straightening))

        # display elapsed time
        self.elapsed_time = int(np.round(time.time() - start_time))

        return fname_straight
Exemplo n.º 12
0
    def straighten(self):
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp
        algo_fitting = self.algo_fitting

        # start timer
        start_time = time.time()

        # get path of the toolbox
        path_sct = os.environ.get("SCT_DIR", os.path.dirname(os.path.dirname(__file__)))
        sct.printv(path_sct, verbose)

        # Display arguments
        sct.printv("\nCheck input arguments:", verbose)
        sct.printv("  Input volume ...................... " + fname_anat, verbose)
        sct.printv("  Centerline ........................ " + fname_centerline, verbose)
        sct.printv("  Final interpolation ............... " + interpolation_warp, verbose)
        sct.printv("  Verbose ........................... " + str(verbose), verbose)
        sct.printv("", verbose)

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat)
        path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline)

        path_tmp = sct.tmp_create(basename="straighten_spinalcord", verbose=verbose)

        # Copying input data to tmp folder
        sct.printv('\nCopy files to tmp folder...', verbose)
        img_src = Image(fname_anat).save(os.path.join(path_tmp, "data.nii"))
        img_centerline = Image(fname_centerline).save(os.path.join(path_tmp, "centerline.nii.gz"))

        if self.use_straight_reference:
            Image(self.centerline_reference_filename).save(os.path.join(path_tmp, "centerline_ref.nii.gz"))
        if self.discs_input_filename != '':
            Image(self.discs_input_filename).save(os.path.join(path_tmp, "labels_input.nii.gz"))
        if self.discs_ref_filename != '':
            Image(self.discs_ref_filename).save(os.path.join(path_tmp, "labels_ref.nii.gz"))

        # go to tmp folder
        curdir = os.getcwd()
        os.chdir(path_tmp)

        try:
            # Change orientation of the input centerline into RPI
            sct.printv("\nOrient centerline to RPI orientation...", verbose)
            image_centerline = Image("centerline.nii.gz").change_orientation("RPI").save("centerline_rpi.nii.gz", mutable=True)

            # Get dimension
            sct.printv('\nGet dimensions...', verbose)
            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
            sct.printv('.. matrix size: ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)
            sct.printv('.. voxel size:  ' + str(px) + 'mm x ' + str(py) + 'mm x ' + str(pz) + 'mm', verbose)
            if self.speed_factor != 1.0:
                intermediate_resampling = True
                px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor
            else:
                intermediate_resampling = False

            if intermediate_resampling:
                sct.mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz')
                pz_native = pz

                sct.run(['sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm', str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o', 'centerline_rpi.nii.gz'])
                image_centerline = Image('centerline_rpi.nii.gz')
                nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim

            if np.min(image_centerline.data) < 0 or np.max(image_centerline.data) > 1:
                image_centerline.data[image_centerline.data < 0] = 0
                image_centerline.data[image_centerline.data > 1] = 1
                image_centerline.save()

            """
            Steps: (everything is done in physical space)
            1. open input image and centreline image
            2. extract bspline fitting of the centreline, and its derivatives
            3. compute length of centerline
            4. compute and generate straight space
            5. compute transformations
                for each voxel of one space: (done using matrices --> improves speed by a factor x300)
                    a. determine which plane of spinal cord centreline it is included
                    b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane)
                    c. find the correspondant centreline point in the other space
                    d. find the correspondance of the voxel in the corresponding plane
            6. generate warping fields for each transformations
            7. write warping fields and apply them

            step 5.b: how to find the corresponding plane?
                The centerline plane corresponding to a voxel correspond to the nearest point of the centerline.
                However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant.
                If it is more far than a threshold, warping value should be 0.

            step 5.d: how to make the correspondance between centerline point in both images?
                Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve.
                If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward.
            """

            # number of points along the spinal cord
            if algo_fitting == 'nurbs':
                number_of_points = int(self.precision * (float(nz) / pz))
                if number_of_points < 100:
                    number_of_points *= 50
                if number_of_points == 0:
                    number_of_points = 50
            else:
                number_of_points = nz

            # 2. extract bspline fitting of the centerline, and its derivatives
            img_ctl = Image('centerline_rpi.nii.gz')
            centerline = _get_centerline(img_ctl, algo_fitting, verbose)
            number_of_points = centerline.number_of_points

            # ==========================================================================================
            sct.printv("\nCreate the straight space and the safe zone...", verbose)
            # 3. compute length of centerline
            # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction

            # Computation of the safe zone.
            # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete
            # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the
            # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle)
            # Calculate Ls for both edges and remove appropriate number of centerline points
            radius_safe = 0.0  # mm

            # inferior edge
            u = centerline.derivatives[0]
            v = np.array([0, 0, -1])

            angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v))
            length_safe_inferior = radius_safe * np.sin(angle_inferior)

            # superior edge
            u = centerline.derivatives[-1]
            v = np.array([0, 0, 1])
            angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v))
            length_safe_superior = radius_safe * np.sin(angle_superior)

            # remove points
            inferior_bound = bisect.bisect(centerline.progressive_length, length_safe_inferior) - 1
            superior_bound = centerline.number_of_points - bisect.bisect(centerline.progressive_length_inverse, length_safe_superior)

            z_centerline = centerline.points[:, 2]
            length_centerline = centerline.length
            size_z_centerline = z_centerline[-1] - z_centerline[0]

            # compute the size factor between initial centerline and straight bended centerline
            factor_curved_straight = length_centerline / size_z_centerline
            middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0

            bound_curved = [z_centerline[inferior_bound], z_centerline[superior_bound]]
            bound_straight = [(z_centerline[inferior_bound] - middle_slice) * factor_curved_straight + middle_slice,
                              (z_centerline[superior_bound] - middle_slice) * factor_curved_straight + middle_slice]

            if verbose == 2:
                sct.printv("Length of spinal cord = " + str(length_centerline))
                sct.printv("Size of spinal cord in z direction = " + str(size_z_centerline))
                sct.printv("Ratio length/size = " + str(factor_curved_straight))
                sct.printv("Safe zone boundaries: ")
                sct.printv("Curved space = " + str(bound_curved))
                sct.printv("Straight space = " + str(bound_straight))

            # 4. compute and generate straight space
            # points along curved centerline are already regularly spaced.
            # calculate position of points along straight centerline

            # Create straight NIFTI volumes. TODO: maybe this if case is not needed?
            # ==========================================================================================
            if self.use_straight_reference:
                image_centerline_pad = Image('centerline_rpi.nii.gz')
                nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim

                fname_ref = 'centerline_ref_rpi.nii.gz'
                image_centerline_straight = Image('centerline_ref.nii.gz')\
                    .change_orientation("RPI")\
                    .save(fname_ref, mutable=True)
                centerline_straight = _get_centerline(image_centerline_straight, algo_fitting, verbose)
                nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim

                # Prepare warping fields headers
                hdr_warp = image_centerline_pad.hdr.copy()
                hdr_warp.set_data_dtype('float32')
                hdr_warp_s = image_centerline_straight.hdr.copy()
                hdr_warp_s.set_data_dtype('float32')

                if self.discs_input_filename != "" and self.discs_ref_filename != "":
                    discs_input_image = Image('labels_input.nii.gz')
                    coord = discs_input_image.getNonZeroCoordinates(sorting='z', reverse_coord=True)
                    coord_physical = []
                    for c in coord:
                        c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0]
                        c_p.append(c.value)
                        coord_physical.append(c_p)
                    centerline.compute_vertebral_distribution(coord_physical)
                    centerline.save_centerline(image=discs_input_image, fname_output='discs_input_image.nii.gz')

                    discs_ref_image = Image('labels_ref.nii.gz')
                    coord = discs_ref_image.getNonZeroCoordinates(sorting='z', reverse_coord=True)
                    coord_physical = []
                    for c in coord:
                        c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0]
                        c_p.append(c.value)
                        coord_physical.append(c_p)
                    centerline_straight.compute_vertebral_distribution(coord_physical)
                    centerline_straight.save_centerline(image=discs_ref_image, fname_output='discs_ref_image.nii.gz')

            else:
                sct.printv('\nPad input volume to account for spinal cord length...', verbose)

                start_point = (z_centerline[0] - middle_slice) * factor_curved_straight + middle_slice
                end_point = (z_centerline[-1] - middle_slice) * factor_curved_straight + middle_slice

                offset_z = 0

                # if the destination image is resampled, we still create the straight reference space with the native resolution. # TODO: Maybe this if case is not needed?
                if intermediate_resampling:
                    padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz_native))
                    sct.run(['sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o', 'tmp.centerline_pad_native.nii.gz', '-pad', '0,0,' + str(padding_z)])
                    image_centerline_pad = Image('centerline_rpi_native.nii.gz')
                    nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim
                    start_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0]
                    end_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0]
                    straight_size_x = int(self.xy_size / px)
                    straight_size_y = int(self.xy_size / py)
                    warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x]
                    warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y]
                    if warp_space_x[0] < 0:
                        warp_space_x[1] += warp_space_x[0] - 2
                        warp_space_x[0] = 0
                    if warp_space_y[0] < 0:
                        warp_space_y[1] += warp_space_y[0] - 2
                        warp_space_y[0] = 0

                    spec = dict((
                     (0, warp_space_x),
                     (1, warp_space_y),
                     (2, (0, end_point_coord_native[2] - start_point_coord_native[2])),
                    ))
                    msct_image.spatial_crop(Image("tmp.centerline_pad_native.nii.gz"), spec).save("tmp.centerline_pad_crop_native.nii.gz")

                    fname_ref = 'tmp.centerline_pad_crop_native.nii.gz'
                    offset_z = 4
                else:
                    fname_ref = 'tmp.centerline_pad_crop.nii.gz'

                nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
                padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz)) + offset_z
                from sct_image import pad_image
                image_centerline_pad = pad_image(image_centerline, pad_z_i=padding_z, pad_z_f=padding_z)
                # sct.run(['sct_image', '-i', 'centerline_rpi.nii.gz', '-o', 'tmp.centerline_pad.nii.gz', '-pad', '0,0,' + str(padding_z)])
                # image_centerline_pad = Image('tmp.centerline_pad.nii.gz')
                nx, ny, nz = image_centerline_pad.data.shape
                hdr_warp = image_centerline_pad.hdr.copy()
                hdr_warp.set_data_dtype('float32')
                start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0]
                end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0]

                straight_size_x = int(self.xy_size / px)
                straight_size_y = int(self.xy_size / py)
                warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x]
                warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y]

                if warp_space_x[0] < 0:
                    warp_space_x[1] += warp_space_x[0] - 2
                    warp_space_x[0] = 0
                if warp_space_x[1] >= nx:
                    warp_space_x[1] = nx - 1
                if warp_space_y[0] < 0:
                    warp_space_y[1] += warp_space_y[0] - 2
                    warp_space_y[0] = 0
                if warp_space_y[1] >= ny:
                    warp_space_y[1] = ny - 1

                spec = dict((
                 (0, warp_space_x),
                 (1, warp_space_y),
                 (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)),
                ))
                # msct_image.spatial_crop(Image("tmp.centerline_pad.nii.gz"), spec).save("tmp.centerline_pad_crop.nii.gz")
                image_centerline_straight = msct_image.spatial_crop(image_centerline_pad, spec)

                # image_centerline_straight = Image('tmp.centerline_pad_crop.nii.gz')
                nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim
                hdr_warp_s = image_centerline_straight.hdr.copy()
                hdr_warp_s.set_data_dtype('float32')

                if self.template_orientation == 1:
                    raise NotImplementedError()

                start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0]
                end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0]

                number_of_voxel = nx * ny * nz
                sct.printv("Number of voxel = " + str(number_of_voxel))

                time_centerlines = time.time()

                coord_straight = np.empty((number_of_points,3))
                coord_straight[...,0] = int(np.round(nx_s / 2))
                coord_straight[...,1] = int(np.round(ny_s / 2))
                coord_straight[...,2] = np.linspace(0, end_point_coord[2] - start_point_coord[2], number_of_points)
                coord_phys_straight = image_centerline_straight.transfo_pix2phys(coord_straight)
                derivs_straight = np.empty((number_of_points,3))
                derivs_straight[...,0] = derivs_straight[...,1] = 0
                derivs_straight[...,2] = 1
                dx_straight, dy_straight, dz_straight = derivs_straight.T
                centerline_straight = Centerline(coord_phys_straight[:, 0], coord_phys_straight[:, 1], coord_phys_straight[:, 2],
                                                 dx_straight, dy_straight, dz_straight)

                time_centerlines = time.time() - time_centerlines
                sct.printv('Time to generate centerline: ' + str(np.round(time_centerlines * 1000.0)) + ' ms', verbose)

            if verbose == 2:
                import matplotlib.pyplot as plt
                from datetime import datetime
                curved_points = centerline.progressive_length
                straight_points = centerline_straight.progressive_length
                range_points = np.linspace(0, 1, number_of_points)
                dist_curved = np.zeros(number_of_points)
                dist_straight = np.zeros(number_of_points)
                for i in range(1, number_of_points):
                    dist_curved[i] = dist_curved[i - 1] + curved_points[i - 1] / centerline.length
                    dist_straight[i] = dist_straight[i - 1] + straight_points[i - 1] / centerline_straight.length
                plt.plot(range_points, dist_curved)
                plt.plot(range_points, dist_straight)
                plt.grid(True)
                plt.savefig('fig_straighten_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '.png')
                plt.close()

            #alignment_mode = 'length'
            alignment_mode = 'levels'

            lookup_curved2straight = list(range(centerline.number_of_points))
            if self.discs_input_filename != "":
                # create look-up table curved to straight
                for index in range(centerline.number_of_points):
                    disc_label = centerline.l_points[index]
                    if alignment_mode == 'length':
                        relative_position = centerline.dist_points[index]
                    else:
                        relative_position = centerline.dist_points_rel[index]
                    idx_closest = centerline_straight.get_closest_to_absolute_position(disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode)
                    if idx_closest is not None:
                        lookup_curved2straight[index] = idx_closest
                    else:
                        lookup_curved2straight[index] = 0
            for p in range(0, len(lookup_curved2straight)//2):
                if lookup_curved2straight[p] == lookup_curved2straight[p + 1]:
                    lookup_curved2straight[p] = 0
                else:
                    break
            for p in range(len(lookup_curved2straight)-1, len(lookup_curved2straight)//2, -1):
                if lookup_curved2straight[p] == lookup_curved2straight[p - 1]:
                    lookup_curved2straight[p] = 0
                else:
                    break
            lookup_curved2straight = np.array(lookup_curved2straight)

            lookup_straight2curved = list(range(centerline_straight.number_of_points))
            if self.discs_input_filename != "":
                for index in range(centerline_straight.number_of_points):
                    disc_label = centerline_straight.l_points[index]
                    if alignment_mode == 'length':
                        relative_position = centerline_straight.dist_points[index]
                    else:
                        relative_position = centerline_straight.dist_points_rel[index]
                    idx_closest = centerline.get_closest_to_absolute_position(disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode)
                    if idx_closest is not None:
                        lookup_straight2curved[index] = idx_closest
            for p in range(0, len(lookup_straight2curved)//2):
                if lookup_straight2curved[p] == lookup_straight2curved[p + 1]:
                    lookup_straight2curved[p] = 0
                else:
                    break
            for p in range(len(lookup_straight2curved)-1, len(lookup_straight2curved)//2, -1):
                if lookup_straight2curved[p] == lookup_straight2curved[p - 1]:
                    lookup_straight2curved[p] = 0
                else:
                    break
            lookup_straight2curved = np.array(lookup_straight2curved)

            # Create volumes containing curved and straight warping fields
            time_generation_volumes = time.time()
            data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3))
            data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3))

            # 5. compute transformations
            # Curved and straight images and the same dimensions, so we compute both warping fields at the same time.
            # b. determine which plane of spinal cord centreline it is included
            # sct.printv(nx * ny * nz, nx_s * ny_s * nz_s)

            if self.curved2straight:
                for u in tqdm.tqdm(range(nz_s)):
                    x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1]
                    indexes_straight = np.array(list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel())))
                    physical_coordinates_straight = image_centerline_straight.transfo_pix2phys(indexes_straight)
                    nearest_indexes_straight = centerline_straight.find_nearest_indexes(physical_coordinates_straight)
                    distances_straight = centerline_straight.get_distances_from_planes(physical_coordinates_straight, nearest_indexes_straight)
                    lookup = lookup_straight2curved[nearest_indexes_straight]
                    indexes_out_distance_straight = np.logical_or(np.logical_or(distances_straight > self.threshold_distance, distances_straight < -self.threshold_distance), lookup == 0)
                    projected_points_straight = centerline_straight.get_projected_coordinates_on_planes(physical_coordinates_straight, nearest_indexes_straight)
                    coord_in_planes_straight = centerline_straight.get_in_plans_coordinates(projected_points_straight, nearest_indexes_straight)

                    coord_straight2curved = centerline.get_inverse_plans_coordinates(coord_in_planes_straight, lookup)
                    displacements_straight = coord_straight2curved - physical_coordinates_straight
                    # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+)
                    # while ours is LPI-
                    # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a
                    #  https://www.slicer.org/wiki/Coordinate_systems
                    displacements_straight[:, 2] = -displacements_straight[:, 2]
                    displacements_straight[indexes_out_distance_straight] = [100000.0, 100000.0, 100000.0]

                    data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :] = -displacements_straight

            if self.straight2curved:
                for u in tqdm.tqdm(range(nz)):
                    x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1]
                    indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel())))
                    physical_coordinates = image_centerline_pad.transfo_pix2phys(indexes)
                    nearest_indexes_curved = centerline.find_nearest_indexes(physical_coordinates)
                    distances_curved = centerline.get_distances_from_planes(physical_coordinates, nearest_indexes_curved)
                    lookup = lookup_curved2straight[nearest_indexes_curved]
                    indexes_out_distance_curved = np.logical_or(np.logical_or(distances_curved > self.threshold_distance, distances_curved < -self.threshold_distance), lookup == 0)
                    projected_points_curved = centerline.get_projected_coordinates_on_planes(physical_coordinates, nearest_indexes_curved)
                    coord_in_planes_curved = centerline.get_in_plans_coordinates(projected_points_curved, nearest_indexes_curved)

                    coord_curved2straight = centerline_straight.points[lookup]
                    coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2]
                    coord_curved2straight[:, 2] += distances_curved

                    displacements_curved = coord_curved2straight - physical_coordinates

                    displacements_curved[:, 2] = -displacements_curved[:, 2]
                    displacements_curved[indexes_out_distance_curved] = [100000.0, 100000.0, 100000.0]

                    data_warp_straight2curved[indexes[:, 0], indexes[:, 1], indexes[:, 2], 0, :] = -displacements_curved

            # Creation of the safe zone based on pre-calculated safe boundaries
            coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix([[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix([[0, 0, bound_curved[1]]])
            coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix([[0, 0, bound_straight[0]]]), image_centerline_straight.transfo_phys2pix([[0, 0, bound_straight[1]]])

            if radius_safe > 0:
                data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2], 0, :] = 100000.0
                data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:, 0, :] = 100000.0
                data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2], 0, :] = 100000.0
                data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:, 0, :] = 100000.0

            # Generate warp files as a warping fields
            hdr_warp_s.set_intent('vector', (), '')
            hdr_warp_s.set_data_dtype('float32')
            hdr_warp.set_intent('vector', (), '')
            hdr_warp.set_data_dtype('float32')
            if self.curved2straight:
                img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s)
                save(img, 'tmp.curve2straight.nii.gz')
                sct.printv('\nDONE ! Warping field generated: tmp.curve2straight.nii.gz', verbose)

            if self.straight2curved:
                img = Nifti1Image(data_warp_straight2curved, None, hdr_warp)
                save(img, 'tmp.straight2curve.nii.gz')
                sct.printv('\nDONE ! Warping field generated: tmp.straight2curve.nii.gz', verbose)

            image_centerline_straight.save(fname_ref)
            if self.curved2straight:
                sct.printv('\nApply transformation to input image...', verbose)
                sct.run(['isct_antsApplyTransforms',
                         '-d', '3',
                         '-r', fname_ref,
                         '-i', 'data.nii',
                         '-o', 'tmp.anat_rigid_warp.nii.gz',
                         '-t', 'tmp.curve2straight.nii.gz',
                         '-n', 'BSpline[3]'],
                         verbose=verbose)

            if self.accuracy_results:
                time_accuracy_results = time.time()
                # compute the error between the straightened centerline/segmentation and the central vertical line.
                # Ideally, the error should be zero.
                # Apply deformation to input image
                sct.printv('\nApply transformation to centerline image...', verbose)
                Transform(input_filename='centerline.nii.gz', fname_dest=fname_ref,
                          output_filename="tmp.centerline_straight.nii.gz", interp="nn",
                          warp="tmp.curve2straight.nii.gz", verbose=verbose).apply()
                file_centerline_straight = Image('tmp.centerline_straight.nii.gz', verbose=verbose)
                nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim
                coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting='z')
                mean_coord = []
                for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z):
                    temp_mean = [coord.value for coord in coordinates_centerline if coord.z == z]
                    if temp_mean:
                        mean_value = np.mean(temp_mean)
                        mean_coord.append(np.mean([[coord.x * coord.value / mean_value, coord.y * coord.value / mean_value]
                                                    for coord in coordinates_centerline if coord.z == z], axis=0))

                # compute error between the straightened centerline and the straight line.
                x0 = file_centerline_straight.data.shape[0] / 2.0
                y0 = file_centerline_straight.data.shape[1] / 2.0
                count_mean = 0
                if number_of_points >= 10:
                    mean_c = mean_coord[2:-2]  # we don't include the four extrema because there are usually messy.
                else:
                    mean_c = mean_coord
                for coord_z in mean_c:
                    if not np.isnan(np.sum(coord_z)):
                        dist = ((x0 - coord_z[0]) * px)**2 + ((y0 - coord_z[1]) * py)**2
                        self.mse_straightening += dist
                        dist = np.sqrt(dist)
                        if dist > self.max_distance_straightening:
                            self.max_distance_straightening = dist
                        count_mean += 1
                self.mse_straightening = np.sqrt(self.mse_straightening / float(count_mean))

                self.elapsed_time_accuracy = time.time() - time_accuracy_results

        except Exception as e:
            sct.printv('WARNING: Exception during Straightening:', 1, 'warning')
            sct.printv('Error on line {}'.format(sys.exc_info()[-1].tb_lineno), 1, 'warning')
            sct.printv(str(e), 1, 'warning')
            raise
        finally:
            os.chdir(curdir)

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        sct.printv("\nGenerate output file (in current folder)...", verbose)
        if self.curved2straight:
            sct.generate_output_file(os.path.join(path_tmp, "tmp.curve2straight.nii.gz"), os.path.join(self.path_output, "warp_curve2straight.nii.gz"), verbose)
        if self.straight2curved:
            sct.generate_output_file(os.path.join(path_tmp, "tmp.straight2curve.nii.gz"), os.path.join(self.path_output, "warp_straight2curve.nii.gz"), verbose)

        # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space
        if self.curved2straight:
            sct.copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, "straight_ref.nii.gz"))
            # move straightened input file
            if fname_output == '':
                fname_straight = sct.generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                                                          os.path.join(self.path_output, file_anat + "_straight" + ext_anat), verbose)
            else:
                fname_straight = sct.generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                                                          os.path.join(self.path_output, fname_output), verbose)  # straightened anatomic

        # Remove temporary files
        if remove_temp_files:
            sct.printv("\nRemove temporary files...", verbose)
            sct.rmtree(path_tmp)

        sct.printv('\nDone!\n', verbose)

        if self.accuracy_results:
            sct.printv("Maximum x-y error = " + str(np.round(self.max_distance_straightening, 2)) + " mm", verbose, "bold")
            sct.printv("Accuracy of straightening (MSE) = " + str(np.round(self.mse_straightening, 2)) +
                       " mm", verbose, "bold")

        # display elapsed time
        self.elapsed_time = time.time() - start_time
        sct.printv("\nFinished! Elapsed time: " + str(int(np.round(self.elapsed_time))) + " s", verbose)
        if self.accuracy_results:
            sct.printv('    including ' + str(int(np.round(self.elapsed_time_accuracy))) + ' s spent computing '
                                                                                      'accuracy results', verbose)

        return fname_straight
Exemplo n.º 13
0
def compute_csa(segmentation,
                algo_fitting='bspline',
                angle_correction=True,
                use_phys_coord=True,
                remove_temp_files=1,
                verbose=1):
    """
    Compute CSA.
    Note: segmentation can be binary or weighted for partial volume effect.
    :param segmentation: input segmentation. Could be either an Image or a file name.
    :param algo_fitting:
    :param angle_correction:
    :param use_phys_coord:
    :return metrics: Dict of class process_seg.Metric()
    """
    # create temporary folder
    path_tmp = sct.tmp_create()
    # open image and save in temp folder
    im_seg = msct_image.Image(segmentation).save(path_tmp, )

    # change orientation to RPI
    im_seg.change_orientation('RPI')
    nx, ny, nz, nt, px, py, pz, pt = im_seg.dim
    fname_seg = os.path.join(path_tmp, 'segmentation_RPI.nii.gz')
    im_seg.save(fname_seg)

    # Extract min and max index in Z direction
    data_seg = im_seg.data
    X, Y, Z = (data_seg > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)

    # if angle correction is required, get segmentation centerline
    # Note: even if angle_correction=0, we should run the code below so that z_centerline_voxel is defined (later used
    # with option -vert). See #1791
    # TODO: check if we need use_phys_coord case with recent changes in centerline
    if use_phys_coord:
        # fit centerline, smooth it and return the first derivative (in physical space)
        _, arr_ctl, arr_ctl_der = get_centerline(im_seg,
                                                 algo_fitting=algo_fitting,
                                                 verbose=verbose)
        x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl
        x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = arr_ctl_der
        centerline = Centerline(x_centerline_fit.tolist(),
                                y_centerline_fit.tolist(),
                                z_centerline.tolist(),
                                x_centerline_deriv.tolist(),
                                y_centerline_deriv.tolist(),
                                z_centerline_deriv.tolist())

        # average centerline coordinates over slices of the image
        x_centerline_fit_rescorr, y_centerline_fit_rescorr, z_centerline_rescorr, x_centerline_deriv_rescorr, \
        y_centerline_deriv_rescorr, z_centerline_deriv_rescorr = centerline.average_coordinates_over_slices(im_seg)

        # compute Z axis of the image, in physical coordinate
        axis_X, axis_Y, axis_Z = im_seg.get_directions()

    else:
        # fit centerline, smooth it and return the first derivative (in voxel space but FITTED coordinates)
        _, arr_ctl, arr_ctl_der = get_centerline(im_seg,
                                                 algo_fitting=algo_fitting,
                                                 verbose=verbose)
        x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl
        x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = arr_ctl_der

        # # correct centerline fitted coordinates according to the data resolution
        # x_centerline_fit_rescorr, y_centerline_fit_rescorr, z_centerline_rescorr, \
        # x_centerline_deriv_rescorr, y_centerline_deriv_rescorr, z_centerline_deriv_rescorr = \
        #     x_centerline_fit * px, y_centerline_fit * py, z_centerline * pz, \
        #     x_centerline_deriv * px, y_centerline_deriv * py, z_centerline_deriv * pz
        #
        # axis_Z = [0.0, 0.0, 1.0]

    # Compute CSA
    sct.printv('\nCompute CSA...', verbose)

    # Initialize 1d array with nan. Each element corresponds to a slice.
    csa = np.full_like(np.empty(nz), np.nan, dtype=np.double)
    angles = np.full_like(np.empty(nz), np.nan, dtype=np.double)

    for iz in range(min_z_index, max_z_index + 1):
        if angle_correction:
            # in the case of problematic segmentation (e.g., non continuous segmentation often at the extremities),
            # display a warning but do not crash
            try:
                # normalize the tangent vector to the centerline (i.e. its derivative)
                tangent_vect = normalize(
                    np.array([
                        x_centerline_deriv[iz - min_z_index] * px,
                        y_centerline_deriv[iz - min_z_index] * py, pz
                    ]))

            except IndexError:
                sct.printv(
                    'WARNING: Your segmentation does not seem continuous, which could cause wrong estimations at the '
                    'problematic slices. Please check it, especially at the extremities.',
                    type='warning')

            # compute the angle between the normal vector of the plane and the vector z
            angle = np.arccos(np.vdot(tangent_vect, np.array([0, 0, 1])))
        else:
            angle = 0.0

        # compute the number of voxels, assuming the segmentation is coded for partial volume effect between 0 and 1.
        number_voxels = np.sum(data_seg[:, :, iz])

        # compute CSA, by scaling with voxel size (in mm) and adjusting for oblique plane
        csa[iz] = number_voxels * px * py * np.cos(angle)
        angles[iz] = math.degrees(angle)

    # Remove temporary files
    if remove_temp_files:
        sct.printv('\nRemove temporary files...')
        sct.rmtree(path_tmp)

    # prepare output
    metrics = {
        'csa': Metric(data=csa, label='CSA [mm^2]'),
        'angle': Metric(data=angles,
                        label='Angle between cord axis and z [deg]')
    }
    return metrics
Exemplo n.º 14
0
def extract_centerline(segmentation, verbose=0, algo_fitting='hanning', type_window='hanning',
                       window_length=5, use_phys_coord=True, file_out='centerline'):
    """
    Extract centerline from a binary or weighted segmentation by computing the center of mass slicewise.
    :param segmentation: input segmentation. Could be either an Image or a file name.
    :param verbose:
    :param algo_fitting:
    :param type_window:
    :param window_length:
    :param use_phys_coord: TODO: Explain the pros/cons of use_phys_coord.
    :param file_out:
    :return: None
    """
    # TODO: output continuous centerline (and add in unit test)
    # TODO: centerline coordinate should have the same orientation as the input image
    # TODO: no need for unecessary i/o. Everything could be done in RAM

    # Create temp folder
    path_tmp = sct.tmp_create()
    # Open segmentation volume
    im_seg = msct_image.Image(segmentation)
    # im_seg.change_orientation('RPI', generate_path=True)
    native_orientation = im_seg.orientation
    im_seg.change_orientation("RPI", generate_path=True).save(path_tmp, mutable=True)
    fname_tmp_seg = im_seg.absolutepath

    # extract centerline and smooth it
    if use_phys_coord:
        # fit centerline, smooth it and return the first derivative (in physical space)
        x_centerline_fit, y_centerline_fit, z_centerline, \
        x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
            fname_tmp_seg, algo_fitting=algo_fitting, type_window=type_window, window_length=window_length,
            nurbs_pts_number=3000, phys_coordinates=True, verbose=verbose, all_slices=False)
        centerline = Centerline(x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv,
                                y_centerline_deriv, z_centerline_deriv)

        # average centerline coordinates over slices of the image (floating point)
        x_centerline_fit_rescorr, y_centerline_fit_rescorr, z_centerline_rescorr, \
        x_centerline_deriv_rescorr, y_centerline_deriv_rescorr, z_centerline_deriv_rescorr = \
            centerline.average_coordinates_over_slices(im_seg)

        # compute z_centerline in image coordinates (discrete)
        voxel_coordinates = im_seg.transfo_phys2pix(
            [[x_centerline_fit_rescorr[i], y_centerline_fit_rescorr[i], z_centerline_rescorr[i]] for i in
             range(len(z_centerline_rescorr))])
        x_centerline_voxel = [coord[0] for coord in voxel_coordinates]
        y_centerline_voxel = [coord[1] for coord in voxel_coordinates]
        z_centerline_voxel = [coord[2] for coord in voxel_coordinates]

    else:
        # fit centerline, smooth it and return the first derivative (in voxel space but FITTED coordinates)
        x_centerline_voxel, y_centerline_voxel, z_centerline_voxel, \
        x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
            'segmentation_RPI.nii.gz', algo_fitting=algo_fitting, type_window=type_window, window_length=window_length,
            nurbs_pts_number=3000, phys_coordinates=False, verbose=verbose, all_slices=True)

    if verbose == 2:
        # TODO: code below does not work
        import matplotlib.pyplot as plt

        # Creation of a vector x that takes into account the distance between the labels
        nz_nonz = len(z_centerline_voxel)
        x_display = [0 for i in range(x_centerline_voxel.shape[0])]
        y_display = [0 for i in range(y_centerline_voxel.shape[0])]
        for i in range(0, nz_nonz, 1):
            x_display[int(z_centerline_voxel[i] - z_centerline_voxel[0])] = x_centerline[i]
            y_display[int(z_centerline_voxel[i] - z_centerline_voxel[0])] = y_centerline[i]

        plt.figure(1)
        plt.subplot(2, 1, 1)
        plt.plot(z_centerline_voxel, x_display, 'ro')
        plt.plot(z_centerline_voxel, x_centerline_voxel)
        plt.xlabel("Z")
        plt.ylabel("X")
        plt.title("x and x_fit coordinates")

        plt.subplot(2, 1, 2)
        plt.plot(z_centerline_voxel, y_display, 'ro')
        plt.plot(z_centerline_voxel, y_centerline_voxel)
        plt.xlabel("Z")
        plt.ylabel("Y")
        plt.title("y and y_fit coordinates")
        plt.show()

    # Create an image with the centerline
    # TODO: write the center of mass, not the discrete image coordinate (issue #1938)
    im_centerline = im_seg.copy()
    data_centerline = im_centerline.data * 0
    # Find z-boundaries above which and below which there is no non-null slices
    min_z_index, max_z_index = int(round(min(z_centerline_voxel))), int(round(max(z_centerline_voxel)))
    # loop across slices and set centerline pixel to value=1
    for iz in range(min_z_index, max_z_index + 1):
        data_centerline[int(round(x_centerline_voxel[iz - min_z_index])),
                        int(round(y_centerline_voxel[iz - min_z_index])),
                        int(iz)] = 1
    # assign data to centerline image
    im_centerline.data = data_centerline
    # reorient centerline to native orientation
    im_centerline.change_orientation(native_orientation)
    # save nifti volume
    fname_centerline = file_out + '.nii.gz'
    im_centerline.save(fname_centerline, dtype='uint8')
    # display stuff
    # sct.display_viewer_syntax([fname_segmentation, fname_centerline], colormaps=['gray', 'green'])

    # output csv with centerline coordinates
    fname_centerline_csv = file_out + '.csv'
    f_csv = open(fname_centerline_csv, 'w')
    f_csv.write('x,y,z\n')  # csv header
    for i in range(min_z_index, max_z_index + 1):
        f_csv.write("%d,%d,%d\n" % (int(i),
                                    x_centerline_voxel[i - min_z_index],
                                    y_centerline_voxel[i - min_z_index]))
    f_csv.close()
    # TODO: display open syntax for csv

    # create a .roi file
    fname_roi_centerline = optic.centerline2roi(fname_image=fname_centerline,
                                                folder_output='./',
                                                verbose=verbose)