def run_main():
    sct.init_sct()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    # Input filename
    fname_input_data = arguments["-i"]
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = 'optic'
    if "-method" in arguments:
        method = arguments["-method"]

    # Contrast type
    contrast_type = ''
    if "-c" in arguments:
        contrast_type = arguments["-c"]
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = 'ERROR: -c is a mandatory argument when using Optic method.'
        sct.printv(error, type='error')
        return

    # Ga between slices
    interslice_gap = 10.0
    if "-gap" in arguments:
        interslice_gap = float(arguments["-gap"])

    # Output folder
    if "-o" in arguments:
        file_output = arguments["-o"]
    else:
        path_data, file_data, ext_data = sct.extract_fname(fname_data)
        file_output = os.path.join(path_data, file_data + '_centerline')

    # Verbosity
    verbose = 0
    if "-v" in arguments:
        verbose = int(arguments["-v"])

    if method == 'viewer':
        im_labels = _call_viewer_centerline(Image(fname_data),
                                            interslice_gap=interslice_gap)
        im_centerline, arr_centerline, _ = \
            get_centerline(im_labels, algo_fitting='polyfit', param=ParamCenterline(degree=3), minmax=True,
                           verbose=verbose)
    else:
        im_centerline, arr_centerline, _ = \
            get_centerline(Image(fname_data), algo_fitting='optic', param=ParamCenterline(contrast=contrast_type),
                           minmax=True, verbose=verbose)

    # save centerline as nifti (discrete) and csv (continuous) files
    im_centerline.save(file_output + '.nii.gz')
    np.savetxt(file_output + '.csv', arr_centerline.transpose(), delimiter=",")

    sct.display_viewer_syntax([fname_input_data, file_output + '.nii.gz'],
                              colormaps=['gray', 'red'],
                              opacities=['', '1'])
def test_get_centerline_optic():
    """Test extraction of metrics aggregation across slices: All slices by default"""
    fname_t2 = os.path.join(__sct_dir__, 'sct_testing_data/t2/t2.nii.gz'
                            )  # install: sct_download_data -d sct_testing_data
    img_t2 = Image(fname_t2)
    # Add non-numerical values at the top corner of the image for testing purpose
    img_t2.change_type('float32')
    img_t2.data[0, 0, 0] = np.nan
    img_t2.data[1, 0, 0] = np.inf
    img_out, arr_out, _, _ = get_centerline(img_t2,
                                            ParamCenterline(
                                                algo_fitting='optic',
                                                contrast='t2',
                                                minmax=False),
                                            verbose=VERBOSE)
    # Open ground truth segmentation and compare
    fname_t2_seg = os.path.join(__sct_dir__,
                                'sct_testing_data/t2/t2_seg.nii.gz')
    img_seg_out, arr_seg_out, _, _ = get_centerline(Image(fname_t2_seg),
                                                    ParamCenterline(
                                                        algo_fitting='bspline',
                                                        minmax=False),
                                                    verbose=VERBOSE)
    assert np.linalg.norm(
        find_and_sort_coord(img_seg_out) - find_and_sort_coord(img_out)) < 3.5
Exemple #3
0
def test_get_centerline_nurbs(img_ctl, expected, params):
    """Test centerline fitting using nurbs"""
    img, img_sub = [img_ctl[0].copy(), img_ctl[1].copy()]
    img_out, arr_out, arr_deriv_out, fit_results = get_centerline(
        img_sub, ParamCenterline(algo_fitting='nurbs', minmax=False), verbose=VERBOSE)
    assert np.median(find_and_sort_coord(img) - find_and_sort_coord(img_out)) == expected['median']
    assert fit_results.laplacian_max < expected['laplacian']
Exemple #4
0
def test_get_centerline_polyfit_minmax(img_ctl, expected):
    """Test centerline fitting with minmax=True"""
    img, img_sub = [img_ctl[0].copy(), img_ctl[1].copy()]
    img_out, arr_out, _, _ = get_centerline(
        img_sub, ParamCenterline(algo_fitting='polyfit', degree=3, minmax=True), verbose=VERBOSE)
    # Assess output size
    assert arr_out.shape == expected
    def angle_correction(self):
        im_seg = Image(self.fname_sc)
        nx, ny, nz, nt, px, py, pz, pt = im_seg.dim
        data_seg = im_seg.data

        # fit centerline, smooth it and return the first derivative (in physical space)
        _, arr_ctl, arr_ctl_der, _ = get_centerline(im_seg,
                                                    param=ParamCenterline(),
                                                    verbose=1)
        x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = arr_ctl_der

        self.angles = np.full_like(np.empty(nz), np.nan, dtype=np.double)

        # loop across x_centerline_deriv (instead of [min_z_index, max_z_index], which could vary after interpolation)
        for iz in range(x_centerline_deriv.shape[0]):
            # normalize the tangent vector to the centerline (i.e. its derivative)
            tangent_vect = self._normalize(
                np.array([
                    x_centerline_deriv[iz] * px, y_centerline_deriv[iz] * py,
                    pz
                ]))

            # compute the angle between the normal vector of the plane and the vector z
            angle = np.arccos(np.vdot(tangent_vect, np.array([0, 0, 1])))
            self.angles[iz] = math.degrees(angle)
Exemple #6
0
 def get_center_spit(self, img_idx=-1):
     """Retrieve index of the medial plane (in the R-L direction) for each slice (in the I-S direction) in order
     to center the spinal cord in the sagittal view.
     Exception: if the input mask only has a single label (e.g., for sct_detect_pmj), then output the index that has
     the sagittal slice centered at that label."""
     image = self._images[img_idx].copy()
     # If mask is empty, raise error
     if np.argwhere(image.data).shape[0] == 0:
         logging.error('Mask is empty')
     # If mask only has one label (e.g., in sct_detect_pmj), return the repmat of the R-L index (assuming SAL orient)
     elif np.argwhere(image.data).shape[0] == 1:
         return [np.argwhere(image.data)[0][2]] * image.data.shape[2]
     # Otherwise, find the center of mass per slice and return the R-L index
     else:
         from spinalcordtoolbox.centerline.core import ParamCenterline, get_centerline
         image.change_orientation(
             'RPI'
         )  # need to do that because get_centerline operates in RPI orientation
         # Get coordinate of centerline
         _, arr_ctl_RPI, _, _ = get_centerline(image,
                                               param=ParamCenterline())
         # Extend the centerline by copying values below zmin and above zmax to avoid discontinuities
         zmin, zmax = arr_ctl_RPI[2, :].min().astype(int), arr_ctl_RPI[
             2, :].max().astype(int)
         index_RL_in_RPI = np.concatenate([
             np.ones(zmin) * arr_ctl_RPI[0, 0], arr_ctl_RPI[0, 1:],
             np.ones(image.data.shape[2] - zmax) * arr_ctl_RPI[0, -1]
         ])
         # reorient R-L index to go from RPI to SAL
         index_RL_in_SAL = image.data.shape[0] - index_RL_in_RPI
         # then reverse to go from RL to LR
         index_RL_in_SAL = index_RL_in_SAL[::-1]
         return index_RL_in_SAL
    def get_center_spit(self, img_idx=-1):
        """
        Retrieve index along in the R-L direction for each S-I slice in order to center the spinal cord in the
        medial plane, around the labels or segmentation.

        By default, it looks at the latest image in the input list of images, assuming the latest is the labels or
        segmentation.

        If only one label is found, the cord will be centered at that label.

        :return: index: [int] * n_SI
        """
        image = self._images[img_idx].copy()
        assert image.orientation == 'SAL'
        # If mask is empty, raise error
        if np.argwhere(image.data).shape[0] == 0:
            logging.error('Mask is empty')
        # If mask only has one label (e.g., in sct_detect_pmj), return the repmat of the R-L index (assuming SAL orient)
        elif np.argwhere(image.data).shape[0] == 1:
            return [np.argwhere(image.data)[0][2]] * image.data.shape[2]
        # Otherwise, find the center of mass of each label (per axial plane) and extrapolate linearly
        else:
            image.change_orientation('RPI')  # need to do that because get_centerline operates in RPI orientation
            # Get coordinate of centerline
            # Here we use smooth=0 because we want the centerline to pass through the labels, and minmax=True extends
            # the centerline below zmin and above zmax to avoid discontinuities
            data_ctl_RPI, _, _, _ = get_centerline(
                image, param=ParamCenterline(algo_fitting='linear', smooth=0, minmax=False))
            data_ctl_RPI.change_orientation('SAL')
            index_RL = np.argwhere(data_ctl_RPI.data)
            return [index_RL[i][2] for i in range(len(index_RL))]
