예제 #1
0
def test_sct_register_multimodal_mask_files_exist(tmp_path):
    """
    Run the script without validating results.

    - TODO: Write a check that verifies the registration results.
    - TODO: Parametrize this test to add '-initwarpinv warp_anat2template.nii.gz',
            after the file is added to sct_testing_data:
            https://github.com/spinalcordtoolbox/spinalcordtoolbox/pull/3407#discussion_r646895013
    """
    fname_mask = str(tmp_path / 'mask_mt1.nii.gz')
    sct_create_mask.main([
        '-i',
        sct_test_path('mt', 'mt1.nii.gz'), '-p',
        f"centerline,{sct_test_path('mt', 'mt1_seg.nii.gz')}", '-size', '35mm',
        '-f', 'cylinder', '-o', fname_mask
    ])
    sct_register_multimodal.main([
        '-i',
        sct_dir_local_path('data/PAM50/template/', 'PAM50_t2.nii.gz'), '-iseg',
        sct_dir_local_path('data/PAM50/template/', 'PAM50_cord.nii.gz'), '-d',
        sct_test_path('mt', 'mt1.nii.gz'), '-dseg',
        sct_test_path('mt', 'mt1_seg.nii.gz'), '-param',
        'step=1,type=seg,algo=centermass:step=2,type=seg,algo=bsplinesyn,slicewise=1,iter=3',
        '-m', fname_mask, '-initwarp',
        sct_test_path('t2', 'warp_template2anat.nii.gz'), '-ofolder',
        str(tmp_path)
    ])

    for path in ["PAM50_t2_reg.nii.gz", "warp_PAM50_t22mt1.nii.gz"]:
        assert os.path.exists(tmp_path / path)

    # Because `-initwarp` was specified (but `-initwarpinv` wasn't) the dest->seg files should NOT exist
    for path in ["mt1_reg.nii.gz", "warp_mt12PAM50_t2.nii.gz"]:
        assert not os.path.exists(tmp_path / path)
예제 #2
0
def test_sct_register_multimodal_mask_no_checks(tmp_path):
    """Run the script without validating results.

    TODO: Write a check that verifies the registration results as part of
    https://github.com/spinalcordtoolbox/spinalcordtoolbox/pull/3246."""
    fname_mask = str(tmp_path / 'mask_mt1.nii.gz')
    sct_create_mask.main([
        '-i',
        sct_test_path('mt', 'mt1.nii.gz'), '-p',
        f"centerline,{sct_test_path('mt', 'mt1_seg.nii.gz')}", '-size', '35mm',
        '-f', 'cylinder', '-o', fname_mask
    ])
    sct_register_multimodal.main([
        '-i',
        sct_dir_local_path('data/PAM50/template/', 'PAM50_t2.nii.gz'), '-iseg',
        sct_dir_local_path('data/PAM50/template/', 'PAM50_cord.nii.gz'), '-d',
        sct_test_path('mt', 'mt1.nii.gz'), '-dseg',
        sct_test_path('mt', 'mt1_seg.nii.gz'), '-param',
        'step=1,type=seg,algo=centermass:step=2,type=seg,algo=bsplinesyn,slicewise=1,iter=3',
        '-m', fname_mask, '-initwarp',
        sct_test_path('t2', 'warp_template2anat.nii.gz')
    ])
예제 #3
0
    def __init__(self, qc_params, usage):
        """
        Parameters

        :param qc_params: arguments of the "-param-qc" option in Terminal
        :param usage: str: description of the process
        """
        self.tool_name = qc_params.command
        self.slice_name = qc_params.orientation
        self.qc_params = qc_params
        self.usage = usage
        self.assets_folder = sct_dir_local_path('assets')
        self.img_base_name = 'bkg_img'
        self.description_base_name = "qc_results"
