Exemplo n.º 1
0
def test_init():
    model = ObjectDetector(num_classes=2)
    model.eval()

    batch_size = 2
    ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
    dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size)
    data = next(iter(dl))
    img = data[DefaultDataKeys.INPUT]

    out = model(img)

    assert len(out) == batch_size
    assert {"boxes", "labels", "scores"} <= out[0].keys()
Exemplo n.º 2
0
def test_init():
    model = ObjectDetector(num_classes=2)
    model.eval()

    batch_size = 2
    ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
    dl = model.process_predict_dataset(ds, batch_size=batch_size)
    data = next(iter(dl))

    out = model.forward(data[DefaultDataKeys.INPUT])

    assert len(out) == batch_size
    assert all(isinstance(res, dict) for res in out)
    assert all("bboxes" in res for res in out)
    assert all("labels" in res for res in out)
    assert all("scores" in res for res in out)
Exemplo n.º 3
0
def test_jit(tmpdir):
    path = os.path.join(tmpdir, "test.pt")

    model = ObjectDetector(2)
    model.eval()

    model = torch.jit.script(
        model)  # torch.jit.trace doesn't work with torchvision RCNN

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model([torch.rand(3, 32, 32)])

    # torchvision RCNN always returns a (Losses, Detections) tuple in scripting
    out = out[1]

    assert {"boxes", "labels", "scores"} <= out[0].keys()