def load_network(self):
     # load weights
     weights_filename = osp.expanduser(self.weights)
     if osp.isfile(weights_filename):
         loc_func = lambda storage, loc: storage
         checkpoint = torch.load(weights_filename, map_location=loc_func)
         load_state_dict(self.eval_net, checkpoint['model_state_dict'])
         print 'Loaded weights from {:s}'.format(weights_filename)
     else:
         print 'Could not load weights from {:s}'.format(weights_filename)
         sys.exit(-1)
示例#2
0
文件: eval.py 项目: zjudzl/geomapnet
if (args.model.find('mapnet') >= 0) or args.pose_graph:
    model = MapNet(mapnet=posenet)
else:
    model = posenet
model.eval()

# loss functions
t_criterion = lambda t_pred, t_gt: np.linalg.norm(t_pred - t_gt)
q_criterion = quaternion_angular_error

# load weights
weights_filename = osp.expanduser(args.weights)
if osp.isfile(weights_filename):
    loc_func = lambda storage, loc: storage
    checkpoint = torch.load(weights_filename, map_location=loc_func)
    load_state_dict(model, checkpoint['model_state_dict'])
    print 'Loaded weights from {:s}'.format(weights_filename)
else:
    print 'Could not load weights from {:s}'.format(weights_filename)
    sys.exit(-1)

data_dir = osp.join('..', 'data', args.dataset)
stats_filename = osp.join(data_dir, args.scene, 'stats.txt')
stats = np.loadtxt(stats_filename)
# transformer
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=stats[0], std=np.sqrt(stats[1]))
])
target_transform = transforms.Lambda(lambda x: torch.from_numpy(x).float())
示例#3
0
    calc_vos_safe,
)
from dataset_loaders.composite import MF
import argparse
import os
import os.path as osp
import sys
import numpy as np
import matplotlib

DISPLAY = "DISPLAY" in os.environ
if not DISPLAY:
    matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import configparser
import torch.cuda
from torch.utils.data import DataLoader
from torchvision import transforms, models
import cPickle

dropout = 0.5
feature_extractor = models.resnet34(pretrained=False)
model = PoseNet(feature_extractor, droprate=dropout, pretrained=False)
weights_filename = "./weights.pth"

loc_func = lambda storage, loc: storage
checkpoint = torch.load(weights_filename, map_location=loc_func)
load_state_dict(model, checkpoint["model_state_dict"])
print "Loaded weights from {:s}".format(weights_filename)