예제 #4
0
def get_dependencies(requirements_txt=None):
    if requirements_txt is None:
        requirements_txt = sct_dir_local_path("requirements.txt")

    requirements_txt = open(requirements_txt, "r", encoding="utf-8")

    # workaround for https://github.com/davidfischer/requirements-parser/issues/39
    warnings.filterwarnings(action='ignore', module='requirements')

    for req in requirements.parse(requirements_txt):
        if ';' in req.line:  # handle environment markers; TODO: move this upstream into requirements-parser
            condition = req.line.split(';', 1)[-1].strip()
            if not _test_condition(condition):
                continue
        pkg = req.name
        # TODO: just return req directly and make sure caller can deal with fancier specs
        ver = dict(req.specs).get("==", None)
        yield pkg, ver
예제 #5
0
def detect_centerline(img, contrast, verbose=1):
    """Detect spinal cord centerline using OptiC.

    :param img: input Image() object.
    :param contrast: str: The type of contrast. Will define the path to Optic model.
    :returns: Image(): Output centerline
    """

    # Fetch path to Optic model based on contrast
    optic_models_path = sct_dir_local_path('data', 'optic_models',
                                           '{}_model'.format(contrast))

    logger.debug('Detecting the spinal cord using OptiC')
    img_orientation = img.orientation

    temp_folder = sct.TempFolder()
    temp_folder.chdir()

    # convert image data type to int16, as required by opencv (backend in OptiC)
    img_int16 = img.copy()
    # Replace non-numeric values by zero
    img_data = img.data
    img_data[np.where(np.isnan(img_data))] = 0
    img_data[np.where(np.isinf(img_data))] = 0
    img_int16.data[np.where(np.isnan(img_int16.data))] = 0
    img_int16.data[np.where(np.isinf(img_int16.data))] = 0
    # rescale intensity
    min_out = np.iinfo('uint16').min
    max_out = np.iinfo('uint16').max
    min_in = np.nanmin(img_data)
    max_in = np.nanmax(img_data)
    data_rescaled = img_data.astype('float') * (max_out - min_out) / (max_in -
                                                                      min_in)
    img_int16.data = data_rescaled - (data_rescaled.min() - min_out)
    # change data type
    img_int16.change_type(np.uint16)
    # reorient the input image to RPI + convert to .nii
    img_int16.change_orientation('RPI')
    file_img = 'img_rpi_uint16'
    img_int16.save(file_img + '.nii')

    # call the OptiC method to generate the spinal cord centerline
    optic_input = file_img
    optic_filename = file_img + '_optic'
    os.environ["FSLOUTPUTTYPE"] = "NIFTI_PAIR"
    cmd_optic = [
        'isct_spine_detect',
        '-ctype=dpdt',
        '-lambda=1',
        optic_models_path,
        optic_input,
        optic_filename,
    ]
    # TODO: output coordinates, for each slice, in continuous (not discrete) values.

    sct.run(cmd_optic, is_sct_binary=True, verbose=0)

    # convert .img and .hdr files to .nii.gz
    img_ctl = Image(file_img + '_optic_ctr.hdr')
    img_ctl.change_orientation(img_orientation)

    # return to initial folder
    temp_folder.chdir_undo()
    if verbose < 2:
        logger.info("Remove temporary files...")
        temp_folder.cleanup()

    return img_ctl
