def test_and_save_alignment(batch, batch_index, model_output, out_dir): # load the template image and the batch to align img_a = batch['image_a'][0] img_b = batch['image_b'][0] denorm_a_img = np.array(normalize_image(img_a.unsqueeze(0), forward=False).squeeze(0).permute(1, 2, 0) * 255, dtype=np.uint8) denorm_b_img = np.array(normalize_image(img_b.unsqueeze(0), forward=False).squeeze(0).permute(1, 2, 0) * 255, dtype=np.uint8) # get vertices to warp vertices = batch['vertices_a'][0] # add ones to warping points to_warp_pts = np.hstack([vertices, np.ones(shape=(len(vertices), 1))]) aff_matrix = model_output[0].reshape(2, 3) transform = aff_matrix.detach().numpy() # warp points through affine matrix warped_pts = transform.dot(to_warp_pts.T).T # denormalize warped points out_img_y, out_img_x = denorm_b_img.shape[:2] src_img_y, src_img_x = denorm_a_img.shape[:2] original_pts = np.array([[int(point[0]*src_img_x), int(point[1]*src_img_y)] for point in to_warp_pts], np.int32).reshape((-1, 1, 2)) to_draw_pts = np.array([[int(point[0]*out_img_x), int(point[1]*out_img_y)] for point in warped_pts], np.int32).reshape((-1, 1, 2)) drawn_b_image = np.ones(denorm_b_img.shape) * denorm_b_img drawn_a_image = np.ones(denorm_a_img.shape) * denorm_a_img # draw warped points over template image cv2.polylines(drawn_b_image, [to_draw_pts], True, (0, 0, 255), 7) cv2.polylines(drawn_a_image, [original_pts], True, (0, 0, 255), 7) # concatenate A and drawn B and save image concat_img = np.concatenate([drawn_a_image, drawn_b_image], axis=1) out_path = os.path.join(out_dir, 'drawn_{}.png'.format(batch_index)) cv2.imwrite(out_path, concat_img) return
def train(epoch, model, loss_fn, optimizer, dataloader, pair_generation_tnf, use_cuda=True, log_interval=50, logger=None): model.train() train_loss = 0 B = len(dataloader) for batch_idx, batch in enumerate(dataloader): optimizer.zero_grad() tnf_batch = pair_generation_tnf(batch) theta = model(tnf_batch) loss = loss_fn(theta, tnf_batch['theta_GT']) loss.backward() optimizer.step() train_loss += loss.data.cpu().numpy().item() if batch_idx % log_interval == 0: logger.info( 'Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( epoch, batch_idx, len(dataloader), 100. * batch_idx / len(dataloader), loss.item())) src_img = tnf_batch['source_image'][0].unsqueeze(0) tgt_img = tnf_batch['target_image'][0].unsqueeze(0) #resizeTgt = GeometricTnf(out_h=tgt_img.shape[2], out_w=tgt_img.shape[3], use_cuda = True) warped_image_aff = affTnf(src_img, theta[0].view(-1, 2, 3)) warped_image_aff_np = normalize_image(warped_image_aff, forward=False) src_img = normalize_image(src_img, forward=False).detach().cpu().numpy() tgt_img = normalize_image(tgt_img, forward=False).detach().cpu().numpy() warped_image_aff_np = warped_image_aff_np.detach().cpu().numpy() img_cat = np.concatenate((src_img, tgt_img, warped_image_aff_np), axis=3) info = {'img': img_cat} for tag, images in info.items(): tb_logger.add_images(tag, images, epoch * B + batch_idx) #print('Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( # epoch, batch_idx , len(dataloader), # 100. * batch_idx / len(dataloader), loss.item())) train_loss /= len(dataloader) print('Train set: Average loss: {:.4f}'.format(train_loss)) return train_loss
def process(self, source_img_path, target_img_path): """ Main process of model Input: source_img_path: path to image you want to transform or np.array (RGB) target_img_path: path to target image you want to transform or np.array (RGB) Output: affine_image (np.ndarray same shape with target) Affine Transformation result affine_tps_image (np.ndarray same shape with target): Affine TPS transformation result """ # Load data src_img, src_shape = self.load_image(source_img_path) target_img, target_shape = self.load_image(target_img_path) batch = {'source_image': src_img, 'target_image': target_img} self.model.eval() with torch.no_grad(): theta_aff, theta_aff_tps = self.model(batch) resizeTgt = GeometricTnf(out_h=target_shape[0], out_w=target_shape[1], use_cuda=self.use_cuda) warped_image_aff = self.affTnf(batch['source_image'], theta_aff.view(-1, 2, 3)) warped_image_aff_tps = affTpsTnf(batch['source_image'], theta_aff, theta_aff_tps) warped_image_aff_np = normalize_image( resizeTgt(warped_image_aff), forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_image_aff_tps_np = normalize_image( resizeTgt(warped_image_aff_tps), forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_image_aff_tps_np = warped_image_aff_tps_np.astype(np.uint8) warped_image_aff_np = warped_image_aff_np.astype(np.uint8) return warped_image_aff_np, warped_image_aff_tps_np
def preprocess_image_high_res_back(image): # convert to torch Variable image = np.expand_dims(image.transpose((2, 0, 1)), 0) image = torch.Tensor(image.astype(np.float32) / 255.0) image_var = Variable(image, requires_grad=False) # Normalize image image_var = normalize_image(image_var) return image_var
def preprocess_image(image): # convert to torch Variable image = np.expand_dims(image.transpose((2, 0, 1)), 0) image = torch.Tensor(image.astype(np.float32) / 255.0) image_var = Variable(image, requires_grad=False) # Resize image using bilinear sampling with identity affine tnf image_var = resize(image_var) # Normalize image image_var = normalize_image(image_var) return image_var
def preprocess_image(image,means,stds): """ Preprocesses the image for warping """ # convert to torch Variable image = np.expand_dims(image.transpose((2,0,1)),0) image = torch.Tensor(image.astype(np.float32)/255.0) image_var = Variable(image,requires_grad=False) # Resize image using bilinear sampling with identity affine tnf image_var = resizeCNN(image_var) # Normalize image image_var = normalize_image(image_var,mean=means,std=stds) return image_var
def im_show_1(image, title, rows, cols, index): """ Show image (transfer tensor to numpy first) """ # image = image.permute(1, 2, 0).cpu().numpy() # mean = np.array([0.485, 0.456, 0.406]) # std = np.array([0.229, 0.224, 0.225]) # image = std * image + mean image = normalize_image(image, forward=False) image = image.permute(1, 2, 0).cpu().numpy() ax = plt.subplot(rows, cols, index) ax.set_title(title) ax.imshow(image.clip(0, 1)) return ax
def preprocess_image_high_res(image): resizeCNN = GeometricTnf(out_h=half_out_size * 2, out_w=half_out_size * 2, use_cuda=False) # convert to torch Variable image = np.expand_dims(image.transpose((2, 0, 1)), 0) image = torch.Tensor(image.astype(np.float32) / 255.0) image_var = Variable(image, requires_grad=False) # Resize image using bilinear sampling with identity affine tnf image_var = resizeCNN(image_var) # Normalize image image_var = normalize_image(image_var) return image_var
def __call__(self, fname, fname2): image = io.imread(fname) image = np.expand_dims(image.transpose((2, 0, 1)), 0) image = torch.Tensor(image.astype(np.float32)) image_var = Variable(image, requires_grad=False) image_A = self.affTnf(image_var).data.squeeze(0) image_A_demo = self.affTnf_demo(image_var).data.squeeze(0) image_A_origin = self.affTnf_origin(image_var).data.squeeze(0) image2 = io.imread(fname2) image2 = np.expand_dims(image2.transpose((2, 0, 1)), 0) image2 = torch.Tensor(image2.astype(np.float32)) image_var2 = Variable(image2, requires_grad=False) image_B = self.affTnf(image_var2).data.squeeze(0) sample = {'source_image': image_A, 'target_image': image_B, 'demo': image_A_demo, 'origin_image': image_A_origin} sample = self.transform(sample) batchTensorToVars = BatchTensorToVars(use_cuda=self.use_cuda) batch = batchTensorToVars(sample) batch['source_image'] = torch.unsqueeze(batch['source_image'],0) batch['target_image'] = torch.unsqueeze(batch['target_image'],0) batch['origin_image'] = torch.unsqueeze(batch['origin_image'],0) batch['demo'] = torch.unsqueeze(batch['demo'],0) if self.do_aff: self.model_aff.eval() # Evaluate models if self.do_aff: theta_aff = self.model_aff(batch) warped_image_aff_demo = self.affTnf_demo(batch['demo'], theta_aff.view(-1, 2, 3)) if self.do_aff: warped_image_aff_demo = normalize_image(warped_image_aff_demo, forward=False) warped_image_aff_demo = warped_image_aff_demo.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy() print("Done") imsave('result.jpg', warped_image_aff_demo) return warped_image_aff_demo
target_points = batch['target_points'] # warp points with estimated transformations target_points_norm = PointsToUnitCoords(target_points, target_im_size) model.eval() # Evaluate model theta_aff, theta_aff_tps = model(batch) warped_image_aff = affTnf(batch['source_image'], theta_aff.view(-1, 2, 3)) warped_image_aff_tps = affTpsTnf(batch['source_image'], theta_aff, theta_aff_tps) # Un-normalize images and convert to numpy source_image = normalize_image(batch['source_image'], forward=False) source_image = source_image.data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() target_image = normalize_image(batch['target_image'], forward=False) target_image = target_image.data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_image_aff = normalize_image(warped_image_aff, forward=False) warped_image_aff = warped_image_aff.data.squeeze(0).transpose( 0, 1).transpose(1, 2).cpu().numpy() warped_image_aff_tps = normalize_image(warped_image_aff_tps, forward=False) warped_image_aff_tps = warped_image_aff_tps.data.squeeze(0).transpose( 0, 1).transpose(1, 2).cpu().numpy() # check if display is available
def log_images(tb_writer, batch, tnf_matrices, counter, tag=None, n_max=1): """ Fn to log image batches :param tb_writer: Summary Writer :param batch: Batch of samples :param tnf_matrices: Batch of transformations to apply :param counter: Epoch index :param tag: Default None, if a string is specified tags the log with it as a prefix :param n_max: Maximum numbers of images per batch to display :return: None """ images = zip(batch['image_a'], batch['image_b'], batch['vertices_a'], tnf_matrices) for idx, (img_a, img_b, vertices, aff_matrix) in enumerate(images): if idx < n_max: denorm_a_img = normalize_image(img_a.unsqueeze(0), forward=False) denorm_b_img = normalize_image(img_b.unsqueeze(0), forward=False) transform = aff_matrix.cpu().detach().reshape([2, 3]).numpy() # get vertices to warp vertices = np.array(vertices) # add ones to warping points to_warp_pts = np.hstack([vertices, np.ones(shape=(len(vertices), 1))]) # warp points through affine matrix warped_pts = transform.dot(to_warp_pts.T).T # denormalize warped points out_img_y, out_img_x = denorm_b_img.shape[2:] src_img_y, src_img_x = denorm_a_img.shape[2:] original_pts = np.array([[int(point[0] * src_img_x), int(point[1] * src_img_y)] for point in to_warp_pts], np.int32).reshape((-1, 1, 2)) to_draw_pts = np.array([[int(point[0] * out_img_x), int(point[1] * out_img_y)] for point in warped_pts], np.int32).reshape((-1, 1, 2)) drawn_a_img = np.moveaxis(denorm_a_img.squeeze().numpy(), 0, 2) drawn_b_img = np.moveaxis(denorm_b_img.squeeze().numpy(), 0, 2) drawn_a_img = np.ones(drawn_a_img.shape) * np.array(drawn_a_img * 255, dtype=np.uint8) drawn_b_img = np.ones(drawn_b_img.shape) * np.array(drawn_b_img * 255, dtype=np.uint8) # draw warped points over template image cv2.polylines(drawn_b_img, [to_draw_pts], True, (0, 0, 255), 7) cv2.polylines(drawn_a_img, [original_pts], True, (0, 0, 255), 7) # concatenate A and drawn B concat_img = cat([Tensor(drawn_a_img).double() / 255, Tensor(drawn_b_img).double() / 255], 1) if not tag: log_name = 'A warp on B' elif isinstance(tag, str): log_name = '{}\tA warp on B'.format(tag) else: raise ValueError("Unexpected type for 'tag', must be of type string.") # log image tb_writer.add_images(log_name, concat_img.permute(2, 0, 1).unsqueeze(0), counter) else: break
theta_aff_inv = torch.cat((theta_aff_inv, (torch.Tensor( [0, 0, 1]).to('cuda').unsqueeze(0).unsqueeze(1).expand(batch_size, 1, 3))), 1) theta_aff_2 = theta_aff_inv.inverse().contiguous().view(-1, 9)[:, :6] theta_aff_ensemble = (theta_aff + theta_aff_2) / 2 # Ensemble ### Process result warped_image_aff = affTnf(Im2Tensor(source_image), theta_aff_ensemble.view(-1, 2, 3)) result_aff_np = warped_image_aff.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().detach().numpy() io.imsave('results/aff.jpg', result_aff_np) """2nd Affine""" # Preprocess source_image_2 source_image_2 = normalize_image(resize(warped_image_aff.cpu())) if use_cuda: source_image_2 = source_image_2.cuda() theta_aff_aff, theta_aff_aff_inv = model({ 'source_image': source_image_2, 'target_image': batch['target_image'] }) # Calculate theta_aff_2 batch_size = theta_aff_aff.size(0) theta_aff_aff_inv = theta_aff_aff_inv.view(-1, 2, 3) theta_aff_aff_inv = torch.cat((theta_aff_aff_inv, (torch.Tensor( [0, 0, 1]).to('cuda').unsqueeze(0).unsqueeze(1).expand(batch_size, 1, 3))), 1) theta_aff_aff_2 = theta_aff_aff_inv.inverse().contiguous().view(-1, 9)[:, :6]
# Evaluate models if do_aff: theta_aff=model_aff(batch) warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3)) if do_tps: theta_tps=model_tps(batch) warped_image_tps = tpsTnf(batch['source_image'],theta_tps) if do_aff and do_tps: theta_aff_tps=model_tps({'source_image': warped_image_aff, 'target_image': batch['target_image']}) warped_image_aff_tps = tpsTnf(warped_image_aff,theta_aff_tps) # Un-normalize images and convert to numpy if do_aff: warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() if do_tps: warped_image_tps_np = normalize_image(resizeTgt(warped_image_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() if do_aff and do_tps: warped_image_aff_tps_np = normalize_image(resizeTgt(warped_image_aff_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() N_subplots = 2+int(do_aff)+int(do_tps)+int(do_aff and do_tps) fig, axs = plt.subplots(1,N_subplots) axs[0].imshow(source_image) axs[0].set_title('src') axs[1].imshow(target_image) axs[1].set_title('tgt') subplot_idx = 2 if do_aff:
def save_warped(source_name, target_name, savename, model_aff, model_tps, demo=False): """ Aligns a source image to a target image, and saves it. """ do_aff, do_tps = not model_aff == '', not model_tps == '' if not (do_aff and do_tps): print("No model found. Exiting.") return source_image = io.imread(source_name) target_image = io.imread(target_name) #Use the preprocess method declared above to resize, #normalize using the means and stds. Here, by default uses the means/stds of ImageNet, #Otherwise causes some weird issues we don't understand why. source_image_var = preprocess_image(source_image,means=means, stds=stds) target_image_var = preprocess_image(target_image,means=means, stds=stds) if use_cuda: source_image_var = source_image_var.cuda() target_image_var = target_image_var.cuda() #Create a "batch" (i.e. a pair) for the next cell below batch = {'source_image': source_image_var, 'target_image':target_image_var} #Resize target: create a function that will resize a given input into the target_image's dimension resizeTgt = GeometricTnf(out_h=target_image.shape[0], out_w=target_image.shape[1], use_cuda = use_cuda) #Set the models to eval mode if do_aff: model_aff.eval() if do_tps: model_tps.eval() # Evaluate models and get the thetas if do_aff: theta_aff=model_aff(batch) warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3)) if do_tps: theta_tps=model_tps(batch) warped_image_tps = tpsTnf(batch['source_image'],theta_tps) if do_aff and do_tps: theta_aff_tps=model_tps({'source_image': warped_image_aff, 'target_image': batch['target_image']}) warped_image_aff_tps = tpsTnf(warped_image_aff,theta_aff_tps) if do_aff: warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() if do_tps: warped_image_tps_np = normalize_image(resizeTgt(warped_image_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() if do_aff and do_tps: warped_image_aff_tps_np = normalize_image(resizeTgt(warped_image_aff_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() x = np.clip(warped_image_aff_tps_np,0,1) if demo==False: plt.imsave(savename, x) else: N_subplots = 2+int(do_aff)+int(do_tps)+int(do_aff and do_tps) fig, axs = plt.subplots(1,N_subplots) axs[0].imshow(source_image) axs[0].set_title('src') axs[1].imshow(target_image) axs[1].set_title('tgt') subplot_idx = 2 if do_aff: axs[subplot_idx].imshow(warped_image_aff_np) axs[subplot_idx].set_title('aff') subplot_idx +=1 if do_tps: axs[subplot_idx].imshow(warped_image_tps_np) axs[subplot_idx].set_title('tps') subplot_idx +=1 if do_aff and do_tps: axs[subplot_idx].imshow(warped_image_aff_tps_np) axs[subplot_idx].set_title('aff+tps') for i in range(N_subplots): axs[i].axis('off') fig.set_dpi(330) plt.show()
target_points = batch['target_points'] # warp points with estimated transformations target_points_norm = PointsToUnitCoords(target_points,target_im_size) model.eval() # Evaluate model theta_aff,theta_aff_tps=model(batch) warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3)) warped_image_aff_tps = affTpsTnf(batch['source_image'],theta_aff, theta_aff_tps) # Un-normalize images and convert to numpy source_image = normalize_image(batch['source_image'],forward=False) source_image = source_image.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() target_image = normalize_image(batch['target_image'],forward=False) target_image = target_image.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() warped_image_aff = normalize_image(warped_image_aff,forward=False) warped_image_aff = warped_image_aff.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() warped_image_aff_tps = normalize_image(warped_image_aff_tps,forward=False) warped_image_aff_tps = warped_image_aff_tps.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() # check if display is available exit_val = os.system('python -c "import matplotlib.pyplot as plt;plt.figure()" > /dev/null 2>&1') display_avail = exit_val==0 if display_avail:
def runCnn(model_cache, source_image_path, target_image_path, region01, region00, region10, region09): model_aff, model_tps, do_aff, do_tps, use_cuda = model_cache tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda) affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) tpsTnf_high_res = GeometricTnf_high_res(geometric_model='tps', use_cuda=use_cuda) affTnf_high_res = GeometricTnf_high_res(geometric_model='affine', use_cuda=use_cuda) source_image = io.imread(source_image_path) target_image = io.imread(target_image_path) # copy MRI image to 3 channels target_image3d = np.zeros( (target_image.shape[0], target_image.shape[1], 3), dtype=int) target_image3d[:, :, 0] = target_image target_image3d[:, :, 1] = target_image target_image3d[:, :, 2] = target_image target_image = np.copy(target_image3d) #### begin new code, affine registration using the masks only source_image_mask = np.copy(source_image) source_image_mask[np.any(source_image_mask > 5, axis=-1)] = 255 target_image_mask = np.copy(target_image) target_image_mask[np.any(target_image_mask > 5, axis=-1)] = 255 source_image_mask_var = preprocess_image(source_image_mask) target_image_mask_var = preprocess_image(target_image_mask) if use_cuda: source_image_mask_var = source_image_mask_var.cuda() target_image_mask_var = target_image_mask_var.cuda() batch_mask = { 'source_image': source_image_mask_var, 'target_image': target_image_mask_var } #### end new code source_image_var = preprocess_image(source_image) target_image_var = preprocess_image(target_image) region01_image_var = preprocess_image(region01) region00_image_var = preprocess_image(region00) region10_image_var = preprocess_image(region10) region09_image_var = preprocess_image(region09) source_image_var_high_res = preprocess_image_high_res(source_image) target_image_var_high_res = preprocess_image_high_res(target_image) region01_image_var_high_res = preprocess_image_high_res(region01) region00_image_var_high_res = preprocess_image_high_res(region00) region10_image_var_high_res = preprocess_image_high_res(region10) region09_image_var_high_res = preprocess_image_high_res(region09) if use_cuda: source_image_var = source_image_var.cuda() target_image_var = target_image_var.cuda() region01_image_var = region01_image_var.cuda() region00_image_var = region00_image_var.cuda() region10_image_var = region10_image_var.cuda() region09_image_var = region09_image_var.cuda() source_image_var_high_res = source_image_var_high_res.cuda() target_image_var_high_res = target_image_var_high_res.cuda() region01_image_var_high_res = region01_image_var_high_res.cuda() region00_image_var_high_res = region00_image_var_high_res.cuda() region10_image_var_high_res = region10_image_var_high_res.cuda() region09_image_var_high_res = region09_image_var_high_res.cuda() batch = { 'source_image': source_image_var, 'target_image': target_image_var } batch_high_res = { 'source_image': source_image_var_high_res, 'target_image': target_image_var_high_res } if do_aff: model_aff.eval() if do_tps: model_tps.eval() # Evaluate models if do_aff: #theta_aff=model_aff(batch) #### affine registration using the masks only theta_aff = model_aff(batch_mask) warped_image_aff_high_res = affTnf_high_res( batch_high_res['source_image'], theta_aff.view(-1, 2, 3)) warped_image_aff = affTnf(batch['source_image'], theta_aff.view(-1, 2, 3)) warped_region01_aff_high_res = affTnf_high_res( region01_image_var_high_res, theta_aff.view(-1, 2, 3)) warped_region00_aff_high_res = affTnf_high_res( region00_image_var_high_res, theta_aff.view(-1, 2, 3)) warped_region10_aff_high_res = affTnf_high_res( region10_image_var_high_res, theta_aff.view(-1, 2, 3)) warped_region09_aff_high_res = affTnf_high_res( region09_image_var_high_res, theta_aff.view(-1, 2, 3)) ###>>>>>>>>>>>> do affine registration one more time<<<<<<<<<<<< warped_mask_aff = affTnf(source_image_mask_var, theta_aff.view(-1, 2, 3)) theta_aff = model_aff({ 'source_image': warped_mask_aff, 'target_image': target_image_mask_var }) warped_image_aff_high_res = affTnf_high_res(warped_image_aff_high_res, theta_aff.view(-1, 2, 3)) warped_image_aff = affTnf(warped_image_aff, theta_aff.view(-1, 2, 3)) warped_region01_aff_high_res = affTnf_high_res( warped_region01_aff_high_res, theta_aff.view(-1, 2, 3)) warped_region00_aff_high_res = affTnf_high_res( warped_region00_aff_high_res, theta_aff.view(-1, 2, 3)) warped_region10_aff_high_res = affTnf_high_res( warped_region10_aff_high_res, theta_aff.view(-1, 2, 3)) warped_region09_aff_high_res = affTnf_high_res( warped_region09_aff_high_res, theta_aff.view(-1, 2, 3)) ###>>>>>>>>>>>> do affine registration one more time<<<<<<<<<<<< if do_aff and do_tps: theta_aff_tps = model_tps({ 'source_image': warped_image_aff, 'target_image': batch['target_image'] }) warped_image_aff_tps_high_res = tpsTnf_high_res( warped_image_aff_high_res, theta_aff_tps) warped_region01_aff_tps_high_res = tpsTnf_high_res( warped_region01_aff_high_res, theta_aff_tps) warped_region00_aff_tps_high_res = tpsTnf_high_res( warped_region00_aff_high_res, theta_aff_tps) warped_region10_aff_tps_high_res = tpsTnf_high_res( warped_region10_aff_high_res, theta_aff_tps) warped_region09_aff_tps_high_res = tpsTnf_high_res( warped_region09_aff_high_res, theta_aff_tps) # Un-normalize images and convert to numpy if do_aff: warped_image_aff_np_high_res = normalize_image( warped_image_aff_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region01_aff_np_high_res = normalize_image( warped_region01_aff_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region00_aff_np_high_res = normalize_image( warped_region00_aff_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region10_aff_np_high_res = normalize_image( warped_region10_aff_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region09_aff_np_high_res = normalize_image( warped_region09_aff_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() if do_aff and do_tps: warped_image_aff_tps_np_high_res = normalize_image( warped_image_aff_tps_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region01_aff_tps_np_high_res = normalize_image( warped_region01_aff_tps_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region00_aff_tps_np_high_res = normalize_image( warped_region00_aff_tps_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region10_aff_tps_np_high_res = normalize_image( warped_region10_aff_tps_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_region09_aff_tps_np_high_res = normalize_image( warped_region09_aff_tps_high_res, forward=False).data.squeeze(0).transpose(0, 1).transpose( 1, 2).cpu().numpy() warped_image_aff_np_high_res[warped_image_aff_np_high_res < 0] = 0 warped_image_aff_tps_np_high_res[warped_image_aff_tps_np_high_res < 0] = 0 return warped_image_aff_tps_np_high_res, warped_region01_aff_tps_np_high_res, warped_region00_aff_tps_np_high_res, warped_region10_aff_tps_np_high_res, warped_region09_aff_tps_np_high_res