Ejemplo n.º 1
0
def get_centerline(im_seg,
                   algo_fitting='polyfit',
                   minmax=True,
                   param=ParamCenterline(),
                   verbose=1):
    """
    Extract centerline from an image (using optic) or from a binary or weighted segmentation (using the center of mass).
    :param im_seg: Image(): Input segmentation or series of points along the centerline.
    :param algo_fitting: str:
        polyfit: Polynomial fitting
        nurbs:
        optic: Automatic segmentation using SVM and HOG. See [Gros et al. MIA 2018].
    :param minmax: Crop output centerline where the segmentation starts/end. If False, centerline will span all slices.
    :param param: ParamCenterline()
    :param verbose: int: verbose level
    :return: im_centerline: Image: Centerline in discrete coordinate (int)
    :return: arr_centerline: 3x1 array: Centerline in continuous coordinate (float) for each slice in RPI orientation.
    :return: arr_centerline_deriv: 3x1 array: Derivatives of x and y centerline wrt. z for each slice in RPI orient.
    """

    if not isinstance(im_seg, Image):
        raise ValueError("Expecting an image")
    # Open image and change to RPI orientation
    native_orientation = im_seg.orientation
    im_seg.change_orientation('RPI')
    px, py, pz = im_seg.dim[4:7]

    # Take the center of mass at each slice to avoid: https://stackoverflow.com/questions/2009379/interpolate-question
    x_mean, y_mean, z_mean = find_and_sort_coord(im_seg)

    # Crop output centerline to where the segmentation starts/end
    if minmax:
        z_ref = np.array(
            range(z_mean.min().astype(int),
                  z_mean.max().astype(int)))
    else:
        z_ref = np.array(range(im_seg.dim[2]))

    # Choose method
    if algo_fitting == 'polyfit':
        x_centerline_fit, x_centerline_deriv = curve_fitting.polyfit_1d(
            z_mean, x_mean, z_ref, deg=param.degree)
        y_centerline_fit, y_centerline_deriv = curve_fitting.polyfit_1d(
            z_mean, y_mean, z_ref, deg=param.degree)

    elif algo_fitting == 'bspline':
        x_centerline_fit, x_centerline_deriv = curve_fitting.bspline(
            z_mean, x_mean, z_ref, deg=param.degree)
        y_centerline_fit, y_centerline_deriv = curve_fitting.bspline(
            z_mean, y_mean, z_ref, deg=param.degree)

    elif algo_fitting == 'linear':
        # Simple linear interpolation
        x_centerline_fit = curve_fitting.linear(z_mean, x_mean, z_ref)
        y_centerline_fit = curve_fitting.linear(z_mean, y_mean, z_ref)
        # Compute derivatives using polynomial fit due to undefined derivatives using linear interpolation
        _, x_centerline_deriv = curve_fitting.polyfit_1d(z_mean,
                                                         x_mean,
                                                         z_ref,
                                                         deg=5)
        _, y_centerline_deriv = curve_fitting.polyfit_1d(z_mean,
                                                         y_mean,
                                                         z_ref,
                                                         deg=5)

    elif algo_fitting == 'nurbs':
        from spinalcordtoolbox.centerline.nurbs import b_spline_nurbs
        # Interpolate such that the output centerline has the same length as z_ref
        x_mean_interp = curve_fitting.linear(z_mean, x_mean, z_ref)
        y_mean_interp = curve_fitting.linear(z_mean, y_mean, z_ref)
        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, \
            z_centerline_deriv, error = b_spline_nurbs(x_mean_interp, y_mean_interp, z_ref, nbControl=None, point_number=3000,
                                                       all_slices=True)

    elif algo_fitting == 'optic':
        # This method is particular compared to the previous ones, as here we estimate the centerline based on the
        # image itself (not the segmentation). Hence, we can bypass the fitting procedure and centerline creation
        # and directly output results.
        from spinalcordtoolbox.centerline import optic
        im_centerline = optic.detect_centerline(im_seg, param.contrast)
        x_centerline_fit, y_centerline_fit, z_centerline = find_and_sort_coord(
            im_centerline)
        # Compute derivatives using polynomial fit
        # TODO: Fix below with reorientation of axes
        _, x_centerline_deriv = curve_fitting.polyfit_1d(z_centerline,
                                                         x_centerline_fit,
                                                         z_centerline,
                                                         deg=5)
        _, y_centerline_deriv = curve_fitting.polyfit_1d(z_centerline,
                                                         y_centerline_fit,
                                                         z_centerline,
                                                         deg=5)
        return im_centerline.change_orientation(native_orientation), \
               np.array([x_centerline_fit, y_centerline_fit, z_centerline]), \
               np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_centerline)]),

    # Display fig of fitted curves
    if verbose == 2:
        from datetime import datetime
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt
        plt.figure()
        plt.subplot(2, 1, 1)
        plt.title("Algo=%s, Deg=%s" % (algo_fitting, param.degree))
        plt.plot(z_ref * pz, x_centerline_fit * px)
        plt.plot(z_ref * pz, x_centerline_fit * px, 'b.')
        plt.plot(z_mean * pz, x_mean * px, 'ro')
        plt.ylabel("X [mm]")
        plt.subplot(2, 1, 2)
        plt.plot(z_ref, y_centerline_fit)
        plt.plot(z_ref, y_centerline_fit, 'b.')
        plt.plot(z_mean, y_mean, 'ro')
        plt.xlabel("Z [mm]")
        plt.ylabel("Y [mm]")
        plt.savefig('fig_centerline_' +
                    datetime.now().strftime("%y%m%d%H%M%S%f") + '_' +
                    algo_fitting + '.png')
        plt.close()

    # Create an image with the centerline
    im_centerline = im_seg.copy()
    im_centerline.data = np.zeros(im_centerline.data.shape)
    # Assign value=1 to centerline. Make sure to clip to avoid array overflow.
    # TODO: check this round and clip-- suspicious
    im_centerline.data[
        round_and_clip(x_centerline_fit, clip=[0, im_centerline.data.shape[0]]
                       ),
        round_and_clip(y_centerline_fit, clip=[0, im_centerline.data.shape[1]]
                       ), z_ref] = 1
    # reorient centerline to native orientation
    im_centerline.change_orientation(native_orientation)
    # TODO: Reorient centerline in native orientation. For now, we output the array in RPI. Note that it is tricky to
    #   reorient in native orientation, because the voxel center is not in the middle, but in the top corner, so this
    #   needs to be taken into accound during reorientation. The code below does not work properly.
    # # Get a permutation and inversion based on native orientation
    # perm, inversion = _get_permutations(im_seg.orientation, native_orientation)
    # # axes inversion (flip)
    # # ctl = np.array([x_centerline_fit[::inversion[0]], y_centerline_fit[::inversion[1]], z_ref[::inversion[2]]])
    # ctl = np.array([x_centerline_fit, y_centerline_fit, z_ref])
    # ctl_deriv = np.array([x_centerline_deriv[::inversion[0]], y_centerline_deriv[::inversion[1]], np.ones_like(z_ref)])
    # return im_centerline, \
    #        np.array([ctl[perm[0]], ctl[perm[1]], ctl[perm[2]]]), \
    #        np.array([ctl_deriv[perm[0]], ctl_deriv[perm[1]], ctl_deriv[perm[2]]])
    return im_centerline, \
           np.array([x_centerline_fit, y_centerline_fit, z_ref]), \
           np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_ref)])
Ejemplo n.º 2
0
        if use_viewer == "centerline":
            cmd += ["-init-centerline", tmp_output_file.absolutepath]
        elif use_viewer == "mask":
            cmd += ["-init-mask", tmp_output_file.absolutepath]

    # If using OptiC
    elif use_optic:
        path_script = os.path.dirname(__file__)
        path_sct = os.path.dirname(path_script)
        path_classifier = os.path.join(path_sct, 'data/optic_models',
                                       '{}_model'.format(contrast_type))

        init_option_optic, optic_filename = optic.detect_centerline(
            fname_data,
            contrast_type,
            path_classifier,
            folder_output,
            remove_temp_files,
            init_option,
            verbose=verbose)
        if init_option is not None:
            cmd += ["-init", str(init_option_optic)]

        cmd += ["-init-centerline", optic_filename]

    # enabling centerline extraction by default (needed by check_and_correct_segmentation() )
    cmd += ['-centerline-binary']

    # run propseg
    status, output = sct.run(cmd, verbose, raise_exception=False)

    # check status is not 0
