def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False, coordinates=None, verbose=1, vertebral_levels=None, value=None, msg=""): """ Collection of processes that deal with label creation/modification. :param fname_label: :param fname_output: :param fname_ref: :param cross_radius: :param dilate: # TODO: remove dilate (does not seem to be used) :param coordinates: :param verbose: :param vertebral_levels: :param value: :param msg: string. message to display to the user. """ self.image_input = Image(fname_label, verbose=verbose) self.image_ref = None if fname_ref is not None: self.image_ref = Image(fname_ref, verbose=verbose) if isinstance(fname_output, list): if len(fname_output) == 1: self.fname_output = fname_output[0] else: self.fname_output = fname_output else: self.fname_output = fname_output self.cross_radius = cross_radius self.vertebral_levels = vertebral_levels self.dilate = dilate self.coordinates = coordinates self.verbose = verbose self.value = value self.msg = msg self.output_image = None
def detect_c2c3_from_file(fname_im, fname_seg, contrast, fname_c2c3=None, verbose=1): """ Detect the posterior edge of C2-C3 disc. :param fname_im: :param fname_seg: :param contrast: :param fname_c2c3: :param verbose: :return: fname_c2c3 """ # load data logger.info('Load data...') nii_im = Image(fname_im) nii_seg = Image(fname_seg) # detect C2-C3 nii_c2c3 = detect_c2c3(nii_im.copy(), nii_seg, contrast, verbose=verbose) # Output C2-C3 disc label # by default, output in the same directory as the input images logger.info('Generate output file...') if fname_c2c3 is None: fname_c2c3 = os.path.join(os.path.dirname(nii_im.absolutepath), "label_c2c3.nii.gz") nii_c2c3.save(fname_c2c3) return fname_c2c3
def check_labels(fname_landmarks, label_type='body'): """ Make sure input labels are consistent Parameters ---------- fname_landmarks: file name of input labels label_type: 'body', 'disc' Returns ------- none """ sct.printv('\nCheck input labels...') # open label file image_label = Image(fname_landmarks) # -> all labels must be different labels = image_label.getNonZeroCoordinates(sorting='value') # check if there is two labels if label_type == 'body' and not len(labels) <= 2: sct.printv('ERROR: Label file has ' + str(len(labels)) + ' label(s). It must contain one or two labels.', 1, 'error') # check if labels are integer for label in labels: if not int(label.value) == label.value: sct.printv('ERROR: Label should be integer.', 1, 'error') # check if there are duplicates in label values n_labels = len(labels) list_values = [labels[i].value for i in range(0,n_labels)] list_duplicates = [x for x in list_values if list_values.count(x) > 1] if not list_duplicates == []: sct.printv('ERROR: Found two labels with same value.', 1, 'error') return labels
def multicomponent_merge(fname_list): from numpy import zeros # WARNING: output multicomponent is not optimal yet, some issues may be related to the use of this function im_0 = Image(fname_list[0]) new_shape = list(im_0.data.shape) if len(new_shape) == 3: new_shape.append(1) new_shape.append(len(fname_list)) new_shape = tuple(new_shape) data_out = zeros(new_shape) for i, fname in enumerate(fname_list): im = Image(fname) dat = im.data if len(dat.shape) == 2: data_out[:, :, 0, 0, i] = dat.astype('float32') elif len(dat.shape) == 3: data_out[:, :, :, 0, i] = dat.astype('float32') elif len(dat.shape) == 4: data_out[:, :, :, :, i] = dat.astype('float32') del im del dat im_out = im_0.copy() im_out.data = data_out.astype('float32') im_out.hdr.set_intent('vector', (), '') im_out.absolutepath = sct.add_suffix(im_out.absolutepath, '_multicomponent') return im_out
def remove_label(self, symmetry=False): """ Compare two label images and remove any labels in input image that are not in reference image. The symmetry option enables to remove labels from reference image that are not in input image """ # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose) image_output = msct_image.zeros_like(self.image_input) result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True), self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry) for coord in result_coord_input: image_output.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value)) if symmetry: # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose) image_output_ref = Image(self.image_ref, verbose=self.verbose) for coord in result_coord_ref: image_output_ref.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value)) image_output_ref.absolutepath = self.fname_output[1] image_output_ref.save('minimize_int') self.fname_output = self.fname_output[0] return image_output
def label_lesion(self): printv('\nLabel connected regions of the masked image...', self.verbose, 'normal') im = Image(self.fname_mask) im_2save = im.copy() im_2save.data = label(im.data, connectivity=2) im_2save.save(self.fname_label) self.measure_pd['label'] = [l for l in np.unique(im_2save.data) if l] printv('Lesion count = ' + str(len(self.measure_pd['label'])), self.verbose, 'info')
def extract_slices(self): # open image and re-orient it to RPI if needed im, seg = Image(self.param.fname_im), Image(self.param.fname_seg) if self.orientation_im != self.orientation_extraction: im.change_orientation(self.orientation_extraction) seg.change_orientation(self.orientation_extraction) # extract axial slices in self.dct_im_seg self.dct_im_seg['im'], self.dct_im_seg['seg'] = [im.data[:, :, z] for z in range(im.dim[2])], [seg.data[:, :, z] for z in range(im.dim[2])]
def measure(self): im_lesion = Image(self.fname_label) im_lesion_data = im_lesion.data p_lst = im_lesion.dim[4:7] # voxel size label_lst = [l for l in np.unique(im_lesion_data) if l] # lesion label IDs list if self.path_template is not None: if os.path.isfile(self.path_levels): img_vert = Image(self.path_levels) im_vert_data = img_vert.data self.vert_lst = [v for v in np.unique(im_vert_data) if v] # list of vertebral levels available in the input image else: im_vert_data = None printv('ERROR: the file ' + self.path_levels + ' does not exist. Please make sure the template was correctly registered and warped (sct_register_to_template or sct_register_multimodal and sct_warp_template)', type='error') # In order to open atlas images only one time atlas_data_dct = {} # dict containing the np.array of the registrated atlas for fname_atlas_roi in self.atlas_roi_lst: tract_id = int(fname_atlas_roi.split('_')[-1].split('.nii.gz')[0]) img_cur = Image(fname_atlas_roi) img_cur_copy = img_cur.copy() atlas_data_dct[tract_id] = img_cur_copy.data del img_cur self.volumes = np.zeros((im_lesion.dim[2], len(label_lst))) # iteration across each lesion to measure statistics for lesion_label in label_lst: im_lesion_data_cur = np.copy(im_lesion_data == lesion_label) printv('\nMeasures on lesion #' + str(lesion_label) + '...', self.verbose, 'normal') label_idx = self.measure_pd[self.measure_pd.label == lesion_label].index self._measure_volume(im_lesion_data_cur, p_lst, label_idx) self._measure_length(im_lesion_data_cur, p_lst, label_idx) self._measure_diameter(im_lesion_data_cur, p_lst, label_idx) # compute lesion distribution for each lesion if self.path_template is not None: self._measure_eachLesion_distribution(lesion_id=lesion_label, atlas_data=atlas_data_dct, im_vert=im_vert_data, im_lesion=im_lesion_data_cur, p_lst=p_lst) if self.path_template is not None: # compute total lesion distribution self._measure_totLesion_distribution(im_lesion=np.copy(im_lesion_data > 0), atlas_data=atlas_data_dct, im_vert=im_vert_data, p_lst=p_lst) if self.fname_ref is not None: # Compute mean and std value in each labeled lesion self._measure_within_im(im_lesion=im_lesion_data, im_ref=Image(self.fname_ref).data, label_lst=label_lst)
def __next__(self): if self.iteration <= self.num_of_frames: result = Image(self) sct.printv("Iteration #" + str(self.iteration)) result.data *= float(self.iteration) / float(self.num_of_frames) result.file_name = "tmp." + result.file_name + "_" + str(self.iteration) self.iteration += 1 return result, self.iteration else: raise StopIteration()
def get_im_from_list(self, data): im = Image(data) # set pix dimension im.hdr.structarr['pixdim'][1] = self.param_data.axial_res im.hdr.structarr['pixdim'][2] = self.param_data.axial_res # set the correct orientation im.save('im_to_orient.nii.gz') # TODO explain this quirk im = msct_image.change_orientation(im, 'IRP') im = msct_image.change_orientation(im, 'PIL', inverse=True) return im
def resample_image(fname, suffix='_resampled.nii.gz', binary=False, npx=0.3, npy=0.3, thr=0.0, interpolation='spline'): """ Resampling function: add a padding, resample, crop the padding :param fname: name of the image file to be resampled :param suffix: suffix added to the original fname after resampling :param binary: boolean, image is binary or not :param npx: new pixel size in the x direction :param npy: new pixel size in the y direction :param thr: if the image is binary, it will be thresholded at thr (default=0) after the resampling :param interpolation: type of interpolation used for the resampling :return: file name after resampling (or original fname if it was already in the correct resolution) """ im_in = Image(fname) orientation = im_in.orientation if orientation != 'RPI': fname = im_in.change_orientation(im_in, 'RPI', generate_path=True).save().absolutepath nx, ny, nz, nt, px, py, pz, pt = im_in.dim if np.round(px, 2) != np.round(npx, 2) or np.round(py, 2) != np.round(npy, 2): name_resample = sct.extract_fname(fname)[1] + suffix if binary: interpolation = 'nn' if nz == 1: # when data is 2d: we convert it to a 3d image in order to avoid nipy problem of conversion nifti-->nipy with 2d data sct.run(['sct_image', '-i', ','.join([fname, fname]), '-concat', 'z', '-o', fname]) sct.run(['sct_resample', '-i', fname, '-mm', str(npx) + 'x' + str(npy) + 'x' + str(pz), '-o', name_resample, '-x', interpolation]) if nz == 1: # when input data was 2d: re-convert data 3d-->2d sct.run(['sct_image', '-i', name_resample, '-split', 'z']) im_split = Image(name_resample.split('.nii.gz')[0] + '_Z0000.nii.gz') im_split.save(name_resample) if binary: sct.run(['sct_maths', '-i', name_resample, '-bin', str(thr), '-o', name_resample]) if orientation != 'RPI': name_resample = Image(name_resample) \ .change_orientation(orientation, generate_path=True) \ .save() \ .absolutepath return name_resample else: if orientation != 'RPI': fname = sct.add_suffix(fname, "_RPI") im_in = msct_image.change_orientation(im_in, orientation).save(fname) sct.printv('Image resolution already ' + str(npx) + 'x' + str(npy) + 'xpz') return fname
def test_get_centerline_optic(): """Test extraction of metrics aggregation across slices: All slices by default""" fname_t2 = os.path.join(sct.__sct_dir__, 'sct_testing_data/t2/t2.nii.gz') # install: sct_download_data -d sct_testing_data img_t2 = Image(fname_t2) # Add non-numerical values at the top corner of the image for testing purpose img_t2.change_type('float32') img_t2.data[0, 0, 0] = np.nan img_t2.data[1, 0, 0] = np.inf img_out, arr_out, _ = get_centerline(img_t2, algo_fitting='optic', contrast='t2', minmax=False, verbose=VERBOSE) # Open ground truth segmentation and compare fname_t2_seg = os.path.join(sct.__sct_dir__, 'sct_testing_data/t2/t2_seg.nii.gz') img_seg_out, arr_seg_out, _ = get_centerline(Image(fname_t2_seg), algo_fitting='bspline', minmax=False, verbose=VERBOSE) assert np.linalg.norm(find_and_sort_coord(img_seg_out) - find_and_sort_coord(img_out)) < 3.5
def compute_texture(self): offset = int(self.param_glcm.distance) sct.printv('\nCompute texture metrics...', self.param.verbose, 'normal') # open image and re-orient it to RPI if needed im_tmp = Image(self.param.fname_im) if self.orientation_im != self.orientation_extraction: im_tmp.change_orientation(self.orientation_extraction) dct_metric = {} for m in self.metric_lst: im_2save = msct_image.zeros_like(im_tmp, dtype='float64') dct_metric[m] = im_2save # dct_metric[m] = Image(self.fname_metric_lst[m]) with tqdm.tqdm() as pbar: for im_z, seg_z, zz in zip(self.dct_im_seg['im'], self.dct_im_seg['seg'], range(len(self.dct_im_seg['im']))): for xx in range(im_z.shape[0]): for yy in range(im_z.shape[1]): if not seg_z[xx, yy]: continue if xx < offset or yy < offset: continue if xx > (im_z.shape[0] - offset - 1) or yy > (im_z.shape[1] - offset - 1): continue # to check if the whole glcm_window is in the axial_slice if False in np.unique(seg_z[xx - offset: xx + offset + 1, yy - offset: yy + offset + 1]): continue # to check if the whole glcm_window is in the mask of the axial_slice glcm_window = im_z[xx - offset: xx + offset + 1, yy - offset: yy + offset + 1] glcm_window = glcm_window.astype(np.uint8) dct_glcm = {} for a in self.param_glcm.angle.split(','): # compute the GLCM for self.param_glcm.distance and for each self.param_glcm.angle dct_glcm[a] = greycomatrix(glcm_window, [self.param_glcm.distance], [np.radians(int(a))], symmetric=self.param_glcm.symmetric, normed=self.param_glcm.normed) for m in self.metric_lst: # compute the GLCM property (m.split('_')[0]) of the voxel xx,yy,zz dct_metric[m].data[xx, yy, zz] = greycoprops(dct_glcm[m.split('_')[2]], m.split('_')[0])[0][0] pbar.set_postfix(pos="{}/{}".format(zz, len(self.dct_im_seg["im"]))) pbar.update(1) for m in self.metric_lst: fname_out = sct.add_suffix(''.join(sct.extract_fname(self.param.fname_im)[1:]), '_' + m) dct_metric[m].save(fname_out) self.fname_metric_lst[m] = fname_out
def project_labels_on_spinalcord(fname_label, fname_seg): """ Project labels orthogonally on the spinal cord centerline. The algorithm works by finding the smallest distance between each label and the spinal cord center of mass. :param fname_label: file name of labels :param fname_seg: file name of cord segmentation (could also be of centerline) :return: file name of projected labels """ # build output name fname_label_projected = sct.add_suffix(fname_label, "_projected") # open labels and segmentation im_label = Image(fname_label).change_orientation("RPI") im_seg = Image(fname_seg) native_orient = im_seg.orientation im_seg.change_orientation("RPI") # smooth centerline and return fitted coordinates in voxel space _, arr_ctl, _ = get_centerline(im_seg, algo_fitting='bspline') x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl # convert pixel into physical coordinates centerline_xyz_transposed = \ [im_seg.transfo_pix2phys([[x_centerline_fit[i], y_centerline_fit[i], z_centerline[i]]])[0] for i in range(len(x_centerline_fit))] # transpose list centerline_phys_x = [i[0] for i in centerline_xyz_transposed] centerline_phys_y = [i[1] for i in centerline_xyz_transposed] centerline_phys_z = [i[2] for i in centerline_xyz_transposed] # get center of mass of label labels = im_label.getCoordinatesAveragedByValue() # initialize image of projected labels. Note that we use the space of the seg (not label). im_label_projected = msct_image.zeros_like(im_seg, dtype=np.uint8) # loop across label values for label in labels: # convert pixel into physical coordinates for the label label_phys_x, label_phys_y, label_phys_z = im_label.transfo_pix2phys([[label.x, label.y, label.z]])[0] # calculate distance between label and each point of the centerline distance_centerline = [np.linalg.norm([centerline_phys_x[i] - label_phys_x, centerline_phys_y[i] - label_phys_y, centerline_phys_z[i] - label_phys_z]) for i in range(len(x_centerline_fit))] # get the index corresponding to the min distance ind_min_distance = np.argmin(distance_centerline) # get centerline coordinate (in physical space) [min_phy_x, min_phy_y, min_phy_z] = [centerline_phys_x[ind_min_distance], centerline_phys_y[ind_min_distance], centerline_phys_z[ind_min_distance]] # convert coordinate to voxel space minx, miny, minz = im_seg.transfo_phys2pix([[min_phy_x, min_phy_y, min_phy_z]])[0] # use that index to assign projected label in the centerline im_label_projected.data[minx, miny, minz] = label.value # re-orient projected labels to native orientation and save im_label_projected.change_orientation(native_orient).save(fname_label_projected) return fname_label_projected
def test_segment(): contrast_test = 't2' model_path = os.path.join(sct.__sct_dir__, 'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_test)) # create fake data data = np.zeros((48,48,96)) xx, yy = np.mgrid[:48, :48] circle = (xx - 24) ** 2 + (yy - 24) ** 2 for zz in range(data.shape[2]): data[:,:,zz] += np.logical_and(circle < 400, circle >= 200) * 2400 # CSF data[:,:,zz] += (circle < 200) * 500 # SC data[16:22, 16:22, 64:90] = 1000 # fake lesion affine = np.eye(4) nii = nib.nifti1.Nifti1Image(data, affine) img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape()) seg = deepseg_lesion.segment_3d(model_path, contrast_test, img.copy()) assert np.any(seg.data[16:22, 16:22, 64:90]) == True # check if lesion detected assert np.any(seg.data[img.data != 1000]) == False # check if no FP
def heatmap2optic(fname_heatmap, lambda_value, fname_out, z_max, algo='dpdt'): """Run OptiC on the heatmap computed by CNN_1.""" import nibabel as nib os.environ["FSLOUTPUTTYPE"] = "NIFTI_PAIR" optic_input = fname_heatmap.split('.nii')[0] cmd_optic = 'isct_spine_detect -ctype="%s" -lambda="%s" "%s" "%s" "%s"' % \ (algo, str(lambda_value), "NONE", optic_input, optic_input) sct.run(cmd_optic, verbose=1) optic_hdr_filename = optic_input + '_ctr.hdr' img = nib.load(optic_hdr_filename) nib.save(img, fname_out) # crop the centerline if z_max < data.shape[2] and -brain == 1 if z_max is not None: sct.printv('Cropping brain section.') ctr_nii = Image(fname_out) ctr_nii.data[:, :, z_max:] = 0 ctr_nii.save()
def register_landmarks(fname_src, fname_dest, dof, fname_affine='affine.txt', verbose=1, path_qc=None): """ Register two NIFTI volumes containing landmarks :param fname_src: fname of source landmarks :param fname_dest: fname of destination landmarks :param dof: degree of freedom. Separate with "_". Example: Tx_Ty_Tz_Rx_Ry_Sz :param fname_affine: output affine transformation :param verbose: 0, 1, 2 :return: """ from spinalcordtoolbox.image import Image # open src label im_src = Image(fname_src) # coord_src = im_src.getNonZeroCoordinates(sorting='value') # landmarks are sorted by value coord_src = im_src.getCoordinatesAveragedByValue() # landmarks are sorted by value # open dest labels im_dest = Image(fname_dest) # coord_dest = im_dest.getNonZeroCoordinates(sorting='value') coord_dest = im_dest.getCoordinatesAveragedByValue() # Reorganize landmarks points_src, points_dest = [], [] for coord in coord_src: point_src = im_src.transfo_pix2phys([[coord.x, coord.y, coord.z]]) # convert NIFTI to ITK world coordinate # points_src.append([point_src[0][0], point_src[0][1], point_src[0][2]]) points_src.append([-point_src[0][0], -point_src[0][1], point_src[0][2]]) for coord in coord_dest: point_dest = im_dest.transfo_pix2phys([[coord.x, coord.y, coord.z]]) # convert NIFTI to ITK world coordinate # points_dest.append([point_dest[0][0], point_dest[0][1], point_dest[0][2]]) points_dest.append([-point_dest[0][0], -point_dest[0][1], point_dest[0][2]]) # display sct.printv('Labels src: ' + str(points_src), verbose) sct.printv('Labels dest: ' + str(points_dest), verbose) sct.printv('Degrees of freedom (dof): ' + dof, verbose) if len(coord_src) != len(coord_dest): raise Exception('Error: number of source and destination landmarks are not the same, so landmarks cannot be paired.') # estimate transformation # N.B. points_src and points_dest are inverted below, because ITK uses inverted transformation matrices, i.e., src->dest is defined in dest instead of src. # (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = getRigidTransformFromLandmarks(points_dest, points_src, constraints=dof, verbose=verbose, path_qc=path_qc) (rotation_matrix, translation_array, points_moving_reg, points_moving_barycenter) = getRigidTransformFromLandmarks(points_src, points_dest, constraints=dof, verbose=verbose, path_qc=path_qc) # writing rigid transformation file # N.B. x and y dimensions have a negative sign to ensure compatibility between Python and ITK transfo text_file = open(fname_affine, 'w') text_file.write("#Insight Transform File V1.0\n") text_file.write("#Transform 0\n") text_file.write("Transform: AffineTransform_double_3_3\n") text_file.write("Parameters: %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f %.9f\n" % ( rotation_matrix[0, 0], rotation_matrix[0, 1], rotation_matrix[0, 2], rotation_matrix[1, 0], rotation_matrix[1, 1], rotation_matrix[1, 2], rotation_matrix[2, 0], rotation_matrix[2, 1], rotation_matrix[2, 2], translation_array[0, 0], translation_array[0, 1], translation_array[0, 2])) text_file.write("FixedParameters: %.9f %.9f %.9f\n" % (points_moving_barycenter[0], points_moving_barycenter[1], points_moving_barycenter[2])) text_file.close()
def test_integrity(param_test): """ Test integrity of function """ # initializations distance_detection = float('nan') # extract name of output centerline: data_centerline_optic.nii.gz file_pmj = os.path.join(param_test.path_output, sct.add_suffix(param_test.file_input, '_pmj')) # open output segmentation im_pmj = Image(file_pmj) # open ground truth im_pmj_manual = Image(param_test.fname_gt) # compute Euclidean distance between predicted and GT PMJ label x_true, y_true, z_true = np.where(im_pmj_manual.data == 50) x_pred, y_pred, z_pred = np.where(im_pmj.data == 50) x_true, y_true, z_true = im_pmj_manual.transfo_pix2phys([[x_true[0], y_true[0], z_true[0]]])[0] x_pred, y_pred, z_pred = im_pmj.transfo_pix2phys([[x_pred[0], y_pred[0], z_pred[0]]])[0] distance_detection = math.sqrt(((x_true - x_pred))**2 + ((y_true - y_pred))**2 + ((z_true - z_pred))**2) param_test.output += 'Computed distance: ' + str(distance_detection) param_test.output += 'Distance threshold (if computed Distance higher: fail): ' + str(param_test.dist_threshold) if distance_detection > param_test.dist_threshold: param_test.status = 99 param_test.output += '--> FAILED' else: param_test.output += '--> PASSED' # update Panda structure param_test.results['distance_detection'] = distance_detection return param_test
def dummy_centerline(size_arr=(9, 9, 9), subsampling=1, dilate_ctl=0, hasnan=False, zeroslice=[], outlier=[], orientation='RPI', debug=False): """ Create a dummy Image centerline of small size. Return the full and sub-sampled version along z. :param size_arr: tuple: (nx, ny, nz) :param subsampling: int >=1. Subsampling factor along z. 1: no subsampling. 2: centerline defined every other z. :param dilate_ctl: Dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice. if dilate_ctl=0, result will be a single pixel per slice. :param hasnan: Bool: Image has non-numerical values: nan, inf. In this case, do not subsample. :param zeroslice: list int: zero all slices listed in this param :param outlier: list int: replace the current point with an outlier at the corner of the image for the slices listed :param orientation: :param debug: Bool: Write temp files :return: """ from numpy import poly1d, polyfit nx, ny, nz = size_arr # define array based on a polynomial function, within X-Z plane, located at y=ny/4, based on the following points: x = np.array([round(nx/4.), round(nx/2.), round(3*nx/4.)]) z = np.array([0, round(nz/2.), nz-1]) p = poly1d(polyfit(z, x, deg=3)) data = np.zeros((nx, ny, nz)) arr_ctl = np.array([p(range(nz)).astype(np.int), [round(ny / 4.)] * len(range(nz)), range(nz)], dtype=np.uint16) # Loop across dilation of centerline. E.g., if dilate_ctl=1, result will be a square of 3x3 per slice. for ixiy_ctl in itertools.product(range(-dilate_ctl, dilate_ctl+1, 1), range(-dilate_ctl, dilate_ctl+1, 1)): data[(arr_ctl[0] + ixiy_ctl[0]).tolist(), (arr_ctl[1] + ixiy_ctl[1]).tolist(), arr_ctl[2].tolist()] = 1 # Zero specified slices if zeroslice is not []: data[:, :, zeroslice] = 0 # Add outlier if outlier is not []: # First, zero all the slice data[:, :, outlier] = 0 # Then, add point in the corner data[0, 0, outlier] = 1 # Create image with default orientation LPI affine = np.eye(4) nii = nib.nifti1.Nifti1Image(data, affine) img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape()) # subsample data img_sub = img.copy() img_sub.data = np.zeros((nx, ny, nz)) for iz in range(0, nz, subsampling): img_sub.data[..., iz] = data[..., iz] # Add non-numerical values at the top corner of the image if hasnan: img.data[0, 0, 0] = np.nan img.data[1, 0, 0] = np.inf # Update orientation img.change_orientation(orientation) img_sub.change_orientation(orientation) if debug: img_sub.save('tmp_dummy_seg_'+datetime.now().strftime("%Y%m%d%H%M%S%f")+'.nii.gz') return img, img_sub, arr_ctl
def visualize_warp(fname_warp, fname_grid=None, step=3, rm_tmp=True): if fname_grid is None: from numpy import zeros tmp_dir = sct.tmp_create() im_warp = Image(fname_warp) status, out = sct.run(['fslhd', fname_warp]) curdir = os.getcwd() os.chdir(tmp_dir) dim1 = 'dim1 ' dim2 = 'dim2 ' dim3 = 'dim3 ' nx = int(out[out.find(dim1):][len(dim1):out[out.find(dim1):].find('\n')]) ny = int(out[out.find(dim2):][len(dim2):out[out.find(dim2):].find('\n')]) nz = int(out[out.find(dim3):][len(dim3):out[out.find(dim3):].find('\n')]) sq = zeros((step, step)) sq[step - 1] = 1 sq[:, step - 1] = 1 dat = zeros((nx, ny, nz)) for i in range(0, dat.shape[0], step): for j in range(0, dat.shape[1], step): for k in range(dat.shape[2]): if dat[i:i + step, j:j + step, k].shape == (step, step): dat[i:i + step, j:j + step, k] = sq fname_grid = 'grid_' + str(step) + '.nii.gz' im_grid = Image(param=dat) grid_hdr = im_warp.hdr im_grid.hdr = grid_hdr im_grid.absolutepath = fname_grid im_grid.save() fname_grid_resample = sct.add_suffix(fname_grid, '_resample') sct.run(['sct_resample', '-i', fname_grid, '-f', '3x3x1', '-x', 'nn', '-o', fname_grid_resample]) fname_grid = os.path.join(tmp_dir, fname_grid_resample) os.chdir(curdir) path_warp, file_warp, ext_warp = sct.extract_fname(fname_warp) grid_warped = os.path.join(path_warp, sct.extract_fname(fname_grid)[1] + '_' + file_warp + ext_warp) sct.run(['sct_apply_transfo', '-i', fname_grid, '-d', fname_grid, '-w', fname_warp, '-o', grid_warped]) if rm_tmp: sct.rmtree(tmp_dir)
def test_crop_image_around_centerline(): input_shape = (100, 100, 100) crop_size = 20 crop_size_half = crop_size // 2 data = np.random.rand(input_shape[0], input_shape[1], input_shape[2]) affine = np.eye(4) nii = nib.nifti1.Nifti1Image(data, affine) img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape()) ctr, _, _ = dummy_centerline(size_arr=input_shape) _, _, _, img_out = deepseg_sc.crop_image_around_centerline(im_in=img.copy(), ctr_in=ctr.copy(), crop_size=crop_size) img_in_z0 = img.data[:,:,0] x_ctr_z0, y_ctr_z0 = np.where(ctr.data[:,:,0])[0][0], np.where(ctr.data[:,:,0])[1][0] x_start, x_end = deepseg_sc._find_crop_start_end(x_ctr_z0, crop_size, img.dim[0]) y_start, y_end = deepseg_sc._find_crop_start_end(y_ctr_z0, crop_size, img.dim[1]) img_in_z0_crop = img_in_z0[x_start:x_end, y_start:y_end] assert img_out.data.shape == (crop_size, crop_size, input_shape[2]) assert np.allclose(img_in_z0_crop, img_out.data[:,:,0])
def create_mask(param): # parse argument for method method_type = param.process[0] # check method val if not method_type == 'center': method_val = param.process[1] # check existence of input files if method_type == 'centerline': check_file_exist(method_val, param.verbose) # Extract path/file/extension path_data, file_data, ext_data = extract_fname(param.fname_data) # Get output folder and file name if param.fname_out == '': param.fname_out = os.path.abspath(param.file_prefix + file_data + ext_data) path_tmp = tmp_create(basename="create_mask") printv('\nOrientation:', param.verbose) orientation_input = Image(param.fname_data).orientation printv(' ' + orientation_input, param.verbose) # copy input data to tmp folder and re-orient to RPI Image(param.fname_data).change_orientation("RPI").save( os.path.join(path_tmp, "data_RPI.nii")) if method_type == 'centerline': Image(method_val).change_orientation("RPI").save( os.path.join(path_tmp, "centerline_RPI.nii")) if method_type == 'point': Image(method_val).change_orientation("RPI").save( os.path.join(path_tmp, "point_RPI.nii")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # Get dimensions of data im_data = Image('data_RPI.nii') nx, ny, nz, nt, px, py, pz, pt = im_data.dim printv('\nDimensions:', param.verbose) printv(im_data.dim, param.verbose) # in case user input 4d data if nt != 1: printv( 'WARNING in ' + os.path.basename(__file__) + ': Input image is 4d but output mask will be 3D from first time slice.', param.verbose, 'warning') # extract first volume to have 3d reference nii = empty_like(Image('data_RPI.nii')) data3d = nii.data[:, :, :, 0] nii.data = data3d nii.save('data_RPI.nii') if method_type == 'coord': # parse to get coordinate coord = [x for x in map(int, method_val.split('x'))] if method_type == 'point': # extract coordinate of point printv('\nExtract coordinate of point...', param.verbose) coord = Image("point_RPI.nii").getNonZeroCoordinates() if method_type == 'center': # set coordinate at center of FOV coord = np.round(float(nx) / 2), np.round(float(ny) / 2) if method_type == 'centerline': # get name of centerline from user argument fname_centerline = 'centerline_RPI.nii' else: # generate volume with line along Z at coordinates 'coord' printv('\nCreate line...', param.verbose) fname_centerline = create_line(param, 'data_RPI.nii', coord, nz) # create mask printv('\nCreate mask...', param.verbose) centerline = nibabel.load(fname_centerline) # open centerline hdr = centerline.get_header() # get header hdr.set_data_dtype('uint8') # set imagetype to uint8 spacing = hdr.structarr['pixdim'] data_centerline = centerline.get_data() # get centerline # if data is 2D, reshape with empty third dimension if len(data_centerline.shape) == 2: data_centerline_shape = list(data_centerline.shape) data_centerline_shape.append(1) data_centerline = data_centerline.reshape(data_centerline_shape) z_centerline_not_null = [ iz for iz in range(0, nz, 1) if data_centerline[:, :, iz].any() ] # get center of mass of the centerline cx = [0] * nz cy = [0] * nz for iz in range(0, nz, 1): if iz in z_centerline_not_null: cx[iz], cy[iz] = ndimage.measurements.center_of_mass( np.array(data_centerline[:, :, iz])) # create 2d masks file_mask = 'data_mask' for iz in range(nz): if iz not in z_centerline_not_null: # write an empty nifty volume img = nibabel.Nifti1Image(data_centerline[:, :, iz], None, hdr) nibabel.save(img, (file_mask + str(iz) + '.nii')) else: center = np.array([cx[iz], cy[iz]]) mask2d = create_mask2d(param, center, param.shape, param.size, im_data=im_data) # Write NIFTI volumes img = nibabel.Nifti1Image(mask2d, None, hdr) nibabel.save(img, (file_mask + str(iz) + '.nii')) fname_list = [file_mask + str(iz) + '.nii' for iz in range(nz)] im_list = [Image(fname) for fname in fname_list] im_out = concat_data(im_list, dim=2).save('mask_RPI.nii.gz') im_out.change_orientation(orientation_input) im_out.header = Image(param.fname_data).header im_out.save(param.fname_out) # come back os.chdir(curdir) # Remove temporary files if param.remove_temp_files == 1: printv('\nRemove temporary files...', param.verbose) rmtree(path_tmp) display_viewer_syntax([param.fname_data, param.fname_out], colormaps=['gray', 'red'], opacities=['', '0.5'])
def fmri_moco(param): file_data = "fmri.nii" mat_final = 'mat_final/' ext_mat = 'Warp.nii.gz' # warping field # Get dimensions of data sct.printv('\nGet dimensions of data...', param.verbose) im_data = Image(param.fname_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim sct.printv(' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), param.verbose) # Get orientation sct.printv('\nData orientation: ' + im_data.orientation, param.verbose) if im_data.orientation[2] in 'LR': param.is_sagittal = True sct.printv(' Treated as sagittal') elif im_data.orientation[2] in 'IS': param.is_sagittal = False sct.printv(' Treated as axial') else: param.is_sagittal = False sct.printv('WARNING: Orientation seems to be neither axial nor sagittal.') # Adjust group size in case of sagittal scan if param.is_sagittal and param.group_size != 1: sct.printv('For sagittal data group_size should be one for more robustness. Forcing group_size=1.', 1, 'warning') param.group_size = 1 # Split into T dimension sct.printv('\nSplit along T dimension...', param.verbose) im_data_split_list = split_data(im_data, 3) for im in im_data_split_list: x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath) # Make further steps slurp the data to avoid too many open files (#2149) im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz") im.save() # assign an index to each volume index_fmri = list(range(0, nt)) # Number of groups nb_groups = int(math.floor(nt / param.group_size)) # Generate groups indexes group_indexes = [] for iGroup in range(nb_groups): group_indexes.append(index_fmri[(iGroup * param.group_size):((iGroup + 1) * param.group_size)]) # add the remaining images to the last fMRI group nb_remaining = nt%param.group_size # number of remaining images if nb_remaining > 0: nb_groups += 1 group_indexes.append(index_fmri[len(index_fmri) - nb_remaining:len(index_fmri)]) # groups for iGroup in tqdm(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=True, ncols=80): # get index index_fmri_i = group_indexes[iGroup] nt_i = len(index_fmri_i) # Merge Images file_data_merge_i = sct.add_suffix(file_data, '_' + str(iGroup)) # cmd = fsloutput + 'fslmerge -t ' + file_data_merge_i # for it in range(nt_i): # cmd = cmd + ' ' + file_data + '_T' + str(index_fmri_i[it]).zfill(4) im_fmri_list = [] for it in range(nt_i): im_fmri_list.append(im_data_split_list[index_fmri_i[it]]) im_fmri_concat = concat_data(im_fmri_list, 3, squeeze_data=True).save(file_data_merge_i) file_data_mean = sct.add_suffix(file_data, '_mean_' + str(iGroup)) if file_data_mean.endswith(".nii"): file_data_mean += ".gz" # #2149 if param.group_size == 1: # copy to new file name instead of averaging (faster) # note: this is a bandage. Ideally we should skip this entire for loop if g=1 convert(file_data_merge_i, file_data_mean) else: # Average Images sct.run(['sct_maths', '-i', file_data_merge_i, '-o', file_data_mean, '-mean', 't'], verbose=0) # if not average_data_across_dimension(file_data_merge_i+'.nii', file_data_mean+'.nii', 3): # sct.printv('ERROR in average_data_across_dimension', 1, 'error') # cmd = fsloutput + 'fslmaths ' + file_data_merge_i + ' -Tmean ' + file_data_mean # sct.run(cmd, param.verbose) # Merge groups means. The output 4D volume will be used for motion correction. sct.printv('\nMerging volumes...', param.verbose) file_data_groups_means_merge = 'fmri_averaged_groups.nii' im_mean_list = [] for iGroup in range(nb_groups): file_data_mean = sct.add_suffix(file_data, '_mean_' + str(iGroup)) if file_data_mean.endswith(".nii"): file_data_mean += ".gz" # #2149 im_mean_list.append(Image(file_data_mean)) im_mean_concat = concat_data(im_mean_list, 3).save(file_data_groups_means_merge) # Estimate moco sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion...', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco = param param_moco.file_data = 'fmri_averaged_groups.nii' param_moco.file_target = sct.add_suffix(file_data, '_mean_' + param.num_target) if param_moco.file_target.endswith(".nii"): param_moco.file_target += ".gz" # #2149 param_moco.path_out = '' param_moco.todo = 'estimate_and_apply' param_moco.mat_moco = 'mat_groups' file_mat = moco.moco(param_moco) # TODO: if g=1, no need to run the block below (already applied) if param.group_size == 1: # if flag g=1, it means that all images have already been corrected, so we just need to rename the file sct.mv('fmri_averaged_groups_moco.nii', 'fmri_moco.nii') else: # create final mat folder sct.create_folder(mat_final) # Copy registration matrices sct.printv('\nCopy transformations...', param.verbose) for iGroup in range(nb_groups): for data in range(len(group_indexes[iGroup])): # we cannot use enumerate because group_indexes has 2 dim. # fetch all file_mat_z for given t-group list_file_mat_z = file_mat[:, iGroup] # loop across file_mat_z and copy to mat_final folder for file_mat_z in list_file_mat_z: # we want to copy 'mat_groups/mat.ZXXXXTYYYYWarp.nii.gz' --> 'mat_final/mat.ZXXXXTYYYZWarp.nii.gz' # Notice the Y->Z in the under the T index: the idea here is to use the single matrix from each group, # and apply it to all images belonging to the same group. sct.copy(file_mat_z + ext_mat, mat_final + file_mat_z[11:20] + 'T' + str(group_indexes[iGroup][data]).zfill(4) + ext_mat) # Apply moco on all fmri data sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Apply moco', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = 'fmri.nii' param_moco.file_target = sct.add_suffix(file_data, '_mean_' + str(0)) if param_moco.file_target.endswith(".nii"): param_moco.file_target += ".gz" param_moco.path_out = '' param_moco.mat_moco = mat_final param_moco.todo = 'apply' moco.moco(param_moco) # copy geometric information from header # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1". im_fmri = Image('fmri.nii') im_fmri_moco = Image('fmri_moco.nii') im_fmri_moco.header = im_fmri.header im_fmri_moco.save() # Average volumes sct.printv('\nAveraging data...', param.verbose) sct_maths.main(args=['-i', 'fmri_moco.nii', '-o', 'fmri_moco_mean.nii', '-mean', 't', '-v', '0'])
def straighten(self): """ Straighten spinal cord. Steps: (everything is done in physical space) 1. open input image and centreline image 2. extract bspline fitting of the centreline, and its derivatives 3. compute length of centerline 4. compute and generate straight space 5. compute transformations for each voxel of one space: (done using matrices --> improves speed by a factor x300) a. determine which plane of spinal cord centreline it is included b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane) c. find the correspondant centreline point in the other space d. find the correspondance of the voxel in the corresponding plane 6. generate warping fields for each transformations 7. write warping fields and apply them step 5.b: how to find the corresponding plane? The centerline plane corresponding to a voxel correspond to the nearest point of the centerline. However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant. If it is more far than a threshold, warping value should be 0. step 5.d: how to make the correspondance between centerline point in both images? Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve. If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward. :return: """ # Initialization fname_anat = self.input_filename fname_centerline = self.centerline_filename fname_output = self.output_filename remove_temp_files = self.remove_temp_files verbose = self.verbose interpolation_warp = self.interpolation_warp algo_fitting = self.algo_fitting # start timer start_time = time.time() # Extract path/file/extension path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat) path_tmp = sct.tmp_create(basename="straighten_spinalcord", verbose=verbose) # Copying input data to tmp folder sct.printv('\nCopy files to tmp folder...', verbose) Image(fname_anat).save(os.path.join(path_tmp, "data.nii")) Image(fname_centerline).save(os.path.join(path_tmp, "centerline.nii.gz")) if self.use_straight_reference: Image(self.centerline_reference_filename).save(os.path.join(path_tmp, "centerline_ref.nii.gz")) if self.discs_input_filename != '': Image(self.discs_input_filename).save(os.path.join(path_tmp, "labels_input.nii.gz")) if self.discs_ref_filename != '': Image(self.discs_ref_filename).save(os.path.join(path_tmp, "labels_ref.nii.gz")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # Change orientation of the input centerline into RPI image_centerline = Image("centerline.nii.gz").change_orientation("RPI").save("centerline_rpi.nii.gz", mutable=True) # Get dimension nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if self.speed_factor != 1.0: intermediate_resampling = True px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor else: intermediate_resampling = False if intermediate_resampling: sct.mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz') pz_native = pz # TODO: remove system call sct.run(['sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm', str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o', 'centerline_rpi.nii.gz']) image_centerline = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if np.min(image_centerline.data) < 0 or np.max(image_centerline.data) > 1: image_centerline.data[image_centerline.data < 0] = 0 image_centerline.data[image_centerline.data > 1] = 1 image_centerline.save() # 2. extract bspline fitting of the centerline, and its derivatives img_ctl = Image('centerline_rpi.nii.gz') centerline = _get_centerline(img_ctl, algo_fitting, self.degree, verbose) number_of_points = centerline.number_of_points # ========================================================================================== logger.info('Create the straight space and the safe zone') # 3. compute length of centerline # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction # Computation of the safe zone. # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle) # Calculate Ls for both edges and remove appropriate number of centerline points radius_safe = 0.0 # mm # inferior edge u = centerline.derivatives[0] v = np.array([0, 0, -1]) angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_inferior = radius_safe * np.sin(angle_inferior) # superior edge u = centerline.derivatives[-1] v = np.array([0, 0, 1]) angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_superior = radius_safe * np.sin(angle_superior) # remove points inferior_bound = bisect.bisect(centerline.progressive_length, length_safe_inferior) - 1 superior_bound = centerline.number_of_points - bisect.bisect(centerline.progressive_length_inverse, length_safe_superior) z_centerline = centerline.points[:, 2] length_centerline = centerline.length size_z_centerline = z_centerline[-1] - z_centerline[0] # compute the size factor between initial centerline and straight bended centerline factor_curved_straight = length_centerline / size_z_centerline middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0 bound_curved = [z_centerline[inferior_bound], z_centerline[superior_bound]] bound_straight = [(z_centerline[inferior_bound] - middle_slice) * factor_curved_straight + middle_slice, (z_centerline[superior_bound] - middle_slice) * factor_curved_straight + middle_slice] logger.info('Length of spinal cord: {}'.format(length_centerline)) logger.info('Size of spinal cord in z direction: {}'.format(size_z_centerline)) logger.info('Ratio length/size: {}'.format(factor_curved_straight)) logger.info('Safe zone boundaries (curved space): {}'.format(bound_curved)) logger.info('Safe zone boundaries (straight space): {}'.format(bound_straight)) # 4. compute and generate straight space # points along curved centerline are already regularly spaced. # calculate position of points along straight centerline # Create straight NIFTI volumes. # ========================================================================================== # TODO: maybe this if case is not needed? if self.use_straight_reference: image_centerline_pad = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim fname_ref = 'centerline_ref_rpi.nii.gz' image_centerline_straight = Image('centerline_ref.nii.gz') \ .change_orientation("RPI") \ .save(fname_ref, mutable=True) centerline_straight = _get_centerline(image_centerline_straight, algo_fitting, self.degree, verbose) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim # Prepare warping fields headers hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.discs_input_filename != "" and self.discs_ref_filename != "": discs_input_image = Image('labels_input.nii.gz') coord = discs_input_image.getNonZeroCoordinates(sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline.compute_vertebral_distribution(coord_physical) centerline.save_centerline(image=discs_input_image, fname_output='discs_input_image.nii.gz') discs_ref_image = Image('labels_ref.nii.gz') coord = discs_ref_image.getNonZeroCoordinates(sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline_straight.compute_vertebral_distribution(coord_physical) centerline_straight.save_centerline(image=discs_ref_image, fname_output='discs_ref_image.nii.gz') else: logger.info('Pad input volume to account for spinal cord length...') start_point, end_point = bound_straight[0], bound_straight[1] offset_z = 0 # if the destination image is resampled, we still create the straight reference space with the native # resolution. # TODO: Maybe this if case is not needed? if intermediate_resampling: padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz_native)) sct.run( ['sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o', 'tmp.centerline_pad_native.nii.gz', '-pad', '0,0,' + str(padding_z)]) image_centerline_pad = Image('centerline_rpi_native.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim start_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0] end_point_coord_native = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x] warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord_native[2] - start_point_coord_native[2])), )) msct_image.spatial_crop(Image("tmp.centerline_pad_native.nii.gz"), spec).save( "tmp.centerline_pad_crop_native.nii.gz") fname_ref = 'tmp.centerline_pad_crop_native.nii.gz' offset_z = 4 else: fname_ref = 'tmp.centerline_pad_crop.nii.gz' nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim padding_z = int(np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz)) + offset_z image_centerline_pad = pad_image(image_centerline, pad_z_i=padding_z, pad_z_f=padding_z) nx, ny, nz = image_centerline_pad.data.shape hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x] warp_space_y = [int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_x[1] >= nx: warp_space_x[1] = nx - 1 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 if warp_space_y[1] >= ny: warp_space_y[1] = ny - 1 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)), )) image_centerline_straight = msct_image.spatial_crop(image_centerline_pad, spec) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.template_orientation == 1: raise NotImplementedError() start_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix([[0, 0, end_point]])[0] number_of_voxel = nx * ny * nz logger.debug('Number of voxels: {}'.format(number_of_voxel)) time_centerlines = time.time() coord_straight = np.empty((number_of_points, 3)) coord_straight[..., 0] = int(np.round(nx_s / 2)) coord_straight[..., 1] = int(np.round(ny_s / 2)) coord_straight[..., 2] = np.linspace(0, end_point_coord[2] - start_point_coord[2], number_of_points) coord_phys_straight = image_centerline_straight.transfo_pix2phys(coord_straight) derivs_straight = np.empty((number_of_points, 3)) derivs_straight[..., 0] = derivs_straight[..., 1] = 0 derivs_straight[..., 2] = 1 dx_straight, dy_straight, dz_straight = derivs_straight.T centerline_straight = Centerline(coord_phys_straight[:, 0], coord_phys_straight[:, 1], coord_phys_straight[:, 2], dx_straight, dy_straight, dz_straight) time_centerlines = time.time() - time_centerlines logger.info('Time to generate centerline: {} ms'.format(np.round(time_centerlines * 1000.0))) if verbose == 2: # TODO: use OO import matplotlib.pyplot as plt from datetime import datetime curved_points = centerline.progressive_length straight_points = centerline_straight.progressive_length range_points = np.linspace(0, 1, number_of_points) dist_curved = np.zeros(number_of_points) dist_straight = np.zeros(number_of_points) for i in range(1, number_of_points): dist_curved[i] = dist_curved[i - 1] + curved_points[i - 1] / centerline.length dist_straight[i] = dist_straight[i - 1] + straight_points[i - 1] / centerline_straight.length plt.plot(range_points, dist_curved) plt.plot(range_points, dist_straight) plt.grid(True) plt.savefig('fig_straighten_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '.png') plt.close() # alignment_mode = 'length' alignment_mode = 'levels' lookup_curved2straight = list(range(centerline.number_of_points)) if self.discs_input_filename != "": # create look-up table curved to straight for index in range(centerline.number_of_points): disc_label = centerline.l_points[index] if alignment_mode == 'length': relative_position = centerline.dist_points[index] else: relative_position = centerline.dist_points_rel[index] idx_closest = centerline_straight.get_closest_to_absolute_position(disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_curved2straight[index] = idx_closest else: lookup_curved2straight[index] = 0 for p in range(0, len(lookup_curved2straight) // 2): if lookup_curved2straight[p] == lookup_curved2straight[p + 1]: lookup_curved2straight[p] = 0 else: break for p in range(len(lookup_curved2straight) - 1, len(lookup_curved2straight) // 2, -1): if lookup_curved2straight[p] == lookup_curved2straight[p - 1]: lookup_curved2straight[p] = 0 else: break lookup_curved2straight = np.array(lookup_curved2straight) lookup_straight2curved = list(range(centerline_straight.number_of_points)) if self.discs_input_filename != "": for index in range(centerline_straight.number_of_points): disc_label = centerline_straight.l_points[index] if alignment_mode == 'length': relative_position = centerline_straight.dist_points[index] else: relative_position = centerline_straight.dist_points_rel[index] idx_closest = centerline.get_closest_to_absolute_position(disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_straight2curved[index] = idx_closest for p in range(0, len(lookup_straight2curved) // 2): if lookup_straight2curved[p] == lookup_straight2curved[p + 1]: lookup_straight2curved[p] = 0 else: break for p in range(len(lookup_straight2curved) - 1, len(lookup_straight2curved) // 2, -1): if lookup_straight2curved[p] == lookup_straight2curved[p - 1]: lookup_straight2curved[p] = 0 else: break lookup_straight2curved = np.array(lookup_straight2curved) # Create volumes containing curved and straight warping fields data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3)) data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3)) # 5. compute transformations # Curved and straight images and the same dimensions, so we compute both warping fields at the same time. # b. determine which plane of spinal cord centreline it is included # sct.printv(nx * ny * nz, nx_s * ny_s * nz_s) if self.curved2straight: for u in tqdm(range(nz_s)): x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1] indexes_straight = np.array(list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel()))) physical_coordinates_straight = image_centerline_straight.transfo_pix2phys(indexes_straight) nearest_indexes_straight = centerline_straight.find_nearest_indexes(physical_coordinates_straight) distances_straight = centerline_straight.get_distances_from_planes(physical_coordinates_straight, nearest_indexes_straight) lookup = lookup_straight2curved[nearest_indexes_straight] indexes_out_distance_straight = np.logical_or( np.logical_or(distances_straight > self.threshold_distance, distances_straight < -self.threshold_distance), lookup == 0) projected_points_straight = centerline_straight.get_projected_coordinates_on_planes( physical_coordinates_straight, nearest_indexes_straight) coord_in_planes_straight = centerline_straight.get_in_plans_coordinates(projected_points_straight, nearest_indexes_straight) coord_straight2curved = centerline.get_inverse_plans_coordinates(coord_in_planes_straight, lookup) displacements_straight = coord_straight2curved - physical_coordinates_straight # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+) # while ours is LPI- # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a # https://www.slicer.org/wiki/Coordinate_systems displacements_straight[:, 2] = -displacements_straight[:, 2] displacements_straight[indexes_out_distance_straight] = [100000.0, 100000.0, 100000.0] data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\ = -displacements_straight if self.straight2curved: for u in tqdm(range(nz)): x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1] indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel()))) physical_coordinates = image_centerline_pad.transfo_pix2phys(indexes) nearest_indexes_curved = centerline.find_nearest_indexes(physical_coordinates) distances_curved = centerline.get_distances_from_planes(physical_coordinates, nearest_indexes_curved) lookup = lookup_curved2straight[nearest_indexes_curved] indexes_out_distance_curved = np.logical_or( np.logical_or(distances_curved > self.threshold_distance, distances_curved < -self.threshold_distance), lookup == 0) projected_points_curved = centerline.get_projected_coordinates_on_planes(physical_coordinates, nearest_indexes_curved) coord_in_planes_curved = centerline.get_in_plans_coordinates(projected_points_curved, nearest_indexes_curved) coord_curved2straight = centerline_straight.points[lookup] coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2] coord_curved2straight[:, 2] += distances_curved displacements_curved = coord_curved2straight - physical_coordinates displacements_curved[:, 2] = -displacements_curved[:, 2] displacements_curved[indexes_out_distance_curved] = [100000.0, 100000.0, 100000.0] data_warp_straight2curved[indexes[:, 0], indexes[:, 1], indexes[:, 2], 0, :] = -displacements_curved # Creation of the safe zone based on pre-calculated safe boundaries coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix( [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix([[0, 0, bound_curved[1]]]) coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix( [[0, 0, bound_straight[0]]]), image_centerline_straight.transfo_phys2pix([[0, 0, bound_straight[1]]]) if radius_safe > 0: data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2], 0, :] = 100000.0 data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:, 0, :] = 100000.0 data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2], 0, :] = 100000.0 data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:, 0, :] = 100000.0 # Generate warp files as a warping fields hdr_warp_s.set_intent('vector', (), '') hdr_warp_s.set_data_dtype('float32') hdr_warp.set_intent('vector', (), '') hdr_warp.set_data_dtype('float32') if self.curved2straight: img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s) save(img, 'tmp.curve2straight.nii.gz') logger.info('Warping field generated: tmp.curve2straight.nii.gz') if self.straight2curved: img = Nifti1Image(data_warp_straight2curved, None, hdr_warp) save(img, 'tmp.straight2curve.nii.gz') logger.info('Warping field generated: tmp.straight2curve.nii.gz') image_centerline_straight.save(fname_ref) if self.curved2straight: logger.info('Apply transformation to input image...') sct.run(['isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i', 'data.nii', '-o', 'tmp.anat_rigid_warp.nii.gz', '-t', 'tmp.curve2straight.nii.gz', '-n', 'BSpline[3]'], is_sct_binary=True, verbose=verbose) if self.accuracy_results: time_accuracy_results = time.time() # compute the error between the straightened centerline/segmentation and the central vertical line. # Ideally, the error should be zero. # Apply deformation to input image logger.info('Apply transformation to centerline image...') sct.run(['isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i', 'centerline.nii.gz', '-o', 'tmp.centerline_straight.nii.gz', '-t', 'tmp.curve2straight.nii.gz', '-n', 'NearestNeighbor'], is_sct_binary=True, verbose=verbose) file_centerline_straight = Image('tmp.centerline_straight.nii.gz', verbose=verbose) nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim coordinates_centerline = file_centerline_straight.getNonZeroCoordinates(sorting='z') mean_coord = [] for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z): temp_mean = [coord.value for coord in coordinates_centerline if coord.z == z] if temp_mean: mean_value = np.mean(temp_mean) mean_coord.append( np.mean([[coord.x * coord.value / mean_value, coord.y * coord.value / mean_value] for coord in coordinates_centerline if coord.z == z], axis=0)) # compute error between the straightened centerline and the straight line. x0 = file_centerline_straight.data.shape[0] / 2.0 y0 = file_centerline_straight.data.shape[1] / 2.0 count_mean = 0 if number_of_points >= 10: mean_c = mean_coord[2:-2] # we don't include the four extrema because there are usually messy. else: mean_c = mean_coord for coord_z in mean_c: if not np.isnan(np.sum(coord_z)): dist = ((x0 - coord_z[0]) * px) ** 2 + ((y0 - coord_z[1]) * py) ** 2 self.mse_straightening += dist dist = np.sqrt(dist) if dist > self.max_distance_straightening: self.max_distance_straightening = dist count_mean += 1 self.mse_straightening = np.sqrt(self.mse_straightening / float(count_mean)) self.elapsed_time_accuracy = time.time() - time_accuracy_results os.chdir(curdir) # Generate output file (in current folder) # TODO: do not uncompress the warping field, it is too time consuming! logger.info('Generate output files...') if self.curved2straight: sct.generate_output_file(os.path.join(path_tmp, "tmp.curve2straight.nii.gz"), os.path.join(self.path_output, "warp_curve2straight.nii.gz"), verbose) if self.straight2curved: sct.generate_output_file(os.path.join(path_tmp, "tmp.straight2curve.nii.gz"), os.path.join(self.path_output, "warp_straight2curve.nii.gz"), verbose) # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space if self.curved2straight: sct.copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, "straight_ref.nii.gz")) # move straightened input file if fname_output == '': fname_straight = sct.generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, file_anat + "_straight" + ext_anat), verbose) else: fname_straight = sct.generate_output_file(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, fname_output), verbose) # straightened anatomic # Remove temporary files if remove_temp_files: logger.info('Remove temporary files...') sct.rmtree(path_tmp) if self.accuracy_results: logger.info('Maximum x-y error: {} mm'.format(self.max_distance_straightening)) logger.info('Accuracy of straightening (MSE): {} mm'.format(self.mse_straightening)) # display elapsed time self.elapsed_time = int(np.round(time.time() - start_time)) return fname_straight
def _orient(self, fname, orientation): return Image(fname).change_orientation(orientation).save(fname, mutable=True)
def register(src, dest, paramreg, param, i_step_str): # initiate default parameters of antsRegistration transformation ants_registration_params = { 'rigid': '', 'affine': '', 'compositeaffine': '', 'similarity': '', 'translation': '', 'bspline': ',10', 'gaussiandisplacementfield': ',3,0', 'bsplinedisplacementfield': ',5,10', 'syn': ',3,0', 'bsplinesyn': ',1,3' } output = '' # default output if problem # display arguments sct.printv('Registration parameters:', param.verbose) sct.printv(' type ........... ' + paramreg.steps[i_step_str].type, param.verbose) sct.printv(' algo ........... ' + paramreg.steps[i_step_str].algo, param.verbose) sct.printv(' slicewise ...... ' + paramreg.steps[i_step_str].slicewise, param.verbose) sct.printv(' metric ......... ' + paramreg.steps[i_step_str].metric, param.verbose) sct.printv(' iter ........... ' + paramreg.steps[i_step_str].iter, param.verbose) sct.printv(' smooth ......... ' + paramreg.steps[i_step_str].smooth, param.verbose) sct.printv(' laplacian ...... ' + paramreg.steps[i_step_str].laplacian, param.verbose) sct.printv(' shrink ......... ' + paramreg.steps[i_step_str].shrink, param.verbose) sct.printv(' gradStep ....... ' + paramreg.steps[i_step_str].gradStep, param.verbose) sct.printv(' deformation .... ' + paramreg.steps[i_step_str].deformation, param.verbose) sct.printv(' init ........... ' + paramreg.steps[i_step_str].init, param.verbose) sct.printv(' poly ........... ' + paramreg.steps[i_step_str].poly, param.verbose) sct.printv(' dof ............ ' + paramreg.steps[i_step_str].dof, param.verbose) sct.printv(' smoothWarpXY ... ' + paramreg.steps[i_step_str].smoothWarpXY, param.verbose) # set metricSize if paramreg.steps[i_step_str].metric == 'MI': metricSize = '32' # corresponds to number of bins else: metricSize = '4' # corresponds to radius (for CC, MeanSquares...) # set masking if param.fname_mask: fname_mask = 'mask.nii.gz' masking = ['-x', 'mask.nii.gz'] else: fname_mask = '' masking = [] if paramreg.steps[i_step_str].algo == 'slicereg': # check if user used type=label if paramreg.steps[i_step_str].type == 'label': sct.printv( '\nERROR: this algo is not compatible with type=label. Please use type=im or type=seg', 1, 'error') else: # Find the min (and max) z-slice index below which (and above which) slices only have voxels below a given # threshold. list_fname = [src, dest] if not masking == []: list_fname.append(fname_mask) zmin_global, zmax_global = 0, 99999 # this is assuming that typical image has less slice than 99999 for fname in list_fname: im = Image(fname) zmin, zmax = msct_image.find_zmin_zmax(im, threshold=0.1) if zmin > zmin_global: zmin_global = zmin if zmax < zmax_global: zmax_global = zmax # crop images (see issue #293) src_crop = sct.add_suffix(src, '_crop') msct_image.spatial_crop(Image(src), dict( ((2, (zmin_global, zmax_global)), ))).save(src_crop) dest_crop = sct.add_suffix(dest, '_crop') msct_image.spatial_crop(Image(dest), dict(((2, (zmin_global, zmax_global)), ))).save(dest_crop) # update variables src = src_crop dest = dest_crop scr_regStep = sct.add_suffix(src, '_regStep' + i_step_str) # estimate transfo # TODO fixup isct_ants* parsers cmd = [ 'isct_antsSliceRegularizedRegistration', '-t', 'Translation[' + paramreg.steps[i_step_str].gradStep + ']', '-m', paramreg.steps[i_step_str].metric + '[' + dest + ',' + src + ',1,' + metricSize + ',Regular,0.2]', '-p', paramreg.steps[i_step_str].poly, '-i', paramreg.steps[i_step_str].iter, '-f', paramreg.steps[i_step_str].shrink, '-s', paramreg.steps[i_step_str].smooth, '-v', '1', # verbose (verbose=2 does not exist, so we force it to 1) '-o', '[step' + i_step_str + ',' + scr_regStep + ']', # here the warp name is stage10 because # antsSliceReg add "Warp" ] + masking warp_forward_out = 'step' + i_step_str + 'Warp.nii.gz' warp_inverse_out = 'step' + i_step_str + 'InverseWarp.nii.gz' # run command status, output = sct.run(cmd, param.verbose) # ANTS 3d elif paramreg.steps[i_step_str].algo.lower() in ants_registration_params \ and paramreg.steps[i_step_str].slicewise == '0': # make sure type!=label. If type==label, this will be addressed later in the code. if not paramreg.steps[i_step_str].type == 'label': # Pad the destination image (because ants doesn't deform the extremities) # N.B. no need to pad if iter = 0 if not paramreg.steps[i_step_str].iter == '0': dest_pad = sct.add_suffix(dest, '_pad') sct.run([ 'sct_image', '-i', dest, '-o', dest_pad, '-pad', '0,0,' + str(param.padding) ]) dest = dest_pad # apply Laplacian filter if not paramreg.steps[i_step_str].laplacian == '0': sct.printv('\nApply Laplacian filter', param.verbose) sct.run([ 'sct_maths', '-i', src, '-laplacian', paramreg.steps[i_step_str].laplacian + ',' + paramreg.steps[i_step_str].laplacian + ',0', '-o', sct.add_suffix(src, '_laplacian') ]) sct.run([ 'sct_maths', '-i', dest, '-laplacian', paramreg.steps[i_step_str].laplacian + ',' + paramreg.steps[i_step_str].laplacian + ',0', '-o', sct.add_suffix(dest, '_laplacian') ]) src = sct.add_suffix(src, '_laplacian') dest = sct.add_suffix(dest, '_laplacian') # Estimate transformation sct.printv('\nEstimate transformation', param.verbose) scr_regStep = sct.add_suffix(src, '_regStep' + i_step_str) # TODO fixup isct_ants* parsers cmd = [ 'isct_antsRegistration', '--dimensionality', '3', '--transform', paramreg.steps[i_step_str].algo + '[' + paramreg.steps[i_step_str].gradStep + ants_registration_params[ paramreg.steps[i_step_str].algo.lower()] + ']', '--metric', paramreg.steps[i_step_str].metric + '[' + dest + ',' + src + ',1,' + metricSize + ']', '--convergence', paramreg.steps[i_step_str].iter, '--shrink-factors', paramreg.steps[i_step_str].shrink, '--smoothing-sigmas', paramreg.steps[i_step_str].smooth + 'mm', '--restrict-deformation', paramreg.steps[i_step_str].deformation, '--output', '[step' + i_step_str + ',' + scr_regStep + ']', '--interpolation', 'BSpline[3]', '--verbose', '1', ] + masking # add init translation if not paramreg.steps[i_step_str].init == '': init_dict = { 'geometric': '0', 'centermass': '1', 'origin': '2' } cmd += [ '-r', '[' + dest + ',' + src + ',' + init_dict[paramreg.steps[i_step_str].init] + ']' ] # run command status, output = sct.run(cmd, param.verbose) # get appropriate file name for transformation if paramreg.steps[i_step_str].algo in [ 'rigid', 'affine', 'translation' ]: warp_forward_out = 'step' + i_step_str + '0GenericAffine.mat' warp_inverse_out = '-step' + i_step_str + '0GenericAffine.mat' else: warp_forward_out = 'step' + i_step_str + '0Warp.nii.gz' warp_inverse_out = 'step' + i_step_str + '0InverseWarp.nii.gz' # ANTS 2d elif paramreg.steps[i_step_str].algo.lower() in ants_registration_params \ and paramreg.steps[i_step_str].slicewise == '1': # make sure type!=label. If type==label, this will be addressed later in the code. if not paramreg.steps[i_step_str].type == 'label': from msct_register import register_slicewise # if shrink!=1, force it to be 1 (otherwise, it generates a wrong 3d warping field). TODO: fix that! if not paramreg.steps[i_step_str].shrink == '1': sct.printv( '\nWARNING: when using slicewise with SyN or BSplineSyN, shrink factor needs to be one. ' 'Forcing shrink=1.', 1, 'warning') paramreg.steps[i_step_str].shrink = '1' warp_forward_out = 'step' + i_step_str + 'Warp.nii.gz' warp_inverse_out = 'step' + i_step_str + 'InverseWarp.nii.gz' register_slicewise( src, dest, paramreg=paramreg.steps[i_step_str], fname_mask=fname_mask, warp_forward_out=warp_forward_out, warp_inverse_out=warp_inverse_out, ants_registration_params=ants_registration_params, path_qc=param.path_qc, remove_temp_files=param.remove_temp_files, verbose=param.verbose) # slice-wise transfo elif paramreg.steps[i_step_str].algo in [ 'centermass', 'centermassrot', 'columnwise' ]: # if type=im, sends warning if paramreg.steps[i_step_str].type == 'im': sct.printv( '\nWARNING: algo ' + paramreg.steps[i_step_str].algo + ' should be used with type=seg.\n', 1, 'warning') # if type=label, exit with error elif paramreg.steps[i_step_str].type == 'label': sct.printv( '\nERROR: this algo is not compatible with type=label. Please use type=im or type=seg', 1, 'error') # check if user provided a mask-- if so, inform it will be ignored if not fname_mask == '': sct.printv( '\nWARNING: algo ' + paramreg.steps[i_step_str].algo + ' will ignore the provided mask.\n', 1, 'warning') # smooth data if not paramreg.steps[i_step_str].smooth == '0': sct.printv('\nSmooth data', param.verbose) sct.run([ 'sct_maths', '-i', src, '-smooth', paramreg.steps[i_step_str].smooth + ',' + paramreg.steps[i_step_str].smooth + ',0', '-o', sct.add_suffix(src, '_smooth') ]) sct.run([ 'sct_maths', '-i', dest, '-smooth', paramreg.steps[i_step_str].smooth + ',' + paramreg.steps[i_step_str].smooth + ',0', '-o', sct.add_suffix(dest, '_smooth') ]) src = sct.add_suffix(src, '_smooth') dest = sct.add_suffix(dest, '_smooth') from msct_register import register_slicewise warp_forward_out = 'step' + i_step_str + 'Warp.nii.gz' warp_inverse_out = 'step' + i_step_str + 'InverseWarp.nii.gz' register_slicewise(src, dest, paramreg=paramreg.steps[i_step_str], fname_mask=fname_mask, warp_forward_out=warp_forward_out, warp_inverse_out=warp_inverse_out, ants_registration_params=ants_registration_params, path_qc=param.path_qc, remove_temp_files=param.remove_temp_files, verbose=param.verbose) else: sct.printv( '\nERROR: algo ' + paramreg.steps[i_step_str].algo + ' does not exist. Exit program\n', 1, 'error') # landmark-based registration if paramreg.steps[i_step_str].type in ['label']: # check if user specified ilabel and dlabel # TODO warp_forward_out = 'step' + i_step_str + '0GenericAffine.txt' warp_inverse_out = '-step' + i_step_str + '0GenericAffine.txt' from msct_register_landmarks import register_landmarks register_landmarks(src, dest, paramreg.steps[i_step_str].dof, fname_affine=warp_forward_out, verbose=param.verbose, path_qc=param.path_qc) if not os.path.isfile(warp_forward_out): # no forward warping field for rigid and affine sct.printv( '\nERROR: file ' + warp_forward_out + ' doesn\'t exist (or is not a file).\n' + output + '\nERROR: ANTs failed. Exit program.\n', 1, 'error') elif not os.path.isfile(warp_inverse_out) and \ paramreg.steps[i_step_str].algo not in ['rigid', 'affine', 'translation'] and \ paramreg.steps[i_step_str].type not in ['label']: # no inverse warping field for rigid and affine sct.printv( '\nERROR: file ' + warp_inverse_out + ' doesn\'t exist (or is not a file).\n' + output + '\nERROR: ANTs failed. Exit program.\n', 1, 'error') else: # rename warping fields if (paramreg.steps[i_step_str].algo.lower() in ['rigid', 'affine', 'translation'] and paramreg.steps[i_step_str].slicewise == '0'): # if ANTs is used with affine/rigid --> outputs .mat file warp_forward = 'warp_forward_' + i_step_str + '.mat' os.rename(warp_forward_out, warp_forward) warp_inverse = '-warp_forward_' + i_step_str + '.mat' elif paramreg.steps[i_step_str].type in ['label']: # if label-based registration is used --> outputs .txt file warp_forward = 'warp_forward_' + i_step_str + '.txt' os.rename(warp_forward_out, warp_forward) warp_inverse = '-warp_forward_' + i_step_str + '.txt' else: warp_forward = 'warp_forward_' + i_step_str + '.nii.gz' warp_inverse = 'warp_inverse_' + i_step_str + '.nii.gz' os.rename(warp_forward_out, warp_forward) os.rename(warp_inverse_out, warp_inverse) return warp_forward, warp_inverse
def main(args=None): if args is None: args = sys.argv[1:] # initialize parameters param = Param() # Initialization fname_output = '' path_out = '' fname_src_seg = '' fname_dest_seg = '' fname_src_label = '' fname_dest_label = '' generate_warpinv = 1 start_time = time.time() # get path of the toolbox path_sct = os.environ.get("SCT_DIR", os.path.dirname(os.path.dirname(__file__))) # get default registration parameters # step1 = Paramreg(step='1', type='im', algo='syn', metric='MI', iter='5', shrink='1', smooth='0', gradStep='0.5') step0 = Paramreg( step='0', type='im', algo='syn', metric='MI', iter='0', shrink='1', smooth='0', gradStep='0.5', slicewise='0', dof='Tx_Ty_Tz_Rx_Ry_Rz') # only used to put src into dest space step1 = Paramreg(step='1', type='im') paramreg = ParamregMultiStep([step0, step1]) parser = get_parser(paramreg=paramreg) arguments = parser.parse(args) # get arguments fname_src = arguments['-i'] fname_dest = arguments['-d'] if '-iseg' in arguments: fname_src_seg = arguments['-iseg'] if '-dseg' in arguments: fname_dest_seg = arguments['-dseg'] if '-ilabel' in arguments: fname_src_label = arguments['-ilabel'] if '-dlabel' in arguments: fname_dest_label = arguments['-dlabel'] if '-o' in arguments: fname_output = arguments['-o'] if '-ofolder' in arguments: path_out = arguments['-ofolder'] if '-owarp' in arguments: fname_output_warp = arguments['-owarp'] else: fname_output_warp = '' if '-initwarp' in arguments: fname_initwarp = os.path.abspath(arguments['-initwarp']) else: fname_initwarp = '' if '-initwarpinv' in arguments: fname_initwarpinv = os.path.abspath(arguments['-initwarpinv']) else: fname_initwarpinv = '' if '-m' in arguments: fname_mask = arguments['-m'] else: fname_mask = '' padding = arguments['-z'] if "-param" in arguments: paramreg_user = arguments['-param'] # update registration parameters for paramStep in paramreg_user: paramreg.addStep(paramStep) identity = int(arguments['-identity']) interp = arguments['-x'] remove_temp_files = int(arguments['-r']) verbose = int(arguments['-v']) # sct.printv(arguments) sct.printv('\nInput parameters:') sct.printv(' Source .............. ' + fname_src) sct.printv(' Destination ......... ' + fname_dest) sct.printv(' Init transfo ........ ' + fname_initwarp) sct.printv(' Mask ................ ' + fname_mask) sct.printv(' Output name ......... ' + fname_output) # sct.printv(' Algorithm ........... '+paramreg.algo) # sct.printv(' Number of iterations '+paramreg.iter) # sct.printv(' Metric .............. '+paramreg.metric) sct.printv(' Remove temp files ... ' + str(remove_temp_files)) sct.printv(' Verbose ............. ' + str(verbose)) # update param param.verbose = verbose param.padding = padding param.fname_mask = fname_mask param.remove_temp_files = remove_temp_files # Get if input is 3D sct.printv('\nCheck if input data are 3D...', verbose) sct.check_if_3d(fname_src) sct.check_if_3d(fname_dest) # Check if user selected type=seg, but did not input segmentation data if 'paramreg_user' in locals(): if True in [ 'type=seg' in paramreg_user[i] for i in range(len(paramreg_user)) ]: if fname_src_seg == '' or fname_dest_seg == '': sct.printv( '\nERROR: if you select type=seg you must specify -iseg and -dseg flags.\n', 1, 'error') # Extract path, file and extension path_src, file_src, ext_src = sct.extract_fname(fname_src) path_dest, file_dest, ext_dest = sct.extract_fname(fname_dest) # check if source and destination images have the same name (related to issue #373) # If so, change names to avoid conflict of result files and warns the user suffix_src, suffix_dest = '_reg', '_reg' if file_src == file_dest: suffix_src, suffix_dest = '_src_reg', '_dest_reg' # define output folder and file name if fname_output == '': path_out = '' if not path_out else path_out # output in user's current directory file_out = file_src + suffix_src file_out_inv = file_dest + suffix_dest ext_out = ext_src else: path, file_out, ext_out = sct.extract_fname(fname_output) path_out = path if not path_out else path_out file_out_inv = file_out + '_inv' # create temporary folder path_tmp = sct.tmp_create() sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose) Image(fname_src).save(os.path.join(path_tmp, "src.nii")) Image(fname_dest).save(os.path.join(path_tmp, "dest.nii")) if fname_src_seg: Image(fname_src_seg).save(os.path.join(path_tmp, "src_seg.nii")) Image(fname_dest_seg).save(os.path.join(path_tmp, "dest_seg.nii")) if fname_src_label: Image(fname_src_label).save(os.path.join(path_tmp, "src_label.nii")) Image(fname_dest_label).save(os.path.join(path_tmp, "dest_label.nii")) if fname_mask != '': Image(fname_mask).save(os.path.join(path_tmp, "mask.nii.gz")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # reorient destination to RPI Image('dest.nii').change_orientation("RPI").save('dest_RPI.nii') if fname_dest_seg: Image('dest_seg.nii').change_orientation("RPI").save( 'dest_seg_RPI.nii') if fname_dest_label: Image('dest_label.nii').change_orientation("RPI").save( 'dest_label_RPI.nii') if identity: # overwrite paramreg and only do one identity transformation step0 = Paramreg(step='0', type='im', algo='syn', metric='MI', iter='0', shrink='1', smooth='0', gradStep='0.5') paramreg = ParamregMultiStep([step0]) # Put source into destination space using header (no estimation -- purely based on header) # TODO: Check if necessary to do that # TODO: use that as step=0 # sct.printv('\nPut source into destination space using header...', verbose) # sct.run('isct_antsRegistration -d 3 -t Translation[0] -m MI[dest_pad.nii,src.nii,1,16] -c 0 -f 1 -s 0 -o # [regAffine,src_regAffine.nii] -n BSpline[3]', verbose) # if segmentation, also do it for seg # initialize list of warping fields warp_forward = [] warp_inverse = [] # initial warping is specified, update list of warping fields and skip step=0 if fname_initwarp: sct.printv('\nSkip step=0 and replace with initial transformations: ', param.verbose) sct.printv(' ' + fname_initwarp, param.verbose) # sct.copy(fname_initwarp, 'warp_forward_0.nii.gz') warp_forward = [fname_initwarp] start_step = 1 if fname_initwarpinv: warp_inverse = [fname_initwarpinv] else: sct.printv( '\nWARNING: No initial inverse warping field was specified, therefore the inverse warping field ' 'will NOT be generated.', param.verbose, 'warning') generate_warpinv = 0 else: start_step = 0 # loop across registration steps for i_step in range(start_step, len(paramreg.steps)): sct.printv('\n--\nESTIMATE TRANSFORMATION FOR STEP #' + str(i_step), param.verbose) # identify which is the src and dest if paramreg.steps[str(i_step)].type == 'im': src = 'src.nii' dest = 'dest_RPI.nii' interp_step = 'spline' elif paramreg.steps[str(i_step)].type == 'seg': src = 'src_seg.nii' dest = 'dest_seg_RPI.nii' interp_step = 'nn' elif paramreg.steps[str(i_step)].type == 'label': src = 'src_label.nii' dest = 'dest_label_RPI.nii' interp_step = 'nn' else: # src = dest = interp_step = None sct.printv('ERROR: Wrong image type.', 1, 'error') # if step>0, apply warp_forward_concat to the src image to be used if i_step > 0: sct.printv('\nApply transformation from previous step', param.verbose) sct.run([ 'sct_apply_transfo', '-i', src, '-d', dest, '-w', ','.join(warp_forward), '-o', sct.add_suffix(src, '_reg'), '-x', interp_step ], verbose) src = sct.add_suffix(src, '_reg') # register src --> dest warp_forward_out, warp_inverse_out = register(src, dest, paramreg, param, str(i_step)) warp_forward.append(warp_forward_out) warp_inverse.insert(0, warp_inverse_out) # Concatenate transformations sct.printv('\nConcatenate transformations...', verbose) sct.run([ 'sct_concat_transfo', '-w', ','.join(warp_forward), '-d', 'dest.nii', '-o', 'warp_src2dest.nii.gz' ], verbose) sct.run([ 'sct_concat_transfo', '-w', ','.join(warp_inverse), '-d', 'src.nii', '-o', 'warp_dest2src.nii.gz' ], verbose) # Apply warping field to src data sct.printv('\nApply transfo source --> dest...', verbose) sct.run([ 'sct_apply_transfo', '-i', 'src.nii', '-o', 'src_reg.nii', '-d', 'dest.nii', '-w', 'warp_src2dest.nii.gz', '-x', interp ], verbose) sct.printv('\nApply transfo dest --> source...', verbose) sct.run([ 'sct_apply_transfo', '-i', 'dest.nii', '-o', 'dest_reg.nii', '-d', 'src.nii', '-w', 'warp_dest2src.nii.gz', '-x', interp ], verbose) # come back os.chdir(curdir) # Generate output files sct.printv('\nGenerate output files...', verbose) # generate: src_reg fname_src2dest = sct.generate_output_file( os.path.join(path_tmp, "src_reg.nii"), os.path.join(path_out, file_out + ext_out), verbose) # generate: forward warping field if fname_output_warp == '': fname_output_warp = os.path.join( path_out, 'warp_' + file_src + '2' + file_dest + '.nii.gz') sct.generate_output_file(os.path.join(path_tmp, "warp_src2dest.nii.gz"), fname_output_warp, verbose) if generate_warpinv: # generate: dest_reg fname_dest2src = sct.generate_output_file( os.path.join(path_tmp, "dest_reg.nii"), os.path.join(path_out, file_out_inv + ext_dest), verbose) # generate: inverse warping field sct.generate_output_file( os.path.join(path_tmp, "warp_dest2src.nii.gz"), os.path.join(path_out, 'warp_' + file_dest + '2' + file_src + '.nii.gz'), verbose) # Delete temporary files if remove_temp_files: sct.printv('\nRemove temporary files...', verbose) sct.rmtree(path_tmp, verbose=verbose) # display elapsed time elapsed_time = time.time() - start_time sct.printv( '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose) if generate_warpinv: sct.display_viewer_syntax([fname_src, fname_dest2src], verbose=verbose) sct.display_viewer_syntax([fname_dest, fname_src2dest], verbose=verbose)
def find_centerline(algo, image_fname, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname): """ Assumes RPI orientation :param algo: :param image_fname: :param contrast_type: :param brain_bool: :param folder_output: :param remove_temp_files: :param centerline_fname: :return: """ im = Image(image_fname) ctl_absolute_path = add_suffix(im.absolutepath, "_ctr") # isct_spine_detect requires nz > 1 if im.dim[2] == 1: im = concat_data([im, im], dim=2) im.hdr['dim'][ 3] = 2 # Needs to be change manually since dim not updated during concat_data bool_2d = True else: bool_2d = False # TODO: maybe change 'svm' for 'optic', because this is how we call it in sct_get_centerline if algo == 'svm': # run optic on a heatmap computed by a trained SVM+HoG algorithm # optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type)) # # TODO: replace with get_centerline(method=optic) im_ctl, _, _, _ = get_centerline( im, ParamCenterline(algo_fitting='optic', contrast=contrast_type)) elif algo == 'cnn': # CNN parameters dct_patch_ctr = { 't2': { 'size': (80, 80), 'mean': 51.1417, 'std': 57.4408 }, 't2s': { 'size': (80, 80), 'mean': 68.8591, 'std': 71.4659 }, 't1': { 'size': (80, 80), 'mean': 55.7359, 'std': 64.3149 }, 'dwi': { 'size': (80, 80), 'mean': 55.744, 'std': 45.003 } } dct_params_ctr = { 't2': { 'features': 16, 'dilation_layers': 2 }, 't2s': { 'features': 8, 'dilation_layers': 3 }, 't1': { 'features': 24, 'dilation_layers': 3 }, 'dwi': { 'features': 8, 'dilation_layers': 2 } } # load model ctr_model_fname = sct_dir_local_path('data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type)) ctr_model = nn_architecture_ctr( height=dct_patch_ctr[contrast_type]['size'][0], width=dct_patch_ctr[contrast_type]['size'][1], channels=1, classes=1, features=dct_params_ctr[contrast_type]['features'], depth=2, temperature=1.0, padding='same', batchnorm=True, dropout=0.0, dilation_layers=dct_params_ctr[contrast_type]['dilation_layers']) ctr_model.load_weights(ctr_model_fname) # compute the heatmap im_heatmap, z_max = heatmap( im=im, model=ctr_model, patch_shape=dct_patch_ctr[contrast_type]['size'], mean_train=dct_patch_ctr[contrast_type]['mean'], std_train=dct_patch_ctr[contrast_type]['std'], brain_bool=brain_bool) im_ctl, _, _, _ = get_centerline( im_heatmap, ParamCenterline(algo_fitting='optic', contrast=contrast_type)) if z_max is not None: logger.info('Cropping brain section.') im_ctl.data[:, :, z_max:] = 0 elif algo == 'viewer': im_labels = _call_viewer_centerline(im) im_ctl, _, _, _ = get_centerline(im_labels, param=ParamCenterline()) elif algo == 'file': im_ctl = Image(centerline_fname) im_ctl.change_orientation('RPI') else: logger.error( 'The parameter "-centerline" is incorrect. Please try again.') sys.exit(1) # TODO: for some reason, when algo == 'file', the absolutepath is changed to None out of the method find_centerline im_ctl.absolutepath = ctl_absolute_path if bool_2d: im_ctl = split_data(im_ctl, dim=2)[0] if algo != 'viewer': im_labels = None # TODO: remove unecessary return params return "dummy_file_name", im_ctl, im_labels
def aggregate_per_slice_or_level(metric, mask=None, slices=[], levels=[], perslice=None, perlevel=False, vert_level=None, group_funcs=(('MEAN', func_wa), ), map_clusters=None): """ The aggregation will be performed along the last dimension of 'metric' ndarray. :param metric: Class Metric(): data to aggregate. :param mask: Class Metric(): mask to use for aggregating the data. Optional. :param slices: List[int]: Slices to aggregate metric from. If empty, select all slices. :param levels: List[int]: Vertebral levels to aggregate metric from. It has priority over "slices". :param Bool perslice: Aggregate per slice (True) or across slices (False) :param Bool perlevel: Aggregate per level (True) or across levels (False). Has priority over "perslice". :param vert_level: Vertebral level. Could be either an Image or a file name. :param tuple group_funcs: Name and function to apply on metric. Example: (('MEAN', func_wa),)). Note, the function has special requirements in terms of i/o. See the definition to func_wa and use it as a template. :param map_clusters: list of list of int: See func_map() :return: Aggregated metric """ # If user neither specified slices nor levels, set perslice=True, otherwise, the output will likely contain nan # because in many cases the segmentation does not span the whole I-S dimension. if perslice is None: if not slices and not levels: perslice = True else: perslice = False # if slices is empty, select all available slices from the metric ndim = metric.data.ndim if not slices: slices = range(metric.data.shape[ndim - 1]) # aggregation based on levels if levels: im_vert_level = Image(vert_level).change_orientation('RPI') # slicegroups = [(0, 1, 2), (3, 4, 5), (6, 7, 8)] slicegroups = [ tuple(get_slices_from_vertebral_levels(im_vert_level, level)) for level in levels ] if perlevel: # vertgroups = [(2,), (3,), (4,)] vertgroups = [tuple([level]) for level in levels] elif perslice: # slicegroups = [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)] slicegroups = [ tuple([i]) for i in functools.reduce(operator.concat, slicegroups) ] # reduce to individual tuple # vertgroups = [(2,), (2,), (2,), (3,), (3,), (3,), (4,), (4,), (4,)] vertgroups = [ tuple([get_vertebral_level_from_slice(im_vert_level, i[0])]) for i in slicegroups ] # output aggregate metric across levels else: # slicegroups = [(0, 1, 2, 3, 4, 5, 6, 7, 8)] slicegroups = [ tuple([val for sublist in slicegroups for val in sublist]) ] # flatten into single tuple # vertgroups = [(2, 3, 4)] vertgroups = [tuple([level for level in levels])] # aggregation based on slices else: vertgroups = None if perslice: # slicegroups = [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)] slicegroups = [tuple([slice]) for slice in slices] else: # slicegroups = [(0, 1, 2, 3, 4, 5, 6, 7, 8)] slicegroups = [tuple(slices)] agg_metric = dict((slicegroup, dict()) for slicegroup in slicegroups) # loop across slice group for slicegroup in slicegroups: # add level info if vertgroups is None: agg_metric[slicegroup]['VertLevel'] = None else: agg_metric[slicegroup]['VertLevel'] = vertgroups[slicegroups.index( slicegroup)] # Loop across functions (e.g.: MEAN, STD) for (name, func) in group_funcs: try: data_slicegroup = metric.data[ ..., slicegroup] # selection is done in the last dimension if mask is not None: mask_slicegroup = mask.data[..., slicegroup, :] agg_metric[slicegroup]['Label'] = mask.label # Add volume fraction agg_metric[slicegroup]['Size [vox]'] = np.sum( mask_slicegroup.flatten()) else: mask_slicegroup = np.ones(data_slicegroup.shape) # Ignore nonfinite values i_nonfinite = np.where(np.isfinite(data_slicegroup) == False) data_slicegroup[i_nonfinite] = 0. # TODO: the lines below could probably be done more elegantly if mask_slicegroup.ndim == data_slicegroup.ndim + 1: arr_tmp_concat = [] for i in range(mask_slicegroup.shape[-1]): arr_tmp = np.reshape(mask_slicegroup[..., i], data_slicegroup.shape) arr_tmp[i_nonfinite] = 0. arr_tmp_concat.append( np.expand_dims(arr_tmp, axis=(mask_slicegroup.ndim - 1))) mask_slicegroup = np.concatenate( arr_tmp_concat, axis=(mask_slicegroup.ndim - 1)) else: mask_slicegroup[i_nonfinite] = 0. # Make sure the number of pixels to extract metrics is not null if mask_slicegroup.sum() == 0: result = None else: # Run estimation result, _ = func(data_slicegroup, mask_slicegroup, map_clusters) # check if nan if np.isnan(result): result = None # here we create a field with name: FUNC(METRIC_NAME). Example: MEAN(CSA) agg_metric[slicegroup]['{}({})'.format(name, metric.label)] = result except Exception as e: logging.warning(e) agg_metric[slicegroup]['{}({})'.format(name, metric.label)] = str(e) return agg_metric
def moco(param): """ Main function that performs motion correction. :param param: :return: """ # retrieve parameters file_data = param.file_data file_target = param.file_target folder_mat = param.mat_moco # output folder of mat file todo = param.todo suffix = param.suffix verbose = param.verbose # other parameters file_mask = 'mask.nii' sct.printv('\nInput parameters:', param.verbose) sct.printv(' Input file ............ ' + file_data, param.verbose) sct.printv(' Reference file ........ ' + file_target, param.verbose) sct.printv(' Polynomial degree ..... ' + param.poly, param.verbose) sct.printv(' Smoothing kernel ...... ' + param.smooth, param.verbose) sct.printv(' Gradient step ......... ' + param.gradStep, param.verbose) sct.printv(' Metric ................ ' + param.metric, param.verbose) sct.printv(' Sampling .............. ' + param.sampling, param.verbose) sct.printv(' Todo .................. ' + todo, param.verbose) sct.printv(' Mask ................. ' + param.fname_mask, param.verbose) sct.printv(' Output mat folder ..... ' + folder_mat, param.verbose) try: os.makedirs(folder_mat) except FileExistsError: pass # Get size of data sct.printv('\nData dimensions:', verbose) im_data = Image(param.file_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim sct.printv( (' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt)), verbose) # copy file_target to a temporary file sct.printv('\nCopy file_target to a temporary file...', verbose) file_target = "target.nii.gz" convert(param.file_target, file_target, verbose=0) # Check if user specified a mask if not param.fname_mask == '': # Check if this mask is soft (i.e., non-binary, such as a Gaussian mask) im_mask = Image(param.fname_mask) if not np.array_equal(im_mask.data, im_mask.data.astype(bool)): # If it is a soft mask, multiply the target by the soft mask. im = Image(file_target) im_masked = im.copy() im_masked.data = im.data * im_mask.data im_masked.save( verbose=0) # silence warning about file overwritting # If scan is sagittal, split src and target along Z (slice) if param.is_sagittal: dim_sag = 2 # TODO: find it # z-split data (time series) im_z_list = split_data(im_data, dim=dim_sag, squeeze_data=False) file_data_splitZ = [] for im_z in im_z_list: im_z.save(verbose=0) file_data_splitZ.append(im_z.absolutepath) # z-split target im_targetz_list = split_data(Image(file_target), dim=dim_sag, squeeze_data=False) file_target_splitZ = [] for im_targetz in im_targetz_list: im_targetz.save(verbose=0) file_target_splitZ.append(im_targetz.absolutepath) # z-split mask (if exists) if not param.fname_mask == '': im_maskz_list = split_data(Image(file_mask), dim=dim_sag, squeeze_data=False) file_mask_splitZ = [] for im_maskz in im_maskz_list: im_maskz.save(verbose=0) file_mask_splitZ.append(im_maskz.absolutepath) # initialize file list for output matrices file_mat = np.empty((nz, nt), dtype=object) # axial orientation else: file_data_splitZ = [file_data] # TODO: make it absolute like above file_target_splitZ = [file_target] # TODO: make it absolute like above # initialize file list for output matrices file_mat = np.empty((1, nt), dtype=object) # deal with mask if not param.fname_mask == '': convert(param.fname_mask, file_mask, squeeze_data=False, verbose=0) im_maskz_list = [Image(file_mask) ] # use a list with single element # Loop across file list, where each file is either a 2D volume (if sagittal) or a 3D volume (otherwise) # file_mat = tuple([[[] for i in range(nt)] for i in range(nz)]) file_data_splitZ_moco = [] sct.printv( '\nRegister. Loop across Z (note: there is only one Z if orientation is axial)' ) for file in file_data_splitZ: iz = file_data_splitZ.index(file) # Split data along T dimension # sct.printv('\nSplit data along T dimension.', verbose) im_z = Image(file) list_im_zt = split_data(im_z, dim=3) file_data_splitZ_splitT = [] for im_zt in list_im_zt: im_zt.save(verbose=0) file_data_splitZ_splitT.append(im_zt.absolutepath) # file_data_splitT = file_data + '_T' # Motion correction: initialization index = np.arange(nt) file_data_splitT_num = [] file_data_splitZ_splitT_moco = [] failed_transfo = [0 for i in range(nt)] # Motion correction: Loop across T for indice_index in sct_progress_bar(range(nt), unit='iter', unit_scale=False, desc="Z=" + str(iz) + "/" + str(len(file_data_splitZ) - 1), ascii=False, ncols=80): # create indices and display stuff it = index[indice_index] file_mat[iz][it] = os.path.join( folder_mat, "mat.Z") + str(iz).zfill(4) + 'T' + str(it).zfill(4) file_data_splitZ_splitT_moco.append( sct.add_suffix(file_data_splitZ_splitT[it], '_moco')) # deal with masking (except in the 'apply' case, where masking is irrelevant) input_mask = None if not param.fname_mask == '' and not param.todo == 'apply': # Check if mask is binary if np.array_equal(im_maskz_list[iz].data, im_maskz_list[iz].data.astype(bool)): # If it is, pass this mask into register() to be used input_mask = im_maskz_list[iz] else: # If not, do not pass this mask into register() because ANTs cannot handle non-binary masks. # Instead, multiply the input data by the Gaussian mask. im = Image(file_data_splitZ_splitT[it]) im_masked = im.copy() im_masked.data = im.data * im_maskz_list[iz].data im_masked.save( verbose=0) # silence warning about file overwritting # run 3D registration failed_transfo[it] = register(param, file_data_splitZ_splitT[it], file_target_splitZ[iz], file_mat[iz][it], file_data_splitZ_splitT_moco[it], im_mask=input_mask) # average registered volume with target image # N.B. use weighted averaging: (target * nb_it + moco) / (nb_it + 1) if param.iterAvg and indice_index < 10 and failed_transfo[ it] == 0 and not param.todo == 'apply': im_targetz = Image(file_target_splitZ[iz]) data_targetz = im_targetz.data data_mocoz = Image(file_data_splitZ_splitT_moco[it]).data data_targetz = (data_targetz * (indice_index + 1) + data_mocoz) / (indice_index + 2) im_targetz.data = data_targetz im_targetz.save(verbose=0) # Replace failed transformation with the closest good one fT = [i for i, j in enumerate(failed_transfo) if j == 1] gT = [i for i, j in enumerate(failed_transfo) if j == 0] for it in range(len(fT)): abs_dist = [np.abs(gT[i] - fT[it]) for i in range(len(gT))] if not abs_dist == []: index_good = abs_dist.index(min(abs_dist)) sct.printv( ' transfo #' + str(fT[it]) + ' --> use transfo #' + str(gT[index_good]), verbose) # copy transformation sct.copy(file_mat[iz][gT[index_good]] + 'Warp.nii.gz', file_mat[iz][fT[it]] + 'Warp.nii.gz') # apply transformation sct_apply_transfo.main(args=[ '-i', file_data_splitZ_splitT[fT[it]], '-d', file_target, '-w', file_mat[iz][fT[it]] + 'Warp.nii.gz', '-o', file_data_splitZ_splitT_moco[fT[it]], '-x', param.interp ]) else: # exit program if no transformation exists. sct.printv( '\nERROR in ' + os.path.basename(__file__) + ': No good transformation exist. Exit program.\n', verbose, 'error') sys.exit(2) # Merge data along T file_data_splitZ_moco.append(sct.add_suffix(file, suffix)) if todo != 'estimate': im_out = concat_data(file_data_splitZ_splitT_moco, 3) im_out.absolutepath = file_data_splitZ_moco[iz] im_out.save(verbose=0) # If sagittal, merge along Z if param.is_sagittal: # TODO: im_out.dim is incorrect: Z value is one im_out = concat_data(file_data_splitZ_moco, 2) dirname, basename, ext = sct.extract_fname(file_data) path_out = os.path.join(dirname, basename + suffix + ext) im_out.absolutepath = path_out im_out.save(verbose=0) return file_mat, im_out
def register(param, file_src, file_dest, file_mat, file_out, im_mask=None): """ Register two images by estimating slice-wise Tx and Ty transformations, which are regularized along Z. This function uses ANTs' isct_antsSliceRegularizedRegistration. :param param: :param file_src: :param file_dest: :param file_mat: :param file_out: :param im_mask: Image of mask, could be 2D or 3D :return: """ # TODO: deal with mask # initialization failed_transfo = 0 # by default, failed matrix is 0 (i.e., no failure) do_registration = True # get metric radius (if MeanSquares, CC) or nb bins (if MI) if param.metric == 'MI': metric_radius = '16' else: metric_radius = '4' file_out_concat = file_out kw = dict() im_data = Image( file_src ) # TODO: pass argument to use antsReg instead of opening Image each time # register file_src to file_dest if param.todo == 'estimate' or param.todo == 'estimate_and_apply': # If orientation is sagittal, use antsRegistration in 2D mode # Note: the parameter --restrict-deformation is irrelevant with affine transfo if param.sampling == 'None': # 'None' sampling means 'fully dense' sampling # see https://github.com/ANTsX/ANTs/wiki/antsRegistration-reproducibility-issues sampling = param.sampling else: # param.sampling should be a float in [0,1], and means the # samplingPercentage that chooses a subset of points to # estimate from. We always use 'Regular' (evenly-spaced) # mode, though antsRegistration offers 'Random' as well. # Be aware: even 'Regular' is not fully deterministic: # > Regular includes a random perturbation on the grid sampling # - https://github.com/ANTsX/ANTs/issues/976#issuecomment-602313884 sampling = 'Regular,' + param.sampling if im_data.orientation[2] in 'LR': cmd = [ 'isct_antsRegistration', '-d', '2', '--transform', 'Affine[%s]' % param.gradStep, '--metric', param.metric + '[' + file_dest + ',' + file_src + ',1,' + metric_radius + ',' + sampling + ']', '--convergence', param.iter, '--shrink-factors', '1', '--smoothing-sigmas', param.smooth, '--verbose', '1', '--output', '[' + file_mat + ',' + file_out_concat + ']' ] cmd += sct.get_interpolation('isct_antsRegistration', param.interp) if im_mask is not None: # if user specified a mask, make sure there are non-null voxels in the image before running the registration if np.count_nonzero(im_mask.data): cmd += ['--masks', im_mask.absolutepath] else: # Mask only contains zeros. Copying the image instead of estimating registration. sct.copy(file_src, file_out_concat, verbose=0) do_registration = False # TODO: create affine mat file with identity, in case used by -g 2 # 3D mode else: cmd = [ 'isct_antsSliceRegularizedRegistration', '--polydegree', param.poly, '--transform', 'Translation[%s]' % param.gradStep, '--metric', param.metric + '[' + file_dest + ',' + file_src + ',1,' + metric_radius + ',' + sampling + ']', '--iterations', param.iter, '--shrinkFactors', '1', '--smoothingSigmas', param.smooth, '--verbose', '1', '--output', '[' + file_mat + ',' + file_out_concat + ']' ] cmd += sct.get_interpolation( 'isct_antsSliceRegularizedRegistration', param.interp) if im_mask is not None: cmd += ['--mask', im_mask.absolutepath] # run command if do_registration: kw.update(dict(is_sct_binary=True)) # reducing the number of CPU used for moco (see issue #201 and #2642) env = { **os.environ, **{ "ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS": "1" } } status, output = run_proc(cmd, verbose=1 if param.verbose == 2 else 0, env=env, **kw) elif param.todo == 'apply': sct_apply_transfo.main(args=[ '-i', file_src, '-d', file_dest, '-w', file_mat + param.suffix_mat, '-o', file_out_concat, '-x', param.interp, '-v', '0' ]) # check if output file exists # Note (from JCA): In the past, i've tried to catch non-zero output from ANTs function (via the 'status' variable), # but in some OSs, the function can fail while outputing zero. So as a pragmatic approach, I decided to go with # the "output file checking" approach, which is 100% sensitive. if not os.path.isfile(file_out_concat): # sct.printv(output, verbose, 'error') sct.printv( 'WARNING in ' + os.path.basename(__file__) + ': No output. Maybe related to improper calculation of ' 'mutual information. Either the mask you provided is ' 'too small, or the subject moved a lot. If you see too ' 'many messages like this try with a bigger mask. ' 'Using previous transformation for this volume (if it' 'exists).', param.verbose, 'warning') failed_transfo = 1 # If sagittal, copy header (because ANTs screws it) and add singleton in 3rd dimension (for z-concatenation) if im_data.orientation[2] in 'LR' and do_registration: im_out = Image(file_out_concat) im_out.header = im_data.header im_out.data = np.expand_dims(im_out.data, 2) im_out.save(file_out, verbose=0) # return status of failure return failed_transfo
def moco_wrapper(param): """ Wrapper that performs motion correction. :param param: ParamMoco class :return: None """ file_data = 'data.nii' # corresponds to the full input data (e.g. dmri or fmri) file_data_dirname, file_data_basename, file_data_ext = sct.extract_fname( file_data) file_b0 = 'b0.nii' file_datasub = 'datasub.nii' # corresponds to the full input data minus the b=0 scans (if param.is_diffusion=True) file_datasubgroup = 'datasub-groups.nii' # concatenation of the average of each file_datasub file_mask = 'mask.nii' file_moco_params_csv = 'moco_params.tsv' file_moco_params_x = 'moco_params_x.nii.gz' file_moco_params_y = 'moco_params_y.nii.gz' ext_data = '.nii.gz' # workaround "too many open files" by slurping the data # TODO: check if .nii can be used mat_final = 'mat_final/' # ext_mat = 'Warp.nii.gz' # warping field # Start timer start_time = time.time() sct.printv('\nInput parameters:', param.verbose) sct.printv(' Input file ............ ' + param.fname_data, param.verbose) sct.printv(' Group size ............ {}'.format(param.group_size), param.verbose) # Get full path # param.fname_data = os.path.abspath(param.fname_data) # param.fname_bvecs = os.path.abspath(param.fname_bvecs) # if param.fname_bvals != '': # param.fname_bvals = os.path.abspath(param.fname_bvals) # Extract path, file and extension # path_data, file_data, ext_data = sct.extract_fname(param.fname_data) # path_mask, file_mask, ext_mask = sct.extract_fname(param.fname_mask) path_tmp = sct.tmp_create(basename="moco", verbose=param.verbose) # Copying input data to tmp folder sct.printv('\nCopying input data to tmp folder and convert to nii...', param.verbose) convert(param.fname_data, os.path.join(path_tmp, file_data)) if param.fname_mask != '': convert(param.fname_mask, os.path.join(path_tmp, file_mask), verbose=param.verbose) # Update field in param (because used later in another function, and param class will be passed) param.fname_mask = file_mask # Build absolute output path and go to tmp folder curdir = os.getcwd() path_out_abs = os.path.abspath(param.path_out) os.chdir(path_tmp) # Get dimensions of data sct.printv('\nGet dimensions of data...', param.verbose) im_data = Image(file_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim sct.printv(' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose) # Get orientation sct.printv('\nData orientation: ' + im_data.orientation, param.verbose) if im_data.orientation[2] in 'LR': param.is_sagittal = True sct.printv(' Treated as sagittal') elif im_data.orientation[2] in 'IS': param.is_sagittal = False sct.printv(' Treated as axial') else: param.is_sagittal = False sct.printv( 'WARNING: Orientation seems to be neither axial nor sagittal. Treated as axial.' ) sct.printv( "\nSet suffix of transformation file name, which depends on the orientation:" ) if param.is_sagittal: param.suffix_mat = '0GenericAffine.mat' sct.printv( "Orientation is sagittal, suffix is '{}'. The image is split across the R-L direction, and the " "estimated transformation is a 2D affine transfo.".format( param.suffix_mat)) else: param.suffix_mat = 'Warp.nii.gz' sct.printv( "Orientation is axial, suffix is '{}'. The estimated transformation is a 3D warping field, which is " "composed of a stack of 2D Tx-Ty transformations".format( param.suffix_mat)) # Adjust group size in case of sagittal scan if param.is_sagittal and param.group_size != 1: sct.printv( 'For sagittal data group_size should be one for more robustness. Forcing group_size=1.', 1, 'warning') param.group_size = 1 if param.is_diffusion: # Identify b=0 and DWI images index_b0, index_dwi, nb_b0, nb_dwi = \ sct_dmri_separate_b0_and_dwi.identify_b0(param.fname_bvecs, param.fname_bvals, param.bval_min, param.verbose) # check if dmri and bvecs are the same size if not nb_b0 + nb_dwi == nt: sct.printv( '\nERROR in ' + os.path.basename(__file__) + ': Size of data (' + str(nt) + ') and size of bvecs (' + str(nb_b0 + nb_dwi) + ') are not the same. Check your bvecs file.\n', 1, 'error') sys.exit(2) # ================================================================================================================== # Prepare data (mean/groups...) # ================================================================================================================== # Split into T dimension sct.printv('\nSplit along T dimension...', param.verbose) im_data_split_list = split_data(im_data, 3) for im in im_data_split_list: x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath) im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz") im.save() if param.is_diffusion: # Merge and average b=0 images sct.printv('\nMerge and average b=0 data...', param.verbose) im_b0_list = [] for it in range(nb_b0): im_b0_list.append(im_data_split_list[index_b0[it]]) im_b0 = concat_data(im_b0_list, 3).save(file_b0, verbose=0) # Average across time im_b0.mean(dim=3).save(sct.add_suffix(file_b0, '_mean')) n_moco = nb_dwi # set number of data to perform moco on (using grouping) index_moco = index_dwi # If not a diffusion scan, we will motion-correct all volumes else: n_moco = nt index_moco = list(range(0, nt)) nb_groups = int(math.floor(n_moco / param.group_size)) # Generate groups indexes group_indexes = [] for iGroup in range(nb_groups): group_indexes.append(index_moco[(iGroup * param.group_size):((iGroup + 1) * param.group_size)]) # add the remaining images to a new last group (in case the total number of image is not divisible by group_size) nb_remaining = n_moco % param.group_size # number of remaining images if nb_remaining > 0: nb_groups += 1 group_indexes.append(index_moco[len(index_moco) - nb_remaining:len(index_moco)]) _, file_dwi_basename, file_dwi_ext = sct.extract_fname(file_datasub) # Group data list_file_group = [] for iGroup in sct_progress_bar(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=False, ncols=80): # get index index_moco_i = group_indexes[iGroup] n_moco_i = len(index_moco_i) # concatenate images across time, within this group file_dwi_merge_i = os.path.join(file_dwi_basename + '_' + str(iGroup) + ext_data) im_dwi_list = [] for it in range(n_moco_i): im_dwi_list.append(im_data_split_list[index_moco_i[it]]) im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i, verbose=0) # Average across time list_file_group.append( os.path.join(file_dwi_basename + '_' + str(iGroup) + '_mean' + ext_data)) im_dwi_out.mean(dim=3).save(list_file_group[-1]) # Merge across groups sct.printv('\nMerge across groups...', param.verbose) # file_dwi_groups_means_merge = 'dwi_averaged_groups' im_dw_list = [] for iGroup in range(nb_groups): im_dw_list.append(list_file_group[iGroup]) concat_data(im_dw_list, 3).save(file_datasubgroup, verbose=0) # Cleanup del im, im_data_split_list # ================================================================================================================== # Estimate moco # ================================================================================================================== # Initialize another class instance that will be passed on to the moco() function param_moco = deepcopy(param) if param.is_diffusion: # Estimate moco on b0 groups sct.printv( '\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion on b=0 images...', param.verbose) sct.printv( '-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = 'b0.nii' # Identify target image if index_moco[0] != 0: # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that # case select it as the target image for registration of all b=0 param_moco.file_target = os.path.join( file_data_dirname, file_data_basename + '_T' + str(index_b0[index_moco[0] - 1]).zfill(4) + ext_data) else: # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0. param_moco.file_target = os.path.join( file_data_dirname, file_data_basename + '_T' + str(index_b0[0]).zfill(4) + ext_data) # Run moco param_moco.path_out = '' param_moco.todo = 'estimate_and_apply' param_moco.mat_moco = 'mat_b0groups' file_mat_b0, _ = moco(param_moco) # Estimate moco across groups sct.printv( '\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion across groups...', param.verbose) sct.printv( '-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = file_datasubgroup param_moco.file_target = list_file_group[ 0] # target is the first volume (closest to the first b=0 if DWI scan) param_moco.path_out = '' param_moco.todo = 'estimate_and_apply' param_moco.mat_moco = 'mat_groups' file_mat_datasub_group, _ = moco(param_moco) # Spline Regularization along T if param.spline_fitting: # TODO: fix this scenario (haven't touched that code for a while-- it is probably buggy) raise NotImplementedError() # spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph) # ================================================================================================================== # Apply moco # ================================================================================================================== # If group_size>1, assign transformation to each individual ungrouped 3d volume if param.group_size > 1: file_mat_datasub = [] for iz in range(len(file_mat_datasub_group)): # duplicate by factor group_size the transformation file for each it # example: [mat.Z0000T0001Warp.nii] --> [mat.Z0000T0001Warp.nii, mat.Z0000T0001Warp.nii] for group_size=2 file_mat_datasub.append( functools.reduce(operator.iconcat, [[i] * param.group_size for i in file_mat_datasub_group[iz]], [])) else: file_mat_datasub = file_mat_datasub_group # Copy transformations to mat_final folder and rename them appropriately copy_mat_files(nt, file_mat_datasub, index_moco, mat_final, param) if param.is_diffusion: copy_mat_files(nt, file_mat_b0, index_b0, mat_final, param) # Apply moco on all dmri data sct.printv( '\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Apply moco', param.verbose) sct.printv( '-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = file_data param_moco.file_target = list_file_group[ 0] # reference for reslicing into proper coordinate system param_moco.path_out = '' # TODO not used in moco() param_moco.mat_moco = mat_final param_moco.todo = 'apply' file_mat_data, im_moco = moco(param_moco) # copy geometric information from header # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1". im_moco.header = im_data.header im_moco.save(verbose=0) # Average across time if param.is_diffusion: # generate b0_moco_mean and dwi_moco_mean args = [ '-i', im_moco.absolutepath, '-bvec', param.fname_bvecs, '-a', '1', '-v', '0' ] if not param.fname_bvals == '': # if bvals file is provided args += ['-bval', param.fname_bvals] fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean = sct_dmri_separate_b0_and_dwi.main( args=args) else: fname_moco_mean = sct.add_suffix(im_moco.absolutepath, '_mean') im_moco.mean(dim=3).save(fname_moco_mean) # Extract and output the motion parameters (doesn't work for sagittal orientation) sct.printv('Extract motion parameters...') if param.output_motion_param: if param.is_sagittal: sct.printv( 'Motion parameters cannot be generated for sagittal images.', 1, 'warning') else: files_warp_X, files_warp_Y = [], [] moco_param = [] for fname_warp in file_mat_data[0]: # Cropping the image to keep only one voxel in the XY plane im_warp = Image(fname_warp + param.suffix_mat) im_warp.data = np.expand_dims(np.expand_dims( im_warp.data[0, 0, :, :, :], axis=0), axis=0) # These three lines allow to generate one file instead of two, containing X, Y and Z moco parameters #fname_warp_crop = fname_warp + '_crop_' + ext_mat #files_warp.append(fname_warp_crop) #im_warp.save(fname_warp_crop) # Separating the three components and saving X and Y only (Z is equal to 0 by default). im_warp_XYZ = multicomponent_split(im_warp) fname_warp_crop_X = fname_warp + '_crop_X_' + param.suffix_mat im_warp_XYZ[0].save(fname_warp_crop_X) files_warp_X.append(fname_warp_crop_X) fname_warp_crop_Y = fname_warp + '_crop_Y_' + param.suffix_mat im_warp_XYZ[1].save(fname_warp_crop_Y) files_warp_Y.append(fname_warp_crop_Y) # Calculating the slice-wise average moco estimate to provide a QC file moco_param.append([ np.mean(np.ravel(im_warp_XYZ[0].data)), np.mean(np.ravel(im_warp_XYZ[1].data)) ]) # These two lines allow to generate one file instead of two, containing X, Y and Z moco parameters #im_warp_concat = concat_data(files_warp, dim=3) #im_warp_concat.save('fmri_moco_params.nii') # Concatenating the moco parameters into a time series for X and Y components. im_warp_concat = concat_data(files_warp_X, dim=3) im_warp_concat.save(file_moco_params_x) im_warp_concat = concat_data(files_warp_Y, dim=3) im_warp_concat.save(file_moco_params_y) # Writing a TSV file with the slicewise average estimate of the moco parameters. Useful for QC with open(file_moco_params_csv, 'wt') as out_file: tsv_writer = csv.writer(out_file, delimiter='\t') tsv_writer.writerow(['X', 'Y']) for mocop in moco_param: tsv_writer.writerow([mocop[0], mocop[1]]) # Generate output files sct.printv('\nGenerate output files...', param.verbose) fname_moco = os.path.join( path_out_abs, sct.add_suffix(os.path.basename(param.fname_data), param.suffix)) sct.generate_output_file(im_moco.absolutepath, fname_moco) if param.is_diffusion: sct.generate_output_file(fname_b0_mean, sct.add_suffix(fname_moco, '_b0_mean')) sct.generate_output_file(fname_dwi_mean, sct.add_suffix(fname_moco, '_dwi_mean')) else: sct.generate_output_file(fname_moco_mean, sct.add_suffix(fname_moco, '_mean')) if os.path.exists(file_moco_params_csv): sct.generate_output_file(file_moco_params_x, os.path.join(path_out_abs, file_moco_params_x), squeeze_data=False) sct.generate_output_file(file_moco_params_y, os.path.join(path_out_abs, file_moco_params_y), squeeze_data=False) sct.generate_output_file( file_moco_params_csv, os.path.join(path_out_abs, file_moco_params_csv)) # Delete temporary files if param.remove_temp_files == 1: sct.printv('\nDelete temporary files...', param.verbose) sct.rmtree(path_tmp, verbose=param.verbose) # come back to working directory os.chdir(curdir) # display elapsed time elapsed_time = time.time() - start_time sct.printv( '\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', param.verbose) sct.display_viewer_syntax([ os.path.join( param.path_out, sct.add_suffix(os.path.basename(param.fname_data), param.suffix)), param.fname_data ], mode='ortho,ortho')
def main(args=None): if not args: args = sys.argv[1:] # initialize parameters param = Param() # call main function parser = get_parser() arguments = parser.parse(args) fname_data = arguments['-i'] fname_bvecs = arguments['-bvec'] average = arguments['-a'] verbose = int(arguments.get('-v')) sct.init_sct(log_level=verbose, update=True) # Update log level remove_temp_files = int(arguments['-r']) path_out = arguments['-ofolder'] if '-bval' in arguments: fname_bvals = arguments['-bval'] else: fname_bvals = '' if '-bvalmin' in arguments: param.bval_min = arguments['-bvalmin'] # Initialization start_time = time.time() # sct.printv(arguments) sct.printv('\nInput parameters:', verbose) sct.printv(' input file ............' + fname_data, verbose) sct.printv(' bvecs file ............' + fname_bvecs, verbose) sct.printv(' bvals file ............' + fname_bvals, verbose) sct.printv(' average ...............' + str(average), verbose) # Get full path fname_data = os.path.abspath(fname_data) fname_bvecs = os.path.abspath(fname_bvecs) if fname_bvals: fname_bvals = os.path.abspath(fname_bvals) # Extract path, file and extension path_data, file_data, ext_data = sct.extract_fname(fname_data) # create temporary folder path_tmp = sct.tmp_create(basename="dmri_separate", verbose=verbose) # copy files into tmp folder and convert to nifti sct.printv('\nCopy files into temporary folder...', verbose) ext = '.nii' dmri_name = 'dmri' b0_name = file_data + '_b0' b0_mean_name = b0_name + '_mean' dwi_name = file_data + '_dwi' dwi_mean_name = dwi_name + '_mean' if not convert(fname_data, os.path.join(path_tmp, dmri_name + ext)): sct.printv('ERROR in convert.', 1, 'error') sct.copy(fname_bvecs, os.path.join(path_tmp, "bvecs"), verbose=verbose) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # Get size of data im_dmri = Image(dmri_name + ext) sct.printv('\nGet dimensions data...', verbose) nx, ny, nz, nt, px, py, pz, pt = im_dmri.dim sct.printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz) + ' x ' + str(nt), verbose) # Identify b=0 and DWI images sct.printv(fname_bvals) index_b0, index_dwi, nb_b0, nb_dwi = identify_b0(fname_bvecs, fname_bvals, param.bval_min, verbose) # Split into T dimension sct.printv('\nSplit along T dimension...', verbose) im_dmri_split_list = split_data(im_dmri, 3) for im_d in im_dmri_split_list: im_d.save() # Merge b=0 images sct.printv('\nMerge b=0...', verbose) from sct_image import concat_data l = [] for it in range(nb_b0): l.append(dmri_name + '_T' + str(index_b0[it]).zfill(4) + ext) im_out = concat_data(l, 3).save(b0_name + ext) # Average b=0 images if average: sct.printv('\nAverage b=0...', verbose) sct.run(['sct_maths', '-i', b0_name + ext, '-o', b0_mean_name + ext, '-mean', 't'], verbose) # Merge DWI l = [] for it in range(nb_dwi): l.append(dmri_name + '_T' + str(index_dwi[it]).zfill(4) + ext) im_out = concat_data(l, 3).save(dwi_name + ext) # Average DWI images if average: sct.printv('\nAverage DWI...', verbose) sct.run(['sct_maths', '-i', dwi_name + ext, '-o', dwi_mean_name + ext, '-mean', 't'], verbose) # come back os.chdir(curdir) # Generate output files fname_b0 = os.path.abspath(os.path.join(path_out, b0_name + ext_data)) fname_dwi = os.path.abspath(os.path.join(path_out, dwi_name + ext_data)) fname_b0_mean = os.path.abspath(os.path.join(path_out, b0_mean_name + ext_data)) fname_dwi_mean = os.path.abspath(os.path.join(path_out, dwi_mean_name + ext_data)) sct.printv('\nGenerate output files...', verbose) sct.generate_output_file(os.path.join(path_tmp, b0_name + ext), fname_b0, verbose=verbose) sct.generate_output_file(os.path.join(path_tmp, dwi_name + ext), fname_dwi, verbose=verbose) if average: sct.generate_output_file(os.path.join(path_tmp, b0_mean_name + ext), fname_b0_mean, verbose=verbose) sct.generate_output_file(os.path.join(path_tmp, dwi_mean_name + ext), fname_dwi_mean, verbose=verbose) # Remove temporary files if remove_temp_files == 1: sct.printv('\nRemove temporary files...', verbose) sct.rmtree(path_tmp, verbose=verbose) # display elapsed time elapsed_time = time.time() - start_time sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's', verbose) return fname_b0, fname_b0_mean, fname_dwi, fname_dwi_mean
def compute_texture(self): offset = int(self.param_glcm.distance) printv('\nCompute texture metrics...', self.param.verbose, 'normal') # open image and re-orient it to RPI if needed im_tmp = Image(self.param.fname_im) if self.orientation_im != self.orientation_extraction: im_tmp.change_orientation(self.orientation_extraction) dct_metric = {} for m in self.metric_lst: im_2save = zeros_like(im_tmp, dtype='float64') dct_metric[m] = im_2save # dct_metric[m] = Image(self.fname_metric_lst[m]) with sct_progress_bar() as pbar: for im_z, seg_z, zz in zip(self.dct_im_seg['im'], self.dct_im_seg['seg'], range(len(self.dct_im_seg['im']))): for xx in range(im_z.shape[0]): for yy in range(im_z.shape[1]): if not seg_z[xx, yy]: continue if xx < offset or yy < offset: continue if xx > (im_z.shape[0] - offset - 1) or yy > (im_z.shape[1] - offset - 1): continue # to check if the whole glcm_window is in the axial_slice if False in np.unique( seg_z[xx - offset:xx + offset + 1, yy - offset:yy + offset + 1]): continue # to check if the whole glcm_window is in the mask of the axial_slice glcm_window = im_z[xx - offset:xx + offset + 1, yy - offset:yy + offset + 1] glcm_window = glcm_window.astype(np.uint8) dct_glcm = {} for a in self.param_glcm.angle.split( ',' ): # compute the GLCM for self.param_glcm.distance and for each self.param_glcm.angle dct_glcm[a] = greycomatrix( glcm_window, [self.param_glcm.distance], [np.radians(int(a))], symmetric=self.param_glcm.symmetric, normed=self.param_glcm.normed) for m in self.metric_lst: # compute the GLCM property (m.split('_')[0]) of the voxel xx,yy,zz dct_metric[m].data[xx, yy, zz] = greycoprops( dct_glcm[m.split('_')[2]], m.split('_')[0])[0][0] pbar.set_postfix( pos="{}/{}".format(zz, len(self.dct_im_seg["im"]))) pbar.update(1) for m in self.metric_lst: fname_out = add_suffix( "".join(extract_fname(self.param.fname_im)[1:]), '_' + m) dct_metric[m].save(fname_out) self.fname_metric_lst[m] = fname_out
def resample_image(fname, suffix='_resampled.nii.gz', binary=False, npx=0.3, npy=0.3, thr=0.0, interpolation='spline'): """ Resampling function: add a padding, resample, crop the padding :param fname: name of the image file to be resampled :param suffix: suffix added to the original fname after resampling :param binary: boolean, image is binary or not :param npx: new pixel size in the x direction :param npy: new pixel size in the y direction :param thr: if the image is binary, it will be thresholded at thr (default=0) after the resampling :param interpolation: type of interpolation used for the resampling :return: file name after resampling (or original fname if it was already in the correct resolution) """ im_in = Image(fname) orientation = im_in.orientation if orientation != 'RPI': fname = im_in.change_orientation( im_in, 'RPI', generate_path=True).save().absolutepath nx, ny, nz, nt, px, py, pz, pt = im_in.dim if np.round(px, 2) != np.round(npx, 2) or np.round(py, 2) != np.round( npy, 2): name_resample = sct.extract_fname(fname)[1] + suffix if binary: interpolation = 'nn' if nz == 1: # when data is 2d: we convert it to a 3d image in order to avoid conversion problem with 2d data # TODO: check if this above problem is still present (now that we are using nibabel instead of nipy) sct.run([ 'sct_image', '-i', ','.join([fname, fname]), '-concat', 'z', '-o', fname ]) sct.run([ 'sct_resample', '-i', fname, '-mm', str(npx) + 'x' + str(npy) + 'x' + str(pz), '-o', name_resample, '-x', interpolation ]) if nz == 1: # when input data was 2d: re-convert data 3d-->2d sct.run(['sct_image', '-i', name_resample, '-split', 'z']) im_split = Image( name_resample.split('.nii.gz')[0] + '_Z0000.nii.gz') im_split.save(name_resample) if binary: sct.run([ 'sct_maths', '-i', name_resample, '-bin', str(thr), '-o', name_resample ]) if orientation != 'RPI': name_resample = Image(name_resample) \ .change_orientation(orientation, generate_path=True) \ .save() \ .absolutepath return name_resample else: if orientation != 'RPI': fname = sct.add_suffix(fname, "_RPI") im_in = msct_image.change_orientation(im_in, orientation).save(fname) sct.printv('Image resolution already ' + str(npx) + 'x' + str(npy) + 'xpz') return fname
def crop_with_gui(self): import matplotlib.pyplot as plt import matplotlib.image as mpimg # Initialization fname_data = self.input_filename suffix_out = '_crop' remove_temp_files = self.rm_tmp_files verbose = self.verbose # Check file existence sct.printv('\nCheck file existence...', verbose) sct.check_file_exist(fname_data, verbose) # Get dimensions of data sct.printv('\nGet dimensions of data...', verbose) nx, ny, nz, nt, px, py, pz, pt = Image(fname_data).dim sct.printv('.. ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), verbose) # check if 4D data if not nt == 1: sct.printv('\nERROR in ' + os.path.basename(__file__) + ': Data should be 3D.\n', 1, 'error') sys.exit(2) # sct.printv(arguments) sct.printv('\nCheck parameters:') sct.printv(' data ................... ' + fname_data) # Extract path/file/extension path_data, file_data, ext_data = sct.extract_fname(fname_data) path_out, file_out, ext_out = '', file_data + suffix_out, ext_data path_tmp = sct.tmp_create() + "/" # copy files into tmp folder from sct_convert import convert sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose) convert(fname_data, os.path.join(path_tmp, "data.nii")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # change orientation sct.printv('\nChange orientation to RPI...', verbose) Image('data.nii').change_orientation("RPI").save('data_rpi.nii') # get image of medial slab sct.printv('\nGet image of medial slab...', verbose) image_array = nibabel.load('data_rpi.nii').get_data() nx, ny, nz = image_array.shape scipy.misc.imsave('image.jpg', image_array[math.floor(nx / 2), :, :]) # Display the image sct.printv('\nDisplay image and get cropping region...', verbose) fig = plt.figure() # fig = plt.gcf() # ax = plt.gca() ax = fig.add_subplot(111) img = mpimg.imread("image.jpg") implot = ax.imshow(img.T) implot.set_cmap('gray') plt.gca().invert_yaxis() # mouse callback ax.set_title('Left click on the top and bottom of your cropping field.\n Right click to remove last point.\n Close window when your done.') line, = ax.plot([], [], 'ro') # empty line cropping_coordinates = LineBuilder(line) plt.show() # disconnect callback # fig.canvas.mpl_disconnect(line) # check if user clicked two times if len(cropping_coordinates.xs) != 2: sct.printv('\nERROR: You have to select two points. Exit program.\n', 1, 'error') sys.exit(2) # convert coordinates to integer zcrop = [int(i) for i in cropping_coordinates.ys] # sort coordinates zcrop.sort() # crop image sct.printv('\nCrop image...', verbose) nii = Image('data_rpi.nii') data_crop = nii.data[:, :, zcrop[0]:zcrop[1]] nii.data = data_crop nii.absolutepath = 'data_rpi_crop.nii' nii.save() # come back os.chdir(curdir) sct.printv('\nGenerate output files...', verbose) sct.generate_output_file(os.path.join(path_tmp, "data_rpi_crop.nii"), os.path.join(path_out, file_out + ext_out)) # Remove temporary files if remove_temp_files == 1: sct.printv('\nRemove temporary files...') sct.rmtree(path_tmp) sct.display_viewer_syntax(files=[os.path.join(path_out, file_out + ext_out)])
def main(): """Main function.""" parser = get_parser() args = parser.parse_args(args=None if sys.argv[1:] else ['--help']) fname_image = os.path.abspath(args.i) contrast_type = args.c ctr_algo = args.centerline if args.brain is None: if contrast_type in ['t2s', 'dwi']: brain_bool = False if contrast_type in ['t1', 't2']: brain_bool = True else: brain_bool = bool(args.brain) if bool(args.brain) and ctr_algo == 'svm': printv('Please only use the flag "-brain 1" with "-centerline cnn".', 1, 'warning') sys.exit(1) kernel_size = args.kernel if kernel_size == '3d' and contrast_type == 'dwi': kernel_size = '2d' printv( '3D kernel model for dwi contrast is not available. 2D kernel model is used instead.', type="warning") if ctr_algo == 'file' and args.file_centerline is None: printv( 'Please use the flag -file_centerline to indicate the centerline filename.', 1, 'warning') sys.exit(1) if args.file_centerline is not None: manual_centerline_fname = args.file_centerline ctr_algo = 'file' else: manual_centerline_fname = None threshold = args.thr if threshold is not None: if threshold > 1.0 or (threshold < 0.0 and threshold != -1.0): raise SyntaxError( "Threshold should be between 0 and 1, or equal to -1 (no threshold)" ) remove_temp_files = args.r verbose = args.v init_sct(log_level=verbose, update=True) # Update log level path_qc = args.qc qc_dataset = args.qc_dataset qc_subject = args.qc_subject output_folder = args.ofolder # check if input image is 2D or 3D check_dim(fname_image, dim_lst=[2, 3]) # Segment image im_image = Image(fname_image) # note: below we pass im_image.copy() otherwise the field absolutepath becomes None after execution of this function im_seg, im_image_RPI_upsamp, im_seg_RPI_upsamp = \ deep_segmentation_spinalcord(im_image.copy(), contrast_type, ctr_algo=ctr_algo, ctr_file=manual_centerline_fname, brain_bool=brain_bool, kernel_size=kernel_size, threshold_seg=threshold, remove_temp_files=remove_temp_files, verbose=verbose) # Save segmentation fname_seg = os.path.abspath( os.path.join( output_folder, extract_fname(fname_image)[1] + '_seg' + extract_fname(fname_image)[2])) im_seg.save(fname_seg) # Generate QC report if path_qc is not None: generate_qc(fname_image, fname_seg=fname_seg, args=sys.argv[1:], path_qc=os.path.abspath(path_qc), dataset=qc_dataset, subject=qc_subject, process='sct_deepseg_sc') display_viewer_syntax([fname_image, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])
def compute_shape(segmentation, angle_correction=True, param_centerline=None, verbose=1): """ Compute morphometric measures of the spinal cord in the transverse (axial) plane from the segmentation. The segmentation could be binary or weighted for partial volume [0,1]. :param segmentation: input segmentation. Could be either an Image or a file name. :param angle_correction: :param param_centerline: see centerline.core.ParamCenterline() :param verbose: :return metrics: Dict of class Metric(). If a metric cannot be calculated, its value will be nan. :return fit_results: class centerline.core.FitResults() """ # List of properties to output (in the right order) property_list = [ 'area', 'angle_AP', 'angle_RL', 'diameter_AP', 'diameter_RL', 'eccentricity', 'orientation', 'solidity', 'length' ] im_seg = Image(segmentation).change_orientation('RPI') # Getting image dimensions. x, y and z respectively correspond to RL, PA and IS. nx, ny, nz, nt, px, py, pz, pt = im_seg.dim pr = min([px, py]) # Resample to isotropic resolution in the axial plane. Use the minimum pixel dimension as target dimension. im_segr = resample_nib(im_seg, new_size=[pr, pr, pz], new_size_type='mm', interpolation='linear') # Update dimensions from resampled image. nx, ny, nz, nt, px, py, pz, pt = im_segr.dim # Extract min and max index in Z direction data_seg = im_segr.data X, Y, Z = (data_seg > 0).nonzero() min_z_index, max_z_index = min(Z), max(Z) # Initialize dictionary of property_list, with 1d array of nan (default value if no property for a given slice). shape_properties = { key: np.full_like(np.empty(nz), np.nan, dtype=np.double) for key in property_list } fit_results = None if angle_correction: # compute the spinal cord centerline based on the spinal cord segmentation # here, param_centerline.minmax needs to be False because we need to retrieve the total number of input slices _, arr_ctl, arr_ctl_der, fit_results = get_centerline( im_segr, param=param_centerline, verbose=verbose) # Loop across z and compute shape analysis for iz in tqdm(range(min_z_index, max_z_index + 1), unit='iter', unit_scale=False, desc="Compute shape analysis", ascii=True, ncols=80): # Extract 2D patch current_patch = im_segr.data[:, :, iz] if angle_correction: # Extract tangent vector to the centerline (i.e. its derivative) tangent_vect = np.array([ arr_ctl_der[0][iz - min_z_index] * px, arr_ctl_der[1][iz - min_z_index] * py, pz ]) # Normalize vector by its L2 norm tangent_vect = tangent_vect / np.linalg.norm(tangent_vect) # Compute the angle about AP axis between the centerline and the normal vector to the slice v0 = [tangent_vect[0], tangent_vect[2]] v1 = [0, 1] angle_AP_rad = np.math.atan2(np.linalg.det([v0, v1]), np.dot(v0, v1)) # Compute the angle about RL axis between the centerline and the normal vector to the slice v0 = [tangent_vect[1], tangent_vect[2]] v1 = [0, 1] angle_RL_rad = np.math.atan2(np.linalg.det([v0, v1]), np.dot(v0, v1)) # Apply affine transformation to account for the angle between the centerline and the normal to the patch tform = transform.AffineTransform(scale=(np.cos(angle_RL_rad), np.cos(angle_AP_rad))) # Convert to float64, to avoid problems in image indexation causing issues when applying transform.warp current_patch = current_patch.astype(np.float64) # TODO: make sure pattern does not go extend outside of image border current_patch_scaled = transform.warp( current_patch, tform.inverse, output_shape=current_patch.shape, order=1, ) else: current_patch_scaled = current_patch angle_AP_rad, angle_RL_rad = 0.0, 0.0 # compute shape properties on 2D patch shape_property = _properties2d(current_patch_scaled, [px, py]) if shape_property is not None: # Add custom fields shape_property['angle_AP'] = angle_AP_rad * 180.0 / math.pi shape_property['angle_RL'] = angle_RL_rad * 180.0 / math.pi shape_property['length'] = pz / (np.cos(angle_AP_rad) * np.cos(angle_RL_rad)) # Loop across properties and assign values for function output for property_name in property_list: shape_properties[property_name][iz] = shape_property[ property_name] else: logging.warning('\nNo properties for slice: {}'.format(iz)) """ DEBUG from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from matplotlib.figure import Figure fig = Figure() FigureCanvas(fig) ax = fig.add_subplot(111) ax.imshow(current_patch_scaled) ax.grid() ax.set_xlabel('y') ax.set_ylabel('x') fig.savefig('tmp_fig.png') """ metrics = {} for key, value in shape_properties.items(): # Making sure all entries added to metrics have results if not value == []: metrics[key] = Metric(data=np.array(value), label=key) return metrics, fit_results
def detect_centerline(img, contrast, verbose=1): """Detect spinal cord centerline using OptiC. :param img: input Image() object. :param contrast: str: The type of contrast. Will define the path to Optic model. :returns: Image(): Output centerline """ # Fetch path to Optic model based on contrast optic_models_path = sct_dir_local_path('data', 'optic_models', '{}_model'.format(contrast)) logger.debug('Detecting the spinal cord using OptiC') img_orientation = img.orientation temp_folder = TempFolder() temp_folder.chdir() # convert image data type to int16, as required by opencv (backend in OptiC) img_int16 = img.copy() # Replace non-numeric values by zero img_data = img.data img_data[np.where(np.isnan(img_data))] = 0 img_data[np.where(np.isinf(img_data))] = 0 img_int16.data[np.where(np.isnan(img_int16.data))] = 0 img_int16.data[np.where(np.isinf(img_int16.data))] = 0 # rescale intensity min_out = np.iinfo('uint16').min max_out = np.iinfo('uint16').max min_in = np.nanmin(img_data) max_in = np.nanmax(img_data) data_rescaled = img_data.astype('float') * (max_out - min_out) / (max_in - min_in) img_int16.data = data_rescaled - (data_rescaled.min() - min_out) # change data type img_int16.change_type(np.uint16) # reorient the input image to RPI + convert to .nii img_int16.change_orientation('RPI') file_img = 'img_rpi_uint16' img_int16.save(file_img + '.nii') # call the OptiC method to generate the spinal cord centerline optic_input = file_img optic_filename = file_img + '_optic' os.environ["FSLOUTPUTTYPE"] = "NIFTI_PAIR" cmd_optic = [ 'isct_spine_detect', '-ctype=dpdt', '-lambda=1', optic_models_path, optic_input, optic_filename, ] # TODO: output coordinates, for each slice, in continuous (not discrete) values. run_proc(cmd_optic, is_sct_binary=True, verbose=0) # convert .img and .hdr files to .nii.gz img_ctl = Image(file_img + '_optic_ctr.hdr') img_ctl.change_orientation(img_orientation) # return to initial folder temp_folder.chdir_undo() if verbose < 2: logger.info("Remove temporary files...") temp_folder.cleanup() return img_ctl
def main(args=None): # initializations initz = '' initcenter = '' fname_initlabel = '' file_labelz = 'labelz.nii.gz' param = Param() # check user arguments parser = get_parser() if args: arguments = parser.parse_args(args) else: arguments = parser.parse_args( args=None if sys.argv[1:] else ['--help']) fname_in = os.path.abspath(arguments.i) fname_seg = os.path.abspath(arguments.s) contrast = arguments.c path_template = os.path.abspath(arguments.t) scale_dist = arguments.scale_dist path_output = arguments.ofolder param.path_qc = arguments.qc if arguments.discfile is not None: fname_disc = os.path.abspath(arguments.discfile) else: fname_disc = None if arguments.initz is not None: initz = arguments.initz if arguments.initcenter is not None: initcenter = arguments.initcenter # if user provided text file, parse and overwrite arguments if arguments.initfile is not None: file = open(arguments.initfile, 'r') initfile = ' ' + file.read().replace('\n', '') arg_initfile = initfile.split(' ') for idx_arg, arg in enumerate(arg_initfile): if arg == '-initz': initz = [int(x) for x in arg_initfile[idx_arg + 1].split(',')] if arg == '-initcenter': initcenter = int(arg_initfile[idx_arg + 1]) if arguments.initlabel is not None: # get absolute path of label fname_initlabel = os.path.abspath(arguments.initlabel) if arguments.param is not None: param.update(arguments.param[0]) verbose = int(arguments.v) init_sct(log_level=verbose, update=True) # Update log level remove_temp_files = arguments.r denoise = arguments.denoise laplacian = arguments.laplacian path_tmp = sct.tmp_create(basename="label_vertebrae", verbose=verbose) # Copying input data to tmp folder sct.printv('\nCopying input data to tmp folder...', verbose) Image(fname_in).save(os.path.join(path_tmp, "data.nii")) Image(fname_seg).save(os.path.join(path_tmp, "segmentation.nii")) # Go go temp folder curdir = os.getcwd() os.chdir(path_tmp) # Straighten spinal cord sct.printv('\nStraighten spinal cord...', verbose) # check if warp_curve2straight and warp_straight2curve already exist (i.e. no need to do it another time) cache_sig = sct.cache_signature(input_files=[fname_in, fname_seg], ) cachefile = os.path.join(curdir, "straightening.cache") if sct.cache_valid(cachefile, cache_sig) and os.path.isfile( os.path.join( curdir, "warp_curve2straight.nii.gz")) and os.path.isfile( os.path.join( curdir, "warp_straight2curve.nii.gz")) and os.path.isfile( os.path.join(curdir, "straight_ref.nii.gz")): # if they exist, copy them into current folder sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning') sct.copy(os.path.join(curdir, "warp_curve2straight.nii.gz"), 'warp_curve2straight.nii.gz') sct.copy(os.path.join(curdir, "warp_straight2curve.nii.gz"), 'warp_straight2curve.nii.gz') sct.copy(os.path.join(curdir, "straight_ref.nii.gz"), 'straight_ref.nii.gz') # apply straightening s, o = run_proc([ 'sct_apply_transfo', '-i', 'data.nii', '-w', 'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o', 'data_straight.nii' ]) else: sct_straighten_spinalcord.main(args=[ '-i', 'data.nii', '-s', 'segmentation.nii', '-r', str(remove_temp_files), '-v', str(verbose), ]) sct.cache_save(cachefile, cache_sig) # resample to 0.5mm isotropic to match template resolution sct.printv('\nResample to 0.5mm isotropic...', verbose) s, o = run_proc([ 'sct_resample', '-i', 'data_straight.nii', '-mm', '0.5x0.5x0.5', '-x', 'linear', '-o', 'data_straightr.nii' ], verbose=verbose) # Apply straightening to segmentation # N.B. Output is RPI sct.printv('\nApply straightening to segmentation...', verbose) run_proc( 'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' % ('segmentation.nii', 'data_straightr.nii', 'warp_curve2straight.nii.gz', 'segmentation_straight.nii', 'Linear'), verbose=verbose, is_sct_binary=True, ) # Threshold segmentation at 0.5 run_proc([ 'sct_maths', '-i', 'segmentation_straight.nii', '-thr', '0.5', '-o', 'segmentation_straight.nii' ], verbose) # If disc label file is provided, label vertebrae using that file instead of automatically if fname_disc: # Apply straightening to disc-label sct.printv('\nApply straightening to disc labels...', verbose) run_proc( 'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' % (fname_disc, 'data_straightr.nii', 'warp_curve2straight.nii.gz', 'labeldisc_straight.nii.gz', 'NearestNeighbor'), verbose=verbose, is_sct_binary=True, ) label_vert('segmentation_straight.nii', 'labeldisc_straight.nii.gz', verbose=1) else: # create label to identify disc sct.printv('\nCreate label to identify disc...', verbose) fname_labelz = os.path.join(path_tmp, file_labelz) if initz or initcenter: if initcenter: # find z centered in FOV nii = Image('segmentation.nii').change_orientation("RPI") nx, ny, nz, nt, px, py, pz, pt = nii.dim # Get dimensions z_center = int(np.round(nz / 2)) # get z_center initz = [z_center, initcenter] im_label = create_labels_along_segmentation( Image('segmentation.nii'), [(initz[0], initz[1])]) im_label.data = dilate(im_label.data, 3, 'ball') im_label.save(fname_labelz) elif fname_initlabel: Image(fname_initlabel).save(fname_labelz) else: # automatically finds C2-C3 disc im_data = Image('data.nii') im_seg = Image('segmentation.nii') if not remove_temp_files: # because verbose is here also used for keeping temp files verbose_detect_c2c3 = 2 else: verbose_detect_c2c3 = 0 im_label_c2c3 = detect_c2c3(im_data, im_seg, contrast, verbose=verbose_detect_c2c3) ind_label = np.where(im_label_c2c3.data) if not np.size(ind_label) == 0: im_label_c2c3.data[ind_label] = 3 else: sct.printv( 'Automatic C2-C3 detection failed. Please provide manual label with sct_label_utils', 1, 'error') sys.exit() im_label_c2c3.save(fname_labelz) # dilate label so it is not lost when applying warping dilate(Image(fname_labelz), 3, 'ball').save(fname_labelz) # Apply straightening to z-label sct.printv('\nAnd apply straightening to label...', verbose) run_proc( 'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' % (file_labelz, 'data_straightr.nii', 'warp_curve2straight.nii.gz', 'labelz_straight.nii.gz', 'NearestNeighbor'), verbose=verbose, is_sct_binary=True, ) # get z value and disk value to initialize labeling sct.printv('\nGet z and disc values from straight label...', verbose) init_disc = get_z_and_disc_values_from_label('labelz_straight.nii.gz') sct.printv('.. ' + str(init_disc), verbose) # denoise data if denoise: sct.printv('\nDenoise data...', verbose) run_proc([ 'sct_maths', '-i', 'data_straightr.nii', '-denoise', 'h=0.05', '-o', 'data_straightr.nii' ], verbose) # apply laplacian filtering if laplacian: sct.printv('\nApply Laplacian filter...', verbose) run_proc([ 'sct_maths', '-i', 'data_straightr.nii', '-laplacian', '1', '-o', 'data_straightr.nii' ], verbose) # detect vertebral levels on straight spinal cord init_disc[1] = init_disc[1] - 1 vertebral_detection('data_straightr.nii', 'segmentation_straight.nii', contrast, param, init_disc=init_disc, verbose=verbose, path_template=path_template, path_output=path_output, scale_dist=scale_dist) # un-straighten labeled spinal cord sct.printv('\nUn-straighten labeling...', verbose) run_proc( 'isct_antsApplyTransforms -d 3 -i %s -r %s -t %s -o %s -n %s' % ('segmentation_straight_labeled.nii', 'segmentation.nii', 'warp_straight2curve.nii.gz', 'segmentation_labeled.nii', 'NearestNeighbor'), verbose=verbose, is_sct_binary=True, ) # Clean labeled segmentation sct.printv( '\nClean labeled segmentation (correct interpolation errors)...', verbose) clean_labeled_segmentation('segmentation_labeled.nii', 'segmentation.nii', 'segmentation_labeled.nii') # label discs sct.printv('\nLabel discs...', verbose) label_discs('segmentation_labeled.nii', verbose=verbose) # come back os.chdir(curdir) # Generate output files path_seg, file_seg, ext_seg = sct.extract_fname(fname_seg) fname_seg_labeled = os.path.join(path_output, file_seg + '_labeled' + ext_seg) sct.printv('\nGenerate output files...', verbose) sct.generate_output_file( os.path.join(path_tmp, "segmentation_labeled.nii"), fname_seg_labeled) sct.generate_output_file( os.path.join(path_tmp, "segmentation_labeled_disc.nii"), os.path.join(path_output, file_seg + '_labeled_discs' + ext_seg)) # copy straightening files in case subsequent SCT functions need them sct.generate_output_file(os.path.join(path_tmp, "warp_curve2straight.nii.gz"), os.path.join(path_output, "warp_curve2straight.nii.gz"), verbose=verbose) sct.generate_output_file(os.path.join(path_tmp, "warp_straight2curve.nii.gz"), os.path.join(path_output, "warp_straight2curve.nii.gz"), verbose=verbose) sct.generate_output_file(os.path.join(path_tmp, "straight_ref.nii.gz"), os.path.join(path_output, "straight_ref.nii.gz"), verbose=verbose) # Remove temporary files if remove_temp_files == 1: sct.printv('\nRemove temporary files...', verbose) sct.rmtree(path_tmp) # Generate QC report if param.path_qc is not None: path_qc = os.path.abspath(arguments.qc) qc_dataset = arguments.qc_dataset qc_subject = arguments.qc_subject labeled_seg_file = os.path.join(path_output, file_seg + '_labeled' + ext_seg) generate_qc(fname_in, fname_seg=labeled_seg_file, args=args, path_qc=os.path.abspath(path_qc), dataset=qc_dataset, subject=qc_subject, process='sct_label_vertebrae') sct.display_viewer_syntax([fname_in, fname_seg_labeled], colormaps=['', 'subcortical'], opacities=['1', '0.5'])
def straighten(self): """ Straighten spinal cord. Steps: (everything is done in physical space) 1. open input image and centreline image 2. extract bspline fitting of the centreline, and its derivatives 3. compute length of centerline 4. compute and generate straight space 5. compute transformations for each voxel of one space: (done using matrices --> improves speed by a factor x300) a. determine which plane of spinal cord centreline it is included b. compute the position of the voxel in the plane (X and Y distance from centreline, along the plane) c. find the correspondant centreline point in the other space d. find the correspondance of the voxel in the corresponding plane 6. generate warping fields for each transformations 7. write warping fields and apply them step 5.b: how to find the corresponding plane? The centerline plane corresponding to a voxel correspond to the nearest point of the centerline. However, we need to compute the distance between the voxel position and the plane to be sure it is part of the plane and not too distant. If it is more far than a threshold, warping value should be 0. step 5.d: how to make the correspondance between centerline point in both images? Both centerline have the same lenght. Therefore, we can map centerline point via their position along the curve. If we use the same number of points uniformely along the spinal cord (1000 for example), the correspondance is straight-forward. :return: """ # Initialization fname_anat = self.input_filename fname_centerline = self.centerline_filename fname_output = self.output_filename remove_temp_files = self.remove_temp_files verbose = self.verbose interpolation_warp = self.interpolation_warp # TODO: remove this # start timer start_time = time.time() # Extract path/file/extension path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat) path_tmp = sct.tmp_create(basename="straighten_spinalcord", verbose=verbose) # Copying input data to tmp folder sct.printv('\nCopy files to tmp folder...', verbose) Image(fname_anat).save(os.path.join(path_tmp, "data.nii")) Image(fname_centerline).save( os.path.join(path_tmp, "centerline.nii.gz")) if self.use_straight_reference: Image(self.centerline_reference_filename).save( os.path.join(path_tmp, "centerline_ref.nii.gz")) if self.discs_input_filename != '': Image(self.discs_input_filename).save( os.path.join(path_tmp, "labels_input.nii.gz")) if self.discs_ref_filename != '': Image(self.discs_ref_filename).save( os.path.join(path_tmp, "labels_ref.nii.gz")) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # Change orientation of the input centerline into RPI image_centerline = Image("centerline.nii.gz").change_orientation( "RPI").save("centerline_rpi.nii.gz", mutable=True) # Get dimension nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if self.speed_factor != 1.0: intermediate_resampling = True px_r, py_r, pz_r = px * self.speed_factor, py * self.speed_factor, pz * self.speed_factor else: intermediate_resampling = False if intermediate_resampling: sct.mv('centerline_rpi.nii.gz', 'centerline_rpi_native.nii.gz') pz_native = pz # TODO: remove system call run_proc([ 'sct_resample', '-i', 'centerline_rpi_native.nii.gz', '-mm', str(px_r) + 'x' + str(py_r) + 'x' + str(pz_r), '-o', 'centerline_rpi.nii.gz' ]) image_centerline = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim if np.min(image_centerline.data) < 0 or np.max( image_centerline.data) > 1: image_centerline.data[image_centerline.data < 0] = 0 image_centerline.data[image_centerline.data > 1] = 1 image_centerline.save() # 2. extract bspline fitting of the centerline, and its derivatives img_ctl = Image('centerline_rpi.nii.gz') centerline = _get_centerline(img_ctl, self.param_centerline, verbose) number_of_points = centerline.number_of_points # ========================================================================================== logger.info('Create the straight space and the safe zone') # 3. compute length of centerline # compute the length of the spinal cord based on fitted centerline and size of centerline in z direction # Computation of the safe zone. # The safe zone is defined as the length of the spinal cord for which an axial segmentation will be complete # The safe length (to remove) is computed using the safe radius (given as parameter) and the angle of the # last centerline point with the inferior-superior direction. Formula: Ls = Rs * sin(angle) # Calculate Ls for both edges and remove appropriate number of centerline points radius_safe = 0.0 # mm # inferior edge u = centerline.derivatives[0] v = np.array([0, 0, -1]) angle_inferior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_inferior = radius_safe * np.sin(angle_inferior) # superior edge u = centerline.derivatives[-1] v = np.array([0, 0, 1]) angle_superior = np.arctan2(np.linalg.norm(np.cross(u, v)), np.dot(u, v)) length_safe_superior = radius_safe * np.sin(angle_superior) # remove points inferior_bound = bisect.bisect(centerline.progressive_length, length_safe_inferior) - 1 superior_bound = centerline.number_of_points - bisect.bisect( centerline.progressive_length_inverse, length_safe_superior) z_centerline = centerline.points[:, 2] length_centerline = centerline.length size_z_centerline = z_centerline[-1] - z_centerline[0] # compute the size factor between initial centerline and straight bended centerline factor_curved_straight = length_centerline / size_z_centerline middle_slice = (z_centerline[0] + z_centerline[-1]) / 2.0 bound_curved = [ z_centerline[inferior_bound], z_centerline[superior_bound] ] bound_straight = [(z_centerline[inferior_bound] - middle_slice) * factor_curved_straight + middle_slice, (z_centerline[superior_bound] - middle_slice) * factor_curved_straight + middle_slice] logger.info('Length of spinal cord: {}'.format(length_centerline)) logger.info( 'Size of spinal cord in z direction: {}'.format(size_z_centerline)) logger.info('Ratio length/size: {}'.format(factor_curved_straight)) logger.info( 'Safe zone boundaries (curved space): {}'.format(bound_curved)) logger.info( 'Safe zone boundaries (straight space): {}'.format(bound_straight)) # 4. compute and generate straight space # points along curved centerline are already regularly spaced. # calculate position of points along straight centerline # Create straight NIFTI volumes. # ========================================================================================== # TODO: maybe this if case is not needed? if self.use_straight_reference: image_centerline_pad = Image('centerline_rpi.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim fname_ref = 'centerline_ref_rpi.nii.gz' image_centerline_straight = Image('centerline_ref.nii.gz') \ .change_orientation("RPI") \ .save(fname_ref, mutable=True) centerline_straight = _get_centerline(image_centerline_straight, self.param_centerline, verbose) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim # Prepare warping fields headers hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.discs_input_filename != "" and self.discs_ref_filename != "": discs_input_image = Image('labels_input.nii.gz') coord = discs_input_image.getNonZeroCoordinates( sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_input_image.transfo_pix2phys([[c.x, c.y, c.z] ]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline.compute_vertebral_distribution(coord_physical) centerline.save_centerline( image=discs_input_image, fname_output='discs_input_image.nii.gz') discs_ref_image = Image('labels_ref.nii.gz') coord = discs_ref_image.getNonZeroCoordinates( sorting='z', reverse_coord=True) coord_physical = [] for c in coord: c_p = discs_ref_image.transfo_pix2phys([[c.x, c.y, c.z]]).tolist()[0] c_p.append(c.value) coord_physical.append(c_p) centerline_straight.compute_vertebral_distribution( coord_physical) centerline_straight.save_centerline( image=discs_ref_image, fname_output='discs_ref_image.nii.gz') else: logger.info( 'Pad input volume to account for spinal cord length...') start_point, end_point = bound_straight[0], bound_straight[1] offset_z = 0 # if the destination image is resampled, we still create the straight reference space with the native # resolution. # TODO: Maybe this if case is not needed? if intermediate_resampling: padding_z = int( np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz_native)) run_proc([ 'sct_image', '-i', 'centerline_rpi_native.nii.gz', '-o', 'tmp.centerline_pad_native.nii.gz', '-pad', '0,0,' + str(padding_z) ]) image_centerline_pad = Image('centerline_rpi_native.nii.gz') nx, ny, nz, nt, px, py, pz, pt = image_centerline_pad.dim start_point_coord_native = image_centerline_pad.transfo_phys2pix( [[0, 0, start_point]])[0] end_point_coord_native = image_centerline_pad.transfo_phys2pix( [[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [ int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x ] warp_space_y = [ int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y ] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord_native[2] - start_point_coord_native[2])), )) msct_image.spatial_crop( Image("tmp.centerline_pad_native.nii.gz"), spec).save("tmp.centerline_pad_crop_native.nii.gz") fname_ref = 'tmp.centerline_pad_crop_native.nii.gz' offset_z = 4 else: fname_ref = 'tmp.centerline_pad_crop.nii.gz' nx, ny, nz, nt, px, py, pz, pt = image_centerline.dim padding_z = int( np.ceil(1.5 * ((length_centerline - size_z_centerline) / 2.0) / pz)) + offset_z image_centerline_pad = pad_image(image_centerline, pad_z_i=padding_z, pad_z_f=padding_z) nx, ny, nz = image_centerline_pad.data.shape hdr_warp = image_centerline_pad.hdr.copy() hdr_warp.set_data_dtype('float32') start_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, end_point]])[0] straight_size_x = int(self.xy_size / px) straight_size_y = int(self.xy_size / py) warp_space_x = [ int(np.round(nx / 2)) - straight_size_x, int(np.round(nx / 2)) + straight_size_x ] warp_space_y = [ int(np.round(ny / 2)) - straight_size_y, int(np.round(ny / 2)) + straight_size_y ] if warp_space_x[0] < 0: warp_space_x[1] += warp_space_x[0] - 2 warp_space_x[0] = 0 if warp_space_x[1] >= nx: warp_space_x[1] = nx - 1 if warp_space_y[0] < 0: warp_space_y[1] += warp_space_y[0] - 2 warp_space_y[0] = 0 if warp_space_y[1] >= ny: warp_space_y[1] = ny - 1 spec = dict(( (0, warp_space_x), (1, warp_space_y), (2, (0, end_point_coord[2] - start_point_coord[2] + offset_z)), )) image_centerline_straight = msct_image.spatial_crop( image_centerline_pad, spec) nx_s, ny_s, nz_s, nt_s, px_s, py_s, pz_s, pt_s = image_centerline_straight.dim hdr_warp_s = image_centerline_straight.hdr.copy() hdr_warp_s.set_data_dtype('float32') if self.template_orientation == 1: raise NotImplementedError() start_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, start_point]])[0] end_point_coord = image_centerline_pad.transfo_phys2pix( [[0, 0, end_point]])[0] number_of_voxel = nx * ny * nz logger.debug('Number of voxels: {}'.format(number_of_voxel)) time_centerlines = time.time() coord_straight = np.empty((number_of_points, 3)) coord_straight[..., 0] = int(np.round(nx_s / 2)) coord_straight[..., 1] = int(np.round(ny_s / 2)) coord_straight[..., 2] = np.linspace( 0, end_point_coord[2] - start_point_coord[2], number_of_points) coord_phys_straight = image_centerline_straight.transfo_pix2phys( coord_straight) derivs_straight = np.empty((number_of_points, 3)) derivs_straight[..., 0] = derivs_straight[..., 1] = 0 derivs_straight[..., 2] = 1 dx_straight, dy_straight, dz_straight = derivs_straight.T centerline_straight = Centerline(coord_phys_straight[:, 0], coord_phys_straight[:, 1], coord_phys_straight[:, 2], dx_straight, dy_straight, dz_straight) time_centerlines = time.time() - time_centerlines logger.info('Time to generate centerline: {} ms'.format( np.round(time_centerlines * 1000.0))) if verbose == 2: # TODO: use OO import matplotlib.pyplot as plt from datetime import datetime curved_points = centerline.progressive_length straight_points = centerline_straight.progressive_length range_points = np.linspace(0, 1, number_of_points) dist_curved = np.zeros(number_of_points) dist_straight = np.zeros(number_of_points) for i in range(1, number_of_points): dist_curved[i] = dist_curved[ i - 1] + curved_points[i - 1] / centerline.length dist_straight[i] = dist_straight[i - 1] + straight_points[ i - 1] / centerline_straight.length plt.plot(range_points, dist_curved) plt.plot(range_points, dist_straight) plt.grid(True) plt.savefig('fig_straighten_' + datetime.now().strftime("%y%m%d%H%M%S%f") + '.png') plt.close() # alignment_mode = 'length' alignment_mode = 'levels' lookup_curved2straight = list(range(centerline.number_of_points)) if self.discs_input_filename != "": # create look-up table curved to straight for index in range(centerline.number_of_points): disc_label = centerline.l_points[index] if alignment_mode == 'length': relative_position = centerline.dist_points[index] else: relative_position = centerline.dist_points_rel[index] idx_closest = centerline_straight.get_closest_to_absolute_position( disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_curved2straight[index] = idx_closest else: lookup_curved2straight[index] = 0 for p in range(0, len(lookup_curved2straight) // 2): if lookup_curved2straight[p] == lookup_curved2straight[p + 1]: lookup_curved2straight[p] = 0 else: break for p in range( len(lookup_curved2straight) - 1, len(lookup_curved2straight) // 2, -1): if lookup_curved2straight[p] == lookup_curved2straight[p - 1]: lookup_curved2straight[p] = 0 else: break lookup_curved2straight = np.array(lookup_curved2straight) lookup_straight2curved = list( range(centerline_straight.number_of_points)) if self.discs_input_filename != "": for index in range(centerline_straight.number_of_points): disc_label = centerline_straight.l_points[index] if alignment_mode == 'length': relative_position = centerline_straight.dist_points[index] else: relative_position = centerline_straight.dist_points_rel[ index] idx_closest = centerline.get_closest_to_absolute_position( disc_label, relative_position, backup_index=index, backup_centerline=centerline_straight, mode=alignment_mode) if idx_closest is not None: lookup_straight2curved[index] = idx_closest for p in range(0, len(lookup_straight2curved) // 2): if lookup_straight2curved[p] == lookup_straight2curved[p + 1]: lookup_straight2curved[p] = 0 else: break for p in range( len(lookup_straight2curved) - 1, len(lookup_straight2curved) // 2, -1): if lookup_straight2curved[p] == lookup_straight2curved[p - 1]: lookup_straight2curved[p] = 0 else: break lookup_straight2curved = np.array(lookup_straight2curved) # Create volumes containing curved and straight warping fields data_warp_curved2straight = np.zeros((nx_s, ny_s, nz_s, 1, 3)) data_warp_straight2curved = np.zeros((nx, ny, nz, 1, 3)) # 5. compute transformations # Curved and straight images and the same dimensions, so we compute both warping fields at the same time. # b. determine which plane of spinal cord centreline it is included # sct.printv(nx * ny * nz, nx_s * ny_s * nz_s) if self.curved2straight: for u in sct_progress_bar(range(nz_s)): x_s, y_s, z_s = np.mgrid[0:nx_s, 0:ny_s, u:u + 1] indexes_straight = np.array( list(zip(x_s.ravel(), y_s.ravel(), z_s.ravel()))) physical_coordinates_straight = image_centerline_straight.transfo_pix2phys( indexes_straight) nearest_indexes_straight = centerline_straight.find_nearest_indexes( physical_coordinates_straight) distances_straight = centerline_straight.get_distances_from_planes( physical_coordinates_straight, nearest_indexes_straight) lookup = lookup_straight2curved[nearest_indexes_straight] indexes_out_distance_straight = np.logical_or( np.logical_or( distances_straight > self.threshold_distance, distances_straight < -self.threshold_distance), lookup == 0) projected_points_straight = centerline_straight.get_projected_coordinates_on_planes( physical_coordinates_straight, nearest_indexes_straight) coord_in_planes_straight = centerline_straight.get_in_plans_coordinates( projected_points_straight, nearest_indexes_straight) coord_straight2curved = centerline.get_inverse_plans_coordinates( coord_in_planes_straight, lookup) displacements_straight = coord_straight2curved - physical_coordinates_straight # Invert Z coordinate as ITK & ANTs physical coordinate system is LPS- (RAI+) # while ours is LPI- # Refs: https://sourceforge.net/p/advants/discussion/840261/thread/2a1e9307/#fb5a # https://www.slicer.org/wiki/Coordinate_systems displacements_straight[:, 2] = -displacements_straight[:, 2] displacements_straight[indexes_out_distance_straight] = [ 100000.0, 100000.0, 100000.0 ] data_warp_curved2straight[indexes_straight[:, 0], indexes_straight[:, 1], indexes_straight[:, 2], 0, :]\ = -displacements_straight if self.straight2curved: for u in sct_progress_bar(range(nz)): x, y, z = np.mgrid[0:nx, 0:ny, u:u + 1] indexes = np.array(list(zip(x.ravel(), y.ravel(), z.ravel()))) physical_coordinates = image_centerline_pad.transfo_pix2phys( indexes) nearest_indexes_curved = centerline.find_nearest_indexes( physical_coordinates) distances_curved = centerline.get_distances_from_planes( physical_coordinates, nearest_indexes_curved) lookup = lookup_curved2straight[nearest_indexes_curved] indexes_out_distance_curved = np.logical_or( np.logical_or(distances_curved > self.threshold_distance, distances_curved < -self.threshold_distance), lookup == 0) projected_points_curved = centerline.get_projected_coordinates_on_planes( physical_coordinates, nearest_indexes_curved) coord_in_planes_curved = centerline.get_in_plans_coordinates( projected_points_curved, nearest_indexes_curved) coord_curved2straight = centerline_straight.points[lookup] coord_curved2straight[:, 0:2] += coord_in_planes_curved[:, 0:2] coord_curved2straight[:, 2] += distances_curved displacements_curved = coord_curved2straight - physical_coordinates displacements_curved[:, 2] = -displacements_curved[:, 2] displacements_curved[indexes_out_distance_curved] = [ 100000.0, 100000.0, 100000.0 ] data_warp_straight2curved[indexes[:, 0], indexes[:, 1], indexes[:, 2], 0, :] = -displacements_curved # Creation of the safe zone based on pre-calculated safe boundaries coord_bound_curved_inf, coord_bound_curved_sup = image_centerline_pad.transfo_phys2pix( [[0, 0, bound_curved[0]]]), image_centerline_pad.transfo_phys2pix( [[0, 0, bound_curved[1]]]) coord_bound_straight_inf, coord_bound_straight_sup = image_centerline_straight.transfo_phys2pix( [[0, 0, bound_straight[0]]]), image_centerline_straight.transfo_phys2pix( [[0, 0, bound_straight[1]]]) if radius_safe > 0: data_warp_curved2straight[:, :, 0:coord_bound_straight_inf[0][2], 0, :] = 100000.0 data_warp_curved2straight[:, :, coord_bound_straight_sup[0][2]:, 0, :] = 100000.0 data_warp_straight2curved[:, :, 0:coord_bound_curved_inf[0][2], 0, :] = 100000.0 data_warp_straight2curved[:, :, coord_bound_curved_sup[0][2]:, 0, :] = 100000.0 # Generate warp files as a warping fields hdr_warp_s.set_intent('vector', (), '') hdr_warp_s.set_data_dtype('float32') hdr_warp.set_intent('vector', (), '') hdr_warp.set_data_dtype('float32') if self.curved2straight: img = Nifti1Image(data_warp_curved2straight, None, hdr_warp_s) save(img, 'tmp.curve2straight.nii.gz') logger.info('Warping field generated: tmp.curve2straight.nii.gz') if self.straight2curved: img = Nifti1Image(data_warp_straight2curved, None, hdr_warp) save(img, 'tmp.straight2curve.nii.gz') logger.info('Warping field generated: tmp.straight2curve.nii.gz') image_centerline_straight.save(fname_ref) if self.curved2straight: logger.info('Apply transformation to input image...') run_proc([ 'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i', 'data.nii', '-o', 'tmp.anat_rigid_warp.nii.gz', '-t', 'tmp.curve2straight.nii.gz', '-n', 'BSpline[3]' ], is_sct_binary=True, verbose=verbose) if self.accuracy_results: time_accuracy_results = time.time() # compute the error between the straightened centerline/segmentation and the central vertical line. # Ideally, the error should be zero. # Apply deformation to input image logger.info('Apply transformation to centerline image...') run_proc([ 'isct_antsApplyTransforms', '-d', '3', '-r', fname_ref, '-i', 'centerline.nii.gz', '-o', 'tmp.centerline_straight.nii.gz', '-t', 'tmp.curve2straight.nii.gz', '-n', 'NearestNeighbor' ], is_sct_binary=True, verbose=verbose) file_centerline_straight = Image('tmp.centerline_straight.nii.gz', verbose=verbose) nx, ny, nz, nt, px, py, pz, pt = file_centerline_straight.dim coordinates_centerline = file_centerline_straight.getNonZeroCoordinates( sorting='z') mean_coord = [] for z in range(coordinates_centerline[0].z, coordinates_centerline[-1].z): temp_mean = [ coord.value for coord in coordinates_centerline if coord.z == z ] if temp_mean: mean_value = np.mean(temp_mean) mean_coord.append( np.mean([[ coord.x * coord.value / mean_value, coord.y * coord.value / mean_value ] for coord in coordinates_centerline if coord.z == z], axis=0)) # compute error between the straightened centerline and the straight line. x0 = file_centerline_straight.data.shape[0] / 2.0 y0 = file_centerline_straight.data.shape[1] / 2.0 count_mean = 0 if number_of_points >= 10: mean_c = mean_coord[ 2: -2] # we don't include the four extrema because there are usually messy. else: mean_c = mean_coord for coord_z in mean_c: if not np.isnan(np.sum(coord_z)): dist = ((x0 - coord_z[0]) * px)**2 + ( (y0 - coord_z[1]) * py)**2 self.mse_straightening += dist dist = np.sqrt(dist) if dist > self.max_distance_straightening: self.max_distance_straightening = dist count_mean += 1 self.mse_straightening = np.sqrt(self.mse_straightening / float(count_mean)) self.elapsed_time_accuracy = time.time() - time_accuracy_results os.chdir(curdir) # Generate output file (in current folder) # TODO: do not uncompress the warping field, it is too time consuming! logger.info('Generate output files...') if self.curved2straight: sct.generate_output_file( os.path.join(path_tmp, "tmp.curve2straight.nii.gz"), os.path.join(self.path_output, "warp_curve2straight.nii.gz"), verbose) if self.straight2curved: sct.generate_output_file( os.path.join(path_tmp, "tmp.straight2curve.nii.gz"), os.path.join(self.path_output, "warp_straight2curve.nii.gz"), verbose) # create ref_straight.nii.gz file that can be used by other SCT functions that need a straight reference space if self.curved2straight: sct.copy(os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, "straight_ref.nii.gz")) # move straightened input file if fname_output == '': fname_straight = sct.generate_output_file( os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, file_anat + "_straight" + ext_anat), verbose) else: fname_straight = sct.generate_output_file( os.path.join(path_tmp, "tmp.anat_rigid_warp.nii.gz"), os.path.join(self.path_output, fname_output), verbose) # straightened anatomic # Remove temporary files if remove_temp_files: logger.info('Remove temporary files...') sct.rmtree(path_tmp) if self.accuracy_results: logger.info('Maximum x-y error: {} mm'.format( self.max_distance_straightening)) logger.info('Accuracy of straightening (MSE): {} mm'.format( self.mse_straightening)) # display elapsed time self.elapsed_time = int(np.round(time.time() - start_time)) return fname_straight
def main(argv=None): parser = get_parser() arguments = parser.parse_args(argv) verbose = arguments.v set_global_loglevel(verbose=verbose) # Default params param = Param() # Get parser info fname_data = arguments.i if arguments.m is not None: fname_mask = arguments.m else: fname_mask = '' method = arguments.method if arguments.vol is not None: index_vol_user = arguments.vol else: index_vol_user = '' # Check parameters if method == 'diff': if not fname_mask: printv('You need to provide a mask with -method diff. Exit.', 1, type='error') # Load data and orient to RPI im_data = Image(fname_data).change_orientation('RPI') data = im_data.data if fname_mask: mask = Image(fname_mask).change_orientation('RPI').data # Retrieve selected volumes if index_vol_user: index_vol = parse_num_list(index_vol_user) else: index_vol = range(data.shape[3]) # Make sure user selected 2 volumes with diff method if method == 'diff': if not len(index_vol) == 2: printv( 'Method "diff" should be used with exactly two volumes (specify with flag "-vol").', 1, 'error') # Compute SNR # NB: "time" is assumed to be the 4th dimension of the variable "data" if method == 'mult': # Compute mean and STD across time data_mean = np.mean(data[:, :, :, index_vol], axis=3) data_std = np.std(data[:, :, :, index_vol], axis=3, ddof=1) # Generate mask where std is different from 0 mask_std_nonzero = np.where(data_std > param.almost_zero) snr_map = np.zeros_like(data_mean) snr_map[mask_std_nonzero] = data_mean[mask_std_nonzero] / data_std[ mask_std_nonzero] # Output SNR map fname_snr = add_suffix(fname_data, '_SNR-' + method) im_snr = empty_like(im_data) im_snr.data = snr_map im_snr.save(fname_snr, dtype=np.float32) # Output non-zero mask fname_stdnonzero = add_suffix(fname_data, '_mask-STD-nonzero' + method) im_stdnonzero = empty_like(im_data) data_stdnonzero = np.zeros_like(data_mean) data_stdnonzero[mask_std_nonzero] = 1 im_stdnonzero.data = data_stdnonzero im_stdnonzero.save(fname_stdnonzero, dtype=np.float32) # Compute SNR in ROI if fname_mask: mean_in_roi = np.average(data_mean[mask_std_nonzero], weights=mask[mask_std_nonzero]) std_in_roi = np.average(data_std[mask_std_nonzero], weights=mask[mask_std_nonzero]) snr_roi = mean_in_roi / std_in_roi # snr_roi = np.average(snr_map[mask_std_nonzero], weights=mask[mask_std_nonzero]) elif method == 'diff': data_2vol = np.take(data, index_vol, axis=3) # Compute mean in ROI data_mean = np.mean(data_2vol, axis=3) mean_in_roi = np.average(data_mean, weights=mask) data_sub = np.subtract(data_2vol[:, :, :, 1], data_2vol[:, :, :, 0]) _, std_in_roi = weighted_avg_and_std(data_sub, mask) # Compute SNR, correcting for Rayleigh noise (see eq. 7 in Dietrich et al.) snr_roi = (2 / np.sqrt(2)) * mean_in_roi / std_in_roi # Display result if fname_mask: printv('\nSNR_' + method + ' = ' + str(snr_roi) + '\n', type='info')
def concat_data(fname_in_list, dim, pixdim=None, squeeze_data=False): """ Concatenate data :param im_in_list: list of Images or image filenames :param dim: dimension: 0, 1, 2, 3. :param pixdim: pixel resolution to join to image header :param squeeze_data: bool: if True, remove the last dim if it is a singleton. :return im_out: concatenated image """ # WARNING: calling concat_data in python instead of in command line causes a non understood issue (results are different with both options) # from numpy import concatenate, expand_dims dat_list = [] data_concat_list = [] # check if shape of first image is smaller than asked dim to concatenate along # data0 = Image(fname_in_list[0]).data # if len(data0.shape) <= dim: # expand_dim = True # else: # expand_dim = False for i, fname in enumerate(fname_in_list): # if there is more than 100 images to concatenate, then it does it iteratively to avoid memory issue. if i != 0 and i % 100 == 0: data_concat_list.append(np.concatenate(dat_list, axis=dim)) im = Image(fname) dat = im.data # if image shape is smaller than asked dim, then expand dim if len(dat.shape) <= dim: dat = np.expand_dims(dat, dim) dat_list = [dat] del im del dat else: im = Image(fname) dat = im.data # if image shape is smaller than asked dim, then expand dim if len(dat.shape) <= dim: dat = np.expand_dims(dat, dim) dat_list.append(dat) del im del dat if data_concat_list: data_concat_list.append(np.concatenate(dat_list, axis=dim)) data_concat = np.concatenate(data_concat_list, axis=dim) else: data_concat = np.concatenate(dat_list, axis=dim) # write file im_out = msct_image.empty_like(Image(fname_in_list[0])) im_out.data = data_concat if isinstance(fname_in_list[0], str): im_out.absolutepath = sct.add_suffix(fname_in_list[0], '_concat') else: if fname_in_list[0].absolutepath: im_out.absolutepath = sct.add_suffix(fname_in_list[0].absolutepath, '_concat') if pixdim is not None: im_out.hdr['pixdim'] = pixdim if squeeze_data and data_concat.shape[dim] == 1: # remove the last dim if it is a singleton. im_out.data = data_concat.reshape(tuple([ x for (idx_shape, x) in enumerate(data_concat.shape) if idx_shape != dim])) else: im_out.data = data_concat return im_out
def _preprocess_segment(fname_t2, fname_t2_seg, contrast_test, dim_3=False): tmp_folder = sct.TempFolder() tmp_folder_path = tmp_folder.get_path() tmp_folder.chdir() img = Image(fname_t2) gt = Image(fname_t2_seg) fname_t2_RPI, fname_t2_seg_RPI = 'img_RPI.nii.gz', 'seg_RPI.nii.gz' img.change_orientation('RPI').save(fname_t2_RPI) gt.change_orientation('RPI').save(fname_t2_seg_RPI) input_resolution = gt.dim[4:7] del img, gt fname_res, fname_ctr = deepseg_sc.find_centerline(algo='svm', image_fname=fname_t2_RPI, contrast_type=contrast_test, brain_bool=False, folder_output=tmp_folder_path, remove_temp_files=1, centerline_fname=None) fname_t2_seg_RPI_res = 'seg_RPI_res.nii.gz' new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])]) resample_file(fname_t2_seg_RPI, fname_t2_seg_RPI_res, new_resolution, 'mm', 'linear', verbose=0) img, ctr, gt = Image(fname_res), Image(fname_ctr), Image(fname_t2_seg_RPI_res) _, _, _, img = deepseg_sc.crop_image_around_centerline(im_in=img, ctr_in=ctr, crop_size=64) _, _, _, gt = deepseg_sc.crop_image_around_centerline(im_in=gt, ctr_in=ctr, crop_size=64) del ctr img = deepseg_sc.apply_intensity_normalization(im_in=img) if dim_3: # If 3D kernels fname_t2_RPI_res_crop, fname_t2_seg_RPI_res_crop = 'img_RPI_res_crop.nii.gz', 'seg_RPI_res_crop.nii.gz' img.save(fname_t2_RPI_res_crop) gt.save(fname_t2_seg_RPI_res_crop) del img, gt fname_t2_RPI_res_crop_res = 'img_RPI_res_crop_res.nii.gz' fname_t2_seg_RPI_res_crop_res = 'seg_RPI_res_crop_res.nii.gz' resample_file(fname_t2_RPI_res_crop, fname_t2_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0) resample_file(fname_t2_seg_RPI_res_crop, fname_t2_seg_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0) img, gt = Image(fname_t2_RPI_res_crop_res), Image(fname_t2_seg_RPI_res_crop_res) tmp_folder.chdir_undo() tmp_folder.cleanup() return img, gt
sct.init_sct(log_level=param.verbose, update=True) # Update log level tmp_dir = sct.tmp_create() im1_name = "im1.nii.gz" sct.copy(input_fname, os.path.join(tmp_dir, im1_name)) if input_second_fname != '': im2_name = 'im2.nii.gz' sct.copy(input_second_fname, os.path.join(tmp_dir, im2_name)) else: im2_name = None curdir = os.getcwd() os.chdir(tmp_dir) # now = time.time() input_im1 = Image(resample_image(im1_name, binary=True, thr=0.5, npx=resample_to, npy=resample_to)) input_im1.absolutepath = os.path.basename(input_fname) if im2_name is not None: input_im2 = Image(resample_image(im2_name, binary=True, thr=0.5, npx=resample_to, npy=resample_to)) input_im2.absolutepath = os.path.basename(input_second_fname) else: input_im2 = None computation = ComputeDistances(input_im1, im2=input_im2, param=param) # TODO change back the orientatin of the thinned image if param.thinning: computation.thinning1.thinned_image.save(os.path.join(curdir, sct.add_suffix(os.path.basename(input_fname), '_thinned'))) if im2_name is not None: computation.thinning2.thinned_image.save(os.path.join(curdir, sct.add_suffix(os.path.basename(input_second_fname), '_thinned')))
class ProcessLabels(object): def __init__(self, fname_label, fname_output=None, fname_ref=None, cross_radius=5, dilate=False, coordinates=None, verbose=1, vertebral_levels=None, value=None, msg=""): """ Collection of processes that deal with label creation/modification. :param fname_label: :param fname_output: :param fname_ref: :param cross_radius: :param dilate: # TODO: remove dilate (does not seem to be used) :param coordinates: :param verbose: :param vertebral_levels: :param value: :param msg: string. message to display to the user. """ self.image_input = Image(fname_label, verbose=verbose) self.image_ref = None if fname_ref is not None: self.image_ref = Image(fname_ref, verbose=verbose) if isinstance(fname_output, list): if len(fname_output) == 1: self.fname_output = fname_output[0] else: self.fname_output = fname_output else: self.fname_output = fname_output self.cross_radius = cross_radius self.vertebral_levels = vertebral_levels self.dilate = dilate self.coordinates = coordinates self.verbose = verbose self.value = value self.msg = msg self.output_image = None def process(self, type_process): # for some processes, change orientation of input image to RPI change_orientation = False if type_process in ['vert-body', 'vert-disc', 'vert-continuous']: # get orientation of input image input_orientation = self.image_input.orientation # change orientation self.image_input.change_orientation('RPI') change_orientation = True if type_process == 'add': self.output_image = self.add(self.value) if type_process == 'plan': self.output_image = self.plan(self.cross_radius, 100, 5) if type_process == 'plan_ref': self.output_image = self.plan_ref() if type_process == 'increment': self.output_image = self.increment_z_inverse() if type_process == 'disks': self.output_image = self.labelize_from_disks() if type_process == 'MSE': self.MSE() self.fname_output = None if type_process == 'remove-reference': self.output_image = self.remove_label() if type_process == 'remove-symm': self.output_image = self.remove_label(symmetry=True) if type_process == 'create': self.output_image = self.create_label() if type_process == 'create-add': self.output_image = self.create_label(add=True) if type_process == 'create-seg': self.output_image = self.create_label_along_segmentation() if type_process == 'display-voxel': self.display_voxel() self.fname_output = None if type_process == 'diff': self.diff() self.fname_output = None if type_process == 'dist-inter': # second argument is in pixel distance self.distance_interlabels(5) self.fname_output = None if type_process == 'cubic-to-point': self.output_image = self.cubic_to_point() if type_process == 'vert-body': self.output_image = self.label_vertebrae(self.vertebral_levels) if type_process == 'vert-continuous': self.output_image = self.continuous_vertebral_levels() if type_process == 'create-viewer': self.output_image = self.launch_sagittal_viewer(self.value) if type_process in ['remove', 'keep']: self.output_image = self.remove_or_keep_labels(self.value, action=type_process) # TODO: do not save here. Create another function save() for that if self.fname_output is not None: if change_orientation: self.output_image.change_orientation(input_orientation) self.output_image.absolutepath = self.fname_output if type_process == 'vert-continuous': self.output_image.save(dtype='float32') elif type_process != 'plan_ref': self.output_image.save(dtype='minimize_int') else: self.output_image.save() return self.output_image def add(self, value): """ This function add a specified value to all non-zero voxels. """ image_output = self.image_input.copy() # image_output.data *= 0 coordinates_input = self.image_input.getNonZeroCoordinates() # for all points with non-zeros neighbors, force the neighbors to 0 for i, coord in enumerate(coordinates_input): image_output.data[int(coord.x), int(coord.y), int(coord.z)] = image_output.data[int(coord.x), int(coord.y), int(coord.z)] + float(value) return image_output def create_label(self, add=False): """ Create an image with labels listed by the user. This method works only if the user inserted correct coordinates. self.coordinates is a list of coordinates (class in msct_types). a Coordinate contains x, y, z and value. If only one label is to be added, coordinates must be completed with '[]' examples: For one label: object_define=ProcessLabels( fname_label, coordinates=[coordi]) where coordi is a 'Coordinate' object from msct_types For two labels: object_define=ProcessLabels( fname_label, coordinates=[coordi1, coordi2]) where coordi1 and coordi2 are 'Coordinate' objects from msct_types """ image_output = self.image_input.copy() if add else msct_image.zeros_like(self.image_input) # loop across labels for i, coord in enumerate(self.coordinates): if len(image_output.data.shape) == 3: image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coord.value elif len(image_output.data.shape) == 2: assert str(coord.z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(coord.z) image_output.data[int(coord.x), int(coord.y)] = coord.value else: sct.printv('ERROR: Data should be 2D or 3D. Current shape is: ' + str(image_output.data.shape), 1, 'error') # display info sct.printv('Label #' + str(i) + ': ' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ' --> ' + str(coord.value), 1) return image_output def create_label_along_segmentation(self): """ Create an image with labels defined along the spinal cord segmentation (or centerline). Input image does **not** need to be RPI (re-orientation is done within this function). Example: object_define=ProcessLabels(fname_segmentation, coordinates=[coord_1, coord_2, coord_i]), where coord_i='z,value'. If z=-1, then use z=nz/2 (i.e. center of FOV in superior-inferior direction) Returns """ # reorient input image to RPI im_rpi = self.image_input.copy().change_orientation('RPI') im_output_rpi = zeros_like(im_rpi) # loop across labels for ilabel, coord in enumerate(self.coordinates): # split coord string list_coord = coord.split(',') # convert to int() and assign to variable z, value = [int(i) for i in list_coord] # update z based on native image orientation (z should represent superior-inferior axis) coord = Coordinate([z, z, z]) # since we don't know which dimension corresponds to the superior-inferior # axis, we put z in all dimensions (we don't care about x and y here) _, _, z_rpi = coord.permute(self.image_input, 'RPI') # if z=-1, replace with nz/2 if z == -1: z_rpi = int(np.round(im_output_rpi.dim[2] / 2.0)) # get center of mass of segmentation at given z x, y = ndimage.measurements.center_of_mass(np.array(im_rpi.data[:, :, z_rpi])) # round values to make indices x, y = int(np.round(x)), int(np.round(y)) # display info sct.printv('Label #' + str(ilabel) + ': ' + str(x) + ',' + str(y) + ',' + str(z_rpi) + ' --> ' + str(value), 1) if len(im_output_rpi.data.shape) == 3: im_output_rpi.data[x, y, z_rpi] = value elif len(im_output_rpi.data.shape) == 2: assert str(z) == '0', "ERROR: 2D coordinates should have a Z value of 0. Z coordinate is :" + str(z) im_output_rpi.data[x, y] = value # change orientation back to native return im_output_rpi.change_orientation(self.image_input.orientation) def plan(self, width, offset=0, gap=1): """ Create a plane of thickness="width" and changes its value with an offset and a gap between labels. """ image_output = msct_image.zeros_like(self.image_input) coordinates_input = self.image_input.getNonZeroCoordinates() # for all points with non-zeros neighbors, force the neighbors to 0 for coord in coordinates_input: image_output.data[:, :, int(coord.z) - width:int(coord.z) + width] = offset + gap * coord.value return image_output def plan_ref(self): """ Generate a plane in the reference space for each label present in the input image """ image_output = msct_image.zeros_like(Image(self.image_ref)) image_input_neg = msct_image.zeros_like(Image(self.image_input)) image_input_pos = msct_image.zeros_like(Image(self.image_input)) X, Y, Z = (self.image_input.data < 0).nonzero() for i in range(len(X)): image_input_neg.data[X[i], Y[i], Z[i]] = -self.image_input.data[X[i], Y[i], Z[i]] # in order to apply getNonZeroCoordinates X_pos, Y_pos, Z_pos = (self.image_input.data > 0).nonzero() for i in range(len(X_pos)): image_input_pos.data[X_pos[i], Y_pos[i], Z_pos[i]] = self.image_input.data[X_pos[i], Y_pos[i], Z_pos[i]] coordinates_input_neg = image_input_neg.getNonZeroCoordinates() coordinates_input_pos = image_input_pos.getNonZeroCoordinates() image_output.change_type('float32') for coord in coordinates_input_neg: image_output.data[:, :, int(coord.z)] = -coord.value # PB: takes the int value of coord.value for coord in coordinates_input_pos: image_output.data[:, :, int(coord.z)] = coord.value return image_output def cubic_to_point(self): """ Calculate the center of mass of each group of labels and returns a file of same size with only a label by group at the center of mass of this group. It is to be used after applying homothetic warping field to a label file as the labels will be dilated. Be careful: this algorithm computes the center of mass of voxels with same value, if two groups of voxels with the same value are present but separated in space, this algorithm will compute the center of mass of the two groups together. :return: image_output """ # 0. Initialization of output image output_image = msct_image.zeros_like(self.image_input) # 1. Extraction of coordinates from all non-null voxels in the image. Coordinates are sorted by value. coordinates = self.image_input.getNonZeroCoordinates(sorting='value') # 2. Separate all coordinates into groups by value groups = dict() for coord in coordinates: if coord.value in groups: groups[coord.value].append(coord) else: groups[coord.value] = [coord] # 3. Compute the center of mass of each group of voxels and write them into the output image for value, list_coord in groups.items(): center_of_mass = sum(list_coord) / float(len(list_coord)) sct.printv("Value = " + str(center_of_mass.value) + " : (" + str(center_of_mass.x) + ", " + str(center_of_mass.y) + ", " + str(center_of_mass.z) + ") --> ( " + str(np.round(center_of_mass.x)) + ", " + str(np.round(center_of_mass.y)) + ", " + str(np.round(center_of_mass.z)) + ")", verbose=self.verbose) output_image.data[int(np.round(center_of_mass.x)), int(np.round(center_of_mass.y)), int(np.round(center_of_mass.z))] = center_of_mass.value return output_image def increment_z_inverse(self): """ Take all non-zero values, sort them along the inverse z direction, and attributes the values 1, 2, 3, etc. This function assuming RPI orientation. """ image_output = msct_image.zeros_like(self.image_input) coordinates_input = self.image_input.getNonZeroCoordinates(sorting='z', reverse_coord=True) # for all points with non-zeros neighbors, force the neighbors to 0 for i, coord in enumerate(coordinates_input): image_output.data[int(coord.x), int(coord.y), int(coord.z)] = i + 1 return image_output def labelize_from_disks(self): """ Create an image with regions labelized depending on values from reference. Typically, user inputs a segmentation image, and labels with disks position, and this function produces a segmentation image with vertebral levels labelized. Labels are assumed to be non-zero and incremented from top to bottom, assuming a RPI orientation """ image_output = msct_image.zeros_like(self.image_input) coordinates_input = self.image_input.getNonZeroCoordinates() coordinates_ref = self.image_ref.getNonZeroCoordinates(sorting='value') # for all points in input, find the value that has to be set up, depending on the vertebral level for i, coord in enumerate(coordinates_input): for j in range(0, len(coordinates_ref) - 1): if coordinates_ref[j + 1].z < coord.z <= coordinates_ref[j].z: image_output.data[int(coord.x), int(coord.y), int(coord.z)] = coordinates_ref[j].value return image_output def label_vertebrae(self, levels_user=None): """ Find the center of mass of vertebral levels specified by the user. :return: image_output: Image with labels. """ # get center of mass of each vertebral level image_cubic2point = self.cubic_to_point() # get list of coordinates for each label list_coordinates = image_cubic2point.getNonZeroCoordinates(sorting='value') # if user did not specify levels, include all: if levels_user[0] == 0: levels_user = [int(i.value) for i in list_coordinates] # loop across labels and remove those that are not listed by the user for i_label in range(len(list_coordinates)): # check if this level is NOT in levels_user if not levels_user.count(int(list_coordinates[i_label].value)): # if not, set value to zero image_cubic2point.data[int(list_coordinates[i_label].x), int(list_coordinates[i_label].y), int(list_coordinates[i_label].z)] = 0 # list all labels return image_cubic2point def MSE(self, threshold_mse=0): """ Compute the Mean Square Distance Error between two sets of labels (input and ref). Moreover, a warning is generated for each label mismatch. If the MSE is above the threshold provided (by default = 0mm), a log is reported with the filenames considered here. """ coordinates_input = self.image_input.getNonZeroCoordinates() coordinates_ref = self.image_ref.getNonZeroCoordinates() # check if all the labels in both the images match if len(coordinates_input) != len(coordinates_ref): sct.printv('ERROR: labels mismatch', 1, 'warning') for coord in coordinates_input: if np.round(coord.value) not in [np.round(coord_ref.value) for coord_ref in coordinates_ref]: sct.printv('ERROR: labels mismatch', 1, 'warning') for coord_ref in coordinates_ref: if np.round(coord_ref.value) not in [np.round(coord.value) for coord in coordinates_input]: sct.printv('ERROR: labels mismatch', 1, 'warning') result = 0.0 for coord in coordinates_input: for coord_ref in coordinates_ref: if np.round(coord_ref.value) == np.round(coord.value): result += (coord_ref.z - coord.z) ** 2 break result = np.sqrt(result / len(coordinates_input)) sct.printv('MSE error in Z direction = ' + str(result) + ' mm') if result > threshold_mse: parent, stem, ext = sct.extract_fname(self.image_input.absolutepath) fname_report = os.path.join(parent, 'error_log_{}.txt'.format(stem)) with open(fname_report, 'w') as f: f.write('The labels error (MSE) between {} and {} is: {}\n'.format( os.path.relpath(self.image_input.absolutepath, os.path.dirname(fname_report)), os.path.relpath(self.image_ref.absolutepath, os.path.dirname(fname_report)), result)) return result @staticmethod def remove_label_coord(coord_input, coord_ref, symmetry=False): """ coord_input and coord_ref should be sets of CoordinateValue in order to improve speed of intersection :param coord_input: set of CoordinateValue :param coord_ref: set of CoordinateValue :param symmetry: boolean, :return: intersection of CoordinateValue: list """ from msct_types import CoordinateValue if isinstance(coord_input[0], CoordinateValue) and isinstance(coord_ref[0], CoordinateValue) and symmetry: coord_intersection = list(set(coord_input).intersection(set(coord_ref))) result_coord_input = [coord for coord in coord_input if coord in coord_intersection] result_coord_ref = [coord for coord in coord_ref if coord in coord_intersection] else: result_coord_ref = coord_ref result_coord_input = [coord for coord in coord_input if list(filter(lambda x: x.value == coord.value, coord_ref))] if symmetry: result_coord_ref = [coord for coord in coord_ref if list(filter(lambda x: x.value == coord.value, result_coord_input))] return result_coord_input, result_coord_ref def remove_label(self, symmetry=False): """ Compare two label images and remove any labels in input image that are not in reference image. The symmetry option enables to remove labels from reference image that are not in input image """ # image_output = Image(self.image_input.dim, orientation=self.image_input.orientation, hdr=self.image_input.hdr, verbose=self.verbose) image_output = msct_image.zeros_like(self.image_input) result_coord_input, result_coord_ref = self.remove_label_coord(self.image_input.getNonZeroCoordinates(coordValue=True), self.image_ref.getNonZeroCoordinates(coordValue=True), symmetry) for coord in result_coord_input: image_output.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value)) if symmetry: # image_output_ref = Image(self.image_ref.dim, orientation=self.image_ref.orientation, hdr=self.image_ref.hdr, verbose=self.verbose) image_output_ref = Image(self.image_ref, verbose=self.verbose) for coord in result_coord_ref: image_output_ref.data[int(coord.x), int(coord.y), int(coord.z)] = int(np.round(coord.value)) image_output_ref.absolutepath = self.fname_output[1] image_output_ref.save('minimize_int') self.fname_output = self.fname_output[0] return image_output def display_voxel(self): """ Display all the labels that are contained in the input image. The image is suppose to be RPI to display voxels. But works also for other orientations """ coordinates_input = self.image_input.getNonZeroCoordinates(sorting='value') self.useful_notation = '' for coord in coordinates_input: sct.printv('Position=(' + str(coord.x) + ',' + str(coord.y) + ',' + str(coord.z) + ') -- Value= ' + str(coord.value), verbose=self.verbose) if self.useful_notation: self.useful_notation = self.useful_notation + ':' self.useful_notation += str(coord) sct.printv('All labels (useful syntax):', verbose=self.verbose) sct.printv(self.useful_notation, verbose=self.verbose) return coordinates_input def get_physical_coordinates(self): """ This function returns the coordinates of the labels in the physical referential system. :return: a list of CoordinateValue, in the physical (scanner) space """ coord = self.image_input.getNonZeroCoordinates(sorting='value') phys_coord = [] for c in coord: # convert pixelar coordinates to physical coordinates c_p = self.image_input.transfo_pix2phys([[c.x, c.y, c.z]])[0] phys_coord.append(CoordinateValue([c_p[0], c_p[1], c_p[2], c.value])) return phys_coord def get_coordinates_in_destination(self, im_dest, type='discrete'): """ This function calculate the position of labels in the pixelar space of a destination image :param im_dest: Object Image :param type: 'discrete' or 'continuous' :return: a list of CoordinateValue, in the pixelar (image) space of the destination image """ phys_coord = self.get_physical_coordinates() dest_coord = [] for c in phys_coord: if type is 'discrete': c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]])[0] elif type is 'continuous': c_p = im_dest.transfo_phys2pix([[c.x, c.y, c.y]], real=False)[0] else: raise ValueError("The value of 'type' should either be 'discrete' or 'continuous'.") dest_coord.append(CoordinateValue([c_p[0], c_p[1], c_p[2], c.value])) return dest_coord def diff(self): """ Detect any label mismatch between input image and reference image """ coordinates_input = self.image_input.getNonZeroCoordinates() coordinates_ref = self.image_ref.getNonZeroCoordinates() sct.printv("Label in input image that are not in reference image:") for coord in coordinates_input: isIn = False for coord_ref in coordinates_ref: if coord.value == coord_ref.value: isIn = True break if not isIn: sct.printv(coord.value) sct.printv("Label in ref image that are not in input image:") for coord_ref in coordinates_ref: isIn = False for coord in coordinates_input: if coord.value == coord_ref.value: isIn = True break if not isIn: sct.printv(coord_ref.value) def distance_interlabels(self, max_dist): """ Calculate the distances between each label in the input image. If a distance is larger than max_dist, a warning message is displayed. """ coordinates_input = self.image_input.getNonZeroCoordinates() # for all points with non-zeros neighbors, force the neighbors to 0 for i in range(0, len(coordinates_input) - 1): dist = np.sqrt((coordinates_input[i].x - coordinates_input[i + 1].x)**2 + (coordinates_input[i].y - coordinates_input[i + 1].y)**2 + (coordinates_input[i].z - coordinates_input[i + 1].z)**2) if dist < max_dist: sct.printv('Warning: the distance between label ' + str(i) + '[' + str(coordinates_input[i].x) + ',' + str(coordinates_input[i].y) + ',' + str( coordinates_input[i].z) + ']=' + str(coordinates_input[i].value) + ' and label ' + str(i + 1) + '[' + str( coordinates_input[i + 1].x) + ',' + str(coordinates_input[i + 1].y) + ',' + str(coordinates_input[i + 1].z) + ']=' + str( coordinates_input[i + 1].value) + ' is larger than ' + str(max_dist) + '. Distance=' + str(dist)) def continuous_vertebral_levels(self): """ This function transforms the vertebral levels file from the template into a continuous file. Instead of having integer representing the vertebral level on each slice, a continuous value that represents the position of the slice in the vertebral level coordinate system. The image must be RPI :return: """ im_input = Image(self.image_input, self.verbose) im_output = msct_image.zeros_like(self.image_input) # 1. extract vertebral levels from input image # a. extract centerline # b. for each slice, extract corresponding level nx, ny, nz, nt, px, py, pz, pt = im_input.dim from spinalcordtoolbox.centerline.core import get_centerline _, arr_ctl, _ = get_centerline(self.image_input, algo_fitting='bspline') x_centerline_fit, y_centerline_fit, z_centerline = arr_ctl value_centerline = np.array( [im_input.data[int(x_centerline_fit[it]), int(y_centerline_fit[it]), int(z_centerline[it])] for it in range(len(z_centerline))]) # 2. compute distance for each vertebral level --> Di for i being the vertebral levels vertebral_levels = {} for slice_image, level in enumerate(value_centerline): if level not in vertebral_levels: vertebral_levels[level] = slice_image length_levels = {} for level in vertebral_levels: indexes_slice = np.where(value_centerline == level) length_levels[level] = np.sum([np.sqrt(((x_centerline_fit[indexes_slice[0][index_slice + 1]] - x_centerline_fit[indexes_slice[0][index_slice]]) * px)**2 + ((y_centerline_fit[indexes_slice[0][index_slice + 1]] - y_centerline_fit[indexes_slice[0][index_slice]]) * py)**2 + ((z_centerline[indexes_slice[0][index_slice + 1]] - z_centerline[indexes_slice[0][index_slice]]) * pz)**2) for index_slice in range(len(indexes_slice[0]) - 1)]) # 2. for each slice: # a. identify corresponding vertebral level --> i # b. calculate distance of slice from upper vertebral level --> d # c. compute relative distance in the vertebral level coordinate system --> d/Di continuous_values = {} for it, iz in enumerate(z_centerline): level = value_centerline[it] indexes_slice = np.where(value_centerline == level) indexes_slice = indexes_slice[0][indexes_slice[0] >= it] distance_from_level = np.sum([np.sqrt(((x_centerline_fit[indexes_slice[index_slice + 1]] - x_centerline_fit[indexes_slice[index_slice]]) * px * px) ** 2 + ((y_centerline_fit[indexes_slice[index_slice + 1]] - y_centerline_fit[indexes_slice[index_slice]]) * py * py) ** 2 + ((z_centerline[indexes_slice[index_slice + 1]] - z_centerline[indexes_slice[index_slice]]) * pz * pz) ** 2) for index_slice in range(len(indexes_slice) - 1)]) continuous_values[iz] = level + 2.0 * distance_from_level / float(length_levels[level]) # 3. saving data # for each slice, get all non-zero pixels and replace with continuous values coordinates_input = self.image_input.getNonZeroCoordinates() im_output.change_type(np.float32) # for all points in input, find the value that has to be set up, depending on the vertebral level for i, coord in enumerate(coordinates_input): im_output.data[int(coord.x), int(coord.y), int(coord.z)] = continuous_values[coord.z] return im_output def launch_sagittal_viewer(self, labels): from spinalcordtoolbox.gui import base from spinalcordtoolbox.gui.sagittal import launch_sagittal_dialog params = base.AnatomicalParams() params.vertebraes = labels params.input_file_name = self.image_input.absolutepath params.output_file_name = self.fname_output params.subtitle = self.msg output = msct_image.zeros_like(self.image_input) output.absolutepath = self.fname_output launch_sagittal_dialog(self.image_input, output, params) return output def remove_or_keep_labels(self, labels, action): """ Create or remove labels from self.image_input :param list(int): Labels to keep or remove :param str: 'remove': remove specified labels (i.e. set to zero), 'keep': keep specified labels and remove the others """ if action == 'keep': image_output = msct_image.zeros_like(self.image_input) elif action == 'remove': image_output = self.image_input.copy() coordinates_input = self.image_input.getNonZeroCoordinates() for labelNumber in labels: isInLabels = False for coord in coordinates_input: if labelNumber == coord.value: new_coord = coord isInLabels = True if isInLabels: if action == 'keep': image_output.data[int(new_coord.x), int(new_coord.y), int(new_coord.z)] = new_coord.value elif action == 'remove': image_output.data[int(new_coord.x), int(new_coord.y), int(new_coord.z)] = 0.0 else: sct.printv("WARNING: Label " + str(float(labelNumber)) + " not found in input image.", type='warning') return image_output
def main(args=None): """ Main function :param args: :return: """ # initializations output_type = None dim_list = ['x', 'y', 'z', 't'] # Get parser args if args is None: args = None if sys.argv[1:] else ['--help'] parser = get_parser() arguments = parser.parse_args(args=args) fname_in = arguments.i n_in = len(fname_in) verbose = arguments.v sct.init_sct(log_level=verbose, update=True) # Update log level if arguments.o is not None: fname_out = arguments.o else: fname_out = None # run command if arguments.concat is not None: dim = arguments.concat assert dim in dim_list dim = dim_list.index(dim) im_out = [concat_data(fname_in, dim)] # TODO: adapt to fname_in elif arguments.copy_header is not None: im_in = Image(fname_in[0]) im_dest = Image(arguments.copy_header) im_dest_new = im_in.copy() im_dest_new.data = im_dest.data.copy() # im_dest.header = im_in.header im_dest_new.absolutepath = im_dest.absolutepath im_out = [im_dest_new] fname_out = arguments.copy_header elif arguments.display_warp: im_in = fname_in[0] visualize_warp(im_in, fname_grid=None, step=3, rm_tmp=True) im_out = None elif arguments.getorient: im_in = Image(fname_in[0]) orient = im_in.orientation im_out = None elif arguments.keep_vol is not None: index_vol = (arguments.keep_vol).split(',') for iindex_vol, vol in enumerate(index_vol): index_vol[iindex_vol] = int(vol) im_in = Image(fname_in[0]) im_out = [remove_vol(im_in, index_vol, todo='keep')] elif arguments.mcs: im_in = Image(fname_in[0]) if n_in != 1: sct.printv(parser.usage.generate(error='ERROR: -mcs need only one input')) if len(im_in.data.shape) != 5: sct.printv(parser.usage.generate(error='ERROR: -mcs input need to be a multi-component image')) im_out = multicomponent_split(im_in) elif arguments.omc: im_ref = Image(fname_in[0]) for fname in fname_in: im = Image(fname) if im.data.shape != im_ref.data.shape: sct.printv(parser.usage.generate(error='ERROR: -omc inputs need to have all the same shapes')) del im im_out = [multicomponent_merge(fname_in)] # TODO: adapt to fname_in elif arguments.pad is not None: im_in = Image(fname_in[0]) ndims = len(im_in.data.shape) if ndims != 3: sct.printv('ERROR: you need to specify a 3D input file.', 1, 'error') return pad_arguments = arguments.pad.split(',') if len(pad_arguments) != 3: sct.printv('ERROR: you need to specify 3 padding values.', 1, 'error') padx, pady, padz = pad_arguments padx, pady, padz = int(padx), int(pady), int(padz) im_out = [pad_image(im_in, pad_x_i=padx, pad_x_f=padx, pad_y_i=pady, pad_y_f=pady, pad_z_i=padz, pad_z_f=padz)] elif arguments.pad_asym is not None: im_in = Image(fname_in[0]) ndims = len(im_in.data.shape) if ndims != 3: sct.printv('ERROR: you need to specify a 3D input file.', 1, 'error') return pad_arguments = arguments.pad_asym.split(',') if len(pad_arguments) != 6: sct.printv('ERROR: you need to specify 6 padding values.', 1, 'error') padxi, padxf, padyi, padyf, padzi, padzf = pad_arguments padxi, padxf, padyi, padyf, padzi, padzf = int(padxi), int(padxf), int(padyi), int(padyf), int(padzi), int(padzf) im_out = [pad_image(im_in, pad_x_i=padxi, pad_x_f=padxf, pad_y_i=padyi, pad_y_f=padyf, pad_z_i=padzi, pad_z_f=padzf)] elif arguments.remove_vol is not None: index_vol = (arguments.remove_vol).split(',') for iindex_vol, vol in enumerate(index_vol): index_vol[iindex_vol] = int(vol) im_in = Image(fname_in[0]) im_out = [remove_vol(im_in, index_vol, todo='remove')] elif arguments.setorient is not None: sct.printv(fname_in[0]) im_in = Image(fname_in[0]) im_out = [msct_image.change_orientation(im_in, arguments.setorient)] elif arguments.setorient_data is not None: im_in = Image(fname_in[0]) im_out = [msct_image.change_orientation(im_in, arguments.setorient_data, data_only=True)] elif arguments.split is not None: dim = arguments.split assert dim in dim_list im_in = Image(fname_in[0]) dim = dim_list.index(dim) im_out = split_data(im_in, dim) elif arguments.type is not None: output_type = arguments.type im_in = Image(fname_in[0]) im_out = [im_in] # TODO: adapt to fname_in else: im_out = None sct.printv(parser.usage.generate(error='ERROR: you need to specify an operation to do on the input image')) # in case fname_out is not defined, use first element of input file name list if fname_out is None: fname_out = fname_in[0] # Write output if im_out is not None: sct.printv('Generate output files...', verbose) # if only one output if len(im_out) == 1 and not '-split' in arguments: im_out[0].save(fname_out, dtype=output_type, verbose=verbose) sct.display_viewer_syntax([fname_out], verbose=verbose) if arguments.mcs: # use input file name and add _X, _Y _Z. Keep the same extension l_fname_out = [] for i_dim in range(3): l_fname_out.append(sct.add_suffix(fname_out or fname_in[0], '_' + dim_list[i_dim].upper())) im_out[i_dim].save(l_fname_out[i_dim], verbose=verbose) sct.display_viewer_syntax(fname_out) if arguments.split is not None: # use input file name and add _"DIM+NUMBER". Keep the same extension l_fname_out = [] for i, im in enumerate(im_out): l_fname_out.append(sct.add_suffix(fname_out or fname_in[0], '_' + dim_list[dim].upper() + str(i).zfill(4))) im.save(l_fname_out[i]) sct.display_viewer_syntax(l_fname_out) elif arguments.getorient: sct.printv(orient) elif arguments.display_warp: sct.printv('Warping grid generated.', verbose, 'info')
tmp_dir = sct.tmp_create() im1_name = "im1.nii.gz" sct.copy(input_fname, os.path.join(tmp_dir, im1_name)) if input_second_fname != '': im2_name = 'im2.nii.gz' sct.copy(input_second_fname, os.path.join(tmp_dir, im2_name)) else: im2_name = None curdir = os.getcwd() os.chdir(tmp_dir) # now = time.time() input_im1 = Image( resample_image(im1_name, binary=True, thr=0.5, npx=resample_to, npy=resample_to)) input_im1.absolutepath = os.path.basename(input_fname) if im2_name is not None: input_im2 = Image( resample_image(im2_name, binary=True, thr=0.5, npx=resample_to, npy=resample_to)) input_im2.absolutepath = os.path.basename(input_second_fname) else: input_im2 = None computation = ComputeDistances(input_im1, im2=input_im2, param=param)
def dmri_moco(param): file_data = 'dmri.nii' file_data_dirname, file_data_basename, file_data_ext = sct.extract_fname(file_data) file_b0 = 'b0.nii' file_dwi = 'dwi.nii' ext_data = '.nii.gz' # workaround "too many open files" by slurping the data mat_final = 'mat_final/' file_dwi_group = 'dwi_averaged_groups.nii' ext_mat = 'Warp.nii.gz' # warping field # Get dimensions of data sct.printv('\nGet dimensions of data...', param.verbose) im_data = Image(file_data) nx, ny, nz, nt, px, py, pz, pt = im_data.dim sct.printv(' ' + str(nx) + ' x ' + str(ny) + ' x ' + str(nz), param.verbose) # Identify b=0 and DWI images index_b0, index_dwi, nb_b0, nb_dwi = sct_dmri_separate_b0_and_dwi.identify_b0('bvecs.txt', param.fname_bvals, param.bval_min, param.verbose) # check if dmri and bvecs are the same size if not nb_b0 + nb_dwi == nt: sct.printv('\nERROR in ' + os.path.basename(__file__) + ': Size of data (' + str(nt) + ') and size of bvecs (' + str(nb_b0 + nb_dwi) + ') are not the same. Check your bvecs file.\n', 1, 'error') sys.exit(2) # Prepare NIFTI (mean/groups...) #=================================================================================================================== # Split into T dimension sct.printv('\nSplit along T dimension...', param.verbose) im_data_split_list = split_data(im_data, 3) for im in im_data_split_list: x_dirname, x_basename, x_ext = sct.extract_fname(im.absolutepath) im.absolutepath = os.path.join(x_dirname, x_basename + ".nii.gz") im.save() # Merge b=0 images sct.printv('\nMerge b=0...', param.verbose) im_b0_list = [] for it in range(nb_b0): im_b0_list.append(im_data_split_list[index_b0[it]]) im_b0_out = concat_data(im_b0_list, 3).save(file_b0) sct.printv((' File created: ' + file_b0), param.verbose) # Average b=0 images sct.printv('\nAverage b=0...', param.verbose) file_b0_mean = sct.add_suffix(file_b0, '_mean') sct.run(['sct_maths', '-i', file_b0, '-o', file_b0_mean, '-mean', 't'], param.verbose) # Number of DWI groups nb_groups = int(math.floor(nb_dwi / param.group_size)) # Generate groups indexes group_indexes = [] for iGroup in range(nb_groups): group_indexes.append(index_dwi[(iGroup * param.group_size):((iGroup + 1) * param.group_size)]) # add the remaining images to the last DWI group nb_remaining = nb_dwi%param.group_size # number of remaining images if nb_remaining > 0: nb_groups += 1 group_indexes.append(index_dwi[len(index_dwi) - nb_remaining:len(index_dwi)]) file_dwi_dirname, file_dwi_basename, file_dwi_ext = sct.extract_fname(file_dwi) # DWI groups file_dwi_mean = [] for iGroup in tqdm(range(nb_groups), unit='iter', unit_scale=False, desc="Merge within groups", ascii=True, ncols=80): # get index index_dwi_i = group_indexes[iGroup] nb_dwi_i = len(index_dwi_i) # Merge DW Images file_dwi_merge_i = os.path.join(file_dwi_dirname, file_dwi_basename + '_' + str(iGroup) + ext_data) im_dwi_list = [] for it in range(nb_dwi_i): im_dwi_list.append(im_data_split_list[index_dwi_i[it]]) im_dwi_out = concat_data(im_dwi_list, 3).save(file_dwi_merge_i) # Average DW Images file_dwi_mean_i = os.path.join(file_dwi_dirname, file_dwi_basename + '_mean_' + str(iGroup) + ext_data) file_dwi_mean.append(file_dwi_mean_i) sct.run(["sct_maths", "-i", file_dwi_merge_i, "-o", file_dwi_mean[iGroup], "-mean", "t"], 0) # Merge DWI groups means sct.printv('\nMerging DW files...', param.verbose) # file_dwi_groups_means_merge = 'dwi_averaged_groups' im_dw_list = [] for iGroup in range(nb_groups): im_dw_list.append(file_dwi_mean[iGroup]) im_dw_out = concat_data(im_dw_list, 3).save(file_dwi_group) # Average DW Images # TODO: USEFULL ??? sct.printv('\nAveraging all DW images...', param.verbose) sct.run(["sct_maths", "-i", file_dwi_group, "-o", file_dwi_group + '_mean' + ext_data, "-mean", "t"], param.verbose) # segment dwi images using otsu algorithm if param.otsu: sct.printv('\nSegment group DWI using OTSU algorithm...', param.verbose) # import module otsu = importlib.import_module('sct_otsu') # get class from module param_otsu = otsu.param() #getattr(otsu, param) param_otsu.fname_data = file_dwi_group param_otsu.threshold = param.otsu param_otsu.file_suffix = '_seg' # run otsu otsu.otsu(param_otsu) file_dwi_group = file_dwi_group + '_seg.nii' # START MOCO #=================================================================================================================== # Estimate moco on b0 groups sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion on b=0 images...', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco = param param_moco.file_data = 'b0.nii' # identify target image if index_dwi[0] != 0: # If first DWI is not the first volume (most common), then there is a least one b=0 image before. In that case # select it as the target image for registration of all b=0 param_moco.file_target = os.path.join(file_data_dirname, file_data_basename + '_T' + str(index_b0[index_dwi[0] - 1]).zfill(4) + ext_data) else: # If first DWI is the first volume, then the target b=0 is the first b=0 from the index_b0. param_moco.file_target = os.path.join(file_data_dirname, file_data_basename + '_T' + str(index_b0[0]).zfill(4) + ext_data) param_moco.path_out = '' param_moco.todo = 'estimate' param_moco.mat_moco = 'mat_b0groups' file_mat_b0 = moco.moco(param_moco) # Estimate moco on dwi groups sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Estimating motion on DW images...', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = file_dwi_group param_moco.file_target = file_dwi_mean[0] # target is the first DW image (closest to the first b=0) param_moco.path_out = '' param_moco.todo = 'estimate_and_apply' param_moco.mat_moco = 'mat_dwigroups' file_mat_dwi = moco.moco(param_moco) # create final mat folder sct.create_folder(mat_final) # Copy b=0 registration matrices # TODO: use file_mat_b0 and file_mat_dwi instead of the hardcoding below sct.printv('\nCopy b=0 registration matrices...', param.verbose) for it in range(nb_b0): sct.copy('mat_b0groups/' + 'mat.Z0000T' + str(it).zfill(4) + ext_mat, mat_final + 'mat.Z0000T' + str(index_b0[it]).zfill(4) + ext_mat) # Copy DWI registration matrices sct.printv('\nCopy DWI registration matrices...', param.verbose) for iGroup in range(nb_groups): for dwi in range(len(group_indexes[iGroup])): # we cannot use enumerate because group_indexes has 2 dim. sct.copy('mat_dwigroups/' + 'mat.Z0000T' + str(iGroup).zfill(4) + ext_mat, mat_final + 'mat.Z0000T' + str(group_indexes[iGroup][dwi]).zfill(4) + ext_mat) # Spline Regularization along T if param.spline_fitting: moco.spline(mat_final, nt, nz, param.verbose, np.array(index_b0), param.plot_graph) # combine Eddy Matrices if param.run_eddy: param.mat_2_combine = 'mat_eddy' param.mat_final = mat_final moco.combine_matrix(param) # Apply moco on all dmri data sct.printv('\n-------------------------------------------------------------------------------', param.verbose) sct.printv(' Apply moco', param.verbose) sct.printv('-------------------------------------------------------------------------------', param.verbose) param_moco.file_data = file_data param_moco.file_target = os.path.join(file_dwi_dirname, file_dwi_basename + '_mean_' + str(0) + ext_data) # reference for reslicing into proper coordinate system param_moco.path_out = '' param_moco.mat_moco = mat_final param_moco.todo = 'apply' moco.moco(param_moco) # copy geometric information from header # NB: this is required because WarpImageMultiTransform in 2D mode wrongly sets pixdim(3) to "1". im_dmri = Image(file_data) fname_data_moco = os.path.join(file_data_dirname, file_data_basename + param.suffix + '.nii') im_dmri_moco = Image(fname_data_moco) im_dmri_moco.header = im_dmri.header im_dmri_moco.save() return os.path.abspath(fname_data_moco)
def deep_segmentation_MSlesion(fname_image, contrast_type, output_folder, ctr_algo='svm', ctr_file=None, brain_bool=True, remove_temp_files=1, verbose=1): """Pipeline.""" path_script = os.path.dirname(__file__) path_sct = os.path.dirname(path_script) # create temporary folder with intermediate results sct.log.info("\nCreating temporary folder...") file_fname = os.path.basename(fname_image) tmp_folder = sct.TempFolder() tmp_folder_path = tmp_folder.get_path() fname_image_tmp = tmp_folder.copy_from(fname_image) if ctr_algo == 'manual': # if the ctr_file is provided tmp_folder.copy_from(ctr_file) file_ctr = os.path.basename(ctr_file) else: file_ctr = None tmp_folder.chdir() # orientation of the image, should be RPI sct.log.info("\nReorient the image to RPI, if necessary...") fname_orient = sct.add_suffix(file_fname, '_RPI') im_2orient = Image(file_fname) original_orientation = im_2orient.orientation if original_orientation != 'RPI': im_orient = msct_image.change_orientation(im_2orient, 'RPI').save(fname_orient) else: im_orient = im_2orient sct.copy(fname_image_tmp, fname_orient) # resampling RPI image sct.log.info("\nResample the image to 0.5 mm isotropic resolution...") fname_res = sct.add_suffix(fname_orient, '_resampled') im_2res = im_orient input_resolution = im_2res.dim[4:7] new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])]) spinalcordtoolbox.resample.nipy_resample.resample_file(fname_orient, fname_res, new_resolution, 'mm', 'linear', verbose=0) # find the spinal cord centerline - execute OptiC binary sct.log.info("\nFinding the spinal cord centerline...") contrast_type_ctr = contrast_type.split('_')[0] centerline_filename = find_centerline(algo=ctr_algo, image_fname=fname_res, path_sct=path_sct, contrast_type=contrast_type_ctr, brain_bool=brain_bool, folder_output=tmp_folder_path, remove_temp_files=remove_temp_files, centerline_fname=file_ctr) # crop image around the spinal cord centerline sct.log.info("\nCropping the image around the spinal cord...") fname_crop = sct.add_suffix(fname_res, '_crop') crop_size = 48 X_CROP_LST, Y_CROP_LST = crop_image_around_centerline( filename_in=fname_res, filename_ctr=centerline_filename, filename_out=fname_crop, crop_size=crop_size) # normalize the intensity of the images sct.log.info("\nNormalizing the intensity...") fname_norm = sct.add_suffix(fname_crop, '_norm') apply_intensity_normalization(img_path=fname_crop, fname_out=fname_norm, contrast=contrast_type) # resample to 0.5mm isotropic fname_res3d = sct.add_suffix(fname_norm, '_resampled3d') spinalcordtoolbox.resample.nipy_resample.resample_file(fname_norm, fname_res3d, '0.5x0.5x0.5', 'mm', 'linear', verbose=0) # segment data using 3D convolutions sct.log.info( "\nSegmenting the MS lesions using deep learning on 3D patches...") segmentation_model_fname = os.path.join( path_sct, 'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_type)) fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg') segment_3d(model_fname=segmentation_model_fname, contrast_type=contrast_type, fname_in=fname_res3d, fname_out=fname_seg_crop_res) # resample to the initial pz resolution fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d') initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])]) spinalcordtoolbox.resample.nipy_resample.resample_file( fname_seg_crop_res, fname_seg_res2d, initial_2d_resolution, 'mm', 'linear', verbose=0) seg_crop_data = Image(fname_seg_res2d).data # reconstruct the segmentation from the crop data sct.log.info("\nReassembling the image...") fname_seg_res_RPI = sct.add_suffix(file_fname, '_res_RPI_seg') uncrop_image(fname_ref=fname_res, fname_out=fname_seg_res_RPI, data_crop=seg_crop_data, x_crop_lst=X_CROP_LST, y_crop_lst=Y_CROP_LST) # resample to initial resolution sct.log.info( "\nResampling the segmentation to the original image resolution...") fname_seg_RPI = sct.add_suffix(file_fname, '_RPI_lesionseg') initial_resolution = 'x'.join([ str(input_resolution[0]), str(input_resolution[1]), str(input_resolution[2]) ]) spinalcordtoolbox.resample.nipy_resample.resample_file(fname_seg_res_RPI, fname_seg_RPI, initial_resolution, 'mm', 'linear', verbose=0) # binarize the resampled image to remove interpolation effects sct.log.info( "\nBinarizing the segmentation to avoid interpolation effects...") thr = '0.1' sct.run( ['sct_maths', '-i', fname_seg_RPI, '-bin', thr, '-o', fname_seg_RPI], verbose=0) # reorient to initial orientation sct.log.info( "\nReorienting the segmentation to the original image orientation...") fname_seg = sct.add_suffix(file_fname, '_lesionseg') if original_orientation != 'RPI': im_seg_orient = Image(fname_seg_RPI) \ .change_orientation(original_orientation) \ .save(fname_seg) else: sct.copy(fname_seg_RPI, fname_seg) tmp_folder.chdir_undo() # copy image from temporary folder into output folder sct.copy(os.path.join(tmp_folder_path, fname_seg), output_folder) # remove temporary files if remove_temp_files: sct.log.info("\nRemove temporary files...") tmp_folder.cleanup() return os.path.join(output_folder, fname_seg)
def deep_segmentation_MSlesion(im_image, contrast_type, ctr_algo='svm', ctr_file=None, brain_bool=True, remove_temp_files=1, verbose=1): """ Segment lesions from MRI data. :param im_image: Image() object containing the lesions to segment :param contrast_type: Constrast of the image. Need to use one supported by the CNN models. :param ctr_algo: Algo to find the centerline. See sct_get_centerline :param ctr_file: Centerline or segmentation (optional) :param brain_bool: If brain if present or not in the image. :param remove_temp_files: :return: """ # create temporary folder with intermediate results tmp_folder = sct.TempFolder(verbose=verbose) tmp_folder_path = tmp_folder.get_path() if ctr_algo == 'file': # if the ctr_file is provided tmp_folder.copy_from(ctr_file) file_ctr = os.path.basename(ctr_file) else: file_ctr = None tmp_folder.chdir() # orientation of the image, should be RPI logger.info("\nReorient the image to RPI, if necessary...") fname_in = im_image.absolutepath original_orientation = im_image.orientation fname_orient = 'image_in_RPI.nii' im_image.change_orientation('RPI').save(fname_orient) input_resolution = im_image.dim[4:7] # find the spinal cord centerline - execute OptiC binary logger.info("\nFinding the spinal cord centerline...") contrast_type_ctr = contrast_type.split('_')[0] fname_res, centerline_filename = find_centerline(algo=ctr_algo, image_fname=fname_orient, contrast_type=contrast_type_ctr, brain_bool=brain_bool, folder_output=tmp_folder_path, remove_temp_files=remove_temp_files, centerline_fname=file_ctr) im_nii, ctr_nii = Image(fname_res), Image(centerline_filename) # crop image around the spinal cord centerline logger.info("\nCropping the image around the spinal cord...") crop_size = 48 X_CROP_LST, Y_CROP_LST, Z_CROP_LST, im_crop_nii = crop_image_around_centerline(im_in=im_nii, ctr_in=ctr_nii, crop_size=crop_size) del ctr_nii # normalize the intensity of the images logger.info("Normalizing the intensity...") im_norm_in = apply_intensity_normalization(img=im_crop_nii, contrast=contrast_type) del im_crop_nii # resample to 0.5mm isotropic fname_norm = sct.add_suffix(fname_orient, '_norm') im_norm_in.save(fname_norm) fname_res3d = sct.add_suffix(fname_norm, '_resampled3d') resampling.resample_file(fname_norm, fname_res3d, '0.5x0.5x0.5', 'mm', 'linear', verbose=0) # segment data using 3D convolutions logger.info("\nSegmenting the MS lesions using deep learning on 3D patches...") segmentation_model_fname = os.path.join(sct.__sct_dir__, 'data', 'deepseg_lesion_models', '{}_lesion.h5'.format(contrast_type)) fname_seg_crop_res = sct.add_suffix(fname_res3d, '_lesionseg') im_res3d = Image(fname_res3d) seg_im = segment_3d(model_fname=segmentation_model_fname, contrast_type=contrast_type, im=im_res3d.copy()) seg_im.save(fname_seg_crop_res) del im_res3d, seg_im # resample to the initial pz resolution fname_seg_res2d = sct.add_suffix(fname_seg_crop_res, '_resampled2d') initial_2d_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])]) resampling.resample_file(fname_seg_crop_res, fname_seg_res2d, initial_2d_resolution, 'mm', 'linear', verbose=0) seg_crop = Image(fname_seg_res2d) # reconstruct the segmentation from the crop data logger.info("\nReassembling the image...") seg_uncrop_nii = uncrop_image(ref_in=im_nii, data_crop=seg_crop.copy().data, x_crop_lst=X_CROP_LST, y_crop_lst=Y_CROP_LST, z_crop_lst=Z_CROP_LST) fname_seg_res_RPI = sct.add_suffix(fname_in, '_res_RPI_seg') seg_uncrop_nii.save(fname_seg_res_RPI) del seg_crop # resample to initial resolution logger.info("Resampling the segmentation to the original image resolution...") initial_resolution = 'x'.join([str(input_resolution[0]), str(input_resolution[1]), str(input_resolution[2])]) fname_seg_RPI = sct.add_suffix(fname_in, '_RPI_seg') resampling.resample_file(fname_seg_res_RPI, fname_seg_RPI, initial_resolution, 'mm', 'linear', verbose=0) seg_initres_nii = Image(fname_seg_RPI) if ctr_algo == 'viewer': # resample and reorient the viewer labels fname_res_labels = sct.add_suffix(fname_orient, '_labels-centerline') resampling.resample_file(fname_res_labels, fname_res_labels, initial_resolution, 'mm', 'linear', verbose=0) im_image_res_labels_downsamp = Image(fname_res_labels).change_orientation(original_orientation) else: im_image_res_labels_downsamp = None if verbose == 2: fname_res_ctr = sct.add_suffix(fname_orient, '_ctr') resampling.resample_file(fname_res_ctr, fname_res_ctr, initial_resolution, 'mm', 'linear', verbose=0) im_image_res_ctr_downsamp = Image(fname_res_ctr).change_orientation(original_orientation) else: im_image_res_ctr_downsamp = None # binarize the resampled image to remove interpolation effects logger.info("\nBinarizing the segmentation to avoid interpolation effects...") thr = 0.1 seg_initres_nii.data[np.where(seg_initres_nii.data >= thr)] = 1 seg_initres_nii.data[np.where(seg_initres_nii.data < thr)] = 0 # reorient to initial orientation logger.info("\nReorienting the segmentation to the original image orientation...") tmp_folder.chdir_undo() # remove temporary files if remove_temp_files: logger.info("\nRemove temporary files...") tmp_folder.cleanup() # reorient to initial orientation return seg_initres_nii.change_orientation(original_orientation), im_image_res_labels_downsamp, im_image_res_ctr_downsamp
def dummy_segmentation(size_arr=(256, 256, 256), pixdim=(1, 1, 1), dtype=np.float64, orientation='LPI', shape='rectangle', angle_RL=0, angle_AP=0, angle_IS=0, radius_RL=5.0, radius_AP=3.0, zeroslice=[], debug=False): """Create a dummy Image with a ellipse or ones running from top to bottom in the 3rd dimension, and rotate the image to make sure that compute_csa and compute_shape properly estimate the centerline angle. :param size_arr: tuple: (nx, ny, nz) :param pixdim: tuple: (px, py, pz) :param dtype: Numpy dtype. :param orientation: Orientation of the image. Default: LPI :param shape: {'rectangle', 'ellipse'} :param angle_RL: int: angle around RL axis (in deg) :param angle_AP: int: angle around AP axis (in deg) :param angle_IS: int: angle around IS axis (in deg) :param radius_RL: float: 1st radius. With a, b = 50.0, 30.0 (in mm), theoretical CSA of ellipse is 4712.4 :param radius_AP: float: 2nd radius :param zeroslice: list int: zero all slices listed in this param :param debug: Write temp files for debug :return: img: Image object """ # Initialization padding = 15 # Padding size (isotropic) to avoid edge effect during rotation # Create a 3d array, with dimensions corresponding to x: RL, y: AP, z: IS nx, ny, nz = [int(size_arr[i] * pixdim[i]) for i in range(3)] data = np.random.random((nx, ny, nz)) * 0. xx, yy = np.mgrid[:nx, :ny] # loop across slices and add object for iz in range(nz): if shape == 'rectangle': # theoretical CSA: (a*2+1)(b*2+1) data[:, :, iz] = ((abs(xx - nx / 2) <= radius_RL) & (abs(yy - ny / 2) <= radius_AP)) * 1 if shape == 'ellipse': data[:, :, iz] = (((xx - nx / 2) / radius_RL)**2 + ((yy - ny / 2) / radius_AP)**2 <= 1) * 1 # Pad to avoid edge effect during rotation data = np.pad(data, padding, 'reflect') # ROTATION ABOUT IS AXIS # rotate (in deg), and re-grid using linear interpolation data_rotIS = rotate(data, angle_IS, resize=False, center=None, order=1, mode='constant', cval=0, clip=False, preserve_range=False) # ROTATION ABOUT RL AXIS # Swap x-z axes (to make a rotation within y-z plane, because rotate will apply rotation on the first 2 dims) data_rotIS_swap = data_rotIS.swapaxes(0, 2) # rotate (in deg), and re-grid using linear interpolation data_rotIS_swap_rotRL = rotate(data_rotIS_swap, angle_RL, resize=False, center=None, order=1, mode='constant', cval=0, clip=False, preserve_range=False) # swap back data_rotIS_rotRL = data_rotIS_swap_rotRL.swapaxes(0, 2) # ROTATION ABOUT AP AXIS # Swap y-z axes (to make a rotation within x-z plane) data_rotIS_rotRL_swap = data_rotIS_rotRL.swapaxes(1, 2) # rotate (in deg), and re-grid using linear interpolation data_rotIS_rotRL_swap_rotAP = rotate(data_rotIS_rotRL_swap, angle_AP, resize=False, center=None, order=1, mode='constant', cval=0, clip=False, preserve_range=False) # swap back data_rot = data_rotIS_rotRL_swap_rotAP.swapaxes(1, 2) # Crop image (to remove padding) data_rot_crop = data_rot[padding:nx + padding, padding:ny + padding, padding:nz + padding] # Zero specified slices if zeroslice is not []: data_rot_crop[:, :, zeroslice] = 0 # Create nibabel object xform = np.eye(4) for i in range(3): xform[i][i] = 1 # in [mm] nii = nib.nifti1.Nifti1Image(data_rot_crop.astype('float32'), xform) # resample to desired resolution nii_r = resample_nib(nii, new_size=pixdim, new_size_type='mm', interpolation='linear') # Create Image object. Default orientation is LPI. # For debugging add .save() at the end of the command below img = Image(nii_r.get_data(), hdr=nii_r.header, dim=nii_r.header.get_data_shape()) # Update orientation img.change_orientation(orientation) if debug: img.save('tmp_dummy_seg_' + datetime.now().strftime("%Y%m%d%H%M%S%f") + '.nii.gz') return img
def heatmap(filename_in, filename_out, model, patch_shape, mean_train, std_train, brain_bool=True): """Compute the heatmap with CNN_1 representing the SC localization.""" im = Image(filename_in) data_im = im.data.astype(np.float32) im_out = msct_image.change_type(im, "uint8") del im data = np.zeros(im_out.data.shape) x_shape, y_shape = data_im.shape[:2] x_shape_block, y_shape_block = np.ceil( x_shape * 1.0 / patch_shape[0]).astype(np.int), np.int(y_shape * 1.0 / patch_shape[1]) x_pad = int(x_shape_block * patch_shape[0] - x_shape) if y_shape > patch_shape[1]: y_crop = y_shape - y_shape_block * patch_shape[1] # slightly crop the input data in the P-A direction so that data_im.shape[1] % patch_shape[1] == 0 data_im = data_im[:, :y_shape - y_crop, :] # coordinates of the blocks to scan during the detection, in the cross-sectional plane coord_lst = [[ x_dim * patch_shape[0], y_dim * patch_shape[1], (x_dim + 1) * patch_shape[0], (y_dim + 1) * patch_shape[1] ] for y_dim in range(y_shape_block) for x_dim in range(x_shape_block)] else: data_im = np.pad(data_im, ((0, 0), (0, patch_shape[1] - y_shape), (0, 0)), 'constant') coord_lst = [[ x_dim * patch_shape[0], 0, (x_dim + 1) * patch_shape[0], patch_shape[1] ] for x_dim in range(x_shape_block)] # pad the input data in the R-L direction data_im = np.pad(data_im, ((0, x_pad), (0, 0), (0, 0)), 'constant') # scale intensities between 0 and 255 data_im = scale_intensity(data_im) x_CoM, y_CoM = None, None z_sc_notDetected_cmpt = 0 for zz in range(data_im.shape[2]): # if SC was detected at zz-1, we will start doing the detection on the block centered around the previously conputed center of mass (CoM) if x_CoM is not None: z_sc_notDetected_cmpt = 0 # SC detected, cmpt set to zero x_0, x_1 = _find_crop_start_end(x_CoM, patch_shape[0], data_im.shape[0]) y_0, y_1 = _find_crop_start_end(y_CoM, patch_shape[1], data_im.shape[1]) block = data_im[x_0:x_1, y_0:y_1, zz] block_nn = np.expand_dims(np.expand_dims(block, 0), -1) block_nn_norm = _normalize_data(block_nn, mean_train, std_train) block_pred = model.predict(block_nn_norm, batch_size=BATCH_SIZE) # coordinates manipulation due to the above padding and cropping if x_1 > data.shape[0]: x_end = data.shape[0] x_1 = data.shape[0] x_0 = data.shape[0] - patch_shape[0] if data.shape[ 0] > patch_shape[0] else 0 else: x_end = patch_shape[0] if y_1 > data.shape[1]: y_end = data.shape[1] y_1 = data.shape[1] y_0 = data.shape[1] - patch_shape[1] if data.shape[ 1] > patch_shape[1] else 0 else: y_end = patch_shape[1] data[x_0:x_1, y_0:y_1, zz] = block_pred[0, :x_end, :y_end, 0] # computation of the new center of mass if np.max(data[:, :, zz]) > 0.5: z_slice_out_bin = data[:, :, zz] > 0.5 # if the SC was detection x_CoM, y_CoM = center_of_mass(z_slice_out_bin) x_CoM, y_CoM = int(x_CoM), int(y_CoM) else: x_CoM, y_CoM = None, None # if the SC was not detected at zz-1 or on the patch centered around CoM in slice zz, the entire cross-sectional slice is scaned if x_CoM is None: z_slice, x_CoM, y_CoM, coord_lst = scan_slice( data_im[:, :, zz], model, mean_train, std_train, coord_lst, patch_shape, data.shape[:2]) data[:, :, zz] = z_slice z_sc_notDetected_cmpt += 1 # if the SC has not been detected on 10 consecutive z_slices, we stop the SC investigation if z_sc_notDetected_cmpt > 10 and brain_bool: sct.printv('Brain section detected.') break # distance transform to deal with the harsh edges of the prediction boundaries (Dice) data[:, :, zz][np.where(data[:, :, zz] < 0.5)] = 0 data[:, :, zz] = distance_transform_edt(data[:, :, zz]) if not np.any(data): sct.log.error( '\nSpinal cord was not detected using "-centerline cnn". Please try another "-centerline" method.\n' ) sys.exit(1) im_out.data = data im_out.save(filename_out) del im_out # z_max is used to reject brain sections z_max = np.max(list(set(np.where(data)[2]))) if z_max == data.shape[2] - 1: return None else: return z_max
def _preprocess_segment(fname_t2, fname_t2_seg, contrast_test, dim_3=False): tmp_folder = sct.TempFolder() tmp_folder_path = tmp_folder.get_path() tmp_folder.chdir() img = Image(fname_t2) gt = Image(fname_t2_seg) fname_t2_RPI, fname_t2_seg_RPI = 'img_RPI.nii.gz', 'seg_RPI.nii.gz' img.change_orientation('RPI').save(fname_t2_RPI) gt.change_orientation('RPI').save(fname_t2_seg_RPI) input_resolution = gt.dim[4:7] del img, gt fname_res, fname_ctr = deepseg_sc.find_centerline( algo='svm', image_fname=fname_t2_RPI, contrast_type=contrast_test, brain_bool=False, folder_output=tmp_folder_path, remove_temp_files=1, centerline_fname=None) fname_t2_seg_RPI_res = 'seg_RPI_res.nii.gz' new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])]) resample_file(fname_t2_seg_RPI, fname_t2_seg_RPI_res, new_resolution, 'mm', 'linear', verbose=0) img, ctr, gt = Image(fname_res), Image(fname_ctr), Image( fname_t2_seg_RPI_res) _, _, _, img = deepseg_sc.crop_image_around_centerline(im_in=img, ctr_in=ctr, crop_size=64) _, _, _, gt = deepseg_sc.crop_image_around_centerline(im_in=gt, ctr_in=ctr, crop_size=64) del ctr img = deepseg_sc.apply_intensity_normalization(im_in=img) if dim_3: # If 3D kernels fname_t2_RPI_res_crop, fname_t2_seg_RPI_res_crop = 'img_RPI_res_crop.nii.gz', 'seg_RPI_res_crop.nii.gz' img.save(fname_t2_RPI_res_crop) gt.save(fname_t2_seg_RPI_res_crop) del img, gt fname_t2_RPI_res_crop_res = 'img_RPI_res_crop_res.nii.gz' fname_t2_seg_RPI_res_crop_res = 'seg_RPI_res_crop_res.nii.gz' resample_file(fname_t2_RPI_res_crop, fname_t2_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0) resample_file(fname_t2_seg_RPI_res_crop, fname_t2_seg_RPI_res_crop_res, new_resolution, 'mm', 'linear', verbose=0) img, gt = Image(fname_t2_RPI_res_crop_res), Image( fname_t2_seg_RPI_res_crop_res) tmp_folder.chdir_undo() tmp_folder.cleanup() return img, gt
def find_centerline(algo, image_fname, path_sct, contrast_type, brain_bool, folder_output, remove_temp_files, centerline_fname): if Image(image_fname).dim[2] == 1: # isct_spine_detect requires nz > 1 from sct_image import concat_data im_concat = concat_data([image_fname, image_fname], dim=2) im_concat.save(sct.add_suffix(image_fname, '_concat')) image_fname = sct.add_suffix(image_fname, '_concat') bool_2d = True else: bool_2d = False if algo == 'svm': # run optic on a heatmap computed by a trained SVM+HoG algorithm optic_models_fname = os.path.join(path_sct, 'data', 'optic_models', '{}_model'.format(contrast_type)) _, centerline_filename = optic.detect_centerline( image_fname=image_fname, contrast_type=contrast_type, optic_models_path=optic_models_fname, folder_output=folder_output, remove_temp_files=remove_temp_files, output_roi=False, verbose=0) elif algo == 'cnn': # CNN parameters dct_patch_ctr = { 't2': { 'size': (80, 80), 'mean': 51.1417, 'std': 57.4408 }, 't2s': { 'size': (80, 80), 'mean': 68.8591, 'std': 71.4659 }, 't1': { 'size': (80, 80), 'mean': 55.7359, 'std': 64.3149 }, 'dwi': { 'size': (80, 80), 'mean': 55.744, 'std': 45.003 } } dct_params_ctr = { 't2': { 'features': 16, 'dilation_layers': 2 }, 't2s': { 'features': 8, 'dilation_layers': 3 }, 't1': { 'features': 24, 'dilation_layers': 3 }, 'dwi': { 'features': 8, 'dilation_layers': 2 } } # load model ctr_model_fname = os.path.join(path_sct, 'data', 'deepseg_sc_models', '{}_ctr.h5'.format(contrast_type)) ctr_model = nn_architecture_ctr( height=dct_patch_ctr[contrast_type]['size'][0], width=dct_patch_ctr[contrast_type]['size'][1], channels=1, classes=1, features=dct_params_ctr[contrast_type]['features'], depth=2, temperature=1.0, padding='same', batchnorm=True, dropout=0.0, dilation_layers=dct_params_ctr[contrast_type]['dilation_layers']) ctr_model.load_weights(ctr_model_fname) # compute the heatmap fname_heatmap = sct.add_suffix(image_fname, "_heatmap") img_filename = ''.join(sct.extract_fname(fname_heatmap)[:2]) fname_heatmap_nii = img_filename + '.nii' z_max = heatmap(filename_in=image_fname, filename_out=fname_heatmap_nii, model=ctr_model, patch_shape=dct_patch_ctr[contrast_type]['size'], mean_train=dct_patch_ctr[contrast_type]['mean'], std_train=dct_patch_ctr[contrast_type]['std'], brain_bool=brain_bool) # run optic on the heatmap centerline_filename = sct.add_suffix(fname_heatmap, "_ctr") heatmap2optic(fname_heatmap=fname_heatmap_nii, lambda_value=7 if contrast_type == 't2s' else 1, fname_out=centerline_filename, z_max=z_max if brain_bool else None) elif algo == 'viewer': centerline_filename = sct.add_suffix(image_fname, "_ctr") fname_labels_viewer = _call_viewer_centerline(fname_in=image_fname) centerline_filename = extract_centerline(fname_labels_viewer, remove_temp_files=True, algo_fitting='nurbs', nurbs_pts_number=8000) elif algo == 'manual': centerline_filename = sct.add_suffix(image_fname, "_ctr") image_manual_centerline = Image(centerline_fname) # Re-orient and Re-sample the manual centerline image_centerline_reoriented = msct_image.change_orientation( image_manual_centerline, 'RPI').save(centerline_filename) input_resolution = image_centerline_reoriented.dim[4:7] new_resolution = 'x'.join(['0.5', '0.5', str(input_resolution[2])]) spinalcordtoolbox.resample.nipy_resample.resample_file( centerline_filename, centerline_filename, new_resolution, 'mm', 'linear', verbose=0) else: sct.log.error( 'The parameter "-centerline" is incorrect. Please try again.') sys.exit(1) if bool_2d: from sct_image import split_data im_split_lst = split_data(Image(centerline_filename), dim=2) im_split_lst[0].save(centerline_filename) return centerline_filename
def segment_3d(model_fname, contrast_type, fname_in, fname_out): """Perform segmentation with 3D convolutions.""" from spinalcordtoolbox.deepseg_sc.cnn_models_3d import load_trained_model dct_patch_sc_3d = { 't2': { 'size': (64, 64, 48), 'mean': 65.8562, 'std': 59.7999 }, 't2s': { 'size': (96, 96, 48), 'mean': 87.0212, 'std': 64.425 }, 't1': { 'size': (64, 64, 48), 'mean': 88.5001, 'std': 66.275 } } # load 3d model seg_model = load_trained_model(model_fname) im = Image(fname_in) out = msct_image.zeros_like(im, dtype=np.uint8) # segment the spinal cord z_patch_size = dct_patch_sc_3d[contrast_type]['size'][2] z_step_keep = list(range(0, im.data.shape[2], z_patch_size)) for zz in z_step_keep: if zz == z_step_keep[ -1]: # deal with instances where the im.data.shape[2] % patch_size_z != 0 patch_im = np.zeros(dct_patch_sc_3d[contrast_type]['size']) z_patch_extracted = im.data.shape[2] - zz patch_im[:, :, :z_patch_extracted] = im.data[:, :, zz:] else: z_patch_extracted = z_patch_size patch_im = im.data[:, :, zz:z_patch_size + zz] if np.any( patch_im ): # Check if the patch is (not) empty, which could occur after a brain detection. patch_norm = _normalize_data( patch_im, dct_patch_sc_3d[contrast_type]['mean'], dct_patch_sc_3d[contrast_type]['std']) patch_pred_proba = seg_model.predict(np.expand_dims( np.expand_dims(patch_norm, 0), 0), batch_size=BATCH_SIZE) pred_seg_th = (patch_pred_proba > 0.5).astype(int)[0, 0, :, :, :] x_cOm, y_cOm = None, None for zz_pp in range(z_patch_size): pred_seg_pp = post_processing_slice_wise( pred_seg_th[:, :, zz_pp], x_cOm, y_cOm) pred_seg_th[:, :, zz_pp] = pred_seg_pp x_cOm, y_cOm = center_of_mass(pred_seg_pp) x_cOm, y_cOm = np.round(x_cOm), np.round(y_cOm) if zz == z_step_keep[-1]: out.data[:, :, zz:] = pred_seg_th[:, :, :z_patch_extracted] else: out.data[:, :, zz:z_patch_size + zz] = pred_seg_th out.save(fname_out)
def merge_images(list_fname_src, fname_dest, list_fname_warp, param): """ Merge multiple source images onto destination space. All images are warped to the destination space and then added. To deal with overlap during merging (e.g. one voxel in destination image is shared with two input images), the resulting voxel is divided by the sum of the partial volume of each image. For example, if src(x,y,z)=1 is mapped to dest(i,j,k) with a partial volume of 0.5 (because destination voxel is bigger), then its value after linear interpolation will be 0.5. To account for partial volume, the resulting voxel will be: dest(i,j,k) = 0.5*0.5/0.5 = 0.5. Now, if two voxels overlap in the destination space, let's say: src(x,y,z)=1 and src2'(x',y',z')=1, then the resulting value will be: dest(i,j,k) = (0.5*0.5 + 0.5*0.5) / (0.5+0.5) = 0.5. So this function acts like a weighted average operator, only in destination voxels that share multiple source voxels. Parameters ---------- list_fname_src fname_dest list_fname_warp param Returns ------- """ # create temporary folder path_tmp = tmp_create() # get dimensions of destination file nii_dest = Image(fname_dest) # initialize variables data = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2], len(list_fname_src)]) partial_volume = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2], len(list_fname_src)]) data_merge = np.zeros([nii_dest.dim[0], nii_dest.dim[1], nii_dest.dim[2]]) # loop across files i_file = 0 for fname_src in list_fname_src: # apply transformation src --> dest sct_apply_transfo.main(argv=[ '-i', fname_src, '-d', fname_dest, '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' + str(i_file) + '_template.nii.gz', '-v', str(param.verbose)]) # create binary mask from input file by assigning one to all non-null voxels sct_maths.main(argv=[ '-i', fname_src, '-bin', str(param.almost_zero), '-o', 'src_' + str(i_file) + 'native_bin.nii.gz']) # apply transformation to binary mask to compute partial volume sct_apply_transfo.main(argv=[ '-i', 'src_' + str(i_file) + 'native_bin.nii.gz', '-d', fname_dest, '-w', list_fname_warp[i_file], '-x', param.interp, '-o', 'src_' + str(i_file) + '_template_partialVolume.nii.gz']) # open data data[:, :, :, i_file] = Image('src_' + str(i_file) + '_template.nii.gz').data partial_volume[:, :, :, i_file] = Image('src_' + str(i_file) + '_template_partialVolume.nii.gz').data i_file += 1 # merge files using partial volume information (and convert nan resulting from division by zero to zeros) data_merge = np.divide(np.sum(data * partial_volume, axis=3), np.sum(partial_volume, axis=3)) data_merge = np.nan_to_num(data_merge) # write result in file nii_dest.data = data_merge nii_dest.save(param.fname_out) # remove temporary folder if param.rm_tmp: rmtree(path_tmp)
def main(): # find all the images of interest and store the mid slice in slice_lst slice_lst = [] for x in os.walk(i_folder): for file in glob.glob( os.path.join(x[0], 'sub' + im_string) ): # prefixe sub: to prevent from fetching warp files print('\nLoading: ' + file) # load data if plane == 'ax': file_seg = glob.glob(os.path.join(x[0], 'sub' + seg_string))[0] # workaround to save some time img, seg = Image(file).change_orientation('RPI'), Image( file_seg).change_orientation('RPI') mid_slice_idx = int(float(img.dim[2]) // 2) nii_mid = nib.nifti1.Nifti1Image(img.data[:, :, mid_slice_idx], affine) nii_mid_seg = nib.nifti1.Nifti1Image( seg.data[:, :, mid_slice_idx], affine) img_mid = Image(img.data[:, :, mid_slice_idx], hdr=nii_mid.header, dim=nii_mid.header.get_data_shape()) seg_mid = Image(seg.data[:, :, mid_slice_idx], hdr=nii_mid_seg.header, dim=nii_mid_seg.header.get_data_shape()) del img, seg qcslice_cur = qcslice.Axial([img_mid, seg_mid]) center_x_lst, center_y_lst = qcslice_cur.get_center( ) # find seg center of mass mid_slice = qcslice_cur.get_slice(qcslice_cur._images[0].data, 0) # get the mid slice # crop image around SC seg mid_slice = qcslice_cur.crop(mid_slice, int(center_x_lst[0]), int(center_y_lst[0]), 30, 30) else: sag_im = Image(file).change_orientation('RSP') if not np.isclose( sag_im.dim[5], sag_im.dim[6]): # in case data is anisotropic sag_im = resample_nib( sag_im.copy(), new_size=[sag_im.dim[4], sag_im.dim[5], sag_im.dim[5]], new_size_type='mm') mid_slice_idx = int(sag_im.dim[0] // 2) mid_slice = sag_im.data[mid_slice_idx, :, :] del sag_im # histogram equalization using CLAHE slice_cur = equalized(mid_slice, winsize) # scale intensities of all slices (ie of all subjects) in a common range of values slice_cur = scale_intensity(slice_cur) # resize all slices with the shape of the first loaded slice if len(slice_lst): slice_cur = resize(slice_cur, slice_size, anti_aliasing=True) else: slice_size = slice_cur.shape slice_lst.append(slice_cur) # create a new Image object containing the samples to display data = np.stack(slice_lst, axis=-1) nii = nib.nifti1.Nifti1Image(data, affine) img = Image(data, hdr=nii.header, dim=nii.header.get_data_shape()) nb_img = img.data.shape[2] nb_items_mosaic = nb_column * nb_row nb_mosaic = np.ceil(float(nb_img) / (nb_items_mosaic)) for i in range(int(nb_mosaic)): if nb_mosaic == 1: fname_out = o_fname else: fname_out = os.path.splitext(o_fname)[0] + '_' + str(i).zfill( 3) + os.path.splitext(o_fname)[1] print('\nCreating: ' + fname_out) # create mosaic idx_end = (i + 1) * nb_items_mosaic if ( i + 1) * nb_items_mosaic <= nb_img else nb_img data_mosaic = img.data[:, :, i * (nb_items_mosaic):idx_end] mosaic = get_mosaic(data_mosaic, nb_column, nb_row) # save mosaic plt.figure() plt.subplot(1, 1, 1) plt.axis("off") plt.imshow(mosaic, interpolation='bilinear', cmap='gray', aspect='equal') plt.savefig(fname_out, dpi=300, bbox_inches='tight', pad_inches=0) plt.close()
def main(args=None): # Initialization param = Param() start_time = time.time() parser = get_parser() arguments = parser.parse(sys.argv[1:]) fname_anat = arguments['-i'] fname_centerline = arguments['-s'] if '-smooth' in arguments: sigma = arguments['-smooth'] if '-param' in arguments: param.update(arguments['-param']) if '-r' in arguments: remove_temp_files = int(arguments['-r']) verbose = int(arguments.get('-v')) sct.init_sct(log_level=verbose, update=True) # Update log level # Display arguments sct.printv('\nCheck input arguments...') sct.printv(' Volume to smooth .................. ' + fname_anat) sct.printv(' Centerline ........................ ' + fname_centerline) sct.printv(' Sigma (mm) ........................ ' + str(sigma)) sct.printv(' Verbose ........................... ' + str(verbose)) # Check that input is 3D: from spinalcordtoolbox.image import Image nx, ny, nz, nt, px, py, pz, pt = Image(fname_anat).dim dim = 4 # by default, will be adjusted later if nt == 1: dim = 3 if nz == 1: dim = 2 if dim == 4: sct.printv('WARNING: the input image is 4D, please split your image to 3D before smoothing spinalcord using :\n' 'sct_image -i ' + fname_anat + ' -split t -o ' + fname_anat, verbose, 'warning') sct.printv('4D images not supported, aborting ...', verbose, 'error') # Extract path/file/extension path_anat, file_anat, ext_anat = sct.extract_fname(fname_anat) path_centerline, file_centerline, ext_centerline = sct.extract_fname(fname_centerline) path_tmp = sct.tmp_create(basename="smooth_spinalcord", verbose=verbose) # Copying input data to tmp folder sct.printv('\nCopying input data to tmp folder and convert to nii...', verbose) sct.copy(fname_anat, os.path.join(path_tmp, "anat" + ext_anat)) sct.copy(fname_centerline, os.path.join(path_tmp, "centerline" + ext_centerline)) # go to tmp folder curdir = os.getcwd() os.chdir(path_tmp) # convert to nii format convert('anat' + ext_anat, 'anat.nii') convert('centerline' + ext_centerline, 'centerline.nii') # Change orientation of the input image into RPI sct.printv('\nOrient input volume to RPI orientation...') fname_anat_rpi = msct_image.Image("anat.nii") \ .change_orientation("RPI", generate_path=True) \ .save() \ .absolutepath # Change orientation of the input image into RPI sct.printv('\nOrient centerline to RPI orientation...') fname_centerline_rpi = msct_image.Image("centerline.nii") \ .change_orientation("RPI", generate_path=True) \ .save() \ .absolutepath # Straighten the spinal cord # straighten segmentation sct.printv('\nStraighten the spinal cord using centerline/segmentation...', verbose) cache_sig = sct.cache_signature(input_files=[fname_anat_rpi, fname_centerline_rpi], input_params={"x": "spline"}) cachefile = os.path.join(curdir, "straightening.cache") if sct.cache_valid(cachefile, cache_sig) and os.path.isfile(os.path.join(curdir, 'warp_curve2straight.nii.gz')) and os.path.isfile(os.path.join(curdir, 'warp_straight2curve.nii.gz')) and os.path.isfile(os.path.join(curdir, 'straight_ref.nii.gz')): # if they exist, copy them into current folder sct.printv('Reusing existing warping field which seems to be valid', verbose, 'warning') sct.copy(os.path.join(curdir, 'warp_curve2straight.nii.gz'), 'warp_curve2straight.nii.gz') sct.copy(os.path.join(curdir, 'warp_straight2curve.nii.gz'), 'warp_straight2curve.nii.gz') sct.copy(os.path.join(curdir, 'straight_ref.nii.gz'), 'straight_ref.nii.gz') # apply straightening sct.run(['sct_apply_transfo', '-i', fname_anat_rpi, '-w', 'warp_curve2straight.nii.gz', '-d', 'straight_ref.nii.gz', '-o', 'anat_rpi_straight.nii', '-x', 'spline'], verbose) else: sct.run(['sct_straighten_spinalcord', '-i', fname_anat_rpi, '-o', 'anat_rpi_straight.nii', '-s', fname_centerline_rpi, '-x', 'spline', '-param', 'algo_fitting='+param.algo_fitting], verbose) sct.cache_save(cachefile, cache_sig) # move warping fields locally (to use caching next time) sct.copy('warp_curve2straight.nii.gz', os.path.join(curdir, 'warp_curve2straight.nii.gz')) sct.copy('warp_straight2curve.nii.gz', os.path.join(curdir, 'warp_straight2curve.nii.gz')) # Smooth the straightened image along z sct.printv('\nSmooth the straightened image...') sigma_smooth = ",".join([str(i) for i in sigma]) sct_maths.main(args=['-i', 'anat_rpi_straight.nii', '-smooth', sigma_smooth, '-o', 'anat_rpi_straight_smooth.nii', '-v', '0']) # Apply the reversed warping field to get back the curved spinal cord sct.printv('\nApply the reversed warping field to get back the curved spinal cord...') sct.run(['sct_apply_transfo', '-i', 'anat_rpi_straight_smooth.nii', '-o', 'anat_rpi_straight_smooth_curved.nii', '-d', 'anat.nii', '-w', 'warp_straight2curve.nii.gz', '-x', 'spline'], verbose) # replace zeroed voxels by original image (issue #937) sct.printv('\nReplace zeroed voxels by original image...', verbose) nii_smooth = Image('anat_rpi_straight_smooth_curved.nii') data_smooth = nii_smooth.data data_input = Image('anat.nii').data indzero = np.where(data_smooth == 0) data_smooth[indzero] = data_input[indzero] nii_smooth.data = data_smooth nii_smooth.save('anat_rpi_straight_smooth_curved_nonzero.nii') # come back os.chdir(curdir) # Generate output file sct.printv('\nGenerate output file...') sct.generate_output_file(os.path.join(path_tmp, "anat_rpi_straight_smooth_curved_nonzero.nii"), file_anat + '_smooth' + ext_anat) # Remove temporary files if remove_temp_files == 1: sct.printv('\nRemove temporary files...') sct.rmtree(path_tmp) # Display elapsed time elapsed_time = time.time() - start_time sct.printv('\nFinished! Elapsed time: ' + str(int(np.round(elapsed_time))) + 's\n') sct.display_viewer_syntax([file_anat, file_anat + '_smooth'], verbose=verbose)
def main(argv=None): """Main function.""" parser = get_parser() arguments = parser.parse_args(argv if argv else ['--help']) verbose = arguments.v set_global_loglevel(verbose=verbose) fname_image = arguments.i contrast_type = arguments.c ctr_algo = arguments.centerline brain_bool = bool(arguments.brain) if arguments.brain is None and contrast_type in ['t2s', 't2_ax']: brain_bool = False output_folder = arguments.ofolder if ctr_algo == 'file' and arguments.file_centerline is None: printv( 'Please use the flag -file_centerline to indicate the centerline filename.', 1, 'error') sys.exit(1) if arguments.file_centerline is not None: manual_centerline_fname = arguments.file_centerline ctr_algo = 'file' else: manual_centerline_fname = None remove_temp_files = arguments.r algo_config_stg = '\nMethod:' algo_config_stg += '\n\tCenterline algorithm: ' + str(ctr_algo) algo_config_stg += '\n\tAssumes brain section included in the image: ' + str( brain_bool) + '\n' printv(algo_config_stg) # Segment image from spinalcordtoolbox.image import Image from spinalcordtoolbox.deepseg_lesion.core import deep_segmentation_MSlesion im_image = Image(fname_image) im_seg, im_labels_viewer, im_ctr = deep_segmentation_MSlesion( im_image, contrast_type, ctr_algo=ctr_algo, ctr_file=manual_centerline_fname, brain_bool=brain_bool, remove_temp_files=remove_temp_files, verbose=verbose) # Save segmentation fname_seg = os.path.abspath( os.path.join( output_folder, extract_fname(fname_image)[1] + '_lesionseg' + extract_fname(fname_image)[2])) im_seg.save(fname_seg) if ctr_algo == 'viewer': # Save labels fname_labels = os.path.abspath( os.path.join( output_folder, extract_fname(fname_image)[1] + '_labels-centerline' + extract_fname(fname_image)[2])) im_labels_viewer.save(fname_labels) if verbose == 2: # Save ctr fname_ctr = os.path.abspath( os.path.join( output_folder, extract_fname(fname_image)[1] + '_centerline' + extract_fname(fname_image)[2])) im_ctr.save(fname_ctr) display_viewer_syntax([fname_image, fname_seg], colormaps=['gray', 'red'], opacities=['', '0.7'])