Exemple #1
0
def test_faster_rcnn_mobile_param_groups():
    backbone = models.torchvision.faster_rcnn.backbones.mobilenet(
        pretrained=False)
    model = faster_rcnn.model(num_classes=6, backbone=backbone)

    param_groups = model.param_groups()
    assert len(param_groups) == 6
Exemple #2
0
def test_e2e_detect(samples_source, fridge_class_map, model_name,
                    param_groups_len):
    img_path = samples_source / "fridge/odFridgeObjects/images/10.jpg"
    tfms_ = tfms.A.Adapter([A.Resize(384, 384), A.Normalize()])

    backbone_fn = getattr(models.torchvision.faster_rcnn.backbones, model_name)
    backbone = backbone_fn(pretrained=False)
    model = faster_rcnn.model(num_classes=4, backbone=backbone)

    pred_dict = faster_rcnn.end2end_detect(img_path,
                                           tfms_,
                                           model,
                                           fridge_class_map,
                                           detection_threshold=1)
    assert len(pred_dict["detection"]["bboxes"]) == 0
Exemple #3
0
def fridge_faster_rcnn_model() -> nn.Module:
    backbone = faster_rcnn.backbones.resnet_fpn.resnet18(pretrained=False)
    return faster_rcnn.model(num_classes=5, backbone=backbone)
Exemple #4
0
def test_faster_rcnn_default_param_groups():
    model = faster_rcnn.model(num_classes=4)

    param_groups = model.param_groups()
    assert len(param_groups) == 8
def fridge_faster_rcnn_model() -> nn.Module:
    backbone = models.torchvision.faster_rcnn.backbones.resnet18_fpn(pretrained=False)
    return faster_rcnn.model(num_classes=5, backbone=backbone)
Exemple #6
0
def test_faster_rcnn_fpn_backbones(model_name, param_groups_len):
    backbone_fn = getattr(models.torchvision.faster_rcnn.backbones, model_name)
    backbone = backbone_fn(pretrained=False)

    model = faster_rcnn.model(num_classes=4, backbone=backbone)
    assert len(model.param_groups()) == param_groups_len