Ejemplo n.º 3
0
def propseg(img_input, options_dict):
    """
    :param img_input: source image, to be segmented
    :param options_dict: arguments as dictionary
    :return: segmented Image
    """
    arguments = options_dict
    fname_input_data = img_input.absolutepath
    fname_data = os.path.abspath(fname_input_data)
    contrast_type = arguments.c
    contrast_type_conversion = {
        't1': 't1',
        't2': 't2',
        't2s': 't2',
        'dwi': 't1'
    }
    contrast_type_propseg = contrast_type_conversion[contrast_type]

    # Starting building the command
    cmd = ['isct_propseg', '-t', contrast_type_propseg]

    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = os.path.basename(add_suffix(fname_data, "_seg"))

    folder_output = os.path.dirname(fname_out)
    cmd += ['-o', folder_output]
    if not os.path.isdir(folder_output) and os.path.exists(folder_output):
        logger.error("output directory %s is not a valid directory" %
                     folder_output)
    if not os.path.exists(folder_output):
        os.makedirs(folder_output)

    if arguments.down is not None:
        cmd += ["-down", str(arguments.down)]
    if arguments.up is not None:
        cmd += ["-up", str(arguments.up)]

    remove_temp_files = arguments.r

    verbose = int(arguments.v)
    # Update for propseg binary
    if verbose > 0:
        cmd += ["-verbose"]

    # Output options
    if arguments.mesh is not None:
        cmd += ["-mesh"]
    if arguments.centerline_binary is not None:
        cmd += ["-centerline-binary"]
    if arguments.CSF is not None:
        cmd += ["-CSF"]
    if arguments.centerline_coord is not None:
        cmd += ["-centerline-coord"]
    if arguments.cross is not None:
        cmd += ["-cross"]
    if arguments.init_tube is not None:
        cmd += ["-init-tube"]
    if arguments.low_resolution_mesh is not None:
        cmd += ["-low-resolution-mesh"]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_nii is not None:
    #     cmd += ["-detect-nii"]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_png is not None:
    #     cmd += ["-detect-png"]

    # Helping options
    use_viewer = None
    use_optic = True  # enabled by default
    init_option = None
    rescale_header = arguments.rescale
    if arguments.init is not None:
        init_option = float(arguments.init)
        if init_option < 0:
            printv(
                'Command-line usage error: ' + str(init_option) +
                " is not a valid value for '-init'", 1, 'error')
            sys.exit(1)
    if arguments.init_centerline is not None:
        if str(arguments.init_centerline) == "viewer":
            use_viewer = "centerline"
        elif str(arguments.init_centerline) == "hough":
            use_optic = False
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(
                    arguments.init_centerline),
                                                          rescale_header,
                                                          verbose=verbose)
            else:
                fname_labels_viewer = str(arguments.init_centerline)
            cmd += ["-init-centerline", fname_labels_viewer]
            use_optic = False
    if arguments.init_mask is not None:
        if str(arguments.init_mask) == "viewer":
            use_viewer = "mask"
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(
                    str(arguments.init_mask), rescale_header)
            else:
                fname_labels_viewer = str(arguments.init_mask)
            cmd += ["-init-mask", fname_labels_viewer]
            use_optic = False
    if arguments.mask_correction is not None:
        cmd += ["-mask-correction", str(arguments.mask_correction)]
    if arguments.radius is not None:
        cmd += ["-radius", str(arguments.radius)]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_n is not None:
    #     cmd += ["-detect-n", str(arguments.detect_n)]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.detect_gap is not None:
    #     cmd += ["-detect-gap", str(arguments.detect_gap)]
    # TODO: Not present. Why is this here? Was this renamed?
    # if arguments.init_validation is not None:
    #     cmd += ["-init-validation"]
    if arguments.nbiter is not None:
        cmd += ["-nbiter", str(arguments.nbiter)]
    if arguments.max_area is not None:
        cmd += ["-max-area", str(arguments.max_area)]
    if arguments.max_deformation is not None:
        cmd += ["-max-deformation", str(arguments.max_deformation)]
    if arguments.min_contrast is not None:
        cmd += ["-min-contrast", str(arguments.min_contrast)]
    if arguments.d is not None:
        cmd += ["-d", str(arguments["-d"])]
    if arguments.distance_search is not None:
        cmd += ["-dsearch", str(arguments.distance_search)]
    if arguments.alpha is not None:
        cmd += ["-alpha", str(arguments.alpha)]

    # check if input image is in 3D. Otherwise itk image reader will cut the 4D image in 3D volumes and only take the first one.
    image_input = Image(fname_data)
    image_input_rpi = image_input.copy().change_orientation('RPI')
    nx, ny, nz, nt, px, py, pz, pt = image_input_rpi.dim
    if nt > 1:
        printv(
            'ERROR: your input image needs to be 3D in order to be segmented.',
            1, 'error')

    path_data, file_data, ext_data = extract_fname(fname_data)
    path_tmp = tmp_create(basename="label_vertebrae")

    # rescale header (see issue #1406)
    if rescale_header is not 1:
        fname_data_propseg = func_rescale_header(fname_data, rescale_header)
    else:
        fname_data_propseg = fname_data

    # add to command
    cmd += ['-i', fname_data_propseg]

    # if centerline or mask is asked using viewer
    if use_viewer:
        from spinalcordtoolbox.gui.base import AnatomicalParams
        from spinalcordtoolbox.gui.centerline import launch_centerline_dialog

        params = AnatomicalParams()
        if use_viewer == 'mask':
            params.num_points = 3
            params.interval_in_mm = 15  # superior-inferior interval between two consecutive labels
            params.starting_slice = 'midfovminusinterval'
        if use_viewer == 'centerline':
            # setting maximum number of points to a reasonable value
            params.num_points = 20
            params.interval_in_mm = 30
            params.starting_slice = 'top'
        im_data = Image(fname_data_propseg)

        im_mask_viewer = zeros_like(im_data)
        # im_mask_viewer.absolutepath = add_suffix(fname_data_propseg, '_labels_viewer')
        controller = launch_centerline_dialog(im_data, im_mask_viewer, params)
        fname_labels_viewer = add_suffix(fname_data_propseg, '_labels_viewer')

        if not controller.saved:
            printv(
                'The viewer has been closed before entering all manual points. Please try again.',
                1, 'error')
            sys.exit(1)
        # save labels
        controller.as_niftii(fname_labels_viewer)

        # add mask filename to parameters string
        if use_viewer == "centerline":
            cmd += ["-init-centerline", fname_labels_viewer]
        elif use_viewer == "mask":
            cmd += ["-init-mask", fname_labels_viewer]

    # If using OptiC
    elif use_optic:
        image_centerline = optic.detect_centerline(image_input, contrast_type,
                                                   verbose)
        fname_centerline_optic = os.path.join(path_tmp,
                                              'centerline_optic.nii.gz')
        image_centerline.save(fname_centerline_optic)
        cmd += ["-init-centerline", fname_centerline_optic]

    if init_option is not None:
        if init_option > 1:
            init_option /= (nz - 1)
        cmd += ['-init', str(init_option)]

    # enabling centerline extraction by default (needed by check_and_correct_segmentation() )
    cmd += ['-centerline-binary']

    # run propseg
    status, output = run_proc(cmd,
                              verbose,
                              raise_exception=False,
                              is_sct_binary=True)

    # check status is not 0
    if not status == 0:
        printv(
            'Automatic cord detection failed. Please initialize using -init-centerline or -init-mask (see help)',
            1, 'error')
        sys.exit(1)

    # build output filename
    fname_seg = os.path.join(folder_output, fname_out)
    fname_centerline = os.path.join(
        folder_output, os.path.basename(add_suffix(fname_data, "_centerline")))
    # in case header was rescaled, we need to update the output file names by removing the "_rescaled"
    if rescale_header is not 1:
        mv(
            os.path.join(
                folder_output,
                add_suffix(os.path.basename(fname_data_propseg), "_seg")),
            fname_seg)
        mv(
            os.path.join(
                folder_output,
                add_suffix(os.path.basename(fname_data_propseg),
                           "_centerline")), fname_centerline)
        # if user was used, copy the labelled points to the output folder (they will then be scaled back)
        if use_viewer:
            fname_labels_viewer_new = os.path.join(
                folder_output,
                os.path.basename(add_suffix(fname_data, "_labels_viewer")))
            copy(fname_labels_viewer, fname_labels_viewer_new)
            # update variable (used later)
            fname_labels_viewer = fname_labels_viewer_new

    # check consistency of segmentation
    if arguments.correct_seg:
        check_and_correct_segmentation(fname_seg,
                                       fname_centerline,
                                       folder_output=folder_output,
                                       threshold_distance=3.0,
                                       remove_temp_files=remove_temp_files,
                                       verbose=verbose)

    # copy header from input to segmentation to make sure qform is the same
    printv("Copy header input --> output(s) to make sure qform is the same.",
           verbose)
    list_fname = [fname_seg, fname_centerline]
    if use_viewer:
        list_fname.append(fname_labels_viewer)
    for fname in list_fname:
        im = Image(fname)
        im.header = image_input.header
        im.save(dtype='int8'
                )  # they are all binary masks hence fine to save as int8

    return Image(fname_seg)
