# Instantiate point transformer
pt = PointTnf(use_cuda=use_cuda)

# Instatiate image transformers
tpsTnf = GeometricTnf(geometric_model='tps', use_cuda=use_cuda)
affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda)

for i, batch in enumerate(dataloader):
    # get random batch of size 1
    batch = batchTensorToVars(batch)

    source_im_size = batch['source_im_size']
    target_im_size = batch['target_im_size']

    if do_aff:
        model_aff.eval()
    if do_tps:
        model_tps.eval()

    # Evaluate models
    if do_aff:
        theta_aff = model_aff(batch)
        warped_image_aff = affTnf(batch['source_image'],
                                  theta_aff.view(-1, 2, 3))

    if do_tps:
        theta_tps = model_tps(batch)
        warped_image_tps = tpsTnf(batch['source_image'], theta_tps)

    if do_aff and do_tps:
        theta_aff_tps = model_tps({
示例#2
0
        if batch_idx % log_interval == 0:
            print(mode.capitalize() +
                  ' Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format(
                      epoch, batch_idx, len(dataloader), 100. * batch_idx /
                      len(dataloader), loss_np))
    epoch_loss /= len(dataloader)
    print(mode.capitalize() + ' set: Average loss: {:.4f}'.format(epoch_loss))
    return epoch_loss


train_loss = np.zeros(args.num_epochs)
test_loss = np.zeros(args.num_epochs)

print('Starting training...')

model.eval()

for epoch in range(1, args.num_epochs + 1):
    model.train()
    train_loss[epoch - 1] = process_epoch('train',
                                          epoch,
                                          model,
                                          loss,
                                          optimizer,
                                          dataloader,
                                          pair_generation_tnf,
                                          log_interval=100)
    model.eval()
    test_loss[epoch - 1] = process_epoch('test',
                                         epoch,
                                         model,
示例#3
0
class CNNGeometricMatcher:
    def __init__(self,
                 use_extracted_features=False,
                 geometric_affine_model=None,
                 geometric_tps_model=None,
                 arch='resnet18',
                 featext_weights=None,
                 min_mutual_keypoints=6,
                 min_reprojection_error=200):
        self.min_mutual_keypoints = min_mutual_keypoints
        self.min_reprojection_error = min_reprojection_error
        self.__do_affine = geometric_affine_model is not None
        self.__do_tps = not use_extracted_features and geometric_tps_model is not None
        self.__affTnf = GeometricTnf(geometric_model='affine',
                                     use_cuda=_use_cuda)
        if self.__do_affine:
            checkpoint = torch.load(geometric_affine_model,
                                    map_location=lambda storage, loc: storage)
            print('Loading CNN Affine Geometric Model')
            if use_extracted_features:
                self.model_affine = CNNGeometricRegression(
                    use_cuda=use_cuda,
                    geometric_model='affine',
                    arch=arch,
                    featext_weights=featext_weights)
                model_dict = self.model_affine.state_dict()
                pretrained_dict = {
                    k: v
                    for k, v in checkpoint['state_dict'].items()
                    if k in model_dict
                }
                model_dict.update(pretrained_dict)
                self.model_affine.load_state_dict(model_dict)
            else:
                self.model_affine = CNNGeometric(
                    use_cuda=_use_cuda,
                    geometric_model='affine',
                    arch=arch,
                    featext_weights=featext_weights)

                self.model_affine.load_state_dict(checkpoint['state_dict'])

            self.model_affine.eval()

        if self.__do_tps:
            self.model_tps = CNNGeometric(use_cuda=use_cuda,
                                          geometric_model='tps',
                                          arch=arch,
                                          featext_weights=featext_weights)
            checkpoint = torch.load(geometric_tps_model,
                                    map_location=lambda storage, loc: storage)
            print('Loading CNN TPS Geometric Model')
            #self.model_tps.load_state_dict(checkpoint['state_dict'])
            model_dict = self.model_tps.state_dict()
            pretrained_dict = {
                k: v
                for k, v in checkpoint['state_dict'].items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            self.model_tps.load_state_dict(model_dict)
            self.model_tps.eval()

        self.pt = PointTnf(use_cuda=_use_cuda)

    def run(self, batch):
        if self.__do_affine:
            theta_aff, correlationAB, correlationBA = self.model_affine(batch)

        if self.__do_tps:
            if self.__do_affine:
                warped_image_aff = self.__affTnf(batch['source_image'],
                                                 theta_aff.view(-1, 2, 3))
                theta_affine_tps, correlationAB, correlationBA = self.model_tps(
                    {
                        'source_image': warped_image_aff,
                        'target_image': batch['target_image']
                    })
            else:
                theta_tps, correlationAB, correlationBA = self.model_tps(batch)

        keypoints_A, keypoints_B = find_mutual_matached_keypoints(
            correlationAB, correlationBA)
        num_mutual_keypoints = keypoints_A.shape[0]

        if num_mutual_keypoints < self.min_mutual_keypoints:
            matched = False
            reprojection_error = -1
        else:
            source_im_size = batch['source_im_size']
            target_im_size = batch['target_im_size']

            source_im_shape_np = source_im_size.data.numpy()
            target_im_shape_np = target_im_size.data.numpy()

            tensor_shape = correlationAB.data.numpy()[0].shape

            im_keypointsA = tensorPointstoPixels(
                keypoints_A,
                tensor_size=tensor_shape,
                im_size=(source_im_shape_np[0][1], source_im_shape_np[0][0]))
            im_keypointsB = tensorPointstoPixels(
                keypoints_B,
                tensor_size=tensor_shape,
                im_size=(target_im_shape_np[0][1], target_im_shape_np[0][0]))

            torch_keypointsA_var = Variable(
                torch.Tensor(
                    im_keypointsA.reshape(1, 2, -1).astype(np.float32)))
            torch_keypointsB_var = Variable(
                torch.Tensor(
                    im_keypointsB.reshape(1, 2, -1).astype(np.float32)))

            target_points_norm = PointsToUnitCoords(torch_keypointsB_var,
                                                    target_im_size)

            if self.__do_affine and self.__do_tps:
                warped_points_aff_tps_norm = self.pt.tpsPointTnf(
                    theta_affine_tps, target_points_norm)
                warped_points_norm = self.pt.affPointTnf(
                    theta_aff, warped_points_aff_tps_norm)
            elif self.__do_affine:
                warped_points_norm = self.pt.affPointTnf(
                    theta_aff, target_points_norm)
            elif self.__do_tps:
                warped_points_norm = self.pt.tpsPointTnf(
                    theta_tps, target_points_norm)

            warped_points_aff = PointsToPixelCoords(warped_points_norm,
                                                    source_im_size)
            reprojection_error = compute_reprojection_error(
                torch_keypointsA_var, warped_points_aff)
            matched = reprojection_error <= self.min_reprojection_error
        return reprojection_error, matched, num_mutual_keypoints
示例#4
0
    batch = batchTensorToVars(batch)
    print("batch", batch, batch['source_image'])
    source_im_size = batch['source_im_size']
    target_im_size = batch['target_im_size']

    source_points = batch['source_points']
    target_points = batch['target_points']

    # warp points with estimated transformations
    target_points_norm = PointsToUnitCoords(target_points, target_im_size)

    if do_aff:
        pass
        #model_aff.eval()
    if do_tps:
        model_tps.eval()

    # Evaluate models
    if do_aff:
        theta_aff = model_aff(batch)
        print("theta_aff", theta_aff)
        warped_image_aff = affTnf(batch['source_image'],
                                  theta_aff.view(-1, 2, 3))

    if do_tps:
        theta_tps = model_tps(batch)
        warped_image_tps = tpsTnf(batch['source_image'], theta_tps)

    if do_aff and do_tps:
        theta_aff_tps = model_tps({
            'source_image': warped_image_aff,
示例#5
0
for i, batch in enumerate(dataloader):
    # get random batch of size 1
    batch = batchTensorToVars(batch)
    
    source_im_size = batch['source_im_size']
    target_im_size = batch['target_im_size']

    source_points = batch['source_points']
    target_points = batch['target_points']
    
    # warp points with estimated transformations
    target_points_norm = PointsToUnitCoords(target_points,target_im_size)
    
    if do_aff:
        model_aff.eval()
        
    # Evaluate models
    if do_aff:
        theta_aff=model_aff(batch)
        warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3))
        print(theta_aff)
        print(theta_aff.view(-1,2,3))


    # Un-normalize images and convert to numpy
    source_image = normalize_image(batch['source_image'],forward=False)
    source_image = source_image.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()
    target_image = normalize_image(batch['target_image'],forward=False)
    target_image = target_image.data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy()