예제 #1
0
def load_torch_model(args, use_cuda):
    model = CNNGeometric(use_cuda=use_cuda, geometric_model='affine',
                         feature_extraction_cnn=args.feature_extraction_cnn)

    # Load trained weights
    print('Loading trained model weights...')
    checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
    checkpoint['state_dict'] = OrderedDict(
        [(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])
    model.load_state_dict(checkpoint['state_dict'])
    return model
예제 #2
0
def load_model(aff_params_path = '', aff_feat_ext = 'wormbrain_1', aff_feat_reg = 'simpler',
               tps_params_path = '', tps_feat_ext = 'wormbrain_1', tps_feat_reg = 'simpler'):
    """
    Loads a model. Assumes that each model (Affine and Thin-Plate Spline)
    have been trained separately. Must specify the architecture used for feature_extraction
    By default, it is resnet101
    """
    use_cuda = torch.cuda.is_available()
    #Only create a model for which weights have been provided
    do_aff = not aff_params_path==''
    do_tps = not tps_params_path==''
    if not do_aff and not do_tps: 
        print("No weights found. Models not created, exiting.")
        return
    
    print("Creating CNN model.")
    if do_aff:
        model_aff = CNNGeometric(output_dim=6,use_cuda=use_cuda,
                             feature_extraction_cnn= aff_feat_ext,
                                feature_regression=aff_feat_reg)
    if do_tps:
        model_tps = CNNGeometric(output_dim=18,use_cuda=use_cuda,
                             feature_extraction_cnn= tps_feat_ext,
                                feature_regression = tps_feat_reg)
    print("Loading trained model weights.")
    
    if do_aff: #Loading affine model    
        if aff_feat_ext == 'resnet101': aff_feat_ext = 'resnet'
            
        checkpoint = torch.load(aff_params_path, map_location=lambda storage, loc: storage)
        checkpoint['state_dict'] = OrderedDict([(k.replace(aff_feat_ext, 'model'), v) for k, v in checkpoint['state_dict'].items()])
        model_aff.load_state_dict(checkpoint['state_dict'])
        print('Weights for Affine model loaded.')
    
    if do_tps: #Loading thin plate spline model
        if tps_feat_ext == 'resnet101': aff_feat_ext = 'resnet'
    
        checkpoint = torch.load(tps_params_path, map_location=lambda storage, loc: storage)
        checkpoint['state_dict'] = OrderedDict([(k.replace(aff_feat_ext, 'model'), v) for k, v in checkpoint['state_dict'].items()])
        model_tps.load_state_dict(checkpoint['state_dict'])
        print('Weights for Thin-Plate Spline model loaded.')
    
    print('Returning model(s).')
    
    if do_aff and not do_tps:
        return model_aff
    if do_tps and not do_aff:
        return model_tps
    if do_aff and do_tps:
        return model_aff, model_tps
예제 #3
0
dataset_path = args.path
dataset_pairs_file = args.pairs
# Create model
print('Creating CNN model...')
if do_aff:
    model_aff = CNNGeometric(use_cuda=use_cuda, geometric_model='affine')
if do_tps:
    model_tps = CNNGeometric(use_cuda=use_cuda, geometric_model='tps')

# Load trained weights
print('Loading trained model weights...')
if do_aff:
    checkpoint = torch.load(args.model_aff,
                            map_location=lambda storage, loc: storage)
    model_aff.load_state_dict(checkpoint['state_dict'])
if do_tps:
    checkpoint = torch.load(args.model_tps,
                            map_location=lambda storage, loc: storage)
    model_tps.load_state_dict(checkpoint['state_dict'])

# Dataset and dataloader
dataset = PlacesDataset(csv_file=dataset_pairs_file,
                        training_image_path=dataset_path,
                        transform=NormalizeImageDict(
                            ['source_image', 'target_image']))
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)
batchTensorToVars = BatchTensorToVars(use_cuda=use_cuda)

# Instantiate point transformer
pt = PointTnf(use_cuda=use_cuda)
예제 #4
0
parser.add_argument('--model', type=str, default='trained_models/best_checkpoint_resnet18_adam_pose_mse_loss.pth.tar', help='Trained affine model filename')
#parser.add_argument('--model-tps', type=str, default='trained_models/best_pascal_checkpoint_adam_tps_grid_loss.pth.tar', help='Trained TPS model filename')
parser.add_argument('--path', type=str, default='/home/develop/Work/Datasets/', help='Path to PF dataset')
parser.add_argument('--pairs', type=str, default='/home/develop/Work/Datasets/gardens_pairs_path_samples_sift_RANSAC_12kps.csv', help='Path to PF dataset')
args = parser.parse_args()

dataset_path=args.path
dataset_pairs_file = args.pairs
# Create model
print('Creating CNN model...')

model = CNNGeometric(use_cuda=use_cuda,geometric_model='pose',arch = 'resnet18')

print('Load CNN Weights ...')
checkpoint = torch.load(args.model, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])

# Dataset and dataloader
dataset = PlacesDataset(csv_file=dataset_pairs_file,
                    training_image_path=dataset_path,
                    transform=NormalizeImageDict(['source_image','target_image']))
dataloader = DataLoader(dataset, batch_size=1,
                        shuffle=True, num_workers=4)
batchTensorToVars = BatchTensorToVars(use_cuda=use_cuda)