Ejemplo n.º 4
0
def propseg(img_input, options_dict):
    """
    :param img_input: source image, to be segmented
    :param options_dict: arguments as dictionary
    :return: segmented Image
    """
    arguments = options_dict
    fname_input_data = img_input.absolutepath
    fname_data = os.path.abspath(fname_input_data)
    contrast_type = arguments["-c"]
    contrast_type_conversion = {
        't1': 't1',
        't2': 't2',
        't2s': 't2',
        'dwi': 't1'
    }
    contrast_type_propseg = contrast_type_conversion[contrast_type]

    # Starting building the command
    cmd = ['isct_propseg', '-t', contrast_type_propseg]

    if "-ofolder" in arguments:
        folder_output = arguments["-ofolder"]
    else:
        folder_output = './'
    cmd += ['-o', folder_output]
    if not os.path.isdir(folder_output) and os.path.exists(folder_output):
        sct.log.error("output directory %s is not a valid directory" %
                      folder_output)
    if not os.path.exists(folder_output):
        os.makedirs(folder_output)

    if "-down" in arguments:
        cmd += ["-down", str(arguments["-down"])]
    if "-up" in arguments:
        cmd += ["-up", str(arguments["-up"])]

    remove_temp_files = 1
    if "-r" in arguments:
        remove_temp_files = int(arguments["-r"])

    verbose = 0
    if "-v" in arguments:
        if arguments["-v"] is "1":
            verbose = 2
            cmd += ["-verbose"]

    # Output options
    if "-mesh" in arguments:
        cmd += ["-mesh"]
    if "-centerline-binary" in arguments:
        cmd += ["-centerline-binary"]
    if "-CSF" in arguments:
        cmd += ["-CSF"]
    if "-centerline-coord" in arguments:
        cmd += ["-centerline-coord"]
    if "-cross" in arguments:
        cmd += ["-cross"]
    if "-init-tube" in arguments:
        cmd += ["-init-tube"]
    if "-low-resolution-mesh" in arguments:
        cmd += ["-low-resolution-mesh"]
    if "-detect-nii" in arguments:
        cmd += ["-detect-nii"]
    if "-detect-png" in arguments:
        cmd += ["-detect-png"]

    # Helping options
    use_viewer = None
    use_optic = True  # enabled by default
    init_option = None
    rescale_header = arguments["-rescale"]
    if "-init" in arguments:
        init_option = float(arguments["-init"])
        if init_option < 0:
            sct.log.error('Command-line usage error: ' + str(init_option) +
                          " is not a valid value for '-init'")
            sys.exit(1)
    if "-init-centerline" in arguments:
        if str(arguments["-init-centerline"]) == "viewer":
            use_viewer = "centerline"
        elif str(arguments["-init-centerline"]) == "hough":
            use_optic = False
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(
                    arguments["-init-centerline"]),
                                                          rescale_header,
                                                          verbose=verbose)
            else:
                fname_labels_viewer = str(arguments["-init-centerline"])
            cmd += ["-init-centerline", fname_labels_viewer]
            use_optic = False
    if "-init-mask" in arguments:
        if str(arguments["-init-mask"]) == "viewer":
            use_viewer = "mask"
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(
                    str(arguments["-init-mask"]), rescale_header)
            else:
                fname_labels_viewer = str(arguments["-init-mask"])
            cmd += ["-init-mask", fname_labels_viewer]
            use_optic = False
    if "-mask-correction" in arguments:
        cmd += ["-mask-correction", str(arguments["-mask-correction"])]
    if "-radius" in arguments:
        cmd += ["-radius", str(arguments["-radius"])]
    if "-detect-n" in arguments:
        cmd += ["-detect-n", str(arguments["-detect-n"])]
    if "-detect-gap" in arguments:
        cmd += ["-detect-gap", str(arguments["-detect-gap"])]
    if "-init-validation" in arguments:
        cmd += ["-init-validation"]
    if "-nbiter" in arguments:
        cmd += ["-nbiter", str(arguments["-nbiter"])]
    if "-max-area" in arguments:
        cmd += ["-max-area", str(arguments["-max-area"])]
    if "-max-deformation" in arguments:
        cmd += ["-max-deformation", str(arguments["-max-deformation"])]
    if "-min-contrast" in arguments:
        cmd += ["-min-contrast", str(arguments["-min-contrast"])]
    if "-d" in arguments:
        cmd += ["-d", str(arguments["-d"])]
    if "-distance-search" in arguments:
        cmd += ["-dsearch", str(arguments["-distance-search"])]
    if "-alpha" in arguments:
        cmd += ["-alpha", str(arguments["-alpha"])]

    # check if input image is in 3D. Otherwise itk image reader will cut the 4D image in 3D volumes and only take the first one.
    image_input = Image(fname_data)
    nx, ny, nz, nt, px, py, pz, pt = image_input.dim
    if nt > 1:
        sct.log.error(
            'ERROR: your input image needs to be 3D in order to be segmented.')

    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # rescale header (see issue #1406)
    if rescale_header is not 1:
        fname_data_propseg = func_rescale_header(fname_data, rescale_header)
    else:
        fname_data_propseg = fname_data

    # add to command
    cmd += ['-i', fname_data_propseg]

    # if centerline or mask is asked using viewer
    if use_viewer:
        from spinalcordtoolbox.gui.base import AnatomicalParams
        from spinalcordtoolbox.gui.centerline import launch_centerline_dialog

        params = AnatomicalParams()
        if use_viewer == 'mask':
            params.num_points = 3
            params.interval_in_mm = 15  # superior-inferior interval between two consecutive labels
            params.starting_slice = 'midfovminusinterval'
        if use_viewer == 'centerline':
            # setting maximum number of points to a reasonable value
            params.num_points = 20
            params.interval_in_mm = 30
            params.starting_slice = 'top'
        im_data = Image(fname_data_propseg)

        im_mask_viewer = msct_image.zeros_like(im_data)
        # im_mask_viewer.absolutepath = sct.add_suffix(fname_data_propseg, '_labels_viewer')
        controller = launch_centerline_dialog(im_data, im_mask_viewer, params)
        fname_labels_viewer = sct.add_suffix(fname_data_propseg,
                                             '_labels_viewer')

        if not controller.saved:
            sct.log.error(
                'The viewer has been closed before entering all manual points. Please try again.'
            )
            sys.exit(1)
        # save labels
        controller.as_niftii(fname_labels_viewer)

        # add mask filename to parameters string
        if use_viewer == "centerline":
            cmd += ["-init-centerline", fname_labels_viewer]
        elif use_viewer == "mask":
            cmd += ["-init-mask", fname_labels_viewer]

    # If using OptiC
    elif use_optic:
        path_script = os.path.dirname(__file__)
        path_sct = os.path.dirname(path_script)
        path_classifier = os.path.join(path_sct, 'data/optic_models',
                                       '{}_model'.format(contrast_type))

        init_option_optic, fname_centerline = optic.detect_centerline(
            fname_data_propseg,
            contrast_type,
            path_classifier,
            folder_output,
            remove_temp_files,
            init_option,
            verbose=verbose)
        if init_option is not None:
            # TODO: what's this???
            cmd += ["-init", str(init_option_optic)]

        cmd += ["-init-centerline", fname_centerline]

    # enabling centerline extraction by default (needed by check_and_correct_segmentation() )
    cmd += ['-centerline-binary']

    # run propseg
    status, output = sct.run(cmd, verbose, raise_exception=False)

    # check status is not 0
    if not status == 0:
        sct.log.error(
            'Automatic cord detection failed. Please initialize using -init-centerline or '
            '-init-mask (see help).')
        sys.exit(1)

    # build output filename
    fname_seg = os.path.join(
        folder_output, os.path.basename(sct.add_suffix(fname_data, "_seg")))
    fname_centerline = os.path.join(
        folder_output,
        os.path.basename(sct.add_suffix(fname_data, "_centerline")))
    # in case header was rescaled, we need to update the output file names by removing the "_rescaled"
    if rescale_header is not 1:
        sct.mv(
            os.path.join(
                folder_output,
                sct.add_suffix(os.path.basename(fname_data_propseg), "_seg")),
            fname_seg)
        sct.mv(
            os.path.join(
                folder_output,
                sct.add_suffix(os.path.basename(fname_data_propseg),
                               "_centerline")), fname_centerline)
        # if user was used, copy the labelled points to the output folder (they will then be scaled back)
        if use_viewer:
            fname_labels_viewer_new = os.path.join(
                folder_output,
                os.path.basename(sct.add_suffix(fname_data, "_labels_viewer")))
            sct.copy(fname_labels_viewer, fname_labels_viewer_new)
            # update variable (used later)
            fname_labels_viewer = fname_labels_viewer_new

    # check consistency of segmentation
    if arguments["-correct-seg"] == "1":
        check_and_correct_segmentation(fname_seg,
                                       fname_centerline,
                                       folder_output=folder_output,
                                       threshold_distance=3.0,
                                       remove_temp_files=remove_temp_files,
                                       verbose=verbose)

    # copy header from input to segmentation to make sure qform is the same
    sct.printv(
        "Copy header input --> output(s) to make sure qform is the same.",
        verbose)
    list_fname = [fname_seg, fname_centerline]
    if use_viewer:
        list_fname.append(fname_labels_viewer)
    for fname in list_fname:
        im = Image(fname)
        im.header = image_input.header
        im.save(dtype='int8'
                )  # they are all binary masks hence fine to save as int8

    return Image(fname_seg)
