def apply_transfo(im_src, im_dest, warp, interp='spline', rm_tmp=True): # create tmp dir and go in it tmp_dir = tmp_create() # copy warping field to tmp dir copy(warp, tmp_dir) warp = ''.join(extract_fname(warp)[1:]) # go to tmp dir curdir = os.getcwd() os.chdir(tmp_dir) # save image and seg fname_src = 'src.nii.gz' im_src.save(fname_src) fname_dest = 'dest.nii.gz' im_dest.save(fname_dest) # apply warping field fname_src_reg = add_suffix(fname_src, '_reg') sct_apply_transfo.main( argv=['-i', fname_src, '-d', fname_dest, '-w', warp, '-x', interp]) im_src_reg = Image(fname_src_reg) # get out of tmp dir os.chdir(curdir) if rm_tmp: # remove tmp dir rmtree(tmp_dir) # return res image return im_src_reg
def visualize_warp(im_warp: Image, im_grid: Image = None, step=3, rm_tmp=True): fname_warp = im_warp.absolutepath if im_grid: fname_grid = im_grid.absolutepath else: tmp_dir = tmp_create() nx, ny, nz = im_warp.data.shape[0:3] curdir = os.getcwd() os.chdir(tmp_dir) sq = np.zeros((step, step)) sq[step - 1] = 1 sq[:, step - 1] = 1 dat = np.zeros((nx, ny, nz)) for i in range(0, dat.shape[0], step): for j in range(0, dat.shape[1], step): for k in range(dat.shape[2]): if dat[i:i + step, j:j + step, k].shape == (step, step): dat[i:i + step, j:j + step, k] = sq im_grid = Image(param=dat) grid_hdr = im_warp.hdr im_grid.hdr = grid_hdr fname_grid = 'grid_' + str(step) + '.nii.gz' im_grid.save(fname_grid) fname_grid_resample = add_suffix(fname_grid, '_resample') sct_resample.main(argv=['-i', fname_grid, '-f', '3x3x1', '-x', 'nn', '-o', fname_grid_resample]) fname_grid = os.path.join(tmp_dir, fname_grid_resample) os.chdir(curdir) path_warp, file_warp, ext_warp = extract_fname(fname_warp) grid_warped = os.path.join(path_warp, extract_fname(fname_grid)[1] + '_' + file_warp + ext_warp) sct_apply_transfo.main(argv=['-i', fname_grid, '-d', fname_grid, '-w', fname_warp, '-o', grid_warped]) if rm_tmp: rmtree(tmp_dir)
def test_integrity(param_test): """ Test integrity of function """ # fetch index of the test being performed index_args = param_test.default_args.index(param_test.args) # apply transformation to binary mask: template --> anat sct_apply_transfo.main(args=[ '-i', param_test.fname_gt[index_args], '-d', param_test.file_seg, '-w', 'warp_template2anat.nii.gz', '-o', 'test_template2anat.nii.gz', '-x', 'nn', '-v', '0' ]) # apply transformation to binary mask: anat --> template sct_apply_transfo.main(args=[ '-i', param_test.file_seg, '-d', param_test.fname_gt[index_args], '-w', 'warp_anat2template.nii.gz', '-o', 'test_anat2template.nii.gz', '-x', 'nn', '-v', '0' ]) # compute dice coefficient between template segmentation warped to anat and segmentation from anat im_seg = Image(param_test.file_seg) im_template_seg_reg = Image('test_template2anat.nii.gz') dice_template2anat = msct_image.compute_dice(im_seg, im_template_seg_reg, mode='3d', zboundaries=True) # check param_test.output += 'Dice[seg,template_seg_reg]: ' + str( dice_template2anat) if dice_template2anat > param_test.dice_threshold: param_test.output += '\n--> PASSED' else: param_test.status = 99 param_test.output += '\n--> FAILED' # compute dice coefficient between anat segmentation warped to template and segmentation from template im_seg_reg = Image('test_anat2template.nii.gz') im_template_seg = Image(param_test.fname_gt[index_args]) dice_anat2template = msct_image.compute_dice(im_seg_reg, im_template_seg, mode='3d', zboundaries=True) # check param_test.output += '\n\nDice[seg_reg,template_seg]: ' + str( dice_anat2template) if dice_anat2template > param_test.dice_threshold: param_test.output += '\n--> PASSED' else: param_test.status = 99 param_test.output += '\n--> FAILED' # update Panda structure param_test.results['dice_template2anat'] = dice_template2anat param_test.results['dice_anat2template'] = dice_anat2template return param_test
def test_sct_apply_transfo_output_image_attributes(path_in, path_dest, path_warp, path_out, remaining_args): """Run the CLI script and verify transformed images have expected attributes.""" sct_apply_transfo.main(argv=['-i', path_in, '-d', path_dest, '-w', path_warp, '-o', path_out] + remaining_args) img_src = Image(path_in) img_ref = Image(path_dest) img_output = Image(path_out) assert img_output.orientation == img_ref.orientation assert (img_output.data != 0).any() # Only checking the first 3 dimensions because one test involves a 4D volume assert img_ref.dim[0:3] == img_output.dim[0:3] # Checking the 4th dim (which should be the same as the input image, not the reference image) assert img_src.dim[3] == img_output.dim[3]
def warp_label(path_label, folder_label, file_label, fname_src, fname_transfo, path_out, list_labels_nn, verbose): """ Warp label files according to info_label.txt file :param path_label: :param folder_label: :param file_label: :param fname_src: :param fname_transfo: :param path_out: :param list_labels_nn: :param verbose: :return: """ try: # Read label file template_label_ids, template_label_names, template_label_file, combined_labels_ids, combined_labels_names, \ combined_labels_id_groups, clusters_apriori = \ spinalcordtoolbox.metadata.read_label_file(os.path.join(path_label, folder_label), file_label) except Exception as error: printv( '\nWARNING: Cannot warp label ' + folder_label + ': ' + str(error), 1, 'warning') raise else: # create output folder if not os.path.exists(os.path.join(path_out, folder_label)): os.makedirs(os.path.join(path_out, folder_label)) # Warp label for i in range(0, len(template_label_file)): fname_label = os.path.join(path_label, folder_label, template_label_file[i]) # apply transfo sct_apply_transfo.main([ '-i', fname_label, '-d', fname_src, '-w', fname_transfo, '-o', os.path.join(path_out, folder_label, template_label_file[i]), '-x', get_interp(template_label_file[i], list_labels_nn), '-v', '0' ]) # Copy list.txt copy(os.path.join(path_label, folder_label, file_label), os.path.join(path_out, folder_label))
def test_sct_register_to_template_dice_coefficient_against_groundtruth( fname_gt, remaining_args): """Run the CLI script and verify transformed images have expected attributes.""" fname_seg = 't2/t2_seg-manual.nii.gz' dice_threshold = 0.9 sct_register_to_template.main( argv=['-i', 't2/t2.nii.gz', '-s', fname_seg] + remaining_args) # apply transformation to binary mask: template --> anat sct_apply_transfo.main(argv=[ '-i', fname_gt, '-d', fname_seg, '-w', 'warp_template2anat.nii.gz', '-o', 'test_template2anat.nii.gz', '-x', 'nn', '-v', '0' ]) # apply transformation to binary mask: anat --> template sct_apply_transfo.main(argv=[ '-i', fname_seg, '-d', fname_gt, '-w', 'warp_anat2template.nii.gz', '-o', 'test_anat2template.nii.gz', '-x', 'nn', '-v', '0' ]) # compute dice coefficient between template segmentation warped to anat and segmentation from anat im_seg = Image(fname_seg) im_template_seg_reg = Image('test_template2anat.nii.gz') dice_template2anat = compute_dice(im_seg, im_template_seg_reg, mode='3d', zboundaries=True) assert dice_template2anat > dice_threshold # compute dice coefficient between anat segmentation warped to template and segmentation from template im_seg_reg = Image('test_anat2template.nii.gz') im_template_seg = Image(fname_gt) dice_anat2template = compute_dice(im_seg_reg, im_template_seg, mode='3d', zboundaries=True) assert dice_anat2template > dice_threshold
def register(param, file_src, file_dest, file_mat, file_out, im_mask=None): """ Register two images by estimating slice-wise Tx and Ty transformations, which are regularized along Z. This function uses ANTs' isct_antsSliceRegularizedRegistration. :param param: :param file_src: :param file_dest: :param file_mat: :param file_out: :param im_mask: Image of mask, could be 2D or 3D :return: """ # TODO: deal with mask # initialization failed_transfo = 0 # by default, failed matrix is 0 (i.e., no failure) do_registration = True # get metric radius (if MeanSquares, CC) or nb bins (if MI) if param.metric == 'MI': metric_radius = '16' else: metric_radius = '4' file_out_concat = file_out kw = dict() im_data = Image( file_src ) # TODO: pass argument to use antsReg instead of opening Image each time # register file_src to file_dest if param.todo == 'estimate' or param.todo == 'estimate_and_apply': # If orientation is sagittal, use antsRegistration in 2D mode # Note: the parameter --restrict-deformation is irrelevant with affine transfo if param.sampling == 'None': # 'None' sampling means 'fully dense' sampling # see https://github.com/ANTsX/ANTs/wiki/antsRegistration-reproducibility-issues sampling = param.sampling else: # param.sampling should be a float in [0,1], and means the # samplingPercentage that chooses a subset of points to # estimate from. We always use 'Regular' (evenly-spaced) # mode, though antsRegistration offers 'Random' as well. # Be aware: even 'Regular' is not fully deterministic: # > Regular includes a random perturbation on the grid sampling # - https://github.com/ANTsX/ANTs/issues/976#issuecomment-602313884 sampling = 'Regular,' + param.sampling if im_data.orientation[2] in 'LR': cmd = [ 'isct_antsRegistration', '-d', '2', '--transform', 'Affine[%s]' % param.gradStep, '--metric', param.metric + '[' + file_dest + ',' + file_src + ',1,' + metric_radius + ',' + sampling + ']', '--convergence', param.iter, '--shrink-factors', '1', '--smoothing-sigmas', param.smooth, '--verbose', '1', '--output', '[' + file_mat + ',' + file_out_concat + ']' ] cmd += get_interpolation('isct_antsRegistration', param.interp) if im_mask is not None: # if user specified a mask, make sure there are non-null voxels in the image before running the registration if np.count_nonzero(im_mask.data): cmd += ['--masks', im_mask.absolutepath] else: # Mask only contains zeros. Copying the image instead of estimating registration. copy(file_src, file_out_concat, verbose=0) do_registration = False # TODO: create affine mat file with identity, in case used by -g 2 # 3D mode else: cmd = [ 'isct_antsSliceRegularizedRegistration', '--polydegree', param.poly, '--transform', 'Translation[%s]' % param.gradStep, '--metric', param.metric + '[' + file_dest + ',' + file_src + ',1,' + metric_radius + ',' + sampling + ']', '--iterations', param.iter, '--shrinkFactors', '1', '--smoothingSigmas', param.smooth, '--verbose', '1', '--output', '[' + file_mat + ',' + file_out_concat + ']' ] cmd += get_interpolation('isct_antsSliceRegularizedRegistration', param.interp) if im_mask is not None: cmd += ['--mask', im_mask.absolutepath] # run command if do_registration: kw.update(dict(is_sct_binary=True)) # reducing the number of CPU used for moco (see issue #201 and #2642) env = { **os.environ, **{ "ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS": "1" } } status, output = run_proc(cmd, verbose=1 if param.verbose == 2 else 0, env=env, **kw) elif param.todo == 'apply': sct_apply_transfo.main(argv=[ '-i', file_src, '-d', file_dest, '-w', file_mat + param.suffix_mat, '-o', file_out_concat, '-x', param.interp, '-v', '0' ]) # check if output file exists # Note (from JCA): In the past, i've tried to catch non-zero output from ANTs function (via the 'status' variable), # but in some OSs, the function can fail while outputing zero. So as a pragmatic approach, I decided to go with # the "output file checking" approach, which is 100% sensitive. if not os.path.isfile(file_out_concat): # printv(output, verbose, 'error') printv( 'WARNING in ' + os.path.basename(__file__) + ': No output. Maybe related to improper calculation of ' 'mutual information. Either the mask you provided is ' 'too small, or the subject moved a lot. If you see too ' 'many messages like this try with a bigger mask. ' 'Using previous transformation for this volume (if it' 'exists).', param.verbose, 'warning') failed_transfo = 1 # If sagittal, copy header (because ANTs screws it) and add singleton in 3rd dimension (for z-concatenation) if im_data.orientation[2] in 'LR' and do_registration: im_out = Image(file_out_concat) im_out.header = im_data.header im_out.data = np.expand_dims(im_out.data, 2) im_out.save(file_out, verbose=0) # return status of failure return failed_transfo
def moco(param): """ Main function that performs motion correction. :param param: :return: """ # retrieve parameters file_data = param.file_data file_target = param.file_target folder_mat = param.mat_moco # output folder of mat file todo = param.todo suffix = param.suffix verbose = param.verbose # other parameters file_mask = 'mask.nii' printv('\nInput parameters:', param.verbose) printv(' Input file ............ ' + file_data, param.verbose) printv(' Reference file ........ ' + file_target, param.verbose) printv(' Polynomial degree ..... ' + param.poly, param.verbose) printv(' Smoothing kernel ...... ' + param.smooth, param.verbose) printv(' Gradient step ......... ' + param.gradStep, param.verbose) printv(' Metric ................ ' + param.metric, param.verbose) printv(' Sampling .............. ' + param.sampling, param.verbose) printv(' Todo .................. ' + todo, param.verbose) printv(' Mask ................. ' + param.fname_mask, param.verbose) printv(' Output mat folder ..... ' + folder_mat, param.verbose) try: os.makedirs(folder_mat) except FileExistsError: pass # Get size of data printv('\nData dimensions:', verbose) im_data = Image(param.file_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim printv( (' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)), verbose) # copy file_target to a temporary file printv('\nCopy file_target to a temporary file...', verbose) file_target = "target.nii.gz" convert(param.file_target, file_target, verbose=0) # Check if user specified a mask if not param.fname_mask == '': # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask) im_mask = Image(param.fname_mask) if not np.array_equal(im_mask.data, im_mask.data.astype(bool)): # If it is a soft mask, multiply the target by the soft mask. im = Image(file_target) im_masked = im.copy() im_masked.data = im.data * im_mask.data im_masked.save( verbose=0) # silence warning about file overwritting # If scan is sagittal, split src and target along Z (slice) if param.is_sagittal: dim_sag = 2 # TODO: find it # z-split data (time series) im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False) file_data_splitZ = [] for im_z in im_z_list: im_z.save(verbose=0) file_data_splitZ.append(im_z.absolutepath) # z-split target im_targetz_list = split_data(Image(file_target), dim=dim_sag, squeeze_data=False) file_target_splitZ = [] for im_targetz in im_targetz_list: im_targetz.save(verbose=0) file_target_splitZ.append(im_targetz.absolutepath) # z-split mask (if exists) if not param.fname_mask == '': im_maskz_list = split_data(Image(file_mask), dim=dim_sag, squeeze_data=False) file_mask_splitZ = [] for im_maskz in im_maskz_list: im_maskz.save(verbose=0) file_mask_splitZ.append(im_maskz.absolutepath) # initialize file list for output matrices file_mat = np.empty((nz, nt), dtype=object) # axial orientation else: file_data_splitZ = [file_data] # TODO: make it absolute like above file_target_splitZ = [file_target] # TODO: make it absolute like above # initialize file list for output matrices file_mat = np.empty((1, nt), dtype=object) # deal with mask if not param.fname_mask == '': convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0) im_maskz_list = [Image(file_mask) ] # use a list with single element # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise) # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)]) file_data_splitZ_moco = [] printv( '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)' ) for file in file_data_splitZ: iz = file_data_splitZ.index(file) # Split data along T dimension # printv('\nSplit data along T dimension.', verbose) im_z = Image(file) list_im_zt = split_data(im_z, dim=3) file_data_splitZ_splitT = [] for im_zt in list_im_zt: im_zt.save(verbose=0) file_data_splitZ_splitT.append(im_zt.absolutepath) # file_data_splitT = file_data + '_T' # Motion correction: initialization index = np.arange(nt) file_data_splitT_num = [] file_data_splitZ_splitT_moco = [] failed_transfo = [0 for i in range(nt)] # Motion correction: Loop across T for indice_index in sct_progress_bar(range(nt), unit='iter', unit_scale=False, desc="Z=" + str(iz) + "/" + str(len(file_data_splitZ) - 1), ascii=False, ncols=80): # create indices and display stuff it = index[indice_index] file_mat[iz][it] = os.path.join( folder_mat, "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4) file_data_splitZ_splitT_moco.append( add_suffix(file_data_splitZ_splitT[it], '_moco')) # deal with masking (except in the 'apply' case, where masking is irrelevant) input_mask = None if not param.fname_mask == '' and not param.todo == 'apply': # Check if mask is binary if np.array_equal(im_maskz_list[iz].data, im_maskz_list[iz].data.astype(bool)): # If it is, pass this mask into register() to be used input_mask = im_maskz_list[iz] else: # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks. # Instead, multiply the input data by the Gaussian mask. im = Image(file_data_splitZ_splitT[it]) im_masked = im.copy() im_masked.data = im.data * im_maskz_list[iz].data im_masked.save( verbose=0) # silence warning about file overwritting # run 3D registration failed_transfo[it] = register(param, file_data_splitZ_splitT[it], file_target_splitZ[iz], file_mat[iz][it], file_data_splitZ_splitT_moco[it], im_mask=input_mask) # average registered volume with target image # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1) if param.iterAvg and indice_index < 10 and failed_transfo[ it] == 0 and not param.todo == 'apply': im_targetz = Image(file_target_splitZ[iz]) data_targetz = im_targetz.data data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data data_targetz = (data_targetz * (indice_index + 1) + data_mocoz) / (indice_index + 2) im_targetz.data = data_targetz im_targetz.save(verbose=0) # Replace failed transformation with the closest good one fT = [i for i, j in enumerate(failed_transfo) if j == 1] gT = [i for i, j in enumerate(failed_transfo) if j == 0] for it in range(len(fT)): abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))] if not abs_dist == []: index_good = abs_dist.index(min(abs_dist)) printv( ' transfo #' + str(fT[it]) + ' --> use transfo #' + str(gT[index_good]), verbose) # copy transformation copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz', file_mat[iz][fT[it]] + 'Warp.nii.gz') # apply transformation sct_apply_transfo.main(argv=[ '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target, '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o', file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp ]) else: # exit program if no transformation exists. printv( '\nERROR in ' + os.path.basename(__file__) + ': No good transformation exist. Exit program.\n', verbose, 'error') sys.exit(2) # Merge data along T file_data_splitZ_moco.append(add_suffix(file, suffix)) if todo != 'estimate': im_data_splitZ_splitT_moco = [ Image(fname) for fname in file_data_splitZ_splitT_moco ] im_out = concat_data(im_data_splitZ_splitT_moco, 3) im_out.absolutepath = file_data_splitZ_moco[iz] im_out.save(verbose=0) # If sagittal, merge along Z if param.is_sagittal: # TODO: im_out.dim is incorrect: Z value is one im_data_splitZ_moco = [Image(fname) for fname in file_data_splitZ_moco] im_out = concat_data(im_data_splitZ_moco, 2) dirname, basename, ext = extract_fname(file_data) path_out = os.path.join(dirname, basename + suffix + ext) im_out.absolutepath = path_out im_out.save(verbose=0) return file_mat, im_out
def merge_images(list_fname_src, fname_dest, list_fname_warp, param): """ Merge multiple source images onto destination space. All images are warped to the destination space and then added. To deal with overlap during merging (e.g. one voxel in destination image is shared with two input images), the resulting voxel is divided by the sum of the partial volume of each image. For example, if src(x,y,z)=1 is mapped to dest(i,j,k) with a partial volume of 0.5 (because destination voxel is bigger), then its value after linear interpolation will be 0.5. To account for partial volume, the resulting voxel will be: dest(i,j,k) = 0.5*0.5/0.5 = 0.5. Now, if two voxels overlap in the destination space, let's say: src(x,y,z)=1 and src2'(x',y',z')=1, then the resulting value will be: dest(i,j,k) = (0.5*0.5 + 0.5*0.5) / (0.5+0.5) = 0.5. So this function acts like a weighted average operator, only in destination voxels that share multiple source voxels. Parameters ---------- list_fname_src fname_dest list_fname_warp param Returns ------- """ # create temporary folder path_tmp = tmp_create() # get dimensions of destination file nii_dest = Image(fname_dest) # initialize variables data = np.zeros([ nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2], len(list_fname_src) ]) partial_volume = np.zeros([ nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2], len(list_fname_src) ]) data_merge = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2]]) # loop across files i_file = 0 for fname_src in list_fname_src: # apply transformation src --> dest sct_apply_transfo.main(argv=[ '-i', fname_src, '-d', fname_dest, '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' + str(i_file) + '_template.nii.gz', '-v', str(param.verbose) ]) # create binary mask from input file by assigning one to all non-null voxels img = Image(fname_src) out = img.copy() out.data = binarize(out.data, param.almost_zero) out.save(path=f"src_{i_file}native_bin.nii.gz") # apply transformation to binary mask to compute partial volume sct_apply_transfo.main(argv=[ '-i', 'src_' + str(i_file) + 'native_bin.nii.gz', '-d', fname_dest, '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' + str(i_file) + '_template_partialVolume.nii.gz' ]) # open data data[:, :, :, i_file] = Image('src_' + str(i_file) + '_template.nii.gz').data partial_volume[:, :, :, i_file] = Image('src_' + str(i_file) + '_template_partialVolume.nii.gz').data i_file += 1 # merge files using partial volume information (and convert nan resulting from division by zero to zeros) data_merge = np.divide(np.sum(data * partial_volume, axis=3), np.sum(partial_volume, axis=3)) data_merge = np.nan_to_num(data_merge) # write result in file nii_dest.data = data_merge nii_dest.save(param.fname_out) # remove temporary folder if param.rm_tmp: rmtree(path_tmp)
def main(argv=None): parser = get_parser() arguments = parser.parse_args(argv) verbose = arguments.v set_global_loglevel(verbose=verbose) # initializations initz = '' initcenter = '' fname_initlabel = '' file_labelz = 'labelz.nii.gz' param = Param() fname_in = os.path.abspath(arguments.i) fname_seg = os.path.abspath(arguments.s) contrast = arguments.c path_template = os.path.abspath(arguments.t) scale_dist = arguments.scale_dist path_output = arguments.ofolder param.path_qc = arguments.qc if arguments.discfile is not None: fname_disc = os.path.abspath(arguments.discfile) else: fname_disc = None if arguments.initz is not None: initz = arguments.initz if len(initz) != 2: raise ValueError('--initz takes two arguments: position in superior-inferior direction, label value') if arguments.initcenter is not None: initcenter = arguments.initcenter # if user provided text file, parse and overwrite arguments if arguments.initfile is not None: file = open(arguments.initfile, 'r') initfile = ' ' + file.read().replace('\n', '') arg_initfile = initfile.split(' ') for idx_arg, arg in enumerate(arg_initfile): if arg == '-initz': initz = [int(x) for x in arg_initfile[idx_arg + 1].split(',')] if len(initz) != 2: raise ValueError('--initz takes two arguments: position in superior-inferior direction, label value') if arg == '-initcenter': initcenter = int(arg_initfile[idx_arg + 1]) if arguments.initlabel is not None: # get absolute path of label fname_initlabel = os.path.abspath(arguments.initlabel) if arguments.param is not None: param.update(arguments.param[0]) remove_temp_files = arguments.r clean_labels = arguments.clean_labels laplacian = arguments.laplacian path_tmp = tmp_create(basename="label_vertebrae") # Copying input data to tmp folder printv('\nCopying input data to tmp folder...', verbose) Image(fname_in).save(os.path.join(path_tmp, "data.nii")) Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii")) # Go go temp folder curdir = os.getcwd() os.chdir(path_tmp) # Straighten spinal cord printv('\nStraighten spinal cord...', verbose) # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time) cache_sig = cache_signature( input_files=[fname_in, fname_seg], ) cachefile = os.path.join(curdir, "straightening.cache") if cache_valid(cachefile, cache_sig) and os.path.isfile(os.path.join(curdir, "warp_curve2straight.nii.gz")) and os.path.isfile(os.path.join(curdir, "warp_straight2curve.nii.gz")) and os.path.isfile(os.path.join(curdir, "straight_ref.nii.gz")): # if they exist, copy them into current folder printv('Reusing existing warping field which seems to be valid', verbose, 'warning') copy(os.path.join(curdir, "warp_curve2straight.nii.gz"), 'warp_curve2straight.nii.gz') copy(os.path.join(curdir, "warp_straight2curve.nii.gz"), 'warp_straight2curve.nii.gz') copy(os.path.join(curdir, "straight_ref.nii.gz"), 'straight_ref.nii.gz') # apply straightening s, o = run_proc(['sct_apply_transfo', '-i', 'data.nii', '-w', 'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o', 'data_straight.nii']) else: sct_straighten_spinalcord.main(argv=[ '-i', 'data.nii', '-s', 'segmentation.nii', '-r', str(remove_temp_files), '-v', '0', ]) cache_save(cachefile, cache_sig) # resample to 0.5mm isotropic to match template resolution printv('\nResample to 0.5mm isotropic...', verbose) s, o = run_proc(['sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x', 'linear', '-o', 'data_straightr.nii'], verbose=verbose) # Apply straightening to segmentation # N.B. Output is RPI printv('\nApply straightening to segmentation...', verbose) sct_apply_transfo.main(['-i', 'segmentation.nii', '-d', 'data_straightr.nii', '-w', 'warp_curve2straight.nii.gz', '-o', 'segmentation_straight.nii', '-x', 'linear', '-v', '0']) # Threshold segmentation at 0.5 img = Image('segmentation_straight.nii') img.data = threshold(img.data, 0.5) img.save() # If disc label file is provided, label vertebrae using that file instead of automatically if fname_disc: # Apply straightening to disc-label printv('\nApply straightening to disc labels...', verbose) run_proc('sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' % (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz', 'labeldisc_straight.nii.gz', 'label'), verbose=verbose ) label_vert('segmentation_straight.nii', 'labeldisc_straight.nii.gz', verbose=1) else: # create label to identify disc printv('\nCreate label to identify disc...', verbose) fname_labelz = os.path.join(path_tmp, file_labelz) if initz or initcenter: if initcenter: # find z centered in FOV nii = Image('segmentation.nii').change_orientation("RPI") nx, ny, nz, nt, px, py, pz, pt = nii.dim # Get dimensions z_center = int(np.round(nz / 2)) # get z_center initz = [z_center, initcenter] im_label = create_labels_along_segmentation(Image('segmentation.nii'), [(initz[0], initz[1])]) im_label.data = dilate(im_label.data, 3, 'ball') im_label.save(fname_labelz) elif fname_initlabel: Image(fname_initlabel).save(fname_labelz) else: # automatically finds C2-C3 disc im_data = Image('data.nii') im_seg = Image('segmentation.nii') if not remove_temp_files: # because verbose is here also used for keeping temp files verbose_detect_c2c3 = 2 else: verbose_detect_c2c3 = 0 im_label_c2c3 = detect_c2c3(im_data, im_seg, contrast, verbose=verbose_detect_c2c3) ind_label = np.where(im_label_c2c3.data) if not np.size(ind_label) == 0: im_label_c2c3.data[ind_label] = 3 else: printv('Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils', 1, 'error') sys.exit() im_label_c2c3.save(fname_labelz) # dilate label so it is not lost when applying warping dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz) # Apply straightening to z-label printv('\nAnd apply straightening to label...', verbose) sct_apply_transfo.main(['-i', file_labelz, '-d', 'data_straightr.nii', '-w', 'warp_curve2straight.nii.gz', '-o', 'labelz_straight.nii.gz', '-x', 'nn', '-v', '0']) # get z value and disk value to initialize labeling printv('\nGet z and disc values from straight label...', verbose) init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz') printv('.. ' + str(init_disc), verbose) # apply laplacian filtering if laplacian: printv('\nApply Laplacian filter...', verbose) img = Image("data_straightr.nii") # apply std dev to each axis of the image sigmas = [1 for i in range(len(img.data.shape))] # adjust sigma based on voxel size sigmas = [sigmas[i] / img.dim[i + 4] for i in range(3)] # smooth data img.data = laplacian(img.data, sigmas) img.save() # detect vertebral levels on straight spinal cord init_disc[1] = init_disc[1] - 1 vertebral_detection('data_straightr.nii', 'segmentation_straight.nii', contrast, param, init_disc=init_disc, verbose=verbose, path_template=path_template, path_output=path_output, scale_dist=scale_dist) # un-straighten labeled spinal cord printv('\nUn-straighten labeling...', verbose) sct_apply_transfo.main(['-i', 'segmentation_straight_labeled.nii', '-d', 'segmentation.nii', '-w', 'warp_straight2curve.nii.gz', '-o', 'segmentation_labeled.nii', '-x', 'nn', '-v', '0']) if clean_labels: # Clean labeled segmentation printv('\nClean labeled segmentation (correct interpolation errors)...', verbose) clean_labeled_segmentation('segmentation_labeled.nii', 'segmentation.nii', 'segmentation_labeled.nii') # label discs printv('\nLabel discs...', verbose) printv('\nUn-straighten labeled discs...', verbose) run_proc('sct_apply_transfo -i %s -d %s -w %s -o %s -x %s' % ('segmentation_straight_labeled_disc.nii', 'segmentation.nii', 'warp_straight2curve.nii.gz', 'segmentation_labeled_disc.nii', 'label'), verbose=verbose, is_sct_binary=True, ) # come back os.chdir(curdir) # Generate output files path_seg, file_seg, ext_seg = extract_fname(fname_seg) fname_seg_labeled = os.path.join(path_output, file_seg + '_labeled' + ext_seg) printv('\nGenerate output files...', verbose) generate_output_file(os.path.join(path_tmp, "segmentation_labeled.nii"), fname_seg_labeled) generate_output_file(os.path.join(path_tmp, "segmentation_labeled_disc.nii"), os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg)) # copy straightening files in case subsequent SCT functions need them generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose=verbose) generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose=verbose) generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose=verbose) # Remove temporary files if remove_temp_files == 1: printv('\nRemove temporary files...', verbose) rmtree(path_tmp) # Generate QC report if param.path_qc is not None: path_qc = os.path.abspath(arguments.qc) qc_dataset = arguments.qc_dataset qc_subject = arguments.qc_subject labeled_seg_file = os.path.join(path_output, file_seg + '_labeled' + ext_seg) generate_qc(fname_in, fname_seg=labeled_seg_file, args=argv, path_qc=os.path.abspath(path_qc), dataset=qc_dataset, subject=qc_subject, process='sct_label_vertebrae') display_viewer_syntax([fname_in, fname_seg_labeled], colormaps=['', 'subcortical'], opacities=['1', '0.5'])
def straighten(self): """ Straighten spinal cord. Steps: (everything is done in physical space) 1. open input image and centreline image 2. extract bspline fitting of the centreline, and its derivatives 3. compute length of centerline 4. compute and generate straight space 5. compute transformations for each voxel of one space: (done using matrices --> improves speed by a factor x300) a. determine which plane of spinal cord centreline it is included b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane) c. find the correspondant centreline point in the other space d. find the correspondance of the voxel in the corresponding plane 6. generate warping fields for each transformations 7. write warping fields and apply them step 5.b: how to find the corresponding plane? The centerline plane corresponding to a voxel correspond to the nearest point of the centerline. However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant. If it is more far than a threshold, warping value should be 0. step 5.d: how to make the correspondance between centerline point in both images? Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve. If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward. :return: """ # Initialization fname_anat = self.input_filename fname_centerline = self.centerline_filename fname_output = self.output_filename remove_temp_files = self.remove_temp_files verbose = self.verbose interpolation_warp = self.interpolation_warp # TODO: remove this # start timer start_time = time.time() # Extract path/file/extension path_anat, file_anat, ext_anat = extract_fname(fname_anat) path_tmp = tmp_create(basename="straighten_spinalcord") # Copying input data to tmp folder logger.info('Copy files to tmp folder...') Image(fname_anat, check_sform=True).save(os.path.join(path_tmp, "data.nii")) Image(fname_centerline, check_sform=True).save(os.path.join(path_tmp, "centerline.nii.gz")) if self.use_straight_reference: Image(self.centerline_reference_filename, check_sform=True).save(os.path.join(path_tmp, "centerline_ref.nii.gz")) if self.discs_input_filename != '': Image(self.discs_input_filename, check_sform=True).save(os.path.join(path_tmp, "labels_input.nii.gz")) if self.discs_ref_filename != '': Image(self.discs_ref_filename, check_sform=True).save(os.path.join(path_tmp, "labels_ref.nii.gz")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # Change orientation of the input centerline into RPI image_centerline = Image("centerline.nii.gz").change_orientation("RPI").save("centerline_rpi.nii.gz", mutable=True) # Get dimension nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if self.speed_factor != 1.0: intermediate_resampling = True px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor else: intermediate_resampling = False if intermediate_resampling: mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz') pz_native = pz # TODO: remove system call run_proc(['sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm', str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o', 'centerline_rpi.nii.gz']) image_centerline = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if np.min(image_centerline.data) < 0 or np.max(image_centerline.data) > 1: image_centerline.data[image_centerline.data < 0] = 0 image_centerline.data[image_centerline.data > 1] = 1 image_centerline.save() # 2. extract bspline fitting of the centerline, and its derivatives img_ctl = Image('centerline_rpi.nii.gz') centerline = _get_centerline(img_ctl, self.param_centerline, verbose) number_of_points = centerline.number_of_points # ========================================================================================== logger.info('Create the straight space and the safe zone') # 3. compute length of centerline # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction # Computation of the safe zone. # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle) # Calculate Ls for both edges and remove appropriate number of centerline points radius_safe = 0.0 # mm # inferior edge u = centerline.derivatives[0] v = np.array([0, 0, -1]) angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_inferior = radius_safe * np.sin(angle_inferior) # superior edge u = centerline.derivatives[-1] v = np.array([0, 0, 1]) angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_superior = radius_safe * np.sin(angle_superior) # remove points inferior_bound = bisect.bisect(centerline.progressive_length, length_safe_inferior) - 1 superior_bound = centerline.number_of_points - bisect.bisect(centerline.progressive_length_inverse, length_safe_superior) z_centerline = centerline.points[:, 2] length_centerline = centerline.length size_z_centerline = z_centerline[-1] - z_centerline[0] # compute the size factor between initial centerline and straight bended centerline factor_curved_straight = length_centerline / size_z_centerline middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0 bound_curved = [z_centerline[inferior_bound], z_centerline[superior_bound]] bound_straight = [(z_centerline[inferior_bound] - middle_slice) * factor_curved_straight + middle_slice, (z_centerline[superior_bound] - middle_slice) * factor_curved_straight + middle_slice] logger.info('Length of spinal cord: {}'.format(length_centerline)) logger.info('Size of spinal cord in z direction: {}'.format(size_z_centerline)) logger.info('Ratio length/size: {}'.format(factor_curved_straight)) logger.info('Safe zone boundaries (curved space): {}'.format(bound_curved)) logger.info('Safe zone boundaries (straight space): {}'.format(bound_straight)) # 4. compute and generate straight space # points along curved centerline are already regularly spaced. # calculate position of points along straight centerline # Create straight NIFTI volumes. # ========================================================================================== # TODO: maybe this if case is not needed? if self.use_straight_reference: image_centerline_pad = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim fname_ref = 'centerline_ref_rpi.nii.gz' image_centerline_straight = Image('centerline_ref.nii.gz') \ .change_orientation("RPI") \ .save(fname_ref, mutable=True) centerline_straight = _get_centerline(image_centerline_straight, self.param_centerline, verbose) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim # Prepare warping fields headers hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.discs_input_filename != "" and self.discs_ref_filename != "": discs_input_image = Image('labels_input.nii.gz') coord = discs_input_image.getNonZeroCoordinates(sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline.compute_vertebral_distribution(coord_physical) centerline.save_centerline(image=discs_input_image, fname_output='discs_input_image.nii.gz') discs_ref_image = Image('labels_ref.nii.gz') coord = discs_ref_image.getNonZeroCoordinates(sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline_straight.compute_vertebral_distribution(coord_physical) centerline_straight.save_centerline(image=discs_ref_image, fname_output='discs_ref_image.nii.gz') else: logger.info('Pad input volume to account for spinal cord length...') start_point, end_point = bound_straight[0], bound_straight[1] offset_z = 0 # if the destination image is resampled, we still create the straight reference space with the native # resolution. # TODO: Maybe this if case is not needed? if intermediate_resampling: padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz_native)) run_proc( ['sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o', 'tmp.centerline_pad_native.nii.gz', '-pad', '0,0,' + str(padding_z)]) image_centerline_pad = Image('centerline_rpi_native.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim start_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0] end_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x] warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord_native[2] - start_point_coord_native[2])), )) spatial_crop(Image("tmp.centerline_pad_native.nii.gz"), spec).save( "tmp.centerline_pad_crop_native.nii.gz") fname_ref = 'tmp.centerline_pad_crop_native.nii.gz' offset_z = 4 else: fname_ref = 'tmp.centerline_pad_crop.nii.gz' nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz)) + offset_z image_centerline_pad = pad_image(image_centerline, pad_z_i=padding_z, pad_z_f=padding_z) nx, ny, nz = image_centerline_pad.data.shape hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x] warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_x[1] >= nx: warp_space_x[1] = nx - 1 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 if warp_space_y[1] >= ny: warp_space_y[1] = ny - 1 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)), )) image_centerline_straight = spatial_crop(image_centerline_pad, spec) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.template_orientation == 1: raise NotImplementedError() start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0] number_of_voxel = nx * ny * nz logger.debug('Number of voxels: {}'.format(number_of_voxel)) time_centerlines = time.time() coord_straight = np.empty((number_of_points, 3)) coord_straight[..., 0] = int(np.round(nx_s / 2)) coord_straight[..., 1] = int(np.round(ny_s / 2)) coord_straight[..., 2] = np.linspace(0, end_point_coord[2] - start_point_coord[2], number_of_points) coord_phys_straight = image_centerline_straight.transfo_pix2phys(coord_straight) derivs_straight = np.empty((number_of_points, 3)) derivs_straight[..., 0] = derivs_straight[..., 1] = 0 derivs_straight[..., 2] = 1 dx_straight, dy_straight, dz_straight = derivs_straight.T centerline_straight = Centerline(coord_phys_straight[:, 0], coord_phys_straight[:, 1], coord_phys_straight[:, 2], dx_straight, dy_straight, dz_straight) time_centerlines = time.time() - time_centerlines logger.info('Time to generate centerline: {} ms'.format(np.round(time_centerlines * 1000.0))) if verbose == 2: # TODO: use OO import matplotlib.pyplot as plt from datetime import datetime curved_points = centerline.progressive_length straight_points = centerline_straight.progressive_length range_points = np.linspace(0, 1, number_of_points) dist_curved = np.zeros(number_of_points) dist_straight = np.zeros(number_of_points) for i in range(1, number_of_points): dist_curved[i] = dist_curved[i - 1] + curved_points[i - 1] / centerline.length dist_straight[i] = dist_straight[i - 1] + straight_points[i - 1] / centerline_straight.length plt.plot(range_points, dist_curved) plt.plot(range_points, dist_straight) plt.grid(True) plt.savefig('fig_straighten_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '.png') plt.close() # alignment_mode = 'length' alignment_mode = 'levels' lookup_curved2straight = list(range(centerline.number_of_points)) if self.discs_input_filename != "": # create look-up table curved to straight for index in range(centerline.number_of_points): disc_label = centerline.l_points[index] if alignment_mode == 'length': relative_position = centerline.dist_points[index] else: relative_position = centerline.dist_points_rel[index] idx_closest = centerline_straight.get_closest_to_absolute_position(disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_curved2straight[index] = idx_closest else: lookup_curved2straight[index] = 0 for p in range(0, len(lookup_curved2straight) // 2): if lookup_curved2straight[p] == lookup_curved2straight[p + 1]: lookup_curved2straight[p] = 0 else: break for p in range(len(lookup_curved2straight) - 1, len(lookup_curved2straight) // 2, -1): if lookup_curved2straight[p] == lookup_curved2straight[p - 1]: lookup_curved2straight[p] = 0 else: break lookup_curved2straight = np.array(lookup_curved2straight) lookup_straight2curved = list(range(centerline_straight.number_of_points)) if self.discs_input_filename != "": for index in range(centerline_straight.number_of_points): disc_label = centerline_straight.l_points[index] if alignment_mode == 'length': relative_position = centerline_straight.dist_points[index] else: relative_position = centerline_straight.dist_points_rel[index] idx_closest = centerline.get_closest_to_absolute_position(disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_straight2curved[index] = idx_closest for p in range(0, len(lookup_straight2curved) // 2): if lookup_straight2curved[p] == lookup_straight2curved[p + 1]: lookup_straight2curved[p] = 0 else: break for p in range(len(lookup_straight2curved) - 1, len(lookup_straight2curved) // 2, -1): if lookup_straight2curved[p] == lookup_straight2curved[p - 1]: lookup_straight2curved[p] = 0 else: break lookup_straight2curved = np.array(lookup_straight2curved) # Create volumes containing curved and straight warping fields data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3)) data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3)) # 5. compute transformations # Curved and straight images and the same dimensions, so we compute both warping fields at the same time. # b. determine which plane of spinal cord centreline it is included if self.curved2straight: for u in sct_progress_bar(range(nz_s)): x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1] indexes_straight = np.array(list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel()))) physical_coordinates_straight = image_centerline_straight.transfo_pix2phys(indexes_straight) nearest_indexes_straight = centerline_straight.find_nearest_indexes(physical_coordinates_straight) distances_straight = centerline_straight.get_distances_from_planes(physical_coordinates_straight, nearest_indexes_straight) lookup = lookup_straight2curved[nearest_indexes_straight] indexes_out_distance_straight = np.logical_or( np.logical_or(distances_straight > self.threshold_distance, distances_straight < -self.threshold_distance), lookup == 0) projected_points_straight = centerline_straight.get_projected_coordinates_on_planes( physical_coordinates_straight, nearest_indexes_straight) coord_in_planes_straight = centerline_straight.get_in_plans_coordinates(projected_points_straight, nearest_indexes_straight) coord_straight2curved = centerline.get_inverse_plans_coordinates(coord_in_planes_straight, lookup) displacements_straight = coord_straight2curved - physical_coordinates_straight # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+) # while ours is LPI- # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a # https://www.slicer.org/wiki/Coordinate_systems displacements_straight[:, 2] = -displacements_straight[:, 2] displacements_straight[indexes_out_distance_straight] = [100000.0, 100000.0, 100000.0] data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\ = -displacements_straight if self.straight2curved: for u in sct_progress_bar(range(nz)): x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1] indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel()))) physical_coordinates = image_centerline_pad.transfo_pix2phys(indexes) nearest_indexes_curved = centerline.find_nearest_indexes(physical_coordinates) distances_curved = centerline.get_distances_from_planes(physical_coordinates, nearest_indexes_curved) lookup = lookup_curved2straight[nearest_indexes_curved] indexes_out_distance_curved = np.logical_or( np.logical_or(distances_curved > self.threshold_distance, distances_curved < -self.threshold_distance), lookup == 0) projected_points_curved = centerline.get_projected_coordinates_on_planes(physical_coordinates, nearest_indexes_curved) coord_in_planes_curved = centerline.get_in_plans_coordinates(projected_points_curved, nearest_indexes_curved) coord_curved2straight = centerline_straight.points[lookup] coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2] coord_curved2straight[:, 2] += distances_curved displacements_curved = coord_curved2straight - physical_coordinates displacements_curved[:, 2] = -displacements_curved[:, 2] displacements_curved[indexes_out_distance_curved] = [100000.0, 100000.0, 100000.0] data_warp_straight2curved[indexes[:, 0], indexes[:, 1], indexes[:, 2], 0, :] = -displacements_curved # Creation of the safe zone based on pre-calculated safe boundaries coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix( [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix([[0, 0, bound_curved[1]]]) coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix( [[0, 0, bound_straight[0]]]), image_centerline_straight.transfo_phys2pix([[0, 0, bound_straight[1]]]) if radius_safe > 0: data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2], 0, :] = 100000.0 data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:, 0, :] = 100000.0 data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2], 0, :] = 100000.0 data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:, 0, :] = 100000.0 # Generate warp files as a warping fields hdr_warp_s.set_intent('vector', (), '') hdr_warp_s.set_data_dtype('float32') hdr_warp.set_intent('vector', (), '') hdr_warp.set_data_dtype('float32') if self.curved2straight: img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s) save(img, 'tmp.curve2straight.nii.gz') logger.info('Warping field generated: tmp.curve2straight.nii.gz') if self.straight2curved: img = Nifti1Image(data_warp_straight2curved, None, hdr_warp) save(img, 'tmp.straight2curve.nii.gz') logger.info('Warping field generated: tmp.straight2curve.nii.gz') image_centerline_straight.save(fname_ref) if self.curved2straight: logger.info('Apply transformation to input image...') sct_apply_transfo.main(['-i', 'data.nii', '-d', fname_ref, '-w', 'tmp.curve2straight.nii.gz', '-o', 'tmp.anat_rigid_warp.nii.gz', '-x', 'spline', '-v', '0']) if self.accuracy_results: time_accuracy_results = time.time() # compute the error between the straightened centerline/segmentation and the central vertical line. # Ideally, the error should be zero. # Apply deformation to input image logger.info('Apply transformation to centerline image...') sct_apply_transfo.main(['-i', 'centerline.nii.gz', '-d', fname_ref, '-w', 'tmp.curve2straight.nii.gz', '-o', 'tmp.centerline_straight.nii.gz', '-x', 'nn', '-v', '0']) file_centerline_straight = Image('tmp.centerline_straight.nii.gz', verbose=verbose) nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting='z') mean_coord = [] for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z): temp_mean = [coord.value for coord in coordinates_centerline if coord.z == z] if temp_mean: mean_value = np.mean(temp_mean) mean_coord.append( np.mean([[coord.x * coord.value / mean_value, coord.y * coord.value / mean_value] for coord in coordinates_centerline if coord.z == z], axis=0)) # compute error between the straightened centerline and the straight line. x0 = file_centerline_straight.data.shape[0] / 2.0 y0 = file_centerline_straight.data.shape[1] / 2.0 count_mean = 0 if number_of_points >= 10: mean_c = mean_coord[2:-2] # we don't include the four extrema because there are usually messy. else: mean_c = mean_coord for coord_z in mean_c: if not np.isnan(np.sum(coord_z)): dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2 self.mse_straightening += dist dist = np.sqrt(dist) if dist > self.max_distance_straightening: self.max_distance_straightening = dist count_mean += 1 self.mse_straightening = np.sqrt(self.mse_straightening / float(count_mean)) self.elapsed_time_accuracy = time.time() - time_accuracy_results os.chdir(curdir) # Generate output file (in current folder) # TODO: do not uncompress the warping field, it is too time consuming! logger.info('Generate output files...') if self.curved2straight: generate_output_file(os.path.join(path_tmp, "tmp.curve2straight.nii.gz"), os.path.join(self.path_output, "warp_curve2straight.nii.gz"), verbose) if self.straight2curved: generate_output_file(os.path.join(path_tmp, "tmp.straight2curve.nii.gz"), os.path.join(self.path_output, "warp_straight2curve.nii.gz"), verbose) # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space if self.curved2straight: copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, "straight_ref.nii.gz")) # move straightened input file if fname_output == '': fname_straight = generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, file_anat + "_straight" + ext_anat), verbose) else: fname_straight = generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, fname_output), verbose) # straightened anatomic # Remove temporary files if remove_temp_files: logger.info('Remove temporary files...') rmtree(path_tmp) if self.accuracy_results: logger.info('Maximum x-y error: {} mm'.format(self.max_distance_straightening)) logger.info('Accuracy of straightening (MSE): {} mm'.format(self.mse_straightening)) # display elapsed time self.elapsed_time = int(np.round(time.time() - start_time)) return fname_straight