def test_get_centerline_linear(img_ctl, expected):
    """Test centerline fitting using linear interpolation"""
    deg = 3
    img, img_sub = [img_ctl[0].copy(), img_ctl[1].copy()]
    img_out, arr_out, _ = get_centerline(img_sub, algo_fitting='linear', param=ParamCenterline(degree=deg),
                                         minmax=False, verbose=VERBOSE)
    assert np.linalg.norm(find_and_sort_coord(img) - find_and_sort_coord(img_out)) < expected
Exemple #9
0
def test_get_centerline_polyfit(img_ctl, expected, params):
    """Test centerline fitting using polyfit"""
    img, img_sub = [img_ctl[0].copy(), img_ctl[1].copy()]
    img_out, arr_out, arr_deriv_out, fit_results = get_centerline(
        img_sub, ParamCenterline(algo_fitting='polyfit', minmax=False), verbose=VERBOSE)
    assert np.median(find_and_sort_coord(img) - find_and_sort_coord(img_out)) == expected['median']
    assert np.max(np.absolute(np.diff(arr_deriv_out))) < expected['laplacian']
    # check arr_out only if input orientation is RPI (because the output array is always in RPI)
    if img.orientation == 'RPI':
        assert np.linalg.norm(find_and_sort_coord(img) - arr_out) < expected['norm']
def test_get_centerline_polyfit(img_ctl, expected):
    """Test centerline fitting using polyfit"""
    deg = 3
    img, img_sub = [img_ctl[0].copy(), img_ctl[1].copy()]
    img_out, arr_out, _ = get_centerline(img_sub, algo_fitting='polyfit', param=ParamCenterline(degree=deg),
                                         minmax=False, verbose=VERBOSE)

    assert np.linalg.norm(find_and_sort_coord(img) - find_and_sort_coord(img_out)) < expected
    # check arr_out and arr_out_deriv only if input orientation is RPI (because the output array is always in RPI)
    if img.orientation == 'RPI':
        assert np.linalg.norm(find_and_sort_coord(img) - arr_out) < expected
Exemple #11
0
def flatten_sagittal(im_anat, im_centerline, verbose):
    """
    Flatten a 3D volume using the segmentation, such that the spinal cord is centered in the R-L medial plane.

    :param im_anat:
    :param im_centerline:
    :param verbose:
    :return:
    """
    # re-oriente to RPI
    orientation_native = im_anat.orientation
    im_anat.change_orientation("RPI")
    im_centerline.change_orientation("RPI")
    nx, ny, nz, nt, px, py, pz, pt = im_anat.dim

    # smooth centerline and return fitted coordinates in voxel space
    _, arr_ctl, _, _ = get_centerline(im_centerline,
                                      param=ParamCenterline(),
                                      verbose=verbose)
    x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl

    # Extend the centerline by copying values below zmin and above zmax to avoid discontinuities
    zmin, zmax = z_centerline.min().astype(int), z_centerline.max().astype(int)
    x_centerline_extended = np.concatenate([
        np.ones(zmin) * x_centerline_fit[0], x_centerline_fit,
        np.ones(nz - zmax) * x_centerline_fit[-1]
    ])

    # change type to float32 and scale between -1 and 1 as requested by img_as_float(). See #1790, #2069
    im_anat_flattened = change_type(im_anat, np.float32)
    min_data, max_data = np.min(im_anat_flattened.data), np.max(
        im_anat_flattened.data)
    im_anat_flattened.data = 2 * im_anat_flattened.data / (max_data -
                                                           min_data) - 1

    # loop and translate each axial slice, such that the flattened centerline is centered in the medial plane (R-L)
    for iz in range(nz):
        # compute translation along x (R-L)
        translation_x = x_centerline_extended[iz] - np.round(nx / 2.0)
        # apply transformation to 2D image with linear interpolation
        # tform = tf.SimilarityTransform(scale=1, rotation=0, translation=(translation_x, 0))
        tform = transform.SimilarityTransform(translation=(0, translation_x))
        # important to force input in float to skikit image, because it will output float values
        img = img_as_float(im_anat_flattened.data[:, :, iz])
        img_reg = transform.warp(img, tform)
        im_anat_flattened.data[:, :, iz] = img_reg

    # change back to native orientation
    im_anat_flattened.change_orientation(orientation_native)

    return im_anat_flattened
Exemple #12
0
def test_get_centerline_optic(params):
    """Test centerline extraction with optic"""
    # TODO: add assert on the output .csv files for more precision
    im = Image(params['fname_image'])
    # Add non-numerical values at the top corner of the image for testing purpose
    im.change_type('float32')
    im.data[0, 0, 0] = np.nan
    im.data[1, 0, 0] = np.inf
    im_centerline, arr_out, _, _ = get_centerline(
        im,
        ParamCenterline(algo_fitting='optic',
                        contrast=params['contrast'],
                        minmax=False),
        verbose=VERBOSE)
    # Compare with ground truth centerline
    assert np.all(
        im_centerline.data == Image(params['fname_centerline-optic']).data)
def test_compute_shape(im_seg, expected, params):
    metrics, fit_results = process_seg.compute_shape(
        im_seg,
        angle_correction=params['angle_corr'],
        param_centerline=ParamCenterline(),
        verbose=VERBOSE)
    for key in expected.keys():
        # fetch obtained_value
        if 'slice' in params:
            obtained_value = float(metrics['area'].data[params['slice']])
        else:
            obtained_value = float(np.mean(metrics[key].data))
        # fetch expected_value
        if expected[key] is np.nan:
            assert math.isnan(obtained_value)
            break
        else:
            expected_value = pytest.approx(expected[key], rel=0.05)
        assert obtained_value == expected_value
Exemple #14
0
    def __init__(self,
                 input_filename,
                 centerline_filename,
                 debug=0,
                 param_centerline=ParamCenterline(),
                 interpolation_warp='spline',
                 rm_tmp_files=1,
                 verbose=1,
                 precision=2.0,
                 threshold_distance=10,
                 output_filename=''):
        self.input_filename = input_filename
        self.centerline_filename = centerline_filename
        self.output_filename = output_filename
        self.debug = debug
        self.interpolation_warp = interpolation_warp
        self.remove_temp_files = rm_tmp_files  # remove temporary files
        self.verbose = verbose
        self.precision = precision
        self.threshold_distance = threshold_distance
        self.path_output = ""
        self.use_straight_reference = False
        self.centerline_reference_filename = ""
        self.discs_input_filename = ""
        self.discs_ref_filename = ""
        self.speed_factor = 1.0  # Speed parameter
        self.xy_size = 70  # in mm
        self.param_centerline = param_centerline

        # QC metrics
        self.accuracy_results = 0
        self.mse_straightening = 0.0
        self.max_distance_straightening = 0.0
        self.elapsed_time = 0.0
        self.elapsed_time_accuracy = 0.0

        # Outputs
        self.curved2straight = True
        self.straight2curved = True
        self.path_qc = None

        self.template_orientation = 0
