def main(argv=None):
    parser = get_parser()
    if sys.platform.startswith("win32"):
        # This isn't *really* a parsing error, but it feels a little more official to display the help with this error
        parser.error(
            "`sct_propseg` is not currently supported on native Windows installations. \n\n"
            "For spinal cord segmentation, please migrate to the new and improved `sct_deepseg_sc` tool, "
            "or consider using WSL to install SCT instead.\n\n"
            "For further updates on `sct_propseg` Windows support, please visit:\n"
            "https://github.com/spinalcordtoolbox/spinalcordtoolbox/issues/3694"
        )
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_input_data = os.path.abspath(arguments.i)
    img_input = Image(fname_input_data)
    img_seg = propseg(img_input, arguments)
    fname_seg = img_seg.absolutepath
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject
    if path_qc is not None:
        generate_qc(fname_in1=fname_input_data,
                    fname_seg=fname_seg,
                    args=arguments,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_propseg')
    display_viewer_syntax([fname_input_data, fname_seg],
                          colormaps=['gray', 'red'],
                          opacities=['', '1'])
Example #2
0
def main(argv):
    parser = get_parser()
    args = parser.parse_args(argv if argv else ['--help'])

    # Deal with task
    if args.list_tasks:
        deepseg.models.display_list_tasks()

    if args.install_task is not None:
        for name_model in deepseg.models.TASKS[args.install_task]['models']:
            deepseg.models.install_model(name_model)
        exit(0)

    # Deal with input/output
    if not os.path.isfile(args.i):
        parser.error("This file does not exist: {}".format(args.i))

    # Check if at least a model or task has been specified
    if args.task is None:
        parser.error("You need to specify a task.")

    # Get pipeline model names
    name_models = deepseg.models.TASKS[args.task]['models']

    # Run pipeline by iterating through the models
    fname_prior = None
    for name_model in name_models:
        # Check if this is an official model
        if name_model in list(deepseg.models.MODELS.keys()):
            # If it is, check if it is installed
            path_model = deepseg.models.folder(name_model)
            if not deepseg.models.is_valid(path_model):
                printv("Model {} is not installed. Installing it now...".format(name_model))
                deepseg.models.install_model(name_model)
        # If it is not, check if this is a path to a valid model
        else:
            path_model = os.path.abspath(name_model)
            if not deepseg.models.is_valid(path_model):
                parser.error("The input model is invalid: {}".format(path_model))

        # Call segment_nifti
        options = {**vars(args), "fname_prior": fname_prior}
        nii_seg = imed.utils.segment_volume(path_model, args.i, options=options)

        # Save output seg
        if 'o' in options and options['o'] is not None:
            fname_seg = options['o']
        else:
            fname_seg = ''.join([sct.image.splitext(args.i)[0], '_seg.nii.gz'])

        # If output folder does not exist, create it
        path_out = os.path.dirname(fname_seg)
        if not (path_out == '' or os.path.exists(path_out)):
            os.makedirs(path_out)
        nib.save(nii_seg, fname_seg)

        # Use the result of the current model as additional input of the next model
        fname_prior = fname_seg

    display_viewer_syntax([args.i, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])
Example #3
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # create param objects
    param = Param()

    # set param arguments ad inputted by user
    list_fname_src = arguments.i
    fname_dest = arguments.d
    list_fname_warp = arguments.w
    param.fname_out = arguments.o

    # if arguments.ofolder is not None
    #     path_results = arguments.ofolder
    if arguments.x is not None:
        param.interp = arguments.x
    if arguments.r is not None:
        param.rm_tmp = arguments.r

    # check if list of input files and warping fields have same length
    assert len(list_fname_src) == len(
        list_fname_warp), "ERROR: list of files are not of the same length"

    # merge src images to destination image
    try:
        merge_images(list_fname_src, fname_dest, list_fname_warp, param)
    except Exception as e:
        printv(str(e), 1, 'error')

    display_viewer_syntax([fname_dest, os.path.abspath(param.fname_out)])
Example #4
0
def main():
    # create param objects
    param = Param()

    # get parser
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # set param arguments ad inputted by user
    list_fname_src = arguments.i
    fname_dest = arguments.d
    list_fname_warp = arguments.w
    param.fname_out = arguments.o

    # if arguments.ofolder is not None
    #     path_results = arguments.ofolder
    if arguments.x is not None:
        param.interp = arguments.x
    if arguments.r is not None:
        param.rm_tmp = arguments.r
    param.verbose = arguments.v
    init_sct(log_level=param.verbose, update=True)  # Update log level

    # check if list of input files and warping fields have same length
    assert len(list_fname_src) == len(
        list_fname_warp), "ERROR: list of files are not of the same length"

    # merge src images to destination image
    try:
        merge_images(list_fname_src, fname_dest, list_fname_warp, param)
    except Exception as e:
        printv(str(e), 1, 'error')

    display_viewer_syntax([fname_dest, os.path.abspath(param.fname_out)])
def main():
    parser = get_parser()
    args = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # Deal with model
    if args.list_models:
        deepseg.models.display_list_models()

    # Deal with task
    if args.list_tasks:
        deepseg.models.display_list_tasks()

    if args.install_model is not None:
        deepseg.models.install_model(args.install_model)
        exit(0)

    if args.install_task is not None:
        for name_model in deepseg.models.TASKS[args.install_task]['models']:
            deepseg.models.install_model(name_model)
        exit(0)

    # Deal with input/output
    if not os.path.isfile(args.i):
        parser.error("This file does not exist: {}".format(args.i))

    # Check if at least a model or task has been specified
    if args.model is None and args.task is None:
        parser.error("You need to specify a model or a task.")

    # Get pipeline model names
    if args.task is not None:
        name_models = deepseg.models.TASKS[args.task]['models']

    if args.model is not None:
        name_models = args.model

    # Run pipeline by iterating through the models
    fname_prior = None
    for name_model in name_models:
        # Check if this is an official model
        if name_model in list(deepseg.models.MODELS.keys()):
            # If it is, check if it is installed
            path_model = deepseg.models.folder(name_model)
            if not deepseg.models.is_valid(path_model):
                printv("Model {} is not installed. Installing it now...".format(name_model))
                deepseg.models.install_model(name_model)
        # If it is not, check if this is a path to a valid model
        else:
            path_model = os.path.abspath(name_model)
            if not deepseg.models.is_valid(path_model):
                parser.error("The input model is invalid: {}".format(path_model))

        # Call segment_nifti
        fname_seg = deepseg.core.segment_nifti(args.i, path_model, fname_prior, vars(args))
        # Use the result of the current model as additional input of the next model
        fname_prior = fname_seg

    display_viewer_syntax([args.i, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])
def main():
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # Set param arguments ad inputted by user
    fname_in = arguments.i
    contrast = arguments.c

    # Segmentation or Centerline line
    if arguments.s is not None:
        fname_seg = arguments.s
        if not os.path.isfile(fname_seg):
            fname_seg = None
            printv('WARNING: -s input file: "' + arguments.s + '" does not exist.\nDetecting PMJ without using segmentation information', 1, 'warning')
    else:
        fname_seg = None

    # Output Folder
    if arguments.ofolder is not None:
        path_results = arguments.ofolder
        if not os.path.isdir(path_results) and os.path.exists(path_results):
            printv("ERROR output directory %s is not a valid directory" % path_results, 1, 'error')
        if not os.path.exists(path_results):
            os.makedirs(path_results)
    else:
        path_results = '.'

    path_qc = arguments.qc

    # Remove temp folder
    rm_tmp = bool(arguments.r)

    verbose = arguments.v
    init_sct(log_level=verbose, update=True)  # Update log level

    # Initialize DetectPMJ
    detector = DetectPMJ(fname_im=fname_in,
                         contrast=contrast,
                         fname_seg=fname_seg,
                         path_out=path_results,
                         verbose=verbose)

    # run the extraction
    fname_out, tmp_dir = detector.apply()

    # Remove tmp_dir
    if rm_tmp:
        rmtree(tmp_dir)

    # View results
    if fname_out is not None:
        if path_qc is not None:
            from spinalcordtoolbox.reports.qc import generate_qc
            generate_qc(fname_in, fname_seg=fname_out, args=sys.argv[1:], path_qc=os.path.abspath(path_qc), process='sct_detect_pmj')

        display_viewer_syntax([fname_in, fname_out], colormaps=['gray', 'red'])
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    param = Param()

    fname_src = arguments.d
    fname_transfo = arguments.w
    warp_atlas = arguments.a
    warp_spinal_levels = arguments.s
    folder_out = arguments.ofolder
    path_template = arguments.t
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject
    folder_template = param.folder_template
    folder_atlas = param.folder_atlas
    folder_spinal_levels = param.folder_spinal_levels
    file_info_label = param.file_info_label
    list_labels_nn = param.list_labels_nn

    # call main function
    w = WarpTemplate(fname_src, fname_transfo, warp_atlas, warp_spinal_levels, folder_out, path_template,
                     folder_template, folder_atlas, folder_spinal_levels, file_info_label, list_labels_nn, verbose)

    path_template = os.path.join(w.folder_out, w.folder_template)

    # Deal with QC report
    if path_qc is not None:
        try:
            fname_wm = os.path.join(
                w.folder_out, w.folder_template, spinalcordtoolbox.metadata.get_file_label(path_template, id_label=4))  # label = 'white matter mask (probabilistic)'
            generate_qc(
                fname_src, fname_seg=fname_wm, args=sys.argv[1:], path_qc=os.path.abspath(path_qc), dataset=qc_dataset,
                subject=qc_subject, process='sct_warp_template')
        # If label is missing, get_file_label() throws a RuntimeError
        except RuntimeError:
            printv("QC not generated since expected labels are missing from template", type="warning")

    # Deal with verbose
    try:
        display_viewer_syntax(
            [fname_src,
             spinalcordtoolbox.metadata.get_file_label(path_template, id_label=1, output="filewithpath"),  # label = 'T2-weighted template'
             spinalcordtoolbox.metadata.get_file_label(path_template, id_label=5, output="filewithpath"),  # label = 'gray matter mask (probabilistic)'
             spinalcordtoolbox.metadata.get_file_label(path_template, id_label=4, output="filewithpath")],  # label = 'white matter mask (probabilistic)'
            colormaps=['gray', 'gray', 'red-yellow', 'blue-lightblue'],
            opacities=['1', '1', '0.5', '0.5'],
            minmax=['', '0,4000', '0.4,1', '0.4,1'],
            verbose=verbose)
    # If label is missing, continue silently
    except RuntimeError:
        pass
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initialization
    param = ParamMoco(is_diffusion=True, group_size=3, metric='MI', smooth='1')

    # Fetch user arguments
    param.fname_data = arguments.i
    param.fname_bvecs = arguments.bvec
    param.fname_bvals = arguments.bval
    param.bval_min = arguments.bvalmin
    param.group_size = arguments.g
    param.fname_mask = arguments.m
    param.interp = arguments.x
    param.path_out = arguments.ofolder
    param.remove_temp_files = arguments.r
    if arguments.param is not None:
        param.update(arguments.param)

    path_qc = arguments.qc
    qc_fps = arguments.qc_fps
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject
    qc_seg = arguments.qc_seg

    mutually_inclusive_args = (path_qc, qc_seg)
    is_qc_none, is_seg_none = [arg is None for arg in mutually_inclusive_args]
    if not (is_qc_none == is_seg_none):
        raise parser.error(
            "Both '-qc' and '-qc-seg' are required in order to generate a QC report."
        )

    # run moco
    fname_output_image = moco_wrapper(param)

    set_global_loglevel(
        verbose)  # moco_wrapper changes verbose to 0, see issue #3341

    # QC report
    if path_qc is not None:
        generate_qc(fname_in1=fname_output_image,
                    fname_in2=param.fname_data,
                    fname_seg=qc_seg,
                    args=sys.argv[1:],
                    path_qc=os.path.abspath(path_qc),
                    fps=qc_fps,
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_dmri_moco')

    display_viewer_syntax([fname_output_image, param.fname_data],
                          mode='ortho,ortho')
Example #9
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    fname_input_data = os.path.abspath(arguments.i)
    img_input = Image(fname_input_data)
    img_seg = propseg(img_input, arguments)
    fname_seg = img_seg.absolutepath
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject
    if path_qc is not None:
        generate_qc(fname_in1=fname_input_data, fname_seg=fname_seg, args=arguments, path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset, subject=qc_subject, process='sct_propseg')
    display_viewer_syntax([fname_input_data, fname_seg], colormaps=['gray', 'red'], opacities=['', '1'])
Example #10
0
def main(arguments):
    fname_input_data = os.path.abspath(arguments.i)
    img_input = Image(fname_input_data)
    img_seg = propseg(img_input, arguments)
    fname_seg = img_seg.absolutepath
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject
    if path_qc is not None:
        generate_qc(fname_in1=fname_input_data,
                    fname_seg=fname_seg,
                    args=arguments,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_propseg')
    display_viewer_syntax([fname_input_data, fname_seg],
                          colormaps=['gray', 'red'],
                          opacities=['', '1'])
Example #11
0
def main(argv=None):
    """Main function."""
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_image = os.path.abspath(arguments.i)
    contrast_type = arguments.c

    ctr_algo = arguments.centerline

    if arguments.brain is None:
        if contrast_type in ['t2s', 'dwi']:
            brain_bool = False
        if contrast_type in ['t1', 't2']:
            brain_bool = True
    else:
        brain_bool = bool(arguments.brain)

    if bool(arguments.brain) and ctr_algo == 'svm':
        printv('Please only use the flag "-brain 1" with "-centerline cnn".',
               1, 'warning')
        sys.exit(1)

    kernel_size = arguments.kernel
    if kernel_size == '3d' and contrast_type == 'dwi':
        kernel_size = '2d'
        printv(
            '3D kernel model for dwi contrast is not available. 2D kernel model is used instead.',
            type="warning")

    if ctr_algo == 'file' and arguments.file_centerline is None:
        printv(
            'Please use the flag -file_centerline to indicate the centerline filename.',
            1, 'warning')
        sys.exit(1)

    if arguments.file_centerline is not None:
        manual_centerline_fname = arguments.file_centerline
        ctr_algo = 'file'
    else:
        manual_centerline_fname = None

    if arguments.o is not None:
        fname_out = arguments.o
    else:
        path, file_name, ext = extract_fname(fname_image)
        fname_out = file_name + '_seg' + ext

    threshold = arguments.thr

    if threshold is not None:
        if threshold > 1.0 or (threshold < 0.0 and threshold != -1.0):
            raise SyntaxError(
                "Threshold should be between 0 and 1, or equal to -1 (no threshold)"
            )

    remove_temp_files = arguments.r

    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject
    output_folder = arguments.ofolder

    # check if input image is 2D or 3D
    check_dim(fname_image, dim_lst=[2, 3])

    # Segment image

    im_image = Image(fname_image)
    # note: below we pass im_image.copy() otherwise the field absolutepath becomes None after execution of this function
    im_seg, im_image_RPI_upsamp, im_seg_RPI_upsamp = \
        deep_segmentation_spinalcord(im_image.copy(), contrast_type, ctr_algo=ctr_algo,
                                     ctr_file=manual_centerline_fname, brain_bool=brain_bool, kernel_size=kernel_size,
                                     threshold_seg=threshold, remove_temp_files=remove_temp_files, verbose=verbose)

    # Save segmentation
    fname_seg = os.path.abspath(os.path.join(output_folder, fname_out))
    im_seg.save(fname_seg)

    # Generate QC report
    if path_qc is not None:
        generate_qc(fname_image,
                    fname_seg=fname_seg,
                    args=sys.argv[1:],
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_deepseg_sc')
    display_viewer_syntax([fname_image, fname_seg],
                          colormaps=['gray', 'red'],
                          opacities=['', '0.7'])
Example #12
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # Initialization
    param = Param()
    start_time = time.time()

    fname_anat = arguments.i
    fname_centerline = arguments.s
    param.algo_fitting = arguments.algo_fitting

    if arguments.smooth is not None:
        sigmas = arguments.smooth
    remove_temp_files = arguments.r
    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = extract_fname(fname_anat)[1] + '_smooth.nii'

    # Display arguments
    printv('\nCheck input arguments...')
    printv('  Volume to smooth .................. ' + fname_anat)
    printv('  Centerline ........................ ' + fname_centerline)
    printv('  Sigma (mm) ........................ ' + str(sigmas))
    printv('  Verbose ........................... ' + str(verbose))

    # Check that input is 3D:
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_anat).dim
    dim = 4  # by default, will be adjusted later
    if nt == 1:
        dim = 3
    if nz == 1:
        dim = 2
    if dim == 4:
        printv(
            'WARNING: the input image is 4D, please split your image to 3D before smoothing spinalcord using :\n'
            'sct_image -i ' + fname_anat + ' -split t -o ' + fname_anat,
            verbose, 'warning')
        printv('4D images not supported, aborting ...', verbose, 'error')

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = extract_fname(
        fname_centerline)

    path_tmp = tmp_create(basename="smooth_spinalcord")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    copy(fname_anat, os.path.join(path_tmp, "anat" + ext_anat))
    copy(fname_centerline, os.path.join(path_tmp,
                                        "centerline" + ext_centerline))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # convert to nii format
    im_anat = convert(Image('anat' + ext_anat))
    im_anat.save('anat.nii', mutable=True, verbose=verbose)
    im_centerline = convert(Image('centerline' + ext_centerline))
    im_centerline.save('centerline.nii', mutable=True, verbose=verbose)

    # Change orientation of the input image into RPI
    printv('\nOrient input volume to RPI orientation...')

    img_anat_rpi = Image("anat.nii").change_orientation("RPI")
    fname_anat_rpi = add_suffix(img_anat_rpi.absolutepath, "_rpi")
    img_anat_rpi.save(path=fname_anat_rpi, mutable=True)

    # Change orientation of the input image into RPI
    printv('\nOrient centerline to RPI orientation...')

    img_centerline_rpi = Image("centerline.nii").change_orientation("RPI")
    fname_centerline_rpi = add_suffix(img_centerline_rpi.absolutepath, "_rpi")
    img_centerline_rpi.save(path=fname_centerline_rpi, mutable=True)

    # Straighten the spinal cord
    # straighten segmentation
    printv('\nStraighten the spinal cord using centerline/segmentation...',
           verbose)
    cache_sig = cache_signature(
        input_files=[fname_anat_rpi, fname_centerline_rpi],
        input_params={"x": "spline"})
    cachefile = os.path.join(curdir, "straightening.cache")
    if cache_valid(cachefile, cache_sig) and os.path.isfile(
            os.path.join(
                curdir, 'warp_curve2straight.nii.gz')) and os.path.isfile(
                    os.path.join(
                        curdir,
                        'warp_straight2curve.nii.gz')) and os.path.isfile(
                            os.path.join(curdir, 'straight_ref.nii.gz')):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, 'warp_curve2straight.nii.gz'),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, 'warp_straight2curve.nii.gz'),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, 'straight_ref.nii.gz'),
             'straight_ref.nii.gz')
        # apply straightening
        run_proc([
            'sct_apply_transfo', '-i', fname_anat_rpi, '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'anat_rpi_straight.nii', '-x', 'spline'
        ], verbose)
    else:
        run_proc([
            'sct_straighten_spinalcord', '-i', fname_anat_rpi, '-o',
            'anat_rpi_straight.nii', '-s', fname_centerline_rpi, '-x',
            'spline', '-param', 'algo_fitting=' + param.algo_fitting
        ], verbose)
        cache_save(cachefile, cache_sig)
        # move warping fields locally (to use caching next time)
        copy('warp_curve2straight.nii.gz',
             os.path.join(curdir, 'warp_curve2straight.nii.gz'))
        copy('warp_straight2curve.nii.gz',
             os.path.join(curdir, 'warp_straight2curve.nii.gz'))

    # Smooth the straightened image along z
    printv('\nSmooth the straightened image...')

    img = Image("anat_rpi_straight.nii")
    out = img.copy()

    if len(sigmas) == 1:
        sigmas = [sigmas[0] for i in range(len(img.data.shape))]
    elif len(sigmas) != len(img.data.shape):
        raise ValueError(
            "-smooth need the same number of inputs as the number of image dimension OR only one input"
        )

    sigmas = [sigmas[i] / img.dim[i + 4] for i in range(3)]
    out.data = smooth(out.data, sigmas)
    out.save(path="anat_rpi_straight_smooth.nii")

    # Apply the reversed warping field to get back the curved spinal cord
    printv(
        '\nApply the reversed warping field to get back the curved spinal cord...'
    )
    run_proc([
        'sct_apply_transfo', '-i', 'anat_rpi_straight_smooth.nii', '-o',
        'anat_rpi_straight_smooth_curved.nii', '-d', 'anat.nii', '-w',
        'warp_straight2curve.nii.gz', '-x', 'spline'
    ], verbose)

    # replace zeroed voxels by original image (issue #937)
    printv('\nReplace zeroed voxels by original image...', verbose)
    nii_smooth = Image('anat_rpi_straight_smooth_curved.nii')
    data_smooth = nii_smooth.data
    data_input = Image('anat.nii').data
    indzero = np.where(data_smooth == 0)
    data_smooth[indzero] = data_input[indzero]
    nii_smooth.data = data_smooth
    nii_smooth.save('anat_rpi_straight_smooth_curved_nonzero.nii')

    # come back
    os.chdir(curdir)

    # Generate output file
    printv('\nGenerate output file...')
    generate_output_file(
        os.path.join(path_tmp, "anat_rpi_straight_smooth_curved_nonzero.nii"),
        fname_out)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...')
        rmtree(path_tmp)

    # Display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) +
           's\n')

    display_viewer_syntax([fname_anat, fname_out], verbose=verbose)
