def apply_transfo(im_src, im_dest, warp, interp='spline', rm_tmp=True):
    # create tmp dir and go in it
    tmp_dir = tmp_create()
    # copy warping field to tmp dir
    copy(warp, tmp_dir)
    warp = ''.join(extract_fname(warp)[1:])
    # go to tmp dir
    curdir = os.getcwd()
    os.chdir(tmp_dir)
    # save image and seg
    fname_src = 'src.nii.gz'
    im_src.save(fname_src)
    fname_dest = 'dest.nii.gz'
    im_dest.save(fname_dest)
    # apply warping field
    fname_src_reg = add_suffix(fname_src, '_reg')
    sct_apply_transfo.main(
        argv=['-i', fname_src, '-d', fname_dest, '-w', warp, '-x', interp])

    im_src_reg = Image(fname_src_reg)
    # get out of tmp dir
    os.chdir(curdir)
    if rm_tmp:
        # remove tmp dir
        rmtree(tmp_dir)
    # return res image
    return im_src_reg
def visualize_warp(im_warp: Image, im_grid: Image = None, step=3, rm_tmp=True):
    fname_warp = im_warp.absolutepath
    if im_grid:
        fname_grid = im_grid.absolutepath
    else:
        tmp_dir = tmp_create()
        nx, ny, nz = im_warp.data.shape[0:3]
        curdir = os.getcwd()
        os.chdir(tmp_dir)
        sq = np.zeros((step, step))
        sq[step - 1] = 1
        sq[:, step - 1] = 1
        dat = np.zeros((nx, ny, nz))
        for i in range(0, dat.shape[0], step):
            for j in range(0, dat.shape[1], step):
                for k in range(dat.shape[2]):
                    if dat[i:i + step, j:j + step, k].shape == (step, step):
                        dat[i:i + step, j:j + step, k] = sq
        im_grid = Image(param=dat)
        grid_hdr = im_warp.hdr
        im_grid.hdr = grid_hdr
        fname_grid = 'grid_' + str(step) + '.nii.gz'
        im_grid.save(fname_grid)
        fname_grid_resample = add_suffix(fname_grid, '_resample')
        sct_resample.main(argv=['-i', fname_grid, '-f', '3x3x1', '-x', 'nn', '-o', fname_grid_resample])
        fname_grid = os.path.join(tmp_dir, fname_grid_resample)
        os.chdir(curdir)
    path_warp, file_warp, ext_warp = extract_fname(fname_warp)
    grid_warped = os.path.join(path_warp, extract_fname(fname_grid)[1] + '_' + file_warp + ext_warp)
    sct_apply_transfo.main(argv=['-i', fname_grid, '-d', fname_grid, '-w', fname_warp, '-o', grid_warped])
    if rm_tmp:
        rmtree(tmp_dir)
Exemple #3
0
def register_data(im_src, im_dest, param_reg, path_copy_warp=None, rm_tmp=True):
    '''

    Parameters
    ----------
    im_src: class Image: source image
    im_dest: class Image: destination image
    param_reg: str: registration parameter
    path_copy_warp: path: path to copy the warping fields

    Returns: im_src_reg: class Image: source image registered on destination image
    -------

    '''
    # im_src and im_dest are already preprocessed (in theory: im_dest = mean_image)
    # binarize images to get seg
    im_src_seg = binarize(im_src, thr_min=1, thr_max=1)
    im_dest_seg = binarize(im_dest)
    # create tmp dir and go in it
    tmp_dir = tmp_create()
    curdir = os.getcwd()
    os.chdir(tmp_dir)
    # save image and seg
    fname_src = 'src.nii.gz'
    im_src.save(fname_src)
    fname_src_seg = 'src_seg.nii.gz'
    im_src_seg.save(fname_src_seg)
    fname_dest = 'dest.nii.gz'
    im_dest.save(fname_dest)
    fname_dest_seg = 'dest_seg.nii.gz'
    im_dest_seg.save(fname_dest_seg)
    # do registration using param_reg
    sct_register_multimodal.main(args=['-i', fname_src,
                                       '-d', fname_dest,
                                       '-iseg', fname_src_seg,
                                       '-dseg', fname_dest_seg,
                                       '-param', param_reg])

    # get registration result
    fname_src_reg = add_suffix(fname_src, '_reg')
    im_src_reg = Image(fname_src_reg)
    # get out of tmp dir
    os.chdir(curdir)

    # copy warping fields
    if path_copy_warp is not None and os.path.isdir(os.path.abspath(path_copy_warp)):
        path_copy_warp = os.path.abspath(path_copy_warp)
        file_src = extract_fname(fname_src)[1]
        file_dest = extract_fname(fname_dest)[1]
        fname_src2dest = 'warp_' + file_src + '2' + file_dest + '.nii.gz'
        fname_dest2src = 'warp_' + file_dest + '2' + file_src + '.nii.gz'
        copy(os.path.join(tmp_dir, fname_src2dest), path_copy_warp)
        copy(os.path.join(tmp_dir, fname_dest2src), path_copy_warp)

    if rm_tmp:
        # remove tmp dir
        rmtree(tmp_dir)
    # return res image
    return im_src_reg, fname_src2dest, fname_dest2src
def main():
    parser = get_parser()
    arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    # Set param arguments ad inputted by user
    fname_in = arguments.i
    contrast = arguments.c

    # Segmentation or Centerline line
    if arguments.s is not None:
        fname_seg = arguments.s
        if not os.path.isfile(fname_seg):
            fname_seg = None
            printv('WARNING: -s input file: "' + arguments.s + '" does not exist.\nDetecting PMJ without using segmentation information', 1, 'warning')
    else:
        fname_seg = None

    # Output Folder
    if arguments.ofolder is not None:
        path_results = arguments.ofolder
        if not os.path.isdir(path_results) and os.path.exists(path_results):
            printv("ERROR output directory %s is not a valid directory" % path_results, 1, 'error')
        if not os.path.exists(path_results):
            os.makedirs(path_results)
    else:
        path_results = '.'

    path_qc = arguments.qc

    # Remove temp folder
    rm_tmp = bool(arguments.r)

    verbose = arguments.v
    init_sct(log_level=verbose, update=True)  # Update log level

    # Initialize DetectPMJ
    detector = DetectPMJ(fname_im=fname_in,
                         contrast=contrast,
                         fname_seg=fname_seg,
                         path_out=path_results,
                         verbose=verbose)

    # run the extraction
    fname_out, tmp_dir = detector.apply()

    # Remove tmp_dir
    if rm_tmp:
        rmtree(tmp_dir)

    # View results
    if fname_out is not None:
        if path_qc is not None:
            from spinalcordtoolbox.reports.qc import generate_qc
            generate_qc(fname_in, fname_seg=fname_out, args=sys.argv[1:], path_qc=os.path.abspath(path_qc), process='sct_detect_pmj')

        display_viewer_syntax([fname_in, fname_out], colormaps=['gray', 'red'])
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    # create param object
    param = Param()
    param_glcm = ParamGLCM()

    # set param arguments ad inputted by user
    param.fname_im = arguments.i
    param.fname_seg = arguments.m

    if arguments.ofolder is not None:
        param.path_results = arguments.ofolder

    if not os.path.isdir(param.path_results) and os.path.exists(
            param.path_results):
        printv(
            "ERROR output directory %s is not a valid directory" %
            param.path_results, 1, 'error')
    if not os.path.exists(param.path_results):
        os.makedirs(param.path_results)

    if arguments.feature is not None:
        param_glcm.feature = arguments.feature
    if arguments.distance is not None:
        param_glcm.distance = arguments.distance
    if arguments.angle is not None:
        param_glcm.angle = arguments.angle

    if arguments.dim is not None:
        param.dim = arguments.dim
    if arguments.r is not None:
        param.rm_tmp = bool(arguments.r)

    # create the GLCM constructor
    glcm = ExtractGLCM(param=param, param_glcm=param_glcm)
    # run the extraction
    fname_out_lst = glcm.extract()

    # remove tmp_dir
    if param.rm_tmp:
        rmtree(glcm.tmp_dir)

    printv('\nDone! To view results, type:', param.verbose)
    printv(
        'fsleyes ' + arguments.i + ' ' +
        ' -cm red-yellow -a 70.0 '.join(fname_out_lst) +
        ' -cm Red-Yellow -a 70.0 & \n', param.verbose, 'info')
def main(argv=None):
    """
    Main function
    :param argv:
    :return:
    """
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_mask = arguments.m
    fname_sc = arguments.s
    fname_ref = arguments.i

    # Path to template
    path_template = arguments.f
    # TODO: check this in the parser
    # if not os.path.isdir(path_template) and os.path.exists(path_template):
    #     path_template = None
    #     printv("ERROR output directory %s is not a valid directory" % path_template, 1, 'error')

    # Output Folder
    path_results = arguments.ofolder
    # if not os.path.isdir(path_results) and os.path.exists(path_results):
    #     printv("ERROR output directory %s is not a valid directory" % path_results, 1, 'error')
    if not os.path.exists(path_results):
        os.makedirs(path_results)

    # Remove temp folder
    if arguments.r is not None:
        rm_tmp = bool(arguments.r)
    else:
        rm_tmp = True

    # create the Lesion constructor
    lesion_obj = AnalyzeLeion(fname_mask=fname_mask,
                              fname_sc=fname_sc,
                              fname_ref=fname_ref,
                              path_template=path_template,
                              path_ofolder=path_results,
                              verbose=verbose)

    # run the analyze
    lesion_obj.analyze()

    # remove tmp_dir
    if rm_tmp:
        rmtree(lesion_obj.tmp_dir)

    printv('\nDone! To view the labeled lesion file (one value per lesion), type:', verbose)
    if fname_ref is not None:
        printv('fsleyes ' + fname_mask + ' ' + os.path.join(path_results, lesion_obj.fname_label) + ' -cm red-yellow -a 70.0 & \n', verbose, 'info')
    else:
        printv('fsleyes ' + os.path.join(path_results, lesion_obj.fname_label) + ' -cm red-yellow -a 70.0 & \n', verbose, 'info')
Exemple #7
0
def visualize_warp(fname_warp, fname_grid=None, step=3, rm_tmp=True):
    if fname_grid is None:
        from numpy import zeros
        tmp_dir = tmp_create()
        im_warp = Image(fname_warp)
        status, out = run_proc(['fslhd', fname_warp])
        curdir = os.getcwd()
        os.chdir(tmp_dir)
        dim1 = 'dim1           '
        dim2 = 'dim2           '
        dim3 = 'dim3           '
        nx = int(
            out[out.find(dim1):][len(dim1):out[out.find(dim1):].find('\n')])
        ny = int(
            out[out.find(dim2):][len(dim2):out[out.find(dim2):].find('\n')])
        nz = int(
            out[out.find(dim3):][len(dim3):out[out.find(dim3):].find('\n')])
        sq = zeros((step, step))
        sq[step - 1] = 1
        sq[:, step - 1] = 1
        dat = zeros((nx, ny, nz))
        for i in range(0, dat.shape[0], step):
            for j in range(0, dat.shape[1], step):
                for k in range(dat.shape[2]):
                    if dat[i:i + step, j:j + step, k].shape == (step, step):
                        dat[i:i + step, j:j + step, k] = sq
        fname_grid = 'grid_' + str(step) + '.nii.gz'
        im_grid = Image(param=dat)
        grid_hdr = im_warp.hdr
        im_grid.hdr = grid_hdr
        im_grid.absolutepath = fname_grid
        im_grid.save()
        fname_grid_resample = add_suffix(fname_grid, '_resample')
        run_proc([
            'sct_resample', '-i', fname_grid, '-f', '3x3x1', '-x', 'nn', '-o',
            fname_grid_resample
        ])
        fname_grid = os.path.join(tmp_dir, fname_grid_resample)
        os.chdir(curdir)
    path_warp, file_warp, ext_warp = extract_fname(fname_warp)
    grid_warped = os.path.join(
        path_warp,
        extract_fname(fname_grid)[1] + '_' + file_warp + ext_warp)
    run_proc([
        'sct_apply_transfo', '-i', fname_grid, '-d', fname_grid, '-w',
        fname_warp, '-o', grid_warped
    ])
    if rm_tmp:
        rmtree(tmp_dir)
def main(args=None):
    # initialize parameters
    param = Param()
    # call main function
    parser = get_parser()
    if args:
        arguments = parser.parse_args(args)
    else:
        arguments = parser.parse_args(args=None if sys.argv[1:] else ['--help'])

    fname_data = arguments.i
    fname_bvecs = arguments.bvec
    average = arguments.a
    verbose = int(arguments.v)
    init_sct(log_level=verbose, update=True)  # Update log level
    remove_temp_files = arguments.r
    path_out = arguments.ofolder

    fname_bvals = arguments.bval
    if arguments.bvalmin:
        param.bval_min = arguments.bvalmin

    # Initialization
    start_time = time.time()

    # printv(arguments)
    printv('\nInput parameters:', verbose)
    printv('  input file ............' + fname_data, verbose)
    printv('  bvecs file ............' + fname_bvecs, verbose)
    printv('  bvals file ............' + fname_bvals, verbose)
    printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    # create temporary folder
    path_tmp = tmp_create(basename="dmri_separate")

    # copy files into tmp folder and convert to nifti
    printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        printv('ERROR in convert.', 1, 'error')
    copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), verbose)

    # Identify b=0 and DWI images
    printv(fname_bvals)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals, param.bval_min, verbose)

    # Split into T dimension
    printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:
        im_d.save()

    # Merge b=0 images
    printv('\nMerge b=0...', verbose)
    from sct_image import concat_data
    l = []
    for it in range(nb_b0):
        l.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        printv('\nAverage b=0...', verbose)
        run_proc(['sct_maths', '-i', b0_name + ext, '-o', b0_mean_name + ext, '-mean', 't'], verbose)

    # Merge DWI
    l = []
    for it in range(nb_dwi):
        l.append(dmri_name + '_T' + str(index_dwi[it]).zfill(4) + ext)
    im_out = concat_data(l, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        printv('\nAverage DWI...', verbose)
        run_proc(['sct_maths', '-i', dwi_name + ext, '-o', dwi_mean_name + ext, '-mean', 't'], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(os.path.join(path_out, dwi_mean_name + ext_data))
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, b0_name + ext), fname_b0, verbose=verbose)
    generate_output_file(os.path.join(path_tmp, dwi_name + ext), fname_dwi, verbose=verbose)
    if average:
        generate_output_file(os.path.join(path_tmp, b0_mean_name + ext), fname_b0_mean, verbose=verbose)
        generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext), fname_dwi_mean, verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
