def _get_slices_dim_list(pc_folder, file_list_txt): center_slices_x = [] center_slices_y = [] center_slices_z = [] data_folder = DataFolder(pc_folder, file_list_txt) for file_idx in range(show_pc_number): data_folder.print_idx(file_idx) file_path = data_folder.get_file_path(file_idx) scan = ScanWrapper(file_path) slice_x, slice_y, slice_z = scan.get_center_slices() center_slices_x.append(slice_x) center_slices_y.append(slice_y) center_slices_z.append(slice_z) return center_slices_x, center_slices_y, center_slices_z
class GetLossBetweenFolder(AbstractParallelRoutine): def __init__(self, config, in_folder_1, in_folder_2, file_list_txt): super().__init__(config, in_folder_1, file_list_txt) self._in_data_folder_2 = DataFolder(in_folder_2, file_list_txt) self._nrmse_diff = [] def get_nrmse(self): return self._nrmse_diff def print_file_list(self): file_list = self._in_data_folder.get_data_file_list() for idx in range(len(file_list)): print(f'The {idx}th file is {file_list[idx]}') def _run_single_scan(self, idx): in_file_1_path = self._in_data_folder.get_file_path(idx) in_file_2_path = self._in_data_folder_2.get_file_path(idx) in_img_1 = ScanWrapper(in_file_1_path).get_data() in_img_2 = ScanWrapper(in_file_2_path).get_data() nrmse = compare_nrmse(np.abs(in_img_1), np.abs(in_img_2)) self._nrmse_diff.append(nrmse)
class AverageScans: def __init__(self, config, in_folder=None, data_file_txt=None, in_data_folder_obj=None): self._data_folder = None if in_data_folder_obj is None: self._data_folder = DataFolder(in_folder, data_file_txt) else: self._data_folder = in_data_folder_obj self._standard_ref = ScanWrapper(self._data_folder.get_first_path()) self._num_processes = config['num_processes'] def get_average_image_union(self, save_path): im_shape = self._get_std_shape() average_union = np.zeros(im_shape) average_union.fill(np.nan) non_null_mask_count_image = np.zeros(im_shape) chunk_list = self._data_folder.get_chunks_list(self._num_processes) pool = Pool(processes=self._num_processes) print('Average in union') print('Step.1 Summation') image_average_union_result_list = [ pool.apply_async(self._sum_images_union, (file_idx_chunk, )) for file_idx_chunk in chunk_list ] for thread_idx in range(len(image_average_union_result_list)): result = image_average_union_result_list[thread_idx] result.wait() print( f'Thread with idx {thread_idx} / {len(image_average_union_result_list)} is completed' ) print('Adding to averaged_image...') averaged_image_chunk = result.get() average_union = self._add_image_union(average_union, averaged_image_chunk) print('Done.') print('Step.2 Non-nan counter') non_null_mask_count_result = [ pool.apply_async(self._sum_non_null_count, (file_idx_chunk, )) for file_idx_chunk in chunk_list ] for thread_idx in range(len(non_null_mask_count_result)): result = non_null_mask_count_result[thread_idx] result.wait() print( f'Thread with idx {thread_idx} / {len(non_null_mask_count_result)} is completed' ) print('Adding to averaged_image...') averaged_image_chunk = result.get() non_null_mask_count_image = np.add(non_null_mask_count_image, averaged_image_chunk) print('Done.') average_union = np.divide(average_union, non_null_mask_count_image, out=average_union, where=non_null_mask_count_image > 0) self._standard_ref.save_scan_same_space(save_path, average_union) print('Done.') def _sum_images_union(self, chunk_list): print('Sum images, union non-null region. Loading images...') im_shape = self._get_std_shape() sum_image = np.zeros(im_shape) sum_image.fill(np.nan) for id_file in chunk_list: file_path = self._data_folder.get_file_path(id_file) self._data_folder.print_idx(id_file) im = nib.load(file_path) im_data = im.get_data() sum_image = self._add_image_union(sum_image, im_data) return sum_image def _sum_non_null_count(self, chunk_list): print('Count non-null per voxel. Loading images...') im_shape = self._get_std_shape() sum_image = np.zeros(im_shape) for id_file in chunk_list: file_path = self._data_folder.get_file_path(id_file) self._data_folder.print_idx(id_file) im = nib.load(file_path) im_data = im.get_data() sum_image = np.add(sum_image, 1, out=sum_image, where=np.logical_not(np.isnan(im_data))) return sum_image def _get_std_shape(self): return self._standard_ref.get_data().shape @staticmethod def _add_image_inter(image1, image2): return np.add(image1, image2, out=np.full_like(image1, np.nan), where=np.logical_not( np.logical_or(np.isnan(image1), np.isnan(image2)))) @staticmethod def _add_image_union(image1, image2): add_image = np.full_like(image1, np.nan) add_image[np.logical_not( np.logical_and(np.isnan(image1), np.isnan(image2)))] = 0 add_image = np.add(add_image, image1, out=add_image, where=np.logical_not(np.isnan(image1))) add_image = np.add(add_image, image2, out=add_image, where=np.logical_not(np.isnan(image2))) return add_image @staticmethod def sum_non_null_count(file_list, in_folder): print('Count non-null per voxel. Loading images...') im_temp = nib.load(os.path.join(in_folder, file_list[0])) im_temp_data = im_temp.get_data() sum_image = np.zeros_like(im_temp_data) for id_file in range(len(file_list)): file_name = file_list[id_file] print('%s (%d/%d)' % (file_name, id_file, len(file_list))) file_path = os.path.join(in_folder, file_name) im = nib.load(file_path) im_data = im.get_data() sum_image = np.add(sum_image, 1, out=sum_image, where=np.logical_not(np.isnan(im_data))) return sum_image
class ScanFolderConcatBatchReader(AbstractParallelRoutine): def __init__(self, config, in_ori_folder, in_jac_folder, batch_size, file_list_txt=None): super().__init__(config, in_ori_folder, file_list_txt) self._in_jac_folder = DataFolder(in_jac_folder, file_list_txt) self._ref_ori = ScanWrapper(self._in_data_folder.get_file_path(0)) self._ref_jac = ScanWrapper(self._in_jac_folder.get_file_path(0)) self._chunk_list = self._in_data_folder.get_chunks_list_batch_size( batch_size) self._data_matrix = [] self._cur_idx = 0 def read_data(self, idx_batch): self._reset_cur_idx() print(f'Reading scans from folder {self._in_data_folder.get_folder()}', flush=True) tic = time.perf_counter() cur_batch = self._chunk_list[idx_batch] self._init_data_matrix(len(cur_batch)) self.run_non_parallel(cur_batch) toc = time.perf_counter() print(f'Done. {toc - tic:0.4f} (s)', flush=True) def num_batch(self): return len(self._chunk_list) def get_data_matrix(self): return self._data_matrix def save_flat_data(self, data_array, idx, out_folder): out_path_ori = os.path.join(out_folder, f'pc_ori_{idx}.nii.gz') out_path_jac = os.path.join(out_folder, f'pc_jac_{idx}.nii.gz') ori_data_flat = data_array[:self._ref_ori.get_number_voxel()] jac_data_flat = data_array[self._ref_ori.get_number_voxel():] self._ref_ori.save_scan_flat_img(ori_data_flat, out_path_ori) self._ref_jac.save_scan_flat_img(jac_data_flat, out_path_jac) def get_ref(self): return self._ref_ori def _run_single_scan(self, idx): in_ori_data = ScanWrapper( self._in_data_folder.get_file_path(idx)).get_data() in_jac_data = ScanWrapper( self._in_jac_folder.get_file_path(idx)).get_data() self._data_matrix[self._cur_idx, :self._ref_ori.get_number_voxel( )] = convert_3d_2_flat(in_ori_data) self._data_matrix[ self._cur_idx, self._ref_ori.get_number_voxel():] = convert_3d_2_flat(in_jac_data) self._cur_idx += 1 def _init_data_matrix(self, num_sample): num_features = self._get_number_of_voxel() del self._data_matrix self._data_matrix = np.zeros((num_sample, num_features)) def _get_number_of_voxel(self): return self._get_number_of_voxel_ori() + self._get_number_of_voxel_jac( ) def _get_number_of_voxel_ori(self): return self._ref_ori.get_number_voxel() def _get_number_of_voxel_jac(self): return self._ref_jac.get_number_voxel() def _reset_cur_idx(self): self._cur_idx = 0