def test_compute_shape(im_seg, expected, params):
    metrics, fit_results = process_seg.compute_shape(im_seg,
                                                     angle_correction=params['angle_corr'],
                                                     param_centerline=ParamCenterline(),
                                                     verbose=VERBOSE)
    for key in expected.keys():
        # fetch obtained_value
        if 'slice' in params:
            obtained_value = float(metrics['area'].data[params['slice']])
        else:
            if key == 'length':
                # when computing length, sums values across slices
                obtained_value = metrics[key].data.sum()
            else:
                # otherwise, average across slices
                obtained_value = metrics[key].data.mean()
        # fetch expected_value
        if expected[key] is np.nan:
            assert math.isnan(obtained_value)
            break
        else:
            expected_value = pytest.approx(expected[key], rel=0.05)
        assert obtained_value == expected_value
Exemple #16
0
def main(args=None):
    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
    else:
        arguments = parser.parse_args(
            args=None if sys.argv[1:] else ['--help'])

    # Initialization
    slices = ''
    group_funcs = (('MEAN', func_wa), ('STD', func_std)
                   )  # functions to perform when aggregating metrics along S-I

    fname_segmentation = get_absolute_path(arguments.i)
    fname_vert_levels = ''
    if arguments.o is not None:
        file_out = os.path.abspath(arguments.o)
    else:
        file_out = ''
    if arguments.append is not None:
        append = arguments.append
    else:
        append = 0
    if arguments.vert is not None:
        vert_levels = arguments.vert
    else:
        vert_levels = ''
    remove_temp_files = arguments.r
    if arguments.vertfile is not None:
        fname_vert_levels = arguments.vertfile
    if arguments.perlevel is not None:
        perlevel = arguments.perlevel
    else:
        perlevel = None
    if arguments.z is not None:
        slices = arguments.z
    if arguments.perslice is not None:
        perslice = arguments.perslice
    else:
        perslice = None
    angle_correction = arguments.angle_corr
    param_centerline = ParamCenterline(algo_fitting=arguments.centerline_algo,
                                       smooth=arguments.centerline_smooth,
                                       minmax=True)
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject

    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level

    # update fields
    metrics_agg = {}
    if not file_out:
        file_out = 'csa.csv'

    metrics, fit_results = compute_shape(fname_segmentation,
                                         angle_correction=angle_correction,
                                         param_centerline=param_centerline,
                                         verbose=verbose)
    for key in metrics:
        if key == 'length':
            # For computing cord length, slice-wise length needs to be summed across slices
            metrics_agg[key] = aggregate_per_slice_or_level(
                metrics[key],
                slices=parse_num_list(slices),
                levels=parse_num_list(vert_levels),
                perslice=perslice,
                perlevel=perlevel,
                vert_level=fname_vert_levels,
                group_funcs=(('SUM', func_sum), ))
        else:
            # For other metrics, we compute the average and standard deviation across slices
            metrics_agg[key] = aggregate_per_slice_or_level(
                metrics[key],
                slices=parse_num_list(slices),
                levels=parse_num_list(vert_levels),
                perslice=perslice,
                perlevel=perlevel,
                vert_level=fname_vert_levels,
                group_funcs=group_funcs)
    metrics_agg_merged = merge_dict(metrics_agg)
    save_as_csv(metrics_agg_merged,
                file_out,
                fname_in=fname_segmentation,
                append=append)

    # QC report (only show CSA for clarity)
    if path_qc is not None:
        generate_qc(fname_segmentation,
                    args=args,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    path_img=_make_figure(metrics_agg_merged, fit_results),
                    process='sct_process_segmentation')

    display_open(file_out)
    def continuous_vertebral_levels(self):
        """
        This function transforms the vertebral levels file from the template into a continuous file.
        Instead of having integer representing the vertebral level on each slice, a continuous value that represents
        the position of the slice in the vertebral level coordinate system.
        The image must be RPI
        :return:
        """
        im_input = Image(self.image_input, self.verbose)
        im_output = msct_image.zeros_like(self.image_input)

        # 1. extract vertebral levels from input image
        #   a. extract centerline
        #   b. for each slice, extract corresponding level
        nx, ny, nz, nt, px, py, pz, pt = im_input.dim
        from spinalcordtoolbox.centerline.core import ParamCenterline, get_centerline
        _, arr_ctl, _, _ = get_centerline(self.image_input,
                                          param=ParamCenterline())
        x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl
        value_centerline = np.array([
            im_input.data[int(x_centerline_fit[it]),
                          int(y_centerline_fit[it]),
                          int(z_centerline[it])]
            for it in range(len(z_centerline))
        ])

        # 2. compute distance for each vertebral level --> Di for i being the vertebral levels
        vertebral_levels = {}
        for slice_image, level in enumerate(value_centerline):
            if level not in vertebral_levels:
                vertebral_levels[level] = slice_image

        length_levels = {}
        for level in vertebral_levels:
            indexes_slice = np.where(value_centerline == level)
            length_levels[level] = np.sum([
                np.sqrt(((x_centerline_fit[indexes_slice[0][index_slice + 1]] -
                          x_centerline_fit[indexes_slice[0][index_slice]]) *
                         px)**2 +
                        ((y_centerline_fit[indexes_slice[0][index_slice + 1]] -
                          y_centerline_fit[indexes_slice[0][index_slice]]) *
                         py)**2 +
                        ((z_centerline[indexes_slice[0][index_slice + 1]] -
                          z_centerline[indexes_slice[0][index_slice]]) *
                         pz)**2)
                for index_slice in range(len(indexes_slice[0]) - 1)
            ])

        # 2. for each slice:
        #   a. identify corresponding vertebral level --> i
        #   b. calculate distance of slice from upper vertebral level --> d
        #   c. compute relative distance in the vertebral level coordinate system --> d/Di
        continuous_values = {}
        for it, iz in enumerate(z_centerline):
            level = value_centerline[it]
            indexes_slice = np.where(value_centerline == level)
            indexes_slice = indexes_slice[0][indexes_slice[0] >= it]
            distance_from_level = np.sum([
                np.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] -
                          x_centerline_fit[indexes_slice[index_slice]]) * px *
                         px)**2 +
                        ((y_centerline_fit[indexes_slice[index_slice + 1]] -
                          y_centerline_fit[indexes_slice[index_slice]]) * py *
                         py)**2 +
                        ((z_centerline[indexes_slice[index_slice + 1]] -
                          z_centerline[indexes_slice[index_slice]]) * pz *
                         pz)**2)
                for index_slice in range(len(indexes_slice) - 1)
            ])
            continuous_values[iz] = level + 2.0 * distance_from_level / float(
                length_levels[level])

        # 3. saving data
        # for each slice, get all non-zero pixels and replace with continuous values
        coordinates_input = self.image_input.getNonZeroCoordinates()
        im_output.change_type(np.float32)
        # for all points in input, find the value that has to be set up, depending on the vertebral level
        for i, coord in enumerate(coordinates_input):
            im_output.data[int(coord.x),
                           int(coord.y),
                           int(coord.z)] = continuous_values[coord.z]

        return im_output
