def _preprocess_segment(fname_t2, fname_t2_seg, contrast_test, dim_3=False):
    tmp_folder = sct.TempFolder()
    tmp_folder_path = tmp_folder.get_path()
    tmp_folder.chdir()

    img = Image(fname_t2)
    gt = Image(fname_t2_seg)

    fname_t2_RPI, fname_t2_seg_RPI = 'img_RPI.nii.gz', 'seg_RPI.nii.gz'
    img.change_orientation('RPI').save(fname_t2_RPI)
    gt.change_orientation('RPI').save(fname_t2_seg_RPI)
    input_resolution = gt.dim[4:7]
    del img, gt

    fname_res, fname_ctr, _ = deepseg_sc.find_centerline(algo='svm',
                                                            image_fname=fname_t2_RPI,
                                                            contrast_type=contrast_test,
                                                            brain_bool=False,
                                                            folder_output=tmp_folder_path,
                                                            remove_temp_files=1,
                                                            centerline_fname=None)

    fname_t2_seg_RPI_res = 'seg_RPI_res.nii.gz'
    new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resample_file(fname_t2_seg_RPI, fname_t2_seg_RPI_res, new_resolution, 'mm', 'linear', verbose=0)

    img, ctr, gt = Image(fname_res), Image(fname_ctr), Image(fname_t2_seg_RPI_res)
    _, _, _, img = deepseg_sc.crop_image_around_centerline(im_in=img,
                                                        ctr_in=ctr,
                                                        crop_size=64)
    _, _, _, gt = deepseg_sc.crop_image_around_centerline(im_in=gt,
                                                        ctr_in=ctr,
                                                        crop_size=64)
    del ctr

    img = deepseg_sc.apply_intensity_normalization(im_in=img)

    if dim_3:  # If 3D kernels
        fname_t2_RPI_res_crop, fname_t2_seg_RPI_res_crop = 'img_RPI_res_crop.nii.gz', 'seg_RPI_res_crop.nii.gz'
        img.save(fname_t2_RPI_res_crop)
        gt.save(fname_t2_seg_RPI_res_crop)
        del img, gt

        fname_t2_RPI_res_crop_res = 'img_RPI_res_crop_res.nii.gz'
        fname_t2_seg_RPI_res_crop_res = 'seg_RPI_res_crop_res.nii.gz'
        resample_file(fname_t2_RPI_res_crop, fname_t2_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0)
        resample_file(fname_t2_seg_RPI_res_crop, fname_t2_seg_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0)
        img, gt = Image(fname_t2_RPI_res_crop_res), Image(fname_t2_seg_RPI_res_crop_res)

    tmp_folder.chdir_undo()
    tmp_folder.cleanup()

    return img, gt
def _preprocess_segment(fname_t2, fname_t2_seg, contrast_test, dim_3=False):
    tmp_folder = sct.TempFolder()
    tmp_folder_path = tmp_folder.get_path()
    tmp_folder.chdir()

    img = Image(fname_t2)
    gt = Image(fname_t2_seg)

    fname_t2_RPI, fname_t2_seg_RPI = 'img_RPI.nii.gz', 'seg_RPI.nii.gz'
    img.change_orientation('RPI').save(fname_t2_RPI)
    gt.change_orientation('RPI').save(fname_t2_seg_RPI)
    input_resolution = gt.dim[4:7]
    del img, gt

    fname_res, fname_ctr = deepseg_sc.find_centerline(algo='svm',
                                                        image_fname=fname_t2_RPI,
                                                        contrast_type=contrast_test,
                                                        brain_bool=False,
                                                        folder_output=tmp_folder_path,
                                                        remove_temp_files=1,
                                                        centerline_fname=None)

    fname_t2_seg_RPI_res = 'seg_RPI_res.nii.gz'
    new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resample_file(fname_t2_seg_RPI, fname_t2_seg_RPI_res, new_resolution, 'mm', 'linear', verbose=0)

    img, ctr, gt = Image(fname_res), Image(fname_ctr), Image(fname_t2_seg_RPI_res)
    _, _, _, img = deepseg_sc.crop_image_around_centerline(im_in=img,
                                                        ctr_in=ctr,
                                                        crop_size=64)
    _, _, _, gt = deepseg_sc.crop_image_around_centerline(im_in=gt,
                                                        ctr_in=ctr,
                                                        crop_size=64)
    del ctr

    img = deepseg_sc.apply_intensity_normalization(im_in=img)

    if dim_3:  # If 3D kernels
        fname_t2_RPI_res_crop, fname_t2_seg_RPI_res_crop = 'img_RPI_res_crop.nii.gz', 'seg_RPI_res_crop.nii.gz'
        img.save(fname_t2_RPI_res_crop)
        gt.save(fname_t2_seg_RPI_res_crop)
        del img, gt

        fname_t2_RPI_res_crop_res = 'img_RPI_res_crop_res.nii.gz'
        fname_t2_seg_RPI_res_crop_res = 'seg_RPI_res_crop_res.nii.gz'
        resample_file(fname_t2_RPI_res_crop, fname_t2_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0)
        resample_file(fname_t2_seg_RPI_res_crop, fname_t2_seg_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0)
        img, gt = Image(fname_t2_RPI_res_crop_res), Image(fname_t2_seg_RPI_res_crop_res)

    tmp_folder.chdir_undo()
    tmp_folder.cleanup()

    return img, gt