Exemple #9
0
def merge_images(list_fname_src, fname_dest, list_fname_warp, param):
    """
    Merge multiple source images onto destination space. All images are warped to the destination space and then added.
    To deal with overlap during merging (e.g. one voxel in destination image is shared with two input images), the
    resulting voxel is divided by the sum of the partial volume of each image. For example, if src(x,y,z)=1 is mapped to
    dest(i,j,k) with a partial volume of 0.5 (because destination voxel is bigger), then its value after linear interpolation
    will be 0.5. To account for partial volume, the resulting voxel will be: dest(i,j,k) = 0.5*0.5/0.5 = 0.5.
    Now, if two voxels overlap in the destination space, let's say: src(x,y,z)=1 and src2'(x',y',z')=1, then the
    resulting value will be: dest(i,j,k) = (0.5*0.5 + 0.5*0.5) / (0.5+0.5) = 0.5. So this function acts like a weighted
    average operator, only in destination voxels that share multiple source voxels.

    Parameters
    ----------
    list_fname_src
    fname_dest
    list_fname_warp
    param

    Returns
    -------

    """
    # create temporary folder
    path_tmp = tmp_create()

    # get dimensions of destination file
    nii_dest = Image(fname_dest)

    # initialize variables
    data = np.zeros([
        nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2],
        len(list_fname_src)
    ])
    partial_volume = np.zeros([
        nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2],
        len(list_fname_src)
    ])
    data_merge = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2]])

    # loop across files
    i_file = 0
    for fname_src in list_fname_src:

        # apply transformation src --> dest
        sct_apply_transfo.main(argv=[
            '-i', fname_src, '-d', fname_dest, '-w', list_fname_warp[i_file],
            '-x', param.interp, '-o', 'src_' + str(i_file) +
            '_template.nii.gz', '-v',
            str(param.verbose)
        ])

        # create binary mask from input file by assigning one to all non-null voxels
        img = Image(fname_src)
        out = img.copy()
        out.data = binarize(out.data, param.almost_zero)
        out.save(path=f"src_{i_file}native_bin.nii.gz")

        # apply transformation to binary mask to compute partial volume
        sct_apply_transfo.main(argv=[
            '-i', 'src_' + str(i_file) + 'native_bin.nii.gz', '-d', fname_dest,
            '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' +
            str(i_file) + '_template_partialVolume.nii.gz'
        ])

        # open data
        data[:, :, :,
             i_file] = Image('src_' + str(i_file) + '_template.nii.gz').data
        partial_volume[:, :, :,
                       i_file] = Image('src_' + str(i_file) +
                                       '_template_partialVolume.nii.gz').data
        i_file += 1

    # merge files using partial volume information (and convert nan resulting from division by zero to zeros)
    data_merge = np.divide(np.sum(data * partial_volume, axis=3),
                           np.sum(partial_volume, axis=3))
    data_merge = np.nan_to_num(data_merge)

    # write result in file
    nii_dest.data = data_merge
    nii_dest.save(param.fname_out)

    # remove temporary folder
    if param.rm_tmp:
        rmtree(path_tmp)
Exemple #10
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # Initialization
    param = Param()
    start_time = time.time()

    fname_anat = arguments.i
    fname_centerline = arguments.s
    param.algo_fitting = arguments.algo_fitting

    if arguments.smooth is not None:
        sigmas = arguments.smooth
    remove_temp_files = arguments.r
    if arguments.o is not None:
        fname_out = arguments.o
    else:
        fname_out = extract_fname(fname_anat)[1] + '_smooth.nii'

    # Display arguments
    printv('\nCheck input arguments...')
    printv('  Volume to smooth .................. ' + fname_anat)
    printv('  Centerline ........................ ' + fname_centerline)
    printv('  Sigma (mm) ........................ ' + str(sigmas))
    printv('  Verbose ........................... ' + str(verbose))

    # Check that input is 3D:
    nx, ny, nz, nt, px, py, pz, pt = Image(fname_anat).dim
    dim = 4  # by default, will be adjusted later
    if nt == 1:
        dim = 3
    if nz == 1:
        dim = 2
    if dim == 4:
        printv(
            'WARNING: the input image is 4D, please split your image to 3D before smoothing spinalcord using :\n'
            'sct_image -i ' + fname_anat + ' -split t -o ' + fname_anat,
            verbose, 'warning')
        printv('4D images not supported, aborting ...', verbose, 'error')

    # Extract path/file/extension
    path_anat, file_anat, ext_anat = extract_fname(fname_anat)
    path_centerline, file_centerline, ext_centerline = extract_fname(
        fname_centerline)

    path_tmp = tmp_create(basename="smooth_spinalcord")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...', verbose)
    copy(fname_anat, os.path.join(path_tmp, "anat" + ext_anat))
    copy(fname_centerline, os.path.join(path_tmp,
                                        "centerline" + ext_centerline))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # convert to nii format
    im_anat = convert(Image('anat' + ext_anat))
    im_anat.save('anat.nii', mutable=True, verbose=verbose)
    im_centerline = convert(Image('centerline' + ext_centerline))
    im_centerline.save('centerline.nii', mutable=True, verbose=verbose)

    # Change orientation of the input image into RPI
    printv('\nOrient input volume to RPI orientation...')

    img_anat_rpi = Image("anat.nii").change_orientation("RPI")
    fname_anat_rpi = add_suffix(img_anat_rpi.absolutepath, "_rpi")
    img_anat_rpi.save(path=fname_anat_rpi, mutable=True)

    # Change orientation of the input image into RPI
    printv('\nOrient centerline to RPI orientation...')

    img_centerline_rpi = Image("centerline.nii").change_orientation("RPI")
    fname_centerline_rpi = add_suffix(img_centerline_rpi.absolutepath, "_rpi")
    img_centerline_rpi.save(path=fname_centerline_rpi, mutable=True)

    # Straighten the spinal cord
    # straighten segmentation
    printv('\nStraighten the spinal cord using centerline/segmentation...',
           verbose)
    cache_sig = cache_signature(
        input_files=[fname_anat_rpi, fname_centerline_rpi],
        input_params={"x": "spline"})
    cachefile = os.path.join(curdir, "straightening.cache")
    if cache_valid(cachefile, cache_sig) and os.path.isfile(
            os.path.join(
                curdir, 'warp_curve2straight.nii.gz')) and os.path.isfile(
                    os.path.join(
                        curdir,
                        'warp_straight2curve.nii.gz')) and os.path.isfile(
                            os.path.join(curdir, 'straight_ref.nii.gz')):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, 'warp_curve2straight.nii.gz'),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, 'warp_straight2curve.nii.gz'),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, 'straight_ref.nii.gz'),
             'straight_ref.nii.gz')
        # apply straightening
        run_proc([
            'sct_apply_transfo', '-i', fname_anat_rpi, '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'anat_rpi_straight.nii', '-x', 'spline'
        ], verbose)
    else:
        run_proc([
            'sct_straighten_spinalcord', '-i', fname_anat_rpi, '-o',
            'anat_rpi_straight.nii', '-s', fname_centerline_rpi, '-x',
            'spline', '-param', 'algo_fitting=' + param.algo_fitting
        ], verbose)
        cache_save(cachefile, cache_sig)
        # move warping fields locally (to use caching next time)
        copy('warp_curve2straight.nii.gz',
             os.path.join(curdir, 'warp_curve2straight.nii.gz'))
        copy('warp_straight2curve.nii.gz',
             os.path.join(curdir, 'warp_straight2curve.nii.gz'))

    # Smooth the straightened image along z
    printv('\nSmooth the straightened image...')

    img = Image("anat_rpi_straight.nii")
    out = img.copy()

    if len(sigmas) == 1:
        sigmas = [sigmas[0] for i in range(len(img.data.shape))]
    elif len(sigmas) != len(img.data.shape):
        raise ValueError(
            "-smooth need the same number of inputs as the number of image dimension OR only one input"
        )

    sigmas = [sigmas[i] / img.dim[i + 4] for i in range(3)]
    out.data = smooth(out.data, sigmas)
    out.save(path="anat_rpi_straight_smooth.nii")

    # Apply the reversed warping field to get back the curved spinal cord
    printv(
        '\nApply the reversed warping field to get back the curved spinal cord...'
    )
    run_proc([
        'sct_apply_transfo', '-i', 'anat_rpi_straight_smooth.nii', '-o',
        'anat_rpi_straight_smooth_curved.nii', '-d', 'anat.nii', '-w',
        'warp_straight2curve.nii.gz', '-x', 'spline'
    ], verbose)

    # replace zeroed voxels by original image (issue #937)
    printv('\nReplace zeroed voxels by original image...', verbose)
    nii_smooth = Image('anat_rpi_straight_smooth_curved.nii')
    data_smooth = nii_smooth.data
    data_input = Image('anat.nii').data
    indzero = np.where(data_smooth == 0)
    data_smooth[indzero] = data_input[indzero]
    nii_smooth.data = data_smooth
    nii_smooth.save('anat_rpi_straight_smooth_curved_nonzero.nii')

    # come back
    os.chdir(curdir)

    # Generate output file
    printv('\nGenerate output file...')
    generate_output_file(
        os.path.join(path_tmp, "anat_rpi_straight_smooth_curved_nonzero.nii"),
        fname_out)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...')
        rmtree(path_tmp)

    # Display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) +
           's\n')

    display_viewer_syntax([fname_anat, fname_out], verbose=verbose)
