コード例 #1
0
ファイル: test_model.py プロジェクト: potipot/icevision
def test_mask_rcnn_mobile_param_groups():
    backbone = models.torchvision.mask_rcnn.backbones.mobilenet(
        pretrained=False)
    model = mask_rcnn.model(num_classes=6, backbone=backbone)

    param_groups = model.param_groups()
    assert len(param_groups) == 6
コード例 #2
0
ファイル: test_predict.py プロジェクト: sasakits/icevision
def test_mantis_mask_rcnn_predict_dl(sample_dataset, pretrained_state_dict):
    model = mask_rcnn.model(num_classes=91)
    model.load_state_dict(pretrained_state_dict)

    infer_dl = mask_rcnn.infer_dl(dataset=sample_dataset, batch_size=2)
    preds = mask_rcnn.predict_dl(model=model, infer_dl=infer_dl, show_pbar=False)
    _test_preds(preds)
コード例 #3
0
def test_mantis_mask_rcnn_predict(sample_dataset, pretrained_state_dict):
    model = mask_rcnn.model(num_classes=91)
    model.load_state_dict(pretrained_state_dict)

    batch, samples = mask_rcnn.build_infer_batch(dataset=sample_dataset)
    preds = mask_rcnn.predict(model=model, batch=batch)

    _test_preds(preds)
コード例 #4
0
def test_mantis_mask_rcnn_predict_dl_threshold(sample_dataset, pretrained_state_dict):
    model = mask_rcnn.model(num_classes=91)
    model.load_state_dict(pretrained_state_dict)

    infer_dl = mask_rcnn.infer_dl(dataset=sample_dataset, batch_size=2)
    samples, preds = mask_rcnn.predict_dl(
        model=model,
        infer_dl=infer_dl,
        show_pbar=False,
        detection_threshold=1.0,
    )

    assert len(preds[0]["labels"]) == 0
コード例 #5
0
ファイル: test_predict.py プロジェクト: kedarisetti/icevision
def test_mantis_mask_rcnn_predict(sample_dataset, pretrained_state_dict):
    model = mask_rcnn.model(num_classes=91)
    model.load_state_dict(pretrained_state_dict)

    preds = mask_rcnn.predict(model=model, dataset=sample_dataset)
    _test_preds(preds)
コード例 #6
0
def test_mask_rcnn_fpn_backbones(model_name, param_groups_len):
    backbone_fn = getattr(models.torchvision.mask_rcnn.backbones, model_name)
    backbone = backbone_fn(pretrained=False)

    model = mask_rcnn.model(num_classes=4, backbone=backbone)
    assert len(model.param_groups()) == param_groups_len
コード例 #7
0
ファイル: test_model.py プロジェクト: potipot/icevision
def test_mask_rcnn_default_param_groups():
    model = mask_rcnn.model(num_classes=4)

    param_groups = model.param_groups()
    assert len(param_groups) == 8