for source_im_path,target_im_path,batch in dataloader:
    # get random batch of size 1
    batch = batchTensorToVars(batch)

    source_im_size = batch['source_im_size']
    target_im_size = batch['target_im_size']
예제 #5
0
class CNNGeometricMatcher:
    def __init__(self,
                 use_extracted_features=False,
                 geometric_affine_model=None,
                 geometric_tps_model=None,
                 arch='resnet18',
                 featext_weights=None,
                 min_mutual_keypoints=6,
                 min_reprojection_error=200):
        self.min_mutual_keypoints = min_mutual_keypoints
        self.min_reprojection_error = min_reprojection_error
        self.__do_affine = geometric_affine_model is not None
        self.__do_tps = not use_extracted_features and geometric_tps_model is not None
        self.__affTnf = GeometricTnf(geometric_model='affine',
                                     use_cuda=_use_cuda)
        if self.__do_affine:
            checkpoint = torch.load(geometric_affine_model,
                                    map_location=lambda storage, loc: storage)
            print('Loading CNN Affine Geometric Model')
            if use_extracted_features:
                self.model_affine = CNNGeometricRegression(
                    use_cuda=use_cuda,
                    geometric_model='affine',
                    arch=arch,
                    featext_weights=featext_weights)
                model_dict = self.model_affine.state_dict()
                pretrained_dict = {
                    k: v
                    for k, v in checkpoint['state_dict'].items()
                    if k in model_dict
                }
                model_dict.update(pretrained_dict)
                self.model_affine.load_state_dict(model_dict)
            else:
                self.model_affine = CNNGeometric(
                    use_cuda=_use_cuda,
                    geometric_model='affine',
                    arch=arch,
                    featext_weights=featext_weights)

                self.model_affine.load_state_dict(checkpoint['state_dict'])

            self.model_affine.eval()

        if self.__do_tps:
            self.model_tps = CNNGeometric(use_cuda=use_cuda,
                                          geometric_model='tps',
                                          arch=arch,
                                          featext_weights=featext_weights)
            checkpoint = torch.load(geometric_tps_model,
                                    map_location=lambda storage, loc: storage)
            print('Loading CNN TPS Geometric Model')
            #self.model_tps.load_state_dict(checkpoint['state_dict'])
            model_dict = self.model_tps.state_dict()
            pretrained_dict = {
                k: v
                for k, v in checkpoint['state_dict'].items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            self.model_tps.load_state_dict(model_dict)
            self.model_tps.eval()

        self.pt = PointTnf(use_cuda=_use_cuda)

    def run(self, batch):
        if self.__do_affine:
            theta_aff, correlationAB, correlationBA = self.model_affine(batch)

        if self.__do_tps:
            if self.__do_affine:
                warped_image_aff = self.__affTnf(batch['source_image'],
                                                 theta_aff.view(-1, 2, 3))
                theta_affine_tps, correlationAB, correlationBA = self.model_tps(
                    {
                        'source_image': warped_image_aff,
                        'target_image': batch['target_image']
                    })
            else:
                theta_tps, correlationAB, correlationBA = self.model_tps(batch)

        keypoints_A, keypoints_B = find_mutual_matached_keypoints(
            correlationAB, correlationBA)
        num_mutual_keypoints = keypoints_A.shape[0]

        if num_mutual_keypoints < self.min_mutual_keypoints:
            matched = False
            reprojection_error = -1
        else:
            source_im_size = batch['source_im_size']
            target_im_size = batch['target_im_size']

            source_im_shape_np = source_im_size.data.numpy()
            target_im_shape_np = target_im_size.data.numpy()

            tensor_shape = correlationAB.data.numpy()[0].shape

            im_keypointsA = tensorPointstoPixels(
                keypoints_A,
                tensor_size=tensor_shape,
                im_size=(source_im_shape_np[0][1], source_im_shape_np[0][0]))
            im_keypointsB = tensorPointstoPixels(
                keypoints_B,
                tensor_size=tensor_shape,
                im_size=(target_im_shape_np[0][1], target_im_shape_np[0][0]))

            torch_keypointsA_var = Variable(
                torch.Tensor(
                    im_keypointsA.reshape(1, 2, -1).astype(np.float32)))
            torch_keypointsB_var = Variable(
                torch.Tensor(
                    im_keypointsB.reshape(1, 2, -1).astype(np.float32)))

            target_points_norm = PointsToUnitCoords(torch_keypointsB_var,
                                                    target_im_size)

            if self.__do_affine and self.__do_tps:
                warped_points_aff_tps_norm = self.pt.tpsPointTnf(
                    theta_affine_tps, target_points_norm)
                warped_points_norm = self.pt.affPointTnf(
                    theta_aff, warped_points_aff_tps_norm)
            elif self.__do_affine:
                warped_points_norm = self.pt.affPointTnf(
                    theta_aff, target_points_norm)
            elif self.__do_tps:
                warped_points_norm = self.pt.tpsPointTnf(
                    theta_tps, target_points_norm)

            warped_points_aff = PointsToPixelCoords(warped_points_norm,
                                                    source_im_size)
            reprojection_error = compute_reprojection_error(
                torch_keypointsA_var, warped_points_aff)
            matched = reprojection_error <= self.min_reprojection_error
        return reprojection_error, matched, num_mutual_keypoints