def compute_shape(segmentation, angle_correction=True, param_centerline=None, verbose=1): """ Compute morphometric measures of the spinal cord in the transverse (axial) plane from the segmentation. The segmentation could be binary or weighted for partial volume [0,1]. :param segmentation: input segmentation. Could be either an Image or a file name. :param angle_correction: :param param_centerline: see centerline.core.ParamCenterline() :param verbose: :return metrics: Dict of class Metric(). If a metric cannot be calculated, its value will be nan. :return fit_results: class centerline.core.FitResults() """ # List of properties to output (in the right order) property_list = [ 'area', 'angle_AP', 'angle_RL', 'diameter_AP', 'diameter_RL', 'eccentricity', 'orientation', 'solidity', 'length' ] im_seg = Image(segmentation).change_orientation('RPI') # Getting image dimensions. x, y and z respectively correspond to RL, PA and IS. nx, ny, nz, nt, px, py, pz, pt = im_seg.dim pr = min([px, py]) # Resample to isotropic resolution in the axial plane. Use the minimum pixel dimension as target dimension. im_segr = resample_nib(im_seg, new_size=[pr, pr, pz], new_size_type='mm', interpolation='linear') # Update dimensions from resampled image. nx, ny, nz, nt, px, py, pz, pt = im_segr.dim # Extract min and max index in Z direction data_seg = im_segr.data X, Y, Z = (data_seg > 0).nonzero() min_z_index, max_z_index = min(Z), max(Z) # Initialize dictionary of property_list, with 1d array of nan (default value if no property for a given slice). shape_properties = { key: np.full_like(np.empty(nz), np.nan, dtype=np.double) for key in property_list } fit_results = None if angle_correction: # compute the spinal cord centerline based on the spinal cord segmentation # here, param_centerline.minmax needs to be False because we need to retrieve the total number of input slices _, arr_ctl, arr_ctl_der, fit_results = get_centerline( im_segr, param=param_centerline, verbose=verbose) # Loop across z and compute shape analysis for iz in sct_progress_bar(range(min_z_index, max_z_index + 1), unit='iter', unit_scale=False, desc="Compute shape analysis", ascii=True, ncols=80): # Extract 2D patch current_patch = im_segr.data[:, :, iz] if angle_correction: # Extract tangent vector to the centerline (i.e. its derivative) tangent_vect = np.array([ arr_ctl_der[0][iz - min_z_index] * px, arr_ctl_der[1][iz - min_z_index] * py, pz ]) # Normalize vector by its L2 norm tangent_vect = tangent_vect / np.linalg.norm(tangent_vect) # Compute the angle about AP axis between the centerline and the normal vector to the slice v0 = [tangent_vect[0], tangent_vect[2]] v1 = [0, 1] angle_AP_rad = np.math.atan2(np.linalg.det([v0, v1]), np.dot(v0, v1)) # Compute the angle about RL axis between the centerline and the normal vector to the slice v0 = [tangent_vect[1], tangent_vect[2]] v1 = [0, 1] angle_RL_rad = np.math.atan2(np.linalg.det([v0, v1]), np.dot(v0, v1)) # Apply affine transformation to account for the angle between the centerline and the normal to the patch tform = transform.AffineTransform(scale=(np.cos(angle_RL_rad), np.cos(angle_AP_rad))) # Convert to float64, to avoid problems in image indexation causing issues when applying transform.warp current_patch = current_patch.astype(np.float64) # TODO: make sure pattern does not go extend outside of image border current_patch_scaled = transform.warp( current_patch, tform.inverse, output_shape=current_patch.shape, order=1, ) else: current_patch_scaled = current_patch angle_AP_rad, angle_RL_rad = 0.0, 0.0 # compute shape properties on 2D patch shape_property = _properties2d(current_patch_scaled, [px, py]) if shape_property is not None: # Add custom fields shape_property['angle_AP'] = angle_AP_rad * 180.0 / math.pi shape_property['angle_RL'] = angle_RL_rad * 180.0 / math.pi shape_property['length'] = pz / (np.cos(angle_AP_rad) * np.cos(angle_RL_rad)) # Loop across properties and assign values for function output for property_name in property_list: shape_properties[property_name][iz] = shape_property[ property_name] else: logging.warning('\nNo properties for slice: {}'.format(iz)) """ DEBUG from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure fig = Figure() FigureCanvas(fig) ax = fig.add_subplot(111) ax.imshow(current_patch_scaled) ax.grid() ax.set_xlabel('y') ax.set_ylabel('x') fig.savefig('tmp_fig.png') """ metrics = {} for key, value in shape_properties.items(): # Making sure all entries added to metrics have results if not value == []: metrics[key] = Metric(data=np.array(value), label=key) return metrics, fit_results
def compute_texture(self): offset = int(self.param_glcm.distance) sct.printv('\nCompute texture metrics...', self.param.verbose, 'normal') # open image and re-orient it to RPI if needed im_tmp = Image(self.param.fname_im) if self.orientation_im != self.orientation_extraction: im_tmp.change_orientation(self.orientation_extraction) dct_metric = {} for m in self.metric_lst: im_2save = msct_image.zeros_like(im_tmp, dtype='float64') dct_metric[m] = im_2save # dct_metric[m] = Image(self.fname_metric_lst[m]) with sct_progress_bar() as pbar: for im_z, seg_z, zz in zip(self.dct_im_seg['im'], self.dct_im_seg['seg'], range(len(self.dct_im_seg['im']))): for xx in range(im_z.shape[0]): for yy in range(im_z.shape[1]): if not seg_z[xx, yy]: continue if xx < offset or yy < offset: continue if xx > (im_z.shape[0] - offset - 1) or yy > (im_z.shape[1] - offset - 1): continue # to check if the whole glcm_window is in the axial_slice if False in np.unique( seg_z[xx - offset:xx + offset + 1, yy - offset:yy + offset + 1]): continue # to check if the whole glcm_window is in the mask of the axial_slice glcm_window = im_z[xx - offset:xx + offset + 1, yy - offset:yy + offset + 1] glcm_window = glcm_window.astype(np.uint8) dct_glcm = {} for a in self.param_glcm.angle.split( ',' ): # compute the GLCM for self.param_glcm.distance and for each self.param_glcm.angle dct_glcm[a] = greycomatrix( glcm_window, [self.param_glcm.distance], [np.radians(int(a))], symmetric=self.param_glcm.symmetric, normed=self.param_glcm.normed) for m in self.metric_lst: # compute the GLCM property (m.split('_')[0]) of the voxel xx,yy,zz dct_metric[m].data[xx, yy, zz] = greycoprops( dct_glcm[m.split('_')[2]], m.split('_')[0])[0][0] pbar.set_postfix( pos="{}/{}".format(zz, len(self.dct_im_seg["im"]))) pbar.update(1) for m in self.metric_lst: fname_out = sct.add_suffix( "".join(sct.extract_fname(self.param.fname_im)[1:]), '_' + m) dct_metric[m].save(fname_out) self.fname_metric_lst[m] = fname_out
def download_data(urls): """Download the binaries from a URL and return the destination filename Retry downloading if either server or connection errors occur on a SSL connection urls: list of several urls (mirror servers) or single url (string) """ # make sure urls becomes a list, in case user inputs a str if isinstance(urls, str): urls = [urls] # loop through URLs exceptions = [] for url in urls: try: logger.info('Trying URL: %s' % url) retry = Retry(total=3, backoff_factor=0.5, status_forcelist=[500, 503, 504]) session = requests.Session() session.mount('https://', HTTPAdapter(max_retries=retry)) response = session.get(url, stream=True) response.raise_for_status() filename = os.path.basename(urllib.parse.urlparse(url).path) if "Content-Disposition" in response.headers: _, content = cgi.parse_header( response.headers['Content-Disposition']) filename = content["filename"] # protect against directory traversal filename = os.path.basename(filename) if not filename: # this handles cases where you're loading something like an index page # instead of a specific file. e.g. https://osf.io/ugscu/?action=view. raise ValueError( "Unable to determine target filename for URL: %s" % (url, )) tmp_path = os.path.join(tempfile.mkdtemp(), filename) logger.info('Downloading: %s' % filename) with open(tmp_path, 'wb') as tmp_file: total = int(response.headers.get('content-length', 1)) sct_bar = sct_progress_bar(total=total, unit='B', unit_scale=True, desc="Status", ascii=False, position=0) for chunk in response.iter_content(chunk_size=8192): if chunk: tmp_file.write(chunk) dl_chunk = len(chunk) sct_bar.update(dl_chunk) sct_bar.close() return tmp_path except Exception as e: logger.warning( "Link download error, trying next mirror (error was: %s)" % e) exceptions.append(e) else: raise Exception('Download error', exceptions)
def moco(param): """ Main function that performs motion correction. :param param: :return: """ # retrieve parameters file_data = param.file_data file_target = param.file_target folder_mat = param.mat_moco # output folder of mat file todo = param.todo suffix = param.suffix verbose = param.verbose # other parameters file_mask = 'mask.nii' sct.printv('\nInput parameters:', param.verbose) sct.printv(' Input file ............ ' + file_data, param.verbose) sct.printv(' Reference file ........ ' + file_target, param.verbose) sct.printv(' Polynomial degree ..... ' + param.poly, param.verbose) sct.printv(' Smoothing kernel ...... ' + param.smooth, param.verbose) sct.printv(' Gradient step ......... ' + param.gradStep, param.verbose) sct.printv(' Metric ................ ' + param.metric, param.verbose) sct.printv(' Sampling .............. ' + param.sampling, param.verbose) sct.printv(' Todo .................. ' + todo, param.verbose) sct.printv(' Mask ................. ' + param.fname_mask, param.verbose) sct.printv(' Output mat folder ..... ' + folder_mat, param.verbose) # create folder for mat files sct.create_folder(folder_mat) # Get size of data sct.printv('\nData dimensions:', verbose) im_data = Image(param.file_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim sct.printv((' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)), verbose) # copy file_target to a temporary file sct.printv('\nCopy file_target to a temporary file...', verbose) file_target = "target.nii.gz" convert(param.file_target, file_target, verbose=0) # Check if user specified a mask if not param.fname_mask == '': # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask) im_mask = Image(param.fname_mask) if not np.array_equal(im_mask.data, im_mask.data.astype(bool)): # If it is a soft mask, multiply the target by the soft mask. im = Image(file_target) im_masked = im.copy() im_masked.data = im.data * im_mask.data im_masked.save(verbose=0) # silence warning about file overwritting # If scan is sagittal, split src and target along Z (slice) if param.is_sagittal: dim_sag = 2 # TODO: find it # z-split data (time series) im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False) file_data_splitZ = [] for im_z in im_z_list: im_z.save(verbose=0) file_data_splitZ.append(im_z.absolutepath) # z-split target im_targetz_list = split_data(Image(file_target), dim=dim_sag, squeeze_data=False) file_target_splitZ = [] for im_targetz in im_targetz_list: im_targetz.save(verbose=0) file_target_splitZ.append(im_targetz.absolutepath) # z-split mask (if exists) if not param.fname_mask == '': im_maskz_list = split_data(Image(file_mask), dim=dim_sag, squeeze_data=False) file_mask_splitZ = [] for im_maskz in im_maskz_list: im_maskz.save(verbose=0) file_mask_splitZ.append(im_maskz.absolutepath) # initialize file list for output matrices file_mat = np.empty((nz, nt), dtype=object) # axial orientation else: file_data_splitZ = [file_data] # TODO: make it absolute like above file_target_splitZ = [file_target] # TODO: make it absolute like above # initialize file list for output matrices file_mat = np.empty((1, nt), dtype=object) # deal with mask if not param.fname_mask == '': convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0) im_maskz_list = [Image(file_mask)] # use a list with single element # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise) # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)]) file_data_splitZ_moco = [] sct.printv('\nRegister. Loop across Z (note: there is only one Z if orientation is axial)') for file in file_data_splitZ: iz = file_data_splitZ.index(file) # Split data along T dimension # sct.printv('\nSplit data along T dimension.', verbose) im_z = Image(file) list_im_zt = split_data(im_z, dim=3) file_data_splitZ_splitT = [] for im_zt in list_im_zt: im_zt.save(verbose=0) file_data_splitZ_splitT.append(im_zt.absolutepath) # file_data_splitT = file_data + '_T' # Motion correction: initialization index = np.arange(nt) file_data_splitT_num = [] file_data_splitZ_splitT_moco = [] failed_transfo = [0 for i in range(nt)] # Motion correction: Loop across T for indice_index in sct_progress_bar(range(nt), unit='iter', unit_scale=False, desc="Z=" + str(iz) + "/" + str(len(file_data_splitZ)-1), ascii=False, ncols=80): # create indices and display stuff it = index[indice_index] file_mat[iz][it] = os.path.join(folder_mat, "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4) file_data_splitZ_splitT_moco.append(sct.add_suffix(file_data_splitZ_splitT[it], '_moco')) # deal with masking (except in the 'apply' case, where masking is irrelevant) input_mask = None if not param.fname_mask == '' and not param.todo == 'apply': # Check if mask is binary if np.array_equal(im_maskz_list[iz].data, im_maskz_list[iz].data.astype(bool)): # If it is, pass this mask into register() to be used input_mask = im_maskz_list[iz] else: # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks. # Instead, multiply the input data by the Gaussian mask. im = Image(file_data_splitZ_splitT[it]) im_masked = im.copy() im_masked.data = im.data * im_maskz_list[iz].data im_masked.save(verbose=0) # silence warning about file overwritting # run 3D registration failed_transfo[it] = register(param, file_data_splitZ_splitT[it], file_target_splitZ[iz], file_mat[iz][it], file_data_splitZ_splitT_moco[it], im_mask=input_mask) # average registered volume with target image # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1) if param.iterAvg and indice_index < 10 and failed_transfo[it] == 0 and not param.todo == 'apply': im_targetz = Image(file_target_splitZ[iz]) data_targetz = im_targetz.data data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data data_targetz = (data_targetz * (indice_index + 1) + data_mocoz) / (indice_index + 2) im_targetz.data = data_targetz im_targetz.save(verbose=0) # Replace failed transformation with the closest good one fT = [i for i, j in enumerate(failed_transfo) if j == 1] gT = [i for i, j in enumerate(failed_transfo) if j == 0] for it in range(len(fT)): abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))] if not abs_dist == []: index_good = abs_dist.index(min(abs_dist)) sct.printv(' transfo #' + str(fT[it]) + ' --> use transfo #' + str(gT[index_good]), verbose) # copy transformation sct.copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz', file_mat[iz][fT[it]] + 'Warp.nii.gz') # apply transformation sct_apply_transfo.main(args=['-i', file_data_splitZ_splitT[fT[it]], '-d', file_target, '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o', file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp]) else: # exit program if no transformation exists. sct.printv('\nERROR in ' + os.path.basename(__file__) + ': No good transformation exist. Exit program.\n', verbose, 'error') sys.exit(2) # Merge data along T file_data_splitZ_moco.append(sct.add_suffix(file, suffix)) if todo != 'estimate': im_out = concat_data(file_data_splitZ_splitT_moco, 3) im_out.absolutepath = file_data_splitZ_moco[iz] im_out.save(verbose=0) # If sagittal, merge along Z if param.is_sagittal: # TODO: im_out.dim is incorrect: Z value is one im_out = concat_data(file_data_splitZ_moco, 2) dirname, basename, ext = sct.extract_fname(file_data) path_out = os.path.join(dirname, basename + suffix + ext) im_out.absolutepath = path_out im_out.save(verbose=0) return file_mat, im_out
def moco_wrapper(param): """ Wrapper that performs motion correction. :param param: ParamMoco class :return: None """ file_data = 'data.nii' # corresponds to the full input data (e.g. dmri or fmri) file_data_dirname, file_data_basename, file_data_ext = sct.extract_fname(file_data) file_b0 = 'b0.nii' file_datasub = 'datasub.nii' # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True) file_datasubgroup = 'datasub-groups.nii' # concatenation of the average of each file_datasub file_mask = 'mask.nii' file_moco_params_csv = 'moco_params.tsv' file_moco_params_x = 'moco_params_x.nii.gz' file_moco_params_y = 'moco_params_y.nii.gz' ext_data = '.nii.gz' # workaround "too many open files" by slurping the data # TODO: check if .nii can be used mat_final = 'mat_final/' # ext_mat = 'Warp.nii.gz' # warping field # Start timer start_time = time.time() sct.printv('\nInput parameters:', param.verbose) sct.printv(' Input file ............ ' + param.fname_data, param.verbose) sct.printv(' Group size ............ {}'.format(param.group_size), param.verbose) # Get full path # param.fname_data = os.path.abspath(param.fname_data) # param.fname_bvecs = os.path.abspath(param.fname_bvecs) # if param.fname_bvals != '': # param.fname_bvals = os.path.abspath(param.fname_bvals) # Extract path, file and extension # path_data, file_data, ext_data = sct.extract_fname(param.fname_data) # path_mask, file_mask, ext_mask = sct.extract_fname(param.fname_mask) path_tmp = sct.tmp_create(basename="moco", verbose=param.verbose) # Copying input data to tmp folder sct.printv('\nCopying input data to tmp folder and convert to nii...', param.verbose) convert(param.fname_data, os.path.join(path_tmp, file_data)) if param.fname_mask != '': convert(param.fname_mask, os.path.join(path_tmp, file_mask), verbose=param.verbose) # Update field in param (because used later in another function, and param class will be passed) param.fname_mask = file_mask # Build absolute output path and go to tmp folder curdir = os.getcwd() path_out_abs = os.path.abspath(param.path_out) os.chdir(path_tmp) # Get dimensions of data sct.printv('\nGet dimensions of data...', param.verbose) im_data = Image(file_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim sct.printv(' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose) # Get orientation sct.printv('\nData orientation: ' + im_data.orientation, param.verbose) if im_data.orientation[2] in 'LR': param.is_sagittal = True sct.printv(' Treated as sagittal') elif im_data.orientation[2] in 'IS': param.is_sagittal = False sct.printv(' Treated as axial') else: param.is_sagittal = False sct.printv('WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.') sct.printv("\nSet suffix of transformation file name, which depends on the orientation:") if param.is_sagittal: param.suffix_mat = '0GenericAffine.mat' sct.printv("Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the " "estimated transformation is a 2D affine transfo.".format(param.suffix_mat)) else: param.suffix_mat = 'Warp.nii.gz' sct.printv("Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is " "composed of a stack of 2D Tx-Ty transformations".format(param.suffix_mat)) # Adjust group size in case of sagittal scan if param.is_sagittal and param.group_size != 1: sct.printv('For sagittal data group_size should be one for more robustness. Forcing group_size=1.', 1, 'warning') param.group_size = 1 if param.is_diffusion: # Identify b=0 and DWI images index_b0, index_dwi, nb_b0, nb_dwi = \ sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min, param.verbose) # check if dmri and bvecs are the same size if not nb_b0 + nb_dwi == nt: sct.printv( '\nERROR in ' + os.path.basename(__file__) + ': Size of data (' + str(nt) + ') and size of bvecs (' + str( nb_b0 + nb_dwi) + ') are not the same. Check your bvecs file.\n', 1, 'error') sys.exit(2) # ================================================================================================================== # Prepare data (mean/groups...) # ================================================================================================================== # Split into T dimension sct.printv('\nSplit along T dimension...', param.verbose) im_data_split_list = split_data(im_data, 3) for im in im_data_split_list: x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath) im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz") im.save() if param.is_diffusion: # Merge and average b=0 images sct.printv('\nMerge and average b=0 data...', param.verbose) im_b0_list = [] for it in range(nb_b0): im_b0_list.append(im_data_split_list[index_b0[it]]) im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0) # Average across time im_b0.mean(dim=3).save(sct.add_suffix(file_b0, '_mean')) n_moco = nb_dwi # set number of data to perform moco on (using grouping) index_moco = index_dwi # If not a diffusion scan, we will motion-correct all volumes else: n_moco = nt index_moco = list(range(0, nt)) nb_groups = int(math.floor(n_moco / param.group_size)) # Generate groups indexes group_indexes = [] for iGroup in range(nb_groups): group_indexes.append(index_moco[(iGroup * param.group_size):((iGroup + 1) * param.group_size)]) # add the remaining images to a new last group (in case the total number of image is not divisible by group_size) nb_remaining = n_moco % param.group_size # number of remaining images if nb_remaining > 0: nb_groups += 1 group_indexes.append(index_moco[len(index_moco) - nb_remaining:len(index_moco)]) _, file_dwi_basename, file_dwi_ext = sct.extract_fname(file_datasub) # Group data list_file_group = [] for iGroup in sct_progress_bar(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=False, ncols=80): # get index index_moco_i = group_indexes[iGroup] n_moco_i = len(index_moco_i) # concatenate images across time, within this group file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) + ext_data) im_dwi_list = [] for it in range(n_moco_i): im_dwi_list.append(im_data_split_list[index_moco_i[it]]) im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i, verbose=0) # Average across time list_file_group.append(os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' + ext_data)) im_dwi_out.mean(dim=3).save(list_file_group[-1]) # Merge across groups sct.printv('\nMerge across groups...', param.verbose) # file_dwi_groups_means_merge = 'dwi_averaged_groups' im_dw_list = [] for iGroup in range(nb_groups): im_dw_list.append(list_file_group[iGroup]) concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0) # Cleanup del im, im_data_split_list # ================================================================================================================== # Estimate moco # ================================================================================================================== # Initialize another class instance that will be passed on to the moco() function param_moco = deepcopy(param) if param.is_diffusion: # Estimate moco on b0 groups sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion on b=0 images...', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = 'b0.nii' # Identify target image if index_moco[0] != 0: # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that # case select it as the target image for registration of all b=0 param_moco.file_target = os.path.join(file_data_dirname, file_data_basename + '_T' + str(index_b0[index_moco[0] - 1]).zfill( 4) + ext_data) else: # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0. param_moco.file_target = os.path.join(file_data_dirname, file_data_basename + '_T' + str(index_b0[0]).zfill(4) + ext_data) # Run moco param_moco.path_out = '' param_moco.todo = 'estimate_and_apply' param_moco.mat_moco = 'mat_b0groups' file_mat_b0, _ = moco(param_moco) # Estimate moco across groups sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion across groups...', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = file_datasubgroup param_moco.file_target = list_file_group[0] # target is the first volume (closest to the first b=0 if DWI scan) param_moco.path_out = '' param_moco.todo = 'estimate_and_apply' param_moco.mat_moco = 'mat_groups' file_mat_datasub_group, _ = moco(param_moco) # Spline Regularization along T if param.spline_fitting: # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy) raise NotImplementedError() # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph) # ================================================================================================================== # Apply moco # ================================================================================================================== # If group_size>1, assign transformation to each individual ungrouped 3d volume if param.group_size > 1: file_mat_datasub = [] for iz in range(len(file_mat_datasub_group)): # duplicate by factor group_size the transformation file for each it # example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2 file_mat_datasub.append( functools.reduce(operator.iconcat, [[i] * param.group_size for i in file_mat_datasub_group[iz]], [])) else: file_mat_datasub = file_mat_datasub_group # Copy transformations to mat_final folder and rename them appropriately copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param) if param.is_diffusion: copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param) # Apply moco on all dmri data sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Apply moco', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = file_data param_moco.file_target = list_file_group[0] # reference for reslicing into proper coordinate system param_moco.path_out = '' # TODO not used in moco() param_moco.mat_moco = mat_final param_moco.todo = 'apply' file_mat_data, im_moco = moco(param_moco) # copy geometric information from header # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1". im_moco.header = im_data.header im_moco.save(verbose=0) # Average across time if param.is_diffusion: # generate b0_moco_mean and dwi_moco_mean args = ['-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1', '-v', '0'] if not param.fname_bvals == '': # if bvals file is provided args += ['-bval', param.fname_bvals] fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main(args=args) else: fname_moco_mean = sct.add_suffix(im_moco.absolutepath, '_mean') im_moco.mean(dim=3).save(fname_moco_mean) # Extract and output the motion parameters (doesn't work for sagittal orientation) sct.printv('Extract motion parameters...') if param.output_motion_param: if param.is_sagittal: sct.printv('Motion parameters cannot be generated for sagittal images.', 1, 'warning') else: files_warp_X, files_warp_Y = [], [] moco_param = [] for fname_warp in file_mat_data[0]: # Cropping the image to keep only one voxel in the XY plane im_warp = Image(fname_warp + param.suffix_mat) im_warp.data = np.expand_dims(np.expand_dims(im_warp.data[0, 0, :, :, :], axis=0), axis=0) # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters #fname_warp_crop = fname_warp + '_crop_' + ext_mat #files_warp.append(fname_warp_crop) #im_warp.save(fname_warp_crop) # Separating the three components and saving X and Y only (Z is equal to 0 by default). im_warp_XYZ = multicomponent_split(im_warp) fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat im_warp_XYZ[0].save(fname_warp_crop_X) files_warp_X.append(fname_warp_crop_X) fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat im_warp_XYZ[1].save(fname_warp_crop_Y) files_warp_Y.append(fname_warp_crop_Y) # Calculating the slice-wise average moco estimate to provide a QC file moco_param.append([np.mean(np.ravel(im_warp_XYZ[0].data)), np.mean(np.ravel(im_warp_XYZ[1].data))]) # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters #im_warp_concat = concat_data(files_warp, dim=3) #im_warp_concat.save('fmri_moco_params.nii') # Concatenating the moco parameters into a time series for X and Y components. im_warp_concat = concat_data(files_warp_X, dim=3) im_warp_concat.save(file_moco_params_x) im_warp_concat = concat_data(files_warp_Y, dim=3) im_warp_concat.save(file_moco_params_y) # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC with open(file_moco_params_csv, 'wt') as out_file: tsv_writer = csv.writer(out_file, delimiter='\t') tsv_writer.writerow(['X', 'Y']) for mocop in moco_param: tsv_writer.writerow([mocop[0], mocop[1]]) # Generate output files sct.printv('\nGenerate output files...', param.verbose) fname_moco = os.path.join(path_out_abs, sct.add_suffix(os.path.basename(param.fname_data), param.suffix)) sct.generate_output_file(im_moco.absolutepath, fname_moco) if param.is_diffusion: sct.generate_output_file(fname_b0_mean, sct.add_suffix(fname_moco, '_b0_mean')) sct.generate_output_file(fname_dwi_mean, sct.add_suffix(fname_moco, '_dwi_mean')) else: sct.generate_output_file(fname_moco_mean, sct.add_suffix(fname_moco, '_mean')) if os.path.exists(file_moco_params_csv): sct.generate_output_file(file_moco_params_x, os.path.join(path_out_abs, file_moco_params_x), squeeze_data=False) sct.generate_output_file(file_moco_params_y, os.path.join(path_out_abs, file_moco_params_y), squeeze_data=False) sct.generate_output_file(file_moco_params_csv, os.path.join(path_out_abs, file_moco_params_csv)) # Delete temporary files if param.remove_temp_files == 1: sct.printv('\nDelete temporary files...', param.verbose) sct.rmtree(path_tmp, verbose=param.verbose) # come back to working directory os.chdir(curdir) # display elapsed time elapsed_time = time.time() - start_time sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', param.verbose) sct.display_viewer_syntax( [os.path.join(param.path_out, sct.add_suffix(os.path.basename(param.fname_data), param.suffix)), param.fname_data], mode='ortho,ortho')
def straighten(self): """ Straighten spinal cord. Steps: (everything is done in physical space) 1. open input image and centreline image 2. extract bspline fitting of the centreline, and its derivatives 3. compute length of centerline 4. compute and generate straight space 5. compute transformations for each voxel of one space: (done using matrices --> improves speed by a factor x300) a. determine which plane of spinal cord centreline it is included b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane) c. find the correspondant centreline point in the other space d. find the correspondance of the voxel in the corresponding plane 6. generate warping fields for each transformations 7. write warping fields and apply them step 5.b: how to find the corresponding plane? The centerline plane corresponding to a voxel correspond to the nearest point of the centerline. However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant. If it is more far than a threshold, warping value should be 0. step 5.d: how to make the correspondance between centerline point in both images? Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve. If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward. :return: """ # Initialization fname_anat = self.input_filename fname_centerline = self.centerline_filename fname_output = self.output_filename remove_temp_files = self.remove_temp_files verbose = self.verbose interpolation_warp = self.interpolation_warp # TODO: remove this # start timer start_time = time.time() # Extract path/file/extension path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat) path_tmp = sct.tmp_create(basename="straighten_spinalcord", verbose=verbose) # Copying input data to tmp folder sct.printv('\nCopy files to tmp folder...', verbose) Image(fname_anat).save(os.path.join(path_tmp, "data.nii")) Image(fname_centerline).save( os.path.join(path_tmp, "centerline.nii.gz")) if self.use_straight_reference: Image(self.centerline_reference_filename).save( os.path.join(path_tmp, "centerline_ref.nii.gz")) if self.discs_input_filename != '': Image(self.discs_input_filename).save( os.path.join(path_tmp, "labels_input.nii.gz")) if self.discs_ref_filename != '': Image(self.discs_ref_filename).save( os.path.join(path_tmp, "labels_ref.nii.gz")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # Change orientation of the input centerline into RPI image_centerline = Image("centerline.nii.gz").change_orientation( "RPI").save("centerline_rpi.nii.gz", mutable=True) # Get dimension nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if self.speed_factor != 1.0: intermediate_resampling = True px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor else: intermediate_resampling = False if intermediate_resampling: sct.mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz') pz_native = pz # TODO: remove system call sct.run([ 'sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm', str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o', 'centerline_rpi.nii.gz' ]) image_centerline = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if np.min(image_centerline.data) < 0 or np.max( image_centerline.data) > 1: image_centerline.data[image_centerline.data < 0] = 0 image_centerline.data[image_centerline.data > 1] = 1 image_centerline.save() # 2. extract bspline fitting of the centerline, and its derivatives img_ctl = Image('centerline_rpi.nii.gz') centerline = _get_centerline(img_ctl, self.param_centerline, verbose) number_of_points = centerline.number_of_points # ========================================================================================== logger.info('Create the straight space and the safe zone') # 3. compute length of centerline # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction # Computation of the safe zone. # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle) # Calculate Ls for both edges and remove appropriate number of centerline points radius_safe = 0.0 # mm # inferior edge u = centerline.derivatives[0] v = np.array([0, 0, -1]) angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_inferior = radius_safe * np.sin(angle_inferior) # superior edge u = centerline.derivatives[-1] v = np.array([0, 0, 1]) angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_superior = radius_safe * np.sin(angle_superior) # remove points inferior_bound = bisect.bisect(centerline.progressive_length, length_safe_inferior) - 1 superior_bound = centerline.number_of_points - bisect.bisect( centerline.progressive_length_inverse, length_safe_superior) z_centerline = centerline.points[:, 2] length_centerline = centerline.length size_z_centerline = z_centerline[-1] - z_centerline[0] # compute the size factor between initial centerline and straight bended centerline factor_curved_straight = length_centerline / size_z_centerline middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0 bound_curved = [ z_centerline[inferior_bound], z_centerline[superior_bound] ] bound_straight = [(z_centerline[inferior_bound] - middle_slice) * factor_curved_straight + middle_slice, (z_centerline[superior_bound] - middle_slice) * factor_curved_straight + middle_slice] logger.info('Length of spinal cord: {}'.format(length_centerline)) logger.info( 'Size of spinal cord in z direction: {}'.format(size_z_centerline)) logger.info('Ratio length/size: {}'.format(factor_curved_straight)) logger.info( 'Safe zone boundaries (curved space): {}'.format(bound_curved)) logger.info( 'Safe zone boundaries (straight space): {}'.format(bound_straight)) # 4. compute and generate straight space # points along curved centerline are already regularly spaced. # calculate position of points along straight centerline # Create straight NIFTI volumes. # ========================================================================================== # TODO: maybe this if case is not needed? if self.use_straight_reference: image_centerline_pad = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim fname_ref = 'centerline_ref_rpi.nii.gz' image_centerline_straight = Image('centerline_ref.nii.gz') \ .change_orientation("RPI") \ .save(fname_ref, mutable=True) centerline_straight = _get_centerline(image_centerline_straight, self.param_centerline, verbose) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim # Prepare warping fields headers hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.discs_input_filename != "" and self.discs_ref_filename != "": discs_input_image = Image('labels_input.nii.gz') coord = discs_input_image.getNonZeroCoordinates( sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z] ]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline.compute_vertebral_distribution(coord_physical) centerline.save_centerline( image=discs_input_image, fname_output='discs_input_image.nii.gz') discs_ref_image = Image('labels_ref.nii.gz') coord = discs_ref_image.getNonZeroCoordinates( sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline_straight.compute_vertebral_distribution( coord_physical) centerline_straight.save_centerline( image=discs_ref_image, fname_output='discs_ref_image.nii.gz') else: logger.info( 'Pad input volume to account for spinal cord length...') start_point, end_point = bound_straight[0], bound_straight[1] offset_z = 0 # if the destination image is resampled, we still create the straight reference space with the native # resolution. # TODO: Maybe this if case is not needed? if intermediate_resampling: padding_z = int( np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz_native)) sct.run([ 'sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o', 'tmp.centerline_pad_native.nii.gz', '-pad', '0,0,' + str(padding_z) ]) image_centerline_pad = Image('centerline_rpi_native.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim start_point_coord_native = image_centerline_pad.transfo_phys2pix( [[0, 0, start_point]])[0] end_point_coord_native = image_centerline_pad.transfo_phys2pix( [[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [ int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x ] warp_space_y = [ int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y ] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord_native[2] - start_point_coord_native[2])), )) msct_image.spatial_crop( Image("tmp.centerline_pad_native.nii.gz"), spec).save("tmp.centerline_pad_crop_native.nii.gz") fname_ref = 'tmp.centerline_pad_crop_native.nii.gz' offset_z = 4 else: fname_ref = 'tmp.centerline_pad_crop.nii.gz' nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim padding_z = int( np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz)) + offset_z image_centerline_pad = pad_image(image_centerline, pad_z_i=padding_z, pad_z_f=padding_z) nx, ny, nz = image_centerline_pad.data.shape hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') start_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [ int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x ] warp_space_y = [ int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y ] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_x[1] >= nx: warp_space_x[1] = nx - 1 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 if warp_space_y[1] >= ny: warp_space_y[1] = ny - 1 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)), )) image_centerline_straight = msct_image.spatial_crop( image_centerline_pad, spec) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.template_orientation == 1: raise NotImplementedError() start_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, end_point]])[0] number_of_voxel = nx * ny * nz logger.debug('Number of voxels: {}'.format(number_of_voxel)) time_centerlines = time.time() coord_straight = np.empty((number_of_points, 3)) coord_straight[..., 0] = int(np.round(nx_s / 2)) coord_straight[..., 1] = int(np.round(ny_s / 2)) coord_straight[..., 2] = np.linspace( 0, end_point_coord[2] - start_point_coord[2], number_of_points) coord_phys_straight = image_centerline_straight.transfo_pix2phys( coord_straight) derivs_straight = np.empty((number_of_points, 3)) derivs_straight[..., 0] = derivs_straight[..., 1] = 0 derivs_straight[..., 2] = 1 dx_straight, dy_straight, dz_straight = derivs_straight.T centerline_straight = Centerline(coord_phys_straight[:, 0], coord_phys_straight[:, 1], coord_phys_straight[:, 2], dx_straight, dy_straight, dz_straight) time_centerlines = time.time() - time_centerlines logger.info('Time to generate centerline: {} ms'.format( np.round(time_centerlines * 1000.0))) if verbose == 2: # TODO: use OO import matplotlib.pyplot as plt from datetime import datetime curved_points = centerline.progressive_length straight_points = centerline_straight.progressive_length range_points = np.linspace(0, 1, number_of_points) dist_curved = np.zeros(number_of_points) dist_straight = np.zeros(number_of_points) for i in range(1, number_of_points): dist_curved[i] = dist_curved[ i - 1] + curved_points[i - 1] / centerline.length dist_straight[i] = dist_straight[i - 1] + straight_points[ i - 1] / centerline_straight.length plt.plot(range_points, dist_curved) plt.plot(range_points, dist_straight) plt.grid(True) plt.savefig('fig_straighten_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '.png') plt.close() # alignment_mode = 'length' alignment_mode = 'levels' lookup_curved2straight = list(range(centerline.number_of_points)) if self.discs_input_filename != "": # create look-up table curved to straight for index in range(centerline.number_of_points): disc_label = centerline.l_points[index] if alignment_mode == 'length': relative_position = centerline.dist_points[index] else: relative_position = centerline.dist_points_rel[index] idx_closest = centerline_straight.get_closest_to_absolute_position( disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_curved2straight[index] = idx_closest else: lookup_curved2straight[index] = 0 for p in range(0, len(lookup_curved2straight) // 2): if lookup_curved2straight[p] == lookup_curved2straight[p + 1]: lookup_curved2straight[p] = 0 else: break for p in range( len(lookup_curved2straight) - 1, len(lookup_curved2straight) // 2, -1): if lookup_curved2straight[p] == lookup_curved2straight[p - 1]: lookup_curved2straight[p] = 0 else: break lookup_curved2straight = np.array(lookup_curved2straight) lookup_straight2curved = list( range(centerline_straight.number_of_points)) if self.discs_input_filename != "": for index in range(centerline_straight.number_of_points): disc_label = centerline_straight.l_points[index] if alignment_mode == 'length': relative_position = centerline_straight.dist_points[index] else: relative_position = centerline_straight.dist_points_rel[ index] idx_closest = centerline.get_closest_to_absolute_position( disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_straight2curved[index] = idx_closest for p in range(0, len(lookup_straight2curved) // 2): if lookup_straight2curved[p] == lookup_straight2curved[p + 1]: lookup_straight2curved[p] = 0 else: break for p in range( len(lookup_straight2curved) - 1, len(lookup_straight2curved) // 2, -1): if lookup_straight2curved[p] == lookup_straight2curved[p - 1]: lookup_straight2curved[p] = 0 else: break lookup_straight2curved = np.array(lookup_straight2curved) # Create volumes containing curved and straight warping fields data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3)) data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3)) # 5. compute transformations # Curved and straight images and the same dimensions, so we compute both warping fields at the same time. # b. determine which plane of spinal cord centreline it is included # sct.printv(nx * ny * nz, nx_s * ny_s * nz_s) if self.curved2straight: for u in sct_progress_bar(range(nz_s)): x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1] indexes_straight = np.array( list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel()))) physical_coordinates_straight = image_centerline_straight.transfo_pix2phys( indexes_straight) nearest_indexes_straight = centerline_straight.find_nearest_indexes( physical_coordinates_straight) distances_straight = centerline_straight.get_distances_from_planes( physical_coordinates_straight, nearest_indexes_straight) lookup = lookup_straight2curved[nearest_indexes_straight] indexes_out_distance_straight = np.logical_or( np.logical_or( distances_straight > self.threshold_distance, distances_straight < -self.threshold_distance), lookup == 0) projected_points_straight = centerline_straight.get_projected_coordinates_on_planes( physical_coordinates_straight, nearest_indexes_straight) coord_in_planes_straight = centerline_straight.get_in_plans_coordinates( projected_points_straight, nearest_indexes_straight) coord_straight2curved = centerline.get_inverse_plans_coordinates( coord_in_planes_straight, lookup) displacements_straight = coord_straight2curved - physical_coordinates_straight # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+) # while ours is LPI- # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a # https://www.slicer.org/wiki/Coordinate_systems displacements_straight[:, 2] = -displacements_straight[:, 2] displacements_straight[indexes_out_distance_straight] = [ 100000.0, 100000.0, 100000.0 ] data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\ = -displacements_straight if self.straight2curved: for u in sct_progress_bar(range(nz)): x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1] indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel()))) physical_coordinates = image_centerline_pad.transfo_pix2phys( indexes) nearest_indexes_curved = centerline.find_nearest_indexes( physical_coordinates) distances_curved = centerline.get_distances_from_planes( physical_coordinates, nearest_indexes_curved) lookup = lookup_curved2straight[nearest_indexes_curved] indexes_out_distance_curved = np.logical_or( np.logical_or(distances_curved > self.threshold_distance, distances_curved < -self.threshold_distance), lookup == 0) projected_points_curved = centerline.get_projected_coordinates_on_planes( physical_coordinates, nearest_indexes_curved) coord_in_planes_curved = centerline.get_in_plans_coordinates( projected_points_curved, nearest_indexes_curved) coord_curved2straight = centerline_straight.points[lookup] coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2] coord_curved2straight[:, 2] += distances_curved displacements_curved = coord_curved2straight - physical_coordinates displacements_curved[:, 2] = -displacements_curved[:, 2] displacements_curved[indexes_out_distance_curved] = [ 100000.0, 100000.0, 100000.0 ] data_warp_straight2curved[indexes[:, 0], indexes[:, 1], indexes[:, 2], 0, :] = -displacements_curved # Creation of the safe zone based on pre-calculated safe boundaries coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix( [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix( [[0, 0, bound_curved[1]]]) coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix( [[0, 0, bound_straight[0]]]), image_centerline_straight.transfo_phys2pix( [[0, 0, bound_straight[1]]]) if radius_safe > 0: data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2], 0, :] = 100000.0 data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:, 0, :] = 100000.0 data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2], 0, :] = 100000.0 data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:, 0, :] = 100000.0 # Generate warp files as a warping fields hdr_warp_s.set_intent('vector', (), '') hdr_warp_s.set_data_dtype('float32') hdr_warp.set_intent('vector', (), '') hdr_warp.set_data_dtype('float32') if self.curved2straight: img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s) save(img, 'tmp.curve2straight.nii.gz') logger.info('Warping field generated: tmp.curve2straight.nii.gz') if self.straight2curved: img = Nifti1Image(data_warp_straight2curved, None, hdr_warp) save(img, 'tmp.straight2curve.nii.gz') logger.info('Warping field generated: tmp.straight2curve.nii.gz') image_centerline_straight.save(fname_ref) if self.curved2straight: logger.info('Apply transformation to input image...') sct.run([ 'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i', 'data.nii', '-o', 'tmp.anat_rigid_warp.nii.gz', '-t', 'tmp.curve2straight.nii.gz', '-n', 'BSpline[3]' ], is_sct_binary=True, verbose=verbose) if self.accuracy_results: time_accuracy_results = time.time() # compute the error between the straightened centerline/segmentation and the central vertical line. # Ideally, the error should be zero. # Apply deformation to input image logger.info('Apply transformation to centerline image...') sct.run([ 'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i', 'centerline.nii.gz', '-o', 'tmp.centerline_straight.nii.gz', '-t', 'tmp.curve2straight.nii.gz', '-n', 'NearestNeighbor' ], is_sct_binary=True, verbose=verbose) file_centerline_straight = Image('tmp.centerline_straight.nii.gz', verbose=verbose) nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim coordinates_centerline = file_centerline_straight.getNonZeroCoordinates( sorting='z') mean_coord = [] for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z): temp_mean = [ coord.value for coord in coordinates_centerline if coord.z == z ] if temp_mean: mean_value = np.mean(temp_mean) mean_coord.append( np.mean([[ coord.x * coord.value / mean_value, coord.y * coord.value / mean_value ] for coord in coordinates_centerline if coord.z == z], axis=0)) # compute error between the straightened centerline and the straight line. x0 = file_centerline_straight.data.shape[0] / 2.0 y0 = file_centerline_straight.data.shape[1] / 2.0 count_mean = 0 if number_of_points >= 10: mean_c = mean_coord[ 2: -2] # we don't include the four extrema because there are usually messy. else: mean_c = mean_coord for coord_z in mean_c: if not np.isnan(np.sum(coord_z)): dist = ((x0 - coord_z[0]) * px)**2 + ( (y0 - coord_z[1]) * py)**2 self.mse_straightening += dist dist = np.sqrt(dist) if dist > self.max_distance_straightening: self.max_distance_straightening = dist count_mean += 1 self.mse_straightening = np.sqrt(self.mse_straightening / float(count_mean)) self.elapsed_time_accuracy = time.time() - time_accuracy_results os.chdir(curdir) # Generate output file (in current folder) # TODO: do not uncompress the warping field, it is too time consuming! logger.info('Generate output files...') if self.curved2straight: sct.generate_output_file( os.path.join(path_tmp, "tmp.curve2straight.nii.gz"), os.path.join(self.path_output, "warp_curve2straight.nii.gz"), verbose) if self.straight2curved: sct.generate_output_file( os.path.join(path_tmp, "tmp.straight2curve.nii.gz"), os.path.join(self.path_output, "warp_straight2curve.nii.gz"), verbose) # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space if self.curved2straight: sct.copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, "straight_ref.nii.gz")) # move straightened input file if fname_output == '': fname_straight = sct.generate_output_file( os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, file_anat + "_straight" + ext_anat), verbose) else: fname_straight = sct.generate_output_file( os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, fname_output), verbose) # straightened anatomic # Remove temporary files if remove_temp_files: logger.info('Remove temporary files...') sct.rmtree(path_tmp) if self.accuracy_results: logger.info('Maximum x-y error: {} mm'.format( self.max_distance_straightening)) logger.info('Accuracy of straightening (MSE): {} mm'.format( self.mse_straightening)) # display elapsed time self.elapsed_time = int(np.round(time.time() - start_time)) return fname_straight