Exemple #18
0
def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname):
    """
    Assumes RPI orientation

    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:
    :return:
    """

    im = Image(image_fname)
    ctl_absolute_path = sct.add_suffix(im.absolutepath, "_ctr")

    # isct_spine_detect requires nz > 1
    if im.dim[2] == 1:
        im = concat_data([im, im], dim=2)
        im.hdr['dim'][3] = 2  # Needs to be change manually since dim not updated during concat_data
        bool_2d = True
    else:
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        im_ctl, _, _, _ = get_centerline(im,
                                        ParamCenterline(algo_fitting='optic', contrast=contrast_type))

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                         't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                         't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                         'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
        dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
                          't2s': {'features': 8, 'dilation_layers': 3},
                          't1': {'features': 24, 'dilation_layers': 3},
                          'dwi': {'features': 8, 'dilation_layers': 2}}

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast_type]['size'][0],
                                        width=dct_patch_ctr[contrast_type]['size'][1],
                                        channels=1,
                                        classes=1,
                                        features=dct_params_ctr[contrast_type]['features'],
                                        depth=2,
                                        temperature=1.0,
                                        padding='same',
                                        batchnorm=True,
                                        dropout=0.0,
                                        dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        # compute the heatmap
        im_heatmap, z_max = heatmap(im=im,
                                    model=ctr_model,
                                    patch_shape=dct_patch_ctr[contrast_type]['size'],
                                    mean_train=dct_patch_ctr[contrast_type]['mean'],
                                    std_train=dct_patch_ctr[contrast_type]['std'],
                                    brain_bool=brain_bool)
        im_ctl, _, _, _ = get_centerline(im_heatmap,
                                        ParamCenterline(algo_fitting='optic', contrast=contrast_type))

        if z_max is not None:
            sct.printv('Cropping brain section.')
            im_ctl.data[:, :, z_max:] = 0

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(im)
        im_ctl, _, _, _ = get_centerline(im_labels, param=ParamCenterline())

    elif algo == 'file':
        im_ctl = Image(centerline_fname)
        im_ctl.change_orientation('RPI')

    else:
        logger.error('The parameter "-centerline" is incorrect. Please try again.')
        sys.exit(1)

    # TODO: for some reason, when algo == 'file', the absolutepath is changed to None out of the method find_centerline
    im_ctl.absolutepath = ctl_absolute_path

    if bool_2d:
        im_ctl = split_data(im_ctl, dim=2)[0]

    if algo != 'viewer':
        im_labels = None

    # TODO: remove unecessary return params
    return "dummy_file_name", im_ctl, im_labels
Exemple #19
0
def get_parser():
    param = Param()
    parser = Parser(__file__)
    parser.usage.set_description('Register an anatomical image to the spinal cord MRI template (default: PAM50).\n\n'
                                 'The registration process includes three main registration steps:\n'
                                   '1. straightening of the image using the spinal cord segmentation (see sct_straighten_spinalcord for details);\n'
                                   '2. vertebral alignment between the image and the template, using labels along the spine;\n'
                                   '3. iterative slice-wise non-linear registration (see sct_register_multimodal for details)\n\n'
                                 'To register a subject to the template, try the default command:\n'
                                   'sct_register_to_template -i data.nii.gz -s data_seg.nii.gz -l data_labels.nii.gz\n\n'
                                 'If this default command does not produce satisfactory results, please refer to:\n'
                                   'https://sourceforge.net/p/spinalcordtoolbox/wiki/registration_tricks/\n\n'
                                 'The default registration method brings the subject image to the template, which can be problematic with highly non-isotropic images as it would induce large interpolation errors during the straightening procedure. Although the default method is recommended, you may want to register the template to the subject (instead of the subject to the template) by skipping the straightening procedure. To do so, use the parameter "-ref subject". Example below:\n'
                                   'sct_register_to_template -i data.nii.gz -s data_seg.nii.gz -l data_labels.nii.gz -ref subject -param step=1,type=seg,algo=centermassrot,smooth=0:step=2,type=seg,algo=columnwise,smooth=0,smoothWarpXY=2\n\n'
                                 'Vertebral alignment (step 2) consists in aligning the vertebrae between the subject and the template. Two types of labels are possible:\n'
                                   '- Vertebrae mid-body labels, created at the center of the spinal cord using the parameter "-l";\n'
                                   '- Posterior edge of the intervertebral discs, using the parameter "-ldisc".\n\n'
                                 'If only one label is provided, a simple translation will be applied between the subject label and the template label. No scaling will be performed. \n\n'
                                 'If two labels are provided, a linear transformation (translation + rotation + superior-inferior linear scaling) will be applied. The strategy here is to defined labels that cover the region of interest. For example, if you are interested in studying C2 to C6 levels, then provide one label at C2 and another at C6. However, note that if the two labels are very far apart (e.g. C2 and T12), there might be a mis-alignment of discs because a subject''s intervertebral discs distance might differ from that of the template.\n\n'
                                 'If more than two labels (only with the parameter "-disc") are used, a non-linear registration will be applied to align the each intervertebral disc between the subject and the template, as described in sct_straighten_spinalcord. This the most accurate and preferred method. This feature does not work with the parameter "-ref subject".\n\n'
                                 'More information about label creation can be found at https://www.slideshare.net/neuropoly/sct-course-20190121/42'
      )
    parser.add_option(name="-i",
                      type_value="file",
                      description="Anatomical image.",
                      mandatory=True,
                      example="anat.nii.gz")
    parser.add_option(name="-s",
                      type_value="file",
                      description="Spinal cord segmentation.",
                      mandatory=True,
                      example="anat_seg.nii.gz")
    parser.add_option(name="-l",
                      type_value="file",
                      description="One or two labels (preferred) located at the center of the spinal cord, on the "
                                  "mid-vertebral slice. For more information about label creation, please see: "
                                  "https://www.slideshare.net/neuropoly/sct-course-20190121/42",
                      mandatory=False,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-ldisc",
                      type_value="file",
                      description="Labels located at the posterior edge of the intervertebral discs. If you are using "
                                  "more than 2 labels, all disc covering the region of interest should be provided. "
                                  "E.g., if you are interested in levels C2 to C7, then you should provide disc labels "
                                  "2,3,4,5,6,7). For more information about label creation, please refer to "
                                  "https://www.slideshare.net/neuropoly/sct-course-20190121/42",  # TODO: update URL
                      mandatory=False,
                      default_value='',
                      example="anat_labels.nii.gz")
    parser.add_option(name="-ofolder",
                      type_value="folder_creation",
                      description="Output folder.",
                      mandatory=False,
                      default_value='')
    parser.add_option(name="-t",
                      type_value="folder",
                      description="Path to template.",
                      mandatory=False,
                      default_value=param.path_template)
    parser.add_option(name='-c',
                      type_value='multiple_choice',
                      description='Contrast to use for registration.',
                      mandatory=False,
                      default_value='t2',
                      example=['t1', 't2', 't2s'])
    parser.add_option(name='-ref',
                      type_value='multiple_choice',
                      description='Reference for registration: template: subject->template, subject: template->subject.',
                      mandatory=False,
                      default_value='template',
                      example=['template', 'subject'])
    parser.add_option(name="-param",
                      type_value=[[':'], 'str'],
                      description='Parameters for registration (see sct_register_multimodal). Default: \
                      \n--\nstep=0\ntype=' + paramreg.steps['0'].type + '\ndof=' + paramreg.steps['0'].dof + '\
                      \n--\nstep=1\ntype=' + paramreg.steps['1'].type + '\nalgo=' + paramreg.steps['1'].algo + '\nmetric=' + paramreg.steps['1'].metric + '\niter=' + paramreg.steps['1'].iter + '\nsmooth=' + paramreg.steps['1'].smooth + '\ngradStep=' + paramreg.steps['1'].gradStep + '\nslicewise=' + paramreg.steps['1'].slicewise + '\nsmoothWarpXY=' + paramreg.steps['1'].smoothWarpXY + '\npca_eigenratio_th=' + paramreg.steps['1'].pca_eigenratio_th + '\
                      \n--\nstep=2\ntype=' + paramreg.steps['2'].type + '\nalgo=' + paramreg.steps['2'].algo + '\nmetric=' + paramreg.steps['2'].metric + '\niter=' + paramreg.steps['2'].iter + '\nsmooth=' + paramreg.steps['2'].smooth + '\ngradStep=' + paramreg.steps['2'].gradStep + '\nslicewise=' + paramreg.steps['2'].slicewise + '\nsmoothWarpXY=' + paramreg.steps['2'].smoothWarpXY + '\npca_eigenratio_th=' + paramreg.steps['1'].pca_eigenratio_th,
                      mandatory=False)
    parser.add_option(name='-centerline-algo',
                      type_value='multiple_choice',
                      description='Algorithm for centerline fitting (when straightening the spinal cord).',
                      mandatory=False,
                      example=['polyfit', 'bspline', 'linear', 'nurbs'],
                      default_value=ParamCenterline().algo_fitting)
    parser.add_option(name='-centerline-smooth',
                      type_value='int',
                      description='Degree of smoothing for centerline fitting. Only use with -centerline-algo {bspline, linear}.',
                      mandatory=False,
                      default_value=ParamCenterline().smooth)
    parser.add_option(name='-qc',
                      type_value='folder_creation',
                      description='The path where the quality control generated content will be saved',
                      default_value=param.path_qc)
    parser.add_option(name='-qc-dataset',
                      type_value='str',
                      description='If provided, this string will be mentioned in the QC report as the dataset the process was run on',
                      )
    parser.add_option(name='-qc-subject',
                      type_value='str',
                      description='If provided, this string will be mentioned in the QC report as the subject the process was run on',
                      )
    parser.add_option(name="-igt",
                      type_value="image_nifti",
                      description="File name of ground-truth template cord segmentation (binary nifti).",
                      mandatory=False)
    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="""Remove temporary files.""",
                      mandatory=False,
                      default_value=param.remove_temp_files,
                      example=['0', '1'])
    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="""Verbose. 0: nothing. 1: basic. 2: extended.""",
                      mandatory=False,
                      default_value=param.verbose,
                      example=['0', '1', '2'])
    return parser
