Example #1
0
def test_and_save_alignment(batch, batch_index, model_output, out_dir):
    # load the template image and the batch to align
    img_a = batch['image_a'][0]
    img_b = batch['image_b'][0]

    denorm_a_img = np.array(normalize_image(img_a.unsqueeze(0),
                                            forward=False).squeeze(0).permute(1, 2, 0) * 255,
                            dtype=np.uint8)

    denorm_b_img = np.array(normalize_image(img_b.unsqueeze(0),
                                            forward=False).squeeze(0).permute(1, 2, 0) * 255,
                            dtype=np.uint8)

    # get vertices to warp
    vertices = batch['vertices_a'][0]
    # add ones to warping points
    to_warp_pts = np.hstack([vertices,
                             np.ones(shape=(len(vertices), 1))])

    aff_matrix = model_output[0].reshape(2, 3)

    transform = aff_matrix.detach().numpy()

    # warp points through affine matrix
    warped_pts = transform.dot(to_warp_pts.T).T

    # denormalize warped points
    out_img_y, out_img_x = denorm_b_img.shape[:2]
    src_img_y, src_img_x = denorm_a_img.shape[:2]

    original_pts = np.array([[int(point[0]*src_img_x), int(point[1]*src_img_y)] for point in to_warp_pts],
                            np.int32).reshape((-1, 1, 2))
    to_draw_pts = np.array([[int(point[0]*out_img_x), int(point[1]*out_img_y)] for point in warped_pts],
                           np.int32).reshape((-1, 1, 2))

    drawn_b_image = np.ones(denorm_b_img.shape) * denorm_b_img
    drawn_a_image = np.ones(denorm_a_img.shape) * denorm_a_img

    # draw warped points over template image
    cv2.polylines(drawn_b_image,  [to_draw_pts],
                  True, (0, 0, 255), 7)
    cv2.polylines(drawn_a_image,  [original_pts],
                  True, (0, 0, 255), 7)

    # concatenate A and drawn B and save image
    concat_img = np.concatenate([drawn_a_image, drawn_b_image], axis=1)

    out_path = os.path.join(out_dir, 'drawn_{}.png'.format(batch_index))

    cv2.imwrite(out_path, concat_img)

    return
Example #2
0
def train(epoch,
          model,
          loss_fn,
          optimizer,
          dataloader,
          pair_generation_tnf,
          use_cuda=True,
          log_interval=50,
          logger=None):
    model.train()
    train_loss = 0
    B = len(dataloader)
    for batch_idx, batch in enumerate(dataloader):
        optimizer.zero_grad()
        tnf_batch = pair_generation_tnf(batch)
        theta = model(tnf_batch)
        loss = loss_fn(theta, tnf_batch['theta_GT'])
        loss.backward()
        optimizer.step()
        train_loss += loss.data.cpu().numpy().item()
        if batch_idx % log_interval == 0:
            logger.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format(
                    epoch, batch_idx, len(dataloader),
                    100. * batch_idx / len(dataloader), loss.item()))
            src_img = tnf_batch['source_image'][0].unsqueeze(0)
            tgt_img = tnf_batch['target_image'][0].unsqueeze(0)
            #resizeTgt = GeometricTnf(out_h=tgt_img.shape[2], out_w=tgt_img.shape[3], use_cuda = True)
            warped_image_aff = affTnf(src_img, theta[0].view(-1, 2, 3))
            warped_image_aff_np = normalize_image(warped_image_aff,
                                                  forward=False)

            src_img = normalize_image(src_img,
                                      forward=False).detach().cpu().numpy()
            tgt_img = normalize_image(tgt_img,
                                      forward=False).detach().cpu().numpy()
            warped_image_aff_np = warped_image_aff_np.detach().cpu().numpy()

            img_cat = np.concatenate((src_img, tgt_img, warped_image_aff_np),
                                     axis=3)
            info = {'img': img_cat}
            for tag, images in info.items():
                tb_logger.add_images(tag, images, epoch * B + batch_idx)
            #print('Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format(
            #   epoch, batch_idx , len(dataloader),
            #   100. * batch_idx / len(dataloader), loss.item()))
    train_loss /= len(dataloader)
    print('Train set: Average loss: {:.4f}'.format(train_loss))
    return train_loss
    def process(self, source_img_path, target_img_path):
        """ Main process of model 
        
        Input: 
            source_img_path: path to image you want to transform or np.array (RGB)
            target_img_path: path to target image you want to transform or np.array (RGB)

        Output:
            affine_image (np.ndarray same shape with target) Affine Transformation result 
            affine_tps_image (np.ndarray same shape with target): Affine TPS transformation result
            
        """
        # Load data
        src_img, src_shape = self.load_image(source_img_path)
        target_img, target_shape = self.load_image(target_img_path)

        batch = {'source_image': src_img, 'target_image': target_img}

        self.model.eval()

        with torch.no_grad():
            theta_aff, theta_aff_tps = self.model(batch)
        resizeTgt = GeometricTnf(out_h=target_shape[0],
                                 out_w=target_shape[1],
                                 use_cuda=self.use_cuda)

        warped_image_aff = self.affTnf(batch['source_image'],
                                       theta_aff.view(-1, 2, 3))
        warped_image_aff_tps = affTpsTnf(batch['source_image'], theta_aff,
                                         theta_aff_tps)

        warped_image_aff_np = normalize_image(
            resizeTgt(warped_image_aff),
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_image_aff_tps_np = normalize_image(
            resizeTgt(warped_image_aff_tps),
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()

        warped_image_aff_tps_np = warped_image_aff_tps_np.astype(np.uint8)
        warped_image_aff_np = warped_image_aff_np.astype(np.uint8)

        return warped_image_aff_np, warped_image_aff_tps_np
Example #4
0
def preprocess_image_high_res_back(image):
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2, 0, 1)), 0)
    image = torch.Tensor(image.astype(np.float32) / 255.0)
    image_var = Variable(image, requires_grad=False)

    # Normalize image
    image_var = normalize_image(image_var)

    return image_var