def main():
    """Main function."""
    parser = get_parser()
    args = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    fname_image = args.i
    contrast_type = args.c

    ctr_algo = args.centerline

    brain_bool = bool(args.brain)
    if args.brain is None and contrast_type in ['t2s', 't2_ax']:
        brain_bool = False

    output_folder = args.ofolder

    if ctr_algo == 'file' and args.file_centerline is None:
        printv(
            'Please use the flag -file_centerline to indicate the centerline filename.',
            1, 'error')
        sys.exit(1)

    if args.file_centerline is not None:
        manual_centerline_fname = args.file_centerline
        ctr_algo = 'file'
    else:
        manual_centerline_fname = None

    remove_temp_files = args.r
    verbose = args.v
    init_sct(log_level=verbose, update=True)  # Update log level

    algo_config_stg = '\nMethod:'
    algo_config_stg += '\n\tCenterline algorithm: ' + str(ctr_algo)
    algo_config_stg += '\n\tAssumes brain section included in the image: ' + str(
        brain_bool) + '\n'
    printv(algo_config_stg)

    # Segment image
    from spinalcordtoolbox.image import Image
    from spinalcordtoolbox.deepseg_lesion.core import deep_segmentation_MSlesion
    im_image = Image(fname_image)
    im_seg, im_labels_viewer, im_ctr = deep_segmentation_MSlesion(
        im_image,
        contrast_type,
        ctr_algo=ctr_algo,
        ctr_file=manual_centerline_fname,
        brain_bool=brain_bool,
        remove_temp_files=remove_temp_files,
        verbose=verbose)

    # Save segmentation
    fname_seg = os.path.abspath(
        os.path.join(
            output_folder,
            extract_fname(fname_image)[1] + '_lesionseg' +
            extract_fname(fname_image)[2]))
    im_seg.save(fname_seg)

    if ctr_algo == 'viewer':
        # Save labels
        fname_labels = os.path.abspath(
            os.path.join(
                output_folder,
                extract_fname(fname_image)[1] + '_labels-centerline' +
                extract_fname(fname_image)[2]))
        im_labels_viewer.save(fname_labels)

    if verbose == 2:
        # Save ctr
        fname_ctr = os.path.abspath(
            os.path.join(
                output_folder,
                extract_fname(fname_image)[1] + '_centerline' +
                extract_fname(fname_image)[2]))
        im_ctr.save(fname_ctr)

    display_viewer_syntax([fname_image, fname_seg],
                          colormaps=['gray', 'red'],
                          opacities=['', '0.7'])
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    input_filename = arguments.i
    centerline_file = arguments.s

    sc_straight = SpinalCordStraightener(input_filename, centerline_file)

    if arguments.dest is not None:
        sc_straight.use_straight_reference = True
        sc_straight.centerline_reference_filename = str(arguments.dest)

    if arguments.ldisc_input is not None:
        if not sc_straight.use_straight_reference:
            printv(
                'Warning: discs position are not taken into account if reference is not provided.'
            )
        else:
            sc_straight.discs_input_filename = str(arguments.ldisc_input)
            sc_straight.precision = 4.0
    if arguments.ldisc_dest is not None:
        if not sc_straight.use_straight_reference:
            printv(
                'Warning: discs position are not taken into account if reference is not provided.'
            )
        else:
            sc_straight.discs_ref_filename = str(arguments.ldisc_dest)
            sc_straight.precision = 4.0

    # Handling optional arguments
    sc_straight.remove_temp_files = arguments.r
    sc_straight.interpolation_warp = arguments.x
    sc_straight.output_filename = arguments.o
    sc_straight.path_output = arguments.ofolder
    path_qc = arguments.qc
    sc_straight.verbose = verbose

    # if arguments.cpu_nb is not None:
    #     sc_straight.cpu_number = arguments.cpu-nb)
    if arguments.disable_straight2curved:
        sc_straight.straight2curved = False
    if arguments.disable_curved2straight:
        sc_straight.curved2straight = False

    if arguments.speed_factor:
        sc_straight.speed_factor = arguments.speed_factor

    if arguments.xy_size:
        sc_straight.xy_size = arguments.xy_size

    sc_straight.param_centerline = ParamCenterline(
        algo_fitting=arguments.centerline_algo,
        smooth=arguments.centerline_smooth)
    if arguments.param is not None:
        params_user = arguments.param
        # update registration parameters
        for param in params_user:
            param_split = param.split('=')
            if param_split[0] == 'precision':
                sc_straight.precision = float(param_split[1])
            if param_split[0] == 'threshold_distance':
                sc_straight.threshold_distance = float(param_split[1])
            if param_split[0] == 'accuracy_results':
                sc_straight.accuracy_results = int(param_split[1])
            if param_split[0] == 'template_orientation':
                sc_straight.template_orientation = int(param_split[1])

    fname_straight = sc_straight.straighten()

    printv("\nFinished! Elapsed time: {} s".format(sc_straight.elapsed_time),
           verbose)

    # Generate QC report
    if path_qc is not None:
        path_qc = os.path.abspath(path_qc)
        qc_dataset = arguments.qc_dataset
        qc_subject = arguments.qc_subject
        generate_qc(fname_straight,
                    args=arguments,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process=os.path.basename(__file__.strip('.py')))

    display_viewer_syntax([fname_straight], verbose=verbose)