Exemple #11
0
def check_and_correct_segmentation(fname_segmentation,
                                   fname_centerline,
                                   folder_output='',
                                   threshold_distance=5.0,
                                   remove_temp_files=1,
                                   verbose=0):
    """
    This function takes the outputs of isct_propseg (centerline and segmentation) and check if the centerline of the
    segmentation is coherent with the centerline provided by the isct_propseg, especially on the edges (related
    to issue #1074).
    Args:
        fname_segmentation: filename of binary segmentation
        fname_centerline: filename of binary centerline
        threshold_distance: threshold, in mm, beyond which centerlines are not coherent
        verbose:

    Returns: None
    """
    printv('\nCheck consistency of segmentation...', verbose)
    # creating a temporary folder in which all temporary files will be placed and deleted afterwards
    path_tmp = tmp_create(basename="propseg")
    convert(fname_segmentation,
            os.path.join(path_tmp, "tmp.segmentation.nii.gz"),
            verbose=0)
    convert(fname_centerline,
            os.path.join(path_tmp, "tmp.centerline.nii.gz"),
            verbose=0)
    fname_seg_absolute = os.path.abspath(fname_segmentation)

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # convert segmentation image to RPI
    im_input = Image('tmp.segmentation.nii.gz')
    image_input_orientation = im_input.orientation

    sct_image.main(
        "-i tmp.segmentation.nii.gz -setorient RPI -o tmp.segmentation_RPI.nii.gz -v 0"
        .split())
    sct_image.main(
        "-i tmp.centerline.nii.gz -setorient RPI -o tmp.centerline_RPI.nii.gz -v 0"
        .split())

    # go through segmentation image, and compare with centerline from propseg
    im_seg = Image('tmp.segmentation_RPI.nii.gz')
    im_centerline = Image('tmp.centerline_RPI.nii.gz')

    # Get size of data
    printv('\nGet data dimensions...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_seg.dim

    # extraction of centerline provided by isct_propseg and computation of center of mass for each slice
    # the centerline is defined as the center of the tubular mesh outputed by propseg.
    centerline, key_centerline = {}, []
    for i in range(nz):
        slice = im_centerline.data[:, :, i]
        if np.any(slice):
            x_centerline, y_centerline = ndi.measurements.center_of_mass(slice)
            centerline[str(i)] = [x_centerline, y_centerline]
            key_centerline.append(i)

    minz_centerline = np.min(key_centerline)
    maxz_centerline = np.max(key_centerline)
    mid_slice = int((maxz_centerline - minz_centerline) / 2)

    # for each slice of the segmentation, check if only one object is present. If not, remove the slice from segmentation.
    # If only one object (the spinal cord) is present in the slice, check if its center of mass is close to the centerline of isct_propseg.
    slices_to_remove = [
        False
    ] * nz  # flag that decides if the slice must be removed
    for i in range(minz_centerline, maxz_centerline + 1):
        # extraction of slice
        slice = im_seg.data[:, :, i]
        distance = -1
        label_objects, nb_labels = ndi.label(
            slice)  # count binary objects in the slice
        if nb_labels > 1:  # if there is more that one object in the slice, the slice is removed from the segmentation
            slices_to_remove[i] = True
        elif nb_labels == 1:  # check if the centerline is coherent with the one from isct_propseg
            x_centerline, y_centerline = ndi.measurements.center_of_mass(slice)
            slice_nearest_coord = min(key_centerline, key=lambda x: abs(x - i))
            coord_nearest_coord = centerline[str(slice_nearest_coord)]
            distance = np.sqrt((
                (x_centerline - coord_nearest_coord[0]) * px)**2 + (
                    (y_centerline - coord_nearest_coord[1]) * py)**2 +
                               ((i - slice_nearest_coord) * pz)**2)

            if distance >= threshold_distance:  # threshold must be adjusted, default is 5 mm
                slices_to_remove[i] = True

    # Check list of removal and keep one continuous centerline (improve this comment)
    # Method:
    # starting from mid-centerline (in both directions), the first True encountered is applied to all following slices
    slice_to_change = False
    for i in range(mid_slice, nz):
        if slice_to_change:
            slices_to_remove[i] = True
        elif slices_to_remove[i]:
            slice_to_change = True

    slice_to_change = False
    for i in range(mid_slice, 0, -1):
        if slice_to_change:
            slices_to_remove[i] = True
        elif slices_to_remove[i]:
            slice_to_change = True

    for i in range(0, nz):
        # remove the slice
        if slices_to_remove[i]:
            im_seg.data[:, :, i] *= 0

    # saving the image
    im_seg.save('tmp.segmentation_RPI_c.nii.gz')

    # replacing old segmentation with the corrected one
    sct_image.main(
        '-i tmp.segmentation_RPI_c.nii.gz -setorient {} -o {} -v 0'.format(
            image_input_orientation, fname_seg_absolute).split())

    os.chdir(curdir)

    # display information about how much of the segmentation has been corrected

    # remove temporary files
    if remove_temp_files:
        # printv("\nRemove temporary files...", verbose)
        rmtree(path_tmp)
Exemple #12
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_loglevel(verbose=verbose)

    fname_in = os.path.abspath(arguments.i)
    fname_seg = os.path.abspath(arguments.s)
    contrast = arguments.c
    path_template = os.path.abspath(arguments.t)
    scale_dist = arguments.scale_dist
    path_output = os.path.abspath(arguments.ofolder)
    fname_disc = arguments.discfile
    if fname_disc is not None:
        fname_disc = os.path.abspath(fname_disc)
    initz = arguments.initz
    initcenter = arguments.initcenter
    fname_initlabel = arguments.initlabel
    if fname_initlabel is not None:
        fname_initlabel = os.path.abspath(fname_initlabel)
    remove_temp_files = arguments.r
    clean_labels = arguments.clean_labels

    path_tmp = tmp_create(basename="label_vertebrae")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder...', verbose)
    Image(fname_in).save(os.path.join(path_tmp, "data.nii"))
    Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii"))

    # Go go temp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Straighten spinal cord
    printv('\nStraighten spinal cord...', verbose)
    # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
    cache_sig = cache_signature(input_files=[fname_in, fname_seg], )
    fname_cache = "straightening.cache"
    if (cache_valid(os.path.join(curdir, fname_cache), cache_sig)
            and os.path.isfile(
                os.path.join(curdir, "warp_curve2straight.nii.gz"))
            and os.path.isfile(
                os.path.join(curdir, "warp_straight2curve.nii.gz"))
            and os.path.isfile(os.path.join(curdir, "straight_ref.nii.gz"))):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, "warp_curve2straight.nii.gz"),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, "warp_straight2curve.nii.gz"),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, "straight_ref.nii.gz"),
             'straight_ref.nii.gz')
        # apply straightening
        s, o = run_proc([
            'sct_apply_transfo', '-i', 'data.nii', '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'data_straight.nii'
        ])
    else:
        sct_straighten_spinalcord.main(argv=[
            '-i',
            'data.nii',
            '-s',
            'segmentation.nii',
            '-r',
            str(remove_temp_files),
            '-v',
            '0',
        ])
        cache_save(os.path.join(path_output, fname_cache), cache_sig)

    # resample to 0.5mm isotropic to match template resolution
    printv('\nResample to 0.5mm isotropic...', verbose)
    s, o = run_proc([
        'sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x',
        'linear', '-o', 'data_straightr.nii'
    ],
                    verbose=verbose)

    # Apply straightening to segmentation
    # N.B. Output is RPI
    printv('\nApply straightening to segmentation...', verbose)
    sct_apply_transfo.main([
        '-i', 'segmentation.nii', '-d', 'data_straightr.nii', '-w',
        'warp_curve2straight.nii.gz', '-o', 'segmentation_straight.nii', '-x',
        'linear', '-v', '0'
    ])

    # Threshold segmentation at 0.5
    img = Image('segmentation_straight.nii')
    img.data = threshold(img.data, 0.5)
    img.save()

    # If disc label file is provided, label vertebrae using that file instead of automatically
    if fname_disc:
        # Apply straightening to disc-label
        printv('\nApply straightening to disc labels...', verbose)
        run_proc(
            'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
            (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labeldisc_straight.nii.gz', 'label'),
            verbose=verbose)
        label_vert('segmentation_straight.nii',
                   'labeldisc_straight.nii.gz',
                   verbose=1)

    else:
        printv('\nCreate label to identify disc...', verbose)
        fname_labelz = os.path.join(path_tmp, 'labelz.nii.gz')
        if initcenter is not None:
            # find z centered in FOV
            nii = Image('segmentation.nii').change_orientation("RPI")
            nx, ny, nz, nt, px, py, pz, pt = nii.dim
            z_center = round(nz / 2)
            initz = [z_center, initcenter]
        if initz is not None:
            im_label = create_labels_along_segmentation(
                Image('segmentation.nii'), [tuple(initz)])
            im_label.save(fname_labelz)
        elif fname_initlabel is not None:
            Image(fname_initlabel).save(fname_labelz)
        else:
            # automatically finds C2-C3 disc
            im_data = Image('data.nii')
            im_seg = Image('segmentation.nii')
            # because verbose is also used for keeping temp files
            verbose_detect_c2c3 = 0 if remove_temp_files else 2
            im_label_c2c3 = detect_c2c3(im_data,
                                        im_seg,
                                        contrast,
                                        verbose=verbose_detect_c2c3)
            ind_label = np.where(im_label_c2c3.data)
            if np.size(ind_label) == 0:
                printv(
                    'Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils',
                    1, 'error')
                sys.exit(1)
            im_label_c2c3.data[ind_label] = 3
            im_label_c2c3.save(fname_labelz)

        # dilate label so it is not lost when applying warping
        dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz)

        # Apply straightening to z-label
        printv('\nAnd apply straightening to label...', verbose)
        sct_apply_transfo.main([
            '-i', 'labelz.nii.gz', '-d', 'data_straightr.nii', '-w',
            'warp_curve2straight.nii.gz', '-o', 'labelz_straight.nii.gz', '-x',
            'nn', '-v', '0'
        ])
        # get z value and disk value to initialize labeling
        printv('\nGet z and disc values from straight label...', verbose)
        init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz')
        printv('.. ' + str(init_disc), verbose)

        # apply laplacian filtering
        if arguments.laplacian:
            printv('\nApply Laplacian filter...', verbose)
            img = Image("data_straightr.nii")

            # apply std dev to each axis of the image
            sigmas = [1 for i in range(len(img.data.shape))]

            # adjust sigma based on voxel size
            sigmas = [sigmas[i] / img.dim[i + 4] for i in range(3)]

            # smooth data
            img.data = laplacian(img.data, sigmas)
            img.save()

        # detect vertebral levels on straight spinal cord
        init_disc[1] = init_disc[1] - 1
        vertebral_detection('data_straightr.nii',
                            'segmentation_straight.nii',
                            contrast,
                            arguments.param,
                            init_disc=init_disc,
                            verbose=verbose,
                            path_template=path_template,
                            path_output=path_output,
                            scale_dist=scale_dist)

    # un-straighten labeled spinal cord
    printv('\nUn-straighten labeling...', verbose)
    sct_apply_transfo.main([
        '-i', 'segmentation_straight_labeled.nii', '-d', 'segmentation.nii',
        '-w', 'warp_straight2curve.nii.gz', '-o', 'segmentation_labeled.nii',
        '-x', 'nn', '-v', '0'
    ])

    if clean_labels >= 1:
        printv('\nCleaning labeled segmentation:', verbose)
        im_labeled_seg = Image('segmentation_labeled.nii')
        im_seg = Image('segmentation.nii')
        if clean_labels >= 2:
            printv('  filling in missing label voxels ...', verbose)
            expand_labels(im_labeled_seg)
        printv('  removing labeled voxels outside segmentation...', verbose)
        crop_labels(im_labeled_seg, im_seg)
        printv('Done cleaning.', verbose)
        im_labeled_seg.save()

    # label discs
    printv('\nLabel discs...', verbose)
    printv('\nUn-straighten labeled discs...', verbose)
    run_proc(
        'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
        ('segmentation_straight_labeled_disc.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled_disc.nii',
         'label'),
        verbose=verbose,
        is_sct_binary=True,
    )

    # come back
    os.chdir(curdir)

    # Generate output files
    path_seg, file_seg, ext_seg = extract_fname(fname_seg)
    fname_seg_labeled = os.path.join(path_output,
                                     file_seg + '_labeled' + ext_seg)
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"),
                         fname_seg_labeled)
    generate_output_file(
        os.path.join(path_tmp, "segmentation_labeled_disc.nii"),
        os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
    # copy straightening files in case subsequent SCT functions need them
    generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"),
                         os.path.join(path_output,
                                      "warp_curve2straight.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"),
                         os.path.join(path_output,
                                      "warp_straight2curve.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"),
                         os.path.join(path_output, "straight_ref.nii.gz"),
                         verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp)

    # Generate QC report
    if arguments.qc is not None:
        path_qc = os.path.abspath(arguments.qc)
        qc_dataset = arguments.qc_dataset
        qc_subject = arguments.qc_subject
        labeled_seg_file = os.path.join(path_output,
                                        file_seg + '_labeled' + ext_seg)
        generate_qc(fname_in,
                    fname_seg=labeled_seg_file,
                    args=argv,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_label_vertebrae')

    display_viewer_syntax([fname_in, fname_seg_labeled],
                          colormaps=['', 'subcortical'],
                          opacities=['1', '0.5'])
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv if argv else ['--help'])
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initializations
    initz = ''
    initcenter = ''
    fname_initlabel = ''
    file_labelz = 'labelz.nii.gz'
    param = Param()

    fname_in = os.path.abspath(arguments.i)
    fname_seg = os.path.abspath(arguments.s)
    contrast = arguments.c
    path_template = os.path.abspath(arguments.t)
    scale_dist = arguments.scale_dist
    path_output = arguments.ofolder
    param.path_qc = arguments.qc
    if arguments.discfile is not None:
        fname_disc = os.path.abspath(arguments.discfile)
    else:
        fname_disc = None
    if arguments.initz is not None:
        initz = arguments.initz
        if len(initz) != 2:
            raise ValueError(
                '--initz takes two arguments: position in superior-inferior direction, label value'
            )
    if arguments.initcenter is not None:
        initcenter = arguments.initcenter
    # if user provided text file, parse and overwrite arguments
    if arguments.initfile is not None:
        file = open(arguments.initfile, 'r')
        initfile = ' ' + file.read().replace('\n', '')
        arg_initfile = initfile.split(' ')
        for idx_arg, arg in enumerate(arg_initfile):
            if arg == '-initz':
                initz = [int(x) for x in arg_initfile[idx_arg + 1].split(',')]
                if len(initz) != 2:
                    raise ValueError(
                        '--initz takes two arguments: position in superior-inferior direction, label value'
                    )
            if arg == '-initcenter':
                initcenter = int(arg_initfile[idx_arg + 1])
    if arguments.initlabel is not None:
        # get absolute path of label
        fname_initlabel = os.path.abspath(arguments.initlabel)
    if arguments.param is not None:
        param.update(arguments.param[0])
    remove_temp_files = arguments.r
    clean_labels = arguments.clean_labels
    laplacian = arguments.laplacian

    path_tmp = tmp_create(basename="label_vertebrae")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder...', verbose)
    Image(fname_in).save(os.path.join(path_tmp, "data.nii"))
    Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii"))

    # Go go temp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Straighten spinal cord
    printv('\nStraighten spinal cord...', verbose)
    # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time)
    cache_sig = cache_signature(input_files=[fname_in, fname_seg], )
    cachefile = os.path.join(curdir, "straightening.cache")
    if cache_valid(cachefile, cache_sig) and os.path.isfile(
            os.path.join(
                curdir, "warp_curve2straight.nii.gz")) and os.path.isfile(
                    os.path.join(
                        curdir,
                        "warp_straight2curve.nii.gz")) and os.path.isfile(
                            os.path.join(curdir, "straight_ref.nii.gz")):
        # if they exist, copy them into current folder
        printv('Reusing existing warping field which seems to be valid',
               verbose, 'warning')
        copy(os.path.join(curdir, "warp_curve2straight.nii.gz"),
             'warp_curve2straight.nii.gz')
        copy(os.path.join(curdir, "warp_straight2curve.nii.gz"),
             'warp_straight2curve.nii.gz')
        copy(os.path.join(curdir, "straight_ref.nii.gz"),
             'straight_ref.nii.gz')
        # apply straightening
        s, o = run_proc([
            'sct_apply_transfo', '-i', 'data.nii', '-w',
            'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o',
            'data_straight.nii'
        ])
    else:
        sct_straighten_spinalcord.main(argv=[
            '-i',
            'data.nii',
            '-s',
            'segmentation.nii',
            '-r',
            str(remove_temp_files),
            '-v',
            str(verbose),
        ])
        cache_save(cachefile, cache_sig)

    # resample to 0.5mm isotropic to match template resolution
    printv('\nResample to 0.5mm isotropic...', verbose)
    s, o = run_proc([
        'sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x',
        'linear', '-o', 'data_straightr.nii'
    ],
                    verbose=verbose)

    # Apply straightening to segmentation
    # N.B. Output is RPI
    printv('\nApply straightening to segmentation...', verbose)
    run_proc(
        'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
        ('segmentation.nii', 'data_straightr.nii',
         'warp_curve2straight.nii.gz', 'segmentation_straight.nii', 'Linear'),
        verbose=verbose,
        is_sct_binary=True,
    )
    # Threshold segmentation at 0.5
    run_proc([
        'sct_maths', '-i', 'segmentation_straight.nii', '-thr', '0.5', '-o',
        'segmentation_straight.nii'
    ], verbose)

    # If disc label file is provided, label vertebrae using that file instead of automatically
    if fname_disc:
        # Apply straightening to disc-label
        printv('\nApply straightening to disc labels...', verbose)
        run_proc(
            'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
            (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labeldisc_straight.nii.gz', 'label'),
            verbose=verbose)
        label_vert('segmentation_straight.nii',
                   'labeldisc_straight.nii.gz',
                   verbose=1)

    else:
        # create label to identify disc
        printv('\nCreate label to identify disc...', verbose)
        fname_labelz = os.path.join(path_tmp, file_labelz)
        if initz or initcenter:
            if initcenter:
                # find z centered in FOV
                nii = Image('segmentation.nii').change_orientation("RPI")
                nx, ny, nz, nt, px, py, pz, pt = nii.dim  # Get dimensions
                z_center = int(np.round(nz / 2))  # get z_center
                initz = [z_center, initcenter]

            im_label = create_labels_along_segmentation(
                Image('segmentation.nii'), [(initz[0], initz[1])])
            im_label.data = dilate(im_label.data, 3, 'ball')
            im_label.save(fname_labelz)

        elif fname_initlabel:
            Image(fname_initlabel).save(fname_labelz)

        else:
            # automatically finds C2-C3 disc
            im_data = Image('data.nii')
            im_seg = Image('segmentation.nii')
            if not remove_temp_files:  # because verbose is here also used for keeping temp files
                verbose_detect_c2c3 = 2
            else:
                verbose_detect_c2c3 = 0
            im_label_c2c3 = detect_c2c3(im_data,
                                        im_seg,
                                        contrast,
                                        verbose=verbose_detect_c2c3)
            ind_label = np.where(im_label_c2c3.data)
            if not np.size(ind_label) == 0:
                im_label_c2c3.data[ind_label] = 3
            else:
                printv(
                    'Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils',
                    1, 'error')
                sys.exit()
            im_label_c2c3.save(fname_labelz)

        # dilate label so it is not lost when applying warping
        dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz)

        # Apply straightening to z-label
        printv('\nAnd apply straightening to label...', verbose)
        run_proc(
            'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
            (file_labelz, 'data_straightr.nii', 'warp_curve2straight.nii.gz',
             'labelz_straight.nii.gz', 'NearestNeighbor'),
            verbose=verbose,
            is_sct_binary=True,
        )
        # get z value and disk value to initialize labeling
        printv('\nGet z and disc values from straight label...', verbose)
        init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz')
        printv('.. ' + str(init_disc), verbose)

        # apply laplacian filtering
        if laplacian:
            printv('\nApply Laplacian filter...', verbose)
            run_proc([
                'sct_maths', '-i', 'data_straightr.nii', '-laplacian', '1',
                '-o', 'data_straightr.nii'
            ], verbose)

        # detect vertebral levels on straight spinal cord
        init_disc[1] = init_disc[1] - 1
        vertebral_detection('data_straightr.nii',
                            'segmentation_straight.nii',
                            contrast,
                            param,
                            init_disc=init_disc,
                            verbose=verbose,
                            path_template=path_template,
                            path_output=path_output,
                            scale_dist=scale_dist)

    # un-straighten labeled spinal cord
    printv('\nUn-straighten labeling...', verbose)
    run_proc(
        'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' %
        ('segmentation_straight_labeled.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled.nii',
         'NearestNeighbor'),
        verbose=verbose,
        is_sct_binary=True,
    )

    if clean_labels:
        # Clean labeled segmentation
        printv(
            '\nClean labeled segmentation (correct interpolation errors)...',
            verbose)
        clean_labeled_segmentation('segmentation_labeled.nii',
                                   'segmentation.nii',
                                   'segmentation_labeled.nii')

    # label discs
    printv('\nLabel discs...', verbose)
    printv('\nUn-straighten labeled discs...', verbose)
    run_proc(
        'sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' %
        ('segmentation_straight_labeled_disc.nii', 'segmentation.nii',
         'warp_straight2curve.nii.gz', 'segmentation_labeled_disc.nii',
         'label'),
        verbose=verbose,
        is_sct_binary=True,
    )

    # come back
    os.chdir(curdir)

    # Generate output files
    path_seg, file_seg, ext_seg = extract_fname(fname_seg)
    fname_seg_labeled = os.path.join(path_output,
                                     file_seg + '_labeled' + ext_seg)
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"),
                         fname_seg_labeled)
    generate_output_file(
        os.path.join(path_tmp, "segmentation_labeled_disc.nii"),
        os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg))
    # copy straightening files in case subsequent SCT functions need them
    generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"),
                         os.path.join(path_output,
                                      "warp_curve2straight.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"),
                         os.path.join(path_output,
                                      "warp_straight2curve.nii.gz"),
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"),
                         os.path.join(path_output, "straight_ref.nii.gz"),
                         verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp)

    # Generate QC report
    if param.path_qc is not None:
        path_qc = os.path.abspath(arguments.qc)
        qc_dataset = arguments.qc_dataset
        qc_subject = arguments.qc_subject
        labeled_seg_file = os.path.join(path_output,
                                        file_seg + '_labeled' + ext_seg)
        generate_qc(fname_in,
                    fname_seg=labeled_seg_file,
                    args=argv,
                    path_qc=os.path.abspath(path_qc),
                    dataset=qc_dataset,
                    subject=qc_subject,
                    process='sct_label_vertebrae')

    display_viewer_syntax([fname_in, fname_seg_labeled],
                          colormaps=['', 'subcortical'],
                          opacities=['1', '0.5'])