Ejemplo n.º 5
0
def deep_segmentation_spinalcord(fname_image,
                                 contrast_type,
                                 output_folder,
                                 ctr_algo='cnn',
                                 brain_bool=True,
                                 kernel_size='2d',
                                 remove_temp_files=1,
                                 verbose=1):
    """Pipeline."""
    path_script = os.path.dirname(__file__)
    path_sct = os.path.dirname(path_script)

    # create temporary folder with intermediate results
    sct.log.info("Creating temporary folder...")
    file_fname = os.path.basename(fname_image)
    tmp_folder = sct.TempFolder()
    tmp_folder_path = tmp_folder.get_path()
    fname_image_tmp = tmp_folder.copy_from(fname_image)
    tmp_folder.chdir()

    # orientation of the image, should be RPI
    sct.log.info("Reorient the image to RPI, if necessary...")
    fname_orient = sct.add_suffix(file_fname, '_RPI')
    im_2orient = Image(file_fname)
    original_orientation = im_2orient.orientation
    if original_orientation != 'RPI':
        im_orient = set_orientation(im_2orient, 'RPI')
        im_orient.setFileName(fname_orient)
        im_orient.save()
    else:
        im_orient = im_2orient
        sct.copy(fname_image_tmp, fname_orient)

    # resampling RPI image
    sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
    fname_res = sct.add_suffix(fname_orient, '_resampled')
    im_2res = im_orient
    input_resolution = im_2res.dim[4:7]
    new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    spinalcordtoolbox.resample.nipy_resample.resample_file(fname_orient,
                                                           fname_res,
                                                           new_resolution,
                                                           'mm',
                                                           'linear',
                                                           verbose=0)

    # find the spinal cord centerline - execute OptiC binary
    sct.log.info("Finding the spinal cord centerline...")
    if ctr_algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        optic_models_fname = os.path.join(path_sct, 'data', 'optic_models',
                                          '{}_model'.format(contrast_type))
        _, centerline_filename = optic.detect_centerline(
            image_fname=fname_res,
            contrast_type=contrast_type,
            optic_models_path=optic_models_fname,
            folder_output=tmp_folder_path,
            remove_temp_files=remove_temp_files,
            output_roi=False,
            verbose=0)
    elif ctr_algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {
            't2': {
                'size': (80, 80),
                'mean': 51.1417,
                'std': 57.4408
            },
            't2s': {
                'size': (80, 80),
                'mean': 68.8591,
                'std': 71.4659
            },
            't1': {
                'size': (80, 80),
                'mean': 55.7359,
                'std': 64.3149
            },
            'dwi': {
                'size': (80, 80),
                'mean': 55.744,
                'std': 45.003
            }
        }
        dct_params_ctr = {
            't2': {
                'features': 16,
                'dilation_layers': 2
            },
            't2s': {
                'features': 8,
                'dilation_layers': 3
            },
            't1': {
                'features': 24,
                'dilation_layers': 3
            },
            'dwi': {
                'features': 8,
                'dilation_layers': 2
            }
        }

        # load model
        ctr_model_fname = os.path.join(path_sct, 'data', 'deepseg_sc_models',
                                       '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(
            height=dct_patch_ctr[contrast_type]['size'][0],
            width=dct_patch_ctr[contrast_type]['size'][1],
            channels=1,
            classes=1,
            features=dct_params_ctr[contrast_type]['features'],
            depth=2,
            temperature=1.0,
            padding='same',
            batchnorm=True,
            dropout=0.0,
            dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        # compute the heatmap
        fname_heatmap = sct.add_suffix(fname_res, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,
                        filename_out=fname_heatmap_nii,
                        model=ctr_model,
                        patch_shape=dct_patch_ctr[contrast_type]['size'],
                        mean_train=dct_patch_ctr[contrast_type]['mean'],
                        std_train=dct_patch_ctr[contrast_type]['std'],
                        brain_bool=brain_bool)

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
        heatmap2optic(fname_heatmap=fname_heatmap_nii,
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      fname_out=centerline_filename,
                      z_max=z_max if brain_bool else None)

    # crop image around the spinal cord centerline
    sct.log.info("Cropping the image around the spinal cord...")
    fname_crop = sct.add_suffix(fname_res, '_crop')
    crop_size = 64 if kernel_size == '2d' else 96
    X_CROP_LST, Y_CROP_LST = crop_image_around_centerline(
        filename_in=fname_res,
        filename_ctr=centerline_filename,
        filename_out=fname_crop,
        crop_size=crop_size)

    # normalize the intensity of the images
    sct.log.info("Normalizing the intensity...")
    fname_norm = sct.add_suffix(fname_crop, '_norm')
    apply_intensity_normalization(img_path=fname_crop, fname_out=fname_norm)

    if kernel_size == '2d':
        # segment data using 2D convolutions
        sct.log.info(
            "Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = os.path.join(
            path_sct, 'data', 'deepseg_sc_models',
            '{}_sc.h5'.format(contrast_type))
        fname_seg_crop = sct.add_suffix(fname_norm, '_seg')
        seg_crop_data = segment_2d(model_fname=segmentation_model_fname,
                                   contrast_type=contrast_type,
                                   input_size=(crop_size, crop_size),
                                   fname_in=fname_norm,
                                   fname_out=fname_seg_crop)
    elif kernel_size == '3d':
        # resample to 0.5mm isotropic
        fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
        spinalcordtoolbox.resample.nipy_resample.resample_file(fname_norm,
                                                               fname_res3d,
                                                               '0.5x0.5x0.5',
                                                               'mm',
                                                               'linear',
                                                               verbose=0)

        # segment data using 3D convolutions
        sct.log.info(
            "Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = os.path.join(
            path_sct, 'data', 'deepseg_sc_models',
            '{}_sc_3D.h5'.format(contrast_type))
        fname_seg_crop_res = sct.add_suffix(fname_res3d, '_seg')
        segment_3d(model_fname=segmentation_model_fname,
                   contrast_type=contrast_type,
                   fname_in=fname_res3d,
                   fname_out=fname_seg_crop_res)

        # resample to the initial pz resolution
        fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
        initial_2d_resolution = 'x'.join(
            ['0.5', '0.5', str(input_resolution[2])])
        spinalcordtoolbox.resample.nipy_resample.resample_file(
            fname_seg_crop_res,
            fname_seg_res2d,
            initial_2d_resolution,
            'mm',
            'linear',
            verbose=0)
        seg_crop_data = Image(fname_seg_res2d).data

    # reconstruct the segmentation from the crop data
    sct.log.info("Reassembling the image...")
    fname_seg_res_RPI = sct.add_suffix(file_fname, '_res_RPI_seg')
    uncrop_image(fname_ref=fname_res,
                 fname_out=fname_seg_res_RPI,
                 data_crop=seg_crop_data,
                 x_crop_lst=X_CROP_LST,
                 y_crop_lst=Y_CROP_LST)

    # resample to initial resolution
    sct.log.info(
        "Resampling the segmentation to the original image resolution...")
    fname_seg_RPI = sct.add_suffix(file_fname, '_RPI_seg')
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    spinalcordtoolbox.resample.nipy_resample.resample_file(fname_seg_res_RPI,
                                                           fname_seg_RPI,
                                                           initial_resolution,
                                                           'mm',
                                                           'linear',
                                                           verbose=0)

    # binarize the resampled image to remove interpolation effects
    sct.log.info(
        "Binarizing the segmentation to avoid interpolation effects...")
    thr = '0.0001' if contrast_type in ['t1', 'dwi'] else '0.5'
    sct.run(
        ['sct_maths', '-i', fname_seg_RPI, '-bin', thr, '-o', fname_seg_RPI],
        verbose=0)

    # post processing step to z_regularized
    post_processing_volume_wise(fname_in=fname_seg_RPI)

    # reorient to initial orientation
    sct.log.info(
        "Reorienting the segmentation to the original image orientation...")
    fname_seg = sct.add_suffix(file_fname, '_seg')
    if original_orientation != 'RPI':
        im_seg_orient = set_orientation(Image(fname_seg_RPI),
                                        original_orientation)
        im_seg_orient.setFileName(fname_seg)
        im_seg_orient.save()
    else:
        sct.copy(fname_seg_RPI, fname_seg)

    tmp_folder.chdir_undo()

    # copy image from temporary folder into output folder
    sct.copy(os.path.join(tmp_folder_path, fname_seg), output_folder)

    # remove temporary files
    if remove_temp_files:
        sct.log.info("Remove temporary files...")
        tmp_folder.cleanup()

    return os.path.join(output_folder, fname_seg)
Ejemplo n.º 6
0
def run_main():
    parser = Parser(__file__)
    parser.usage.set_description(
        """This program will use the OptiC method to detect the spinal cord centerline."""
    )

    parser.add_option(name="-i",
                      type_value="image_nifti",
                      description="input image.",
                      mandatory=True,
                      example="t1.nii.gz")

    parser.add_option(name="-c",
                      type_value="multiple_choice",
                      description="type of image contrast.",
                      mandatory=True,
                      example=['t1', 't2', 't2s', 'dwi'])

    parser.add_option(name="-ofolder",
                      type_value="folder_creation",
                      description="output folder.",
                      mandatory=False,
                      example="My_Output_Folder/",
                      default_value="")

    parser.add_option(
        name="-roi",
        type_value="multiple_choice",
        description="outputs a ROI file, compatible with JIM software.",
        mandatory=False,
        example=['0', '1'],
        default_value='0')

    parser.add_option(name="-r",
                      type_value="multiple_choice",
                      description="remove temporary files.",
                      mandatory=False,
                      example=['0', '1'],
                      default_value='1')

    parser.add_option(name="-v",
                      type_value="multiple_choice",
                      description="1: display on, 0: display off (default)",
                      mandatory=False,
                      example=["0", "1"],
                      default_value="1")

    args = sys.argv[1:]
    arguments = parser.parse(args)

    # Input filename
    fname_input_data = arguments["-i"]
    fname_data = os.path.abspath(fname_input_data)

    # Contrast type
    contrast_type = arguments["-c"]

    # Output folder
    if "-ofolder" in arguments:
        folder_output = sct.slash_at_the_end(arguments["-ofolder"], slash=1)
    else:
        folder_output = './'

    # Remove temporary files
    remove_temp_files = True
    if "-r" in arguments:
        remove_temp_files = bool(arguments["-r"])

    # Outputs a ROI file
    output_roi = False
    if "-roi" in arguments:
        output_roi = bool(arguments["-roi"])

    # Verbosity
    verbose = 0
    if "-v" in arguments:
        if arguments["-v"] is "1":
            verbose = 2

    # OptiC models
    path_script = os.path.dirname(__file__)
    path_sct = os.path.dirname(path_script)
    optic_models_path = os.path.join(path_sct, 'data/optic_models',
                                     '{}_model'.format(contrast_type))

    # Execute OptiC binary
    _, optic_filename = optic.detect_centerline(
        image_fname=fname_data,
        contrast_type=contrast_type,
        optic_models_path=optic_models_path,
        folder_output=folder_output,
        remove_temp_files=remove_temp_files,
        output_roi=output_roi,
        verbose=verbose)

    sct.printv('\nDone! To view results, type:', verbose)
    sct.printv(
        "fslview " + fname_input_data + " " + optic_filename +
        " -l Red -b 0,1 -t 0.7 &\n", verbose, 'info')
Ejemplo n.º 7
0
def run_main():
    sct.start_stream_logger()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    # Input filename
    fname_input_data = arguments["-i"]
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = 'optic'
    if "-method" in arguments:
        method = arguments["-method"]

    # Contrast type
    contrast_type = ''
    if "-c" in arguments:
        contrast_type = arguments["-c"]
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = 'ERROR: -c is a mandatory argument when using Optic method.'
        sct.printv(error, type='error')
        return

    # Ga between slices
    interslice_gap = 10.0
    if "-gap" in arguments:
        interslice_gap = float(arguments["-gap"])

    # Output folder
    if "-ofolder" in arguments:
        folder_output = sct.slash_at_the_end(arguments["-ofolder"], slash=1)
    else:
        folder_output = './'

    # Remove temporary files
    remove_temp_files = True
    if "-r" in arguments:
        remove_temp_files = bool(int(arguments["-r"]))

    # Outputs a ROI file
    output_roi = False
    if "-roi" in arguments:
        output_roi = bool(int(arguments["-roi"]))

    # Verbosity
    verbose = 0
    if "-v" in arguments:
        verbose = int(arguments["-v"])

    if method == 'viewer':
        path_data, file_data, ext_data = sct.extract_fname(fname_data)

        # create temporary folder
        temp_folder = sct.TempFolder()
        temp_folder.copy_from(fname_data)
        temp_folder.chdir()

        # make sure image is in SAL orientation, as it is the orientation used by the viewer
        image_input = Image(fname_data)
        image_input_orientation = orientation(image_input,
                                              get=True,
                                              verbose=False)
        reoriented_image_filename = sct.add_suffix(file_data + ext_data,
                                                   "_SAL")
        cmd_image = 'sct_image -i "%s" -o "%s" -setorient SAL -v 0' % (
            fname_data, reoriented_image_filename)
        sct.run(cmd_image, verbose=False)

        # extract points manually using the viewer
        fname_points = viewer_centerline(image_fname=reoriented_image_filename,
                                         interslice_gap=interslice_gap,
                                         verbose=verbose)

        if fname_points is not None:
            image_points_RPI = sct.add_suffix(fname_points, "_RPI")
            cmd_image = 'sct_image -i "%s" -o "%s" -setorient RPI -v 0' % (
                fname_points, image_points_RPI)
            sct.run(cmd_image, verbose=False)

            image_input_reoriented = Image(image_points_RPI)

            # fit centerline, smooth it and return the first derivative (in physical space)
            x_centerline_fit, y_centerline_fit, z_centerline, x_centerline_deriv, y_centerline_deriv, z_centerline_deriv = smooth_centerline(
                image_points_RPI,
                algo_fitting='nurbs',
                nurbs_pts_number=3000,
                phys_coordinates=True,
                verbose=verbose,
                all_slices=False)
            centerline = Centerline(x_centerline_fit, y_centerline_fit,
                                    z_centerline, x_centerline_deriv,
                                    y_centerline_deriv, z_centerline_deriv)

            # average centerline coordinates over slices of the image
            x_centerline_fit_rescorr, y_centerline_fit_rescorr, z_centerline_rescorr, x_centerline_deriv_rescorr, y_centerline_deriv_rescorr, z_centerline_deriv_rescorr = centerline.average_coordinates_over_slices(
                image_input_reoriented)

            # compute z_centerline in image coordinates for usage in vertebrae mapping
            voxel_coordinates = image_input_reoriented.transfo_phys2pix([[
                x_centerline_fit_rescorr[i], y_centerline_fit_rescorr[i],
                z_centerline_rescorr[i]
            ] for i in range(len(z_centerline_rescorr))])
            x_centerline_voxel = [coord[0] for coord in voxel_coordinates]
            y_centerline_voxel = [coord[1] for coord in voxel_coordinates]
            z_centerline_voxel = [coord[2] for coord in voxel_coordinates]

            # compute z_centerline in image coordinates with continuous precision
            voxel_coordinates = image_input_reoriented.transfo_phys2continuouspix(
                [[
                    x_centerline_fit_rescorr[i], y_centerline_fit_rescorr[i],
                    z_centerline_rescorr[i]
                ] for i in range(len(z_centerline_rescorr))])
            x_centerline_voxel_cont = [coord[0] for coord in voxel_coordinates]
            y_centerline_voxel_cont = [coord[1] for coord in voxel_coordinates]
            z_centerline_voxel_cont = [coord[2] for coord in voxel_coordinates]

            # Create an image with the centerline
            image_input_reoriented.data *= 0
            min_z_index, max_z_index = int(round(
                min(z_centerline_voxel))), int(round(max(z_centerline_voxel)))
            for iz in range(min_z_index, max_z_index + 1):
                image_input_reoriented.data[
                    int(round(x_centerline_voxel[iz - min_z_index])),
                    int(round(y_centerline_voxel[iz - min_z_index])),
                    int(
                        iz
                    )] = 1  # if index is out of bounds here for hanning: either the segmentation has holes or labels have been added to the file

            # Write the centerline image
            sct.printv('\nWrite NIFTI volumes...', verbose)
            fname_centerline_oriented = file_data + '_centerline' + ext_data
            image_input_reoriented.setFileName(fname_centerline_oriented)
            image_input_reoriented.changeType('uint8')
            image_input_reoriented.save()

            sct.printv('\nSet to original orientation...', verbose)
            sct.run('sct_image -i ' + fname_centerline_oriented +
                    ' -setorient ' + image_input_orientation + ' -o ' +
                    fname_centerline_oriented)

            # create a txt file with the centerline
            fname_centerline_oriented_txt = file_data + '_centerline.txt'
            file_results = open(fname_centerline_oriented_txt, 'w')
            for i in range(min_z_index, max_z_index + 1):
                file_results.write(
                    str(int(i)) + ' ' +
                    str(round(x_centerline_voxel_cont[i - min_z_index], 2)) +
                    ' ' +
                    str(round(y_centerline_voxel_cont[i - min_z_index], 2)) +
                    '\n')
            file_results.close()

            fname_centerline_oriented_roi = optic.centerline2roi(
                fname_image=fname_centerline_oriented,
                folder_output='./',
                verbose=verbose)

            # return to initial folder
            temp_folder.chdir_undo()

            # copy result to output folder
            shutil.copy(temp_folder.get_path() + fname_centerline_oriented,
                        folder_output)
            shutil.copy(temp_folder.get_path() + fname_centerline_oriented_txt,
                        folder_output)
            if output_roi:
                shutil.copy(
                    temp_folder.get_path() + fname_centerline_oriented_roi,
                    folder_output)
            centerline_filename = folder_output + fname_centerline_oriented

        else:
            centerline_filename = 'error'

        # delete temporary folder
        if remove_temp_files:
            temp_folder.cleanup()

    else:
        # condition on verbose when using OptiC
        if verbose == 1:
            verbose = 2

        # OptiC models
        path_script = os.path.dirname(__file__)
        path_sct = os.path.dirname(path_script)
        optic_models_path = os.path.join(path_sct, 'data/optic_models',
                                         '{}_model'.format(contrast_type))

        # Execute OptiC binary
        _, centerline_filename = optic.detect_centerline(
            image_fname=fname_data,
            contrast_type=contrast_type,
            optic_models_path=optic_models_path,
            folder_output=folder_output,
            remove_temp_files=remove_temp_files,
            output_roi=output_roi,
            verbose=verbose)

    sct.printv('\nDone! To view results, type:', verbose)
    sct.printv(
        "fslview " + fname_input_data + " " + centerline_filename +
        " -l Red -b 0,1 -t 0.7 &\n", verbose, 'info')
def run_main():
    sct.init_sct()
    parser = get_parser()
    args = sys.argv[1:]
    arguments = parser.parse(args)

    # Input filename
    fname_input_data = arguments["-i"]
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = 'optic'
    if "-method" in arguments:
        method = arguments["-method"]

    # Contrast type
    contrast_type = ''
    if "-c" in arguments:
        contrast_type = arguments["-c"]
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = 'ERROR: -c is a mandatory argument when using Optic method.'
        sct.printv(error, type='error')
        return

    # Ga between slices
    interslice_gap = 10.0
    if "-gap" in arguments:
        interslice_gap = float(arguments["-gap"])

    # Output folder
    if "-ofolder" in arguments:
        folder_output = arguments["-ofolder"]
    else:
        folder_output = '.'

    # Remove temporary files
    remove_temp_files = True
    if "-r" in arguments:
        remove_temp_files = bool(int(arguments["-r"]))

    # Outputs a ROI file
    output_roi = False
    if "-roi" in arguments:
        output_roi = bool(int(arguments["-roi"]))

    # Verbosity
    verbose = 0
    if "-v" in arguments:
        verbose = int(arguments["-v"])

    if method == 'viewer':
        fname_labels_viewer = _call_viewer_centerline(
            fname_in=fname_data, interslice_gap=interslice_gap)
        centerline_filename = extract_centerline(
            fname_labels_viewer,
            remove_temp_files=remove_temp_files,
            verbose=verbose,
            algo_fitting='nurbs',
            nurbs_pts_number=8000)

    else:
        # condition on verbose when using OptiC
        if verbose == 1:
            verbose = 2

        # OptiC models
        path_script = os.path.dirname(__file__)
        path_sct = os.path.dirname(path_script)
        optic_models_path = os.path.join(path_sct, 'data', 'optic_models',
                                         '{}_model'.format(contrast_type))

        # Execute OptiC binary
        _, centerline_filename = optic.detect_centerline(
            image_fname=fname_data,
            contrast_type=contrast_type,
            optic_models_path=optic_models_path,
            folder_output=folder_output,
            remove_temp_files=remove_temp_files,
            output_roi=output_roi,
            verbose=verbose)

    sct.display_viewer_syntax([fname_input_data, centerline_filename],
                              colormaps=['gray', 'red'],
                              opacities=['', '1'])
Ejemplo n.º 9
0
def get_centerline(im_seg, param=ParamCenterline(), verbose=1):
    """
    Extract centerline from an image (using optic) or from a binary or weighted segmentation (using the center of mass).

    :param im_seg: Image(): Input segmentation or series of points along the centerline.
    :param param: ParamCenterline() class:
    :param verbose: int: verbose level
    :return: im_centerline: Image: Centerline in discrete coordinate (int)
    :return: arr_centerline: 3x1 array: Centerline in continuous coordinate (float) for each slice in RPI orientation.
    :return: arr_centerline_deriv: 3x1 array: Derivatives of x and y centerline wrt. z for each slice in RPI orient.
    :return: fit_results: FitResults class
    """

    if not isinstance(im_seg, Image):
        raise ValueError("Expecting an image")
    # Open image and change to RPI orientation
    native_orientation = im_seg.orientation
    im_seg.change_orientation('RPI')
    px, py, pz = im_seg.dim[4:7]

    # Take the center of mass at each slice to avoid: https://stackoverflow.com/questions/2009379/interpolate-question
    x_mean, y_mean, z_mean = find_and_sort_coord(im_seg)

    # Crop output centerline to where the segmentation starts/end
    if param.minmax:
        z_ref = np.array(range(z_mean.min().astype(int), z_mean.max().astype(int) + 1))
    else:
        z_ref = np.array(range(im_seg.dim[2]))
    index_mean = np.array([list(z_ref).index(i) for i in z_mean])

    # Choose method
    if param.algo_fitting == 'polyfit':
        x_centerline_fit, x_centerline_deriv = curve_fitting.polyfit_1d(z_mean, x_mean, z_ref, deg=param.degree)
        y_centerline_fit, y_centerline_deriv = curve_fitting.polyfit_1d(z_mean, y_mean, z_ref, deg=param.degree)
        fig_title = 'Algo={}, Deg={}'.format(param.algo_fitting, param.degree)

    elif param.algo_fitting == 'bspline':
        x_centerline_fit, x_centerline_deriv = curve_fitting.bspline(z_mean, x_mean, z_ref, param.smooth, pz=pz)
        y_centerline_fit, y_centerline_deriv = curve_fitting.bspline(z_mean, y_mean, z_ref, param.smooth, pz=pz)
        fig_title = 'Algo={}, Smooth={}'.format(param.algo_fitting, param.smooth)

    elif param.algo_fitting == 'linear':
        # Simple linear interpolation
        x_centerline_fit, x_centerline_deriv = curve_fitting.linear(z_mean, x_mean, z_ref, param.smooth, pz=pz)
        y_centerline_fit, y_centerline_deriv = curve_fitting.linear(z_mean, y_mean, z_ref, param.smooth, pz=pz)
        fig_title = 'Algo={}, Smooth={}'.format(param.algo_fitting, param.smooth)

    elif param.algo_fitting == 'nurbs':
        from spinalcordtoolbox.centerline.nurbs import b_spline_nurbs
        point_number = 3000
        # Interpolate such that the output centerline has the same length as z_ref
        x_mean_interp, _ = curve_fitting.linear(z_mean, x_mean, z_ref, 0)
        y_mean_interp, _ = curve_fitting.linear(z_mean, y_mean, z_ref, 0)
        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, \
            z_centerline_deriv, error = b_spline_nurbs(x_mean_interp, y_mean_interp, z_ref, nbControl=None,
                                                       point_number=point_number, all_slices=True)
        # Normalize derivatives to z_deriv
        x_centerline_deriv = x_centerline_deriv / z_centerline_deriv
        y_centerline_deriv = y_centerline_deriv / z_centerline_deriv
        fig_title = 'Algo={}, NumberPoints={}'.format(param.algo_fitting, point_number)

    elif param.algo_fitting == 'optic':
        # This method is particular compared to the previous ones, as here we estimate the centerline based on the
        # image itself (not the segmentation). Hence, we can bypass the fitting procedure and centerline creation
        # and directly output results.
        from spinalcordtoolbox.centerline import optic
        assert param.contrast is not None
        im_centerline = optic.detect_centerline(im_seg, param.contrast, verbose)
        x_centerline_fit, y_centerline_fit, z_centerline = find_and_sort_coord(im_centerline)
        # Compute derivatives using polynomial fit
        # TODO: Fix below with reorientation of axes
        _, x_centerline_deriv = curve_fitting.polyfit_1d(z_centerline, x_centerline_fit, z_centerline, deg=param.degree)
        _, y_centerline_deriv = curve_fitting.polyfit_1d(z_centerline, y_centerline_fit, z_centerline, deg=param.degree)
        return \
            im_centerline.change_orientation(native_orientation), \
            np.array([x_centerline_fit, y_centerline_fit, z_centerline]), \
            np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_centerline)]), \
            None
    else:
        logger.error('algo_fitting "' + param.algo_fitting + '" does not exist.')
        raise ValueError

    # Create an image with the centerline
    im_centerline = im_seg.copy()
    im_centerline.data = np.zeros(im_centerline.data.shape)
    # Assign value=1 to centerline. Make sure to clip to avoid array overflow.
    # TODO: check this round and clip-- suspicious
    im_centerline.data[round_and_clip(x_centerline_fit, clip=[0, im_centerline.data.shape[0]]),
                       round_and_clip(y_centerline_fit, clip=[0, im_centerline.data.shape[1]]),
                       z_ref] = 1
    # reorient centerline to native orientation
    im_centerline.change_orientation(native_orientation)
    im_seg.change_orientation(native_orientation)
    # TODO: Reorient centerline in native orientation. For now, we output the array in RPI. Note that it is tricky to
    #   reorient in native orientation, because the voxel center is not in the middle, but in the top corner, so this
    #   needs to be taken into accound during reorientation. The code below does not work properly.
    # # Get a permutation and inversion based on native orientation
    # perm, inversion = _get_permutations(im_seg.orientation, native_orientation)
    # # axes inversion (flip)
    # # ctl = np.array([x_centerline_fit[::inversion[0]], y_centerline_fit[::inversion[1]], z_ref[::inversion[2]]])
    # ctl = np.array([x_centerline_fit, y_centerline_fit, z_ref])
    # ctl_deriv = np.array([x_centerline_deriv[::inversion[0]], y_centerline_deriv[::inversion[1]], np.ones_like(z_ref)])
    # return im_centerline, \
    #        np.array([ctl[perm[0]], ctl[perm[1]], ctl[perm[2]]]), \
    #        np.array([ctl_deriv[perm[0]], ctl_deriv[perm[1]], ctl_deriv[perm[2]]])

    # Compute fitting metrics
    fit_results = FitResults()
    fit_results.rmse = np.sqrt(np.mean((x_mean - x_centerline_fit[index_mean]) ** 2) * px +
                               np.mean((y_mean - y_centerline_fit[index_mean]) ** 2) * py)
    fit_results.laplacian_max = np.max([
        np.absolute(np.gradient(np.array(x_centerline_deriv * px))).max(),
        np.absolute(np.gradient(np.array(y_centerline_deriv * py))).max()])
    fit_results.data.zmean = z_mean
    fit_results.data.zref = z_ref
    fit_results.data.xmean = x_mean
    fit_results.data.xfit = x_centerline_fit
    fit_results.data.ymean = y_mean
    fit_results.data.yfit = y_centerline_fit
    fit_results.param = param

    # Display fig of fitted curves
    if verbose == 2:
        from datetime import datetime
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt
        plt.figure(figsize=(16, 10))
        plt.subplot(3, 1, 1)
        plt.title(fig_title + '\nRMSE[mm]={:0.2f}, LaplacianMax={:0.2f}'.format(fit_results.rmse, fit_results.laplacian_max))
        plt.plot(z_mean * pz, x_mean * px, 'ro')
        plt.plot(z_ref * pz, x_centerline_fit * px, 'k')
        plt.plot(z_ref * pz, x_centerline_fit * px, 'k.')
        plt.ylabel("X [mm]")
        plt.legend(['Reference', 'Fitting', 'Fitting points'])

        plt.subplot(3, 1, 2)
        plt.plot(z_mean * pz, y_mean * py, 'ro')
        plt.plot(z_ref * pz, y_centerline_fit * py, 'b')
        plt.plot(z_ref * pz, y_centerline_fit * py, 'b.')
        plt.xlabel("Z [mm]")
        plt.ylabel("Y [mm]")
        plt.legend(['Reference', 'Fitting', 'Fitting points'])

        plt.subplot(3, 1, 3)
        plt.plot(z_ref * pz, x_centerline_deriv * px, 'k.')
        plt.plot(z_ref * pz, y_centerline_deriv * py, 'b.')
        plt.grid(axis='y', color='grey', linestyle=':', linewidth=1)
        plt.axhline(color='grey', linestyle='-', linewidth=1)
        # plt.plot(z_ref * pz, z_centerline_deriv * pz, 'r.')
        plt.ylabel("dX/dZ, dY/dZ")
        plt.xlabel("Z [mm]")
        plt.legend(['X-deriv', 'Y-deriv'])

        plt.savefig('fig_centerline_' + datetime.now().strftime("%y%m%d-%H%M%S%f") + '_' + param.algo_fitting + '.png')
        plt.close()

    return im_centerline, \
           np.array([x_centerline_fit, y_centerline_fit, z_ref]), \
           np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_ref)]), \
           fit_results
