예제 #1
0
 def __init__(self, model_path, use_cuda=True,num_class = 751):
     self.net = Net(reid=True,num_classes = num_class)
     self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
     state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
     self.net.load_state_dict(state_dict)
     logger = logging.getLogger("root.tracker")
     logger.info("Loading weights from {}... Done!".format(model_path))
     self.net.to(self.device)
     self.size = (64, 128)
     self.norm = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
     ])
예제 #2
0
class Extractor(object):
    def __init__(self, model_path, use_cuda=True):
        self.net = Net(reid=True)
        self.device = "cuda" if torch.cuda.is_available(
        ) and use_cuda else "cpu"
        state_dict = torch.load(
            model_path, map_location=lambda storage, loc: storage)['net_dict']
        self.net.load_state_dict(state_dict)
        print("Loading weights from {}... Done!".format(model_path))
        self.net.to(self.device)
        self.size = (64, 128)
        self.norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def _preprocess(self, im_crops):
        """
        TODO:
            1. to float with scale from 0 to 1
            2. resize to (64, 128) as Market1501 dataset did
            3. concatenate to a numpy array
            3. to torch Tensor
            4. normalize
        """
        def _resize(im, size):
            return cv2.resize(im.astype(np.float32) / 255., size)

        im_batch = torch.cat([
            self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops
        ],
                             dim=0).float()
        return im_batch

    def __call__(self, im_crops):
        im_batch = self._preprocess(im_crops)
        with torch.no_grad():
            im_batch = im_batch.to(self.device)
            features = self.net(im_batch)
        return features.cpu().numpy()
예제 #3
0
    torchvision.transforms.Resize((128, 64)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
])
queryloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
    query_dir, transform=transform),
                                          batch_size=64,
                                          shuffle=False)
galleryloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
    gallery_dir, transform=transform),
                                            batch_size=64,
                                            shuffle=False)

# net definition
net = Net(reid=True)
assert os.path.isfile(
    "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
print('Loading from checkpoint/ckpt.t7')
checkpoint = torch.load("./checkpoint/ckpt.t7")
net_dict = checkpoint['net_dict']
net.load_state_dict(net_dict, strict=False)
net.eval()
net.to(device)

# compute features
query_features = torch.tensor([]).float()
query_labels = torch.tensor([]).long()
gallery_features = torch.tensor([]).float()
gallery_labels = torch.tensor([]).long()
예제 #4
0
def export_net_onnx():
    model_path = '/home/liyongjing/Egolee_2021/programs/deep_sort_pytorch-master/deep_sort/deep/checkpoint/ckpt.t7'
    model = Net(reid=True)
    state_dict = torch.load(
        model_path, map_location=lambda storage, loc: storage)['net_dict']
    model.load_state_dict(state_dict)

    img_size = (64, 128)
    batch_size = 2
    img = torch.randn(batch_size, 3, *img_size[::-1])
    # img = torch.ones(batch_size, 3, *img_size[::-1])
    model.eval()
    y = model(img)  # dry run

    try:
        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = osp.splitext(model_path)[0] + '.onnx'
        torch.onnx.export(
            model,
            img,
            f,
            verbose=True,
            opset_version=9,
            input_names=['images'],
            output_names=['classes', 'boxes'] if y is None else ['output'])

        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        onnx.checker.check_model(onnx_model)  # check onnx model
        # print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
        print('====ONNX export success, saved as %s' % f)

        # simpily onnx
        from onnxsim import simplify
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"

        f2 = f.replace('.onnx', '_sim.onnx')  # filename
        onnx.save(model_simp, f2)
        print('====ONNX SIM export success, saved as %s' % f2)

        # check output different between pytorch and onnx: y, y_onnx
        import onnxruntime as rt
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        # net_feed_input = input_all
        assert (len(net_feed_input) == 1)
        sess = rt.InferenceSession(f2)
        y_onnx = sess.run(None, {net_feed_input[0]: img.detach().numpy()})[0]

        # for i, (_y, _y_onnx) in enumerate(zip(y, y_onnx)):
        y_numpy = y.detach().numpy()
        # all_close = np.allclose(_y_numpy, _y_onnx, rtol=1e-05, atol=1e-06)

        # for x, y in zip(y_numpy[0, 0:20], y_onnx[0, 0:20]):
        #     print(x)
        #     print(y)
        #     print('*' * 10)
        #
        print(y_numpy.shape)
        print(y_onnx.shape)

        print(y_numpy[0, 0:20])

        diff = y_numpy - y_onnx

        print('max diff {}'.format(np.max(diff)))
        # assert(np.max(diff) > 1e-5)

        from onnx import shape_inference
        f3 = f2.replace('.onnx', '_shape.onnx')  # filename
        onnx.save(onnx.shape_inference.infer_shapes(onnx.load(f2)), f3)
        print('====ONNX shape inference export success, saved as %s' % f3)

    except Exception as e:
        print('ONNX export failure: %s' % e)
예제 #5
0
    torchvision.transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
])
trainloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
    train_dir, transform=transform_train),
                                          batch_size=64,
                                          shuffle=True)
testloader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
    test_dir, transform=transform_test),
                                         batch_size=64,
                                         shuffle=True)
num_classes = len(trainloader.dataset.classes)

# net definition
start_epoch = 0
net = Net(num_classes=num_classes)
if args.resume:
    assert os.path.isfile(
        "./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
    print('Loading from checkpoint/ckpt.t7')
    checkpoint = torch.load("./checkpoint/ckpt.t7")
    # import ipdb; ipdb.set_trace()
    net_dict = checkpoint['net_dict']
    net.load_state_dict(net_dict)
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
net.to(device)

# loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(),