def preprocess_image(image):
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2, 0, 1)), 0)
    image = torch.Tensor(image.astype(np.float32) / 255.0)
    image_var = Variable(image, requires_grad=False)

    # Resize image using bilinear sampling with identity affine tnf
    image_var = resize(image_var)

    # Normalize image
    image_var = normalize_image(image_var)

    return image_var
Example #6
0
def preprocess_image(image,means,stds):
    """
    Preprocesses the image for warping
    """
    # convert to torch Variable
    image = np.expand_dims(image.transpose((2,0,1)),0)
    image = torch.Tensor(image.astype(np.float32)/255.0)
    image_var = Variable(image,requires_grad=False)
    # Resize image using bilinear sampling with identity affine tnf
    image_var = resizeCNN(image_var)
    # Normalize image
    image_var = normalize_image(image_var,mean=means,std=stds)
    return image_var
Example #7
0
def im_show_1(image, title, rows, cols, index):
    """ Show image (transfer tensor to numpy first) """
    # image = image.permute(1, 2, 0).cpu().numpy()
    # mean = np.array([0.485, 0.456, 0.406])
    # std = np.array([0.229, 0.224, 0.225])
    # image = std * image + mean
    image = normalize_image(image, forward=False)
    image = image.permute(1, 2, 0).cpu().numpy()
    ax = plt.subplot(rows, cols, index)
    ax.set_title(title)
    ax.imshow(image.clip(0, 1))

    return ax
Example #8
0
def preprocess_image_high_res(image):
    resizeCNN = GeometricTnf(out_h=half_out_size * 2,
                             out_w=half_out_size * 2,
                             use_cuda=False)

    # convert to torch Variable
    image = np.expand_dims(image.transpose((2, 0, 1)), 0)
    image = torch.Tensor(image.astype(np.float32) / 255.0)
    image_var = Variable(image, requires_grad=False)

    # Resize image using bilinear sampling with identity affine tnf
    image_var = resizeCNN(image_var)

    # Normalize image
    image_var = normalize_image(image_var)

    return image_var