Ejemplo n.º 10
0
def get_centerline(im_seg, algo_fitting='polyfit', minmax=True, contrast=None, degree=5, smooth=10, verbose=1):
    """
    Extract centerline from an image (using optic) or from a binary or weighted segmentation (using the center of mass).
    :param im_seg: Image(): Input segmentation or series of points along the centerline.
    :param algo_fitting: str:
        polyfit: Polynomial fitting
        nurbs:
        optic: Automatic segmentation using SVM and HOG. See [Gros et al. MIA 2018].
    :param minmax: Crop output centerline where the segmentation starts/end. If False, centerline will span all slices.
    :param contrast: Contrast type for algo=optic.
    :param degree: int: Max degree for polynomial fitting.
    :param smooth: int: Smoothing factor for bspline fitting. 1: none, 10: moderate smoothing, 100: large smoothing
    :param verbose: int: verbose level
    :return: im_centerline: Image: Centerline in discrete coordinate (int)
    :return: arr_centerline: 3x1 array: Centerline in continuous coordinate (float) for each slice in RPI orientation.
    :return: arr_centerline_deriv: 3x1 array: Derivatives of x and y centerline wrt. z for each slice in RPI orient.
    """

    if not isinstance(im_seg, Image):
        raise ValueError("Expecting an image")
    # Open image and change to RPI orientation
    native_orientation = im_seg.orientation
    im_seg.change_orientation('RPI')
    px, py, pz = im_seg.dim[4:7]

    # Take the center of mass at each slice to avoid: https://stackoverflow.com/questions/2009379/interpolate-question
    x_mean, y_mean, z_mean = find_and_sort_coord(im_seg)

    # Crop output centerline to where the segmentation starts/end
    if minmax:
        z_ref = np.array(range(z_mean.min().astype(int), z_mean.max().astype(int) + 1))
    else:
        z_ref = np.array(range(im_seg.dim[2]))

    # Choose method
    if algo_fitting == 'polyfit':
        x_centerline_fit, x_centerline_deriv = curve_fitting.polyfit_1d(z_mean, x_mean, z_ref, deg=degree)
        y_centerline_fit, y_centerline_deriv = curve_fitting.polyfit_1d(z_mean, y_mean, z_ref, deg=degree)

    elif algo_fitting == 'bspline':
        x_centerline_fit, x_centerline_deriv = curve_fitting.bspline(z_mean, x_mean, z_ref, smooth=smooth)
        y_centerline_fit, y_centerline_deriv = curve_fitting.bspline(z_mean, y_mean, z_ref, smooth=smooth)

    elif algo_fitting == 'linear':
        # Simple linear interpolation
        x_centerline_fit = curve_fitting.linear(z_mean, x_mean, z_ref)
        y_centerline_fit = curve_fitting.linear(z_mean, y_mean, z_ref)
        # Compute derivatives using polynomial fit due to undefined derivatives using linear interpolation
        _, x_centerline_deriv = curve_fitting.polyfit_1d(z_mean, x_mean, z_ref, deg=degree)
        _, y_centerline_deriv = curve_fitting.polyfit_1d(z_mean, y_mean, z_ref, deg=degree)

    elif algo_fitting == 'nurbs':
        from spinalcordtoolbox.centerline.nurbs import b_spline_nurbs
        # Interpolate such that the output centerline has the same length as z_ref
        x_mean_interp = curve_fitting.linear(z_mean, x_mean, z_ref)
        y_mean_interp = curve_fitting.linear(z_mean, y_mean, z_ref)
        x_centerline_fit, y_centerline_fit, z_centerline_fit, x_centerline_deriv, y_centerline_deriv, \
            z_centerline_deriv, error = b_spline_nurbs(x_mean_interp, y_mean_interp, z_ref, nbControl=None, point_number=3000,
                                                       all_slices=True)

    elif algo_fitting == 'optic':
        # This method is particular compared to the previous ones, as here we estimate the centerline based on the
        # image itself (not the segmentation). Hence, we can bypass the fitting procedure and centerline creation
        # and directly output results.
        from spinalcordtoolbox.centerline import optic
        im_centerline = optic.detect_centerline(im_seg, contrast, verbose)
        x_centerline_fit, y_centerline_fit, z_centerline = find_and_sort_coord(im_centerline)
        # Compute derivatives using polynomial fit
        # TODO: Fix below with reorientation of axes
        _, x_centerline_deriv = curve_fitting.polyfit_1d(z_centerline, x_centerline_fit, z_centerline, deg=degree)
        _, y_centerline_deriv = curve_fitting.polyfit_1d(z_centerline, y_centerline_fit, z_centerline, deg=degree)
        return im_centerline.change_orientation(native_orientation), \
               np.array([x_centerline_fit, y_centerline_fit, z_centerline]), \
               np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_centerline)]),

    else:
        logging.error('algo_fitting "' + algo_fitting + '" does not exist.')
        raise ValueError

    # Display fig of fitted curves
    if verbose == 2:
        from datetime import datetime
        import matplotlib
        matplotlib.use('Agg')  # prevent display figure
        import matplotlib.pyplot as plt
        plt.figure()
        plt.subplot(2, 1, 1)
        plt.title("Algo=%s, Deg=%s" % (algo_fitting, degree))
        plt.plot(z_ref * pz, x_centerline_fit * px)
        plt.plot(z_ref * pz, x_centerline_fit * px, 'b.')
        plt.plot(z_mean * pz, x_mean * px, 'ro')
        plt.ylabel("X [mm]")
        plt.subplot(2, 1, 2)
        plt.plot(z_ref, y_centerline_fit)
        plt.plot(z_ref, y_centerline_fit, 'b.')
        plt.plot(z_mean, y_mean, 'ro')
        plt.xlabel("Z [mm]")
        plt.ylabel("Y [mm]")
        plt.savefig('fig_centerline_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '_' + algo_fitting + '.png')
        plt.close()

    # Create an image with the centerline
    im_centerline = im_seg.copy()
    im_centerline.data = np.zeros(im_centerline.data.shape)
    # Assign value=1 to centerline. Make sure to clip to avoid array overflow.
    # TODO: check this round and clip-- suspicious
    im_centerline.data[round_and_clip(x_centerline_fit, clip=[0, im_centerline.data.shape[0]]),
                       round_and_clip(y_centerline_fit, clip=[0, im_centerline.data.shape[1]]),
                       z_ref] = 1
    # reorient centerline to native orientation
    im_centerline.change_orientation(native_orientation)
    # TODO: Reorient centerline in native orientation. For now, we output the array in RPI. Note that it is tricky to
    #   reorient in native orientation, because the voxel center is not in the middle, but in the top corner, so this
    #   needs to be taken into accound during reorientation. The code below does not work properly.
    # # Get a permutation and inversion based on native orientation
    # perm, inversion = _get_permutations(im_seg.orientation, native_orientation)
    # # axes inversion (flip)
    # # ctl = np.array([x_centerline_fit[::inversion[0]], y_centerline_fit[::inversion[1]], z_ref[::inversion[2]]])
    # ctl = np.array([x_centerline_fit, y_centerline_fit, z_ref])
    # ctl_deriv = np.array([x_centerline_deriv[::inversion[0]], y_centerline_deriv[::inversion[1]], np.ones_like(z_ref)])
    # return im_centerline, \
    #        np.array([ctl[perm[0]], ctl[perm[1]], ctl[perm[2]]]), \
    #        np.array([ctl_deriv[perm[0]], ctl_deriv[perm[1]], ctl_deriv[perm[2]]])
    return im_centerline, \
           np.array([x_centerline_fit, y_centerline_fit, z_ref]), \
           np.array([x_centerline_deriv, y_centerline_deriv, np.ones_like(z_ref)])
Ejemplo n.º 11
0
def propseg(img_input, options_dict):
    """
    :param img_input: source image, to be segmented
    :param options_dict: arguments as dictionary
    :return: segmented Image
    """
    arguments = options_dict
    fname_input_data = img_input.absolutepath
    fname_data = os.path.abspath(fname_input_data)
    contrast_type = arguments["-c"]
    contrast_type_conversion = {'t1': 't1', 't2': 't2', 't2s': 't2', 'dwi': 't1'}
    contrast_type_propseg = contrast_type_conversion[contrast_type]

    # Starting building the command
    cmd = ['isct_propseg', '-t', contrast_type_propseg]

    if "-ofolder" in arguments:
        folder_output = arguments["-ofolder"]
    else:
        folder_output = './'
    cmd += ['-o', folder_output]
    if not os.path.isdir(folder_output) and os.path.exists(folder_output):
        logger.error("output directory %s is not a valid directory" % folder_output)
    if not os.path.exists(folder_output):
        os.makedirs(folder_output)

    if "-down" in arguments:
        cmd += ["-down", str(arguments["-down"])]
    if "-up" in arguments:
        cmd += ["-up", str(arguments["-up"])]

    remove_temp_files = 1
    if "-r" in arguments:
        remove_temp_files = int(arguments["-r"])

    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    # Update for propseg binary
    if verbose > 0:
        cmd += ["-verbose"]

    # Output options
    if "-mesh" in arguments:
        cmd += ["-mesh"]
    if "-centerline-binary" in arguments:
        cmd += ["-centerline-binary"]
    if "-CSF" in arguments:
        cmd += ["-CSF"]
    if "-centerline-coord" in arguments:
        cmd += ["-centerline-coord"]
    if "-cross" in arguments:
        cmd += ["-cross"]
    if "-init-tube" in arguments:
        cmd += ["-init-tube"]
    if "-low-resolution-mesh" in arguments:
        cmd += ["-low-resolution-mesh"]
    if "-detect-nii" in arguments:
        cmd += ["-detect-nii"]
    if "-detect-png" in arguments:
        cmd += ["-detect-png"]

    # Helping options
    use_viewer = None
    use_optic = True  # enabled by default
    init_option = None
    rescale_header = arguments["-rescale"]
    if "-init" in arguments:
        init_option = float(arguments["-init"])
        if init_option < 0:
            sct.printv('Command-line usage error: ' + str(init_option) + " is not a valid value for '-init'", 1, 'error')
            sys.exit(1)
    if "-init-centerline" in arguments:
        if str(arguments["-init-centerline"]) == "viewer":
            use_viewer = "centerline"
        elif str(arguments["-init-centerline"]) == "hough":
            use_optic = False
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(arguments["-init-centerline"]), rescale_header, verbose=verbose)
            else:
                fname_labels_viewer = str(arguments["-init-centerline"])
            cmd += ["-init-centerline", fname_labels_viewer]
            use_optic = False
    if "-init-mask" in arguments:
        if str(arguments["-init-mask"]) == "viewer":
            use_viewer = "mask"
        else:
            if rescale_header is not 1:
                fname_labels_viewer = func_rescale_header(str(arguments["-init-mask"]), rescale_header)
            else:
                fname_labels_viewer = str(arguments["-init-mask"])
            cmd += ["-init-mask", fname_labels_viewer]
            use_optic = False
    if "-mask-correction" in arguments:
        cmd += ["-mask-correction", str(arguments["-mask-correction"])]
    if "-radius" in arguments:
        cmd += ["-radius", str(arguments["-radius"])]
    if "-detect-n" in arguments:
        cmd += ["-detect-n", str(arguments["-detect-n"])]
    if "-detect-gap" in arguments:
        cmd += ["-detect-gap", str(arguments["-detect-gap"])]
    if "-init-validation" in arguments:
        cmd += ["-init-validation"]
    if "-nbiter" in arguments:
        cmd += ["-nbiter", str(arguments["-nbiter"])]
    if "-max-area" in arguments:
        cmd += ["-max-area", str(arguments["-max-area"])]
    if "-max-deformation" in arguments:
        cmd += ["-max-deformation", str(arguments["-max-deformation"])]
    if "-min-contrast" in arguments:
        cmd += ["-min-contrast", str(arguments["-min-contrast"])]
    if "-d" in arguments:
        cmd += ["-d", str(arguments["-d"])]
    if "-distance-search" in arguments:
        cmd += ["-dsearch", str(arguments["-distance-search"])]
    if "-alpha" in arguments:
        cmd += ["-alpha", str(arguments["-alpha"])]

    # check if input image is in 3D. Otherwise itk image reader will cut the 4D image in 3D volumes and only take the first one.
    image_input = Image(fname_data)
    image_input_rpi = image_input.copy().change_orientation('RPI')
    nx, ny, nz, nt, px, py, pz, pt = image_input_rpi.dim
    if nt > 1:
        sct.printv('ERROR: your input image needs to be 3D in order to be segmented.', 1, 'error')

    path_data, file_data, ext_data = sct.extract_fname(fname_data)
    path_tmp = sct.tmp_create(basename="label_vertebrae", verbose=verbose)

    # rescale header (see issue #1406)
    if rescale_header is not 1:
        fname_data_propseg = func_rescale_header(fname_data, rescale_header)
    else:
        fname_data_propseg = fname_data

    # add to command
    cmd += ['-i', fname_data_propseg]

    # if centerline or mask is asked using viewer
    if use_viewer:
        from spinalcordtoolbox.gui.base import AnatomicalParams
        from spinalcordtoolbox.gui.centerline import launch_centerline_dialog

        params = AnatomicalParams()
        if use_viewer == 'mask':
            params.num_points = 3
            params.interval_in_mm = 15  # superior-inferior interval between two consecutive labels
            params.starting_slice = 'midfovminusinterval'
        if use_viewer == 'centerline':
            # setting maximum number of points to a reasonable value
            params.num_points = 20
            params.interval_in_mm = 30
            params.starting_slice = 'top'
        im_data = Image(fname_data_propseg)

        im_mask_viewer = msct_image.zeros_like(im_data)
        # im_mask_viewer.absolutepath = sct.add_suffix(fname_data_propseg, '_labels_viewer')
        controller = launch_centerline_dialog(im_data, im_mask_viewer, params)
        fname_labels_viewer = sct.add_suffix(fname_data_propseg, '_labels_viewer')

        if not controller.saved:
            sct.printv('The viewer has been closed before entering all manual points. Please try again.', 1, 'error')
            sys.exit(1)
        # save labels
        controller.as_niftii(fname_labels_viewer)

        # add mask filename to parameters string
        if use_viewer == "centerline":
            cmd += ["-init-centerline", fname_labels_viewer]
        elif use_viewer == "mask":
            cmd += ["-init-mask", fname_labels_viewer]

    # If using OptiC
    elif use_optic:
        image_centerline = optic.detect_centerline(image_input, contrast_type, verbose)
        fname_centerline_optic = os.path.join(path_tmp, 'centerline_optic.nii.gz')
        image_centerline.save(fname_centerline_optic)
        cmd += ["-init-centerline", fname_centerline_optic]

    if init_option is not None:
        if init_option > 1:
            init_option /= (nz - 1)
        cmd += ['-init', str(init_option)]

    # enabling centerline extraction by default (needed by check_and_correct_segmentation() )
    cmd += ['-centerline-binary']

    # run propseg
    status, output = sct.run(cmd, verbose, raise_exception=False, is_sct_binary=True)

    # check status is not 0
    if not status == 0:
        sct.printv('Automatic cord detection failed. Please initialize using -init-centerline or -init-mask (see help)',
                   1, 'error')
        sys.exit(1)

    # build output filename
    fname_seg = os.path.join(folder_output, os.path.basename(sct.add_suffix(fname_data, "_seg")))
    fname_centerline = os.path.join(folder_output, os.path.basename(sct.add_suffix(fname_data, "_centerline")))
    # in case header was rescaled, we need to update the output file names by removing the "_rescaled"
    if rescale_header is not 1:
        sct.mv(os.path.join(folder_output, sct.add_suffix(os.path.basename(fname_data_propseg), "_seg")),
                  fname_seg)
        sct.mv(os.path.join(folder_output, sct.add_suffix(os.path.basename(fname_data_propseg), "_centerline")),
                  fname_centerline)
        # if user was used, copy the labelled points to the output folder (they will then be scaled back)
        if use_viewer:
            fname_labels_viewer_new = os.path.join(folder_output, os.path.basename(sct.add_suffix(fname_data,
                                                                                                  "_labels_viewer")))
            sct.copy(fname_labels_viewer, fname_labels_viewer_new)
            # update variable (used later)
            fname_labels_viewer = fname_labels_viewer_new

    # check consistency of segmentation
    if arguments["-correct-seg"] == "1":
        check_and_correct_segmentation(fname_seg, fname_centerline, folder_output=folder_output, threshold_distance=3.0,
                                       remove_temp_files=remove_temp_files, verbose=verbose)

    # copy header from input to segmentation to make sure qform is the same
    sct.printv("Copy header input --> output(s) to make sure qform is the same.", verbose)
    list_fname = [fname_seg, fname_centerline]
    if use_viewer:
        list_fname.append(fname_labels_viewer)
    for fname in list_fname:
        im = Image(fname)
        im.header = image_input.header
        im.save(dtype='int8')  # they are all binary masks hence fine to save as int8

    return Image(fname_seg)
