def test_mask_rcnn_mobile_param_groups(): backbone = models.torchvision.mask_rcnn.backbones.mobilenet( pretrained=False) model = mask_rcnn.model(num_classes=6, backbone=backbone) param_groups = model.param_groups() assert len(param_groups) == 6
def test_mantis_mask_rcnn_predict_dl(sample_dataset, pretrained_state_dict): model = mask_rcnn.model(num_classes=91) model.load_state_dict(pretrained_state_dict) infer_dl = mask_rcnn.infer_dl(dataset=sample_dataset, batch_size=2) preds = mask_rcnn.predict_dl(model=model, infer_dl=infer_dl, show_pbar=False) _test_preds(preds)
def test_mantis_mask_rcnn_predict(sample_dataset, pretrained_state_dict): model = mask_rcnn.model(num_classes=91) model.load_state_dict(pretrained_state_dict) batch, samples = mask_rcnn.build_infer_batch(dataset=sample_dataset) preds = mask_rcnn.predict(model=model, batch=batch) _test_preds(preds)
def test_mantis_mask_rcnn_predict_dl_threshold(sample_dataset, pretrained_state_dict): model = mask_rcnn.model(num_classes=91) model.load_state_dict(pretrained_state_dict) infer_dl = mask_rcnn.infer_dl(dataset=sample_dataset, batch_size=2) samples, preds = mask_rcnn.predict_dl( model=model, infer_dl=infer_dl, show_pbar=False, detection_threshold=1.0, ) assert len(preds[0]["labels"]) == 0
def test_mantis_mask_rcnn_predict(sample_dataset, pretrained_state_dict): model = mask_rcnn.model(num_classes=91) model.load_state_dict(pretrained_state_dict) preds = mask_rcnn.predict(model=model, dataset=sample_dataset) _test_preds(preds)
def test_mask_rcnn_fpn_backbones(model_name, param_groups_len): backbone_fn = getattr(models.torchvision.mask_rcnn.backbones, model_name) backbone = backbone_fn(pretrained=False) model = mask_rcnn.model(num_classes=4, backbone=backbone) assert len(model.param_groups()) == param_groups_len
def test_mask_rcnn_default_param_groups(): model = mask_rcnn.model(num_classes=4) param_groups = model.param_groups() assert len(param_groups) == 8