예제 #6
0
def run(cmd,
        verbose=1,
        raise_exception=True,
        cwd=None,
        env=None,
        is_sct_binary=False):
    # if verbose == 2:
    #     printv(sys._getframe().f_back.f_code.co_name, 1, 'process')

    if cwd is None:
        cwd = os.getcwd()

    if env is None:
        env = os.environ

    if sys.hexversion < 0x03000000 and isinstance(cmd, unicode):
        cmd = str(cmd)

    if is_sct_binary:
        name = cmd[0] if isinstance(cmd, list) else cmd.split(" ", 1)[0]
        path = None
        #binaries_location_default = os.path.expanduser("~/.cache/spinalcordtoolbox-{}/bin".format(__version__)
        binaries_location_default = sct_dir_local_path("bin")
        for directory in (
                #binaries_location_default,
                sct_dir_local_path("bin"), ):
            candidate = os.path.join(directory, name)
            if os.path.exists(candidate):
                path = candidate
        if path is None:
            run([
                "sct_download_data", "-d",
                which_sct_binaries(), "-o", binaries_location_default
            ])
            path = os.path.join(binaries_location_default, name)

        if isinstance(cmd, list):
            cmd[0] = path
        elif isinstance(cmd, str):
            rem = cmd.split(" ", 1)[1:]
            cmd = path if len(rem) == 0 else "{} {}".format(path, rem[0])

    if isinstance(cmd, str):
        cmdline = cmd
    else:
        cmdline = list2cmdline(cmd)

    if verbose:
        printv("%s # in %s" % (cmdline, cwd), 1, 'code')

    shell = isinstance(cmd, str)

    process = subprocess.Popen(cmd,
                               shell=shell,
                               cwd=cwd,
                               stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE,
                               stderr=subprocess.STDOUT,
                               env=env)
    output_final = ''
    while True:
        # Watch out for deadlock!!!
        output = process.stdout.readline().decode("utf-8")
        if output == '' and process.poll() is not None:
            break
        if output:
            if verbose == 2:
                printv(output.strip())
            output_final += output.strip() + '\n'

    status = process.returncode
    output = output_final.rstrip()

    # process.stdin.close()
    # process.stdout.close()
    # process.terminate()

    if status != 0 and raise_exception:
        raise RunError(output)

    return status, output
