예제 #1
0
 def _resample_slicewise(self, image, p_resample, type_img, image_ref=None):
     """
     Resample at a fixed resolution to make sure the cord always appears with similar scale, regardless of the native
     resolution of the image. Assumes SAL orientation.
     :param image: Image() to resample
     :param p_resample: float: Resampling resolution in mm
     :param type_img: {'im', 'seg'}: If im, interpolate using spline. If seg, interpolate using linear then binarize.
     :param image_ref: Destination Image() to resample image to.
     :return:
     """
     dict_interp = {'im': 'spline', 'seg': 'linear'}
     # Create nibabel object
     nii = Nifti1Image(image.data, image.hdr.get_best_affine())
     # If no reference image is provided, resample to specified resolution
     if image_ref is None:
         # Resample to px x p_resample x p_resample mm (orientation is SAL by convention in QC module)
         nii_r = resample_nib(nii, new_size=[image.dim[4], p_resample, p_resample], new_size_type='mm',
                              interpolation=dict_interp[type_img])
     # Otherwise, resampling to the space of the reference image
     else:
         # Create nibabel object for reference image
         nii_ref = Nifti1Image(image_ref.data, image_ref.hdr.get_best_affine())
         nii_r = resample_nib(nii, image_dest=nii_ref, interpolation=dict_interp[type_img])
     # If resampled image is a segmentation, binarize using threshold at 0.5
     if type_img == 'seg':
         img_r_data = (nii_r.get_data() > 0.5) * 1
     else:
         img_r_data = nii_r.get_data()
     # Create Image objects
     image_r = Image(img_r_data, hdr=nii_r.header, dim=nii_r.header.get_data_shape()). \
         change_orientation(image.orientation)
     return image_r
예제 #2
0
def test_nib_resample_image_3d_to_dest(fake_3dimage_nib, fake_3dimage_nib_big):
    """Test resampling with 3D nibabel image"""
    img_r = resampling.resample_nib(fake_3dimage_nib,
                                    img_dest=fake_3dimage_nib_big,
                                    interpolation='linear')
    assert img_r.get_data().shape == (29, 39, 19)
    assert img_r.get_data()[4, 4, 4] == 1.0
예제 #3
0
def segment_file(input_filename, output_filename, model_name, threshold,
                 verbosity, use_tta):
    """Segment a volume file.

    :param input_filename: the input filename.
    :param output_filename: the output filename.
    :param model_name: the name of model to use.
    :param threshold: threshold to apply in predictions (if None,
                      no threshold will be applied)
    :param verbosity: the verbosity level.
    :param use_tta: whether it should use TTA (test-time augmentation)
                    or not.
    :return: the output filename.
    """
    nii_original = nib.load(input_filename)
    pixdim = nii_original.header["pixdim"][3]
    target_resample = [0.25, 0.25, pixdim]

    nii_resampled = resampling.resample_nib(nii_original,
                                            new_size=target_resample,
                                            new_size_type='mm',
                                            interpolation='linear')
    pred_slices = segment_volume(nii_resampled, model_name, threshold, use_tta)

    original_res = [
        nii_original.header["pixdim"][1], nii_original.header["pixdim"][2],
        nii_original.header["pixdim"][3]
    ]

    volume_affine = nii_resampled.affine
    volume_header = nii_resampled.header
    nii_segmentation = nib.Nifti1Image(pred_slices, volume_affine,
                                       volume_header)
    nii_resampled_original = resampling.resample_nib(nii_segmentation,
                                                     new_size=original_res,
                                                     new_size_type='mm',
                                                     interpolation='linear')
    res_data = nii_resampled_original.get_data()

    # Threshold after resampling, only if specified
    if threshold is not None:
        res_data = threshold_predictions(res_data, 0.5)

    nib.save(nii_resampled_original, output_filename)
    return output_filename
예제 #4
0
def test_nib_resample_image_3d(fake_3dimage_nib):
    """Test resampling with 3D nibabel image"""
    img_r = resampling.resample_nib(fake_3dimage_nib,
                                    new_size=[2, 2, 1],
                                    new_size_type='factor',
                                    interpolation='nn')
    assert img_r.get_data().shape == (18, 18, 9)
    assert img_r.get_data(
    )[8, 8,
      4] == 1.0  # make sure there is no displacement in world coordinate system
    assert img_r.header.get_zooms() == (0.5, 0.5, 1.0)
예제 #5
0
 def get_bbox_from_ref(self, img_ref):
     """
     Get bounding box from input reference image, by looking at min/max indices in each dimension.
     img_ref and self.img_in should have the same dimensions.
     """
     from spinalcordtoolbox.resampling import resample_nib
     #  Check that img_ref has the same length as img_in
     if not len(img_ref.data.shape) == len(self.img_in.data.shape):
         logger.error("Inconsistent dimensions: n_dim(img_ref)={}; n_dim(img_in)={}"
                      .format(len(img_ref.data.shape), len(self.img_in.data.shape)))
         raise Exception(ValueError)
     # Fill reference data with ones
     img_ref.data[:] = 1
     # Resample new image (in reference coordinates) into input image
     img_ref_r = resample_nib(img_ref, image_dest=self.img_in, interpolation='nn', mode='constant')
     # img_ref_r.save('test.nii')  # for debug
     # Get bbox from this resampled mask
     self.get_bbox_from_mask(img_ref_r)