Example #9
0
    def __call__(self, fname, fname2):

        image = io.imread(fname)
        image = np.expand_dims(image.transpose((2, 0, 1)), 0)
        image = torch.Tensor(image.astype(np.float32))
        image_var = Variable(image, requires_grad=False)
        image_A = self.affTnf(image_var).data.squeeze(0)
        image_A_demo = self.affTnf_demo(image_var).data.squeeze(0)
        image_A_origin = self.affTnf_origin(image_var).data.squeeze(0)

        image2 = io.imread(fname2)
        image2 = np.expand_dims(image2.transpose((2, 0, 1)), 0)
        image2 = torch.Tensor(image2.astype(np.float32))
        image_var2 = Variable(image2, requires_grad=False)
        image_B = self.affTnf(image_var2).data.squeeze(0)

        sample = {'source_image': image_A, 'target_image': image_B, 'demo': image_A_demo, 'origin_image': image_A_origin}

        sample = self.transform(sample)

        batchTensorToVars = BatchTensorToVars(use_cuda=self.use_cuda)

        batch = batchTensorToVars(sample)
        batch['source_image'] = torch.unsqueeze(batch['source_image'],0)
        batch['target_image'] = torch.unsqueeze(batch['target_image'],0)
        batch['origin_image'] = torch.unsqueeze(batch['origin_image'],0)
        batch['demo'] = torch.unsqueeze(batch['demo'],0)

        if self.do_aff:
            self.model_aff.eval()

        # Evaluate models
        if self.do_aff:
            theta_aff = self.model_aff(batch)
            warped_image_aff_demo = self.affTnf_demo(batch['demo'], theta_aff.view(-1, 2, 3))

        if self.do_aff:
            warped_image_aff_demo = normalize_image(warped_image_aff_demo, forward=False)
            warped_image_aff_demo = warped_image_aff_demo.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy()

        print("Done")
        imsave('result.jpg', warped_image_aff_demo)

        return warped_image_aff_demo
Example #10
0
    target_points = batch['target_points']

    # warp points with estimated transformations
    target_points_norm = PointsToUnitCoords(target_points, target_im_size)

    model.eval()

    # Evaluate model
    theta_aff, theta_aff_tps = model(batch)

    warped_image_aff = affTnf(batch['source_image'], theta_aff.view(-1, 2, 3))
    warped_image_aff_tps = affTpsTnf(batch['source_image'], theta_aff,
                                     theta_aff_tps)

    # Un-normalize images and convert to numpy
    source_image = normalize_image(batch['source_image'], forward=False)
    source_image = source_image.data.squeeze(0).transpose(0, 1).transpose(
        1, 2).cpu().numpy()
    target_image = normalize_image(batch['target_image'], forward=False)
    target_image = target_image.data.squeeze(0).transpose(0, 1).transpose(
        1, 2).cpu().numpy()

    warped_image_aff = normalize_image(warped_image_aff, forward=False)
    warped_image_aff = warped_image_aff.data.squeeze(0).transpose(
        0, 1).transpose(1, 2).cpu().numpy()

    warped_image_aff_tps = normalize_image(warped_image_aff_tps, forward=False)
    warped_image_aff_tps = warped_image_aff_tps.data.squeeze(0).transpose(
        0, 1).transpose(1, 2).cpu().numpy()

    # check if display is available