예제 #7
0
def deep_segmentation_MSlesion(im_image,
                               contrast_type,
                               ctr_algo='svm',
                               ctr_file=None,
                               brain_bool=True,
                               remove_temp_files=1,
                               verbose=1):
    """
    Segment lesions from MRI data.

    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()
    fname_in = im_image.absolutepath

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    input_resolution = im_image.dim[4:7]

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    _, im_ctl, im_labels_viewer = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type_ctr,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)
    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_image_res, ctr_in=im_ctl, crop_size=crop_size)
    del im_ctl

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii,
                                               contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm,
                             fname_res3d,
                             '0.5x0.5x0.5',
                             'mm',
                             'linear',
                             verbose=0)

    # segment data using 3D convolutions
    logger.info(
        "\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = sct_dir_local_path(
        'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res,
                             fname_seg_res2d,
                             initial_2d_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    logger.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_image_res,
                                  data_crop=seg_crop.copy().data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI,
                             fname_seg_RPI,
                             initial_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        im_labels_viewer_nib = nib.nifti1.Nifti1Image(
            im_labels_viewer.data, im_labels_viewer.hdr.get_best_affine())
        im_viewer_r_nib = resampling.resample_nib(im_labels_viewer_nib,
                                                  new_size=input_resolution,
                                                  new_size_type='mm',
                                                  interpolation='linear')
        im_viewer = Image(
            im_viewer_r_nib.get_data(),
            hdr=im_viewer_r_nib.header,
            orientation='RPI',
            dim=im_viewer_r_nib.header.get_data_shape()).change_orientation(
                original_orientation)

    else:
        im_viewer = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr,
                                 fname_res_ctr,
                                 initial_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(
            original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info(
        "\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # change data type
    seg_initres_nii.change_type(np.uint8)

    # reorient to initial orientation
    logger.info(
        "\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(
        original_orientation), im_viewer, im_image_res_ctr_downsamp
예제 #8
0
import sys, os
from tempfile import TemporaryDirectory
import pytest

import multiprocessing

from spinalcordtoolbox.utils import sct_test_path, sct_dir_local_path
sys.path.append(sct_dir_local_path('scripts'))
from spinalcordtoolbox import resampling
import spinalcordtoolbox.reports.qc as qc
from spinalcordtoolbox.image import Image
import spinalcordtoolbox.reports.slice as qcslice


def gen_qc(args):
    i, path_qc = args

    t2_image = sct_test_path('t2', 't2.nii.gz')
    t2_seg = sct_test_path('t2', 't2_seg-manual.nii.gz')

    qc.generate_qc(fname_in1=t2_image, fname_seg=t2_seg, path_qc=path_qc, process="sct_deepseg_gm")
    return True


def test_many_qc():
    """Test many qc images can be made in parallel"""
    if multiprocessing.cpu_count() < 2:
        pytest.skip("Can't test parallel behaviour")

    with TemporaryDirectory(prefix="sct-qc-") as tmpdir:
예제 #9
0
def deep_segmentation_spinalcord(im_image,
                                 contrast_type,
                                 ctr_algo='cnn',
                                 ctr_file=None,
                                 brain_bool=True,
                                 kernel_size='2d',
                                 threshold_seg=None,
                                 remove_temp_files=1,
                                 verbose=1):
    """
    Main pipeline for CNN-based segmentation of the spinal cord.

    :param im_image:
    :param contrast_type: {'t1', 't2', t2s', 'dwi'}
    :param ctr_algo:
    :param ctr_file:
    :param brain_bool:
    :param kernel_size:
    :param threshold_seg: Binarization threshold (between 0 and 1) to apply to the segmentation prediction. Set to -1
        for no binarization (i.e. soft segmentation output)
    :param remove_temp_files:
    :param verbose:
    :return:
    """
    if threshold_seg is None:
        threshold_seg = THR_DEEPSEG[contrast_type]

    # Display stuff
    logger.info("Config deepseg_sc:")
    logger.info("  Centerline algorithm: {}".format(ctr_algo))
    logger.info("  Brain in image: {}".format(brain_bool))
    logger.info("  Kernel dimension: {}".format(kernel_size))
    logger.info("  Contrast: {}".format(contrast_type))
    logger.info("  Threshold: {}".format(threshold_seg))

    # create temporary folder with intermediate results
    tmp_folder = TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("Finding the spinal cord centerline...")
    _, im_ctl, im_labels_viewer = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)

    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("Cropping the image around the spinal cord...")
    crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_image_res, ctr_in=im_ctl, crop_size=crop_size)

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
    del im_crop_nii

    if kernel_size == '2d':
        # segment data using 2D convolutions
        logger.info(
            "Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = \
            sct_dir_local_path('data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
        seg_crop = segment_2d(model_fname=segmentation_model_fname,
                              contrast_type=contrast_type,
                              input_size=(crop_size, crop_size),
                              im_in=im_norm_in)
    elif kernel_size == '3d':
        # segment data using 3D convolutions
        logger.info(
            "Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = \
            sct_dir_local_path('data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
        seg_crop = segment_3d(model_fname=segmentation_model_fname,
                              contrast_type=contrast_type,
                              im_in=im_norm_in)

    # Postprocessing
    seg_crop_postproc = np.zeros_like(seg_crop)
    x_cOm, y_cOm = None, None
    for zz in range(im_norm_in.dim[2]):
        # Fill holes (only for binary segmentations)
        if threshold_seg >= 0:
            pred_seg_th = fill_holes_2d(
                (seg_crop[:, :, zz] > threshold_seg).astype(int))
            pred_seg_pp = keep_largest_object(pred_seg_th, x_cOm, y_cOm)
            # Update center of mass for slice i+1
            if 1 in pred_seg_pp:
                x_cOm, y_cOm = center_of_mass(pred_seg_pp)
                x_cOm, y_cOm = np.round(x_cOm), np.round(y_cOm)
        else:
            # If soft segmentation, do nothing
            pred_seg_pp = seg_crop[:, :, zz]

        seg_crop_postproc[:, :, zz] = pred_seg_pp  # dtype is float32

    # reconstruct the segmentation from the crop data
    logger.info("Reassembling the image...")
    im_seg = uncrop_image(ref_in=im_image_res,
                          data_crop=seg_crop_postproc,
                          x_crop_lst=X_CROP_LST,
                          y_crop_lst=Y_CROP_LST,
                          z_crop_lst=Z_CROP_LST)
    # seg_uncrop_nii.save(add_suffix(fname_res, '_seg'))  # for debugging
    del seg_crop, seg_crop_postproc, im_norm_in

    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the native image resolution using linear interpolation..."
    )
    im_seg_r = resampling.resample_nib(im_seg,
                                       image_dest=im_image,
                                       interpolation='linear')

    if ctr_algo == 'viewer':  # for debugging
        im_labels_viewer.save(add_suffix(fname_orient, '_labels-viewer'))

    # Binarize the resampled image (except for soft segmentation, defined by threshold_seg=-1)
    if threshold_seg >= 0:
        logger.info("Binarizing the resampled segmentation...")
        im_seg_r.data = (im_seg_r.data > 0.5).astype(np.uint8)

    # post processing step to z_regularized
    im_seg_r_postproc = post_processing_volume_wise(im_seg_r)

    # Change data type. By default, dtype is float32
    if threshold_seg >= 0:
        im_seg_r_postproc.change_type(np.uint8)

    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("Remove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    im_seg_r_postproc.change_orientation(original_orientation)

    # copy q/sform from input image to output segmentation
    im_seg.copy_qform_from_ref(im_image)

    return im_seg_r_postproc, im_image_res, im_seg.change_orientation('RPI')
예제 #10
0
def find_centerline(algo, image_fname, contrast_type, brain_bool,
                    folder_output, remove_temp_files, centerline_fname):
    """
    Assumes RPI orientation

    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:
    :return:
    """

    im = Image(image_fname)
    ctl_absolute_path = add_suffix(im.absolutepath, "_ctr")

    # isct_spine_detect requires nz > 1
    if im.dim[2] == 1:
        im = concat_data([im, im], dim=2)
        im.hdr['dim'][
            3] = 2  # Needs to be change manually since dim not updated during concat_data
        bool_2d = True
    else:
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        im_ctl, _, _, _ = get_centerline(
            im, ParamCenterline(algo_fitting='optic', contrast=contrast_type))

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {
            't2': {
                'size': (80, 80),
                'mean': 51.1417,
                'std': 57.4408
            },
            't2s': {
                'size': (80, 80),
                'mean': 68.8591,
                'std': 71.4659
            },
            't1': {
                'size': (80, 80),
                'mean': 55.7359,
                'std': 64.3149
            },
            'dwi': {
                'size': (80, 80),
                'mean': 55.744,
                'std': 45.003
            }
        }
        dct_params_ctr = {
            't2': {
                'features': 16,
                'dilation_layers': 2
            },
            't2s': {
                'features': 8,
                'dilation_layers': 3
            },
            't1': {
                'features': 24,
                'dilation_layers': 3
            },
            'dwi': {
                'features': 8,
                'dilation_layers': 2
            }
        }

        # load model
        ctr_model_fname = sct_dir_local_path('data', 'deepseg_sc_models',
                                             '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(
            height=dct_patch_ctr[contrast_type]['size'][0],
            width=dct_patch_ctr[contrast_type]['size'][1],
            channels=1,
            classes=1,
            features=dct_params_ctr[contrast_type]['features'],
            depth=2,
            temperature=1.0,
            padding='same',
            batchnorm=True,
            dropout=0.0,
            dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        # compute the heatmap
        im_heatmap, z_max = heatmap(
            im=im,
            model=ctr_model,
            patch_shape=dct_patch_ctr[contrast_type]['size'],
            mean_train=dct_patch_ctr[contrast_type]['mean'],
            std_train=dct_patch_ctr[contrast_type]['std'],
            brain_bool=brain_bool)
        im_ctl, _, _, _ = get_centerline(
            im_heatmap,
            ParamCenterline(algo_fitting='optic', contrast=contrast_type))

        if z_max is not None:
            logger.info('Cropping brain section.')
            im_ctl.data[:, :, z_max:] = 0

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(im)
        im_ctl, _, _, _ = get_centerline(im_labels, param=ParamCenterline())

    elif algo == 'file':
        im_ctl = Image(centerline_fname)
        im_ctl.change_orientation('RPI')

    else:
        logger.error(
            'The parameter "-centerline" is incorrect. Please try again.')
        sys.exit(1)

    # TODO: for some reason, when algo == 'file', the absolutepath is changed to None out of the method find_centerline
    im_ctl.absolutepath = ctl_absolute_path

    if bool_2d:
        im_ctl = split_data(im_ctl, dim=2)[0]

    if algo != 'viewer':
        im_labels = None

    # TODO: remove unecessary return params
    return "dummy_file_name", im_ctl, im_labels