Example #15
0
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    # initializations
    output_type = None
    dim_list = ['x', 'y', 'z', 't']

    fname_in = arguments.i

    im_in_list = [Image(fname) for fname in fname_in]
    if len(im_in_list
           ) > 1 and arguments.concat is None and arguments.omc is None:
        parser.error(
            "Multi-image input is only supported for the '-concat' and '-omc' arguments."
        )

    # Apply initialization steps to all input images first
    if arguments.set_sform_to_qform:
        [im.set_sform_to_qform() for im in im_in_list]
    elif arguments.set_qform_to_sform:
        [im.set_qform_to_sform() for im in im_in_list]

    # Most sct_image options don't accept multi-image input, so here we simply separate out the first image
    # TODO: Extend the options so that they iterate through the list of images (to support multi-image input)
    im_in = im_in_list[0]

    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = None

    # Run command
    # Arguments are sorted alphabetically (not according to the usage order)
    if arguments.concat is not None:
        dim = arguments.concat
        assert dim in dim_list
        dim = dim_list.index(dim)
        im_out = [concat_data(im_in_list, dim)]

    elif arguments.copy_header is not None:
        if fname_out is None:
            raise ValueError("Need to specify output image with -o!")
        im_dest = Image(arguments.copy_header)
        im_dest_new = im_in.copy()
        im_dest_new.data = im_dest.data.copy()
        # im_dest.header = im_in.header
        im_dest_new.absolutepath = im_dest.absolutepath
        im_out = [im_dest_new]

    elif arguments.display_warp:
        visualize_warp(im_warp=im_in, im_grid=None, step=3, rm_tmp=True)
        im_out = None

    elif arguments.getorient:
        orient = im_in.orientation
        im_out = None

    elif arguments.keep_vol is not None:
        index_vol = (arguments.keep_vol).split(',')
        for iindex_vol, vol in enumerate(index_vol):
            index_vol[iindex_vol] = int(vol)
        im_out = [remove_vol(im_in, index_vol, todo='keep')]

    elif arguments.mcs:
        if len(im_in.data.shape) != 5:
            printv(
                parser.error(
                    'ERROR: -mcs input need to be a multi-component image'))
        im_out = multicomponent_split(im_in)

    elif arguments.omc:
        im_ref = im_in_list[0]
        for im in im_in_list:
            if im.data.shape != im_ref.data.shape:
                printv(
                    parser.error(
                        'ERROR: -omc inputs need to have all the same shapes'))
            del im
        im_out = [multicomponent_merge(im_in_list=im_in_list)]

    elif arguments.pad is not None:
        ndims = len(im_in.data.shape)
        if ndims != 3:
            printv('ERROR: you need to specify a 3D input file.', 1, 'error')
            return

        pad_arguments = arguments.pad.split(',')
        if len(pad_arguments) != 3:
            printv('ERROR: you need to specify 3 padding values.', 1, 'error')

        padx, pady, padz = pad_arguments
        padx, pady, padz = int(padx), int(pady), int(padz)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padx,
                      pad_x_f=padx,
                      pad_y_i=pady,
                      pad_y_f=pady,
                      pad_z_i=padz,
                      pad_z_f=padz)
        ]

    elif arguments.pad_asym is not None:
        ndims = len(im_in.data.shape)
        if ndims != 3:
            printv('ERROR: you need to specify a 3D input file.', 1, 'error')
            return

        pad_arguments = arguments.pad_asym.split(',')
        if len(pad_arguments) != 6:
            printv('ERROR: you need to specify 6 padding values.', 1, 'error')

        padxi, padxf, padyi, padyf, padzi, padzf = pad_arguments
        padxi, padxf, padyi, padyf, padzi, padzf = int(padxi), int(padxf), int(
            padyi), int(padyf), int(padzi), int(padzf)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padxi,
                      pad_x_f=padxf,
                      pad_y_i=padyi,
                      pad_y_f=padyf,
                      pad_z_i=padzi,
                      pad_z_f=padzf)
        ]

    elif arguments.remove_vol is not None:
        index_vol = (arguments.remove_vol).split(',')
        for iindex_vol, vol in enumerate(index_vol):
            index_vol[iindex_vol] = int(vol)
        im_out = [remove_vol(im_in, index_vol, todo='remove')]

    elif arguments.setorient is not None:
        printv(im_in.absolutepath)
        im_out = [change_orientation(im_in, arguments.setorient)]

    elif arguments.setorient_data is not None:
        im_out = [
            change_orientation(im_in, arguments.setorient_data, data_only=True)
        ]

    elif arguments.header is not None:
        header = im_in.header
        # Necessary because of https://github.com/nipy/nibabel/issues/480#issuecomment-239227821
        im_file = nib.load(im_in.absolutepath)
        header.structarr['scl_slope'] = im_file.dataobj.slope
        header.structarr['scl_inter'] = im_file.dataobj.inter
        printv(create_formatted_header_string(header=header,
                                              output_format=arguments.header),
               verbose=verbose)
        im_out = None

    elif arguments.split is not None:
        dim = arguments.split
        assert dim in dim_list
        dim = dim_list.index(dim)
        im_out = split_data(im_in, dim)

    elif arguments.type is not None:
        output_type = arguments.type
        im_out = [im_in]

    elif arguments.to_fsl is not None:
        space_files = arguments.to_fsl
        if len(space_files) > 2 or len(space_files) < 1:
            printv(parser.error('ERROR: -to-fsl expects 1 or 2 arguments'))
            return
        spaces = [Image(s) for s in space_files]
        if len(spaces) < 2:
            spaces.append(None)
        im_out = [displacement_to_abs_fsl(im_in, spaces[0], spaces[1])]

    # If these arguments are used standalone, simply pass the input image to the output (the affines were set earlier)
    elif arguments.set_sform_to_qform or arguments.set_qform_to_sform:
        im_out = [im_in]

    else:
        im_out = None
        printv(
            parser.error(
                'ERROR: you need to specify an operation to do on the input image'
            ))

    # in case fname_out is not defined, use first element of input file name list
    if fname_out is None:
        fname_out = fname_in[0]

    # Write output
    if im_out is not None:
        printv('Generate output files...', verbose)
        # if only one output
        if len(im_out) == 1 and arguments.split is None:
            im_out[0].save(fname_out, dtype=output_type, verbose=verbose)
            display_viewer_syntax([fname_out], verbose=verbose)
        if arguments.mcs:
            # use input file name and add _X, _Y _Z. Keep the same extension
            l_fname_out = []
            for i_dim in range(3):
                l_fname_out.append(
                    add_suffix(fname_out or fname_in[0],
                               '_' + dim_list[i_dim].upper()))
                im_out[i_dim].save(l_fname_out[i_dim], verbose=verbose)
            display_viewer_syntax(fname_out)
        if arguments.split is not None:
            # use input file name and add _"DIM+NUMBER". Keep the same extension
            l_fname_out = []
            for i, im in enumerate(im_out):
                l_fname_out.append(
                    add_suffix(fname_out or fname_in[0],
                               '_' + dim_list[dim].upper() + str(i).zfill(4)))
                im.save(l_fname_out[i])
            display_viewer_syntax(l_fname_out)

    elif arguments.getorient:
        printv(orient)

    elif arguments.display_warp:
        printv('Warping grid generated.', verbose, 'info')
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        list_warp = self.list_warp  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        islabel = False
        if self.interp == 'label':
            islabel = True
            self.interp = 'nn'

        interp = get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # list_warp = list_warp.replace(' ', '')  # remove spaces
        # list_warp = list_warp.split(",")  # parse with comma
        for idx_warp, path_warp in enumerate(self.list_warp):
            # Check if this transformation should be inverted
            if path_warp in self.list_warpinv:
                use_inverse.append('-i')
                # list_warp[idx_warp] = path_warp[1:]  # remove '-'
                fname_warp_list_invert += [[
                    use_inverse[idx_warp], list_warp[idx_warp]
                ]]
            else:
                use_inverse.append('')
                fname_warp_list_invert += [[path_warp]]
            path_warp = list_warp[idx_warp]
            if path_warp.endswith((".nii", ".nii.gz")) \
                    and Image(list_warp[idx_warp]).header.get_intent()[0] != 'vector':
                raise ValueError(
                    "Displacement field in {} is invalid: should be encoded"
                    " in a 5D file with vector intent code"
                    " (see https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h"
                    .format(path_warp))
        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = extract_fname(
            fname_warp_list_invert[-1][-1])
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # check if destination file is 3d
        # check_dim(fname_dest, dim_lst=[3]) # PR 2598: we decided to skip this line.

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order
        fname_warp_list_invert.reverse()
        fname_warp_list_invert = functools.reduce(lambda x, y: x + y,
                                                  fname_warp_list_invert)

        # Extract path, file and extension
        path_src, file_src, ext_src = extract_fname(fname_src)
        path_dest, file_dest, ext_dest = extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = os.path.join(path_out, file_out + ext_out)

        # Get dimensions of data
        printv('\nGet dimensions of data...', verbose)
        img_src = Image(fname_src)
        nx, ny, nz, nt, px, py, pz, pt = img_src.dim
        # nx, ny, nz, nt, px, py, pz, pt = get_dimension(fname_src)
        printv(
            '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' +
            str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
            else:
                dim = '3'
            # if labels, dilate before resampling
            if islabel:
                printv("\nDilate labels before warping...")
                path_tmp = tmp_create(basename="apply_transfo")
                fname_dilated_labels = os.path.join(path_tmp,
                                                    "dilated_data.nii")
                # dilate points
                dilate(Image(fname_src), 4, 'ball').save(fname_dilated_labels)
                fname_src = fname_dilated_labels

            printv(
                "\nApply transformation and resample to destination space...",
                verbose)
            run_proc([
                'isct_antsApplyTransforms', '-d', dim, '-i', fname_src, '-o',
                fname_out, '-t'
            ] + fname_warp_list_invert + ['-r', fname_dest] + interp,
                     is_sct_binary=True)

        # if 4d, loop across the T dimension
        else:
            if islabel:
                raise NotImplementedError

            dim = '4'
            path_tmp = tmp_create(basename="apply_transfo")

            # convert to nifti into temp folder
            printv('\nCopying input data to tmp folder and convert to nii...',
                   verbose)
            img_src.save(os.path.join(path_tmp, "data.nii"))
            copy(fname_dest, os.path.join(path_tmp, file_dest + ext_dest))
            fname_warp_list_tmp = []
            for fname_warp in list_warp:
                path_warp, file_warp, ext_warp = extract_fname(fname_warp)
                copy(fname_warp, os.path.join(path_tmp, file_warp + ext_warp))
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            curdir = os.getcwd()
            os.chdir(path_tmp)

            # split along T dimension
            printv('\nSplit along T dimension...', verbose)

            im_dat = Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = sct_image.split_data(im_dat, 3)
            for im in data_split_list:
                im.save()

            # apply transfo
            printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'

                status, output = run_proc([
                    'isct_antsApplyTransforms',
                    '-d',
                    '3',
                    '-i',
                    file_data_split,
                    '-o',
                    file_data_split_reg,
                    '-t',
                ] + fname_warp_list_invert_tmp + [
                    '-r',
                    file_dest + ext_dest,
                ] + interp,
                                          verbose,
                                          is_sct_binary=True)

            # Merge files back
            printv('\nMerge file back...', verbose)
            import glob
            path_out, name_out, ext_out = extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            fname_list.sort()
            im_list = [Image(fname) for fname in fname_list]
            im_out = sct_image.concat_data(im_list, 3, im_header['pixdim'])
            im_out.save(name_out + ext_out)

            os.chdir(curdir)
            generate_output_file(os.path.join(path_tmp, name_out + ext_out),
                                 fname_out)
            # Delete temporary folder if specified
            if remove_temp_files:
                printv('\nRemove temporary files...', verbose)
                rmtree(path_tmp, verbose=verbose)

        # Copy affine matrix from destination space to make sure qform/sform are the same
        printv(
            "Copy affine matrix from destination space to make sure qform/sform are the same.",
            verbose)
        im_src_reg = Image(fname_out)
        im_src_reg.copy_qform_from_ref(Image(fname_dest))
        im_src_reg.save(
            verbose=0
        )  # set verbose=0 to avoid warning message about rewriting file

        if islabel:
            printv(
                "\nTake the center of mass of each registered dilated labels..."
            )
            labeled_img = cubic_to_point(im_src_reg)
            labeled_img.save(path=fname_out)
            if remove_temp_files:
                printv('\nRemove temporary files...', verbose)
                rmtree(path_tmp, verbose=verbose)

        # Crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # If the last transformation is not an affine transfo, we need to compute the matrix space of the concatenated
        # warping field
        if not isLastAffine and crop_reference in [1, 2]:
            printv('Last transformation is not affine.')
            if crop_reference in [1, 2]:
                # Extract only the first ndim of the warping field
                img_warp = Image(warping_field)
                if dim == '2':
                    img_warp_ndim = Image(img_src.data[:, :], hdr=img_warp.hdr)
                elif dim in ['3', '4']:
                    img_warp_ndim = Image(img_src.data[:, :, :],
                                          hdr=img_warp.hdr)
                # Set zero to everything outside the warping field
                cropper = ImageCropper(Image(fname_out))
                cropper.get_bbox_from_ref(img_warp_ndim)
                if crop_reference == 1:
                    printv(
                        'Cropping strategy is: keep same matrix size, put 0 everywhere around warping field'
                    )
                    img_out = cropper.crop(background=0)
                elif crop_reference == 2:
                    printv(
                        'Cropping strategy is: crop around warping field (the size of warping field will '
                        'change)')
                    img_out = cropper.crop()
                img_out.save(fname_out)

        display_viewer_syntax([fname_dest, fname_out], verbose=verbose)
Example #17
0
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    dim_list = ['x', 'y', 'z', 't']

    fname_in = arguments.i
    fname_out = arguments.o
    output_type = arguments.type

    # Open file(s)
    im = Image(fname_in)
    data = im.data  # 3d or 4d numpy array
    dim = im.dim

    # run command
    if arguments.otsu is not None:
        param = arguments.otsu
        data_out = sct_math.otsu(data, param)

    elif arguments.adap is not None:
        param = arguments.adap
        data_out = sct_math.adap(data, param[0], param[1])

    elif arguments.otsu_median is not None:
        param = arguments.otsu_median
        data_out = sct_math.otsu_median(data, param[0], param[1])

    elif arguments.thr is not None:
        param = arguments.thr
        data_out = sct_math.threshold(data, param)

    elif arguments.percent is not None:
        param = arguments.percent
        data_out = sct_math.perc(data, param)

    elif arguments.bin is not None:
        bin_thr = arguments.bin
        data_out = sct_math.binarize(data, bin_thr=bin_thr)

    elif arguments.add is not None:
        data2 = get_data_or_scalar(arguments.add, data)
        data_concat = sct_math.concatenate_along_4th_dimension(data, data2)
        data_out = np.sum(data_concat, axis=3)

    elif arguments.sub is not None:
        data2 = get_data_or_scalar(arguments.sub, data)
        data_out = data - data2

    elif arguments.laplacian is not None:
        sigmas = arguments.laplacian
        if len(sigmas) == 1:
            sigmas = [sigmas for i in range(len(data.shape))]
        elif len(sigmas) != len(data.shape):
            printv(
                parser.error(
                    'ERROR: -laplacian need the same number of inputs as the number of image dimension OR only one input'
                ))
        # adjust sigma based on voxel size
        sigmas = [sigmas[i] / dim[i + 4] for i in range(3)]
        # smooth data
        data_out = sct_math.laplacian(data, sigmas)

    elif arguments.mul is not None:
        data2 = get_data_or_scalar(arguments.mul, data)
        data_concat = sct_math.concatenate_along_4th_dimension(data, data2)
        data_out = np.prod(data_concat, axis=3)

    elif arguments.div is not None:
        data2 = get_data_or_scalar(arguments.div, data)
        data_out = np.divide(data, data2)

    elif arguments.mean is not None:
        dim = dim_list.index(arguments.mean)
        if dim + 1 > len(
                np.shape(data)):  # in case input volume is 3d and dim=t
            data = data[..., np.newaxis]
        data_out = np.mean(data, dim)

    elif arguments.rms is not None:
        dim = dim_list.index(arguments.rms)
        if dim + 1 > len(
                np.shape(data)):  # in case input volume is 3d and dim=t
            data = data[..., np.newaxis]
        data_out = np.sqrt(np.mean(np.square(data.astype(float)), dim))

    elif arguments.std is not None:
        dim = dim_list.index(arguments.std)
        if dim + 1 > len(
                np.shape(data)):  # in case input volume is 3d and dim=t
            data = data[..., np.newaxis]
        data_out = np.std(data, dim, ddof=1)

    elif arguments.smooth is not None:
        sigmas = arguments.smooth
        if len(sigmas) == 1:
            sigmas = [sigmas[0] for i in range(len(data.shape))]
        elif len(sigmas) != len(data.shape):
            printv(
                parser.error(
                    'ERROR: -smooth need the same number of inputs as the number of image dimension OR only one input'
                ))
        # adjust sigma based on voxel size
        sigmas = [sigmas[i] / dim[i + 4] for i in range(3)]
        # smooth data
        data_out = sct_math.smooth(data, sigmas)

    elif arguments.dilate is not None:
        if arguments.shape in ['disk', 'square'] and arguments.dim is None:
            printv(
                parser.error(
                    'ERROR: -dim is required for -dilate with 2D morphological kernel'
                ))
        data_out = sct_math.dilate(data,
                                   size=arguments.dilate,
                                   shape=arguments.shape,
                                   dim=arguments.dim)

    elif arguments.erode is not None:
        if arguments.shape in ['disk', 'square'] and arguments.dim is None:
            printv(
                parser.error(
                    'ERROR: -dim is required for -erode with 2D morphological kernel'
                ))
        data_out = sct_math.erode(data,
                                  size=arguments.erode,
                                  shape=arguments.shape,
                                  dim=arguments.dim)

    elif arguments.denoise is not None:
        # parse denoising arguments
        p, b = 1, 5  # default arguments
        list_denoise = (arguments.denoise).split(",")
        for i in list_denoise:
            if 'p' in i:
                p = int(i.split('=')[1])
            if 'b' in i:
                b = int(i.split('=')[1])
        data_out = sct_math.denoise_nlmeans(data,
                                            patch_radius=p,
                                            block_radius=b)

    elif arguments.symmetrize is not None:
        data_out = (data + data[list(range(data.shape[0] -
                                           1, -1, -1)), :, :]) / float(2)

    elif arguments.mi is not None:
        # input 1 = from flag -i --> im
        # input 2 = from flag -mi
        im_2 = Image(arguments.mi)
        compute_similarity(im,
                           im_2,
                           fname_out,
                           metric='mi',
                           metric_full='Mutual information',
                           verbose=verbose)
        data_out = None

    elif arguments.minorm is not None:
        im_2 = Image(arguments.minorm)
        compute_similarity(im,
                           im_2,
                           fname_out,
                           metric='minorm',
                           metric_full='Normalized Mutual information',
                           verbose=verbose)
        data_out = None

    elif arguments.corr is not None:
        # input 1 = from flag -i --> im
        # input 2 = from flag -mi
        im_2 = Image(arguments.corr)
        compute_similarity(im,
                           im_2,
                           fname_out,
                           metric='corr',
                           metric_full='Pearson correlation coefficient',
                           verbose=verbose)
        data_out = None

    # if no flag is set
    else:
        data_out = None
        printv(
            parser.error(
                'ERROR: you need to specify an operation to do on the input image'
            ))

    if data_out is not None:
        # Write output
        nii_out = Image(fname_in)  # use header of input file
        nii_out.data = data_out
        nii_out.save(fname_out, dtype=output_type)
    # TODO: case of multiple outputs
    # assert len(data_out) == n_out
    # if n_in == n_out:
    #     for im_in, d_out, fn_out in zip(nii, data_out, fname_out):
    #         im_in.data = d_out
    #         im_in.absolutepath = fn_out
    #         if arguments.w is not None:
    #             im_in.hdr.set_intent('vector', (), '')
    #         im_in.save()
    # elif n_out == 1:
    #     nii[0].data = data_out[0]
    #     nii[0].absolutepath = fname_out[0]
    #     if arguments.w is not None:
    #             nii[0].hdr.set_intent('vector', (), '')
    #     nii[0].save()
    # elif n_out > n_in:
    #     for dat_out, name_out in zip(data_out, fname_out):
    #         im_out = nii[0].copy()
    #         im_out.data = dat_out
    #         im_out.absolutepath = name_out
    #         if arguments.w is not None:
    #             im_out.hdr.set_intent('vector', (), '')
    #         im_out.save()
    # else:
    #     printv(parser.usage.generate(error='ERROR: not the correct numbers of inputs and outputs'))

    # display message
    if data_out is not None:
        display_viewer_syntax([fname_out], verbose=verbose)
    else:
        printv('\nDone! File created: ' + fname_out, verbose, 'info')
Example #18
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_in = os.path.abspath(arguments.i)
    fname_seg = os.path.abspath(arguments.s)
    contrast = arguments.c
    path_template = os.path.abspath(arguments.t)
    scale_dist = arguments.scale_dist
    path_output = os.path.abspath(arguments.ofolder)
    fname_disc = arguments.discfile
    if fname_disc is not None:
        fname_disc = os.path.abspath(fname_disc)
    initz = arguments.initz
    initcenter = arguments.initcenter
    fname_initlabel = arguments.initlabel
    if fname_initlabel is not None:
        fname_initlabel = os.path.abspath(fname_initlabel)
    remove_temp_files = arguments.r
    clean_labels = arguments.clean_labels

    path_tmp = tmp_create(basename="label_vertebrae")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder...', verbose)
    Image(fname_in).save(os.path.join(path_tmp, "data.nii"))
    Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii"))

    # Go go temp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Straighten spinal cord
    printv('\nStraighten spinal cord...', verbose)
    # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
    cache_sig = cache_signature(input_files=[fname_in, fname_seg], )
    fname_cache = "straightening.cache"
    if (cache_valid(os.path.join(curdir, fname_cache), cache_sig)
            and os.path.isfile(
                os.path.join(curdir, "warp_curve2straight.nii.gz"))
            and os.path.isfile(
                os.path.join(curdir, "warp_straight2curve.nii.gz"))
            and os.path.isfile(os.path.join(curdir, "straight_ref.nii.gz"))):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, "warp_curve2straight.nii.gz"),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, "warp_straight2curve.nii.gz"),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, "straight_ref.nii.gz"),
             'straight_ref.nii.gz')
        # apply straightening
        s, o = run_proc([
            'sct_apply_transfo', '-i', 'data.nii', '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'data_straight.nii'
        ])
    else:
        sct_straighten_spinalcord.main(argv=[
            '-i',
            'data.nii',
            '-s',
            'segmentation.nii',
            '-r',
            str(remove_temp_files),
            '-v',
            '0',
        ])
        cache_save(os.path.join(path_output, fname_cache), cache_sig)

    # resample to 0.5mm isotropic to match template resolution
    printv('\nResample to 0.5mm isotropic...', verbose)
    s, o = run_proc([
        'sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x',
        'linear', '-o', 'data_straightr.nii'
    ],
                    verbose=verbose)

    # Apply straightening to segmentation
    # N.B. Output is RPI
    printv('\nApply straightening to segmentation...', verbose)
    sct_apply_transfo.main([
        '-i', 'segmentation.nii', '-d', 'data_straightr.nii', '-w',
        'warp_curve2straight.nii.gz', '-o', 'segmentation_straight.nii', '-x',
        'linear', '-v', '0'
    ])

    # Threshold segmentation at 0.5
    img = Image('segmentation_straight.nii')
    img.data = threshold(img.data, 0.5)
    img.save()

    # If disc label file is provided, label vertebrae using that file instead of automatically
    if fname_disc:
        # Apply straightening to disc-label
        printv('\nApply straightening to disc labels...', verbose)
        run_proc(
            'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
            (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labeldisc_straight.nii.gz', 'label'),
            verbose=verbose)
        label_vert('segmentation_straight.nii',
                   'labeldisc_straight.nii.gz',
                   verbose=1)

    else:
        printv('\nCreate label to identify disc...', verbose)
        fname_labelz = os.path.join(path_tmp, 'labelz.nii.gz')
        if initcenter is not None:
            # find z centered in FOV
            nii = Image('segmentation.nii').change_orientation("RPI")
            nx, ny, nz, nt, px, py, pz, pt = nii.dim
            z_center = round(nz / 2)
            initz = [z_center, initcenter]
        if initz is not None:
            im_label = create_labels_along_segmentation(
                Image('segmentation.nii'), [tuple(initz)])
            im_label.save(fname_labelz)
        elif fname_initlabel is not None:
            Image(fname_initlabel).save(fname_labelz)
        else:
            # automatically finds C2-C3 disc
            im_data = Image('data.nii')
            im_seg = Image('segmentation.nii')
            # because verbose is also used for keeping temp files
            verbose_detect_c2c3 = 0 if remove_temp_files else 2
            im_label_c2c3 = detect_c2c3(im_data,
                                        im_seg,
                                        contrast,
                                        verbose=verbose_detect_c2c3)
            ind_label = np.where(im_label_c2c3.data)
            if np.size(ind_label) == 0:
                printv(
                    'Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils',
                    1, 'error')
                sys.exit(1)
            im_label_c2c3.data[ind_label] = 3
            im_label_c2c3.save(fname_labelz)

        # dilate label so it is not lost when applying warping
        dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz)

        # Apply straightening to z-label
        printv('\nAnd apply straightening to label...', verbose)
        sct_apply_transfo.main([
            '-i', 'labelz.nii.gz', '-d', 'data_straightr.nii', '-w',
            'warp_curve2straight.nii.gz', '-o', 'labelz_straight.nii.gz', '-x',
            'nn', '-v', '0'
        ])
        # get z value and disk value to initialize labeling
        printv('\nGet z and disc values from straight label...', verbose)
        init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz')
        printv('.. ' + str(init_disc), verbose)

        # apply laplacian filtering
        if arguments.laplacian:
            printv('\nApply Laplacian filter...', verbose)
            img = Image("data_straightr.nii")

            # apply std dev to each axis of the image
            sigmas = [1 for i in range(len(img.data.shape))]

            # adjust sigma based on voxel size
            sigmas = [sigmas[i] / img.dim[i + 4] for i in range(3)]

            # smooth data
            img.data = laplacian(img.data, sigmas)
            img.save()

        # detect vertebral levels on straight spinal cord
        init_disc[1] = init_disc[1] - 1
        vertebral_detection('data_straightr.nii',
                            'segmentation_straight.nii',
                            contrast,
                            arguments.param,
                            init_disc=init_disc,
                            verbose=verbose,
                            path_template=path_template,
                            path_output=path_output,
                            scale_dist=scale_dist)

    # un-straighten labeled spinal cord
    printv('\nUn-straighten labeling...', verbose)
    sct_apply_transfo.main([
        '-i', 'segmentation_straight_labeled.nii', '-d', 'segmentation.nii',
        '-w', 'warp_straight2curve.nii.gz', '-o', 'segmentation_labeled.nii',
        '-x', 'nn', '-v', '0'
    ])

    if clean_labels >= 1:
        printv('\nCleaning labeled segmentation:', verbose)
        im_labeled_seg = Image('segmentation_labeled.nii')
        im_seg = Image('segmentation.nii')
        if clean_labels >= 2:
            printv('  filling in missing label voxels ...', verbose)
            expand_labels(im_labeled_seg)
        printv('  removing labeled voxels outside segmentation...', verbose)
        crop_labels(im_labeled_seg, im_seg)
        printv('Done cleaning.', verbose)
        im_labeled_seg.save()

    # label discs
    printv('\nLabel discs...', verbose)
    printv('\nUn-straighten labeled discs...', verbose)
    run_proc(
        'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
        ('segmentation_straight_labeled_disc.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled_disc.nii',
         'label'),
        verbose=verbose,
        is_sct_binary=True,
    )

    # come back
    os.chdir(curdir)

    # Generate output files
    path_seg, file_seg, ext_seg = extract_fname(fname_seg)
    fname_seg_labeled = os.path.join(path_output,
                                     file_seg + '_labeled' + ext_seg)
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"),
                         fname_seg_labeled)
    generate_output_file(
        os.path.join(path_tmp, "segmentation_labeled_disc.nii"),
        os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
    # copy straightening files in case subsequent SCT functions need them
    generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"),
                         os.path.join(path_output,
                                      "warp_curve2straight.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"),
                         os.path.join(path_output,
                                      "warp_straight2curve.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"),
                         os.path.join(path_output, "straight_ref.nii.gz"),
                         verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp)

    # Generate QC report
    if arguments.qc is not None:
        path_qc = os.path.abspath(arguments.qc)
        qc_dataset = arguments.qc_dataset
        qc_subject = arguments.qc_subject
        labeled_seg_file = os.path.join(path_output,
                                        file_seg + '_labeled' + ext_seg)
        generate_qc(fname_in,
                    fname_seg=labeled_seg_file,
                    args=argv,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_label_vertebrae')

    display_viewer_syntax([fname_in, fname_seg_labeled],
                          colormaps=['', 'subcortical'],
                          opacities=['1', '0.5'])
Example #19
0
def moco_wrapper(param):
    """
    Wrapper that performs motion correction.

    :param param: ParamMoco class
    :return: None
    """
    file_data = 'data.nii'  # corresponds to the full input data (e.g. dmri or fmri)
    file_data_dirname, file_data_basename, file_data_ext = extract_fname(
        file_data)
    file_b0 = 'b0.nii'
    file_datasub = 'datasub.nii'  # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True)
    file_datasubgroup = 'datasub-groups.nii'  # concatenation of the average of each file_datasub
    file_mask = 'mask.nii'
    file_moco_params_csv = 'moco_params.tsv'
    file_moco_params_x = 'moco_params_x.nii.gz'
    file_moco_params_y = 'moco_params_y.nii.gz'
    ext_data = '.nii.gz'  # workaround "too many open files" by slurping the data
    # TODO: check if .nii can be used
    mat_final = 'mat_final/'
    # ext_mat = 'Warp.nii.gz'  # warping field

    # Start timer
    start_time = time.time()

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + param.fname_data, param.verbose)
    printv('  Group size ............ {}'.format(param.group_size),
           param.verbose)

    # Get full path
    # param.fname_data = os.path.abspath(param.fname_data)
    # param.fname_bvecs = os.path.abspath(param.fname_bvecs)
    # if param.fname_bvals != '':
    #     param.fname_bvals = os.path.abspath(param.fname_bvals)

    # Extract path, file and extension
    # path_data, file_data, ext_data = extract_fname(param.fname_data)
    # path_mask, file_mask, ext_mask = extract_fname(param.fname_mask)

    path_tmp = tmp_create(basename="moco")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...',
           param.verbose)
    convert(param.fname_data, os.path.join(path_tmp, file_data))
    if param.fname_mask != '':
        convert(param.fname_mask,
                os.path.join(path_tmp, file_mask),
                verbose=param.verbose)
        # Update field in param (because used later in another function, and param class will be passed)
        param.fname_mask = file_mask

    # Build absolute output path and go to tmp folder
    curdir = os.getcwd()
    path_out_abs = os.path.abspath(param.path_out)
    os.chdir(path_tmp)

    # Get dimensions of data
    printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Get orientation
    printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        printv('  Treated as axial')
    else:
        param.is_sagittal = False
        printv(
            'WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.'
        )

    printv(
        "\nSet suffix of transformation file name, which depends on the orientation:"
    )
    if param.is_sagittal:
        param.suffix_mat = '0GenericAffine.mat'
        printv(
            "Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the "
            "estimated transformation is a 2D affine transfo.".format(
                param.suffix_mat))
    else:
        param.suffix_mat = 'Warp.nii.gz'
        printv(
            "Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is "
            "composed of a stack of 2D Tx-Ty transformations".format(
                param.suffix_mat))

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
        printv(
            'For sagittal data group_size should be one for more robustness. Forcing group_size=1.',
            1, 'warning')
        param.group_size = 1

    if param.is_diffusion:
        # Identify b=0 and DWI images
        index_b0, index_dwi, nb_b0, nb_dwi = \
            sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min,
                                                     param.verbose)

        # check if dmri and bvecs are the same size
        if not nb_b0 + nb_dwi == nt:
            printv(
                '\nERROR in ' + os.path.basename(__file__) +
                ': Size of data (' + str(nt) + ') and size of bvecs (' +
                str(nb_b0 + nb_dwi) +
                ') are not the same. Check your bvecs file.\n', 1, 'error')
            sys.exit(2)

    # ==================================================================================================================
    # Prepare data (mean/groups...)
    # ==================================================================================================================

    # Split into T dimension
    printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")
        im.save()

    if param.is_diffusion:
        # Merge and average b=0 images
        printv('\nMerge and average b=0 data...', param.verbose)
        im_b0_list = []
        for it in range(nb_b0):
            im_b0_list.append(im_data_split_list[index_b0[it]])
        im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0)
        # Average across time
        im_b0.mean(dim=3).save(add_suffix(file_b0, '_mean'))

        n_moco = nb_dwi  # set number of data to perform moco on (using grouping)
        index_moco = index_dwi

    # If not a diffusion scan, we will motion-correct all volumes
    else:
        n_moco = nt
        index_moco = list(range(0, nt))

    nb_groups = int(math.floor(n_moco / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_moco[(iGroup *
                                         param.group_size):((iGroup + 1) *
                                                            param.group_size)])

    # add the remaining images to a new last group (in case the total number of image is not divisible by group_size)
    nb_remaining = n_moco % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_moco[len(index_moco) -
                                        nb_remaining:len(index_moco)])

    _, file_dwi_basename, file_dwi_ext = extract_fname(file_datasub)
    # Group data
    list_file_group = []
    for iGroup in sct_progress_bar(range(nb_groups),
                                   unit='iter',
                                   unit_scale=False,
                                   desc="Merge within groups",
                                   ascii=False,
                                   ncols=80):
        # get index
        index_moco_i = group_indexes[iGroup]
        n_moco_i = len(index_moco_i)
        # concatenate images across time, within this group
        file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) +
                                        ext_data)
        im_dwi_list = []
        for it in range(n_moco_i):
            im_dwi_list.append(im_data_split_list[index_moco_i[it]])
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i,
                                                      verbose=0)
        # Average across time
        list_file_group.append(
            os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' +
                         ext_data))
        im_dwi_out.mean(dim=3).save(list_file_group[-1])

    # Merge across groups
    printv('\nMerge across groups...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    fname_dw_list = []
    for iGroup in range(nb_groups):
        fname_dw_list.append(list_file_group[iGroup])
    im_dw_list = [Image(fname) for fname in fname_dw_list]
    concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0)

    # Cleanup
    del im, im_data_split_list

    # ==================================================================================================================
    # Estimate moco
    # ==================================================================================================================

    # Initialize another class instance that will be passed on to the moco() function
    param_moco = deepcopy(param)

    if param.is_diffusion:
        # Estimate moco on b0 groups
        printv(
            '\n-------------------------------------------------------------------------------',
            param.verbose)
        printv('  Estimating motion on b=0 images...', param.verbose)
        printv(
            '-------------------------------------------------------------------------------',
            param.verbose)
        param_moco.file_data = 'b0.nii'
        # Identify target image
        if index_moco[0] != 0:
            # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that
            # case select it as the target image for registration of all b=0
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[index_moco[0] - 1]).zfill(4) + ext_data)
        else:
            # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[0]).zfill(4) + ext_data)
        # Run moco
        param_moco.path_out = ''
        param_moco.todo = 'estimate_and_apply'
        param_moco.mat_moco = 'mat_b0groups'
        file_mat_b0, _ = moco(param_moco)

    # Estimate moco across groups
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Estimating motion across groups...', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_datasubgroup
    param_moco.file_target = list_file_group[
        0]  # target is the first volume (closest to the first b=0 if DWI scan)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat_datasub_group, _ = moco(param_moco)

    # Spline Regularization along T
    if param.spline_fitting:
        # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy)
        raise NotImplementedError()
        # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # ==================================================================================================================
    # Apply moco
    # ==================================================================================================================

    # If group_size>1, assign transformation to each individual ungrouped 3d volume
    if param.group_size > 1:
        file_mat_datasub = []
        for iz in range(len(file_mat_datasub_group)):
            # duplicate by factor group_size the transformation file for each it
            #  example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2
            file_mat_datasub.append(
                functools.reduce(operator.iconcat,
                                 [[i] * param.group_size
                                  for i in file_mat_datasub_group[iz]], []))
    else:
        file_mat_datasub = file_mat_datasub_group

    # Copy transformations to mat_final folder and rename them appropriately
    copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param)
    if param.is_diffusion:
        copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param)

    # Apply moco on all dmri data
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Apply moco', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = list_file_group[
        0]  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''  # TODO not used in moco()
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'
    file_mat_data, im_moco = moco(param_moco)

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_moco.header = im_data.header
    im_moco.save(verbose=0)

    # Average across time
    if param.is_diffusion:
        # generate b0_moco_mean and dwi_moco_mean
        args = [
            '-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1',
            '-v', '0'
        ]
        if not param.fname_bvals == '':
            # if bvals file is provided
            args += ['-bval', param.fname_bvals]
        fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main(
            argv=args)
    else:
        fname_moco_mean = add_suffix(im_moco.absolutepath, '_mean')
        im_moco.mean(dim=3).save(fname_moco_mean)

    # Extract and output the motion parameters (doesn't work for sagittal orientation)
    printv('Extract motion parameters...')
    if param.output_motion_param:
        if param.is_sagittal:
            printv(
                'Motion parameters cannot be generated for sagittal images.',
                1, 'warning')
        else:
            files_warp_X, files_warp_Y = [], []
            moco_param = []
            for fname_warp in file_mat_data[0]:
                # Cropping the image to keep only one voxel in the XY plane
                im_warp = Image(fname_warp + param.suffix_mat)
                im_warp.data = np.expand_dims(np.expand_dims(
                    im_warp.data[0, 0, :, :, :], axis=0),
                                              axis=0)

                # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters
                #fname_warp_crop = fname_warp + '_crop_' + ext_mat
                # files_warp.append(fname_warp_crop)
                # im_warp.save(fname_warp_crop)

                # Separating the three components and saving X and Y only (Z is equal to 0 by default).
                im_warp_XYZ = multicomponent_split(im_warp)

                fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat
                im_warp_XYZ[0].save(fname_warp_crop_X)
                files_warp_X.append(fname_warp_crop_X)

                fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat
                im_warp_XYZ[1].save(fname_warp_crop_Y)
                files_warp_Y.append(fname_warp_crop_Y)

                # Calculating the slice-wise average moco estimate to provide a QC file
                moco_param.append([
                    np.mean(np.ravel(im_warp_XYZ[0].data)),
                    np.mean(np.ravel(im_warp_XYZ[1].data))
                ])

            # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters
            # im_warp = [Image(fname) for fname in files_warp]
            # im_warp_concat = concat_data(im_warp, dim=3)
            # im_warp_concat.save('fmri_moco_params.nii')

            # Concatenating the moco parameters into a time series for X and Y components.
            im_warp_X = [Image(fname) for fname in files_warp_X]
            im_warp_concat = concat_data(im_warp_X, dim=3)
            im_warp_concat.save(file_moco_params_x)

            im_warp_Y = [Image(fname) for fname in files_warp_Y]
            im_warp_concat = concat_data(im_warp_Y, dim=3)
            im_warp_concat.save(file_moco_params_y)

            # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC
            with open(file_moco_params_csv, 'wt') as out_file:
                tsv_writer = csv.writer(out_file, delimiter='\t')
                tsv_writer.writerow(['X', 'Y'])
                for mocop in moco_param:
                    tsv_writer.writerow([mocop[0], mocop[1]])

    # Generate output files
    printv('\nGenerate output files...', param.verbose)
    fname_moco = os.path.join(
        path_out_abs,
        add_suffix(os.path.basename(param.fname_data), param.suffix))
    generate_output_file(im_moco.absolutepath, fname_moco)
    if param.is_diffusion:
        generate_output_file(fname_b0_mean, add_suffix(fname_moco, '_b0_mean'))
        generate_output_file(fname_dwi_mean,
                             add_suffix(fname_moco, '_dwi_mean'))
    else:
        generate_output_file(fname_moco_mean, add_suffix(fname_moco, '_mean'))
    if os.path.exists(file_moco_params_csv):
        generate_output_file(file_moco_params_x,
                             os.path.join(path_out_abs, file_moco_params_x),
                             squeeze_data=False)
        generate_output_file(file_moco_params_y,
                             os.path.join(path_out_abs, file_moco_params_y),
                             squeeze_data=False)
        generate_output_file(file_moco_params_csv,
                             os.path.join(path_out_abs, file_moco_params_csv))

    # Delete temporary files
    if param.remove_temp_files == 1:
        printv('\nDelete temporary files...', param.verbose)
        rmtree(path_tmp, verbose=param.verbose)

    # come back to working directory
    os.chdir(curdir)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        param.verbose)

    display_viewer_syntax([
        os.path.join(
            param.path_out,
            add_suffix(os.path.basename(param.fname_data), param.suffix)),
        param.fname_data
    ],
                          mode='ortho,ortho')
