def reg(args, state, fix_image_filename, moving_image_filename, iterations): args_state = state['args'] image_size = [256, 256] gpu_id = args.gpu_id device = th.device("cuda:" + str(gpu_id)) if gpu_id >= 0: th.cuda.set_device(gpu_id) if args_state.model == "R2NN": model = gru.GRU_Registration(image_size, 2, args=args_state, device=device) else: raise ValueError('model type {0} is not known'.format(args_state.model)) model.eval() model.load_state_dict(state['model']) print("model loaded") if gpu_id >= 0: with th.cuda.device(gpu_id): model.cuda() if args_state.image_loss == "MSE": image_loss = il.MSE() else: print("Image loss is not suported") grid = compute_grid([image_size[0], image_size[1]], device=device) if not os.path.exists(args.o): os.makedirs(args.o) fixed_image = sitk.ReadImage(fix_image_filename, sitk.sitkFloat32) fixed_image = th.tensor(sitk.GetArrayFromImage(fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0) fixed_image = fixed_image.to(device=device) fixed_image = fixed_image - th.mean(fixed_image) fixed_image = fixed_image / th.std(fixed_image) fixed_image.clamp_(-2, 2) moving_image = sitk.ReadImage(moving_image_filename, sitk.sitkFloat32) moving_image = th.tensor(sitk.GetArrayFromImage(moving_image)).squeeze().unsqueeze_(0).unsqueeze_(0) moving_image = moving_image.to(device=device) moving_image = moving_image - th.mean(moving_image) moving_image = moving_image / th.std(moving_image) moving_image.clamp_(-2, 2) image_loss_f, warped_image, displacement, displacement_param, displacement_pixel, single_displacement, warped_local_image = eval_rnn(iterations, model, fixed_image, moving_image, image_loss, grid) show_text = False for idx, param in enumerate(displacement_param): sigma = 2 * param[0].squeeze().cpu().numpy() * 255 pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255 angle = -(param[3].cpu().numpy() * 180.0) / np.pi displacement_sum = th.sqrt(displacement_pixel[idx][0, 0, ...].pow(2) + displacement_pixel[idx][0, 1, ...].pow(2)) fig = plt.imshow(displacement_sum.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0) ax = plt.gca() ax.add_patch(Ellipse(pos, width=sigma[0], height=sigma[1], angle=angle, edgecolor='white', facecolor='none', linewidth=2)) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') if show_text: plt.text(8, 250, r"transformation sum: $t=" + str(idx) + "$", {'color': 'w', 'fontsize': 18}) plt.savefig(os.path.join(args.o, "displacement_sum_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0) plt.close() ############################################################################################################## displacement_local = th.sqrt(single_displacement[idx][0, 0, ...].pow(2) + single_displacement[idx][0, 1, ...].pow(2)) fig = plt.imshow(displacement_local.cpu().squeeze().numpy(), cmap='jet',vmax=0.03, vmin=0) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') if show_text: plt.text(32, 15, r"network output: $t=" + str(idx) + "$", {'color': 'w', 'fontsize': 18}) plt.text(8, 250, r"transformation local: $t=" + str(idx) + "$", {'color': 'w', 'fontsize': 18}) # plt.show() plt.savefig(os.path.join(args.o, "displacement_local_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0) plt.close() ############################################################################################################## fig = plt.imshow(warped_local_image[idx].cpu().squeeze().numpy(), cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.axis('off') if show_text: if idx == 0: plt.text(16, 250, r"moving image", {'color': 'r', 'fontsize': 18}) else: plt.text(16, 250, r"warped image: $t=" + str(idx - 1) + "$", {'color': 'r', 'fontsize': 18}) plt.text(3, 15, r"input: $t=" + str(idx) + "$", {'color': 'r', 'fontsize': 18}) plt.savefig(os.path.join(args.o, "warped_loacl_input_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0) plt.close() ############################################################################################################## fig = plt.imshow(warped_local_image[idx + 1].cpu().squeeze().numpy(), cmap='gray') ax = plt.gca() ax.add_patch(Ellipse(pos, width=sigma[0], height=sigma[1], angle=angle, edgecolor='white', facecolor='none', linewidth=2)) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.axis('off') if show_text: plt.text(16, 250, r"warped image: $t=" + str(idx) + "$", {'color': 'r', 'fontsize': 18}) plt.savefig(os.path.join(args.o, "warped_loacl_output_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0) plt.close() ############################################################################################################## diff_image = warped_local_image[idx].cpu().squeeze().data.fill_(1).numpy() fig = plt.imshow(diff_image, cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') if show_text: plt.text(128, 150, "Recurrent Registration\n" "Neural Networks for\n" "Deformable Image\n""Registration", {'color': 'w', 'fontsize': 18}, ha='center', wrap=True, ) plt.savefig(os.path.join(args.o, "diff_image_" + str(idx) + ".png"), bbox_inches='tight', pad_inches=0) plt.close() fig = plt.imshow(fixed_image.cpu().squeeze().numpy(), cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') if show_text: plt.text(64, 250, r"fixed image", {'color': 'r', 'fontsize': 18}) plt.text(180, 15, r"network ", {'color': 'r', 'fontsize': 18}) plt.savefig(os.path.join(args.o, "fixed_image.png"), bbox_inches='tight', pad_inches=0)
def reg(args, state, fix_image_filename, moving_image_filename, iterations): args_state = state['args'] image_size = [256, 256] gpu_id = args.gpu_id device = th.device("cuda:" + str(gpu_id)) if gpu_id >= 0: th.cuda.set_device(gpu_id) if args_state.model == "R2NN": model = gru.GRU_Registration(image_size, 2, args=args_state, device=device) else: raise ValueError('model type {0} is not known'.format( args_state.model)) model.eval() model.load_state_dict(state['model']) print("model loaded") if gpu_id >= 0: with th.cuda.device(gpu_id): model.cuda() if args_state.image_loss == "MSE": image_loss = il.MSE() else: print("Image loss is not suported") grid = compute_grid([image_size[0], image_size[1]], device=device) if not os.path.exists(args.o): os.makedirs(args.o) fixed_image = sitk.ReadImage(fix_image_filename, sitk.sitkFloat32) fixed_image = th.tensor(sitk.GetArrayFromImage( fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0) fixed_image = fixed_image.to(device=device) fixed_image = fixed_image - th.mean(fixed_image) fixed_image = fixed_image / th.std(fixed_image) fixed_image.clamp_(-2, 2) moving_image = sitk.ReadImage(moving_image_filename, sitk.sitkFloat32) moving_image = th.tensor(sitk.GetArrayFromImage( moving_image)).squeeze().unsqueeze_(0).unsqueeze_(0) moving_image = moving_image.to(device=device) moving_image = moving_image - th.mean(moving_image) moving_image = moving_image / th.std(moving_image) moving_image.clamp_(-2, 2) image_loss_f, warped_image, displacement, displacement_param, displacement_pixel, single_displacement, warped_local_image = eval_rnn( iterations, model, fixed_image, moving_image, image_loss, grid) weight_x = [] weight_y = [] shape_x = [] shape_y = [] angle = [] for idx, param in enumerate(displacement_param): shape_x.append(param[0][0].cpu().numpy()) shape_y.append(param[0][1].cpu().numpy()) weight_x.append(param[1][0].cpu().numpy()) weight_y.append(param[1][1].cpu().numpy()) angle.append(-(param[3].cpu().numpy() * 180.0) / np.pi) plt.plot(shape_y, label='shape size $\sigma_x$') plt.plot(shape_x, label='shape size $\sigma_y$') plt.xlabel("Time steps $t$") plt.ylabel("Shape size $\sigma_t$ of the local transformation") plt.legend(bbox_to_anchor=(0.5, 0.9), loc=2, borderaxespad=0.) matplotlib2tikz.save("/tmp/shape.tex") plt.show() plt.plot(weight_x, label='weight $v_x$') plt.plot(weight_y, label='weight $v_y$') plt.xlabel("Time steps $t$") plt.ylabel("Weights $v_t$ of the local transformations ") plt.legend(bbox_to_anchor=(0.5, 0.9), loc=2, borderaxespad=0.) matplotlib2tikz.save("/tmp/weight.tex") plt.show() plt.plot(angle) plt.xlabel("Time steps $t$") plt.ylabel("Time steps $t$") plt.grid(True) plt.show() matplotlib2tikz.save("/tmp/test.tex")
def test(args, state, image_size=[256, 256]): args_state = state['args'] gpu_id = args.gpu_id device = th.device("cuda:" + str(gpu_id)) if gpu_id >= 0: th.cuda.set_device(gpu_id) patients = sorted(os.listdir(args.test_path)) # compute mean image of all data mean_image_filenames = [] for patient in patients: examinations = sorted(os.listdir(os.path.join(args.test_path, patient))) for exa in examinations: slices = sorted(os.listdir(os.path.join(args.test_path, patient, exa))) for image_slice in slices: slice_path = os.path.join(args.test_path, patient, exa, image_slice) images = sorted(os.listdir(slice_path)) images = [f for f in images if os.path.isfile(os.path.join(os.path.join(args.test_path, patient, exa, image_slice, f)))] mean_image_filenames.append(get_fixe_image_filename(slice_path, images)) print(len(mean_image_filenames)) print(mean_image_filenames) if args_state.model == "R2NN": model = gru.GRU_Registration(image_size, 2, args=args_state, device=device) else: raise ValueError('model type {0} is not known'.format(args_state.model)) model.eval() model.load_state_dict(state['model']) print("model loaded") if gpu_id >= 0: with th.cuda.device(gpu_id): model.cuda() if args_state.image_loss == "MSE": image_loss = il.MSE() else: print("Image loss is not suported") grid = compute_grid([image_size[0], image_size[1]], device=device) if args_state.model == "R2NN": evaluate_net = eval_rnn elif args_state.model == "UNET": evaluate_net = eval_feed_forward if not os.path.exists(args.o): os.makedirs(args.o) out_path_image_data = os.path.join(args.o, "image_data") if not os.path.exists(out_path_image_data): os.makedirs(out_path_image_data) slice_index_global = 0 gloabl_eval_error = [] for patient in patients: if os.path.exists(os.path.join(args.o, "error_" + patient + ".csv")): os.remove(os.path.join(args.o, "error_" + patient + ".csv")) if os.path.exists(os.path.join(args.o, "tre_" + patient + ".csv")): os.remove(os.path.join(args.o, "tre_" + patient + ".csv")) for patient in patients: examinations = sorted(os.listdir(os.path.join(args.test_path, patient))) image_loss_examination = 0 for exa in examinations: slices = sorted(os.listdir(os.path.join(args.test_path, patient, exa))) image_loss_slices = 0 for image_slice in slices: slice_path = os.path.join(args.test_path, patient, exa, image_slice) image_filenames = sorted(os.listdir(slice_path)) image_filenames = [f for f in image_filenames if f.endswith(".dcm")] output_path = os.path.join(out_path_image_data, patient, exa, image_slice) if not os.path.exists(output_path): os.makedirs(output_path) fix_image_filename = os.path.join(slice_path, mean_image_filenames[slice_index_global]) fixed_image = sitk.ReadImage(os.path.join(slice_path, fix_image_filename), sitk.sitkFloat32) # load fixed image landmarks fix_landmarks_filenames = os.path.join(slice_path, "landmarks", "landmarks_" + mean_image_filenames[slice_index_global][:-4] + ".vtk") fixed_image_points = Points.read(fix_landmarks_filenames) fixed_image = th.tensor(sitk.GetArrayFromImage(fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0) fixed_image = fixed_image.to(device=device) fixed_image = fixed_image - th.mean(fixed_image) fixed_image = fixed_image / th.std(fixed_image) fixed_image.clamp_(-2, 2) sitk.WriteImage(sitk.GetImageFromArray(fixed_image.detach().cpu().squeeze().numpy()), os.path.join(output_path, "fixed_" + mean_image_filenames[slice_index_global][:-4] + ".vtk")) image_loss_images = 0 image_loss_images_csv = [] tre_slice = [] tre_slice.append(image_slice) tre_slice.append(fix_image_filename) image_loss_images_csv.append(image_slice) image_loss_images_csv.append(fix_image_filename) for image_filename in image_filenames: moving_image = sitk.ReadImage(os.path.join(slice_path, image_filename), sitk.sitkFloat32) # get image properties image_spacing = [1, 1] image_origin = [0, 0] moving_image = th.tensor(sitk.GetArrayFromImage(moving_image)).squeeze().unsqueeze_(0) \ .unsqueeze_(0) moving_image = moving_image.to(device=device) moving_image = moving_image - th.mean(moving_image) moving_image = moving_image / th.std(moving_image) moving_image.clamp_(-2, 2) # load moving image landmarks moving_landmarks_filenames = os.path.join(slice_path, "landmarks", "landmarks_" + image_filename[:-4] + ".vtk") moving_image_points = Points.read(moving_landmarks_filenames) start = time.time() image_loss_f, warped_image, displacement = evaluate_net(args_state, model, fixed_image, moving_image, image_loss, grid) stop = time.time() displacement = displacement.flip(2) displacement = displacement.transpose(1, 2).transpose(2, 3) displacement = displacement.squeeze().to(dtype=th.float64, device='cpu') # transform to itk displacement for dim in range(displacement.shape[-1]): tmp = float(displacement.shape[-dim - 2] - 1) displacement[..., dim] = float(displacement.shape[-dim - 2] - 1) * displacement[..., dim] / 2.0 itk_displacement = sitk.GetImageFromArray(displacement.numpy(), isVector=True) itk_displacement.SetSpacing(image_spacing) itk_displacement.SetOrigin(image_origin) # # displacement_al = Displacement(displacement, image_size=[256, 256], image_spacing=image_spacing, # image_origin=image_origin) # displacement_al.image = displacement_al.image*image_spacing[0] moving_points_transformed = Points.transform(moving_image_points, itk_displacement) tre = Points.TRE(moving_points_transformed, fixed_image_points) tre_slice.append(tre) print("Time", stop-start) image_loss_images_csv.append(image_loss_f) Points.write(os.path.join(output_path, "warped_points_" + image_filename[:-4] + ".vtk"), moving_points_transformed) Points.write(os.path.join(output_path, "moving_points_" + image_filename[:-4] + ".vtk"), moving_image_points) Points.write(os.path.join(output_path, "fixed_points_" + image_filename[:-4] + ".vtk"), fixed_image_points) image_loss_images += image_loss_f sitk.WriteImage(sitk.GetImageFromArray(warped_image.detach().cpu().squeeze().numpy()), os.path.join(output_path, "warped_" + image_filename[:-4] + ".vtk")) sitk.WriteImage(sitk.GetImageFromArray(moving_image.detach().cpu().squeeze().numpy()), os.path.join(output_path, "moving_" + image_filename[:-4] + ".vtk")) sitk.WriteImage(sitk.GetImageFromArray( displacement.detach().cpu().squeeze().numpy(), isVector=True), os.path.join(output_path, "displacement_" + image_filename[:-4] + ".vtk")) slice_index_global += 1 with open(os.path.join(args.o, "error_" + patient + ".csv"), 'a') as csvFile: writer = csv.writer(csvFile, delimiter=',') writer.writerow(image_loss_images_csv) with open(os.path.join(args.o, "tre_" + patient + ".csv"), 'a') as csvFile: writer = csv.writer(csvFile, delimiter=',') writer.writerow(tre_slice) image_loss_images /= len(image_filenames) image_loss_slices += image_loss_images image_loss_slices /= len(slices) image_loss_examination += image_loss_slices image_loss_examination /= len(examinations) gloabl_eval_error.append(image_loss_examination) with open(os.path.join(args.o, "error_all_patients.csv"), 'a') as csvFile: writer = csv.writer(csvFile, delimiter=',') writer.writerow([patient, examinations, image_loss_examination])
def reg(args, state, fix_image_filename, moving_image_filename, iterations): args_state = state['args'] image_size = [256, 256] gpu_id = args.gpu_id device = th.device("cuda:" + str(gpu_id)) if gpu_id >= 0: th.cuda.set_device(gpu_id) if args_state.model == "R2NN": model = gru.GRU_Registration(image_size, 2, args=args_state, device=device) else: raise ValueError('model type {0} is not known'.format( args_state.model)) model.eval() model.load_state_dict(state['model']) print("model loaded") if gpu_id >= 0: with th.cuda.device(gpu_id): model.cuda() if args_state.image_loss == "MSE": image_loss = il.MSE() else: print("Image loss is not suported") grid = compute_grid([image_size[0], image_size[1]], device=device) if not os.path.exists(args.o): os.makedirs(args.o) fixed_image = sitk.ReadImage(fix_image_filename, sitk.sitkFloat32) fixed_image = th.tensor(sitk.GetArrayFromImage( fixed_image)).squeeze().unsqueeze_(0).unsqueeze_(0) fixed_image = fixed_image.to(device=device) fixed_image = fixed_image - th.mean(fixed_image) fixed_image = fixed_image / th.std(fixed_image) fixed_image.clamp_(-2, 2) moving_image = sitk.ReadImage(moving_image_filename, sitk.sitkFloat32) moving_image = th.tensor(sitk.GetArrayFromImage( moving_image)).squeeze().unsqueeze_(0).unsqueeze_(0) moving_image = moving_image.to(device=device) moving_image = moving_image - th.mean(moving_image) moving_image = moving_image / th.std(moving_image) moving_image.clamp_(-2, 2) image_loss_f, warped_image, displacement, displacement_param, displacement_pixel = eval_rnn( iterations, model, fixed_image, moving_image, image_loss, grid) displacement_mag = th.sqrt(displacement_pixel[1][0, 0, ...].pow(2) + displacement_pixel[1][0, 1, ...].pow(2)) fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0) for idx, param in enumerate(displacement_param): if idx < 2: sigma = 2 * param[0].squeeze().cpu().numpy() * 255 pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255 angle = -(param[3].cpu().numpy() * 180.0) / np.pi ax = plt.gca() ax.add_patch( Ellipse(pos, width=sigma[0], height=sigma[1], angle=angle, edgecolor='white', facecolor='none', linewidth=2)) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "disp_2.png"), bbox_inches='tight', pad_inches=0) plt.close() displacement_mag = th.sqrt(displacement_pixel[3][0, 0, ...].pow(2) + displacement_pixel[3][0, 1, ...].pow(2)) fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0) for idx, param in enumerate(displacement_param): if idx < 4: sigma = 2 * param[0].squeeze().cpu().numpy() * 255 pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255 angle = -(param[3].cpu().numpy() * 180.0) / np.pi ax = plt.gca() ax.add_patch( Ellipse(pos, width=sigma[0], height=sigma[1], angle=angle, edgecolor='white', facecolor='none', linewidth=2)) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "disp_4.png"), bbox_inches='tight', pad_inches=0) plt.close() displacement_mag = th.sqrt(displacement_pixel[7][0, 0, ...].pow(2) + displacement_pixel[7][0, 1, ...].pow(2)) fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0) for idx, param in enumerate(displacement_param): if idx < 8: sigma = 2 * param[0].squeeze().cpu().numpy() * 255 pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255 angle = -(param[3].cpu().numpy() * 180.0) / np.pi ax = plt.gca() ax.add_patch( Ellipse(pos, width=sigma[0], height=sigma[1], angle=angle, edgecolor='white', facecolor='none', linewidth=2)) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "disp_8.png"), bbox_inches='tight', pad_inches=0) plt.close() displacement_mag = th.sqrt(displacement[0, 0, ...].pow(2) + displacement[0, 1, ...].pow(2)) fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0) for idx, param in enumerate(displacement_param): sigma = 2 * param[0].squeeze().cpu().numpy() * 255 pos = ((param[2].squeeze().cpu().numpy() + 1) / 2) * 255 angle = -(param[3].cpu().numpy() * 180.0) / np.pi ax = plt.gca() ax.add_patch( Ellipse(pos, width=sigma[0], height=sigma[1], angle=angle, edgecolor='white', facecolor='none', linewidth=2)) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "disp_25.png"), bbox_inches='tight', pad_inches=0) plt.close() fig = plt.imshow(fixed_image.cpu().squeeze().numpy(), cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "fixed_image.png"), bbox_inches='tight', pad_inches=0) plt.close() fig = plt.imshow(moving_image.cpu().squeeze().numpy(), cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "moving_image.png"), bbox_inches='tight', pad_inches=0) plt.close() fig = plt.imshow(warped_image.cpu().squeeze().numpy(), cmap='gray') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "warped_image.png"), bbox_inches='tight', pad_inches=0) plt.close() fig = plt.imshow(displacement_mag.cpu().squeeze().numpy(), cmap='jet', vmax=0.08, vmin=0) fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) plt.axis('off') plt.savefig(os.path.join(args.o, "displacement.png"), bbox_inches='tight', pad_inches=0) plt.close()
def train_sync(args): continue_optimization = False eval_iteration = 0 if args.model_state != "": state = th.load(args.model_state, map_location='cpu') continue_optimization = True eval_iteration = state['eval_counter'] gpu_id = args.gpu_ids[0] th.manual_seed(args.seed) np.random.seed(args.seed) device = th.device("cuda:" + str(gpu_id)) viz = vis.Visdom(port=args.port) data_manager = dm.DataManager(args.training_path, normalize_std=args.normalize_std, random_sampling=args.random_img_pair) image_size = data_manager.image_size() # Parameters params = { 'batch_size': args.batch_size, 'shuffle': True, 'num_workers': args.nb_workers, 'pin_memory': True } training_generator = data.DataLoader(data_manager, **params) model = gru.GRU_Registration(image_size, 2, device=device, args=args) if continue_optimization: model.load_state_dict(state['model']) model.train() model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("number of parameters model", params) evaluater = Evaluater(args, image_size, eval_iteration=eval_iteration) if gpu_id >= 0: with th.cuda.device(gpu_id): model.cuda() if gpu_id >= 0: th.cuda.manual_seed(args.seed) th.cuda.set_device(gpu_id) if args.optimizer == 'RMSprop': optimizer = th.optim.RMSprop(model.parameters(), lr=args.lr) elif args.optimizer == 'Adam': optimizer = th.optim.Adam(model.parameters(), lr=args.lr, amsgrad=args.amsgrad) elif args.optimizer == 'Rprop': optimizer = th.optim.Rprop(model.parameters(), lr=args.lr) if continue_optimization: optimizer.load_state_dict = state['optimizer'] if args.image_loss == "MSE": image_loss = il.MSE() else: print("Image loss is not suported") regulariser = dl.IsotropicTVRegulariser([1.0, 1.0]) grid = compute_grid([image_size[0], image_size[1]]).cuda() train_counter = 0 if continue_optimization: train_counter = state['train_counter'] loss_plot = None print("start optimization") scale = 1 if args.use_diff_loss: scale = -1 while True: for fixed_image, moving_image in training_generator: fixed_image = fixed_image.cuda() moving_image = moving_image.cuda() if train_counter % args.eval_interval == 0: print("Start evaluation") evaluater.evaluation(model) image_loss_epoch = 0 model.reset() model.zero_grad() warped_image = moving_image displacement = th.zeros(args.batch_size, 2, image_size[0], image_size[1], device=fixed_image.device, dtype=fixed_image.dtype) displacement_trans = displacement.transpose(1, 2).transpose( 2, 3) + grid if args.entropy_regularizer_weight > 0: shapes = th.zeros(1, 1, image_size[0], image_size[1], device=fixed_image.device, dtype=fixed_image.dtype) single_entropy = 0 loss_start, _ = image_loss(displacement_trans, fixed_image, warped_image) if args.early_stopping > 0: if loss_start.item() < args.early_stopping: continue start = time.time() for j in range(args.rnn_iter): net_input = th.cat((fixed_image, warped_image), dim=1) net_ouput = model(net_input) displacement = displacement + net_ouput[0] if args.entropy_regularizer_weight > 0: f_x = net_ouput[1] / (th.sum(net_ouput[1]) + 1e-5) + 1e-5 shapes = shapes + f_x single_entropy = single_entropy + compute_entropy(f_x) displacement_trans = displacement.transpose(1, 2).transpose( 2, 3) + grid warped_image = F.grid_sample(moving_image, displacement_trans) loss_, _ = image_loss(displacement_trans, fixed_image, warped_image) if args.use_diff_loss: image_loss_epoch = image_loss_epoch + (loss_start - loss_) loss_start = loss_ else: image_loss_epoch = image_loss_epoch + loss_ if args.early_stopping > 0: if loss_.item() < args.early_stopping: break if args.stop_on_reverse: if loss_.item() <= loss_start.item(): loss_start = loss_ else: break j = j + 1 displacement_loss = args.reg_weight * regulariser(displacement) loss = scale * image_loss_epoch / j + displacement_loss if args.entropy_regularizer_weight > 0: entropy_loss = (compute_entropy(shapes / j) + single_entropy / j) * args.entropy_regularizer_weight loss = loss - entropy_loss entropy_loss_value = entropy_loss.data.item() else: entropy_loss_value = 0 optimizer.zero_grad() loss.backward() if args.clip_gradients: th.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() end = time.time() if train_counter % args.save_model == 0: state = { 'train_counter': train_counter, 'eval_counter': evaluater.eval_iterations, 'args': args, 'agent_id': -1, 'optimizer': optimizer.state_dict(), 'model': model.state_dict() } path = os.path.join(args.o, "state_agent_sync.pt") th.save(state, path) print("iter ", train_counter, "image loss ", image_loss_epoch.item() / j, "displacement loss ", displacement_loss.item(), "loss ", loss.item(), "time", end - start) if loss_plot is None: opts = dict(title=("loss_value"), width=1000, height=500, showlegend=True) loss_value_ = np.column_stack( np.array([ image_loss_epoch.data.item() / j, displacement_loss.data.item(), entropy_loss_value ])) loss_plot = viz.line(X=np.column_stack( np.ones(3) * train_counter), Y=loss_value_, opts=opts) else: loss_value_ = np.column_stack( np.array([ image_loss_epoch.data.item() / j, displacement_loss.data.item(), entropy_loss_value ])) loss_plot = viz.line(X=np.column_stack( np.ones(3) * train_counter), Y=loss_value_, win=loss_plot, update='append') if train_counter % 250 == 0: fixed_image_vis = imfilter.normalize_image( fixed_image[0, ...]).cpu().unsqueeze(0) moving_image_vis = imfilter.normalize_image( moving_image[0, ...]).cpu().unsqueeze(0) displacement_vis = imfilter.normalize_image( displacement[0, ...]).cpu().unsqueeze(0).detach() warped_image_vis = imfilter.normalize_image( warped_image[0, ...]).cpu().unsqueeze(0).detach() checkerboard_image = sitk.GetArrayFromImage( sitk.CheckerBoard( sitk.GetImageFromArray( moving_image_vis.squeeze().numpy()), sitk.GetImageFromArray( fixed_image_vis.squeeze().numpy()), [20, 20])) checkerboard_image_vis_nor_reg = th.Tensor( checkerboard_image).unsqueeze(0).unsqueeze(0) checkerboard_image = sitk.GetArrayFromImage( sitk.CheckerBoard( sitk.GetImageFromArray( warped_image_vis.squeeze().numpy()), sitk.GetImageFromArray( fixed_image_vis.squeeze().numpy()), [20, 20])) checkerboard_image_vis = th.Tensor( checkerboard_image).unsqueeze(0).unsqueeze(0) image_stack = th.cat( (fixed_image_vis, moving_image_vis, displacement_vis[:, 0, ...].unsqueeze(1), displacement_vis[:, 1, ...].unsqueeze(1), warped_image_vis, checkerboard_image_vis, checkerboard_image_vis_nor_reg), dim=0) opts = dict(title="results") viz.images(image_stack, opts=opts, win=2) train_counter += 1