def dummy_segmentation(size_arr=(256, 256, 256),
                       pixdim=(1, 1, 1),
                       dtype=np.float64,
                       orientation='LPI',
                       shape='rectangle',
                       angle_RL=0,
                       angle_AP=0,
                       angle_IS=0,
                       radius_RL=5.0,
                       radius_AP=3.0,
                       degree=2,
                       interleaved=False,
                       zeroslice=[],
                       debug=False):
    """Create a dummy Image with a ellipse or ones running from top to bottom in the 3rd dimension, and rotate the image
    to make sure that compute_csa and compute_shape properly estimate the centerline angle.
    :param size_arr: tuple: (nx, ny, nz)
    :param pixdim: tuple: (px, py, pz)
    :param dtype: Numpy dtype.
    :param orientation: Orientation of the image. Default: LPI
    :param shape: {'rectangle', 'ellipse'}
    :param angle_RL: int: angle around RL axis (in deg)
    :param angle_AP: int: angle around AP axis (in deg)
    :param angle_IS: int: angle around IS axis (in deg)
    :param radius_RL: float: 1st radius. With a, b = 50.0, 30.0 (in mm), theoretical CSA of ellipse is 4712.4
    :param radius_AP: float: 2nd radius
    :param degree: int: degree of polynomial fit
    :param interleaved: bool: create a dummy segmentation simulating interleaved acquisition
    :param zeroslice: list int: zero all slices listed in this param
    :param debug: Write temp files for debug
    :return: img: Image object
    """
    # Initialization
    padding = 15  # Padding size (isotropic) to avoid edge effect during rotation
    # Create a 3d array, with dimensions corresponding to x: RL, y: AP, z: IS
    nx, ny, nz = [int(size_arr[i] * pixdim[i]) for i in range(3)]
    data = np.zeros((nx, ny, nz))
    xx, yy = np.mgrid[:nx, :ny]

    # Create a dummy segmentation using polynomial function
    # create regularized curve, within Y-Z plane (A-P), located at x=nx/2:
    x = [round(nx / 2.)] * len(range(nz))
    # and passing through the following points:
    #y = np.array([round(ny / 4.), round(ny / 2.), round(3 * ny / 4.)])  # oblique curve (changing AP points across SI)
    y = [round(ny / 2.), round(ny / 2.),
         round(ny / 2.)]  # straight curve (same location of AP across SI)
    z = np.array([0, round(nz / 2.), nz - 1])
    # we use poly (instead of bspline) in order to allow change of scalar for each term of polynomial function
    p = np.polynomial.Polynomial.fit(z, y, deg=degree)

    # create two polynomial fits, by choosing random scalar for each term of both polynomial functions and then
    # interleave these two fits (one for odd slices, second one for even slices)
    if interleaved:
        p_even = copy.copy(p)
        p_odd = copy.copy(p)
        # choose random scalar for each term of polynomial function
        # even slices
        p_even.coef = [element * uniform(0.5, 1) for element in p_even.coef]
        # odd slices
        p_odd.coef = [element * uniform(0.5, 1) for element in p_odd.coef]
        # performs two polynomial fits - one will serve for even slices, second one for odd slices
        yfit_even = np.round(p_even(range(nz)))
        yfit_odd = np.round(p_odd(range(nz)))

        # combine even and odd polynomial fits
        yfit = np.zeros(nz)
        yfit[0:nz:2] = yfit_even[0:nz:2]
        yfit[1:nz:2] = yfit_odd[1:nz:2]
    # IF INTERLEAVED=FALSE, perform only one polynomial fit without modification of term's scalars
    else:
        yfit = np.round(
            p(range(nz))
        )  # has to be rounded for correct float -> int conversion in next step

    yfit = yfit.astype(np.int)
    # loop across slices and add object
    for iz in range(nz):
        if shape == 'rectangle':  # theoretical CSA: (a*2+1)(b*2+1)
            data[:, :, iz] = ((abs(xx - x[iz]) <= radius_RL) &
                              (abs(yy - yfit[iz]) <= radius_AP)) * 1
        if shape == 'ellipse':
            data[:, :, iz] = (((xx - x[iz]) / radius_RL)**2 +
                              ((yy - yfit[iz]) / radius_AP)**2 <= 1) * 1

    # Pad to avoid edge effect during rotation
    data = np.pad(data, padding, 'reflect')

    # ROTATION ABOUT IS AXIS
    # rotate (in deg), and re-grid using linear interpolation
    data_rotIS = rotate(data,
                        angle_IS,
                        resize=False,
                        center=None,
                        order=1,
                        mode='constant',
                        cval=0,
                        clip=False,
                        preserve_range=False)

    # ROTATION ABOUT RL AXIS
    # Swap x-z axes (to make a rotation within y-z plane, because rotate will apply rotation on the first 2 dims)
    data_rotIS_swap = data_rotIS.swapaxes(0, 2)
    # rotate (in deg), and re-grid using linear interpolation
    data_rotIS_swap_rotRL = rotate(data_rotIS_swap,
                                   angle_RL,
                                   resize=False,
                                   center=None,
                                   order=1,
                                   mode='constant',
                                   cval=0,
                                   clip=False,
                                   preserve_range=False)
    # swap back
    data_rotIS_rotRL = data_rotIS_swap_rotRL.swapaxes(0, 2)

    # ROTATION ABOUT AP AXIS
    # Swap y-z axes (to make a rotation within x-z plane)
    data_rotIS_rotRL_swap = data_rotIS_rotRL.swapaxes(1, 2)
    # rotate (in deg), and re-grid using linear interpolation
    data_rotIS_rotRL_swap_rotAP = rotate(data_rotIS_rotRL_swap,
                                         angle_AP,
                                         resize=False,
                                         center=None,
                                         order=1,
                                         mode='constant',
                                         cval=0,
                                         clip=False,
                                         preserve_range=False)
    # swap back
    data_rot = data_rotIS_rotRL_swap_rotAP.swapaxes(1, 2)

    # Crop image (to remove padding)
    data_rot_crop = data_rot[padding:nx + padding, padding:ny + padding,
                             padding:nz + padding]

    # Zero specified slices
    if zeroslice is not []:
        data_rot_crop[:, :, zeroslice] = 0

    # Create nibabel object
    xform = np.eye(4)
    for i in range(3):
        xform[i][i] = 1  # in [mm]
    nii = nib.nifti1.Nifti1Image(data_rot_crop.astype('float32'), xform)
    # resample to desired resolution
    nii_r = resample_nib(nii,
                         new_size=pixdim,
                         new_size_type='mm',
                         interpolation='linear')
    # Create Image object. Default orientation is LPI.
    # For debugging add .save() at the end of the command below
    img = Image(nii_r.get_data(),
                hdr=nii_r.header,
                dim=nii_r.header.get_data_shape())
    # Update orientation
    img.change_orientation(orientation)
    if debug:
        img.save('tmp_dummy_seg_' + datetime.now().strftime("%Y%m%d%H%M%S%f") +
                 '.nii.gz')
    return img