def main(argv: Sequence[str]):
    """
    Main function. When this script is run via CLI, the main function is called using main(sys.argv[1:]).

    :param argv: A list of unparsed arguments, which is passed to ArgumentParser.parse_args()
    """
    for i, arg in enumerate(argv):
        if arg == '-create-seg' and len(argv) > i+1 and '-1,' in argv[i+1]:
            raise DeprecationWarning("The use of '-1' for '-create-seg' has been deprecated. Please use "
                                     "'-create-seg-mid' instead.")

    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    input_filename = arguments.i
    output_fname = arguments.o

    img = Image(input_filename)
    dtype = None

    if arguments.add is not None:
        value = arguments.add
        out = sct_labels.add(img, value)
    elif arguments.create is not None:
        labels = arguments.create
        out = sct_labels.create_labels_empty(img, labels)
    elif arguments.create_add is not None:
        labels = arguments.create_add
        out = sct_labels.create_labels(img, labels)
    elif arguments.create_seg is not None:
        labels = arguments.create_seg
        out = sct_labels.create_labels_along_segmentation(img, labels)
    elif arguments.create_seg_mid is not None:
        labels = [(-1, arguments.create_seg_mid)]
        out = sct_labels.create_labels_along_segmentation(img, labels)
    elif arguments.cubic_to_point:
        out = sct_labels.cubic_to_point(img)
    elif arguments.display:
        display_voxel(img, verbose)
        return
    elif arguments.increment:
        out = sct_labels.increment_z_inverse(img)
    elif arguments.disc is not None:
        ref = Image(arguments.disc)
        out = sct_labels.labelize_from_discs(img, ref)
    elif arguments.vert_body is not None:
        levels = arguments.vert_body
        if len(levels) == 1 and levels[0] == 0:
            levels = None  # all levels
        out = sct_labels.label_vertebrae(img, levels)
    elif arguments.vert_continuous:
        out = sct_labels.continuous_vertebral_levels(img)
        dtype = 'float32'
    elif arguments.MSE is not None:
        ref = Image(arguments.MSE)
        mse = sct_labels.compute_mean_squared_error(img, ref)
        printv(f"Computed MSE: {mse}")
        return
    elif arguments.remove_reference is not None:
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)
    elif arguments.remove_sym is not None:
        # first pass use img as source
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)

        # second pass use previous pass result as reference
        ref_out = sct_labels.remove_missing_labels(ref, out)
        ref_out.save(path=ref.absolutepath)
    elif arguments.remove is not None:
        labels = arguments.remove
        out = sct_labels.remove_labels_from_image(img, labels)
    elif arguments.keep is not None:
        labels = arguments.keep
        out = sct_labels.remove_other_labels_from_image(img, labels)
    elif arguments.create_viewer is not None:
        msg = "" if arguments.msg is None else f"{arguments.msg}\n"
        if arguments.ilabel is not None:
            input_labels_img = Image(arguments.ilabel)
            out = launch_manual_label_gui(img, input_labels_img, parse_num_list(arguments.create_viewer), msg)
        else:
            out = launch_sagittal_viewer(img, parse_num_list(arguments.create_viewer), msg)

    printv("Generating output files...")
    out.save(path=output_fname, dtype=dtype)
    display_viewer_syntax([input_filename, output_fname])

    if arguments.qc is not None:
        generate_qc(fname_in1=input_filename, fname_seg=output_fname, args=argv,
                    path_qc=os.path.abspath(arguments.qc), dataset=arguments.qc_dataset,
                    subject=arguments.qc_subject, process='sct_label_utils')
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    # initialize parameters
    param = Param()

    # Initialization
    fname_output = ''
    path_out = ''
    fname_src_seg = ''
    fname_dest_seg = ''
    fname_src_label = ''
    fname_dest_label = ''

    start_time = time.time()

    # get arguments
    fname_src = arguments.i
    fname_dest = arguments.d
    if arguments.iseg is not None:
        fname_src_seg = arguments.iseg
    if arguments.dseg is not None:
        fname_dest_seg = arguments.dseg
    if arguments.ilabel is not None:
        fname_src_label = arguments.ilabel
    if arguments.dlabel is not None:
        fname_dest_label = arguments.dlabel
    if arguments.o is not None:
        fname_output = arguments.o
    if arguments.ofolder is not None:
        path_out = arguments.ofolder
    if arguments.owarp is not None:
        fname_output_warp = arguments.owarp
    else:
        fname_output_warp = ''
    if arguments.owarpinv is not None:
        fname_output_warpinv = arguments.owarpinv
    else:
        fname_output_warpinv = ''
    if arguments.initwarp is not None:
        fname_initwarp = os.path.abspath(arguments.initwarp)
    else:
        fname_initwarp = ''
    if arguments.initwarpinv is not None:
        fname_initwarpinv = os.path.abspath(arguments.initwarpinv)
    else:
        fname_initwarpinv = ''
    if arguments.m is not None:
        fname_mask = arguments.m
    else:
        fname_mask = ''
    padding = arguments.z
    paramregmulti = deepcopy(DEFAULT_PARAMREGMULTI)
    if arguments.param is not None:
        paramregmulti_user = arguments.param
        # update registration parameters
        for paramStep in paramregmulti_user:
            paramregmulti.addStep(paramStep)
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject

    identity = arguments.identity
    interp = arguments.x
    remove_temp_files = arguments.r

    # printv(arguments)
    printv('\nInput parameters:')
    printv('  Source .............. ' + fname_src)
    printv('  Destination ......... ' + fname_dest)
    printv('  Init transfo ........ ' + fname_initwarp)
    printv('  Mask ................ ' + fname_mask)
    printv('  Output name ......... ' + fname_output)
    # printv('  Algorithm ........... '+paramregmulti.algo)
    # printv('  Number of iterations  '+paramregmulti.iter)
    # printv('  Metric .............. '+paramregmulti.metric)
    printv('  Remove temp files ... ' + str(remove_temp_files))
    printv('  Verbose ............. ' + str(verbose))

    # update param
    param.verbose = verbose
    param.padding = padding
    param.fname_mask = fname_mask
    param.remove_temp_files = remove_temp_files

    # Get if input is 3D
    printv('\nCheck if input data are 3D...', verbose)
    check_dim(fname_src, dim_lst=[3])
    check_dim(fname_dest, dim_lst=[3])

    # Check if user selected type=seg, but did not input segmentation data
    if 'paramregmulti_user' in locals():
        if True in ['type=seg' in paramregmulti_user[i] for i in range(len(paramregmulti_user))]:
            if fname_src_seg == '' or fname_dest_seg == '':
                printv('\nERROR: if you select type=seg you must specify -iseg and -dseg flags.\n', 1, 'error')

    # Put source into destination space using header (no estimation -- purely based on header)
    # TODO: Check if necessary to do that
    # TODO: use that as step=0
    # printv('\nPut source into destination space using header...', verbose)
    # run_proc('isct_antsRegistration -d 3 -t Translation[0] -m MI[dest_pad.nii,src.nii,1,16] -c 0 -f 1 -s 0 -o
    # [regAffine,src_regAffine.nii] -n BSpline[3]', verbose)
    # if segmentation, also do it for seg

    fname_src2dest, fname_dest2src, _, _ = \
        register_wrapper(fname_src, fname_dest, param, paramregmulti, fname_src_seg=fname_src_seg,
                         fname_dest_seg=fname_dest_seg, fname_src_label=fname_src_label,
                         fname_dest_label=fname_dest_label, fname_mask=fname_mask, fname_initwarp=fname_initwarp,
                         fname_initwarpinv=fname_initwarpinv, identity=identity, interp=interp,
                         fname_output=fname_output,
                         fname_output_warp=fname_output_warp, fname_output_warpinv=fname_output_warpinv,
                         path_out=path_out)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    if path_qc is not None:
        if fname_dest_seg:
            generate_qc(fname_src2dest, fname_in2=fname_dest, fname_seg=fname_dest_seg, args=argv,
                        path_qc=os.path.abspath(path_qc), dataset=qc_dataset, subject=qc_subject,
                        process='sct_register_multimodal')
        else:
            printv('WARNING: Cannot generate QC because it requires destination segmentation.', 1, 'warning')

    # If dest wasn't registered (e.g. unidirectional registration due to '-initwarp'), then don't output syntax
    if fname_dest2src:
        display_viewer_syntax([fname_src, fname_dest2src], verbose=verbose)
    display_viewer_syntax([fname_dest, fname_src2dest], verbose=verbose)
