def test_mantis_mask_rcnn_predict_dl(sample_dataset, pretrained_state_dict): model = mask_rcnn.model(num_classes=91) model.load_state_dict(pretrained_state_dict) infer_dl = mask_rcnn.infer_dl(dataset=sample_dataset, batch_size=2) preds = mask_rcnn.predict_dl(model=model, infer_dl=infer_dl, show_pbar=False) _test_preds(preds)
def test_mantis_mask_rcnn_predict_dl_threshold(sample_dataset, pretrained_state_dict): model = mask_rcnn.model(num_classes=91) model.load_state_dict(pretrained_state_dict) infer_dl = mask_rcnn.infer_dl(dataset=sample_dataset, batch_size=2) samples, preds = mask_rcnn.predict_dl( model=model, infer_dl=infer_dl, show_pbar=False, detection_threshold=1.0, ) assert len(preds[0]["labels"]) == 0