def deep_segmentation_spinalcord(im_image,
                                 contrast_type,
                                 ctr_algo='cnn',
                                 ctr_file=None,
                                 brain_bool=True,
                                 kernel_size='2d',
                                 remove_temp_files=1,
                                 verbose=1):
    """Pipeline"""
    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI').save(fname_orient)

    input_resolution = im_image.dim[4:7]

    # find the spinal cord centerline - execute OptiC binary
    logger.info("Finding the spinal cord centerline...")
    fname_res, centerline_filename = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)

    im_nii, ctr_nii = Image(fname_res), Image(centerline_filename)

    # crop image around the spinal cord centerline
    logger.info("Cropping the image around the spinal cord...")
    crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_nii, ctr_in=ctr_nii, crop_size=crop_size)
    del ctr_nii

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
    del im_crop_nii

    if kernel_size == '2d':
        # segment data using 2D convolutions
        logger.info(
            "Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
        seg_crop_data = segment_2d(model_fname=segmentation_model_fname,
                                   contrast_type=contrast_type,
                                   input_size=(crop_size, crop_size),
                                   im_in=im_norm_in)
    elif kernel_size == '3d':
        # segment data using 3D convolutions
        logger.info(
            "Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
        seg_crop_data = segment_3d(model_fname=segmentation_model_fname,
                                   contrast_type=contrast_type,
                                   im_in=im_norm_in)
    del im_norm_in

    # reconstruct the segmentation from the crop data
    logger.info("Reassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_nii,
                                  data_crop=seg_crop_data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_res_seg = sct.add_suffix(fname_res, '_seg')
    seg_uncrop_nii.save(fname_res_seg)
    del seg_crop_data

    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    fname_res_seg_downsamp = sct.add_suffix(fname_res_seg, '_downsamp')

    resampling.resample_file(fname_res_seg,
                             fname_res_seg_downsamp,
                             initial_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    im_image_res_seg_downsamp = Image(fname_res_seg_downsamp)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        fname_res_labels = sct.add_suffix(fname_orient, '_labels-centerline')
        resampling.resample_file(fname_res_labels,
                                 fname_res_labels,
                                 initial_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        im_image_res_labels_downsamp = Image(
            fname_res_labels).change_orientation(original_orientation)
    else:
        im_image_res_labels_downsamp = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr,
                                 fname_res_ctr,
                                 initial_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(
            original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info(
        "Binarizing the segmentation to avoid interpolation effects...")
    thr = 0.0001 if contrast_type in ['t1', 'dwi'] else 0.5
    # TODO: optimize speed --> np.where is slow
    im_image_res_seg_downsamp.data[np.where(
        im_image_res_seg_downsamp.data >= thr)] = 1
    im_image_res_seg_downsamp.data[np.where(
        im_image_res_seg_downsamp.data < thr)] = 0

    # post processing step to z_regularized
    im_image_res_seg_downsamp_postproc = post_processing_volume_wise(
        im_in=im_image_res_seg_downsamp)

    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("Remove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return im_image_res_seg_downsamp_postproc.change_orientation(original_orientation), \
           im_nii, \
           seg_uncrop_nii.change_orientation('RPI'), \
           im_image_res_labels_downsamp, \
           im_image_res_ctr_downsamp
def find_centerline(algo, image_fname, contrast_type, brain_bool,
                    folder_output, remove_temp_files, centerline_fname):
    """
    Assumes RPI orientation
    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:
    :return:
    """

    # TODO: remove unnecessary i/o
    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
    else:
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        img_ctl, arr_ctl, _ = get_centerline(Image(image_fname),
                                             algo_fitting='optic',
                                             contrast=contrast_type)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        img_ctl.save(centerline_filename)

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {
            't2': {
                'size': (80, 80),
                'mean': 51.1417,
                'std': 57.4408
            },
            't2s': {
                'size': (80, 80),
                'mean': 68.8591,
                'std': 71.4659
            },
            't1': {
                'size': (80, 80),
                'mean': 55.7359,
                'std': 64.3149
            },
            'dwi': {
                'size': (80, 80),
                'mean': 55.744,
                'std': 45.003
            }
        }
        dct_params_ctr = {
            't2': {
                'features': 16,
                'dilation_layers': 2
            },
            't2s': {
                'features': 8,
                'dilation_layers': 3
            },
            't1': {
                'features': 24,
                'dilation_layers': 3
            },
            'dwi': {
                'features': 8,
                'dilation_layers': 2
            }
        }

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data',
                                       'deepseg_sc_models',
                                       '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(
            height=dct_patch_ctr[contrast_type]['size'][0],
            width=dct_patch_ctr[contrast_type]['size'][1],
            channels=1,
            classes=1,
            features=dct_params_ctr[contrast_type]['features'],
            depth=2,
            temperature=1.0,
            padding='same',
            batchnorm=True,
            dropout=0.0,
            dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname,
                                 fname_res,
                                 new_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)

        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,
                        filename_out=fname_heatmap_nii,
                        model=ctr_model,
                        patch_shape=dct_patch_ctr[contrast_type]['size'],
                        mean_train=dct_patch_ctr[contrast_type]['mean'],
                        std_train=dct_patch_ctr[contrast_type]['std'],
                        brain_bool=brain_bool)

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
        heatmap2optic(fname_heatmap=fname_heatmap_nii,
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      fname_out=centerline_filename,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(Image(image_fname))
        im_centerline, arr_centerline, _ = get_centerline(im_labels)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        labels_filename = sct.add_suffix(image_fname, "_labels-centerline")
        im_centerline.save(centerline_filename)
        im_labels.save(labels_filename)

    elif algo == 'file':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        # Re-orient the manual centerline
        Image(centerline_fname).change_orientation('RPI').save(
            centerline_filename)

    else:
        logger.error(
            'The parameter "-centerline" is incorrect. Please try again.')
        sys.exit(1)

    if algo != 'cnn':
        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname,
                                 fname_res,
                                 new_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)

        resampling.resample_file(centerline_filename,
                                 centerline_filename,
                                 new_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)

    if bool_2d:
        im_split_lst = split_data(Image(centerline_filename), dim=2)
        im_split_lst[0].save(centerline_filename)

    return fname_res, centerline_filename
Exemple #5
0
def main(args=None):

    # initializations
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    if '-l' in arguments:
        fname_landmarks = arguments['-l']
        label_type = 'body'
    elif '-ldisc' in arguments:
        fname_landmarks = arguments['-ldisc']
        label_type = 'disc'
    else:
        sct.printv('ERROR: Labels should be provided.', 1, 'error')
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''

    param.path_qc = arguments.get("-qc", None)

    path_template = arguments['-t']
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    param.remove_temp_files = int(arguments.get('-r'))
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    param_centerline = ParamCenterline(
        algo_fitting=arguments['-centerline-algo'],
        smooth=arguments['-centerline-smooth'])
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    zsubsample = param.zsubsample

    # retrieve template file names
    file_template_vertebral_labeling = get_file_label(os.path.join(path_template, 'template'), 'vertebral labeling')
    file_template = get_file_label(os.path.join(path_template, 'template'), contrast_template.upper() + '-weighted template')
    file_template_seg = get_file_label(os.path.join(path_template, 'template'), 'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = os.path.join(path_template, 'template', file_template)
    fname_template_vertebral_labeling = os.path.join(path_template, 'template', file_template_vertebral_labeling)
    fname_template_seg = os.path.join(path_template, 'template', file_template_seg)
    fname_template_disc_labeling = os.path.join(path_template, 'template', 'PAM50_label_disc.nii.gz')

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # sct.printv(arguments)
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 ' + fname_data, verbose)
    sct.printv('  Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('  Segmentation:         ' + fname_seg, verbose)
    sct.printv('  Path template:        ' + path_template, verbose)
    sct.printv('  Remove temp files:    ' + str(param.remove_temp_files), verbose)

    # check input labels
    labels = check_labels(fname_landmarks, label_type=label_type)

    vertebral_alignment = False
    if len(labels) > 2 and label_type == 'disc':
        vertebral_alignment = True

    path_tmp = sct.tmp_create(basename="register_to_template", verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    Image(fname_data).save(os.path.join(path_tmp, ftmp_data))
    Image(fname_seg).save(os.path.join(path_tmp, ftmp_seg))
    Image(fname_landmarks).save(os.path.join(path_tmp, ftmp_label))
    Image(fname_template).save(os.path.join(path_tmp, ftmp_template))
    Image(fname_template_seg).save(os.path.join(path_tmp, ftmp_template_seg))
    Image(fname_template_vertebral_labeling).save(os.path.join(path_tmp, ftmp_template_label))
    if label_type == 'disc':
        Image(fname_template_disc_labeling).save(os.path.join(path_tmp, ftmp_template_label))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Generate labels from template vertebral labeling
    if label_type == 'body':
        sct.printv('\nGenerate labels from template vertebral labeling', verbose)
        ftmp_template_label_, ftmp_template_label = ftmp_template_label, sct.add_suffix(ftmp_template_label, "_body")
        sct_label_utils.main(args=['-i', ftmp_template_label_, '-vert-body', '0', '-o', ftmp_template_label])

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # if only one label is present, force affine transformation to be Tx,Ty,Tz only (no scaling)
    if len(labels) == 1:
        paramreg.steps['0'].dof = 'Tx_Ty_Tz'
        sct.printv('WARNING: Only one label is present. Forcing initial transformation to: ' + paramreg.steps['0'].dof,
                   1, 'warning')

    # Project labels onto the spinal cord centerline because later, an affine transformation is estimated between the
    # template's labels (centered in the cord) and the subject's labels (assumed to be centered in the cord).
    # If labels are not centered, mis-registration errors are observed (see issue #1826)
    ftmp_label = project_labels_on_spinalcord(ftmp_label, ftmp_seg, param_centerline)

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    ftmp_seg_, ftmp_seg = ftmp_seg, sct.add_suffix(ftmp_seg, "_bin")
    sct_maths.main(['-i', ftmp_seg_,
                    '-bin', '0.5',
                    '-o', ftmp_seg])

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        resample_file(ftmp_data, add_suffix(ftmp_data, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        resample_file(ftmp_seg, add_suffix(ftmp_seg, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling
        # with nearest neighbour can make them disappear.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)

        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath


        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        if vertebral_alignment:
            # cropping the segmentation based on the label coverage to ensure good registration with vertebral alignment
            # See https://github.com/neuropoly/spinalcordtoolbox/pull/1669 for details
            image_labels = Image(ftmp_label)
            coordinates_labels = image_labels.getNonZeroCoordinates(sorting='z')
            nx, ny, nz, nt, px, py, pz, pt = image_labels.dim
            offset_crop = 10.0 * pz  # cropping the image 10 mm above and below the highest and lowest label
            cropping_slices = [coordinates_labels[0].z - offset_crop, coordinates_labels[-1].z + offset_crop]
            # make sure that the cropping slices do not extend outside of the slice range (issue #1811)
            if cropping_slices[0] < 0:
                cropping_slices[0] = 0
            if cropping_slices[1] > nz:
                cropping_slices[1] = nz
            msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, np.int32(np.round(cropping_slices))),))).save(ftmp_seg)
        else:
            # if we do not align the vertebral levels, we crop the segmentation from top to bottom
            im_seg_rpi = Image(ftmp_seg_)
            bottom = 0
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "IS"):
                if (data != 0).any():
                    break
                bottom += 1
            top = im_seg_rpi.data.shape[2]
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "SI"):
                if (data != 0).any():
                    break
                top -= 1
            msct_image.spatial_crop(im_seg_rpi, dict(((2, (bottom, top)),))).save(ftmp_seg)


        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)

        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        fn_warp_curve2straight = os.path.join(curdir, "warp_curve2straight.nii.gz")
        fn_warp_straight2curve = os.path.join(curdir, "warp_straight2curve.nii.gz")
        fn_straight_ref = os.path.join(curdir, "straight_ref.nii.gz")

        cache_input_files=[ftmp_seg]
        if vertebral_alignment:
            cache_input_files += [
             ftmp_template_seg,
             ftmp_label,
             ftmp_template_label,
            ]
        cache_sig = sct.cache_signature(
         input_files=cache_input_files,
        )
        cachefile = os.path.join(curdir, "straightening.cache")
        if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(fn_warp_curve2straight) and os.path.isfile(fn_warp_straight2curve) and os.path.isfile(fn_straight_ref):
            sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
            sct.copy(fn_warp_curve2straight, 'warp_curve2straight.nii.gz')
            sct.copy(fn_warp_straight2curve, 'warp_straight2curve.nii.gz')
            sct.copy(fn_straight_ref, 'straight_ref.nii.gz')
            # apply straightening
            sct_apply_transfo.main(args=[
                '-i', ftmp_seg,
                '-w', 'warp_curve2straight.nii.gz',
                '-d', 'straight_ref.nii.gz',
                '-o', add_suffix(ftmp_seg, '_straight')])
        else:
            from spinalcordtoolbox.straightening import SpinalCordStraightener
            sc_straight = SpinalCordStraightener(ftmp_seg, ftmp_seg)
            sc_straight.param_centerline = param_centerline
            sc_straight.output_filename = add_suffix(ftmp_seg, '_straight')
            sc_straight.path_output = './'
            sc_straight.qc = '0'
            sc_straight.remove_temp_files = param.remove_temp_files
            sc_straight.verbose = verbose

            if vertebral_alignment:
                sc_straight.centerline_reference_filename = ftmp_template_seg
                sc_straight.use_straight_reference = True
                sc_straight.discs_input_filename = ftmp_label
                sc_straight.discs_ref_filename = ftmp_template_label

            sc_straight.straighten()
            sct.cache_save(cachefile, cache_sig)

        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        sct_concat_transfo.main(args=[
            '-w', 'warp_straight2curve.nii.gz',
            '-d', ftmp_data,
            '-o', 'warp_straight2curve.nii.gz'])

        if vertebral_alignment:
            sct.copy('warp_curve2straight.nii.gz', 'warp_curve2straightAffine.nii.gz')
        else:
            # Label preparation:
            # --------------------------------------------------------------------------------
            # Remove unused label on template. Keep only label present in the input label image
            sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
            sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

            # Dilating the input label so they can be straighten without losing them
            sct.printv('\nDilating input labels using 3vox ball radius')
            sct_maths.main(['-i', ftmp_label,
                            '-dilate', '3',
                            '-o', add_suffix(ftmp_label, '_dilate')])
            ftmp_label = add_suffix(ftmp_label, '_dilate')

            # Apply straightening to labels
            sct.printv('\nApply straightening to labels...', verbose)
            sct_apply_transfo.main(args=[
                '-i', ftmp_label,
                '-o', add_suffix(ftmp_label, '_straight'),
                '-d', add_suffix(ftmp_seg, '_straight'),
                '-w', 'warp_curve2straight.nii.gz',
                '-x', 'nn'])
            ftmp_label = add_suffix(ftmp_label, '_straight')

            # Compute rigid transformation straight landmarks --> template landmarks
            sct.printv('\nEstimate transformation for step #0...', verbose)
            try:
                register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof,
                                   fname_affine='straight2templateAffine.txt', verbose=verbose)
            except RuntimeError:
                raise('Input labels do not seem to be at the right place. Please check the position of the labels. '
                      'See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42')

            # Concatenate transformations: curve --> straight --> affine
            sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
            sct_concat_transfo.main(args=[
                '-w', ['warp_curve2straight.nii.gz', 'straight2templateAffine.txt'],
                '-d', 'template.nii',
                '-o', 'warp_curve2straightAffine.nii.gz'])

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct_apply_transfo.main(args=[
            '-i', ftmp_data,
            '-o', add_suffix(ftmp_data, '_straightAffine'),
            '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz'])
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct_apply_transfo.main(args=[
            '-i', ftmp_seg,
            '-o', add_suffix(ftmp_seg, '_straightAffine'),
            '-d', ftmp_template,
            '-w', 'warp_curve2straightAffine.nii.gz',
            '-x', 'linear'])
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(np.round(np.min(points_straight))), int(np.round(np.max(points_straight)))
        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_black')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (min_point,max_point)),))).save(ftmp_seg)

        """
        # open segmentation
        im = Image(ftmp_seg)
        im_new = msct_image.empty_like(im)
        # binarize
        im_new.data = im.data > 0.5
        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = msct_image.find_zmin_zmax(im_new, threshold=0.5)
        # save binarized segmentation
        im_new.save(add_suffix(ftmp_seg, '_bin')) # unused?
        # crop template in z-direction (for faster processing)
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        ftmp_template_, ftmp_template = ftmp_template, add_suffix(ftmp_template, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template)

        ftmp_template_seg_, ftmp_template_seg = ftmp_template_seg, add_suffix(ftmp_template_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template_seg)

        ftmp_data_, ftmp_data = ftmp_data, add_suffix(ftmp_data, '_crop')
        msct_image.spatial_crop(Image(ftmp_data_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_data)

        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_seg)

        # sub-sample in z-direction
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run(['sct_resample', '-i', ftmp_template, '-o', add_suffix(ftmp_template, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run(['sct_resample', '-i', ftmp_template_seg, '-o', add_suffix(ftmp_template_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run(['sct_resample', '-i', ftmp_data, '-o', add_suffix(ftmp_data, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run(['sct_resample', '-i', ftmp_seg, '-o', add_suffix(ftmp_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')

            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':
                src_seg = ftmp_seg
                dest_seg = ftmp_template_seg
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # apply transformation from previous step, to use as new src for registration
                sct_apply_transfo.main(args=[
                    '-i', src,
                    '-d', dest,
                    '-w', warp_forward,
                    '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                    '-x', interp_step])
                src = add_suffix(src, '_regStep' + str(i_step - 1))
                if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':  # also apply transformation to the seg
                    sct_apply_transfo.main(args=[
                        '-i', src_seg,
                        '-d', dest_seg,
                        '-w', warp_forward,
                        '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                        '-x', interp_step])
                    src_seg = add_suffix(src_seg, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog': # im_seg case
                warp_forward_out, warp_inverse_out = register([src, src_seg], [dest, dest_seg], paramreg, param, str(i_step))
            else:
                warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations: anat --> template
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        warp_forward.insert(0, 'warp_curve2straightAffine.nii.gz')
        sct_concat_transfo.main(args=[
            '-w', warp_forward,
            '-d', 'template.nii',
            '-o', 'warp_anat2template.nii.gz'])

        # Concatenate transformations: template --> anat
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()
        if vertebral_alignment:
            warp_inverse.append('warp_straight2curve.nii.gz')
            sct_concat_transfo.main(args=[
                '-w', warp_inverse,
                '-d', 'data.nii',
                '-o', 'warp_template2anat.nii.gz'])
        else:
            warp_inverse.append('straight2templateAffine.txt')
            warp_inverse.append('warp_straight2curve.nii.gz')
            sct_concat_transfo.main(args=[
                '-w', warp_inverse,
                '-winv', ['straight2templateAffine.txt'],
                '-d', 'data.nii',
                '-o', 'warp_template2anat.nii.gz'])

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This
        # new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = np.round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x), int(new_label.y), int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.absolutepath = 'label_rpi_modif.nii.gz'
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc="./")
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct_apply_transfo.main(args=[
                '-i', src,
                '-d', dest,
                '-w', warp_forward,
                '-o', add_suffix(src, '_regStep' + str(i_step - 1)),
                '-x', interp_step])
            src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct_concat_transfo.main(args=[
            '-w', warp_forward,
            '-d', 'data.nii',
            '-o', 'warp_template2anat.nii.gz'])
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct_concat_transfo.main(args=[
            '-w', warp_inverse,
            '-winv', ['template2subjectAffine.txt'],
            '-d', 'template.nii',
            '-o', 'warp_anat2template.nii.gz'])

    # Apply warping fields to anat and template
    sct.run(['sct_apply_transfo', '-i', 'template.nii', '-o', 'template2anat.nii.gz', '-d', 'data.nii', '-w', 'warp_template2anat.nii.gz', '-crop', '1'], verbose)
    sct.run(['sct_apply_transfo', '-i', 'data.nii', '-o', 'anat2template.nii.gz', '-d', 'template.nii', '-w', 'warp_anat2template.nii.gz', '-crop', '1'], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    fname_template2anat = os.path.join(path_output, 'template2anat' + ext_data)
    fname_anat2template = os.path.join(path_output, 'anat2template' + ext_data)
    sct.generate_output_file(os.path.join(path_tmp, "warp_template2anat.nii.gz"), os.path.join(path_output, "warp_template2anat.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "warp_anat2template.nii.gz"), os.path.join(path_output, "warp_anat2template.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "template2anat.nii.gz"), fname_template2anat, verbose)
    sct.generate_output_file(os.path.join(path_tmp, "anat2template.nii.gz"), fname_anat2template, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose)

    # Delete temporary files
    if param.remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)
    if param.path_qc is not None:
        generate_qc(fname_data, fname_in2=fname_template2anat, fname_seg=fname_seg, args=args,
                    path_qc=os.path.abspath(param.path_qc), dataset=qc_dataset, subject=qc_subject,
                    process='sct_register_to_template')
    sct.display_viewer_syntax([fname_data, fname_template2anat], verbose=verbose)
    sct.display_viewer_syntax([fname_template, fname_anat2template], verbose=verbose)
Exemple #6
0
def deep_segmentation_spinalcord(im_image, contrast_type, ctr_algo='cnn', ctr_file=None, brain_bool=True,
                                 kernel_size='2d', remove_temp_files=1, verbose=1):
    """Pipeline"""
    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI').save(fname_orient)

    input_resolution = im_image.dim[4:7]

    # find the spinal cord centerline - execute OptiC binary
    logger.info("Finding the spinal cord centerline...")
    fname_res, centerline_filename = find_centerline(algo=ctr_algo,
                                                     image_fname=fname_orient,
                                                     contrast_type=contrast_type,
                                                     brain_bool=brain_bool,
                                                     folder_output=tmp_folder_path,
                                                     remove_temp_files=remove_temp_files,
                                                     centerline_fname=file_ctr)

    im_nii, ctr_nii = Image(fname_res), Image(centerline_filename)

    # crop image around the spinal cord centerline
    logger.info("Cropping the image around the spinal cord...")
    crop_size = 96 if (kernel_size == '3d' and contrast_type == 't2s') else 64
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_nii,
                                                                                   ctr_in=ctr_nii,
                                                                                   crop_size=crop_size)
    del ctr_nii

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(im_in=im_crop_nii)
    del im_crop_nii

    if kernel_size == '2d':
        # segment data using 2D convolutions
        logger.info("Segmenting the spinal cord using deep learning on 2D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc.h5'.format(contrast_type))
        seg_crop_data = segment_2d(model_fname=segmentation_model_fname,
                                   contrast_type=contrast_type,
                                   input_size=(crop_size, crop_size),
                                   im_in=im_norm_in)
    elif kernel_size == '3d':
        # segment data using 3D convolutions
        logger.info("Segmenting the spinal cord using deep learning on 3D patches...")
        segmentation_model_fname = \
            os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_sc_3D.h5'.format(contrast_type))
        seg_crop_data = segment_3d(model_fname=segmentation_model_fname,
                                   contrast_type=contrast_type,
                                   im_in=im_norm_in)
    del im_norm_in

    # reconstruct the segmentation from the crop data
    logger.info("Reassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_nii,
                                  data_crop=seg_crop_data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_res_seg = sct.add_suffix(fname_res, '_seg')
    seg_uncrop_nii.save(fname_res_seg)
    del seg_crop_data

    # resample to initial resolution
    logger.info("Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([str(input_resolution[0]), str(input_resolution[1]), str(input_resolution[2])])
    fname_res_seg_downsamp = sct.add_suffix(fname_res_seg, '_downsamp')

    resampling.resample_file(fname_res_seg, fname_res_seg_downsamp, initial_resolution,
                                                           'mm', 'linear', verbose=0)
    im_image_res_seg_downsamp = Image(fname_res_seg_downsamp)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        fname_res_labels = sct.add_suffix(fname_orient, '_labels-centerline')
        resampling.resample_file(fname_res_labels, fname_res_labels, initial_resolution, 'mm', 'linear', verbose=0)
        im_image_res_labels_downsamp = Image(fname_res_labels).change_orientation(original_orientation)
    else:
        im_image_res_labels_downsamp = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr, fname_res_ctr, initial_resolution, 'mm', 'linear', verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info("Binarizing the segmentation to avoid interpolation effects...")
    thr = 0.0001 if contrast_type in ['t1', 'dwi'] else 0.5
    # TODO: optimize speed --> np.where is slow
    im_image_res_seg_downsamp.data[np.where(im_image_res_seg_downsamp.data >= thr)] = 1
    im_image_res_seg_downsamp.data[np.where(im_image_res_seg_downsamp.data < thr)] = 0

    # post processing step to z_regularized
    im_image_res_seg_downsamp_postproc = post_processing_volume_wise(im_in=im_image_res_seg_downsamp)

    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("Remove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return im_image_res_seg_downsamp_postproc.change_orientation(original_orientation), \
           im_nii, \
           seg_uncrop_nii.change_orientation('RPI'), \
           im_image_res_labels_downsamp, \
           im_image_res_ctr_downsamp
Exemple #7
0
def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname):
    """
    Assumes RPI orientation
    :param algo:
    :param image_fname:
    :param contrast_type:
    :param brain_bool:
    :param folder_output:
    :param remove_temp_files:
    :param centerline_fname:
    :return:
    """

    # TODO: remove unnecessary i/o
    if Image(image_fname).dim[2] == 1:  # isct_spine_detect requires nz > 1
        im_concat = concat_data([image_fname, image_fname], dim=2)
        im_concat.save(sct.add_suffix(image_fname, '_concat'))
        image_fname = sct.add_suffix(image_fname, '_concat')
        bool_2d = True
    else:
        bool_2d = False

    # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline
    if algo == 'svm':
        # run optic on a heatmap computed by a trained SVM+HoG algorithm
        # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type))
        # # TODO: replace with get_centerline(method=optic)
        img_ctl, arr_ctl, _ = get_centerline(Image(image_fname), algo_fitting='optic', contrast=contrast_type)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        img_ctl.save(centerline_filename)

    elif algo == 'cnn':
        # CNN parameters
        dct_patch_ctr = {'t2': {'size': (80, 80), 'mean': 51.1417, 'std': 57.4408},
                         't2s': {'size': (80, 80), 'mean': 68.8591, 'std': 71.4659},
                         't1': {'size': (80, 80), 'mean': 55.7359, 'std': 64.3149},
                         'dwi': {'size': (80, 80), 'mean': 55.744, 'std': 45.003}}
        dct_params_ctr = {'t2': {'features': 16, 'dilation_layers': 2},
                          't2s': {'features': 8, 'dilation_layers': 3},
                          't1': {'features': 24, 'dilation_layers': 3},
                          'dwi': {'features': 8, 'dilation_layers': 2}}

        # load model
        ctr_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type))
        ctr_model = nn_architecture_ctr(height=dct_patch_ctr[contrast_type]['size'][0],
                                        width=dct_patch_ctr[contrast_type]['size'][1],
                                        channels=1,
                                        classes=1,
                                        features=dct_params_ctr[contrast_type]['features'],
                                        depth=2,
                                        temperature=1.0,
                                        padding='same',
                                        batchnorm=True,
                                        dropout=0.0,
                                        dilation_layers=dct_params_ctr[contrast_type]['dilation_layers'])
        ctr_model.load_weights(ctr_model_fname)

        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname, fname_res, new_resolution, 'mm', 'linear', verbose=0)

        # compute the heatmap
        fname_heatmap = sct.add_suffix(image_fname, "_heatmap")
        img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2])
        fname_heatmap_nii = img_filename + '.nii'
        z_max = heatmap(filename_in=fname_res,
                        filename_out=fname_heatmap_nii,
                        model=ctr_model,
                        patch_shape=dct_patch_ctr[contrast_type]['size'],
                        mean_train=dct_patch_ctr[contrast_type]['mean'],
                        std_train=dct_patch_ctr[contrast_type]['std'],
                        brain_bool=brain_bool)

        # run optic on the heatmap
        centerline_filename = sct.add_suffix(fname_heatmap, "_ctr")
        heatmap2optic(fname_heatmap=fname_heatmap_nii,
                      lambda_value=7 if contrast_type == 't2s' else 1,
                      fname_out=centerline_filename,
                      z_max=z_max if brain_bool else None)

    elif algo == 'viewer':
        im_labels = _call_viewer_centerline(Image(image_fname))
        im_centerline, arr_centerline, _ = get_centerline(im_labels)
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        labels_filename = sct.add_suffix(image_fname, "_labels-centerline")
        im_centerline.save(centerline_filename)
        im_labels.save(labels_filename)

    elif algo == 'file':
        centerline_filename = sct.add_suffix(image_fname, "_ctr")
        # Re-orient the manual centerline
        Image(centerline_fname).change_orientation('RPI').save(centerline_filename)

    else:
        logger.error('The parameter "-centerline" is incorrect. Please try again.')
        sys.exit(1)

    if algo != 'cnn':
        logger.info("Resample the image to 0.5x0.5 mm in-plane resolution...")
        fname_res = sct.add_suffix(image_fname, '_resampled')
        input_resolution = Image(image_fname).dim[4:7]
        new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])

        resampling.resample_file(image_fname, fname_res, new_resolution, 'mm', 'linear', verbose=0)

        resampling.resample_file(centerline_filename, centerline_filename, new_resolution, 'mm', 'linear', verbose=0)

    if bool_2d:
        im_split_lst = split_data(Image(centerline_filename), dim=2)
        im_split_lst[0].save(centerline_filename)

    return fname_res, centerline_filename
Exemple #8
0
def deep_segmentation_MSlesion(im_image, contrast_type, ctr_algo='svm', ctr_file=None, brain_bool=True,
                               remove_temp_files=1, verbose=1):
    """
    Segment lesions from MRI data.
    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # orientation of the image, should be RPI
    logger.info("\nReorient the image to RPI, if necessary...")
    fname_in = im_image.absolutepath
    original_orientation = im_image.orientation
    fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI').save(fname_orient)

    input_resolution = im_image.dim[4:7]

    # find the spinal cord centerline - execute OptiC binary
    logger.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    fname_res, centerline_filename = find_centerline(algo=ctr_algo,
                                                    image_fname=fname_orient,
                                                    contrast_type=contrast_type_ctr,
                                                    brain_bool=brain_bool,
                                                    folder_output=tmp_folder_path,
                                                    remove_temp_files=remove_temp_files,
                                                    centerline_fname=file_ctr)
    im_nii, ctr_nii = Image(fname_res), Image(centerline_filename)

    # crop image around the spinal cord centerline
    logger.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_nii,
                                                                                  ctr_in=ctr_nii,
                                                                                  crop_size=crop_size)
    del ctr_nii

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii, contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm, fname_res3d, '0.5x0.5x0.5', 'mm', 'linear',
                             verbose=0)

    # segment data using 3D convolutions
    logger.info("\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_lesion_models',
                                            '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res, fname_seg_res2d, initial_2d_resolution,
                                                           'mm', 'linear', verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    logger.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_nii, data_crop=seg_crop.copy().data, x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST, z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    logger.info("Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([str(input_resolution[0]), str(input_resolution[1]), str(input_resolution[2])])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI, fname_seg_RPI, initial_resolution,
                                                           'mm', 'linear', verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        fname_res_labels = sct.add_suffix(fname_orient, '_labels-centerline')
        resampling.resample_file(fname_res_labels, fname_res_labels, initial_resolution,
                                                           'mm', 'linear', verbose=0)
        im_image_res_labels_downsamp = Image(fname_res_labels).change_orientation(original_orientation)
    else:
        im_image_res_labels_downsamp = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr, fname_res_ctr, initial_resolution,
                                                           'mm', 'linear', verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info("\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # reorient to initial orientation
    logger.info("\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(original_orientation), im_image_res_labels_downsamp, im_image_res_ctr_downsamp
Exemple #9
0
def deep_segmentation_MSlesion(im_image,
                               contrast_type,
                               ctr_algo='svm',
                               ctr_file=None,
                               brain_bool=True,
                               remove_temp_files=1,
                               verbose=1):
    """
    Segment lesions from MRI data.

    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    tmp_folder = sct.TempFolder(verbose=verbose)
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'file':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()
    fname_in = im_image.absolutepath

    # re-orient image to RPI
    logger.info("Reorient the image to RPI, if necessary...")
    original_orientation = im_image.orientation
    # fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI')

    input_resolution = im_image.dim[4:7]

    # Resample image to 0.5mm in plane
    im_image_res = \
        resampling.resample_nib(im_image, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    fname_orient = 'image_in_RPI_res.nii'
    im_image_res.save(fname_orient)

    # find the spinal cord centerline - execute OptiC binary
    logger.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    _, im_ctl, im_labels_viewer = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type_ctr,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)
    if ctr_algo == 'file':
        im_ctl = \
            resampling.resample_nib(im_ctl, new_size=[0.5, 0.5, im_image.dim[6]], new_size_type='mm', interpolation='linear')

    # crop image around the spinal cord centerline
    logger.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_image_res, ctr_in=im_ctl, crop_size=crop_size)
    del im_ctl

    # normalize the intensity of the images
    logger.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii,
                                               contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm,
                             fname_res3d,
                             '0.5x0.5x0.5',
                             'mm',
                             'linear',
                             verbose=0)

    # segment data using 3D convolutions
    logger.info(
        "\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = sct_dir_local_path(
        'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res,
                             fname_seg_res2d,
                             initial_2d_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    logger.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_image_res,
                                  data_crop=seg_crop.copy().data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    logger.info(
        "Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI,
                             fname_seg_RPI,
                             initial_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    if ctr_algo == 'viewer':  # resample and reorient the viewer labels
        im_labels_viewer_nib = nib.nifti1.Nifti1Image(
            im_labels_viewer.data, im_labels_viewer.hdr.get_best_affine())
        im_viewer_r_nib = resampling.resample_nib(im_labels_viewer_nib,
                                                  new_size=input_resolution,
                                                  new_size_type='mm',
                                                  interpolation='linear')
        im_viewer = Image(
            im_viewer_r_nib.get_data(),
            hdr=im_viewer_r_nib.header,
            orientation='RPI',
            dim=im_viewer_r_nib.header.get_data_shape()).change_orientation(
                original_orientation)

    else:
        im_viewer = None

    if verbose == 2:
        fname_res_ctr = sct.add_suffix(fname_orient, '_ctr')
        resampling.resample_file(fname_res_ctr,
                                 fname_res_ctr,
                                 initial_resolution,
                                 'mm',
                                 'linear',
                                 verbose=0)
        im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(
            original_orientation)
    else:
        im_image_res_ctr_downsamp = None

    # binarize the resampled image to remove interpolation effects
    logger.info(
        "\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # change data type
    seg_initres_nii.change_type(np.uint8)

    # reorient to initial orientation
    logger.info(
        "\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        logger.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(
        original_orientation), im_viewer, im_image_res_ctr_downsamp
Exemple #10
0
def deep_segmentation_MSlesion(im_image,
                               contrast_type,
                               ctr_algo='svm',
                               ctr_file=None,
                               brain_bool=True,
                               remove_temp_files=1):
    """
    Segment lesions from MRI data.
    :param im_image: Image() object containing the lesions to segment
    :param contrast_type: Constrast of the image. Need to use one supported by the CNN models.
    :param ctr_algo: Algo to find the centerline. See sct_get_centerline
    :param ctr_file: Centerline or segmentation (optional)
    :param brain_bool: If brain if present or not in the image.
    :param remove_temp_files:
    :return:
    """

    # create temporary folder with intermediate results
    sct.log.info("\nCreating temporary folder...")
    tmp_folder = sct.TempFolder()
    tmp_folder_path = tmp_folder.get_path()
    if ctr_algo == 'manual':  # if the ctr_file is provided
        tmp_folder.copy_from(ctr_file)
        file_ctr = os.path.basename(ctr_file)
    else:
        file_ctr = None
    tmp_folder.chdir()

    # orientation of the image, should be RPI
    sct.log.info("\nReorient the image to RPI, if necessary...")
    fname_in = im_image.absolutepath
    original_orientation = im_image.orientation
    fname_orient = 'image_in_RPI.nii'
    im_image.change_orientation('RPI').save(fname_orient)

    input_resolution = im_image.dim[4:7]

    # find the spinal cord centerline - execute OptiC binary
    sct.log.info("\nFinding the spinal cord centerline...")
    contrast_type_ctr = contrast_type.split('_')[0]
    fname_res, centerline_filename = find_centerline(
        algo=ctr_algo,
        image_fname=fname_orient,
        contrast_type=contrast_type_ctr,
        brain_bool=brain_bool,
        folder_output=tmp_folder_path,
        remove_temp_files=remove_temp_files,
        centerline_fname=file_ctr)
    im_nii, ctr_nii = Image(fname_res), Image(centerline_filename)

    # crop image around the spinal cord centerline
    sct.log.info("\nCropping the image around the spinal cord...")
    crop_size = 48
    X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(
        im_in=im_nii, ctr_in=ctr_nii, crop_size=crop_size)
    del ctr_nii

    # normalize the intensity of the images
    sct.log.info("Normalizing the intensity...")
    im_norm_in = apply_intensity_normalization(img=im_crop_nii,
                                               contrast=contrast_type)
    del im_crop_nii

    # resample to 0.5mm isotropic
    fname_norm = sct.add_suffix(fname_orient, '_norm')
    im_norm_in.save(fname_norm)
    fname_res3d = sct.add_suffix(fname_norm, '_resampled3d')
    resampling.resample_file(fname_norm,
                             fname_res3d,
                             '0.5x0.5x0.5',
                             'mm',
                             'linear',
                             verbose=0)

    # segment data using 3D convolutions
    sct.log.info(
        "\nSegmenting the MS lesions using deep learning on 3D patches...")
    segmentation_model_fname = os.path.join(
        sct.__sct_dir__, 'data', 'deepseg_lesion_models',
        '{}_lesion.h5'.format(contrast_type))
    fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg')
    im_res3d = Image(fname_res3d)
    seg_im = segment_3d(model_fname=segmentation_model_fname,
                        contrast_type=contrast_type,
                        im=im_res3d.copy())
    seg_im.save(fname_seg_crop_res)
    del im_res3d, seg_im

    # resample to the initial pz resolution
    fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d')
    initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])])
    resampling.resample_file(fname_seg_crop_res,
                             fname_seg_res2d,
                             initial_2d_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_crop = Image(fname_seg_res2d)

    # reconstruct the segmentation from the crop data
    sct.log.info("\nReassembling the image...")
    seg_uncrop_nii = uncrop_image(ref_in=im_nii,
                                  data_crop=seg_crop.copy().data,
                                  x_crop_lst=X_CROP_LST,
                                  y_crop_lst=Y_CROP_LST,
                                  z_crop_lst=Z_CROP_LST)
    fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg')
    seg_uncrop_nii.save(fname_seg_res_RPI)
    del seg_crop

    # resample to initial resolution
    sct.log.info(
        "Resampling the segmentation to the original image resolution...")
    initial_resolution = 'x'.join([
        str(input_resolution[0]),
        str(input_resolution[1]),
        str(input_resolution[2])
    ])
    fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg')
    resampling.resample_file(fname_seg_res_RPI,
                             fname_seg_RPI,
                             initial_resolution,
                             'mm',
                             'linear',
                             verbose=0)
    seg_initres_nii = Image(fname_seg_RPI)

    # binarize the resampled image to remove interpolation effects
    sct.log.info(
        "\nBinarizing the segmentation to avoid interpolation effects...")
    thr = 0.1
    seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1
    seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0

    # reorient to initial orientation
    sct.log.info(
        "\nReorienting the segmentation to the original image orientation...")
    tmp_folder.chdir_undo()

    # remove temporary files
    if remove_temp_files:
        sct.log.info("\nRemove temporary files...")
        tmp_folder.cleanup()

    # reorient to initial orientation
    return seg_initres_nii.change_orientation(original_orientation)
def main(args=None):

    # initializations
    param = Param()

    # check user arguments
    if not args:
        args = sys.argv[1:]

    # Get parser info
    parser = get_parser()
    arguments = parser.parse(args)
    fname_data = arguments['-i']
    fname_seg = arguments['-s']
    if '-l' in arguments:
        fname_landmarks = arguments['-l']
        label_type = 'body'
    elif '-ldisc' in arguments:
        fname_landmarks = arguments['-ldisc']
        label_type = 'disc'
    else:
        sct.printv('ERROR: Labels should be provided.', 1, 'error')
    if '-ofolder' in arguments:
        path_output = arguments['-ofolder']
    else:
        path_output = ''

    param.path_qc = arguments.get("-qc", None)

    path_template = arguments['-t']
    contrast_template = arguments['-c']
    ref = arguments['-ref']
    param.remove_temp_files = int(arguments.get('-r'))
    verbose = int(arguments.get('-v'))
    sct.init_sct(log_level=verbose, update=True)  # Update log level
    param.verbose = verbose  # TODO: not clean, unify verbose or param.verbose in code, but not both
    param.straighten_fitting = arguments['-straighten-fitting']
    # if '-cpu-nb' in arguments:
    #     arg_cpu = ' -cpu-nb '+str(arguments['-cpu-nb'])
    # else:
    #     arg_cpu = ''
    # registration parameters
    if '-param' in arguments:
        # reset parameters but keep step=0 (might be overwritten if user specified step=0)
        paramreg = ParamregMultiStep([step0])
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'
        # add user parameters
        for paramStep in arguments['-param']:
            paramreg.addStep(paramStep)
    else:
        paramreg = ParamregMultiStep([step0, step1, step2])
        # if ref=subject, initialize registration using different affine parameters
        if ref == 'subject':
            paramreg.steps['0'].dof = 'Tx_Ty_Tz_Rx_Ry_Rz_Sz'

    # initialize other parameters
    zsubsample = param.zsubsample

    # retrieve template file names
    file_template_vertebral_labeling = get_file_label(os.path.join(path_template, 'template'), 'vertebral labeling')
    file_template = get_file_label(os.path.join(path_template, 'template'), contrast_template.upper() + '-weighted template')
    file_template_seg = get_file_label(os.path.join(path_template, 'template'), 'spinal cord')

    # start timer
    start_time = time.time()

    # get fname of the template + template objects
    fname_template = os.path.join(path_template, 'template', file_template)
    fname_template_vertebral_labeling = os.path.join(path_template, 'template', file_template_vertebral_labeling)
    fname_template_seg = os.path.join(path_template, 'template', file_template_seg)
    fname_template_disc_labeling = os.path.join(path_template, 'template', 'PAM50_label_disc.nii.gz')

    # check file existence
    # TODO: no need to do that!
    sct.printv('\nCheck template files...')
    sct.check_file_exist(fname_template, verbose)
    sct.check_file_exist(fname_template_vertebral_labeling, verbose)
    sct.check_file_exist(fname_template_seg, verbose)
    path_data, file_data, ext_data = sct.extract_fname(fname_data)

    # sct.printv(arguments)
    sct.printv('\nCheck parameters:', verbose)
    sct.printv('  Data:                 ' + fname_data, verbose)
    sct.printv('  Landmarks:            ' + fname_landmarks, verbose)
    sct.printv('  Segmentation:         ' + fname_seg, verbose)
    sct.printv('  Path template:        ' + path_template, verbose)
    sct.printv('  Remove temp files:    ' + str(param.remove_temp_files), verbose)

    # check input labels
    labels = check_labels(fname_landmarks, label_type=label_type)

    vertebral_alignment = False
    if len(labels) > 2 and label_type == 'disc':
        vertebral_alignment = True

    path_tmp = sct.tmp_create(basename="register_to_template", verbose=verbose)

    # set temporary file names
    ftmp_data = 'data.nii'
    ftmp_seg = 'seg.nii.gz'
    ftmp_label = 'label.nii.gz'
    ftmp_template = 'template.nii'
    ftmp_template_seg = 'template_seg.nii.gz'
    ftmp_template_label = 'template_label.nii.gz'

    # copy files to temporary folder
    sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    Image(fname_data).save(os.path.join(path_tmp, ftmp_data))
    Image(fname_seg).save(os.path.join(path_tmp, ftmp_seg))
    Image(fname_landmarks).save(os.path.join(path_tmp, ftmp_label))
    Image(fname_template).save(os.path.join(path_tmp, ftmp_template))
    Image(fname_template_seg).save(os.path.join(path_tmp, ftmp_template_seg))
    Image(fname_template_vertebral_labeling).save(os.path.join(path_tmp, ftmp_template_label))
    if label_type == 'disc':
        Image(fname_template_disc_labeling).save(os.path.join(path_tmp, ftmp_template_label))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Generate labels from template vertebral labeling
    if label_type == 'body':
        sct.printv('\nGenerate labels from template vertebral labeling', verbose)
        ftmp_template_label_, ftmp_template_label = ftmp_template_label, sct.add_suffix(ftmp_template_label, "_body")
        sct_label_utils.main(args=['-i', ftmp_template_label_, '-vert-body', '0', '-o', ftmp_template_label])

    # check if provided labels are available in the template
    sct.printv('\nCheck if provided labels are available in the template', verbose)
    image_label_template = Image(ftmp_template_label)
    labels_template = image_label_template.getNonZeroCoordinates(sorting='value')
    if labels[-1].value > labels_template[-1].value:
        sct.printv('ERROR: Wrong landmarks input. Labels must have correspondence in template space. \nLabel max '
                   'provided: ' + str(labels[-1].value) + '\nLabel max from template: ' +
                   str(labels_template[-1].value), verbose, 'error')

    # if only one label is present, force affine transformation to be Tx,Ty,Tz only (no scaling)
    if len(labels) == 1:
        paramreg.steps['0'].dof = 'Tx_Ty_Tz'
        sct.printv('WARNING: Only one label is present. Forcing initial transformation to: ' + paramreg.steps['0'].dof,
                   1, 'warning')

    # Project labels onto the spinal cord centerline because later, an affine transformation is estimated between the
    # template's labels (centered in the cord) and the subject's labels (assumed to be centered in the cord).
    # If labels are not centered, mis-registration errors are observed (see issue #1826)
    ftmp_label = project_labels_on_spinalcord(ftmp_label, ftmp_seg)

    # binarize segmentation (in case it has values below 0 caused by manual editing)
    sct.printv('\nBinarize segmentation', verbose)
    ftmp_seg_, ftmp_seg = ftmp_seg, sct.add_suffix(ftmp_seg, "_bin")
    sct_maths.main(['-i', ftmp_seg_,
                    '-bin', '0.5',
                    '-o', ftmp_seg])

    # Switch between modes: subject->template or template->subject
    if ref == 'template':

        # resample data to 1mm isotropic
        sct.printv('\nResample data to 1mm isotropic...', verbose)
        resample_file(ftmp_data, add_suffix(ftmp_data, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_data = add_suffix(ftmp_data, '_1mm')
        resample_file(ftmp_seg, add_suffix(ftmp_seg, '_1mm'), '1.0x1.0x1.0', 'mm', 'linear', verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_1mm')
        # N.B. resampling of labels is more complicated, because they are single-point labels, therefore resampling
        # with nearest neighbour can make them disappear.
        resample_labels(ftmp_label, ftmp_data, add_suffix(ftmp_label, '_1mm'))
        ftmp_label = add_suffix(ftmp_label, '_1mm')

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)

        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath


        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        if vertebral_alignment:
            # cropping the segmentation based on the label coverage to ensure good registration with vertebral alignment
            # See https://github.com/neuropoly/spinalcordtoolbox/pull/1669 for details
            image_labels = Image(ftmp_label)
            coordinates_labels = image_labels.getNonZeroCoordinates(sorting='z')
            nx, ny, nz, nt, px, py, pz, pt = image_labels.dim
            offset_crop = 10.0 * pz  # cropping the image 10 mm above and below the highest and lowest label
            cropping_slices = [coordinates_labels[0].z - offset_crop, coordinates_labels[-1].z + offset_crop]
            # make sure that the cropping slices do not extend outside of the slice range (issue #1811)
            if cropping_slices[0] < 0:
                cropping_slices[0] = 0
            if cropping_slices[1] > nz:
                cropping_slices[1] = nz
            msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, np.int32(np.round(cropping_slices))),))).save(ftmp_seg)
        else:
            # if we do not align the vertebral levels, we crop the segmentation from top to bottom
            im_seg_rpi = Image(ftmp_seg_)
            bottom = 0
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "IS"):
                if (data != 0).any():
                    break
                bottom += 1
            top = im_seg_rpi.data.shape[2]
            for data in msct_image.SlicerOneAxis(im_seg_rpi, "SI"):
                if (data != 0).any():
                    break
                top -= 1
            msct_image.spatial_crop(im_seg_rpi, dict(((2, (bottom, top)),))).save(ftmp_seg)


        # straighten segmentation
        sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose)

        # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
        fn_warp_curve2straight = os.path.join(curdir, "warp_curve2straight.nii.gz")
        fn_warp_straight2curve = os.path.join(curdir, "warp_straight2curve.nii.gz")
        fn_straight_ref = os.path.join(curdir, "straight_ref.nii.gz")

        cache_input_files=[ftmp_seg]
        if vertebral_alignment:
            cache_input_files += [
             ftmp_template_seg,
             ftmp_label,
             ftmp_template_label,
            ]
        cache_sig = sct.cache_signature(
         input_files=cache_input_files,
        )
        cachefile = os.path.join(curdir, "straightening.cache")
        if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(fn_warp_curve2straight) and os.path.isfile(fn_warp_straight2curve) and os.path.isfile(fn_straight_ref):
            sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning')
            sct.copy(fn_warp_curve2straight, 'warp_curve2straight.nii.gz')
            sct.copy(fn_warp_straight2curve, 'warp_straight2curve.nii.gz')
            sct.copy(fn_straight_ref, 'straight_ref.nii.gz')
            # apply straightening
            sct.run(['sct_apply_transfo', '-i', ftmp_seg, '-w', 'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o', add_suffix(ftmp_seg, '_straight')])
        else:
            from spinalcordtoolbox.straightening import SpinalCordStraightener
            sc_straight = SpinalCordStraightener(ftmp_seg, ftmp_seg)
            sc_straight.algo_fitting = param.straighten_fitting
            sc_straight.output_filename = add_suffix(ftmp_seg, '_straight')
            sc_straight.path_output = './'
            sc_straight.qc = '0'
            sc_straight.remove_temp_files = param.remove_temp_files
            sc_straight.verbose = verbose

            if vertebral_alignment:
                sc_straight.centerline_reference_filename = ftmp_template_seg
                sc_straight.use_straight_reference = True
                sc_straight.discs_input_filename = ftmp_label
                sc_straight.discs_ref_filename = ftmp_template_label

            sc_straight.straighten()
            sct.cache_save(cachefile, cache_sig)

        # N.B. DO NOT UPDATE VARIABLE ftmp_seg BECAUSE TEMPORARY USED LATER
        # re-define warping field using non-cropped space (to avoid issue #367)
        s, o = sct.run(['sct_concat_transfo', '-w', 'warp_straight2curve.nii.gz', '-d', ftmp_data, '-o', 'warp_straight2curve.nii.gz'])

        if vertebral_alignment:
            sct.copy('warp_curve2straight.nii.gz', 'warp_curve2straightAffine.nii.gz')
        else:
            # Label preparation:
            # --------------------------------------------------------------------------------
            # Remove unused label on template. Keep only label present in the input label image
            sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
            sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

            # Dilating the input label so they can be straighten without losing them
            sct.printv('\nDilating input labels using 3vox ball radius')
            sct_maths.main(['-i', ftmp_label,
                            '-dilate', '3',
                            '-o', add_suffix(ftmp_label, '_dilate')])
            ftmp_label = add_suffix(ftmp_label, '_dilate')

            # Apply straightening to labels
            sct.printv('\nApply straightening to labels...', verbose)
            sct.run(['sct_apply_transfo', '-i', ftmp_label, '-o', add_suffix(ftmp_label, '_straight'), '-d', add_suffix(ftmp_seg, '_straight'), '-w', 'warp_curve2straight.nii.gz', '-x', 'nn'])
            ftmp_label = add_suffix(ftmp_label, '_straight')

            # Compute rigid transformation straight landmarks --> template landmarks
            sct.printv('\nEstimate transformation for step #0...', verbose)
            try:
                register_landmarks(ftmp_label, ftmp_template_label, paramreg.steps['0'].dof,
                                   fname_affine='straight2templateAffine.txt', verbose=verbose)
            except RuntimeError:
                raise('Input labels do not seem to be at the right place. Please check the position of the labels. '
                      'See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42')

            # Concatenate transformations: curve --> straight --> affine
            sct.printv('\nConcatenate transformations: curve --> straight --> affine...', verbose)
            sct.run(['sct_concat_transfo', '-w', 'warp_curve2straight.nii.gz,straight2templateAffine.txt', '-d', 'template.nii', '-o', 'warp_curve2straightAffine.nii.gz'])

        # Apply transformation
        sct.printv('\nApply transformation...', verbose)
        sct.run(['sct_apply_transfo', '-i', ftmp_data, '-o', add_suffix(ftmp_data, '_straightAffine'), '-d', ftmp_template, '-w', 'warp_curve2straightAffine.nii.gz'])
        ftmp_data = add_suffix(ftmp_data, '_straightAffine')
        sct.run(['sct_apply_transfo', '-i', ftmp_seg, '-o', add_suffix(ftmp_seg, '_straightAffine'), '-d', ftmp_template, '-w', 'warp_curve2straightAffine.nii.gz', '-x', 'linear'])
        ftmp_seg = add_suffix(ftmp_seg, '_straightAffine')

        """
        # Benjamin: Issue from Allan Martin, about the z=0 slice that is screwed up, caused by the affine transform.
        # Solution found: remove slices below and above landmarks to avoid rotation effects
        points_straight = []
        for coord in landmark_template:
            points_straight.append(coord.z)
        min_point, max_point = int(np.round(np.min(points_straight))), int(np.round(np.max(points_straight)))
        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_black')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (min_point,max_point)),))).save(ftmp_seg)

        """
        # open segmentation
        im = Image(ftmp_seg)
        im_new = msct_image.empty_like(im)
        # binarize
        im_new.data = im.data > 0.5
        # find min-max of anat2template (for subsequent cropping)
        zmin_template, zmax_template = msct_image.find_zmin_zmax(im_new, threshold=0.5)
        # save binarized segmentation
        im_new.save(add_suffix(ftmp_seg, '_bin')) # unused?
        # crop template in z-direction (for faster processing)
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nCrop data in template space (for faster processing)...', verbose)
        ftmp_template_, ftmp_template = ftmp_template, add_suffix(ftmp_template, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template)

        ftmp_template_seg_, ftmp_template_seg = ftmp_template_seg, add_suffix(ftmp_template_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_template_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_template_seg)

        ftmp_data_, ftmp_data = ftmp_data, add_suffix(ftmp_data, '_crop')
        msct_image.spatial_crop(Image(ftmp_data_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_data)

        ftmp_seg_, ftmp_seg = ftmp_seg, add_suffix(ftmp_seg, '_crop')
        msct_image.spatial_crop(Image(ftmp_seg_), dict(((2, (zmin_template,zmax_template)),))).save(ftmp_seg)

        # sub-sample in z-direction
        # TODO: refactor to use python module instead of doing i/o
        sct.printv('\nSub-sample in z-direction (for faster processing)...', verbose)
        sct.run(['sct_resample', '-i', ftmp_template, '-o', add_suffix(ftmp_template, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template = add_suffix(ftmp_template, '_sub')
        sct.run(['sct_resample', '-i', ftmp_template_seg, '-o', add_suffix(ftmp_template_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_template_seg = add_suffix(ftmp_template_seg, '_sub')
        sct.run(['sct_resample', '-i', ftmp_data, '-o', add_suffix(ftmp_data, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_data = add_suffix(ftmp_data, '_sub')
        sct.run(['sct_resample', '-i', ftmp_seg, '-o', add_suffix(ftmp_seg, '_sub'), '-f', '1x1x' + zsubsample], verbose)
        ftmp_seg = add_suffix(ftmp_seg, '_sub')

        # Registration straight spinal cord to template
        sct.printv('\nRegister straight spinal cord to template...', verbose)

        # loop across registration steps
        warp_forward = []
        warp_inverse = []
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_data
                dest = ftmp_template
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_seg
                dest = ftmp_template_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')

            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':
                src_seg = ftmp_seg
                dest_seg = ftmp_template_seg
            # if step>1, apply warp_forward_concat to the src image to be used
            if i_step > 1:
                # apply transformation from previous step, to use as new src for registration
                sct.run(['sct_apply_transfo', '-i', src, '-d', dest, '-w', ','.join(warp_forward), '-o', add_suffix(src, '_regStep' + str(i_step - 1)), '-x', interp_step], verbose)
                src = add_suffix(src, '_regStep' + str(i_step - 1))
                if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog':  # also apply transformation to the seg
                    sct.run(['sct_apply_transfo', '-i', src_seg, '-d', dest_seg, '-w', ','.join(warp_forward), '-o', add_suffix(src, '_regStep' + str(i_step - 1)), '-x', interp_step], verbose)
                    src_seg = add_suffix(src_seg, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            if paramreg.steps[str(i_step)].algo == 'centermassrot' and paramreg.steps[str(i_step)].rot_method == 'hog': # im_seg case
                warp_forward_out, warp_inverse_out = register([src, src_seg], [dest, dest_seg], paramreg, param, str(i_step))
            else:
                warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.append(warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: anat --> template...', verbose)
        sct.run(['sct_concat_transfo', '-w', 'warp_curve2straightAffine.nii.gz,' + ','.join(warp_forward), '-d', 'template.nii', '-o', 'warp_anat2template.nii.gz'], verbose)
        # sct.run('sct_concat_transfo -w warp_curve2straight.nii.gz,straight2templateAffine.txt,'+','.join(warp_forward)+' -d template.nii -o warp_anat2template.nii.gz', verbose)
        sct.printv('\nConcatenate transformations: template --> anat...', verbose)
        warp_inverse.reverse()

        if vertebral_alignment:
            sct.run(['sct_concat_transfo', '-w', ','.join(warp_inverse) + ',warp_straight2curve.nii.gz', '-d', 'data.nii', '-o', 'warp_template2anat.nii.gz'], verbose)
        else:
            sct.run(['sct_concat_transfo', '-w', ','.join(warp_inverse) + ',-straight2templateAffine.txt,warp_straight2curve.nii.gz', '-d', 'data.nii', '-o', 'warp_template2anat.nii.gz'], verbose)

    # register template->subject
    elif ref == 'subject':

        # Change orientation of input images to RPI
        sct.printv('\nChange orientation of input images to RPI...', verbose)
        ftmp_data = Image(ftmp_data).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_seg = Image(ftmp_seg).change_orientation("RPI", generate_path=True).save().absolutepath
        ftmp_label = Image(ftmp_label).change_orientation("RPI", generate_path=True).save().absolutepath

        # Remove unused label on template. Keep only label present in the input label image
        sct.printv('\nRemove unused label on template. Keep only label present in the input label image...', verbose)
        sct.run(['sct_label_utils', '-i', ftmp_template_label, '-o', ftmp_template_label, '-remove-reference', ftmp_label])

        # Add one label because at least 3 orthogonal labels are required to estimate an affine transformation. This
        # new label is added at the level of the upper most label (lowest value), at 1cm to the right.
        for i_file in [ftmp_label, ftmp_template_label]:
            im_label = Image(i_file)
            coord_label = im_label.getCoordinatesAveragedByValue()  # N.B. landmarks are sorted by value
            # Create new label
            from copy import deepcopy
            new_label = deepcopy(coord_label[0])
            # move it 5mm to the left (orientation is RAS)
            nx, ny, nz, nt, px, py, pz, pt = im_label.dim
            new_label.x = np.round(coord_label[0].x + 5.0 / px)
            # assign value 99
            new_label.value = 99
            # Add to existing image
            im_label.data[int(new_label.x), int(new_label.y), int(new_label.z)] = new_label.value
            # Overwrite label file
            # im_label.absolutepath = 'label_rpi_modif.nii.gz'
            im_label.save()

        # Bring template to subject space using landmark-based transformation
        sct.printv('\nEstimate transformation for step #0...', verbose)
        warp_forward = ['template2subjectAffine.txt']
        warp_inverse = ['-template2subjectAffine.txt']
        try:
            register_landmarks(ftmp_template_label, ftmp_label, paramreg.steps['0'].dof, fname_affine=warp_forward[0], verbose=verbose, path_qc="./")
        except Exception:
            sct.printv('ERROR: input labels do not seem to be at the right place. Please check the position of the labels. See documentation for more details: https://www.slideshare.net/neuropoly/sct-course-20190121/42', verbose=verbose, type='error')

        # loop across registration steps
        for i_step in range(1, len(paramreg.steps)):
            sct.printv('\nEstimate transformation for step #' + str(i_step) + '...', verbose)
            # identify which is the src and dest
            if paramreg.steps[str(i_step)].type == 'im':
                src = ftmp_template
                dest = ftmp_data
                interp_step = 'linear'
            elif paramreg.steps[str(i_step)].type == 'seg':
                src = ftmp_template_seg
                dest = ftmp_seg
                interp_step = 'nn'
            else:
                sct.printv('ERROR: Wrong image type.', 1, 'error')
            # apply transformation from previous step, to use as new src for registration
            sct.run(['sct_apply_transfo', '-i', src, '-d', dest, '-w', ','.join(warp_forward), '-o', add_suffix(src, '_regStep' + str(i_step - 1)), '-x', interp_step], verbose)
            src = add_suffix(src, '_regStep' + str(i_step - 1))
            # register src --> dest
            # TODO: display param for debugging
            warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step))
            warp_forward.append(warp_forward_out)
            warp_inverse.insert(0, warp_inverse_out)

        # Concatenate transformations:
        sct.printv('\nConcatenate transformations: template --> subject...', verbose)
        sct.run(['sct_concat_transfo', '-w', ','.join(warp_forward), '-d', 'data.nii', '-o', 'warp_template2anat.nii.gz'], verbose)
        sct.printv('\nConcatenate transformations: subject --> template...', verbose)
        sct.run(['sct_concat_transfo', '-w', ','.join(warp_inverse), '-d', 'template.nii', '-o', 'warp_anat2template.nii.gz'], verbose)

    # Apply warping fields to anat and template
    sct.run(['sct_apply_transfo', '-i', 'template.nii', '-o', 'template2anat.nii.gz', '-d', 'data.nii', '-w', 'warp_template2anat.nii.gz', '-crop', '1'], verbose)
    sct.run(['sct_apply_transfo', '-i', 'data.nii', '-o', 'anat2template.nii.gz', '-d', 'template.nii', '-w', 'warp_anat2template.nii.gz', '-crop', '1'], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    sct.printv('\nGenerate output files...', verbose)
    fname_template2anat = os.path.join(path_output, 'template2anat' + ext_data)
    fname_anat2template = os.path.join(path_output, 'anat2template' + ext_data)
    sct.generate_output_file(os.path.join(path_tmp, "warp_template2anat.nii.gz"), os.path.join(path_output, "warp_template2anat.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "warp_anat2template.nii.gz"), os.path.join(path_output, "warp_anat2template.nii.gz"), verbose)
    sct.generate_output_file(os.path.join(path_tmp, "template2anat.nii.gz"), fname_template2anat, verbose)
    sct.generate_output_file(os.path.join(path_tmp, "anat2template.nii.gz"), fname_anat2template, verbose)
    if ref == 'template':
        # copy straightening files in case subsequent SCT functions need them
        sct.generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose)
        sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose)

    # Delete temporary files
    if param.remove_temp_files:
        sct.printv('\nDelete temporary files...', verbose)
        sct.rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    qc_dataset = arguments.get("-qc-dataset", None)
    qc_subject = arguments.get("-qc-subject", None)
    if param.path_qc is not None:
        generate_qc(fname_data, fname_in2=fname_template2anat, fname_seg=fname_seg, args=args,
                    path_qc=os.path.abspath(param.path_qc), dataset=qc_dataset, subject=qc_subject,
                    process='sct_register_to_template')
    sct.display_viewer_syntax([fname_data, fname_template2anat], verbose=verbose)
    sct.display_viewer_syntax([fname_template, fname_anat2template], verbose=verbose)