def main():

    # Initialization
    size_data = 61
    size_label = 1  # put zero for labels that are single points.
    dice_acceptable = 0.39  # computed DICE should be 0.931034
    test_passed = 0
    remove_temp_files = 1
    verbose = 1

    # Check input parameters
    try:
        opts, args = getopt.getopt(sys.argv[1:], 'hvr:')
    except getopt.GetoptError:
        usage()
        raise SystemExit(2)
    for opt, arg in opts:
        if opt == '-h':
            usage()
            return
        elif opt in ('-v'):
            verbose = int(arg)
        elif opt in ('-r'):
            remove_temp_files = int(arg)

    path_tmp = tmp_create(basename="test_ants")

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Initialise numpy volumes
    data_src = np.zeros((size_data, size_data, size_data), dtype=np.int16)
    data_dest = np.zeros((size_data, size_data, size_data), dtype=np.int16)

    # add labels for src image (curved).
    # Labels can be big (more than single point), because when applying NN interpolation, single points might disappear
    data_src[20 - size_label:20 + size_label + 1,
             20 - size_label:20 + size_label + 1,
             10 - size_label:10 + size_label + 1] = 1
    data_src[30 - size_label:30 + size_label + 1,
             30 - size_label:30 + size_label + 1,
             30 - size_label:30 + size_label + 1] = 2
    data_src[20 - size_label:20 + size_label + 1,
             20 - size_label:20 + size_label + 1,
             50 - size_label:50 + size_label + 1] = 3

    # add labels for dest image (straight).
    # Here, no need for big labels (bigger than single point) because these labels will not be re-interpolated.
    data_dest[30 - size_label:30 + size_label + 1,
              30 - size_label:30 + size_label + 1,
              10 - size_label:10 + size_label + 1] = 1
    data_dest[30 - size_label:30 + size_label + 1,
              30 - size_label:30 + size_label + 1,
              30 - size_label:30 + size_label + 1] = 2
    data_dest[30 - size_label:30 + size_label + 1,
              30 - size_label:30 + size_label + 1,
              50 - size_label:50 + size_label + 1] = 3

    # save as nifti
    img_src = nib.Nifti1Image(data_src, np.eye(4))
    nib.save(img_src, 'data_src.nii.gz')
    img_dest = nib.Nifti1Image(data_dest, np.eye(4))
    nib.save(img_dest, 'data_dest.nii.gz')

    # Estimate rigid transformation
    printv('\nEstimate rigid transformation between paired landmarks...',
           verbose)
    # TODO fixup isct_ants* parsers
    run_proc([
        'isct_antsRegistration', '-d', '3', '-t', 'syn[1,3,1]', '-m',
        'MeanSquares[data_dest.nii.gz,data_src.nii.gz,1,3]', '-f', '2', '-s',
        '0', '-o', '[src2reg,data_src_reg.nii.gz]', '-c', '5', '-v', '1', '-n',
        'NearestNeighbor'
    ],
             verbose,
             is_sct_binary=True)

    # # Apply rigid transformation
    # printv('\nApply rigid transformation to curved landmarks...', verbose)
    # run_proc('sct_apply_transfo -i data_src.nii.gz -o data_src_rigid.nii.gz -d data_dest.nii.gz -w curve2straight_rigid.txt -p nn', verbose)
    #
    # # Estimate b-spline transformation curve --> straight
    # printv('\nEstimate b-spline transformation: curve --> straight...', verbose)
    # run_proc('isct_ANTSLandmarksBSplineTransform data_dest.nii.gz data_src_rigid.nii.gz warp_curve2straight_intermediate.nii.gz 5x5x5 3 2 0', verbose)
    #
    # # Concatenate rigid and non-linear transformations...
    # printv('\nConcatenate rigid and non-linear transformations...', verbose)
    # cmd = 'isct_ComposeMultiTransform 3 warp_curve2straight.nii.gz -R data_dest.nii.gz warp_curve2straight_intermediate.nii.gz curve2straight_rigid.txt'
    # printv('>> '+cmd, verbose)
    # run_proc(cmd)
    #
    # # Apply deformation to input image
    # printv('\nApply transformation to input image...', verbose)
    # run_proc('sct_apply_transfo -i data_src.nii.gz -o data_src_warp.nii.gz -d data_dest.nii.gz -w warp_curve2straight.nii.gz -p nn', verbose)
    #
    # Compute DICE coefficient between src and dest
    printv('\nCompute DICE coefficient...', verbose)
    run_proc([
        "sct_dice_coefficient", "-i", "data_dest.nii.gz", "-d",
        "data_src_reg.nii.gz", "-o", "dice.txt"
    ], verbose)
    with open("dice.txt", "r") as file_dice:
        dice = float(file_dice.read().replace('3D Dice coefficient = ', ''))
    printv(
        'Dice coeff = ' + str(dice) + ' (should be above ' +
        str(dice_acceptable) + ')', verbose)

    # Check if DICE coefficient is above acceptable value
    if dice > dice_acceptable:
        test_passed = 1

    # come back
    os.chdir(curdir)

    # Delete temporary files
    if remove_temp_files == 1:
        printv('\nDelete temporary files...', verbose)
        rmtree(path_tmp)

    # output result for parent function
    if test_passed:
        printv('\nTest passed!\n', verbose)
    else:
        printv('\nTest failed!\n', verbose)
        raise SystemExit(1)
    def apply(self):
        # Initialization
        fname_src = self.input_filename  # source image (moving)
        list_warp = self.list_warp  # list of warping fields
        fname_out = self.output_filename  # output
        fname_dest = self.fname_dest  # destination image (fix)
        verbose = self.verbose
        remove_temp_files = self.remove_temp_files
        crop_reference = self.crop  # if = 1, put 0 everywhere around warping field, if = 2, real crop

        islabel = False
        if self.interp == 'label':
            islabel = True
            self.interp = 'nn'

        interp = get_interpolation('isct_antsApplyTransforms', self.interp)

        # Parse list of warping fields
        printv('\nParse list of warping fields...', verbose)
        use_inverse = []
        fname_warp_list_invert = []
        # list_warp = list_warp.replace(' ', '')  # remove spaces
        # list_warp = list_warp.split(",")  # parse with comma
        for idx_warp, path_warp in enumerate(self.list_warp):
            # Check if this transformation should be inverted
            if path_warp in self.list_warpinv:
                use_inverse.append('-i')
                # list_warp[idx_warp] = path_warp[1:]  # remove '-'
                fname_warp_list_invert += [[
                    use_inverse[idx_warp], list_warp[idx_warp]
                ]]
            else:
                use_inverse.append('')
                fname_warp_list_invert += [[path_warp]]
            path_warp = list_warp[idx_warp]
            if path_warp.endswith((".nii", ".nii.gz")) \
                    and Image(list_warp[idx_warp]).header.get_intent()[0] != 'vector':
                raise ValueError(
                    "Displacement field in {} is invalid: should be encoded"
                    " in a 5D file with vector intent code"
                    " (see https://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h"
                    .format(path_warp))
        # need to check if last warping field is an affine transfo
        isLastAffine = False
        path_fname, file_fname, ext_fname = extract_fname(
            fname_warp_list_invert[-1][-1])
        if ext_fname in ['.txt', '.mat']:
            isLastAffine = True

        # check if destination file is 3d
        # check_dim(fname_dest, dim_lst=[3]) # PR 2598: we decided to skip this line.

        # N.B. Here we take the inverse of the warp list, because sct_WarpImageMultiTransform concatenates in the reverse order
        fname_warp_list_invert.reverse()
        fname_warp_list_invert = functools.reduce(lambda x, y: x + y,
                                                  fname_warp_list_invert)

        # Extract path, file and extension
        path_src, file_src, ext_src = extract_fname(fname_src)
        path_dest, file_dest, ext_dest = extract_fname(fname_dest)

        # Get output folder and file name
        if fname_out == '':
            path_out = ''  # output in user's current directory
            file_out = file_src + '_reg'
            ext_out = ext_src
            fname_out = os.path.join(path_out, file_out + ext_out)

        # Get dimensions of data
        printv('\nGet dimensions of data...', verbose)
        img_src = Image(fname_src)
        nx, ny, nz, nt, px, py, pz, pt = img_src.dim
        # nx, ny, nz, nt, px, py, pz, pt = get_dimension(fname_src)
        printv(
            '  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' +
            str(nt), verbose)

        # if 3d
        if nt == 1:
            # Apply transformation
            printv('\nApply transformation...', verbose)
            if nz in [0, 1]:
                dim = '2'
            else:
                dim = '3'
            # if labels, dilate before resampling
            if islabel:
                printv("\nDilate labels before warping...")
                path_tmp = tmp_create(basename="apply_transfo")
                fname_dilated_labels = os.path.join(path_tmp,
                                                    "dilated_data.nii")
                # dilate points
                dilate(Image(fname_src), 4, 'ball').save(fname_dilated_labels)
                fname_src = fname_dilated_labels

            printv(
                "\nApply transformation and resample to destination space...",
                verbose)
            run_proc([
                'isct_antsApplyTransforms', '-d', dim, '-i', fname_src, '-o',
                fname_out, '-t'
            ] + fname_warp_list_invert + ['-r', fname_dest] + interp,
                     is_sct_binary=True)

        # if 4d, loop across the T dimension
        else:
            if islabel:
                raise NotImplementedError

            dim = '4'
            path_tmp = tmp_create(basename="apply_transfo")

            # convert to nifti into temp folder
            printv('\nCopying input data to tmp folder and convert to nii...',
                   verbose)
            img_src.save(os.path.join(path_tmp, "data.nii"))
            copy(fname_dest, os.path.join(path_tmp, file_dest + ext_dest))
            fname_warp_list_tmp = []
            for fname_warp in list_warp:
                path_warp, file_warp, ext_warp = extract_fname(fname_warp)
                copy(fname_warp, os.path.join(path_tmp, file_warp + ext_warp))
                fname_warp_list_tmp.append(file_warp + ext_warp)
            fname_warp_list_invert_tmp = fname_warp_list_tmp[::-1]

            curdir = os.getcwd()
            os.chdir(path_tmp)

            # split along T dimension
            printv('\nSplit along T dimension...', verbose)

            im_dat = Image('data.nii')
            im_header = im_dat.hdr
            data_split_list = sct_image.split_data(im_dat, 3)
            for im in data_split_list:
                im.save()

            # apply transfo
            printv('\nApply transformation to each 3D volume...', verbose)
            for it in range(nt):
                file_data_split = 'data_T' + str(it).zfill(4) + '.nii'
                file_data_split_reg = 'data_reg_T' + str(it).zfill(4) + '.nii'

                status, output = run_proc([
                    'isct_antsApplyTransforms',
                    '-d',
                    '3',
                    '-i',
                    file_data_split,
                    '-o',
                    file_data_split_reg,
                    '-t',
                ] + fname_warp_list_invert_tmp + [
                    '-r',
                    file_dest + ext_dest,
                ] + interp,
                                          verbose,
                                          is_sct_binary=True)

            # Merge files back
            printv('\nMerge file back...', verbose)
            import glob
            path_out, name_out, ext_out = extract_fname(fname_out)
            # im_list = [Image(file_name) for file_name in glob.glob('data_reg_T*.nii')]
            # concat_data use to take a list of image in input, now takes a list of file names to open the files one by one (see issue #715)
            fname_list = glob.glob('data_reg_T*.nii')
            fname_list.sort()
            im_list = [Image(fname) for fname in fname_list]
            im_out = sct_image.concat_data(im_list, 3, im_header['pixdim'])
            im_out.save(name_out + ext_out)

            os.chdir(curdir)
            generate_output_file(os.path.join(path_tmp, name_out + ext_out),
                                 fname_out)
            # Delete temporary folder if specified
            if remove_temp_files:
                printv('\nRemove temporary files...', verbose)
                rmtree(path_tmp, verbose=verbose)

        # Copy affine matrix from destination space to make sure qform/sform are the same
        printv(
            "Copy affine matrix from destination space to make sure qform/sform are the same.",
            verbose)
        im_src_reg = Image(fname_out)
        im_src_reg.copy_qform_from_ref(Image(fname_dest))
        im_src_reg.save(
            verbose=0
        )  # set verbose=0 to avoid warning message about rewriting file

        if islabel:
            printv(
                "\nTake the center of mass of each registered dilated labels..."
            )
            labeled_img = cubic_to_point(im_src_reg)
            labeled_img.save(path=fname_out)
            if remove_temp_files:
                printv('\nRemove temporary files...', verbose)
                rmtree(path_tmp, verbose=verbose)

        # Crop the resulting image using dimensions from the warping field
        warping_field = fname_warp_list_invert[-1]
        # If the last transformation is not an affine transfo, we need to compute the matrix space of the concatenated
        # warping field
        if not isLastAffine and crop_reference in [1, 2]:
            printv('Last transformation is not affine.')
            if crop_reference in [1, 2]:
                # Extract only the first ndim of the warping field
                img_warp = Image(warping_field)
                if dim == '2':
                    img_warp_ndim = Image(img_src.data[:, :], hdr=img_warp.hdr)
                elif dim in ['3', '4']:
                    img_warp_ndim = Image(img_src.data[:, :, :],
                                          hdr=img_warp.hdr)
                # Set zero to everything outside the warping field
                cropper = ImageCropper(Image(fname_out))
                cropper.get_bbox_from_ref(img_warp_ndim)
                if crop_reference == 1:
                    printv(
                        'Cropping strategy is: keep same matrix size, put 0 everywhere around warping field'
                    )
                    img_out = cropper.crop(background=0)
                elif crop_reference == 2:
                    printv(
                        'Cropping strategy is: crop around warping field (the size of warping field will '
                        'change)')
                    img_out = cropper.crop()
                img_out.save(fname_out)

        display_viewer_syntax([fname_dest, fname_out], verbose=verbose)