Example #22
0
def main(args=None):
    if args is None:
        args = sys.argv[1:]

    # initialize parameters
    param = Param()

    # Initialization
    fname_output = ''
    path_out = ''
    fname_src_seg = ''
    fname_dest_seg = ''
    fname_src_label = ''
    fname_dest_label = ''
    generate_warpinv = 1

    start_time = time.time()

    # get default registration parameters
    # step1 = Paramreg(step='1', type='im', algo='syn', metric='MI', iter='5', shrink='1', smooth='0', gradStep='0.5')
    step0 = Paramreg(
        step='0',
        type='im',
        algo='syn',
        metric='MI',
        iter='0',
        shrink='1',
        smooth='0',
        gradStep='0.5',
        slicewise='0',
        dof='Tx_Ty_Tz_Rx_Ry_Rz')  # only used to put src into dest space
    step1 = Paramreg(step='1', type='im')
    paramregmulti = ParamregMultiStep([step0, step1])

    parser = get_parser(paramregmulti=paramregmulti)

    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # get arguments
    fname_src = arguments.i
    fname_dest = arguments.d
    if arguments.iseg is not None:
        fname_src_seg = arguments.iseg
    if arguments.dseg is not None:
        fname_dest_seg = arguments.dseg
    if arguments.ilabel is not None:
        fname_src_label = arguments.ilabel
    if arguments.dlabel is not None:
        fname_dest_label = arguments.dlabel
    if arguments.o is not None:
        fname_output = arguments.o
    if arguments.ofolder is not None:
        path_out = arguments.ofolder
    if arguments.owarp is not None:
        fname_output_warp = arguments.owarp
    else:
        fname_output_warp = ''
    if arguments.initwarp is not None:
        fname_initwarp = os.path.abspath(arguments.initwarp)
    else:
        fname_initwarp = ''
    if arguments.initwarpinv is not None:
        fname_initwarpinv = os.path.abspath(arguments.initwarpinv)
    else:
        fname_initwarpinv = ''
    if arguments.m is not None:
        fname_mask = arguments.m
    else:
        fname_mask = ''
    padding = arguments.z
    if arguments.param is not None:
        paramregmulti_user = arguments.param
        # update registration parameters
        for paramStep in paramregmulti_user:
            paramregmulti.addStep(paramStep)
    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject

    identity = arguments.identity
    interp = arguments.x
    remove_temp_files = arguments.r
    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level

    # printv(arguments)
    printv('\nInput parameters:')
    printv('  Source .............. ' + fname_src)
    printv('  Destination ......... ' + fname_dest)
    printv('  Init transfo ........ ' + fname_initwarp)
    printv('  Mask ................ ' + fname_mask)
    printv('  Output name ......... ' + fname_output)
    # printv('  Algorithm ........... '+paramregmulti.algo)
    # printv('  Number of iterations  '+paramregmulti.iter)
    # printv('  Metric .............. '+paramregmulti.metric)
    printv('  Remove temp files ... ' + str(remove_temp_files))
    printv('  Verbose ............. ' + str(verbose))

    # update param
    param.verbose = verbose
    param.padding = padding
    param.fname_mask = fname_mask
    param.remove_temp_files = remove_temp_files

    # Get if input is 3D
    printv('\nCheck if input data are 3D...', verbose)
    check_dim(fname_src, dim_lst=[3])
    check_dim(fname_dest, dim_lst=[3])

    # Check if user selected type=seg, but did not input segmentation data
    if 'paramregmulti_user' in locals():
        if True in [
                'type=seg' in paramregmulti_user[i]
                for i in range(len(paramregmulti_user))
        ]:
            if fname_src_seg == '' or fname_dest_seg == '':
                printv(
                    '\nERROR: if you select type=seg you must specify -iseg and -dseg flags.\n',
                    1, 'error')

    # Put source into destination space using header (no estimation -- purely based on header)
    # TODO: Check if necessary to do that
    # TODO: use that as step=0
    # printv('\nPut source into destination space using header...', verbose)
    # run_proc('isct_antsRegistration -d 3 -t Translation[0] -m MI[dest_pad.nii,src.nii,1,16] -c 0 -f 1 -s 0 -o
    # [regAffine,src_regAffine.nii] -n BSpline[3]', verbose)
    # if segmentation, also do it for seg

    fname_src2dest, fname_dest2src, _, _ = \
        register_wrapper(fname_src, fname_dest, param, paramregmulti, fname_src_seg=fname_src_seg,
                         fname_dest_seg=fname_dest_seg, fname_src_label=fname_src_label,
                         fname_dest_label=fname_dest_label, fname_mask=fname_mask, fname_initwarp=fname_initwarp,
                         fname_initwarpinv=fname_initwarpinv, identity=identity, interp=interp,
                         fname_output=fname_output,
                         fname_output_warp=fname_output_warp,
                         path_out=path_out)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        verbose)

    if path_qc is not None:
        if fname_dest_seg:
            generate_qc(fname_src2dest,
                        fname_in2=fname_dest,
                        fname_seg=fname_dest_seg,
                        args=args,
                        path_qc=os.path.abspath(path_qc),
                        dataset=qc_dataset,
                        subject=qc_subject,
                        process='sct_register_multimodal')
        else:
            printv(
                'WARNING: Cannot generate QC because it requires destination segmentation.',
                1, 'warning')

    if generate_warpinv:
        display_viewer_syntax([fname_src, fname_dest2src], verbose=verbose)
    display_viewer_syntax([fname_dest, fname_src2dest], verbose=verbose)
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    if (arguments.list_tasks is False
            and arguments.install_task is None
            and (arguments.i is None or arguments.task is None)):
        parser.error("You must specify either '-list-tasks', '-install-task', or both '-i' + '-task'.")

    # Deal with task
    if arguments.list_tasks:
        deepseg.models.display_list_tasks()

    if arguments.install_task is not None:
        for name_model in deepseg.models.TASKS[arguments.install_task]['models']:
            deepseg.models.install_model(name_model)
        exit(0)

    # Deal with input/output
    for file in arguments.i:
        if not os.path.isfile(file):
            parser.error("This file does not exist: {}".format(file))

    # Verify if the task is part of the "official" tasks, or if it is pointing to paths containing custom models
    if len(arguments.task) == 1 and arguments.task[0] in deepseg.models.TASKS:
        # Check if all input images are provided
        required_contrasts = deepseg.models.get_required_contrasts(arguments.task[0])
        n_contrasts = len(required_contrasts)
        # Get pipeline model names
        name_models = deepseg.models.TASKS[arguments.task[0]]['models']
    else:
        n_contrasts = len(arguments.i)
        name_models = arguments.task

    if len(arguments.i) != n_contrasts:
        parser.error(
            "{} input files found. Please provide all required input files for the task {}, i.e. contrasts: {}."
            .format(len(arguments.i), arguments.task, ', '.join(required_contrasts)))

    # Check modality order
    if len(arguments.i) > 1 and arguments.c is None:
        parser.error(
            "Please specify the order in which you put the contrasts in the input images (-i) with flag -c, e.g., "
            "-c t1 t2")

    # Run pipeline by iterating through the models
    fname_prior = None
    output_filenames = None
    for name_model in name_models:
        # Check if this is an official model
        if name_model in list(deepseg.models.MODELS.keys()):
            # If it is, check if it is installed
            path_model = deepseg.models.folder(name_model)
            if not deepseg.models.is_valid(path_model):
                printv("Model {} is not installed. Installing it now...".format(name_model))
                deepseg.models.install_model(name_model)
        # If it is not, check if this is a path to a valid model
        else:
            path_model = os.path.abspath(name_model)
            if not deepseg.models.is_valid(path_model):
                parser.error("The input model is invalid: {}".format(path_model))

        # Order input images
        if arguments.c is not None:
            input_filenames = []
            for required_contrast in deepseg.models.MODELS[name_model]['contrasts']:
                for provided_contrast, input_filename in zip(arguments.c, arguments.i):
                    if required_contrast == provided_contrast:
                        input_filenames.append(input_filename)
        else:
            input_filenames = arguments.i

        # Call segment_nifti
        options = {**vars(arguments), "fname_prior": fname_prior}
        nii_lst, target_lst = imed_inference.segment_volume(path_model, input_filenames, options=options)

        # Delete intermediate outputs
        if fname_prior and os.path.isfile(fname_prior) and arguments.r:
            logger.info("Remove temporary files...")
            os.remove(fname_prior)

        output_filenames = []
        # Save output seg
        for nii_seg, target in zip(nii_lst, target_lst):
            if 'o' in options and options['o'] is not None:
                # To support if the user adds the extension or not
                extension = ".nii.gz" if ".nii.gz" in options['o'] else ".nii" if ".nii" in options['o'] else ""
                if extension == "":
                    fname_seg = options['o'] + target if len(target_lst) > 1 else options['o']
                else:
                    fname_seg = options['o'].replace(extension, target + extension) if len(target_lst) > 1 \
                        else options['o']
            else:
                fname_seg = ''.join([sct.image.splitext(input_filenames[0])[0], target + '.nii.gz'])

            # If output folder does not exist, create it
            path_out = os.path.dirname(fname_seg)
            if not (path_out == '' or os.path.exists(path_out)):
                os.makedirs(path_out)

            nib.save(nii_seg, fname_seg)
            output_filenames.append(fname_seg)

        # Use the result of the current model as additional input of the next model
        fname_prior = fname_seg

    for output_filename in output_filenames:
        display_viewer_syntax([arguments.i[0], output_filename], colormaps=['gray', 'red'], opacities=['', '0.7'])
