def build_network(self): # model feature_extractor = models.resnet34(pretrained=False) posenet = PoseNet(feature_extractor, droprate=self.dropout, pretrained=False) if self.model.find('mapnet') >= 0: model = MapNet(mapnet=posenet) else: model = posenet return model.eval()
fc_vos = args.dataset == 'RobotCar' if args.pose_graph: vo_lib = section.get('vo_lib') sax = section.getfloat('s_abs_trans', 1) saq = section.getfloat('s_abs_rot', 1) srx = section.getfloat('s_rel_trans', 20) srq = section.getfloat('s_rel_rot', 20) # model feature_extractor = models.resnet34(pretrained=False) posenet = PoseNet(feature_extractor, droprate=dropout, pretrained=False) if (args.model.find('mapnet') >= 0) or args.pose_graph: model = MapNet(mapnet=posenet) else: model = posenet model.eval() # loss functions t_criterion = lambda t_pred, t_gt: np.linalg.norm(t_pred - t_gt) q_criterion = quaternion_angular_error # load weights weights_filename = osp.expanduser(args.weights) if osp.isfile(weights_filename): loc_func = lambda storage, loc: storage checkpoint = torch.load(weights_filename, map_location=loc_func) load_state_dict(model, checkpoint['model_state_dict']) print 'Loaded weights from {:s}'.format(weights_filename) else: print 'Could not load weights from {:s}'.format(weights_filename) sys.exit(-1)