コード例 #1
0
def test_init():
    model = ObjectDetector(num_classes=2)
    model.eval()

    batch_size = 2
    ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
    dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size)
    data = next(iter(dl))
    img = data[DefaultDataKeys.INPUT]

    out = model(img)

    assert len(out) == batch_size
    assert {"boxes", "labels", "scores"} <= out[0].keys()
コード例 #2
0
def test_init():
    model = ObjectDetector(num_classes=2)
    model.eval()

    batch_size = 2
    ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
    dl = model.process_predict_dataset(ds, batch_size=batch_size)
    data = next(iter(dl))

    out = model.forward(data[DefaultDataKeys.INPUT])

    assert len(out) == batch_size
    assert all(isinstance(res, dict) for res in out)
    assert all("bboxes" in res for res in out)
    assert all("labels" in res for res in out)
    assert all("scores" in res for res in out)
コード例 #3
0
def test_jit(tmpdir):
    path = os.path.join(tmpdir, "test.pt")

    model = ObjectDetector(2)
    model.eval()

    model = torch.jit.script(
        model)  # torch.jit.trace doesn't work with torchvision RCNN

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model([torch.rand(3, 32, 32)])

    # torchvision RCNN always returns a (Losses, Detections) tuple in scripting
    out = out[1]

    assert {"boxes", "labels", "scores"} <= out[0].keys()