def compose_file_to_file_reg(source_txt, target_txt,output_txt,sever_switcher=('','')): source_img_label = read_txt_into_list(source_txt) target_img_label = read_txt_into_list(target_txt) num_s = len(source_img_label) num_t = len(target_img_label) pair = [] for i in range(num_s): for j in range(num_t): line = [source_img_label[i][0],target_img_label[j][0],source_img_label[i][1],target_img_label[j][1]] line =[item.replace(*sever_switcher) for item in line] pair.append(line) write_list_into_txt(output_txt,pair)
def compute_warped_image_label(img_label_txt_pth,phi_pth,phi_type, saving_pth): img_label_pth_list = read_txt_into_list(img_label_txt_pth) phi_pth_list = glob(os.path.join(phi_pth,phi_type)) f = lambda pth: sitk.GetArrayFromImage(sitk.ReadImage(pth)) img_list = [f(pth[0]) for pth in img_label_pth_list] label_list = [f(pth[1]) for pth in img_label_pth_list] num_img = len(img_list) for i in range(num_img): fname = get_file_name(img_label_pth_list[i][0]) img = torch.Tensor(img_list[i][None][None]) label = torch.Tensor(label_list[i][None][None]) f_phi = lambda x: get_file_name(x).find(fname)==0 phi_sub_list = list(filter(f_phi, phi_pth_list)) num_aug = len(phi_sub_list) phi_list = [f(pth) for pth in phi_sub_list] img = img.repeat(num_aug,1,1,1,1) label = label.repeat(num_aug,1,1,1,1) phi = np.stack(phi_list,0) phi = np.transpose(phi,(0,4,3,2,1)) phi = torch.Tensor(phi) sz = np.array(img.shape[2:]) spacing = 1./(sz-1) phi, _ = resample_image(phi,spacing,[1,3]+list(img.shape[2:])) warped_img = compute_warped_image_multiNC(img,phi,spacing,spline_order=1,zero_boundary=True) warped_label = compute_warped_image_multiNC(label,phi,spacing,spline_order=0,zero_boundary=True) save_image_with_given_reference(warped_img,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_warped' for pth in phi_sub_list]) save_image_with_given_reference(warped_label,[img_label_pth_list[i][0]]*num_aug,saving_pth,[get_file_name(pth).replace("_phi","")+'_label' for pth in phi_sub_list])
def do_segmentation_eval(args, segmentation_file_list): """ set running env and run the task :param args: the parsed arguments :param segmentation_file_list: list of segmentation file list, [image_list, label_list] :return: None """ task_output_path = args.task_output_path os.makedirs(task_output_path, exist_ok=True) setting_folder_path = args.setting_folder_path file_txt_path = args.file_txt_path fname_txt_path = file_txt_path.replace("file_path_list.txt", "file_name_list.txt") fname_list = read_txt_into_list(fname_txt_path) if os.path.isfile( fname_txt_path) else None dm, tsm = init_test_env(setting_folder_path, task_output_path, segmentation_file_list, fname_list) tsm.task_par['tsk_set']['gpu_ids'] = args.gpu_id model_path = args.model_path if model_path is not None: assert os.path.isfile(model_path), "the model {} not exist".format_map( model_path) tsm.task_par['tsk_set']['model_path'] = model_path force_test_setting(dm, tsm, task_output_path) dm_json_path = os.path.join( task_output_path, 'cur_data_setting.json') if dm is not None else None tsm_json_path = os.path.join(task_output_path, 'cur_task_setting.json') run_one_task(tsm_json_path, dm_json_path)
def random_sample_for_oai_inter_txt(txt_path,num_patient,num_pair_per_patient, output_path,mod,switcher): pair_list = read_txt_into_list(txt_path) num_per_m = int(num_patient/2) random.shuffle(pair_list) aug_pair_list = [] sampled_list = [] num_s = 0 for i in range(len(pair_list)): while num_s<num_per_m: pair = pair_list[i] if mod in pair[0]: sampled_list.append([pair[0],pair[2]]) sampled_list.append([pair[1],pair[3]]) num_s += 1 else: continue for i,sampled_source in enumerate(sampled_list): index = list(range(len(sampled_list))) index.remove(i) sampled_index = random.sample(index,num_pair_per_patient) for j in range(num_pair_per_patient): sampled_target = sampled_list[sampled_index[j]] aug_pair_list.append([sampled_source[0],sampled_target[0],sampled_source[1],sampled_target[1]]) aug_pair_list = [[pth.replace(*switcher) for pth in pths] for pths in aug_pair_list] write_list_into_txt(output_path, aug_pair_list)
def transfer_txt_file_to_altas_txt_file(txt_path, atlas_path,output_txt,atlas_label_path,sever_switcher=("","")): """we would remove the seg info here""" img_label_list = read_txt_into_list(txt_path) img_label_list =[[pth.replace(*sever_switcher) for pth in pths] for pths in img_label_list] img_atlas_list = [[img_label[0],atlas_path,img_label[1],atlas_label_path] for img_label in img_label_list] img_atlas_list += [[atlas_path,img_label[0],atlas_label_path,img_label[1]] for img_label in img_label_list] write_list_into_txt(output_txt, img_atlas_list)
def get_pair_txt_for_oai_reg_net(train_txt_path,warped_folder,warped_type, num_train,output_txt): train_pair_list = read_txt_into_list(train_txt_path) warped_file_list = glob(os.path.join(warped_folder,warped_type)) name_set = [get_file_name(pair[0]).split("_")[0] for pair in train_pair_list] name_set = set(name_set) name_file_dict = {name:[] for name in name_set} extra_weight = 2 for pair in train_pair_list: fname = get_file_name(pair[0]).split("_")[0] for i in range(extra_weight): name_file_dict[fname].append(pair[0]) name_file_dict[fname].append(pair[1]) for file in warped_file_list: fname = get_file_name(file).split("_")[0] name_file_dict[fname].append(file) num_per_patient = int(num_train/len(name_set)) train_list = [] for name,values in name_file_dict.items(): num_sample = 0 while num_sample < num_per_patient: pair = random.sample(name_file_dict[name],2) if get_file_name(pair[0])==get_file_name(pair[1]) or get_file_name(pair[0]).split("_")[1]==get_file_name(pair[1]).split("_")[1]: continue else: train_list.append(pair) num_sample += 1 write_list_into_txt(output_txt, train_list)
def split_txt(input_txt,num_split, output_folder): os.makedirs(output_folder,exist_ok=True) pairs = read_txt_into_list(input_txt) output_splits = np.split(np.array(range(len(pairs))), num_split) output_splits = list(output_splits) for i in range(num_split): split = [pairs[ind] for ind in output_splits[i]] write_list_into_txt(os.path.join(output_folder, 'p{}.txt'.format(i)),split)
def get_test_file_for_brainstorm_color(test_path,transfer_path,output_txt): #atlas_image_9023193_image_test_iter_0_warped.nii.gz file_label_list = read_txt_into_list(test_path) file_list, label_list = [file[0] for file in file_label_list],[file[1] for file in file_label_list] f = lambda x: "atlas_image_"+x+"_test_iter_0_warped.nii.gz" new_file_list = [os.path.join(transfer_path,f(get_file_name(file))) for file in file_list] new_file_label_list = [[new_file_list[i],label_list[i]] for i in range(len(file_label_list))] write_list_into_txt(output_txt,new_file_label_list)
def random_sample_from_txt(txt_path,num, output_path,switcher): pair_list = read_txt_into_list(txt_path) if num>0: sampled_list_rand = random.sample(pair_list, num) sampled_list = [] for sample in sampled_list_rand: sampled_list.append([sample[0], sample[1], sample[2], sample[3]]) sampled_list.append([sample[1], sample[0], sample[3], sample[2]]) else: sampled_list = pair_list sampled_list = [[pth.replace(*switcher) for pth in pths] for pths in sampled_list] write_list_into_txt(output_path,sampled_list)
dataset.set_divided_ratio(divided_ratio) dataset.img_after_resize = (200,240,200) dataset.prepare_data() from easyreg.aug_utils import gen_post_aug_pair_list train_file_path = "/playpen-raid1/zyshen/data/brain_35/corrected/train/file_path_list.txt" test_file_path = "/playpen-raid1/zyshen/data/brain_35/corrected/test/file_path_list.txt" train_name_path = "/playpen-raid1/zyshen/data/brain_35/corrected/train/file_name_list.txt" test_name_path = "/playpen-raid1/zyshen/data/brain_35/corrected/test/file_name_list.txt" output_file_path = "/playpen-raid1/zyshen/data/brain_35/corrected/test_aug_path_list.txt" output_name_path = "/playpen-raid1/zyshen/data/brain_35/corrected/test_aug_name_list.txt" train_path_list = read_txt_into_list(train_file_path) test_path_list = read_txt_into_list(test_file_path) train_name_list = read_txt_into_list(train_name_path) test_name_list = read_txt_into_list(test_name_path) test_img_path_list = [path[0] for path in test_path_list] test_label_path_list = [path[1] for path in test_path_list] if isinstance(train_path_list[0],list): train_img_path_list = [path[0] for path in train_path_list] train_label_path_list = [path[1] for path in train_path_list] else: train_img_path_list = train_path_list train_label_path_list = None img_pair_list, pair_name_list = gen_post_aug_pair_list(test_img_path_list,train_img_path_list, test_fname_list=test_name_list,train_fname_list=train_name_list, test_label_path_list=test_label_path_list,train_label_path_list=train_label_path_list, pair_num_limit=-1, per_num_limit=5) pair_name_list = [pair_name[1:] for pair_name in pair_name_list]
print(args) pair_txt_path = args.pair_txt_path pair_name_txt_path = args.pair_name_txt_path source_list = args.source_list target_list = args.target_list lsource_list = args.lsource_list ltarget_list = args.ltarget_list pair_name_list = args.pair_name_list assert pair_txt_path is not None or source_list is not None, "either pair_txt_path or source/target_list should be provided" assert pair_txt_path is None or source_list is None, " pair_txt_path and source/target_list cannot be both provided" if pair_txt_path is not None: source_list, target_list, lsource_list, ltarget_list = loading_img_list_from_files( pair_txt_path) if pair_name_txt_path is not None: pair_name_list = read_txt_into_list(pair_name_txt_path) if source_list is not None: assert len(source_list) == len( target_list ), "the source and target list should be the same length" if lsource_list is not None: assert len(lsource_list) == len( source_list ), "the lsource and source list should be the same length" assert len(lsource_list) == len( ltarget_list ), " the lsource and ltarget list should be the same length" registration_pair_list = [ source_list, target_list, lsource_list, ltarget_list ]
parser.add_argument('-m', "--model_path", required=False, default=None, help='the path of trained model') parser.add_argument('-g', "--gpu_id", required=False, type=int, default=0, help='gpu_id to use') args = parser.parse_args() print(args) file_txt_path = args.file_txt_path image_list = args.image_list limage_list = args.limage_list image_label_list = [] assert file_txt_path is not None or image_list is not None, "either file_txt_path or source/target_list should be provided" assert file_txt_path is None or image_list is None, " file_txt_path and source/target_list cannot be both provided" if file_txt_path is not None: image_label_list = read_txt_into_list(file_txt_path) if limage_list is not None: assert len(image_list) == len( limage_list ), "the image_list and limage_list should be the same length" image_label_list = [image_list, limage_list] do_segmentation_eval(args, image_label_list)
def get_file_txt_from_pair_txt(txt_path, output_path): pair_list = read_txt_into_list(txt_path) file_list = [[pair[0], pair[2]] for pair in pair_list] file_list += [[pair[1], pair[3]] for pair in pair_list] write_list_into_txt(output_path,file_list)
def remove_label_info(pair_path_txt,output_txt): pair_list = read_txt_into_list(pair_path_txt) pair_remove_label = [[pair[0],pair[1]] for pair in pair_list] write_list_into_txt(output_txt, pair_remove_label)
"--gpu_id", required=False, type=int, default=0, help='gpu_id to use') args = parser.parse_args() print(args) file_txt_path = args.file_txt_path image_list = args.image_list limage_list = args.limage_list image_label_list = [] assert file_txt_path is not None or image_list is not None, "either file_txt_path or source/target_list should be provided" assert file_txt_path is None or image_list is None, " file_txt_path and source/target_list cannot be both provided" if file_txt_path is not None: image_label_list = read_txt_into_list(file_txt_path) if limage_list is not None: assert len(image_list) == len( limage_list ), "the image_list and limage_list should be the same length" with open('file_path_list.txt', 'w+') as f: f.write('{}\t{}'.format(image_list[0], limage_list[0])) args.file_txt_path = 'file_path_list.txt' image_label_list = read_txt_into_list('file_path_list.txt') args.image_list = None args.limage_list = None do_segmentation_eval(args, image_label_list)