예제 #7
0
def compute_shape(segmentation,
                  angle_correction=True,
                  param_centerline=None,
                  verbose=1):
    """
    Compute morphometric measures of the spinal cord in the transverse (axial) plane from the segmentation.
    The segmentation could be binary or weighted for partial volume [0,1].

    :param segmentation: input segmentation. Could be either an Image or a file name.
    :param angle_correction:
    :param param_centerline: see centerline.core.ParamCenterline()
    :param verbose:
    :return metrics: Dict of class Metric(). If a metric cannot be calculated, its value will be nan.
    :return fit_results: class centerline.core.FitResults()
    """
    # List of properties to output (in the right order)
    property_list = [
        'area', 'angle_AP', 'angle_RL', 'diameter_AP', 'diameter_RL',
        'eccentricity', 'orientation', 'solidity', 'length'
    ]

    im_seg = Image(segmentation).change_orientation('RPI')
    # Getting image dimensions. x, y and z respectively correspond to RL, PA and IS.
    nx, ny, nz, nt, px, py, pz, pt = im_seg.dim
    pr = min([px, py])
    # Resample to isotropic resolution in the axial plane. Use the minimum pixel dimension as target dimension.
    im_segr = resample_nib(im_seg,
                           new_size=[pr, pr, pz],
                           new_size_type='mm',
                           interpolation='linear')

    # Update dimensions from resampled image.
    nx, ny, nz, nt, px, py, pz, pt = im_segr.dim

    # Extract min and max index in Z direction
    data_seg = im_segr.data
    X, Y, Z = (data_seg > 0).nonzero()
    min_z_index, max_z_index = min(Z), max(Z)

    # Initialize dictionary of property_list, with 1d array of nan (default value if no property for a given slice).
    shape_properties = {
        key: np.full_like(np.empty(nz), np.nan, dtype=np.double)
        for key in property_list
    }

    fit_results = None

    if angle_correction:
        # compute the spinal cord centerline based on the spinal cord segmentation
        # here, param_centerline.minmax needs to be False because we need to retrieve the total number of input slices
        _, arr_ctl, arr_ctl_der, fit_results = get_centerline(
            im_segr, param=param_centerline, verbose=verbose)

    # Loop across z and compute shape analysis
    for iz in sct_progress_bar(range(min_z_index, max_z_index + 1),
                               unit='iter',
                               unit_scale=False,
                               desc="Compute shape analysis",
                               ascii=True,
                               ncols=80):
        # Extract 2D patch
        current_patch = im_segr.data[:, :, iz]
        if angle_correction:
            # Extract tangent vector to the centerline (i.e. its derivative)
            tangent_vect = np.array([
                arr_ctl_der[0][iz - min_z_index] * px,
                arr_ctl_der[1][iz - min_z_index] * py, pz
            ])
            # Normalize vector by its L2 norm
            tangent_vect = tangent_vect / np.linalg.norm(tangent_vect)
            # Compute the angle about AP axis between the centerline and the normal vector to the slice
            v0 = [tangent_vect[0], tangent_vect[2]]
            v1 = [0, 1]
            angle_AP_rad = np.math.atan2(np.linalg.det([v0, v1]),
                                         np.dot(v0, v1))
            # Compute the angle about RL axis between the centerline and the normal vector to the slice
            v0 = [tangent_vect[1], tangent_vect[2]]
            v1 = [0, 1]
            angle_RL_rad = np.math.atan2(np.linalg.det([v0, v1]),
                                         np.dot(v0, v1))
            # Apply affine transformation to account for the angle between the centerline and the normal to the patch
            tform = transform.AffineTransform(scale=(np.cos(angle_RL_rad),
                                                     np.cos(angle_AP_rad)))
            # Convert to float64, to avoid problems in image indexation causing issues when applying transform.warp
            current_patch = current_patch.astype(np.float64)
            # TODO: make sure pattern does not go extend outside of image border
            current_patch_scaled = transform.warp(
                current_patch,
                tform.inverse,
                output_shape=current_patch.shape,
                order=1,
            )
        else:
            current_patch_scaled = current_patch
            angle_AP_rad, angle_RL_rad = 0.0, 0.0
        # compute shape properties on 2D patch
        shape_property = _properties2d(current_patch_scaled, [px, py])
        if shape_property is not None:
            # Add custom fields
            shape_property['angle_AP'] = angle_AP_rad * 180.0 / math.pi
            shape_property['angle_RL'] = angle_RL_rad * 180.0 / math.pi
            shape_property['length'] = pz / (np.cos(angle_AP_rad) *
                                             np.cos(angle_RL_rad))
            # Loop across properties and assign values for function output
            for property_name in property_list:
                shape_properties[property_name][iz] = shape_property[
                    property_name]
        else:
            logging.warning('\nNo properties for slice: {}'.format(iz))
        """ DEBUG
        from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
        from matplotlib.figure import Figure
        fig = Figure()
        FigureCanvas(fig)
        ax = fig.add_subplot(111)
        ax.imshow(current_patch_scaled)
        ax.grid()
        ax.set_xlabel('y')
        ax.set_ylabel('x')
        fig.savefig('tmp_fig.png')
        """
    metrics = {}
    for key, value in shape_properties.items():
        # Making sure all entries added to metrics have results
        if not value == []:
            metrics[key] = Metric(data=np.array(value), label=key)

    return metrics, fit_results