Example #24
0
def run_main():
    init_sct()
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # Input filename
    fname_input_data = arguments.i
    fname_data = os.path.abspath(fname_input_data)

    # Method used
    method = arguments.method

    # Contrast type
    contrast_type = arguments.c
    if method == 'optic' and not contrast_type:
        # Contrast must be
        error = "ERROR: -c is a mandatory argument when using 'optic' method."
        printv(error, type='error')
        return

    # Gap between slices
    interslice_gap = arguments.gap

    param_centerline = ParamCenterline(algo_fitting=arguments.centerline_algo,
                                       smooth=arguments.centerline_smooth,
                                       minmax=True)

    # Output folder
    if arguments.o is not None:
        file_output = arguments.o
    else:
        path_data, file_data, ext_data = extract_fname(fname_data)
        file_output = os.path.join(path_data, file_data + '_centerline')

    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level

    if method == 'viewer':
        # Manual labeling of cord centerline
        im_labels = _call_viewer_centerline(Image(fname_data),
                                            interslice_gap=interslice_gap)
    elif method == 'fitseg':
        im_labels = Image(fname_data)
    elif method == 'optic':
        # Automatic detection of cord centerline
        im_labels = Image(fname_data)
        param_centerline.algo_fitting = 'optic'
        param_centerline.contrast = contrast_type
    else:
        printv(
            "ERROR: The selected method is not available: {}. Please look at the help."
            .format(method),
            type='error')
        return

    # Extrapolate and regularize (or detect if optic) cord centerline
    im_centerline, arr_centerline, _, _ = get_centerline(
        im_labels, param=param_centerline, verbose=verbose)

    # save centerline as nifti (discrete) and csv (continuous) files
    im_centerline.save(file_output + '.nii.gz')
    np.savetxt(file_output + '.csv', arr_centerline.transpose(), delimiter=",")

    display_viewer_syntax([fname_input_data, file_output + '.nii.gz'],
                          colormaps=['gray', 'red'],
                          opacities=['', '1'])

    path_qc = arguments.qc
    qc_dataset = arguments.qc_dataset
    qc_subject = arguments.qc_subject

    # Generate QC report
    if path_qc is not None:
        generate_qc(fname_input_data,
                    fname_seg=file_output + '.nii.gz',
                    args=sys.argv[1:],
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_get_centerline')
    display_viewer_syntax([fname_input_data, file_output + '.nii.gz'],
                          colormaps=['gray', 'red'],
                          opacities=['', '0.7'])
def create_mask(param):
    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        check_file_exist(method_val, param.verbose)

    # Extract path/file/extension
    path_data, file_data, ext_data = extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = os.path.abspath(param.file_prefix + file_data +
                                          ext_data)

    path_tmp = tmp_create(basename="create_mask")

    printv('\nOrientation:', param.verbose)
    orientation_input = Image(param.fname_data).orientation
    printv('  ' + orientation_input, param.verbose)

    # copy input data to tmp folder and re-orient to RPI
    Image(param.fname_data).change_orientation("RPI").save(
        os.path.join(path_tmp, "data_RPI.nii"))
    if method_type == 'centerline':
        Image(method_val).change_orientation("RPI").save(
            os.path.join(path_tmp, "centerline_RPI.nii"))
    if method_type == 'point':
        Image(method_val).change_orientation("RPI").save(
            os.path.join(path_tmp, "point_RPI.nii"))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get dimensions of data
    im_data = Image('data_RPI.nii')
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('\nDimensions:', param.verbose)
    printv(im_data.dim, param.verbose)
    # in case user input 4d data
    if nt != 1:
        printv(
            'WARNING in ' + os.path.basename(__file__) +
            ': Input image is 4d but output mask will be 3D from first time slice.',
            param.verbose, 'warning')
        # extract first volume to have 3d reference
        nii = empty_like(Image('data_RPI.nii'))
        data3d = nii.data[:, :, :, 0]
        nii.data = data3d
        nii.save('data_RPI.nii')

    if method_type == 'coord':
        # parse to get coordinate
        coord = [x for x in map(int, method_val.split('x'))]

    if method_type == 'point':
        # extract coordinate of point
        printv('\nExtract coordinate of point...', param.verbose)
        coord = Image("point_RPI.nii").getNonZeroCoordinates()

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = np.round(float(nx) / 2), np.round(float(ny) / 2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline_RPI.nii'
    else:
        # generate volume with line along Z at coordinates 'coord'
        printv('\nCreate line...', param.verbose)
        fname_centerline = create_line(param, 'data_RPI.nii', coord, nz)

    # create mask
    printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    spacing = hdr.structarr['pixdim']
    data_centerline = centerline.get_data()  # get centerline
    # if data is 2D, reshape with empty third dimension
    if len(data_centerline.shape) == 2:
        data_centerline_shape = list(data_centerline.shape)
        data_centerline_shape.append(1)
        data_centerline = data_centerline.reshape(data_centerline_shape)
    z_centerline_not_null = [
        iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()
    ]
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        if iz in z_centerline_not_null:
            cx[iz], cy[iz] = ndimage.measurements.center_of_mass(
                np.array(data_centerline[:, :, iz]))
    # create 2d masks
    im_list = []
    for iz in range(nz):
        if iz not in z_centerline_not_null:
            im_list.append(Image(data_centerline[:, :, iz], hdr=hdr))
        else:
            center = np.array([cx[iz], cy[iz]])
            mask2d = create_mask2d(param,
                                   center,
                                   param.shape,
                                   param.size,
                                   im_data=im_data)
            im_list.append(Image(mask2d, hdr=hdr))
    im_out = concat_data(im_list, dim=2).save('mask_RPI.nii.gz')

    im_out.change_orientation(orientation_input)
    im_out.header = Image(param.fname_data).header
    im_out.save(param.fname_out)

    # come back
    os.chdir(curdir)

    # Remove temporary files
    if param.remove_temp_files == 1:
        printv('\nRemove temporary files...', param.verbose)
        rmtree(path_tmp)

    display_viewer_syntax([param.fname_data, param.fname_out],
                          colormaps=['gray', 'red'],
                          opacities=['', '0.5'])
Example #26
0
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initializations
    output_type = None
    dim_list = ['x', 'y', 'z', 't']

    fname_in = arguments.i
    n_in = len(fname_in)

    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = None

    # Run command
    # Arguments are sorted alphabetically (not according to the usage order)
    if arguments.concat is not None:
        dim = arguments.concat
        assert dim in dim_list
        dim = dim_list.index(dim)
        im_out = [concat_data(fname_in, dim)]  # TODO: adapt to fname_in

    elif arguments.copy_header is not None:
        if fname_out is None:
            raise ValueError("Need to specify output image with -o!")
        im_in = Image(fname_in[0])
        im_dest = Image(arguments.copy_header)
        im_dest_new = im_in.copy()
        im_dest_new.data = im_dest.data.copy()
        # im_dest.header = im_in.header
        im_dest_new.absolutepath = im_dest.absolutepath
        im_out = [im_dest_new]

    elif arguments.display_warp:
        im_in = fname_in[0]
        visualize_warp(im_in, fname_grid=None, step=3, rm_tmp=True)
        im_out = None

    elif arguments.getorient:
        im_in = Image(fname_in[0])
        orient = im_in.orientation
        im_out = None

    elif arguments.keep_vol is not None:
        index_vol = (arguments.keep_vol).split(',')
        for iindex_vol, vol in enumerate(index_vol):
            index_vol[iindex_vol] = int(vol)
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='keep')]

    elif arguments.mcs:
        im_in = Image(fname_in[0])
        if n_in != 1:
            printv(parser.error('ERROR: -mcs need only one input'))
        if len(im_in.data.shape) != 5:
            printv(
                parser.error(
                    'ERROR: -mcs input need to be a multi-component image'))
        im_out = multicomponent_split(im_in)

    elif arguments.omc:
        im_ref = Image(fname_in[0])
        for fname in fname_in:
            im = Image(fname)
            if im.data.shape != im_ref.data.shape:
                printv(
                    parser.error(
                        'ERROR: -omc inputs need to have all the same shapes'))
            del im
        im_out = [multicomponent_merge(fname_in)]  # TODO: adapt to fname_in

    elif arguments.pad is not None:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            printv('ERROR: you need to specify a 3D input file.', 1, 'error')
            return

        pad_arguments = arguments.pad.split(',')
        if len(pad_arguments) != 3:
            printv('ERROR: you need to specify 3 padding values.', 1, 'error')

        padx, pady, padz = pad_arguments
        padx, pady, padz = int(padx), int(pady), int(padz)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padx,
                      pad_x_f=padx,
                      pad_y_i=pady,
                      pad_y_f=pady,
                      pad_z_i=padz,
                      pad_z_f=padz)
        ]

    elif arguments.pad_asym is not None:
        im_in = Image(fname_in[0])
        ndims = len(im_in.data.shape)
        if ndims != 3:
            printv('ERROR: you need to specify a 3D input file.', 1, 'error')
            return

        pad_arguments = arguments.pad_asym.split(',')
        if len(pad_arguments) != 6:
            printv('ERROR: you need to specify 6 padding values.', 1, 'error')

        padxi, padxf, padyi, padyf, padzi, padzf = pad_arguments
        padxi, padxf, padyi, padyf, padzi, padzf = int(padxi), int(padxf), int(
            padyi), int(padyf), int(padzi), int(padzf)
        im_out = [
            pad_image(im_in,
                      pad_x_i=padxi,
                      pad_x_f=padxf,
                      pad_y_i=padyi,
                      pad_y_f=padyf,
                      pad_z_i=padzi,
                      pad_z_f=padzf)
        ]

    elif arguments.remove_vol is not None:
        index_vol = (arguments.remove_vol).split(',')
        for iindex_vol, vol in enumerate(index_vol):
            index_vol[iindex_vol] = int(vol)
        im_in = Image(fname_in[0])
        im_out = [remove_vol(im_in, index_vol, todo='remove')]

    elif arguments.setorient is not None:
        printv(fname_in[0])
        im_in = Image(fname_in[0])
        im_out = [change_orientation(im_in, arguments.setorient)]

    elif arguments.setorient_data is not None:
        im_in = Image(fname_in[0])
        im_out = [
            change_orientation(im_in, arguments.setorient_data, data_only=True)
        ]

    elif arguments.split is not None:
        dim = arguments.split
        assert dim in dim_list
        im_in = Image(fname_in[0])
        dim = dim_list.index(dim)
        im_out = split_data(im_in, dim)

    elif arguments.type is not None:
        output_type = arguments.type
        im_in = Image(fname_in[0])
        im_out = [im_in]  # TODO: adapt to fname_in

    elif arguments.to_fsl is not None:
        space_files = arguments.to_fsl
        if len(space_files) > 2 or len(space_files) < 1:
            printv(parser.error('ERROR: -to-fsl expects 1 or 2 arguments'))
            return
        spaces = [Image(s) for s in space_files]
        if len(spaces) < 2:
            spaces.append(None)
        im_out = [
            displacement_to_abs_fsl(Image(fname_in[0]), spaces[0], spaces[1])
        ]

    else:
        im_out = None
        printv(
            parser.error(
                'ERROR: you need to specify an operation to do on the input image'
            ))

    # in case fname_out is not defined, use first element of input file name list
    if fname_out is None:
        fname_out = fname_in[0]

    # Write output
    if im_out is not None:
        printv('Generate output files...', verbose)
        # if only one output
        if len(im_out) == 1 and arguments.split is None:
            im_out[0].save(fname_out, dtype=output_type, verbose=verbose)
            display_viewer_syntax([fname_out], verbose=verbose)
        if arguments.mcs:
            # use input file name and add _X, _Y _Z. Keep the same extension
            l_fname_out = []
            for i_dim in range(3):
                l_fname_out.append(
                    add_suffix(fname_out or fname_in[0],
                               '_' + dim_list[i_dim].upper()))
                im_out[i_dim].save(l_fname_out[i_dim], verbose=verbose)
            display_viewer_syntax(fname_out)
        if arguments.split is not None:
            # use input file name and add _"DIM+NUMBER". Keep the same extension
            l_fname_out = []
            for i, im in enumerate(im_out):
                l_fname_out.append(
                    add_suffix(fname_out or fname_in[0],
                               '_' + dim_list[dim].upper() + str(i).zfill(4)))
                im.save(l_fname_out[i])
            display_viewer_syntax(l_fname_out)

    elif arguments.getorient:
        printv(orient)

    elif arguments.display_warp:
        printv('Warping grid generated.', verbose, 'info')
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv if argv else ['--help'])
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initializations
    initz = ''
    initcenter = ''
    fname_initlabel = ''
    file_labelz = 'labelz.nii.gz'
    param = Param()

    fname_in = os.path.abspath(arguments.i)
    fname_seg = os.path.abspath(arguments.s)
    contrast = arguments.c
    path_template = os.path.abspath(arguments.t)
    scale_dist = arguments.scale_dist
    path_output = arguments.ofolder
    param.path_qc = arguments.qc
    if arguments.discfile is not None:
        fname_disc = os.path.abspath(arguments.discfile)
    else:
        fname_disc = None
    if arguments.initz is not None:
        initz = arguments.initz
        if len(initz) != 2:
            raise ValueError(
                '--initz takes two arguments: position in superior-inferior direction, label value'
            )
    if arguments.initcenter is not None:
        initcenter = arguments.initcenter
    # if user provided text file, parse and overwrite arguments
    if arguments.initfile is not None:
        file = open(arguments.initfile, 'r')
        initfile = ' ' + file.read().replace('\n', '')
        arg_initfile = initfile.split(' ')
        for idx_arg, arg in enumerate(arg_initfile):
            if arg == '-initz':
                initz = [int(x) for x in arg_initfile[idx_arg + 1].split(',')]
                if len(initz) != 2:
                    raise ValueError(
                        '--initz takes two arguments: position in superior-inferior direction, label value'
                    )
            if arg == '-initcenter':
                initcenter = int(arg_initfile[idx_arg + 1])
    if arguments.initlabel is not None:
        # get absolute path of label
        fname_initlabel = os.path.abspath(arguments.initlabel)
    if arguments.param is not None:
        param.update(arguments.param[0])
    remove_temp_files = arguments.r
    clean_labels = arguments.clean_labels
    laplacian = arguments.laplacian

    path_tmp = tmp_create(basename="label_vertebrae")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder...', verbose)
    Image(fname_in).save(os.path.join(path_tmp, "data.nii"))
    Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii"))

    # Go go temp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Straighten spinal cord
    printv('\nStraighten spinal cord...', verbose)
    # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
    cache_sig = cache_signature(input_files=[fname_in, fname_seg], )
    cachefile = os.path.join(curdir, "straightening.cache")
    if cache_valid(cachefile, cache_sig) and os.path.isfile(
            os.path.join(
                curdir, "warp_curve2straight.nii.gz")) and os.path.isfile(
                    os.path.join(
                        curdir,
                        "warp_straight2curve.nii.gz")) and os.path.isfile(
                            os.path.join(curdir, "straight_ref.nii.gz")):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, "warp_curve2straight.nii.gz"),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, "warp_straight2curve.nii.gz"),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, "straight_ref.nii.gz"),
             'straight_ref.nii.gz')
        # apply straightening
        s, o = run_proc([
            'sct_apply_transfo', '-i', 'data.nii', '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'data_straight.nii'
        ])
    else:
        sct_straighten_spinalcord.main(argv=[
            '-i',
            'data.nii',
            '-s',
            'segmentation.nii',
            '-r',
            str(remove_temp_files),
            '-v',
            str(verbose),
        ])
        cache_save(cachefile, cache_sig)

    # resample to 0.5mm isotropic to match template resolution
    printv('\nResample to 0.5mm isotropic...', verbose)
    s, o = run_proc([
        'sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x',
        'linear', '-o', 'data_straightr.nii'
    ],
                    verbose=verbose)

    # Apply straightening to segmentation
    # N.B. Output is RPI
    printv('\nApply straightening to segmentation...', verbose)
    run_proc(
        'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
        ('segmentation.nii', 'data_straightr.nii',
         'warp_curve2straight.nii.gz', 'segmentation_straight.nii', 'Linear'),
        verbose=verbose,
        is_sct_binary=True,
    )
    # Threshold segmentation at 0.5
    run_proc([
        'sct_maths', '-i', 'segmentation_straight.nii', '-thr', '0.5', '-o',
        'segmentation_straight.nii'
    ], verbose)

    # If disc label file is provided, label vertebrae using that file instead of automatically
    if fname_disc:
        # Apply straightening to disc-label
        printv('\nApply straightening to disc labels...', verbose)
        run_proc(
            'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
            (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labeldisc_straight.nii.gz', 'label'),
            verbose=verbose)
        label_vert('segmentation_straight.nii',
                   'labeldisc_straight.nii.gz',
                   verbose=1)

    else:
        # create label to identify disc
        printv('\nCreate label to identify disc...', verbose)
        fname_labelz = os.path.join(path_tmp, file_labelz)
        if initz or initcenter:
            if initcenter:
                # find z centered in FOV
                nii = Image('segmentation.nii').change_orientation("RPI")
                nx, ny, nz, nt, px, py, pz, pt = nii.dim  # Get dimensions
                z_center = int(np.round(nz / 2))  # get z_center
                initz = [z_center, initcenter]

            im_label = create_labels_along_segmentation(
                Image('segmentation.nii'), [(initz[0], initz[1])])
            im_label.data = dilate(im_label.data, 3, 'ball')
            im_label.save(fname_labelz)

        elif fname_initlabel:
            Image(fname_initlabel).save(fname_labelz)

        else:
            # automatically finds C2-C3 disc
            im_data = Image('data.nii')
            im_seg = Image('segmentation.nii')
            if not remove_temp_files:  # because verbose is here also used for keeping temp files
                verbose_detect_c2c3 = 2
            else:
                verbose_detect_c2c3 = 0
            im_label_c2c3 = detect_c2c3(im_data,
                                        im_seg,
                                        contrast,
                                        verbose=verbose_detect_c2c3)
            ind_label = np.where(im_label_c2c3.data)
            if not np.size(ind_label) == 0:
                im_label_c2c3.data[ind_label] = 3
            else:
                printv(
                    'Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils',
                    1, 'error')
                sys.exit()
            im_label_c2c3.save(fname_labelz)

        # dilate label so it is not lost when applying warping
        dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz)

        # Apply straightening to z-label
        printv('\nAnd apply straightening to label...', verbose)
        run_proc(
            'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
            (file_labelz, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labelz_straight.nii.gz', 'NearestNeighbor'),
            verbose=verbose,
            is_sct_binary=True,
        )
        # get z value and disk value to initialize labeling
        printv('\nGet z and disc values from straight label...', verbose)
        init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz')
        printv('.. ' + str(init_disc), verbose)

        # apply laplacian filtering
        if laplacian:
            printv('\nApply Laplacian filter...', verbose)
            run_proc([
                'sct_maths', '-i', 'data_straightr.nii', '-laplacian', '1',
                '-o', 'data_straightr.nii'
            ], verbose)

        # detect vertebral levels on straight spinal cord
        init_disc[1] = init_disc[1] - 1
        vertebral_detection('data_straightr.nii',
                            'segmentation_straight.nii',
                            contrast,
                            param,
                            init_disc=init_disc,
                            verbose=verbose,
                            path_template=path_template,
                            path_output=path_output,
                            scale_dist=scale_dist)

    # un-straighten labeled spinal cord
    printv('\nUn-straighten labeling...', verbose)
    run_proc(
        'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
        ('segmentation_straight_labeled.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled.nii',
         'NearestNeighbor'),
        verbose=verbose,
        is_sct_binary=True,
    )

    if clean_labels:
        # Clean labeled segmentation
        printv(
            '\nClean labeled segmentation (correct interpolation errors)...',
            verbose)
        clean_labeled_segmentation('segmentation_labeled.nii',
                                   'segmentation.nii',
                                   'segmentation_labeled.nii')

    # label discs
    printv('\nLabel discs...', verbose)
    printv('\nUn-straighten labeled discs...', verbose)
    run_proc(
        'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
        ('segmentation_straight_labeled_disc.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled_disc.nii',
         'label'),
        verbose=verbose,
        is_sct_binary=True,
    )

    # come back
    os.chdir(curdir)

    # Generate output files
    path_seg, file_seg, ext_seg = extract_fname(fname_seg)
    fname_seg_labeled = os.path.join(path_output,
                                     file_seg + '_labeled' + ext_seg)
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"),
                         fname_seg_labeled)
    generate_output_file(
        os.path.join(path_tmp, "segmentation_labeled_disc.nii"),
        os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
    # copy straightening files in case subsequent SCT functions need them
    generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"),
                         os.path.join(path_output,
                                      "warp_curve2straight.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"),
                         os.path.join(path_output,
                                      "warp_straight2curve.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"),
                         os.path.join(path_output, "straight_ref.nii.gz"),
                         verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp)

    # Generate QC report
    if param.path_qc is not None:
        path_qc = os.path.abspath(arguments.qc)
        qc_dataset = arguments.qc_dataset
        qc_subject = arguments.qc_subject
        labeled_seg_file = os.path.join(path_output,
                                        file_seg + '_labeled' + ext_seg)
        generate_qc(fname_in,
                    fname_seg=labeled_seg_file,
                    args=argv,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_label_vertebrae')

    display_viewer_syntax([fname_in, fname_seg_labeled],
                          colormaps=['', 'subcortical'],
                          opacities=['1', '0.5'])
def main(args=None):
    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
    else:
        arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    verbosity = arguments.v
    init_sct(log_level=verbosity, update=True)  # Update log level

    input_filename = arguments.i
    output_fname = arguments.o

    img = Image(input_filename)
    dtype = None

    if arguments.add is not None:
        value = arguments.add
        out = sct_labels.add(img, value)
    elif arguments.create is not None:
        labels = arguments.create
        out = sct_labels.create_labels_empty(img, labels)
    elif arguments.create_add is not None:
        labels = arguments.create_add
        out = sct_labels.create_labels(img, labels)
    elif arguments.create_seg is not None:
        labels = arguments.create_seg
        out = sct_labels.create_labels_along_segmentation(img, labels)
    elif arguments.cubic_to_point:
        out = sct_labels.cubic_to_point(img)
    elif arguments.display:
        display_voxel(img, verbosity)
        return
    elif arguments.increment:
        out = sct_labels.increment_z_inverse(img)
    elif arguments.disc is not None:
        ref = Image(arguments.disc)
        out = sct_labels.labelize_from_discs(img, ref)
    elif arguments.vert_body is not None:
        levels = arguments.vert_body
        if len(levels) == 1 and levels[0] == 0:
            levels = None  # all levels
        out = sct_labels.label_vertebrae(img, levels)
    elif arguments.vert_continuous:
        out = sct_labels.continuous_vertebral_levels(img)
        dtype = 'float32'
    elif arguments.MSE is not None:
        ref = Image(arguments.MSE)
        mse = sct_labels.compute_mean_squared_error(img, ref)
        printv(f"Computed MSE: {mse}")
        return
    elif arguments.remove_reference is not None:
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)
    elif arguments.remove_sym is not None:
        # first pass use img as source
        ref = Image(arguments.remove_reference)
        out = sct_labels.remove_missing_labels(img, ref)

        # second pass use previous pass result as reference
        ref_out = sct_labels.remove_missing_labels(ref, out)
        ref_out.save(path=ref.absolutepath)
    elif arguments.remove is not None:
        labels = arguments.remove
        out = sct_labels.remove_labels_from_image(img, labels)
    elif arguments.keep is not None:
        labels = arguments.keep
        out = sct_labels.remove_other_labels_from_image(img, labels)
    elif arguments.create_viewer is not None:
        msg = "" if arguments.msg is None else f"{arguments.msg}\n"
        if arguments.ilabel is not None:
            input_labels_img = Image(arguments.ilabel)
            out = launch_manual_label_gui(img, input_labels_img, parse_num_list(arguments.create_viewer), msg)
        else:
            out = launch_sagittal_viewer(img, parse_num_list(arguments.create_viewer), msg)

    printv("Generating output files...")
    out.save(path=output_fname, dtype=dtype)
    display_viewer_syntax([input_filename, output_fname])

    if arguments.qc is not None:
        generate_qc(fname_in1=input_filename, fname_seg=output_fname, args=args,
                    path_qc=os.path.abspath(arguments.qc), dataset=arguments.qc_dataset,
                    subject=arguments.qc_subject, process='sct_label_utils')