Ejemplo n.º 12
0
def find_centerline(algo, image_fname, path_sct, contrast_type, brain_bool,
                    folder_output, remove_temp_files, centerline_fname):

    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        from sct_image import concat_data
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
    else:
        bool_2d = False

    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        optic_models_fname = os.path.join(path_sct, 'data', 'optic_models',
                                          '{}_model'.format(contrast_type))
        _, centerline_filename = optic.detect_centerline(
            image_fname=image_fname,
            contrast_type=contrast_type,
            optic_models_path=optic_models_fname,
            folder_output=folder_output,
            remove_temp_files=remove_temp_files,
            output_roi=False,
            verbose=0)
    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {
            't2': {
                'size': (80, 80),
                'mean': 51.1417,
                'std': 57.4408
            },
            't2s': {
                'size': (80, 80),
                'mean': 68.8591,
                'std': 71.4659
            },
            't1': {
                'size': (80, 80),
                'mean': 55.7359,
                'std': 64.3149
            },
            'dwi': {
                'size': (80, 80),
                'mean': 55.744,
                'std': 45.003
            }
        }
        dct_params_ctr = {
            't2': {
                'features': 16,
                'dilation_layers': 2
            },
            't2s': {
                'features': 8,
                'dilation_layers': 3
            },
            't1': {
                'features': 24,
                'dilation_layers': 3
            },
            'dwi': {
                'features': 8,
                'dilation_layers': 2
            }
        }

        # load model
        ctr_model_fname = os.path.join(path_sct, 'data', 'deepseg_sc_models',
                                       '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(
            height=dct_patch_ctr[contrast_type]['size'][0],
            width=dct_patch_ctr[contrast_type]['size'][1],
            channels=1,
            classes=1,
            features=dct_params_ctr[contrast_type]['features'],
            depth=2,
            temperature=1.0,
            padding='same',
            batchnorm=True,
            dropout=0.0,
            dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        spinalcordtoolbox.resample.nipy_resample.resample_file(image_fname,
                                                               fname_res,
                                                               new_resolution,
                                                               'mm',
                                                               'linear',
                                                               verbose=0)

        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,
                        filename_out=fname_heatmap_nii,
                        model=ctr_model,
                        patch_shape=dct_patch_ctr[contrast_type]['size'],
                        mean_train=dct_patch_ctr[contrast_type]['mean'],
                        std_train=dct_patch_ctr[contrast_type]['std'],
                        brain_bool=brain_bool)

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
        heatmap2optic(fname_heatmap=fname_heatmap_nii,
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      fname_out=centerline_filename,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        fname_labels_viewer = _call_viewer_centerline(fname_in=image_fname)
        centerline_filename = extract_centerline(fname_labels_viewer,
                                                 remove_temp_files=True,
                                                 algo_fitting='nurbs',
                                                 nurbs_pts_number=8000)
    elif algo == 'manual':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        image_manual_centerline = Image(centerline_fname)
        # Re-orient and Re-sample the manual centerline
        msct_image.change_orientation(image_manual_centerline,
                                      'RPI').save(centerline_filename)
    else:
        sct.log.error(
            'The parameter "-centerline" is incorrect. Please try again.')
        sys.exit(1)

    if algo != 'cnn':
        sct.log.info("Resample the image to 0.5 mm isotropic resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        spinalcordtoolbox.resample.nipy_resample.resample_file(image_fname,
                                                               fname_res,
                                                               new_resolution,
                                                               'mm',
                                                               'linear',
                                                               verbose=0)

        spinalcordtoolbox.resample.nipy_resample.resample_file(
            centerline_filename,
            centerline_filename,
            new_resolution,
            'mm',
            'linear',
            verbose=0)

    if bool_2d:
        from sct_image import split_data
        im_split_lst = split_data(Image(centerline_filename), dim=2)
        im_split_lst[0].save(centerline_filename)

    return fname_res, centerline_filename