def log_images(tb_writer, batch, tnf_matrices, counter, tag=None, n_max=1):
    """
    Fn to log image batches

    :param tb_writer: Summary Writer
    :param batch: Batch of samples
    :param tnf_matrices: Batch of transformations to apply
    :param counter: Epoch index
    :param tag: Default None, if a string is specified tags the log
     with it as a prefix
    :param n_max: Maximum numbers of images per batch to display
    :return: None
    """

    images = zip(batch['image_a'], batch['image_b'], batch['vertices_a'], tnf_matrices)

    for idx, (img_a, img_b, vertices, aff_matrix) in enumerate(images):
        if idx < n_max:

            denorm_a_img = normalize_image(img_a.unsqueeze(0),
                                           forward=False)
            denorm_b_img = normalize_image(img_b.unsqueeze(0),
                                           forward=False)
            transform = aff_matrix.cpu().detach().reshape([2, 3]).numpy()

            # get vertices to warp
            vertices = np.array(vertices)
            # add ones to warping points
            to_warp_pts = np.hstack([vertices,
                                     np.ones(shape=(len(vertices), 1))])

            # warp points through affine matrix
            warped_pts = transform.dot(to_warp_pts.T).T

            # denormalize warped points
            out_img_y, out_img_x = denorm_b_img.shape[2:]
            src_img_y, src_img_x = denorm_a_img.shape[2:]

            original_pts = np.array([[int(point[0] * src_img_x), int(point[1] * src_img_y)] for point in to_warp_pts],
                                    np.int32).reshape((-1, 1, 2))
            to_draw_pts = np.array([[int(point[0] * out_img_x), int(point[1] * out_img_y)] for point in warped_pts],
                                   np.int32).reshape((-1, 1, 2))

            drawn_a_img = np.moveaxis(denorm_a_img.squeeze().numpy(), 0, 2)
            drawn_b_img = np.moveaxis(denorm_b_img.squeeze().numpy(), 0, 2)

            drawn_a_img = np.ones(drawn_a_img.shape) * np.array(drawn_a_img * 255,
                                                                dtype=np.uint8)
            drawn_b_img = np.ones(drawn_b_img.shape) * np.array(drawn_b_img * 255,
                                                                dtype=np.uint8)

            # draw warped points over template image
            cv2.polylines(drawn_b_img, [to_draw_pts],
                          True, (0, 0, 255), 7)
            cv2.polylines(drawn_a_img, [original_pts],
                          True, (0, 0, 255), 7)

            # concatenate A and drawn B
            concat_img = cat([Tensor(drawn_a_img).double() / 255,
                              Tensor(drawn_b_img).double() / 255], 1)

            if not tag:
                log_name = 'A warp on B'
            elif isinstance(tag, str):
                log_name = '{}\tA warp on B'.format(tag)
            else:
                raise ValueError("Unexpected type for 'tag', must be of type string.")

            # log image
            tb_writer.add_images(log_name,
                                 concat_img.permute(2, 0, 1).unsqueeze(0),
                                 counter)

        else:
            break
theta_aff_inv = torch.cat((theta_aff_inv, (torch.Tensor(
    [0, 0, 1]).to('cuda').unsqueeze(0).unsqueeze(1).expand(batch_size, 1, 3))),
                          1)
theta_aff_2 = theta_aff_inv.inverse().contiguous().view(-1, 9)[:, :6]

theta_aff_ensemble = (theta_aff + theta_aff_2) / 2  # Ensemble

### Process result
warped_image_aff = affTnf(Im2Tensor(source_image),
                          theta_aff_ensemble.view(-1, 2, 3))
result_aff_np = warped_image_aff.squeeze(0).transpose(0, 1).transpose(
    1, 2).cpu().detach().numpy()
io.imsave('results/aff.jpg', result_aff_np)
"""2nd Affine"""
# Preprocess source_image_2
source_image_2 = normalize_image(resize(warped_image_aff.cpu()))
if use_cuda:
    source_image_2 = source_image_2.cuda()
theta_aff_aff, theta_aff_aff_inv = model({
    'source_image': source_image_2,
    'target_image': batch['target_image']
})

# Calculate theta_aff_2
batch_size = theta_aff_aff.size(0)
theta_aff_aff_inv = theta_aff_aff_inv.view(-1, 2, 3)
theta_aff_aff_inv = torch.cat((theta_aff_aff_inv, (torch.Tensor(
    [0, 0, 1]).to('cuda').unsqueeze(0).unsqueeze(1).expand(batch_size, 1, 3))),
                              1)
theta_aff_aff_2 = theta_aff_aff_inv.inverse().contiguous().view(-1, 9)[:, :6]
Example #13
0
# Evaluate models
if do_aff:
    theta_aff=model_aff(batch)
    warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3))

if do_tps:
    theta_tps=model_tps(batch)
    warped_image_tps = tpsTnf(batch['source_image'],theta_tps)

if do_aff and do_tps:
    theta_aff_tps=model_tps({'source_image': warped_image_aff, 'target_image': batch['target_image']})        
    warped_image_aff_tps = tpsTnf(warped_image_aff,theta_aff_tps)

# Un-normalize images and convert to numpy
if do_aff:
    warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