Exemple #20
0
def main(args=None):

    # initializations
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    if '-l' in arguments:
        fname_landmarks = arguments['-l']
        label_type = 'body'
    elif '-ldisc' in arguments:
        fname_landmarks = arguments['-ldisc']
        label_type = 'disc'
    else:
        sct.printv('ERROR: Labels should be provided.', 1, 'error')
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''

    param.path_qc = arguments.get("-qc", None)

    path_template = arguments['-t']
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    param.remove_temp_files = int(arguments.get('-r'))
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    param_centerline = ParamCenterline(
        algo_fitting=arguments['-centerline-algo'],
        smooth=arguments['-centerline-smooth'])
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    zsubsample = param.zsubsample

    # retrieve template file names
    file_template_vertebral_labeling = get_file_label(os.path.join(path_template, 'template'), 'vertebral labeling')
    file_template = get_file_label(os.path.join(path_template, 'template'), contrast_template.upper() + '-weighted template')
    file_template_seg = get_file_label(os.path.join(path_template, 'template'), 'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = os.path.join(path_template, 'template', file_template)
    fname_template_vertebral_labeling = os.path.join(path_template, 'template', file_template_vertebral_labeling)
    fname_template_seg = os.path.join(path_template, 'template', file_template_seg)
    fname_template_disc_labeling = os.path.join(path_template, 'template', 'PAM50_label_disc.nii.gz')

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # sct.printv(arguments)
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 ' + fname_data, verbose)
    sct.printv('  Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('  Segmentation:         ' + fname_seg, verbose)
    sct.printv('  Path template:        ' + path_template, verbose)
    sct.printv('  Remove temp files:    ' + str(param.remove_temp_files), verbose)

    # check input labels
    labels = check_labels(fname_landmarks, label_type=label_type)

    vertebral_alignment = False
    if len(labels) > 2 and label_type == 'disc':
        vertebral_alignment = True

    path_tmp = sct.tmp_create(basename="register_to_template", verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    Image(fname_data).save(os.path.join(path_tmp, ftmp_data))
    Image(fname_seg).save(os.path.join(path_tmp, ftmp_seg))
    Image(fname_landmarks).save(os.path.join(path_tmp, ftmp_label))
    Image(fname_template).save(os.path.join(path_tmp, ftmp_template))
    Image(fname_template_seg).save(os.path.join(path_tmp, ftmp_template_seg))
    Image(fname_template_vertebral_labeling).save(os.path.join(path_tmp, ftmp_template_label))
    if label_type == 'disc':
        Image(fname_template_disc_labeling).save(os.path.join(path_tmp, ftmp_template_label))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Generate labels from template vertebral labeling
    if label_type == 'body':
        sct.printv('\nGenerate labels from template vertebral labeling', verbose)
        ftmp_template_label_, ftmp_template_label = ftmp_template_label, sct.add_suffix(ftmp_template_label, "_body")
        sct_label_utils.main(args=['-i', ftmp_template_label_, '-vert-body', '0', '-o', ftmp_template_label])

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # if only one label is present, force affine transformation to be Tx,Ty,Tz only (no scaling)
    if len(labels) == 1:
        paramreg.steps['0'].dof = 'Tx_Ty_Tz'
        sct.printv('WARNING: Only one label is present. Forcing initial transformation to: ' + paramreg.steps['0'].dof,
                   1, 'warning')

    # Project labels onto the spinal cord centerline because later, an affine transformation is estimated between the
    # template's labels (centered in the cord) and the subject's labels (assumed to be centered in the cord).
    # If labels are not centered, mis-registration errors are observed (see issue #1826)
    ftmp_label = project_labels_on_spinalcord(ftmp_label, ftmp_seg, param_centerline)

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    ftmp_seg_, ftmp_seg = ftmp_seg, sct.add_suffix(ftmp_seg, "_bin")
    sct_maths.main(['-i', ftmp_seg_,
                    '-bin', '0.5',
                    '-o', ftmp_seg])

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        resample_file(ftmp_data, add_suffix(ftmp_data, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        resample_file(ftmp_seg, add_suffix(ftmp_seg, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling
        # with nearest neighbour can make them disappear.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)

        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath


        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        if vertebral_alignment:
            # cropping the segmentation based on the label coverage to ensure good registration with vertebral alignment
            # See https://github.com/neuropoly/spinalcordtoolbox/pull/1669 for details
            image_labels = Image(ftmp_label)
            coordinates_labels = image_labels.getNonZeroCoordinates(sorting='z')
            nx, ny, nz, nt, px, py, pz, pt = image_labels.dim
            offset_crop = 10.0 * pz  # cropping the image 10 mm above and below the highest and lowest label
            cropping_slices = [coordinates_labels[0].z - offset_crop, coordinates_labels[-1].z + offset_crop]
            # make sure that the cropping slices do not extend outside of the slice range (issue #1811)
            if cropping_slices[0] < 0:
                cropping_slices[0] = 0
            if cropping_slices[1] > nz:
                cropping_slices[1] = nz
            msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, np.int32(np.round(cropping_slices))),))).save(ftmp_seg)
        else:
            # if we do not align the vertebral levels, we crop the segmentation from top to bottom
            im_seg_rpi = Image(ftmp_seg_)
            bottom = 0
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "IS"):
                if (data != 0).any():
                    break
                bottom += 1
            top = im_seg_rpi.data.shape[2]
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "SI"):
                if (data != 0).any():
                    break
                top -= 1
            msct_image.spatial_crop(im_seg_rpi, dict(((2, (bottom, top)),))).save(ftmp_seg)


        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)

        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        fn_warp_curve2straight = os.path.join(curdir, "warp_curve2straight.nii.gz")
        fn_warp_straight2curve = os.path.join(curdir, "warp_straight2curve.nii.gz")
        fn_straight_ref = os.path.join(curdir, "straight_ref.nii.gz")

        cache_input_files=[ftmp_seg]
        if vertebral_alignment:
            cache_input_files += [
             ftmp_template_seg,
             ftmp_label,
             ftmp_template_label,
            ]
        cache_sig = sct.cache_signature(
         input_files=cache_input_files,
        )
        cachefile = os.path.join(curdir, "straightening.cache")
        if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(fn_warp_curve2straight) and os.path.isfile(fn_warp_straight2curve) and os.path.isfile(fn_straight_ref):
            sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
            sct.copy(fn_warp_curve2straight, 'warp_curve2straight.nii.gz')
            sct.copy(fn_warp_straight2curve, 'warp_straight2curve.nii.gz')
            sct.copy(fn_straight_ref, 'straight_ref.nii.gz')
            # apply straightening
            sct_apply_transfo.main(args=[
                '-i', ftmp_seg,
                '-w', 'warp_curve2straight.nii.gz',
                '-d', 'straight_ref.nii.gz',
                '-o', add_suffix(ftmp_seg, '_straight')])
        else:
            from spinalcordtoolbox.straightening import SpinalCordStraightener
            sc_straight = SpinalCordStraightener(ftmp_seg, ftmp_seg)
            sc_straight.param_centerline = param_centerline
            sc_straight.output_filename = add_suffix(ftmp_seg, '_straight')
            sc_straight.path_output = './'
            sc_straight.qc = '0'
            sc_straight.remove_temp_files = param.remove_temp_files
            sc_straight.verbose = verbose

            if vertebral_alignment:
                sc_straight.centerline_reference_filename = ftmp_template_seg
                sc_straight.use_straight_reference = True
                sc_straight.discs_input_filename = ftmp_label
                sc_straight.discs_ref_filename = ftmp_template_label

            sc_straight.straighten()
            sct.cache_save(cachefile, cache_sig)

        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct_concat_transfo.main(args=[
            '-w', 'warp_straight2curve.nii.gz',
            '-d', ftmp_data,
            '-o', 'warp_straight2curve.nii.gz'])

        if vertebral_alignment:
            sct.copy('warp_curve2straight.nii.gz', 'warp_curve2straightAffine.nii.gz')
        else:
            # Label preparation:
            # --------------------------------------------------------------------------------
            # Remove unused label on template. Keep only label present in the input label image
            sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
            sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

            # Dilating the input label so they can be straighten without losing them
            sct.printv('\nDilating input labels using 3vox ball radius')
            sct_maths.main(['-i', ftmp_label,
                            '-dilate', '3',
                            '-o', add_suffix(ftmp_label, '_dilate')])
            ftmp_label = add_suffix(ftmp_label, '_dilate')

            # Apply straightening to labels
            sct.printv('\nApply straightening to labels...', verbose)
            sct_apply_transfo.main(args=[
                '-i', ftmp_label,
                '-o', add_suffix(ftmp_label, '_straight'),
                '-d', add_suffix(ftmp_seg, '_straight'),
                '-w', 'warp_curve2straight.nii.gz',
                '-x', 'nn'])
            ftmp_label = add_suffix(ftmp_label, '_straight')

            # Compute rigid transformation straight landmarks --> template landmarks
            sct.printv('\nEstimate transformation for step #0...', verbose)
            try:
                register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof,
                                   fname_affine='straight2templateAffine.txt', verbose=verbose)
            except RuntimeError:
                raise('Input labels do not seem to be at the right place. Please check the position of the labels. '
                      'See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42')

            # Concatenate transformations: curve --> straight --> affine
            sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
            sct_concat_transfo.main(args=[
                '-w', ['warp_curve2straight.nii.gz', 'straight2templateAffine.txt'],
                '-d', 'template.nii',
                '-o', 'warp_curve2straightAffine.nii.gz'])

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct_apply_transfo.main(args=[
            '-i', ftmp_data,
            '-o', add_suffix(ftmp_data, '_straightAffine'),
            '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz'])
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct_apply_transfo.main(args=[
            '-i', ftmp_seg,
            '-o', add_suffix(ftmp_seg, '_straightAffine'),
            '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz',
            '-x', 'linear'])
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(np.round(np.min(points_straight))), int(np.round(np.max(points_straight)))
        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_black')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (min_point,max_point)),))).save(ftmp_seg)

        """
        # open segmentation
        im = Image(ftmp_seg)
        im_new = msct_image.empty_like(im)
        # binarize
        im_new.data = im.data > 0.5
        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = msct_image.find_zmin_zmax(im_new, threshold=0.5)
        # save binarized segmentation
        im_new.save(add_suffix(ftmp_seg, '_bin')) # unused?
        # crop template in z-direction (for faster processing)
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        ftmp_template_, ftmp_template = ftmp_template, add_suffix(ftmp_template, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template)

        ftmp_template_seg_, ftmp_template_seg = ftmp_template_seg, add_suffix(ftmp_template_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template_seg)

        ftmp_data_, ftmp_data = ftmp_data, add_suffix(ftmp_data, '_crop')
        msct_image.spatial_crop(Image(ftmp_data_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_data)

        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_seg)

        # sub-sample in z-direction
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run(['sct_resample', '-i', ftmp_template, '-o', add_suffix(ftmp_template, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run(['sct_resample', '-i', ftmp_template_seg, '-o', add_suffix(ftmp_template_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run(['sct_resample', '-i', ftmp_data, '-o', add_suffix(ftmp_data, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run(['sct_resample', '-i', ftmp_seg, '-o', add_suffix(ftmp_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')

            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':
                src_seg = ftmp_seg
                dest_seg = ftmp_template_seg
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # apply transformation from previous step, to use as new src for registration
                sct_apply_transfo.main(args=[
                    '-i', src,
                    '-d', dest,
                    '-w', warp_forward,
                    '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                    '-x', interp_step])
                src = add_suffix(src, '_regStep' + str(i_step - 1))
                if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':  # also apply transformation to the seg
                    sct_apply_transfo.main(args=[
                        '-i', src_seg,
                        '-d', dest_seg,
                        '-w', warp_forward,
                        '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                        '-x', interp_step])
                    src_seg = add_suffix(src_seg, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog': # im_seg case
                warp_forward_out, warp_inverse_out = register([src, src_seg], [dest, dest_seg], paramreg, param, str(i_step))
            else:
                warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations: anat --> template
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        warp_forward.insert(0, 'warp_curve2straightAffine.nii.gz')
        sct_concat_transfo.main(args=[
            '-w', warp_forward,
            '-d', 'template.nii',
            '-o', 'warp_anat2template.nii.gz'])

        # Concatenate transformations: template --> anat
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()
        if vertebral_alignment:
            warp_inverse.append('warp_straight2curve.nii.gz')
            sct_concat_transfo.main(args=[
                '-w', warp_inverse,
                '-d', 'data.nii',
                '-o', 'warp_template2anat.nii.gz'])
        else:
            warp_inverse.append('straight2templateAffine.txt')
            warp_inverse.append('warp_straight2curve.nii.gz')
            sct_concat_transfo.main(args=[
                '-w', warp_inverse,
                '-winv', ['straight2templateAffine.txt'],
                '-d', 'data.nii',
                '-o', 'warp_template2anat.nii.gz'])

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This
        # new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = np.round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x), int(new_label.y), int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.absolutepath = 'label_rpi_modif.nii.gz'
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc="./")
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct_apply_transfo.main(args=[
                '-i', src,
                '-d', dest,
                '-w', warp_forward,
                '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                '-x', interp_step])
            src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct_concat_transfo.main(args=[
            '-w', warp_forward,
            '-d', 'data.nii',
            '-o', 'warp_template2anat.nii.gz'])
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct_concat_transfo.main(args=[
            '-w', warp_inverse,
            '-winv', ['template2subjectAffine.txt'],
            '-d', 'template.nii',
            '-o', 'warp_anat2template.nii.gz'])

    # Apply warping fields to anat and template
    sct.run(['sct_apply_transfo', '-i', 'template.nii', '-o', 'template2anat.nii.gz', '-d', 'data.nii', '-w', 'warp_template2anat.nii.gz', '-crop', '1'], verbose)
    sct.run(['sct_apply_transfo', '-i', 'data.nii', '-o', 'anat2template.nii.gz', '-d', 'template.nii', '-w', 'warp_anat2template.nii.gz', '-crop', '1'], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    fname_template2anat = os.path.join(path_output, 'template2anat' + ext_data)
    fname_anat2template = os.path.join(path_output, 'anat2template' + ext_data)
    sct.generate_output_file(os.path.join(path_tmp, "warp_template2anat.nii.gz"), os.path.join(path_output, "warp_template2anat.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "warp_anat2template.nii.gz"), os.path.join(path_output, "warp_anat2template.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "template2anat.nii.gz"), fname_template2anat, verbose)
    sct.generate_output_file(os.path.join(path_tmp, "anat2template.nii.gz"), fname_anat2template, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose)

    # Delete temporary files
    if param.remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)
    if param.path_qc is not None:
        generate_qc(fname_data, fname_in2=fname_template2anat, fname_seg=fname_seg, args=args,
                    path_qc=os.path.abspath(param.path_qc), dataset=qc_dataset, subject=qc_subject,
                    process='sct_register_to_template')
    sct.display_viewer_syntax([fname_data, fname_template2anat], verbose=verbose)
    sct.display_viewer_syntax([fname_template, fname_anat2template], verbose=verbose)
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    input_filename = arguments.i
    centerline_file = arguments.s

    sc_straight = SpinalCordStraightener(input_filename, centerline_file)

    if arguments.dest is not None:
        sc_straight.use_straight_reference = True
        sc_straight.centerline_reference_filename = str(arguments.dest)

    if arguments.ldisc_input is not None:
        if not sc_straight.use_straight_reference:
            printv(
                'Warning: discs position are not taken into account if reference is not provided.'
            )
        else:
            sc_straight.discs_input_filename = str(arguments.ldisc_input)
            sc_straight.precision = 4.0
    if arguments.ldisc_dest is not None:
        if not sc_straight.use_straight_reference:
            printv(
                'Warning: discs position are not taken into account if reference is not provided.'
            )
        else:
            sc_straight.discs_ref_filename = str(arguments.ldisc_dest)
            sc_straight.precision = 4.0

    # Handling optional arguments
    sc_straight.remove_temp_files = arguments.r
    sc_straight.interpolation_warp = arguments.x
    sc_straight.output_filename = arguments.o
    sc_straight.path_output = arguments.ofolder
    path_qc = arguments.qc
    sc_straight.verbose = verbose

    # if arguments.cpu_nb is not None:
    #     sc_straight.cpu_number = arguments.cpu-nb)
    if arguments.disable_straight2curved:
        sc_straight.straight2curved = False
    if arguments.disable_curved2straight:
        sc_straight.curved2straight = False

    if arguments.speed_factor:
        sc_straight.speed_factor = arguments.speed_factor

    if arguments.xy_size:
        sc_straight.xy_size = arguments.xy_size

    sc_straight.param_centerline = ParamCenterline(
        algo_fitting=arguments.centerline_algo,
        smooth=arguments.centerline_smooth)
    if arguments.param is not None:
        params_user = arguments.param
        # update registration parameters
        for param in params_user:
            param_split = param.split('=')
            if param_split[0] == 'precision':
                sc_straight.precision = float(param_split[1])
            if param_split[0] == 'threshold_distance':
                sc_straight.threshold_distance = float(param_split[1])
            if param_split[0] == 'accuracy_results':
                sc_straight.accuracy_results = int(param_split[1])
            if param_split[0] == 'template_orientation':
                sc_straight.template_orientation = int(param_split[1])

    fname_straight = sc_straight.straighten()

    printv("\nFinished! Elapsed time: {} s".format(sc_straight.elapsed_time),
           verbose)

    # Generate QC report
    if path_qc is not None:
        path_qc = os.path.abspath(path_qc)
        qc_dataset = arguments.qc_dataset
        qc_subject = arguments.qc_subject
        generate_qc(fname_straight,
                    args=arguments,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process=os.path.basename(__file__.strip('.py')))

    display_viewer_syntax([fname_straight], verbose=verbose)
Exemple #22
0
def run_main():
    init_sct()
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # Input filename
    fname_input_data = arguments.i
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = arguments.method

    # Contrast type
    contrast_type = arguments.c
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = "ERROR: -c is a mandatory argument when using 'optic' method."
        printv(error, type='error')
        return

    # Gap between slices
    interslice_gap = arguments.gap

    param_centerline = ParamCenterline(algo_fitting=arguments.centerline_algo,
                                       smooth=arguments.centerline_smooth,
                                       minmax=True)

    # Output folder
    if arguments.o is not None:
        file_output = arguments.o
    else:
        path_data, file_data, ext_data = extract_fname(fname_data)
        file_output = os.path.join(path_data, file_data + '_centerline')

    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level

    if method == 'viewer':
        # Manual labeling of cord centerline
        im_labels = _call_viewer_centerline(Image(fname_data),
                                            interslice_gap=interslice_gap)
    elif method == 'fitseg':
        im_labels = Image(fname_data)
    elif method == 'optic':
        # Automatic detection of cord centerline
        im_labels = Image(fname_data)
        param_centerline.algo_fitting = 'optic'
        param_centerline.contrast = contrast_type
    else:
        printv(
            "ERROR: The selected method is not available: {}. Please look at the help."
            .format(method),
            type='error')
        return

    # Extrapolate and regularize (or detect if optic) cord centerline
    im_centerline, arr_centerline, _, _ = get_centerline(
        im_labels, param=param_centerline, verbose=verbose)

    # save centerline as nifti (discrete) and csv (continuous) files
    im_centerline.save(file_output + '.nii.gz')
    np.savetxt(file_output + '.csv', arr_centerline.transpose(), delimiter=",")

    display_viewer_syntax([fname_input_data, file_output + '.nii.gz'],
                          colormaps=['gray', 'red'],
                          opacities=['', '1'])

    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject

    # Generate QC report
    if path_qc is not None:
        generate_qc(fname_input_data,
                    fname_seg=file_output + '.nii.gz',
                    args=sys.argv[1:],
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_get_centerline')
    display_viewer_syntax([fname_input_data, file_output + '.nii.gz'],
                          colormaps=['gray', 'red'],
                          opacities=['', '0.7'])
Exemple #23
0
def main(args):
    parser = get_parser()
    arguments = parser.parse(args)

    # Initialization
    slices = ''
    group_funcs = (('MEAN', func_wa), ('STD', func_std))  # functions to perform when aggregating metrics along S-I

    fname_segmentation = sct.get_absolute_path(arguments['-i'])
    fname_vert_levels = ''
    if '-o' in arguments:
        file_out = os.path.abspath(arguments['-o'])
    else:
        file_out = ''
    if '-append' in arguments:
        append = int(arguments['-append'])
    else:
        append = 0
    if '-vert' in arguments:
        vert_levels = arguments['-vert']
    else:
        vert_levels = ''
    if '-r' in arguments:
        remove_temp_files = arguments['-r']
    if '-vertfile' in arguments:
        fname_vert_levels = arguments['-vertfile']
    if '-perlevel' in arguments:
        perlevel = arguments['-perlevel']
    else:
        perlevel = None
    if '-z' in arguments:
        slices = arguments['-z']
    if '-perslice' in arguments:
        perslice = arguments['-perslice']
    else:
        perslice = None
    if '-angle-corr' in arguments:
        if arguments['-angle-corr'] == '1':
            angle_correction = True
        elif arguments['-angle-corr'] == '0':
            angle_correction = False
    param_centerline = ParamCenterline(
        algo_fitting=arguments['-centerline-algo'],
        smooth=arguments['-centerline-smooth'],
        minmax=True)
    path_qc = arguments.get("-qc", None)
    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)

    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    # update fields
    metrics_agg = {}
    if not file_out:
        file_out = 'csa.csv'

    metrics, fit_results = process_seg.compute_shape(fname_segmentation,
                                                     angle_correction=angle_correction,
                                                     param_centerline=param_centerline,
                                                     verbose=verbose)
    for key in metrics:
        metrics_agg[key] = aggregate_per_slice_or_level(metrics[key], slices=parse_num_list(slices),
                                                        levels=parse_num_list(vert_levels), perslice=perslice,
                                                        perlevel=perlevel, vert_level=fname_vert_levels,
                                                        group_funcs=group_funcs)
    metrics_agg_merged = _merge_dict(metrics_agg)
    save_as_csv(metrics_agg_merged, file_out, fname_in=fname_segmentation, append=append)

    # QC report (only show CSA for clarity)
    if path_qc is not None:
        generate_qc(fname_segmentation, args=args, path_qc=os.path.abspath(path_qc), dataset=qc_dataset,
                    subject=qc_subject, path_img=_make_figure(metrics_agg_merged, fit_results),
                    process='sct_process_segmentation')

    sct.display_open(file_out)
Exemple #24
0
def run_main():
    sct.init_sct()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    # Input filename
    fname_input_data = arguments["-i"]
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = 'optic'
    if "-method" in arguments:
        method = arguments["-method"]

    # Contrast type
    contrast_type = ''
    if "-c" in arguments:
        contrast_type = arguments["-c"]
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = 'ERROR: -c is a mandatory argument when using Optic method.'
        sct.printv(error, type='error')
        return

    # Gap between slices
    interslice_gap = 10.0
    if "-gap" in arguments:
        interslice_gap = float(arguments["-gap"])

    param_centerline = ParamCenterline(
        algo_fitting=arguments['-centerline-algo'],
        smooth=arguments['-centerline-smooth'],
        minmax=True)

    # Output folder
    if "-o" in arguments:
        file_output = arguments["-o"]
    else:
        path_data, file_data, ext_data = sct.extract_fname(fname_data)
        file_output = os.path.join(path_data, file_data + '_centerline')

    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level

    if method == 'viewer':
        # Manual labeling of cord centerline
        im_labels = _call_viewer_centerline(Image(fname_data), interslice_gap=interslice_gap)
    else:
        # Automatic detection of cord centerline
        im_labels = Image(fname_data)
        param_centerline.algo_fitting = 'optic'
        param_centerline.contrast = contrast_type

    # Extrapolate and regularize (or detect if optic) cord centerline
    im_centerline, arr_centerline, _, _ = get_centerline(im_labels,
                                                         param=param_centerline,
                                                         verbose=verbose)

    # save centerline as nifti (discrete) and csv (continuous) files
    im_centerline.save(file_output + '.nii.gz')
    np.savetxt(file_output + '.csv', arr_centerline.transpose(), delimiter=",")

    sct.display_viewer_syntax([fname_input_data, file_output+'.nii.gz'], colormaps=['gray', 'red'], opacities=['', '1'])
Exemple #25
0
def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname):
    """
    Assumes RPI orientation
    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:
    :return:
    """

    # TODO: remove unnecessary i/o
    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        from sct_image import concat_data
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
    else:
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        img_ctl, arr_ctl, _ = get_centerline(Image(image_fname), algo_fitting='optic',
                                             param=ParamCenterline(contrast=contrast_type))
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        img_ctl.save(centerline_filename)

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                            't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                            't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                            'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
        dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
                            't2s': {'features': 8, 'dilation_layers': 3},
                            't1': {'features': 24, 'dilation_layers': 3},
                            'dwi': {'features': 8, 'dilation_layers': 2}}

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast_type]['size'][0],
                                        width=dct_patch_ctr[contrast_type]['size'][1],
                                        channels=1,
                                        classes=1,
                                        features=dct_params_ctr[contrast_type]['features'],
                                        depth=2,
                                        temperature=1.0,
                                        padding='same',
                                        batchnorm=True,
                                        dropout=0.0,
                                        dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname, fname_res, new_resolution,
                                                               'mm', 'linear', verbose=0)

        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,
                        filename_out=fname_heatmap_nii,
                        model=ctr_model,
                        patch_shape=dct_patch_ctr[contrast_type]['size'],
                        mean_train=dct_patch_ctr[contrast_type]['mean'],
                        std_train=dct_patch_ctr[contrast_type]['std'],
                        brain_bool=brain_bool)

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
        heatmap2optic(fname_heatmap=fname_heatmap_nii,
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      fname_out=centerline_filename,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(Image(image_fname))
        im_centerline, arr_centerline, _ = get_centerline(im_labels)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        im_centerline.save(centerline_filename)
    elif algo == 'manual':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        # Re-orient the manual centerline
        Image(centerline_fname).change_orientation('RPI').save(centerline_filename)
    else:
        sct.log.error('The parameter "-centerline" is incorrect. Please try again.')
        sys.exit(1)

    if algo != 'cnn':
        sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname, fname_res, new_resolution,
                                                               'mm', 'linear', verbose=0)

        resampling.resample_file(centerline_filename, centerline_filename, new_resolution,
                                                               'mm', 'linear', verbose=0)

    if bool_2d:
        from sct_image import split_data
        im_split_lst = split_data(Image(centerline_filename), dim=2)
        im_split_lst[0].save(centerline_filename)

    return fname_res, centerline_filename