Exemple #16
0
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv)
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    # initialize parameters
    param = Param()

    fname_data = arguments.i
    fname_bvecs = arguments.bvec
    average = arguments.a
    remove_temp_files = arguments.r
    path_out = arguments.ofolder

    fname_bvals = arguments.bval
    if arguments.bvalmin:
        param.bval_min = arguments.bvalmin

    # Initialization
    start_time = time.time()

    # printv(arguments)
    printv('\nInput parameters:', verbose)
    printv('  input file ............' + fname_data, verbose)
    printv('  bvecs file ............' + fname_bvecs, verbose)
    printv('  bvals file ............' + fname_bvals, verbose)
    printv('  average ...............' + str(average), verbose)

    # Get full path
    fname_data = os.path.abspath(fname_data)
    fname_bvecs = os.path.abspath(fname_bvecs)
    if fname_bvals:
        fname_bvals = os.path.abspath(fname_bvals)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    # create temporary folder
    path_tmp = tmp_create(basename="dmri_separate")

    # copy files into tmp folder and convert to nifti
    printv('\nCopy files into temporary folder...', verbose)
    ext = '.nii'
    dmri_name = 'dmri'
    b0_name = file_data + '_b0'
    b0_mean_name = b0_name + '_mean'
    dwi_name = file_data + '_dwi'
    dwi_mean_name = dwi_name + '_mean'

    if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)):
        printv('ERROR in convert.', 1, 'error')
    copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose)

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get size of data
    im_dmri = Image(dmri_name + ext)
    printv('\nGet dimensions data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim
    printv(
        '.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt),
        verbose)

    # Identify b=0 and DWI images
    printv(fname_bvals)
    index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals,
                                                     param.bval_min, verbose)

    # Split into T dimension
    printv('\nSplit along T dimension...', verbose)
    im_dmri_split_list = split_data(im_dmri, 3)
    for im_d in im_dmri_split_list:
        im_d.save()

    # Merge b=0 images
    printv('\nMerge b=0...', verbose)
    fname_in_list_b0 = []
    for it in range(nb_b0):
        fname_in_list_b0.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) +
                                ext)
    im_in_list_b0 = [Image(fname) for fname in fname_in_list_b0]
    concat_data(im_in_list_b0, 3).save(b0_name + ext)

    # Average b=0 images
    if average:
        printv('\nAverage b=0...', verbose)
        img = Image(b0_name + ext)
        out = img.copy()
        dim_idx = 3
        if len(np.shape(img.data)) < dim_idx + 1:
            raise ValueError("Expecting image with 4 dimensions!")
        out.data = np.mean(out.data, dim_idx)
        out.save(path=b0_mean_name + ext)

    # Merge DWI
    fname_in_list_dwi = []
    for it in range(nb_dwi):
        fname_in_list_dwi.append(dmri_name + '_T' +
                                 str(index_dwi[it]).zfill(4) + ext)
    im_in_list_dwi = [Image(fname) for fname in fname_in_list_dwi]
    concat_data(im_in_list_dwi, 3).save(dwi_name + ext)

    # Average DWI images
    if average:
        printv('\nAverage DWI...', verbose)
        img = Image(dwi_name + ext)
        out = img.copy()
        dim_idx = 3
        if len(np.shape(img.data)) < dim_idx + 1:
            raise ValueError("Expecting image with 4 dimensions!")
        out.data = np.mean(out.data, dim_idx)
        out.save(path=dwi_mean_name + ext)

    # come back
    os.chdir(curdir)

    # Generate output files
    fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data))
    fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data))
    fname_b0_mean = os.path.abspath(
        os.path.join(path_out, b0_mean_name + ext_data))
    fname_dwi_mean = os.path.abspath(
        os.path.join(path_out, dwi_mean_name + ext_data))
    printv('\nGenerate output files...', verbose)
    generate_output_file(os.path.join(path_tmp, b0_name + ext),
                         fname_b0,
                         verbose=verbose)
    generate_output_file(os.path.join(path_tmp, dwi_name + ext),
                         fname_dwi,
                         verbose=verbose)
    if average:
        generate_output_file(os.path.join(path_tmp, b0_mean_name + ext),
                             fname_b0_mean,
                             verbose=verbose)
        generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext),
                             fname_dwi_mean,
                             verbose=verbose)

    # Remove temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...', verbose)
        rmtree(path_tmp, verbose=verbose)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        verbose)

    return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
def main():

    # Initialization
    fname_data = ''
    interp_factor = param.interp_factor
    remove_temp_files = param.remove_temp_files
    verbose = param.verbose
    suffix = param.suffix
    smoothing_sigma = param.smoothing_sigma

    # start timer
    start_time = time.time()

    # Parameters for debug mode
    if param.debug:
        fname_data = os.path.join(__data_dir__, 'sct_testing_data', 't2',
                                  't2_seg.nii.gz')
        remove_temp_files = 0
        param.mask_size = 10
    else:
        # Check input parameters
        try:
            opts, args = getopt.getopt(sys.argv[1:], 'hi:v:r:s:')
        except getopt.GetoptError:
            usage()
            raise SystemExit(2)
        if not opts:
            usage()
            raise SystemExit(2)
        for opt, arg in opts:
            if opt == '-h':
                usage()
                return
            elif opt in ('-i'):
                fname_data = arg
            elif opt in ('-r'):
                remove_temp_files = int(arg)
            elif opt in ('-s'):
                smoothing_sigma = arg
            elif opt in ('-v'):
                verbose = int(arg)

    # display usage if a mandatory argument is not provided
    if fname_data == '':
        usage()
        raise SystemExit(2)

    # printv(arguments)
    printv('\nCheck parameters:')
    printv('  segmentation ........... ' + fname_data)
    printv('  interp factor .......... ' + str(interp_factor))
    printv('  smoothing sigma ........ ' + str(smoothing_sigma))

    # check existence of input files
    printv('\nCheck existence of input files...')
    check_file_exist(fname_data, verbose)

    # Extract path, file and extension
    path_data, file_data, ext_data = extract_fname(fname_data)

    path_tmp = tmp_create(basename="binary_to_trilinear")

    from sct_convert import convert
    printv('\nCopying input data to tmp folder and convert to nii...',
           param.verbose)
    convert(fname_data, os.path.join(path_tmp, "data.nii"))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get dimensions of data
    printv('\nGet dimensions of data...', verbose)
    nx, ny, nz, nt, px, py, pz, pt = Image('data.nii').dim
    printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose)

    # upsample data
    printv('\nUpsample data...', verbose)
    run_proc([
        "sct_resample", "-i", "data.nii", "-x", "linear", "-vox",
        str(nx * interp_factor) + 'x' + str(ny * interp_factor) + 'x' +
        str(nz * interp_factor), "-o", "data_up.nii"
    ], verbose)

    # Smooth along centerline
    printv('\nSmooth along centerline...', verbose)
    run_proc([
        "sct_smooth_spinalcord", "-i", "data_up.nii", "-s", "data_up.nii",
        "-smooth",
        str(smoothing_sigma), "-r",
        str(remove_temp_files), "-v",
        str(verbose)
    ], verbose)

    # downsample data
    printv('\nDownsample data...', verbose)
    run_proc([
        "sct_resample", "-i", "data_up_smooth.nii", "-x", "linear", "-vox",
        str(nx) + 'x' + str(ny) + 'x' + str(nz), "-o",
        "data_up_smooth_down.nii"
    ], verbose)

    # come back
    os.chdir(curdir)

    # Generate output files
    printv('\nGenerate output files...')
    fname_out = generate_output_file(
        os.path.join(path_tmp, "data_up_smooth_down.nii"),
        '' + file_data + suffix + ext_data)

    # Delete temporary files
    if remove_temp_files == 1:
        printv('\nRemove temporary files...')
        rmtree(path_tmp)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) +
           's')

    # to view results
    printv('\nTo view results, type:')
    printv('fslview ' + file_data + ' ' + file_data + suffix + ' &\n')