if do_tps:
    warped_image_tps_np = normalize_image(resizeTgt(warped_image_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

if do_aff and do_tps:
    warped_image_aff_tps_np = normalize_image(resizeTgt(warped_image_aff_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

N_subplots = 2+int(do_aff)+int(do_tps)+int(do_aff and do_tps)
fig, axs = plt.subplots(1,N_subplots)
axs[0].imshow(source_image)
axs[0].set_title('src')
axs[1].imshow(target_image)
axs[1].set_title('tgt')
subplot_idx = 2
if do_aff:
Example #14
0
def save_warped(source_name, target_name, savename, model_aff, model_tps, demo=False):
    """
    Aligns a source image to a target image, and saves it.
    """
    do_aff, do_tps = not model_aff == '', not model_tps == ''
    if not (do_aff and do_tps):
        print("No model found. Exiting.")
        return
    source_image = io.imread(source_name)
    target_image = io.imread(target_name)
    #Use the preprocess method declared above to resize, 
    #normalize using the means and stds. Here, by default uses the means/stds of ImageNet,
    #Otherwise causes some weird issues we don't understand why.
    source_image_var = preprocess_image(source_image,means=means, stds=stds)
    target_image_var = preprocess_image(target_image,means=means, stds=stds)
    
    if use_cuda:
        source_image_var = source_image_var.cuda()
        target_image_var = target_image_var.cuda()
    #Create a "batch" (i.e. a pair) for the next cell below
    batch = {'source_image': source_image_var, 'target_image':target_image_var}
    #Resize target: create a function that will resize a given input into the target_image's dimension
    resizeTgt = GeometricTnf(out_h=target_image.shape[0], out_w=target_image.shape[1], use_cuda = use_cuda) 
    #Set the models to eval mode
    if do_aff:
        model_aff.eval()
    if do_tps:
        model_tps.eval()      
    # Evaluate models and get the thetas
    if do_aff:
        theta_aff=model_aff(batch)
        warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3))
    
    if do_tps:
        theta_tps=model_tps(batch)
        warped_image_tps = tpsTnf(batch['source_image'],theta_tps)
    
    if do_aff and do_tps:
        theta_aff_tps=model_tps({'source_image': warped_image_aff, 'target_image': batch['target_image']})        
        warped_image_aff_tps = tpsTnf(warped_image_aff,theta_aff_tps)
    if do_aff:
        warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
    
    if do_tps:
        warped_image_tps_np = normalize_image(resizeTgt(warped_image_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
    
    if do_aff and do_tps:
        warped_image_aff_tps_np = normalize_image(resizeTgt(warped_image_aff_tps),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
        x = np.clip(warped_image_aff_tps_np,0,1)

    if demo==False:
        plt.imsave(savename, x)
    else:
        N_subplots = 2+int(do_aff)+int(do_tps)+int(do_aff and do_tps)
        fig, axs = plt.subplots(1,N_subplots)
        axs[0].imshow(source_image)
        axs[0].set_title('src')
        axs[1].imshow(target_image)
        axs[1].set_title('tgt')
        subplot_idx = 2
        if do_aff:
            axs[subplot_idx].imshow(warped_image_aff_np)
            axs[subplot_idx].set_title('aff')
            subplot_idx +=1 
        if do_tps:
            axs[subplot_idx].imshow(warped_image_tps_np)
            axs[subplot_idx].set_title('tps')
            subplot_idx +=1 
        if do_aff and do_tps:
            axs[subplot_idx].imshow(warped_image_aff_tps_np)
            axs[subplot_idx].set_title('aff+tps')
        
        for i in range(N_subplots):
            axs[i].axis('off')
        
        fig.set_dpi(330)
        plt.show()
Example #15
0
    target_points = batch['target_points']
    
    # warp points with estimated transformations
    target_points_norm = PointsToUnitCoords(target_points,target_im_size)
    
    model.eval()
    
    # Evaluate model
    theta_aff,theta_aff_tps=model(batch)

    warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3))
    warped_image_aff_tps = affTpsTnf(batch['source_image'],theta_aff, theta_aff_tps)


    # Un-normalize images and convert to numpy
    source_image = normalize_image(batch['source_image'],forward=False)
    source_image = source_image.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
    target_image = normalize_image(batch['target_image'],forward=False)
    target_image = target_image.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
    
    warped_image_aff = normalize_image(warped_image_aff,forward=False)
    warped_image_aff = warped_image_aff.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

    warped_image_aff_tps = normalize_image(warped_image_aff_tps,forward=False)
    warped_image_aff_tps = warped_image_aff_tps.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()

    # check if display is available
    exit_val = os.system('python -c "import matplotlib.pyplot as plt;plt.figure()"  > /dev/null 2>&1')
    display_avail = exit_val==0

    if display_avail:
Example #16
0
def runCnn(model_cache, source_image_path, target_image_path, region01,
           region00, region10, region09):
    model_aff, model_tps, do_aff, do_tps, use_cuda = model_cache

    tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
    affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)

    tpsTnf_high_res = GeometricTnf_high_res(geometric_model='tps',
                                            use_cuda=use_cuda)
    affTnf_high_res = GeometricTnf_high_res(geometric_model='affine',
                                            use_cuda=use_cuda)

    source_image = io.imread(source_image_path)
    target_image = io.imread(target_image_path)

    # copy MRI image to 3 channels
    target_image3d = np.zeros(
        (target_image.shape[0], target_image.shape[1], 3), dtype=int)
    target_image3d[:, :, 0] = target_image
    target_image3d[:, :, 1] = target_image
    target_image3d[:, :, 2] = target_image
    target_image = np.copy(target_image3d)

    #### begin new code, affine registration using the masks only
    source_image_mask = np.copy(source_image)
    source_image_mask[np.any(source_image_mask > 5, axis=-1)] = 255
    target_image_mask = np.copy(target_image)
    target_image_mask[np.any(target_image_mask > 5, axis=-1)] = 255
    source_image_mask_var = preprocess_image(source_image_mask)
    target_image_mask_var = preprocess_image(target_image_mask)

    if use_cuda:
        source_image_mask_var = source_image_mask_var.cuda()
        target_image_mask_var = target_image_mask_var.cuda()
    batch_mask = {
        'source_image': source_image_mask_var,
        'target_image': target_image_mask_var
    }
    #### end new code

    source_image_var = preprocess_image(source_image)
    target_image_var = preprocess_image(target_image)
    region01_image_var = preprocess_image(region01)
    region00_image_var = preprocess_image(region00)
    region10_image_var = preprocess_image(region10)
    region09_image_var = preprocess_image(region09)

    source_image_var_high_res = preprocess_image_high_res(source_image)
    target_image_var_high_res = preprocess_image_high_res(target_image)
    region01_image_var_high_res = preprocess_image_high_res(region01)
    region00_image_var_high_res = preprocess_image_high_res(region00)
    region10_image_var_high_res = preprocess_image_high_res(region10)
    region09_image_var_high_res = preprocess_image_high_res(region09)

    if use_cuda:
        source_image_var = source_image_var.cuda()
        target_image_var = target_image_var.cuda()
        region01_image_var = region01_image_var.cuda()
        region00_image_var = region00_image_var.cuda()
        region10_image_var = region10_image_var.cuda()
        region09_image_var = region09_image_var.cuda()
        source_image_var_high_res = source_image_var_high_res.cuda()
        target_image_var_high_res = target_image_var_high_res.cuda()
        region01_image_var_high_res = region01_image_var_high_res.cuda()
        region00_image_var_high_res = region00_image_var_high_res.cuda()
        region10_image_var_high_res = region10_image_var_high_res.cuda()
        region09_image_var_high_res = region09_image_var_high_res.cuda()

    batch = {
        'source_image': source_image_var,
        'target_image': target_image_var
    }
    batch_high_res = {
        'source_image': source_image_var_high_res,
        'target_image': target_image_var_high_res
    }

    if do_aff:
        model_aff.eval()
    if do_tps:
        model_tps.eval()

    # Evaluate models
    if do_aff:
        #theta_aff=model_aff(batch)
        #### affine registration using the masks only
        theta_aff = model_aff(batch_mask)
        warped_image_aff_high_res = affTnf_high_res(
            batch_high_res['source_image'], theta_aff.view(-1, 2, 3))
        warped_image_aff = affTnf(batch['source_image'],
                                  theta_aff.view(-1, 2, 3))
        warped_region01_aff_high_res = affTnf_high_res(
            region01_image_var_high_res, theta_aff.view(-1, 2, 3))
        warped_region00_aff_high_res = affTnf_high_res(
            region00_image_var_high_res, theta_aff.view(-1, 2, 3))
        warped_region10_aff_high_res = affTnf_high_res(
            region10_image_var_high_res, theta_aff.view(-1, 2, 3))
        warped_region09_aff_high_res = affTnf_high_res(
            region09_image_var_high_res, theta_aff.view(-1, 2, 3))

        ###>>>>>>>>>>>> do affine registration one more time<<<<<<<<<<<<
        warped_mask_aff = affTnf(source_image_mask_var,
                                 theta_aff.view(-1, 2, 3))
        theta_aff = model_aff({
            'source_image': warped_mask_aff,
            'target_image': target_image_mask_var
        })
        warped_image_aff_high_res = affTnf_high_res(warped_image_aff_high_res,
                                                    theta_aff.view(-1, 2, 3))
        warped_image_aff = affTnf(warped_image_aff, theta_aff.view(-1, 2, 3))
        warped_region01_aff_high_res = affTnf_high_res(
            warped_region01_aff_high_res, theta_aff.view(-1, 2, 3))
        warped_region00_aff_high_res = affTnf_high_res(
            warped_region00_aff_high_res, theta_aff.view(-1, 2, 3))
        warped_region10_aff_high_res = affTnf_high_res(
            warped_region10_aff_high_res, theta_aff.view(-1, 2, 3))
        warped_region09_aff_high_res = affTnf_high_res(
            warped_region09_aff_high_res, theta_aff.view(-1, 2, 3))
        ###>>>>>>>>>>>> do affine registration one more time<<<<<<<<<<<<

    if do_aff and do_tps:
        theta_aff_tps = model_tps({
            'source_image': warped_image_aff,
            'target_image': batch['target_image']
        })
        warped_image_aff_tps_high_res = tpsTnf_high_res(
            warped_image_aff_high_res, theta_aff_tps)
        warped_region01_aff_tps_high_res = tpsTnf_high_res(
            warped_region01_aff_high_res, theta_aff_tps)
        warped_region00_aff_tps_high_res = tpsTnf_high_res(
            warped_region00_aff_high_res, theta_aff_tps)
        warped_region10_aff_tps_high_res = tpsTnf_high_res(
            warped_region10_aff_high_res, theta_aff_tps)
        warped_region09_aff_tps_high_res = tpsTnf_high_res(
            warped_region09_aff_high_res, theta_aff_tps)

    # Un-normalize images and convert to numpy
    if do_aff:
        warped_image_aff_np_high_res = normalize_image(
            warped_image_aff_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region01_aff_np_high_res = normalize_image(
            warped_region01_aff_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region00_aff_np_high_res = normalize_image(
            warped_region00_aff_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region10_aff_np_high_res = normalize_image(
            warped_region10_aff_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region09_aff_np_high_res = normalize_image(
            warped_region09_aff_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()

    if do_aff and do_tps:
        warped_image_aff_tps_np_high_res = normalize_image(
            warped_image_aff_tps_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region01_aff_tps_np_high_res = normalize_image(
            warped_region01_aff_tps_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region00_aff_tps_np_high_res = normalize_image(
            warped_region00_aff_tps_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region10_aff_tps_np_high_res = normalize_image(
            warped_region10_aff_tps_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()
        warped_region09_aff_tps_np_high_res = normalize_image(
            warped_region09_aff_tps_high_res,
            forward=False).data.squeeze(0).transpose(0, 1).transpose(
                1, 2).cpu().numpy()

    warped_image_aff_np_high_res[warped_image_aff_np_high_res < 0] = 0
    warped_image_aff_tps_np_high_res[warped_image_aff_tps_np_high_res < 0] = 0

    return warped_image_aff_tps_np_high_res, warped_region01_aff_tps_np_high_res, warped_region00_aff_tps_np_high_res, warped_region10_aff_tps_np_high_res, warped_region09_aff_tps_np_high_res