Example #1
0
def test_keypoints_rcnn_model():
    model = keypoint_rcnn.model(num_keypoints=1)
    param_groups = model.param_groups()

    assert len(param_groups) == 8
    assert model.roi_heads.keypoint_predictor.kps_score_lowres.out_channels == 1

    backbone = models.torchvision.keypoint_rcnn.backbones.resnet18_fpn(
        pretrained=True)
    model = keypoint_rcnn.model(backbone=backbone, num_keypoints=10)

    assert len(param_groups) == 8
    assert model.roi_heads.keypoint_predictor.kps_score_lowres.out_channels == 10
Example #2
0
def test_keypoints_rcnn_show_results(ochuman_ds, monkeypatch):
    monkeypatch.setattr(plt, "show", lambda: None)
    train_ds, valid_ds = ochuman_ds
    model = keypoint_rcnn.model(num_keypoints=19)

    keypoint_rcnn.show_results(model=model,
                               dataset=valid_ds,
                               num_samples=1,
                               ncols=1)
Example #3
0
def test_predict_keypoints_rcnn_train(ochuman_ds):
    _, valid_ds = ochuman_ds
    model = keypoint_rcnn.model(num_keypoints=19)

    infer_dl = keypoint_rcnn.infer_dl(valid_ds, batch_size=2)
    preds = keypoint_rcnn.predict_from_dl(model=model, infer_dl=infer_dl)
    p = preds[0].pred

    assert len(infer_dl) == 1
    assert len(preds) == 2
    assert len(list(p.detection.keypoints_scores)) == len(
        p.detection.keypoints)
    assert len(p.detection.bboxes) == len(p.detection.keypoints)
Example #4
0
def test_lightining_keypoints_rcnn_train(ochuman_keypoints_dls,
                                         light_model_cls):
    train_dl, valid_dl = ochuman_keypoints_dls
    model = keypoint_rcnn.model(num_keypoints=19)
    light_model = light_model_cls(model)

    trainer = pl.Trainer(
        max_epochs=1,
        weights_summary=None,
        num_sanity_val_steps=0,
        logger=False,
        checkpoint_callback=False,
    )
    trainer.fit(light_model, train_dl, valid_dl)
Example #5
0
def test_fastai_keypoints_rcnn_train(ochuman_keypoints_dls):
    model = keypoint_rcnn.model(num_keypoints=19)
    learn = keypoint_rcnn.fastai.learner(dls=[*ochuman_keypoints_dls],
                                         model=model)

    learn.fine_tune(1, 1e-4)
Example #6
0
def test_keypoint_rcnn_fpn_backbones(model_name, param_groups_len):
    backbone_fn = getattr(models.torchvision.keypoint_rcnn.backbones, model_name)
    backbone = backbone_fn(pretrained=False)

    model = keypoint_rcnn.model(num_keypoints=2, num_classes=4, backbone=backbone)
    assert len(model.param_groups()) == param_groups_len