def create_mask(param):
    # parse argument for method
    method_type = param.process[0]
    # check method val
    if not method_type == 'center':
        method_val = param.process[1]

    # check existence of input files
    if method_type == 'centerline':
        check_file_exist(method_val, param.verbose)

    # Extract path/file/extension
    path_data, file_data, ext_data = extract_fname(param.fname_data)

    # Get output folder and file name
    if param.fname_out == '':
        param.fname_out = os.path.abspath(param.file_prefix + file_data +
                                          ext_data)

    path_tmp = tmp_create(basename="create_mask")

    printv('\nOrientation:', param.verbose)
    orientation_input = Image(param.fname_data).orientation
    printv('  ' + orientation_input, param.verbose)

    # copy input data to tmp folder and re-orient to RPI
    Image(param.fname_data).change_orientation("RPI").save(
        os.path.join(path_tmp, "data_RPI.nii"))
    if method_type == 'centerline':
        Image(method_val).change_orientation("RPI").save(
            os.path.join(path_tmp, "centerline_RPI.nii"))
    if method_type == 'point':
        Image(method_val).change_orientation("RPI").save(
            os.path.join(path_tmp, "point_RPI.nii"))

    # go to tmp folder
    curdir = os.getcwd()
    os.chdir(path_tmp)

    # Get dimensions of data
    im_data = Image('data_RPI.nii')
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('\nDimensions:', param.verbose)
    printv(im_data.dim, param.verbose)
    # in case user input 4d data
    if nt != 1:
        printv(
            'WARNING in ' + os.path.basename(__file__) +
            ': Input image is 4d but output mask will be 3D from first time slice.',
            param.verbose, 'warning')
        # extract first volume to have 3d reference
        nii = empty_like(Image('data_RPI.nii'))
        data3d = nii.data[:, :, :, 0]
        nii.data = data3d
        nii.save('data_RPI.nii')

    if method_type == 'coord':
        # parse to get coordinate
        coord = [x for x in map(int, method_val.split('x'))]

    if method_type == 'point':
        # extract coordinate of point
        printv('\nExtract coordinate of point...', param.verbose)
        coord = Image("point_RPI.nii").getNonZeroCoordinates()

    if method_type == 'center':
        # set coordinate at center of FOV
        coord = np.round(float(nx) / 2), np.round(float(ny) / 2)

    if method_type == 'centerline':
        # get name of centerline from user argument
        fname_centerline = 'centerline_RPI.nii'
    else:
        # generate volume with line along Z at coordinates 'coord'
        printv('\nCreate line...', param.verbose)
        fname_centerline = create_line(param, 'data_RPI.nii', coord, nz)

    # create mask
    printv('\nCreate mask...', param.verbose)
    centerline = nibabel.load(fname_centerline)  # open centerline
    hdr = centerline.get_header()  # get header
    hdr.set_data_dtype('uint8')  # set imagetype to uint8
    spacing = hdr.structarr['pixdim']
    data_centerline = centerline.get_data()  # get centerline
    # if data is 2D, reshape with empty third dimension
    if len(data_centerline.shape) == 2:
        data_centerline_shape = list(data_centerline.shape)
        data_centerline_shape.append(1)
        data_centerline = data_centerline.reshape(data_centerline_shape)
    z_centerline_not_null = [
        iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any()
    ]
    # get center of mass of the centerline
    cx = [0] * nz
    cy = [0] * nz
    for iz in range(0, nz, 1):
        if iz in z_centerline_not_null:
            cx[iz], cy[iz] = ndimage.measurements.center_of_mass(
                np.array(data_centerline[:, :, iz]))
    # create 2d masks
    im_list = []
    for iz in range(nz):
        if iz not in z_centerline_not_null:
            im_list.append(Image(data_centerline[:, :, iz], hdr=hdr))
        else:
            center = np.array([cx[iz], cy[iz]])
            mask2d = create_mask2d(param,
                                   center,
                                   param.shape,
                                   param.size,
                                   im_data=im_data)
            im_list.append(Image(mask2d, hdr=hdr))
    im_out = concat_data(im_list, dim=2).save('mask_RPI.nii.gz')

    im_out.change_orientation(orientation_input)
    im_out.header = Image(param.fname_data).header
    im_out.save(param.fname_out)

    # come back
    os.chdir(curdir)

    # Remove temporary files
    if param.remove_temp_files == 1:
        printv('\nRemove temporary files...', param.verbose)
        rmtree(path_tmp)

    display_viewer_syntax([param.fname_data, param.fname_out],
                          colormaps=['gray', 'red'],
                          opacities=['', '0.5'])
    def straighten(self):
        """
        Straighten spinal cord. Steps: (everything is done in physical space)
        1. open input image and centreline image
        2. extract bspline fitting of the centreline, and its derivatives
        3. compute length of centerline
        4. compute and generate straight space
        5. compute transformations
            for each voxel of one space: (done using matrices --> improves speed by a factor x300)
                a. determine which plane of spinal cord centreline it is included
                b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane)
                c. find the correspondant centreline point in the other space
                d. find the correspondance of the voxel in the corresponding plane
        6. generate warping fields for each transformations
        7. write warping fields and apply them

        step 5.b: how to find the corresponding plane?
            The centerline plane corresponding to a voxel correspond to the nearest point of the centerline.
            However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant.
            If it is more far than a threshold, warping value should be 0.

        step 5.d: how to make the correspondance between centerline point in both images?
            Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve.
            If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward.

        :return:
        """
        # Initialization
        fname_anat = self.input_filename
        fname_centerline = self.centerline_filename
        fname_output = self.output_filename
        remove_temp_files = self.remove_temp_files
        verbose = self.verbose
        interpolation_warp = self.interpolation_warp  # TODO: remove this

        # start timer
        start_time = time.time()

        # Extract path/file/extension
        path_anat, file_anat, ext_anat = extract_fname(fname_anat)

        path_tmp = tmp_create(basename="straighten_spinalcord")

        # Copying input data to tmp folder
        logger.info('Copy files to tmp folder...')
        Image(fname_anat,
              check_sform=True).save(os.path.join(path_tmp, "data.nii"))
        Image(fname_centerline, check_sform=True).save(
            os.path.join(path_tmp, "centerline.nii.gz"))

        if self.use_straight_reference:
            Image(self.centerline_reference_filename, check_sform=True).save(
                os.path.join(path_tmp, "centerline_ref.nii.gz"))
        if self.discs_input_filename != '':
            Image(self.discs_input_filename, check_sform=True).save(
                os.path.join(path_tmp, "labels_input.nii.gz"))
        if self.discs_ref_filename != '':
            Image(self.discs_ref_filename, check_sform=True).save(
                os.path.join(path_tmp, "labels_ref.nii.gz"))

        # go to tmp folder
        curdir = os.getcwd()
        os.chdir(path_tmp)

        # Change orientation of the input centerline into RPI
        image_centerline = Image("centerline.nii.gz").change_orientation(
            "RPI").save("centerline_rpi.nii.gz", mutable=True)

        # Get dimension
        nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
        if self.speed_factor != 1.0:
            intermediate_resampling = True
            px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor
        else:
            intermediate_resampling = False

        if intermediate_resampling:
            mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz')
            pz_native = pz
            # TODO: remove system call
            run_proc([
                'sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm',
                str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o',
                'centerline_rpi.nii.gz'
            ])
            image_centerline = Image('centerline_rpi.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim

        if np.min(image_centerline.data) < 0 or np.max(
                image_centerline.data) > 1:
            image_centerline.data[image_centerline.data < 0] = 0
            image_centerline.data[image_centerline.data > 1] = 1
            image_centerline.save()

        # 2. extract bspline fitting of the centerline, and its derivatives
        img_ctl = Image('centerline_rpi.nii.gz')
        centerline = _get_centerline(img_ctl, self.param_centerline, verbose)
        number_of_points = centerline.number_of_points

        # ==========================================================================================
        logger.info('Create the straight space and the safe zone')
        # 3. compute length of centerline
        # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction

        # Computation of the safe zone.
        # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete
        # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the
        # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle)
        # Calculate Ls for both edges and remove appropriate number of centerline points
        radius_safe = 0.0  # mm

        # inferior edge
        u = centerline.derivatives[0]
        v = np.array([0, 0, -1])

        angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)),
                                    np.dot(u, v))
        length_safe_inferior = radius_safe * np.sin(angle_inferior)

        # superior edge
        u = centerline.derivatives[-1]
        v = np.array([0, 0, 1])
        angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)),
                                    np.dot(u, v))
        length_safe_superior = radius_safe * np.sin(angle_superior)

        # remove points
        inferior_bound = bisect.bisect(centerline.progressive_length,
                                       length_safe_inferior) - 1
        superior_bound = centerline.number_of_points - bisect.bisect(
            centerline.progressive_length_inverse, length_safe_superior)

        z_centerline = centerline.points[:, 2]
        length_centerline = centerline.length
        size_z_centerline = z_centerline[-1] - z_centerline[0]

        # compute the size factor between initial centerline and straight bended centerline
        factor_curved_straight = length_centerline / size_z_centerline
        middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0

        bound_curved = [
            z_centerline[inferior_bound], z_centerline[superior_bound]
        ]
        bound_straight = [(z_centerline[inferior_bound] - middle_slice) *
                          factor_curved_straight + middle_slice,
                          (z_centerline[superior_bound] - middle_slice) *
                          factor_curved_straight + middle_slice]

        logger.info('Length of spinal cord: {}'.format(length_centerline))
        logger.info(
            'Size of spinal cord in z direction: {}'.format(size_z_centerline))
        logger.info('Ratio length/size: {}'.format(factor_curved_straight))
        logger.info(
            'Safe zone boundaries (curved space): {}'.format(bound_curved))
        logger.info(
            'Safe zone boundaries (straight space): {}'.format(bound_straight))

        # 4. compute and generate straight space
        # points along curved centerline are already regularly spaced.
        # calculate position of points along straight centerline

        # Create straight NIFTI volumes.
        # ==========================================================================================
        # TODO: maybe this if case is not needed?
        if self.use_straight_reference:
            image_centerline_pad = Image('centerline_rpi.nii.gz')
            nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim

            fname_ref = 'centerline_ref_rpi.nii.gz'
            image_centerline_straight = Image('centerline_ref.nii.gz') \
                .change_orientation("RPI") \
                .save(fname_ref, mutable=True)
            centerline_straight = _get_centerline(image_centerline_straight,
                                                  self.param_centerline,
                                                  verbose)
            nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim

            # Prepare warping fields headers
            hdr_warp = image_centerline_pad.hdr.copy()
            hdr_warp.set_data_dtype('float32')
            hdr_warp_s = image_centerline_straight.hdr.copy()
            hdr_warp_s.set_data_dtype('float32')

            if self.discs_input_filename != "" and self.discs_ref_filename != "":
                discs_input_image = Image('labels_input.nii.gz')
                coord = discs_input_image.getNonZeroCoordinates(
                    sorting='z', reverse_coord=True)
                coord_physical = []
                for c in coord:
                    c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z]
                                                              ]).tolist()[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)
                centerline.compute_vertebral_distribution(coord_physical)
                centerline.save_centerline(
                    image=discs_input_image,
                    fname_output='discs_input_image.nii.gz')

                discs_ref_image = Image('labels_ref.nii.gz')
                coord = discs_ref_image.getNonZeroCoordinates(
                    sorting='z', reverse_coord=True)
                coord_physical = []
                for c in coord:
                    c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y,
                                                             c.z]]).tolist()[0]
                    c_p.append(c.value)
                    coord_physical.append(c_p)
                centerline_straight.compute_vertebral_distribution(
                    coord_physical)
                centerline_straight.save_centerline(
                    image=discs_ref_image,
                    fname_output='discs_ref_image.nii.gz')

        else:
            logger.info(
                'Pad input volume to account for spinal cord length...')

            start_point, end_point = bound_straight[0], bound_straight[1]
            offset_z = 0

            # if the destination image is resampled, we still create the straight reference space with the native
            # resolution.
            # TODO: Maybe this if case is not needed?
            if intermediate_resampling:
                padding_z = int(
                    np.ceil(1.5 *
                            ((length_centerline - size_z_centerline) / 2.0) /
                            pz_native))
                run_proc([
                    'sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o',
                    'tmp.centerline_pad_native.nii.gz', '-pad',
                    '0,0,' + str(padding_z)
                ])
                image_centerline_pad = Image('centerline_rpi_native.nii.gz')
                nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim
                start_point_coord_native = image_centerline_pad.transfo_phys2pix(
                    [[0, 0, start_point]])[0]
                end_point_coord_native = image_centerline_pad.transfo_phys2pix(
                    [[0, 0, end_point]])[0]
                straight_size_x = int(self.xy_size / px)
                straight_size_y = int(self.xy_size / py)
                warp_space_x = [
                    int(np.round(nx / 2)) - straight_size_x,
                    int(np.round(nx / 2)) + straight_size_x
                ]
                warp_space_y = [
                    int(np.round(ny / 2)) - straight_size_y,
                    int(np.round(ny / 2)) + straight_size_y
                ]
                if warp_space_x[0] < 0:
                    warp_space_x[1] += warp_space_x[0] - 2
                    warp_space_x[0] = 0
                if warp_space_y[0] < 0:
                    warp_space_y[1] += warp_space_y[0] - 2
                    warp_space_y[0] = 0

                spec = dict((
                    (0, warp_space_x),
                    (1, warp_space_y),
                    (2, (0, end_point_coord_native[2] -
                         start_point_coord_native[2])),
                ))
                spatial_crop(
                    Image("tmp.centerline_pad_native.nii.gz"),
                    spec).save("tmp.centerline_pad_crop_native.nii.gz")

                fname_ref = 'tmp.centerline_pad_crop_native.nii.gz'
                offset_z = 4
            else:
                fname_ref = 'tmp.centerline_pad_crop.nii.gz'

            nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim
            padding_z = int(
                np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) /
                        pz)) + offset_z
            image_centerline_pad = pad_image(image_centerline,
                                             pad_z_i=padding_z,
                                             pad_z_f=padding_z)
            nx, ny, nz = image_centerline_pad.data.shape
            hdr_warp = image_centerline_pad.hdr.copy()
            hdr_warp.set_data_dtype('float32')
            start_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, start_point]])[0]
            end_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, end_point]])[0]

            straight_size_x = int(self.xy_size / px)
            straight_size_y = int(self.xy_size / py)
            warp_space_x = [
                int(np.round(nx / 2)) - straight_size_x,
                int(np.round(nx / 2)) + straight_size_x
            ]
            warp_space_y = [
                int(np.round(ny / 2)) - straight_size_y,
                int(np.round(ny / 2)) + straight_size_y
            ]

            if warp_space_x[0] < 0:
                warp_space_x[1] += warp_space_x[0] - 2
                warp_space_x[0] = 0
            if warp_space_x[1] >= nx:
                warp_space_x[1] = nx - 1
            if warp_space_y[0] < 0:
                warp_space_y[1] += warp_space_y[0] - 2
                warp_space_y[0] = 0
            if warp_space_y[1] >= ny:
                warp_space_y[1] = ny - 1

            spec = dict((
                (0, warp_space_x),
                (1, warp_space_y),
                (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)),
            ))
            image_centerline_straight = spatial_crop(image_centerline_pad,
                                                     spec)

            nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim
            hdr_warp_s = image_centerline_straight.hdr.copy()
            hdr_warp_s.set_data_dtype('float32')

            if self.template_orientation == 1:
                raise NotImplementedError()

            start_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, start_point]])[0]
            end_point_coord = image_centerline_pad.transfo_phys2pix(
                [[0, 0, end_point]])[0]

            number_of_voxel = nx * ny * nz
            logger.debug('Number of voxels: {}'.format(number_of_voxel))

            time_centerlines = time.time()

            coord_straight = np.empty((number_of_points, 3))
            coord_straight[..., 0] = int(np.round(nx_s / 2))
            coord_straight[..., 1] = int(np.round(ny_s / 2))
            coord_straight[..., 2] = np.linspace(
                0, end_point_coord[2] - start_point_coord[2], number_of_points)
            coord_phys_straight = image_centerline_straight.transfo_pix2phys(
                coord_straight)
            derivs_straight = np.empty((number_of_points, 3))
            derivs_straight[..., 0] = derivs_straight[..., 1] = 0
            derivs_straight[..., 2] = 1
            dx_straight, dy_straight, dz_straight = derivs_straight.T
            centerline_straight = Centerline(coord_phys_straight[:, 0],
                                             coord_phys_straight[:, 1],
                                             coord_phys_straight[:, 2],
                                             dx_straight, dy_straight,
                                             dz_straight)

            time_centerlines = time.time() - time_centerlines
            logger.info('Time to generate centerline: {} ms'.format(
                np.round(time_centerlines * 1000.0)))

        if verbose == 2:
            # TODO: use OO
            import matplotlib.pyplot as plt
            from datetime import datetime
            curved_points = centerline.progressive_length
            straight_points = centerline_straight.progressive_length
            range_points = np.linspace(0, 1, number_of_points)
            dist_curved = np.zeros(number_of_points)
            dist_straight = np.zeros(number_of_points)
            for i in range(1, number_of_points):
                dist_curved[i] = dist_curved[
                    i - 1] + curved_points[i - 1] / centerline.length
                dist_straight[i] = dist_straight[i - 1] + straight_points[
                    i - 1] / centerline_straight.length
            plt.plot(range_points, dist_curved)
            plt.plot(range_points, dist_straight)
            plt.grid(True)
            plt.savefig('fig_straighten_' +
                        datetime.now().strftime("%y%m%d%H%M%S%f") + '.png')
            plt.close()

        # alignment_mode = 'length'
        alignment_mode = 'levels'

        lookup_curved2straight = list(range(centerline.number_of_points))
        if self.discs_input_filename != "":
            # create look-up table curved to straight
            for index in range(centerline.number_of_points):
                disc_label = centerline.l_points[index]
                if alignment_mode == 'length':
                    relative_position = centerline.dist_points[index]
                else:
                    relative_position = centerline.dist_points_rel[index]
                idx_closest = centerline_straight.get_closest_to_absolute_position(
                    disc_label,
                    relative_position,
                    backup_index=index,
                    backup_centerline=centerline_straight,
                    mode=alignment_mode)
                if idx_closest is not None:
                    lookup_curved2straight[index] = idx_closest
                else:
                    lookup_curved2straight[index] = 0
        for p in range(0, len(lookup_curved2straight) // 2):
            if lookup_curved2straight[p] == lookup_curved2straight[p + 1]:
                lookup_curved2straight[p] = 0
            else:
                break
        for p in range(
                len(lookup_curved2straight) - 1,
                len(lookup_curved2straight) // 2, -1):
            if lookup_curved2straight[p] == lookup_curved2straight[p - 1]:
                lookup_curved2straight[p] = 0
            else:
                break
        lookup_curved2straight = np.array(lookup_curved2straight)

        lookup_straight2curved = list(
            range(centerline_straight.number_of_points))
        if self.discs_input_filename != "":
            for index in range(centerline_straight.number_of_points):
                disc_label = centerline_straight.l_points[index]
                if alignment_mode == 'length':
                    relative_position = centerline_straight.dist_points[index]
                else:
                    relative_position = centerline_straight.dist_points_rel[
                        index]
                idx_closest = centerline.get_closest_to_absolute_position(
                    disc_label,
                    relative_position,
                    backup_index=index,
                    backup_centerline=centerline_straight,
                    mode=alignment_mode)
                if idx_closest is not None:
                    lookup_straight2curved[index] = idx_closest
        for p in range(0, len(lookup_straight2curved) // 2):
            if lookup_straight2curved[p] == lookup_straight2curved[p + 1]:
                lookup_straight2curved[p] = 0
            else:
                break
        for p in range(
                len(lookup_straight2curved) - 1,
                len(lookup_straight2curved) // 2, -1):
            if lookup_straight2curved[p] == lookup_straight2curved[p - 1]:
                lookup_straight2curved[p] = 0
            else:
                break
        lookup_straight2curved = np.array(lookup_straight2curved)

        # Create volumes containing curved and straight warping fields
        data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3))
        data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3))

        # 5. compute transformations
        # Curved and straight images and the same dimensions, so we compute both warping fields at the same time.
        # b. determine which plane of spinal cord centreline it is included

        if self.curved2straight:
            for u in sct_progress_bar(range(nz_s)):
                x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1]
                indexes_straight = np.array(
                    list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel())))
                physical_coordinates_straight = image_centerline_straight.transfo_pix2phys(
                    indexes_straight)
                nearest_indexes_straight = centerline_straight.find_nearest_indexes(
                    physical_coordinates_straight)
                distances_straight = centerline_straight.get_distances_from_planes(
                    physical_coordinates_straight, nearest_indexes_straight)
                lookup = lookup_straight2curved[nearest_indexes_straight]
                indexes_out_distance_straight = np.logical_or(
                    np.logical_or(
                        distances_straight > self.threshold_distance,
                        distances_straight < -self.threshold_distance),
                    lookup == 0)
                projected_points_straight = centerline_straight.get_projected_coordinates_on_planes(
                    physical_coordinates_straight, nearest_indexes_straight)
                coord_in_planes_straight = centerline_straight.get_in_plans_coordinates(
                    projected_points_straight, nearest_indexes_straight)

                coord_straight2curved = centerline.get_inverse_plans_coordinates(
                    coord_in_planes_straight, lookup)
                displacements_straight = coord_straight2curved - physical_coordinates_straight
                # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+)
                # while ours is LPI-
                # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a
                #  https://www.slicer.org/wiki/Coordinate_systems
                displacements_straight[:, 2] = -displacements_straight[:, 2]
                displacements_straight[indexes_out_distance_straight] = [
                    100000.0, 100000.0, 100000.0
                ]

                data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\
                    = -displacements_straight

        if self.straight2curved:
            for u in sct_progress_bar(range(nz)):
                x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1]
                indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel())))
                physical_coordinates = image_centerline_pad.transfo_pix2phys(
                    indexes)
                nearest_indexes_curved = centerline.find_nearest_indexes(
                    physical_coordinates)
                distances_curved = centerline.get_distances_from_planes(
                    physical_coordinates, nearest_indexes_curved)
                lookup = lookup_curved2straight[nearest_indexes_curved]
                indexes_out_distance_curved = np.logical_or(
                    np.logical_or(distances_curved > self.threshold_distance,
                                  distances_curved < -self.threshold_distance),
                    lookup == 0)
                projected_points_curved = centerline.get_projected_coordinates_on_planes(
                    physical_coordinates, nearest_indexes_curved)
                coord_in_planes_curved = centerline.get_in_plans_coordinates(
                    projected_points_curved, nearest_indexes_curved)

                coord_curved2straight = centerline_straight.points[lookup]
                coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2]
                coord_curved2straight[:, 2] += distances_curved

                displacements_curved = coord_curved2straight - physical_coordinates

                displacements_curved[:, 2] = -displacements_curved[:, 2]
                displacements_curved[indexes_out_distance_curved] = [
                    100000.0, 100000.0, 100000.0
                ]

                data_warp_straight2curved[indexes[:, 0], indexes[:, 1],
                                          indexes[:, 2],
                                          0, :] = -displacements_curved

        # Creation of the safe zone based on pre-calculated safe boundaries
        coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix(
            [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix(
                [[0, 0, bound_curved[1]]])
        coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix(
            [[0, 0,
              bound_straight[0]]]), image_centerline_straight.transfo_phys2pix(
                  [[0, 0, bound_straight[1]]])

        if radius_safe > 0:
            data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2],
                                      0, :] = 100000.0
            data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:,
                                      0, :] = 100000.0
            data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2],
                                      0, :] = 100000.0
            data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:,
                                      0, :] = 100000.0

        # Generate warp files as a warping fields
        hdr_warp_s.set_intent('vector', (), '')
        hdr_warp_s.set_data_dtype('float32')
        hdr_warp.set_intent('vector', (), '')
        hdr_warp.set_data_dtype('float32')
        if self.curved2straight:
            img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s)
            save(img, 'tmp.curve2straight.nii.gz')
            logger.info('Warping field generated: tmp.curve2straight.nii.gz')

        if self.straight2curved:
            img = Nifti1Image(data_warp_straight2curved, None, hdr_warp)
            save(img, 'tmp.straight2curve.nii.gz')
            logger.info('Warping field generated: tmp.straight2curve.nii.gz')

        image_centerline_straight.save(fname_ref)
        if self.curved2straight:
            logger.info('Apply transformation to input image...')
            run_proc([
                'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i',
                'data.nii', '-o', 'tmp.anat_rigid_warp.nii.gz', '-t',
                'tmp.curve2straight.nii.gz', '-n', 'BSpline[3]'
            ],
                     is_sct_binary=True,
                     verbose=verbose)

        if self.accuracy_results:
            time_accuracy_results = time.time()
            # compute the error between the straightened centerline/segmentation and the central vertical line.
            # Ideally, the error should be zero.
            # Apply deformation to input image
            logger.info('Apply transformation to centerline image...')
            run_proc([
                'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i',
                'centerline.nii.gz', '-o', 'tmp.centerline_straight.nii.gz',
                '-t', 'tmp.curve2straight.nii.gz', '-n', 'NearestNeighbor'
            ],
                     is_sct_binary=True,
                     verbose=verbose)
            file_centerline_straight = Image('tmp.centerline_straight.nii.gz',
                                             verbose=verbose)
            nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim
            coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(
                sorting='z')
            mean_coord = []
            for z in range(coordinates_centerline[0].z,
                           coordinates_centerline[-1].z):
                temp_mean = [
                    coord.value for coord in coordinates_centerline
                    if coord.z == z
                ]
                if temp_mean:
                    mean_value = np.mean(temp_mean)
                    mean_coord.append(
                        np.mean([[
                            coord.x * coord.value / mean_value,
                            coord.y * coord.value / mean_value
                        ] for coord in coordinates_centerline if coord.z == z],
                                axis=0))

            # compute error between the straightened centerline and the straight line.
            x0 = file_centerline_straight.data.shape[0] / 2.0
            y0 = file_centerline_straight.data.shape[1] / 2.0
            count_mean = 0
            if number_of_points >= 10:
                mean_c = mean_coord[
                    2:
                    -2]  # we don't include the four extrema because there are usually messy.
            else:
                mean_c = mean_coord
            for coord_z in mean_c:
                if not np.isnan(np.sum(coord_z)):
                    dist = ((x0 - coord_z[0]) * px)**2 + (
                        (y0 - coord_z[1]) * py)**2
                    self.mse_straightening += dist
                    dist = np.sqrt(dist)
                    if dist > self.max_distance_straightening:
                        self.max_distance_straightening = dist
                    count_mean += 1
            self.mse_straightening = np.sqrt(self.mse_straightening /
                                             float(count_mean))

            self.elapsed_time_accuracy = time.time() - time_accuracy_results

        os.chdir(curdir)

        # Generate output file (in current folder)
        # TODO: do not uncompress the warping field, it is too time consuming!
        logger.info('Generate output files...')
        if self.curved2straight:
            generate_output_file(
                os.path.join(path_tmp, "tmp.curve2straight.nii.gz"),
                os.path.join(self.path_output, "warp_curve2straight.nii.gz"),
                verbose)
        if self.straight2curved:
            generate_output_file(
                os.path.join(path_tmp, "tmp.straight2curve.nii.gz"),
                os.path.join(self.path_output, "warp_straight2curve.nii.gz"),
                verbose)

        # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space
        if self.curved2straight:
            copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                 os.path.join(self.path_output, "straight_ref.nii.gz"))
            # move straightened input file
            if fname_output == '':
                fname_straight = generate_output_file(
                    os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                    os.path.join(self.path_output,
                                 file_anat + "_straight" + ext_anat), verbose)
            else:
                fname_straight = generate_output_file(
                    os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"),
                    os.path.join(self.path_output, fname_output),
                    verbose)  # straightened anatomic

        # Remove temporary files
        if remove_temp_files:
            logger.info('Remove temporary files...')
            rmtree(path_tmp)

        if self.accuracy_results:
            logger.info('Maximum x-y error: {} mm'.format(
                self.max_distance_straightening))
            logger.info('Accuracy of straightening (MSE): {} mm'.format(
                self.mse_straightening))

        # display elapsed time
        self.elapsed_time = int(np.round(time.time() - start_time))

        return fname_straight
