Exemplo n.º 1
0
def test_predict(tmpdir, head):
    model = ObjectDetector(num_classes=2, head=head, pretrained=False)
    ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)

    input_transform = IceVisionInputTransform()

    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    dl = model.process_train_dataset(
        ds,
        2,
        num_workers=0,
        pin_memory=False,
        input_transform=input_transform,
    )
    trainer.fit(model, dl)

    dl = model.process_predict_dataset(
        ds,
        2,
        input_transform=input_transform,
    )
    predictions = trainer.predict(model, dl, output="preds")
    assert len(predictions[0][0]["bboxes"]) > 0
    model.predict_kwargs = {"detection_threshold": 2}
    predictions = trainer.predict(model, dl, output="preds")
    assert len(predictions[0][0]["bboxes"]) == 0
Exemplo n.º 2
0
def test_init():
    model = ObjectDetector(num_classes=2)
    model.eval()

    batch_size = 2
    ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
    dl = model.process_predict_dataset(ds, batch_size=batch_size)
    data = next(iter(dl))

    out = model.forward(data[DefaultDataKeys.INPUT])

    assert len(out) == batch_size
    assert all(isinstance(res, dict) for res in out)
    assert all("bboxes" in res for res in out)
    assert all("labels" in res for res in out)
    assert all("scores" in res for res in out)