def build_network(self):
        # model
        feature_extractor = models.resnet34(pretrained=False)
        posenet = PoseNet(feature_extractor, droprate=self.dropout, pretrained=False)

        if self.model.find('mapnet') >= 0:
            model = MapNet(mapnet=posenet)
        else:
            model = posenet

        return model.eval()
示例#2
0
    srq = section.getfloat('gamma')

section = settings['training']
seed = section.getint('seed')

# model
feature_extractor = models.resnet34(pretrained=True)
posenet = PoseNet(feature_extractor,
                  droprate=dropout,
                  pretrained=True,
                  filter_nans=False)

if args.model == 'posenet':
    model = posenet
elif args.model.find('mapnet') >= 0:
    model = MapNet(mapnet=posenet)
else:
    raise NotImplementedError

# loss function
if args.model == 'posenet':
    train_criterion = PoseNetCriterion(sax=sax, saq=saq, learn_beta=True)
    val_criterion = PoseNetCriterion()
elif args.model.find('mapnet') >= 0:
    kwargs = dict(sax=sax,
                  saq=saq,
                  srx=srx,
                  srq=srq,
                  learn_beta=True,
                  learn_gamma=True)
    train_criterion = MapNetCriterion(**kwargs)
示例#3
0
文件: eval.py 项目: zjudzl/geomapnet
    skip = section.getint('skip')
    real = section.getboolean('real')
    variable_skip = section.getboolean('variable_skip')
    fc_vos = args.dataset == 'RobotCar'
    if args.pose_graph:
        vo_lib = section.get('vo_lib')
        sax = section.getfloat('s_abs_trans', 1)
        saq = section.getfloat('s_abs_rot', 1)
        srx = section.getfloat('s_rel_trans', 20)
        srq = section.getfloat('s_rel_rot', 20)

# model
feature_extractor = models.resnet34(pretrained=False)
posenet = PoseNet(feature_extractor, droprate=dropout, pretrained=False)
if (args.model.find('mapnet') >= 0) or args.pose_graph:
    model = MapNet(mapnet=posenet)
else:
    model = posenet
model.eval()

# loss functions
t_criterion = lambda t_pred, t_gt: np.linalg.norm(t_pred - t_gt)
q_criterion = quaternion_angular_error

# load weights
weights_filename = osp.expanduser(args.weights)
if osp.isfile(weights_filename):
    loc_func = lambda storage, loc: storage
    checkpoint = torch.load(weights_filename, map_location=loc_func)
    load_state_dict(model, checkpoint['model_state_dict'])
    print 'Loaded weights from {:s}'.format(weights_filename)
示例#4
0
文件: eval.py 项目: devyhia/geomapnet
    skip = section.getint('skip')
    real = section.getboolean('real')
    variable_skip = section.getboolean('variable_skip')
    fc_vos = args.dataset == 'RobotCar'
    if args.pose_graph:
        vo_lib = section.get('vo_lib')
        sax = section.getfloat('s_abs_trans', 1)
        saq = section.getfloat('s_abs_rot', 1)
        srx = section.getfloat('s_rel_trans', 20)
        srq = section.getfloat('s_rel_rot', 20)

# model
feature_extractor = models.resnet34(pretrained=False)
posenet = PoseNet(feature_extractor, droprate=dropout, pretrained=False)
if (args.model.find('mapnet') >= 0) or args.pose_graph:
    model = MapNet(mapnet=posenet)
else:
    model = posenet
model.eval()

# loss functions
t_criterion = lambda t_pred, t_gt: np.linalg.norm(t_pred - t_gt)
q_criterion = quaternion_angular_error

# load weights
weights_filename = osp.expanduser(args.weights)
if osp.isfile(weights_filename):
    loc_func = lambda storage, loc: storage
    checkpoint = torch.load(weights_filename, map_location=loc_func)
    load_state_dict(model, checkpoint['model_state_dict'])
    print 'Loaded weights from {:s}'.format(weights_filename)