Exemple #20
0
def moco_wrapper(param):
    """
    Wrapper that performs motion correction.

    :param param: ParamMoco class
    :return: None
    """
    file_data = 'data.nii'  # corresponds to the full input data (e.g. dmri or fmri)
    file_data_dirname, file_data_basename, file_data_ext = extract_fname(
        file_data)
    file_b0 = 'b0.nii'
    file_datasub = 'datasub.nii'  # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True)
    file_datasubgroup = 'datasub-groups.nii'  # concatenation of the average of each file_datasub
    file_mask = 'mask.nii'
    file_moco_params_csv = 'moco_params.tsv'
    file_moco_params_x = 'moco_params_x.nii.gz'
    file_moco_params_y = 'moco_params_y.nii.gz'
    ext_data = '.nii.gz'  # workaround "too many open files" by slurping the data
    # TODO: check if .nii can be used
    mat_final = 'mat_final/'
    # ext_mat = 'Warp.nii.gz'  # warping field

    # Start timer
    start_time = time.time()

    printv('\nInput parameters:', param.verbose)
    printv('  Input file ............ ' + param.fname_data, param.verbose)
    printv('  Group size ............ {}'.format(param.group_size),
           param.verbose)

    # Get full path
    # param.fname_data = os.path.abspath(param.fname_data)
    # param.fname_bvecs = os.path.abspath(param.fname_bvecs)
    # if param.fname_bvals != '':
    #     param.fname_bvals = os.path.abspath(param.fname_bvals)

    # Extract path, file and extension
    # path_data, file_data, ext_data = extract_fname(param.fname_data)
    # path_mask, file_mask, ext_mask = extract_fname(param.fname_mask)

    path_tmp = tmp_create(basename="moco")

    # Copying input data to tmp folder
    printv('\nCopying input data to tmp folder and convert to nii...',
           param.verbose)
    convert(param.fname_data, os.path.join(path_tmp, file_data))
    if param.fname_mask != '':
        convert(param.fname_mask,
                os.path.join(path_tmp, file_mask),
                verbose=param.verbose)
        # Update field in param (because used later in another function, and param class will be passed)
        param.fname_mask = file_mask

    # Build absolute output path and go to tmp folder
    curdir = os.getcwd()
    path_out_abs = os.path.abspath(param.path_out)
    os.chdir(path_tmp)

    # Get dimensions of data
    printv('\nGet dimensions of data...', param.verbose)
    im_data = Image(file_data)
    nx, ny, nz, nt, px, py, pz, pt = im_data.dim
    printv('  ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose)

    # Get orientation
    printv('\nData orientation: ' + im_data.orientation, param.verbose)
    if im_data.orientation[2] in 'LR':
        param.is_sagittal = True
        printv('  Treated as sagittal')
    elif im_data.orientation[2] in 'IS':
        param.is_sagittal = False
        printv('  Treated as axial')
    else:
        param.is_sagittal = False
        printv(
            'WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.'
        )

    printv(
        "\nSet suffix of transformation file name, which depends on the orientation:"
    )
    if param.is_sagittal:
        param.suffix_mat = '0GenericAffine.mat'
        printv(
            "Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the "
            "estimated transformation is a 2D affine transfo.".format(
                param.suffix_mat))
    else:
        param.suffix_mat = 'Warp.nii.gz'
        printv(
            "Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is "
            "composed of a stack of 2D Tx-Ty transformations".format(
                param.suffix_mat))

    # Adjust group size in case of sagittal scan
    if param.is_sagittal and param.group_size != 1:
        printv(
            'For sagittal data group_size should be one for more robustness. Forcing group_size=1.',
            1, 'warning')
        param.group_size = 1

    if param.is_diffusion:
        # Identify b=0 and DWI images
        index_b0, index_dwi, nb_b0, nb_dwi = \
            sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min,
                                                     param.verbose)

        # check if dmri and bvecs are the same size
        if not nb_b0 + nb_dwi == nt:
            printv(
                '\nERROR in ' + os.path.basename(__file__) +
                ': Size of data (' + str(nt) + ') and size of bvecs (' +
                str(nb_b0 + nb_dwi) +
                ') are not the same. Check your bvecs file.\n', 1, 'error')
            sys.exit(2)

    # ==================================================================================================================
    # Prepare data (mean/groups...)
    # ==================================================================================================================

    # Split into T dimension
    printv('\nSplit along T dimension...', param.verbose)
    im_data_split_list = split_data(im_data, 3)
    for im in im_data_split_list:
        x_dirname, x_basename, x_ext = extract_fname(im.absolutepath)
        im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz")
        im.save()

    if param.is_diffusion:
        # Merge and average b=0 images
        printv('\nMerge and average b=0 data...', param.verbose)
        im_b0_list = []
        for it in range(nb_b0):
            im_b0_list.append(im_data_split_list[index_b0[it]])
        im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0)
        # Average across time
        im_b0.mean(dim=3).save(add_suffix(file_b0, '_mean'))

        n_moco = nb_dwi  # set number of data to perform moco on (using grouping)
        index_moco = index_dwi

    # If not a diffusion scan, we will motion-correct all volumes
    else:
        n_moco = nt
        index_moco = list(range(0, nt))

    nb_groups = int(math.floor(n_moco / param.group_size))

    # Generate groups indexes
    group_indexes = []
    for iGroup in range(nb_groups):
        group_indexes.append(index_moco[(iGroup *
                                         param.group_size):((iGroup + 1) *
                                                            param.group_size)])

    # add the remaining images to a new last group (in case the total number of image is not divisible by group_size)
    nb_remaining = n_moco % param.group_size  # number of remaining images
    if nb_remaining > 0:
        nb_groups += 1
        group_indexes.append(index_moco[len(index_moco) -
                                        nb_remaining:len(index_moco)])

    _, file_dwi_basename, file_dwi_ext = extract_fname(file_datasub)
    # Group data
    list_file_group = []
    for iGroup in sct_progress_bar(range(nb_groups),
                                   unit='iter',
                                   unit_scale=False,
                                   desc="Merge within groups",
                                   ascii=False,
                                   ncols=80):
        # get index
        index_moco_i = group_indexes[iGroup]
        n_moco_i = len(index_moco_i)
        # concatenate images across time, within this group
        file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) +
                                        ext_data)
        im_dwi_list = []
        for it in range(n_moco_i):
            im_dwi_list.append(im_data_split_list[index_moco_i[it]])
        im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i,
                                                      verbose=0)
        # Average across time
        list_file_group.append(
            os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' +
                         ext_data))
        im_dwi_out.mean(dim=3).save(list_file_group[-1])

    # Merge across groups
    printv('\nMerge across groups...', param.verbose)
    # file_dwi_groups_means_merge = 'dwi_averaged_groups'
    fname_dw_list = []
    for iGroup in range(nb_groups):
        fname_dw_list.append(list_file_group[iGroup])
    im_dw_list = [Image(fname) for fname in fname_dw_list]
    concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0)

    # Cleanup
    del im, im_data_split_list

    # ==================================================================================================================
    # Estimate moco
    # ==================================================================================================================

    # Initialize another class instance that will be passed on to the moco() function
    param_moco = deepcopy(param)

    if param.is_diffusion:
        # Estimate moco on b0 groups
        printv(
            '\n-------------------------------------------------------------------------------',
            param.verbose)
        printv('  Estimating motion on b=0 images...', param.verbose)
        printv(
            '-------------------------------------------------------------------------------',
            param.verbose)
        param_moco.file_data = 'b0.nii'
        # Identify target image
        if index_moco[0] != 0:
            # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that
            # case select it as the target image for registration of all b=0
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[index_moco[0] - 1]).zfill(4) + ext_data)
        else:
            # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0.
            param_moco.file_target = os.path.join(
                file_data_dirname, file_data_basename + '_T' +
                str(index_b0[0]).zfill(4) + ext_data)
        # Run moco
        param_moco.path_out = ''
        param_moco.todo = 'estimate_and_apply'
        param_moco.mat_moco = 'mat_b0groups'
        file_mat_b0, _ = moco(param_moco)

    # Estimate moco across groups
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Estimating motion across groups...', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_datasubgroup
    param_moco.file_target = list_file_group[
        0]  # target is the first volume (closest to the first b=0 if DWI scan)
    param_moco.path_out = ''
    param_moco.todo = 'estimate_and_apply'
    param_moco.mat_moco = 'mat_groups'
    file_mat_datasub_group, _ = moco(param_moco)

    # Spline Regularization along T
    if param.spline_fitting:
        # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy)
        raise NotImplementedError()
        # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph)

    # ==================================================================================================================
    # Apply moco
    # ==================================================================================================================

    # If group_size>1, assign transformation to each individual ungrouped 3d volume
    if param.group_size > 1:
        file_mat_datasub = []
        for iz in range(len(file_mat_datasub_group)):
            # duplicate by factor group_size the transformation file for each it
            #  example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2
            file_mat_datasub.append(
                functools.reduce(operator.iconcat,
                                 [[i] * param.group_size
                                  for i in file_mat_datasub_group[iz]], []))
    else:
        file_mat_datasub = file_mat_datasub_group

    # Copy transformations to mat_final folder and rename them appropriately
    copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param)
    if param.is_diffusion:
        copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param)

    # Apply moco on all dmri data
    printv(
        '\n-------------------------------------------------------------------------------',
        param.verbose)
    printv('  Apply moco', param.verbose)
    printv(
        '-------------------------------------------------------------------------------',
        param.verbose)
    param_moco.file_data = file_data
    param_moco.file_target = list_file_group[
        0]  # reference for reslicing into proper coordinate system
    param_moco.path_out = ''  # TODO not used in moco()
    param_moco.mat_moco = mat_final
    param_moco.todo = 'apply'
    file_mat_data, im_moco = moco(param_moco)

    # copy geometric information from header
    # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1".
    im_moco.header = im_data.header
    im_moco.save(verbose=0)

    # Average across time
    if param.is_diffusion:
        # generate b0_moco_mean and dwi_moco_mean
        args = [
            '-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1',
            '-v', '0'
        ]
        if not param.fname_bvals == '':
            # if bvals file is provided
            args += ['-bval', param.fname_bvals]
        fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main(
            argv=args)
    else:
        fname_moco_mean = add_suffix(im_moco.absolutepath, '_mean')
        im_moco.mean(dim=3).save(fname_moco_mean)

    # Extract and output the motion parameters (doesn't work for sagittal orientation)
    printv('Extract motion parameters...')
    if param.output_motion_param:
        if param.is_sagittal:
            printv(
                'Motion parameters cannot be generated for sagittal images.',
                1, 'warning')
        else:
            files_warp_X, files_warp_Y = [], []
            moco_param = []
            for fname_warp in file_mat_data[0]:
                # Cropping the image to keep only one voxel in the XY plane
                im_warp = Image(fname_warp + param.suffix_mat)
                im_warp.data = np.expand_dims(np.expand_dims(
                    im_warp.data[0, 0, :, :, :], axis=0),
                                              axis=0)

                # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters
                #fname_warp_crop = fname_warp + '_crop_' + ext_mat
                # files_warp.append(fname_warp_crop)
                # im_warp.save(fname_warp_crop)

                # Separating the three components and saving X and Y only (Z is equal to 0 by default).
                im_warp_XYZ = multicomponent_split(im_warp)

                fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat
                im_warp_XYZ[0].save(fname_warp_crop_X)
                files_warp_X.append(fname_warp_crop_X)

                fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat
                im_warp_XYZ[1].save(fname_warp_crop_Y)
                files_warp_Y.append(fname_warp_crop_Y)

                # Calculating the slice-wise average moco estimate to provide a QC file
                moco_param.append([
                    np.mean(np.ravel(im_warp_XYZ[0].data)),
                    np.mean(np.ravel(im_warp_XYZ[1].data))
                ])

            # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters
            # im_warp = [Image(fname) for fname in files_warp]
            # im_warp_concat = concat_data(im_warp, dim=3)
            # im_warp_concat.save('fmri_moco_params.nii')

            # Concatenating the moco parameters into a time series for X and Y components.
            im_warp_X = [Image(fname) for fname in files_warp_X]
            im_warp_concat = concat_data(im_warp_X, dim=3)
            im_warp_concat.save(file_moco_params_x)

            im_warp_Y = [Image(fname) for fname in files_warp_Y]
            im_warp_concat = concat_data(im_warp_Y, dim=3)
            im_warp_concat.save(file_moco_params_y)

            # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC
            with open(file_moco_params_csv, 'wt') as out_file:
                tsv_writer = csv.writer(out_file, delimiter='\t')
                tsv_writer.writerow(['X', 'Y'])
                for mocop in moco_param:
                    tsv_writer.writerow([mocop[0], mocop[1]])

    # Generate output files
    printv('\nGenerate output files...', param.verbose)
    fname_moco = os.path.join(
        path_out_abs,
        add_suffix(os.path.basename(param.fname_data), param.suffix))
    generate_output_file(im_moco.absolutepath, fname_moco)
    if param.is_diffusion:
        generate_output_file(fname_b0_mean, add_suffix(fname_moco, '_b0_mean'))
        generate_output_file(fname_dwi_mean,
                             add_suffix(fname_moco, '_dwi_mean'))
    else:
        generate_output_file(fname_moco_mean, add_suffix(fname_moco, '_mean'))
    if os.path.exists(file_moco_params_csv):
        generate_output_file(file_moco_params_x,
                             os.path.join(path_out_abs, file_moco_params_x),
                             squeeze_data=False)
        generate_output_file(file_moco_params_y,
                             os.path.join(path_out_abs, file_moco_params_y),
                             squeeze_data=False)
        generate_output_file(file_moco_params_csv,
                             os.path.join(path_out_abs, file_moco_params_csv))

    # Delete temporary files
    if param.remove_temp_files == 1:
        printv('\nDelete temporary files...', param.verbose)
        rmtree(path_tmp, verbose=param.verbose)

    # come back to working directory
    os.chdir(curdir)

    # display elapsed time
    elapsed_time = time.time() - start_time
    printv(
        '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's',
        param.verbose)

    display_viewer_syntax([
        os.path.join(
            param.path_out,
            add_suffix(os.path.basename(param.fname_data), param.suffix)),
        param.fname_data
    ],
                          mode='ortho,ortho')