예제 #8
0
def _preprocess_segment(fname_t2, fname_t2_seg, contrast_test, dim_3=False):
    tmp_folder = sct.TempFolder()
    tmp_folder_path = tmp_folder.get_path()
    tmp_folder.chdir()

    img = Image(fname_t2)
    gt = Image(fname_t2_seg)

    fname_t2_RPI, fname_t2_seg_RPI = 'img_RPI.nii.gz', 'seg_RPI.nii.gz'

    img.change_orientation('RPI')
    gt.change_orientation('RPI')
    new_resolution = 'x'.join(['0.5', '0.5', str(img.dim[6])])

    img_res = \
        resampling.resample_nib(img, new_size=[0.5, 0.5, img.dim[6]], new_size_type='mm', interpolation='linear')
    gt_res = \
        resampling.resample_nib(gt, new_size=[0.5, 0.5, img.dim[6]], new_size_type='mm', interpolation='linear')

    img_res.save(fname_t2_RPI)

    _, ctr_im, _ = deepseg_sc.find_centerline(algo='svm',
                                              image_fname=fname_t2_RPI,
                                              contrast_type=contrast_test,
                                              brain_bool=False,
                                              folder_output=tmp_folder_path,
                                              remove_temp_files=1,
                                              centerline_fname=None)

    _, _, _, img = deepseg_sc.crop_image_around_centerline(im_in=img_res,
                                                           ctr_in=ctr_im,
                                                           crop_size=64)
    _, _, _, gt = deepseg_sc.crop_image_around_centerline(im_in=gt_res,
                                                          ctr_in=ctr_im,
                                                          crop_size=64)
    del ctr_im

    img = deepseg_sc.apply_intensity_normalization(im_in=img)

    if dim_3:  # If 3D kernels
        fname_t2_RPI_res_crop, fname_t2_seg_RPI_res_crop = 'img_RPI_res_crop.nii.gz', 'seg_RPI_res_crop.nii.gz'
        img.save(fname_t2_RPI_res_crop)
        gt.save(fname_t2_seg_RPI_res_crop)
        del img, gt

        fname_t2_RPI_res_crop_res = 'img_RPI_res_crop_res.nii.gz'
        fname_t2_seg_RPI_res_crop_res = 'seg_RPI_res_crop_res.nii.gz'
        resampling.resample_file(fname_t2_RPI_res_crop,
                                 fname_t2_RPI_res_crop_res,
                                 new_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        resampling.resample_file(fname_t2_seg_RPI_res_crop,
                                 fname_t2_seg_RPI_res_crop_res,
                                 new_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        img, gt = Image(fname_t2_RPI_res_crop_res), Image(
            fname_t2_seg_RPI_res_crop_res)

    tmp_folder.chdir_undo()
    tmp_folder.cleanup()

    return img, gt
예제 #9
0
def deep_segmentation_MSlesion(im_image,
                               contrast_type,
                               ctr_algo='svm',
                               ctr_file=None,
                               brain_bool=True,
                               remove_temp_files=1,
                               verbose=1):
    """
    Segment lesions from MRI data.

    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()
    fname_in = im_image.absolutepath

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    input_resolution = im_image.dim[4:7]

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    _, im_ctl, im_labels_viewer = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type_ctr,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)
    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_image_res, ctr_in=im_ctl, crop_size=crop_size)
    del im_ctl

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii,
                                               contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm,
                             fname_res3d,
                             '0.5x0.5x0.5',
                             'mm',
                             'linear',
                             verbose=0)

    # segment data using 3D convolutions
    logger.info(
        "\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = sct_dir_local_path(
        'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res,
                             fname_seg_res2d,
                             initial_2d_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    logger.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_image_res,
                                  data_crop=seg_crop.copy().data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI,
                             fname_seg_RPI,
                             initial_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        im_labels_viewer_nib = nib.nifti1.Nifti1Image(
            im_labels_viewer.data, im_labels_viewer.hdr.get_best_affine())
        im_viewer_r_nib = resampling.resample_nib(im_labels_viewer_nib,
                                                  new_size=input_resolution,
                                                  new_size_type='mm',
                                                  interpolation='linear')
        im_viewer = Image(
            im_viewer_r_nib.get_data(),
            hdr=im_viewer_r_nib.header,
            orientation='RPI',
            dim=im_viewer_r_nib.header.get_data_shape()).change_orientation(
                original_orientation)

    else:
        im_viewer = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr,
                                 fname_res_ctr,
                                 initial_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(
            original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info(
        "\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # change data type
    seg_initres_nii.change_type(np.uint8)

    # reorient to initial orientation
    logger.info(
        "\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(
        original_orientation), im_viewer, im_image_res_ctr_downsamp
예제 #10
0
def deep_segmentation_spinalcord(im_image, contrast_type, ctr_algo='cnn', ctr_file=None, brain_bool=True,
                                 kernel_size='2d', remove_temp_files=1, verbose=1):
    """Pipeline"""
    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI').save(fname_orient)

    input_resolution = im_image.dim[4:7]

    # find the spinal cord centerline - execute OptiC binary
    logger.info("Finding the spinal cord centerline...")
    fname_res, centerline_filename, im_labels_viewer = find_centerline(algo=ctr_algo,
                                                                        image_fname=fname_orient,
                                                                        contrast_type=contrast_type,
                                                                        brain_bool=brain_bool,
                                                                        folder_output=tmp_folder_path,
                                                                        remove_temp_files=remove_temp_files,
                                                                        centerline_fname=file_ctr)

    im_nii, ctr_nii = Image(fname_res), Image(centerline_filename)

    # crop image around the spinal cord centerline
    logger.info("Cropping the image around the spinal cord...")
    crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_nii,
                                                                                   ctr_in=ctr_nii,
                                                                                   crop_size=crop_size)
    del ctr_nii

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
    del im_crop_nii

    if kernel_size == '2d':
        # segment data using 2D convolutions
        logger.info("Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
        seg_crop = segment_2d(model_fname=segmentation_model_fname,
                                 contrast_type=contrast_type,
                                 input_size=(crop_size, crop_size),
                                 im_in=im_norm_in)
    elif kernel_size == '3d':
        # segment data using 3D convolutions
        logger.info("Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
        seg_crop = segment_3d(model_fname=segmentation_model_fname,
                                 contrast_type=contrast_type,
                                 im_in=im_norm_in)
    del im_norm_in

    # reconstruct the segmentation from the crop data
    logger.info("Reassembling the image...")
    im_seg = uncrop_image(ref_in=im_nii,
                          data_crop=seg_crop,
                          x_crop_lst=X_CROP_LST,
                          y_crop_lst=Y_CROP_LST,
                          z_crop_lst=Z_CROP_LST)
    # fname_res_seg = sct.add_suffix(fname_res, '_seg')
    # seg_uncrop_nii.save(fname_res_seg)  # for debugging
    del seg_crop

    # resample to initial resolution
    logger.info("Resampling the segmentation to the native image resolution using linear interpolation...")
    # create nibabel object
    seg_uncrop_nii_nib = nib.nifti1.Nifti1Image(im_seg.data, im_seg.hdr.get_best_affine())
    seg_uncrop_nii_nibr = resampling.resample_nib(seg_uncrop_nii_nib, new_size=input_resolution, new_size_type='mm',
                                                  interpolation='linear')
    # Convert back to Image type
    im_seg_r = Image(seg_uncrop_nii_nibr.get_data(), hdr=seg_uncrop_nii_nibr.header, orientation='RPI',
                     dim=seg_uncrop_nii_nibr.header.get_data_shape())

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        im_labels_viewer_nib = nib.nifti1.Nifti1Image(im_labels_viewer.data, im_labels_viewer.hdr.get_best_affine())
        im_viewer_r_nib = resampling.resample_nib(im_labels_viewer_nib, new_size=input_resolution, new_size_type='mm',
                                                    interpolation='linear')
        im_viewer = Image(im_viewer_r_nib.get_data(), hdr=im_viewer_r_nib.header, orientation='RPI',
                            dim=im_viewer_r_nib.header.get_data_shape()).change_orientation(original_orientation)
    else:
        im_viewer = None

    # TODO: Deal with that later-- ideally this file should be written when debugging, not with verbose=2
    # if verbose == 2:
    #     fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
    #     resampling.resample_file(fname_res_ctr, fname_res_ctr, initial_resolution, 'mm', 'linear', verbose=0)
    #     im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(original_orientation)
    # else:
    im_image_res_ctr_downsamp = None

    # Binarize the resampled image to remove interpolation effects
    logger.info("Binarizing the resampled segmentation...")
    thr = 0.0001 if contrast_type in ['t1', 'dwi'] else 0.5
    # TODO: optimize speed --> np.where is slow
    im_seg_r.data[np.where(im_seg_r.data >= thr)] = 1
    im_seg_r.data[np.where(im_seg_r.data < thr)] = 0

    # post processing step to z_regularized
    im_seg_r_postproc = post_processing_volume_wise(im_seg_r)

    # change data type
    im_seg_r_postproc.change_type(np.uint8)

    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("Remove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return im_seg_r_postproc.change_orientation(original_orientation), \
           im_nii, \
           im_seg.change_orientation('RPI'), \
           im_viewer, \
           im_image_res_ctr_downsamp
예제 #11
0
def deep_segmentation_spinalcord(im_image, contrast_type, ctr_algo='cnn', ctr_file=None, brain_bool=True,
                                 kernel_size='2d', threshold_seg=None, remove_temp_files=1, verbose=1):
    """
    Main pipeline for CNN-based segmentation of the spinal cord.

    :param im_image:
    :param contrast_type: {'t1', 't2', t2s', 'dwi'}
    :param ctr_algo:
    :param ctr_file:
    :param brain_bool:
    :param kernel_size:
    :param threshold_seg: Binarization threshold (between 0 and 1) to apply to the segmentation prediction. Set to -1
        for no binarization (i.e. soft segmentation output)
    :param remove_temp_files:
    :param verbose:
    :return:
    """
    if threshold_seg is None:
        threshold_seg = THR_DEEPSEG[contrast_type]

    # Display stuff
    logger.info("Config deepseg_sc:")
    logger.info("  Centerline algorithm: {}".format(ctr_algo))
    logger.info("  Brain in image: {}".format(brain_bool))
    logger.info("  Kernel dimension: {}".format(kernel_size))
    logger.info("  Contrast: {}".format(contrast_type))
    logger.info("  Threshold: {}".format(threshold_seg))

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("Finding the spinal cord centerline...")
    _, im_ctl, im_labels_viewer = find_centerline(algo=ctr_algo,
                                                    image_fname=fname_orient,
                                                    contrast_type=contrast_type,
                                                    brain_bool=brain_bool,
                                                    folder_output=tmp_folder_path,
                                                    remove_temp_files=remove_temp_files,
                                                    centerline_fname=file_ctr)

    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("Cropping the image around the spinal cord...")
    crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_image_res,
                                                                                   ctr_in=im_ctl,
                                                                                   crop_size=crop_size)

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
    del im_crop_nii

    if kernel_size == '2d':
        # segment data using 2D convolutions
        logger.info("Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
        seg_crop = segment_2d(model_fname=segmentation_model_fname,
                              contrast_type=contrast_type,
                              input_size=(crop_size, crop_size),
                              im_in=im_norm_in)
    elif kernel_size == '3d':
        # segment data using 3D convolutions
        logger.info("Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
        seg_crop = segment_3d(model_fname=segmentation_model_fname,
                              contrast_type=contrast_type,
                              im_in=im_norm_in)

    # Postprocessing
    seg_crop_postproc = np.zeros_like(seg_crop)
    x_cOm, y_cOm = None, None
    for zz in range(im_norm_in.dim[2]):
        # Fill holes (only for binary segmentations)
        if threshold_seg >= 0:
            pred_seg_th = fill_holes_2d((seg_crop[:, :, zz] > threshold_seg).astype(int))
            pred_seg_pp = keep_largest_object(pred_seg_th, x_cOm, y_cOm)
            # Update center of mass for slice i+1
            if 1 in pred_seg_pp:
                x_cOm, y_cOm = center_of_mass(pred_seg_pp)
                x_cOm, y_cOm = np.round(x_cOm), np.round(y_cOm)
        else:
            # If soft segmentation, do nothing
            pred_seg_pp = seg_crop[:, :, zz]

        seg_crop_postproc[:, :, zz] = pred_seg_pp  # dtype is float32

    # reconstruct the segmentation from the crop data
    logger.info("Reassembling the image...")
    im_seg = uncrop_image(ref_in=im_image_res,
                          data_crop=seg_crop_postproc,
                          x_crop_lst=X_CROP_LST,
                          y_crop_lst=Y_CROP_LST,
                          z_crop_lst=Z_CROP_LST)
    # seg_uncrop_nii.save(sct.add_suffix(fname_res, '_seg'))  # for debugging
    del seg_crop, seg_crop_postproc, im_norm_in

    # resample to initial resolution
    logger.info("Resampling the segmentation to the native image resolution using linear interpolation...")
    im_seg_r = resampling.resample_nib(im_seg, image_dest=im_image, interpolation='linear')

    if ctr_algo == 'viewer':  # for debugging
        im_labels_viewer.save(sct.add_suffix(fname_orient, '_labels-viewer'))

    # Binarize the resampled image (except for soft segmentation, defined by threshold_seg=-1)
    if threshold_seg >= 0:
        logger.info("Binarizing the resampled segmentation...")
        im_seg_r.data = (im_seg_r.data > 0.5).astype(np.uint8)

    # post processing step to z_regularized
    im_seg_r_postproc = post_processing_volume_wise(im_seg_r)

    # Change data type. By default, dtype is float32
    if threshold_seg >= 0:
        im_seg_r_postproc.change_type(np.uint8)

    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("Remove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    im_seg_r_postproc.change_orientation(original_orientation)

    # copy q/sform from input image to output segmentation
    im_seg.copy_qform_from_ref(im_image)

    return im_seg_r_postproc, im_image_res, im_seg.change_orientation('RPI')
예제 #12
0
def dummy_segmentation(size_arr=(256, 256, 256),
                       pixdim=(1, 1, 1),
                       dtype=np.float64,
                       orientation='LPI',
                       shape='rectangle',
                       angle_RL=0,
                       angle_AP=0,
                       angle_IS=0,
                       radius_RL=5.0,
                       radius_AP=3.0,
                       interleaved=False,
                       factor=1,
                       zeroslice=[],
                       debug=False):
    """Create a dummy Image with a ellipse or ones running from top to bottom in the 3rd dimension, and rotate the image
    to make sure that compute_csa and compute_shape properly estimate the centerline angle.
    :param size_arr: tuple: (nx, ny, nz)
    :param pixdim: tuple: (px, py, pz)
    :param dtype: Numpy dtype.
    :param orientation: Orientation of the image. Default: LPI
    :param shape: {'rectangle', 'ellipse'}
    :param angle_RL: int: angle around RL axis (in deg)
    :param angle_AP: int: angle around AP axis (in deg)
    :param angle_IS: int: angle around IS axis (in deg)
    :param radius_RL: float: 1st radius. With a, b = 50.0, 30.0 (in mm), theoretical CSA of ellipse is 4712.4
    :param radius_AP: float: 2nd radius
    :param interleaved: bool: use polynomial function to simulate slicewise motion
    :param zeroslice: list int: zero all slices listed in this param
    :param debug: Write temp files for debug
    :return: img: Image object
    """
    # Initialization
    padding = 15  # Padding size (isotropic) to avoid edge effect during rotation
    # Create a 3d array, with dimensions corresponding to x: RL, y: AP, z: IS
    nx, ny, nz = [int(size_arr[i] * pixdim[i]) for i in range(3)]
    data = np.random.random((nx, ny, nz)) * 0.
    xx, yy = np.mgrid[:nx, :ny]
    if not interleaved:
        # loop across slices and add object
        for iz in range(nz):
            if shape == 'rectangle':  # theoretical CSA: (a*2+1)(b*2+1)
                data[:, :, iz] = ((abs(xx - nx / 2) <= radius_RL) &
                                  (abs(yy - ny / 2) <= radius_AP)) * 1
            if shape == 'ellipse':
                data[:, :, iz] = (((xx - nx / 2) / radius_RL)**2 +
                                  ((yy - ny / 2) / radius_AP)**2 <= 1) * 1
    elif interleaved:
        # define array based on a polynomial function, within Y-Z plane to simulate slicewise motion in A-P
        y = np.matlib.repmat([
            round(nx / 2.) + pixdim[0] * factor,
            round(nx / 2.) - pixdim[0] * factor
        ], 1, round(nz / 2))
        if nz % 2 != 0:  # if z-dimension is odd, add one more element to fit size
            y = numpy.append(y, round(nx / 2.) + pixdim[0] * factor)
        y = y.reshape(nz)  # reshape to vector (1,R) -> (R,)
        z = np.arange(0, nz)
        p = np.poly1d(np.polyfit(z, y, deg=nz))
        # loop across slices and add object
        for iz in range(nz):
            if shape == 'rectangle':  # theoretical CSA: (a*2+1)(b*2+1)
                data[:, :, iz] = ((abs(xx - nx / 2) <= radius_RL) &
                                  (abs(yy - p(iz)) <= radius_AP)) * 1
            if shape == 'ellipse':
                data[:, :, iz] = (((xx - nx / 2) / radius_RL)**2 +
                                  ((yy - p(iz)) / radius_AP)**2 <= 1) * 1

    # Pad to avoid edge effect during rotation
    data = np.pad(data, padding, 'reflect')

    # ROTATION ABOUT IS AXIS
    # rotate (in deg), and re-grid using linear interpolation
    data_rotIS = rotate(data,
                        angle_IS,
                        resize=False,
                        center=None,
                        order=1,
                        mode='constant',
                        cval=0,
                        clip=False,
                        preserve_range=False)

    # ROTATION ABOUT RL AXIS
    # Swap x-z axes (to make a rotation within y-z plane, because rotate will apply rotation on the first 2 dims)
    data_rotIS_swap = data_rotIS.swapaxes(0, 2)
    # rotate (in deg), and re-grid using linear interpolation
    data_rotIS_swap_rotRL = rotate(data_rotIS_swap,
                                   angle_RL,
                                   resize=False,
                                   center=None,
                                   order=1,
                                   mode='constant',
                                   cval=0,
                                   clip=False,
                                   preserve_range=False)
    # swap back
    data_rotIS_rotRL = data_rotIS_swap_rotRL.swapaxes(0, 2)

    # ROTATION ABOUT AP AXIS
    # Swap y-z axes (to make a rotation within x-z plane)
    data_rotIS_rotRL_swap = data_rotIS_rotRL.swapaxes(1, 2)
    # rotate (in deg), and re-grid using linear interpolation
    data_rotIS_rotRL_swap_rotAP = rotate(data_rotIS_rotRL_swap,
                                         angle_AP,
                                         resize=False,
                                         center=None,
                                         order=1,
                                         mode='constant',
                                         cval=0,
                                         clip=False,
                                         preserve_range=False)
    # swap back
    data_rot = data_rotIS_rotRL_swap_rotAP.swapaxes(1, 2)

    # Crop image (to remove padding)
    data_rot_crop = data_rot[padding:nx + padding, padding:ny + padding,
                             padding:nz + padding]

    # Zero specified slices
    if zeroslice is not []:
        data_rot_crop[:, :, zeroslice] = 0

    # Create nibabel object
    xform = np.eye(4)
    for i in range(3):
        xform[i][i] = 1  # in [mm]
    nii = nib.nifti1.Nifti1Image(data_rot_crop.astype('float32'), xform)
    # resample to desired resolution
    nii_r = resample_nib(nii,
                         new_size=pixdim,
                         new_size_type='mm',
                         interpolation='linear')
    # Create Image object. Default orientation is LPI.
    # For debugging add .save() at the end of the command below
    img = Image(nii_r.get_data(),
                hdr=nii_r.header,
                dim=nii_r.header.get_data_shape())
    # Update orientation
    img.change_orientation(orientation)
    if debug:
        img.save('tmp_dummy_seg_' + datetime.now().strftime("%Y%m%d%H%M%S%f") +
                 '.nii.gz')
    return img
예제 #13
0
def main():
    # find all the images of interest and store the mid slice in slice_lst
    slice_lst = []
    for x in os.walk(i_folder):
        for file in glob.glob(
                os.path.join(x[0], 'sub' + im_string)
        ):  # prefixe sub: to prevent from fetching warp files
            print('\nLoading: ' + file)
            # load data
            if plane == 'ax':
                file_seg = glob.glob(os.path.join(x[0], 'sub' + seg_string))[0]

                # workaround to save some time
                img, seg = Image(file).change_orientation('RPI'), Image(
                    file_seg).change_orientation('RPI')
                mid_slice_idx = int(float(img.dim[2]) // 2)
                nii_mid = nib.nifti1.Nifti1Image(img.data[:, :, mid_slice_idx],
                                                 affine)
                nii_mid_seg = nib.nifti1.Nifti1Image(
                    seg.data[:, :, mid_slice_idx], affine)
                img_mid = Image(img.data[:, :, mid_slice_idx],
                                hdr=nii_mid.header,
                                dim=nii_mid.header.get_data_shape())
                seg_mid = Image(seg.data[:, :, mid_slice_idx],
                                hdr=nii_mid_seg.header,
                                dim=nii_mid_seg.header.get_data_shape())
                del img, seg

                qcslice_cur = qcslice.Axial([img_mid, seg_mid])
                center_x_lst, center_y_lst = qcslice_cur.get_center(
                )  # find seg center of mass
                mid_slice = qcslice_cur.get_slice(qcslice_cur._images[0].data,
                                                  0)  # get the mid slice
                # crop image around SC seg
                mid_slice = qcslice_cur.crop(mid_slice, int(center_x_lst[0]),
                                             int(center_y_lst[0]), 30, 30)
            else:
                sag_im = Image(file).change_orientation('RSP')
                if not np.isclose(
                        sag_im.dim[5],
                        sag_im.dim[6]):  # in case data is anisotropic
                    sag_im = resample_nib(
                        sag_im.copy(),
                        new_size=[sag_im.dim[4], sag_im.dim[5], sag_im.dim[5]],
                        new_size_type='mm')
                mid_slice_idx = int(sag_im.dim[0] // 2)
                mid_slice = sag_im.data[mid_slice_idx, :, :]
                del sag_im

            # histogram equalization using CLAHE
            slice_cur = equalized(mid_slice, winsize)
            # scale intensities of all slices (ie of all subjects) in a common range of values
            slice_cur = scale_intensity(slice_cur)

            # resize all slices with the shape of the first loaded slice
            if len(slice_lst):
                slice_cur = resize(slice_cur, slice_size, anti_aliasing=True)
            else:
                slice_size = slice_cur.shape

            slice_lst.append(slice_cur)

    # create a new Image object containing the samples to display
    data = np.stack(slice_lst, axis=-1)
    nii = nib.nifti1.Nifti1Image(data, affine)
    img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape())

    nb_img = img.data.shape[2]
    nb_items_mosaic = nb_column * nb_row
    nb_mosaic = np.ceil(float(nb_img) / (nb_items_mosaic))
    for i in range(int(nb_mosaic)):
        if nb_mosaic == 1:
            fname_out = o_fname
        else:
            fname_out = os.path.splitext(o_fname)[0] + '_' + str(i).zfill(
                3) + os.path.splitext(o_fname)[1]
        print('\nCreating: ' + fname_out)

        # create mosaic
        idx_end = (i + 1) * nb_items_mosaic if (
            i + 1) * nb_items_mosaic <= nb_img else nb_img
        data_mosaic = img.data[:, :, i * (nb_items_mosaic):idx_end]
        mosaic = get_mosaic(data_mosaic, nb_column, nb_row)

        # save mosaic
        plt.figure()
        plt.subplot(1, 1, 1)
        plt.axis("off")
        plt.imshow(mosaic,
                   interpolation='bilinear',
                   cmap='gray',
                   aspect='equal')
        plt.savefig(fname_out, dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close()
예제 #14
0
def deep_segmentation_spinalcord(im_image,
                                 contrast_type,
                                 ctr_algo='cnn',
                                 ctr_file=None,
                                 brain_bool=True,
                                 kernel_size='2d',
                                 remove_temp_files=1,
                                 verbose=1):
    """Pipeline"""
    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("Finding the spinal cord centerline...")
    _, im_ctl, im_labels_viewer = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)

    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("Cropping the image around the spinal cord...")
    crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_image_res, ctr_in=im_ctl, crop_size=crop_size)

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
    del im_crop_nii

    if kernel_size == '2d':
        # segment data using 2D convolutions
        logger.info(
            "Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
        seg_crop = segment_2d(model_fname=segmentation_model_fname,
                              contrast_type=contrast_type,
                              input_size=(crop_size, crop_size),
                              im_in=im_norm_in)
    elif kernel_size == '3d':
        # segment data using 3D convolutions
        logger.info(
            "Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
        seg_crop = segment_3d(model_fname=segmentation_model_fname,
                              contrast_type=contrast_type,
                              im_in=im_norm_in)
    del im_norm_in

    # reconstruct the segmentation from the crop data
    logger.info("Reassembling the image...")
    im_seg = uncrop_image(ref_in=im_image_res,
                          data_crop=seg_crop,
                          x_crop_lst=X_CROP_LST,
                          y_crop_lst=Y_CROP_LST,
                          z_crop_lst=Z_CROP_LST)
    # seg_uncrop_nii.save(sct.add_suffix(fname_res, '_seg'))  # for debugging
    del seg_crop

    # Change type uint8 --> float32 otherwise resampling will produce binary output (even with linear interpolation)
    im_seg.change_type(np.float32)
    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the native image resolution using linear interpolation..."
    )
    im_seg_r = resampling.resample_nib(im_seg,
                                       image_dest=im_image,
                                       interpolation='linear')

    if ctr_algo == 'viewer':  # for debugging
        im_labels_viewer.save(sct.add_suffix(fname_orient, '_labels-viewer'))

    # Binarize the resampled image to remove interpolation effects
    logger.info("Binarizing the resampled segmentation...")
    # thr = 0.0001 if contrast_type in ['t1', 'dwi'] else 0.5
    thr = 0.5
    # TODO: optimize speed --> np.where is slow
    im_seg_r.data[np.where(im_seg_r.data > thr)] = 1
    im_seg_r.data[np.where(im_seg_r.data <= thr)] = 0

    # post processing step to z_regularized
    im_seg_r_postproc = post_processing_volume_wise(im_seg_r)

    # change data type
    im_seg_r_postproc.change_type(np.uint8)

    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("Remove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return im_seg_r_postproc.change_orientation(original_orientation), \
           im_image_res, \
           im_seg.change_orientation('RPI')
예제 #15
0
def main():
    args = get_parameters()
    print(args)
    im_string = args.input
    # i_folder = args.input_folder
    # seg_string = args.segmentation
    plane = args.plane
    nb_column = int(args.col)
    nb_row = int(args.row)
    winsize = int(args.winsize_CLAHE)
    o_fname = args.output
    # List input folders
    files = glob.glob(os.path.join(args.input_folder, '**/sub' + im_string), recursive=True)
    files.sort()
    # Initialize list that will store each mosaic element
    slice_lst = []
    for file in files:
        print("Processing ({}/{}): {}".format(files.index(file), len(files), file))
        if plane == 'ax':
            file_seg = add_suffix(file, args.segmentation)
            # Extract the mid-slice
            img, seg = Image(file).change_orientation('RPI'), Image(file_seg).change_orientation('RPI')
            mid_slice_idx = int(float(img.dim[2]) // 2)
            nii_mid = nib.nifti2.Nifti2Image(img.data[:, :, mid_slice_idx], img.hdr.get_best_affine())
            nii_mid_seg = nib.nifti2.Nifti2Image(seg.data[:, :, mid_slice_idx], seg.hdr.get_best_affine())
            img_mid = Image(img.data[:, :, mid_slice_idx], hdr=nii_mid.header, dim=nii_mid.header.get_data_shape())
            seg_mid = Image(seg.data[:, :, mid_slice_idx], hdr=nii_mid_seg.header, dim=nii_mid_seg.header.get_data_shape())
            # Instantiate spinalcordtoolbox.reports.slice.Axial class
            qcslice_cur = qcslice.Axial([img_mid, seg_mid])
            # Find center of mass of the segmentation
            center_x_lst, center_y_lst = qcslice_cur.get_center()
            # Select the mid-slice
            mid_slice = qcslice_cur.get_slice(qcslice_cur._images[0].data, 0)
            # Crop image around SC seg
            mid_slice = qcslice_cur.crop(mid_slice,
                                         int(center_x_lst[0]), int(center_y_lst[0]),
                                         20, 20)
        elif plane == 'sag':
            sag_im = Image(file).change_orientation('RSP')
            # check if data is not isotropic resolution
            if not np.isclose(sag_im.dim[5], sag_im.dim[6]):
                sag_im = resample_nib(sag_im.copy(), new_size=[sag_im.dim[4], sag_im.dim[5], sag_im.dim[5]], new_size_type='mm')
            mid_slice_idx = int(sag_im.dim[0] // 2)
            mid_slice = sag_im.data[mid_slice_idx, :, :]
            del sag_im

        # Histogram equalization using CLAHE
        slice_cur = equalized(mid_slice, winsize)
        # Scale intensities of all slices (ie of all subjects) in a common range of values
        slice_cur = scale_intensity(slice_cur)

        # Resize all slices with the shape of the first loaded slice
        if len(slice_lst):
            slice_cur = resize(slice_cur, slice_size, anti_aliasing=True)
        else:
            slice_size = slice_cur.shape

        slice_lst.append(slice_cur)

    # Create a 2d array containing the samples to display
    data = np.stack(slice_lst, axis=-1)
    nb_img = data.shape[2]
    nb_items_mosaic = nb_column * nb_row
    nb_mosaic = np.ceil(float(nb_img) / nb_items_mosaic)
    for i in range(int(nb_mosaic)):
        if nb_mosaic == 1:
            fname_out = o_fname
        else:
            fname_out = os.path.splitext(o_fname)[0] + '_' + str(i).zfill(3) + os.path.splitext(o_fname)[1]
        # create mosaic
        idx_end = (i+1)*nb_items_mosaic if (i+1)*nb_items_mosaic <= nb_img else nb_img
        data_mosaic = data[:, :, i*nb_items_mosaic: idx_end]
        mosaic = get_mosaic(data_mosaic, nb_column, nb_row)
        # save mosaic
        plt.figure()
        plt.subplot(1, 1, 1)
        plt.axis("off")
        plt.imshow(mosaic, interpolation='bilinear', cmap='gray', aspect='equal')
        plt.savefig(fname_out, dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close()
        print('\nCreated: {}'.format(fname_out))