def pre_processing(fname_target,
                   fname_sc_seg,
                   fname_level=None,
                   fname_manual_gmseg=None,
                   new_res=0.3,
                   square_size_size_mm=22.5,
                   denoising=True,
                   verbose=1,
                   rm_tmp=True,
                   for_model=False):
    printv('\nPre-process data...', verbose, 'normal')

    tmp_dir = tmp_create()

    copy(fname_target, tmp_dir)
    fname_target = ''.join(extract_fname(fname_target)[1:])
    copy(fname_sc_seg, tmp_dir)
    fname_sc_seg = ''.join(extract_fname(fname_sc_seg)[1:])

    curdir = os.getcwd()
    os.chdir(tmp_dir)

    original_info = {
        'orientation': None,
        'im_sc_seg_rpi': None,
        'interpolated_images': []
    }

    im_target = Image(fname_target).copy()
    im_sc_seg = Image(fname_sc_seg).copy()

    # get original orientation
    printv('  Reorient...', verbose, 'normal')
    original_info['orientation'] = im_target.orientation

    # assert images are in the same orientation
    assert im_target.orientation == im_sc_seg.orientation, "ERROR: the image to segment and it's SC segmentation are not in the same orientation"

    im_target_rpi = im_target.copy().change_orientation(
        'RPI', generate_path=True).save()
    im_sc_seg_rpi = im_sc_seg.copy().change_orientation(
        'RPI', generate_path=True).save()
    original_info['im_sc_seg_rpi'] = im_sc_seg_rpi.copy(
    )  # target image in RPI will be used to post-process segmentations

    # denoise using P. Coupe non local means algorithm (see [Manjon et al. JMRI 2010]) implemented in dipy
    if denoising:
        printv('  Denoise...', verbose, 'normal')
        # crop image before denoising to fasten denoising
        nx, ny, nz, nt, px, py, pz, pt = im_target_rpi.dim
        size_x, size_y = (square_size_size_mm + 1) / px, (square_size_size_mm +
                                                          1) / py
        size = int(np.ceil(max(size_x, size_y)))
        # create mask
        fname_mask = 'mask_pre_crop.nii.gz'
        sct_create_mask.main([
            '-i', im_target_rpi.absolutepath, '-p',
            'centerline,' + im_sc_seg_rpi.absolutepath, '-f', 'box', '-size',
            str(size), '-o', fname_mask
        ])
        # crop image
        cropper = ImageCropper(im_target_rpi)
        cropper.get_bbox_from_mask(Image(fname_mask))
        im_target_rpi_crop = cropper.crop()
        # crop segmentation
        cropper = ImageCropper(im_sc_seg_rpi)
        cropper.get_bbox_from_mask(Image(fname_mask))
        im_sc_seg_rpi_crop = cropper.crop()
        # denoising
        from spinalcordtoolbox.math import denoise_nlmeans
        block_radius = 3
        block_radius = int(
            im_target_rpi_crop.data.shape[2] /
            2) if im_target_rpi_crop.data.shape[2] < (block_radius *
                                                      2) else block_radius
        patch_radius = block_radius - 1
        data_denoised = denoise_nlmeans(im_target_rpi_crop.data,
                                        block_radius=block_radius,
                                        patch_radius=patch_radius)
        im_target_rpi_crop.data = data_denoised

        im_target_rpi = im_target_rpi_crop
        im_sc_seg_rpi = im_sc_seg_rpi_crop
    else:
        fname_mask = None

    # interpolate image to reference square image (resample and square crop centered on SC)
    printv('  Interpolate data to the model space...', verbose, 'normal')
    list_im_slices = interpolate_im_to_ref(im_target_rpi,
                                           im_sc_seg_rpi,
                                           new_res=new_res,
                                           sq_size_size_mm=square_size_size_mm)
    original_info[
        'interpolated_images'] = list_im_slices  # list of images (not Slice() objects)

    printv('  Mask data using the spinal cord segmentation...', verbose,
           'normal')
    list_sc_seg_slices = interpolate_im_to_ref(
        im_sc_seg_rpi,
        im_sc_seg_rpi,
        new_res=new_res,
        sq_size_size_mm=square_size_size_mm,
        interpolation_mode=1)
    for i in range(len(list_im_slices)):
        # list_im_slices[i].data[list_sc_seg_slices[i].data == 0] = 0
        list_sc_seg_slices[i] = binarize(list_sc_seg_slices[i],
                                         thr_min=0.5,
                                         thr_max=1)
        list_im_slices[
            i].data = list_im_slices[i].data * list_sc_seg_slices[i].data

    printv('  Split along rostro-caudal direction...', verbose, 'normal')
    list_slices_target = [
        Slice(slice_id=i, im=im_slice.data, gm_seg=[], wm_seg=[])
        for i, im_slice in enumerate(list_im_slices)
    ]

    # load vertebral levels
    if fname_level is not None:
        printv('  Load vertebral levels...', verbose, 'normal')
        # copy level file to tmp dir
        os.chdir(curdir)
        copy(fname_level, tmp_dir)
        os.chdir(tmp_dir)
        # change fname level to only file name (path = tmp dir now)
        fname_level = ''.join(extract_fname(fname_level)[1:])
        # load levels
        list_slices_target = load_level(list_slices_target, fname_level)

    os.chdir(curdir)

    # load manual gmseg if there is one (model data)
    if fname_manual_gmseg is not None:
        printv('\n\tLoad manual GM segmentation(s) ...', verbose, 'normal')
        list_slices_target = load_manual_gmseg(list_slices_target,
                                               fname_manual_gmseg,
                                               tmp_dir,
                                               im_sc_seg_rpi,
                                               new_res,
                                               square_size_size_mm,
                                               for_model=for_model,
                                               fname_mask=fname_mask)

    if rm_tmp:
        # remove tmp folder
        rmtree(tmp_dir)
    return list_slices_target, original_info
def main(argv=None):
    parser = get_parser()
    arguments = parser.parse_args(argv if argv else ['--help'])
    verbose = arguments.v
    set_global_loglevel(verbose=verbose)

    fname_input1 = arguments.i
    fname_input2 = arguments.d

    tmp_dir = tmp_create()  # create tmp directory
    tmp_dir = os.path.abspath(tmp_dir)

    # copy input files to tmp directory
    # for fname in [fname_input1, fname_input2]:
    copy(fname_input1, tmp_dir)
    copy(fname_input2, tmp_dir)
    fname_input1 = ''.join(extract_fname(fname_input1)[1:])
    fname_input2 = ''.join(extract_fname(fname_input2)[1:])

    curdir = os.getcwd()
    os.chdir(tmp_dir)  # go to tmp directory

    if arguments.bin is not None:
        fname_input1_bin = add_suffix(fname_input1, '_bin')
        run_proc(['sct_maths', '-i', fname_input1, '-bin', '0', '-o', fname_input1_bin])
        fname_input1 = fname_input1_bin
        fname_input2_bin = add_suffix(fname_input2, '_bin')
        run_proc(['sct_maths', '-i', fname_input2, '-bin', '0', '-o', fname_input2_bin])
        fname_input2 = fname_input2_bin

    # copy header of im_1 to im_2
    from spinalcordtoolbox.image import Image
    im_1 = Image(fname_input1)
    im_2 = Image(fname_input2)
    im_2.header = im_1.header
    im_2.save()

    cmd = ['isct_dice_coefficient', fname_input1, fname_input2]

    if vars(arguments)["2d_slices"] is not None:
        cmd += ['-2d-slices', str(vars(arguments)["2d_slices"])]
    if arguments.b is not None:
        bounding_box = arguments.b
        cmd += ['-b'] + bounding_box
    if arguments.bmax is not None and arguments.bmax == 1:
        cmd += ['-bmax']
    if arguments.bzmax is not None and arguments.bzmax == 1:
        cmd += ['-bzmax']
    if arguments.o is not None:
        path_output, fname_output, ext = extract_fname(arguments.o)
        cmd += ['-o', fname_output + ext]

    rm_tmp = bool(arguments.r)

    # # Computation of Dice coefficient using Python implementation.
    # # commented for now as it does not cover all the feature of isct_dice_coefficient
    # #from spinalcordtoolbox.image import Image, compute_dice
    # #dice = compute_dice(Image(fname_input1), Image(fname_input2), mode='3d', zboundaries=False)
    # #printv('Dice (python-based) = ' + str(dice), verbose)

    status, output = run_proc(cmd, verbose, is_sct_binary=True)

    os.chdir(curdir)  # go back to original directory

    # copy output file into original directory
    if arguments.o is not None:
        copy(os.path.join(tmp_dir, fname_output + ext), os.path.join(path_output, fname_output + ext))

    # remove tmp_dir
    if rm_tmp:
        rmtree(tmp_dir)

    printv(output, verbose)
Exemple #23
0
    if arguments.bmax is not None and arguments.bmax == 1:
        cmd += ['-bmax']
    if arguments.bzmax is not None and arguments.bzmax == 1:
        cmd += ['-bzmax']
    if arguments.o is not None:
        path_output, fname_output, ext = extract_fname(arguments.o)
        cmd += ['-o', fname_output + ext]

    rm_tmp = bool(arguments.r)

    # # Computation of Dice coefficient using Python implementation.
    # # commented for now as it does not cover all the feature of isct_dice_coefficient
    # #from spinalcordtoolbox.image import Image, compute_dice
    # #dice = compute_dice(Image(fname_input1), Image(fname_input2), mode='3d', zboundaries=False)
    # #printv('Dice (python-based) = ' + str(dice), verbose)

    status, output = run_proc(cmd, verbose, is_sct_binary=True)

    os.chdir(curdir)  # go back to original directory

    # copy output file into original directory
    if arguments.o is not None:
        copy(os.path.join(tmp_dir, fname_output + ext),
             os.path.join(path_output, fname_output + ext))

    # remove tmp_dir
    if rm_tmp:
        rmtree(tmp_dir)

    printv(output, verbose)