def test_keypoints_rcnn_model(): model = keypoint_rcnn.model(num_keypoints=1) param_groups = model.param_groups() assert len(param_groups) == 8 assert model.roi_heads.keypoint_predictor.kps_score_lowres.out_channels == 1 backbone = models.torchvision.keypoint_rcnn.backbones.resnet18_fpn( pretrained=True) model = keypoint_rcnn.model(backbone=backbone, num_keypoints=10) assert len(param_groups) == 8 assert model.roi_heads.keypoint_predictor.kps_score_lowres.out_channels == 10
def test_keypoints_rcnn_show_results(ochuman_ds, monkeypatch): monkeypatch.setattr(plt, "show", lambda: None) train_ds, valid_ds = ochuman_ds model = keypoint_rcnn.model(num_keypoints=19) keypoint_rcnn.show_results(model=model, dataset=valid_ds, num_samples=1, ncols=1)
def test_predict_keypoints_rcnn_train(ochuman_ds): _, valid_ds = ochuman_ds model = keypoint_rcnn.model(num_keypoints=19) infer_dl = keypoint_rcnn.infer_dl(valid_ds, batch_size=2) preds = keypoint_rcnn.predict_from_dl(model=model, infer_dl=infer_dl) p = preds[0].pred assert len(infer_dl) == 1 assert len(preds) == 2 assert len(list(p.detection.keypoints_scores)) == len( p.detection.keypoints) assert len(p.detection.bboxes) == len(p.detection.keypoints)
def test_lightining_keypoints_rcnn_train(ochuman_keypoints_dls, light_model_cls): train_dl, valid_dl = ochuman_keypoints_dls model = keypoint_rcnn.model(num_keypoints=19) light_model = light_model_cls(model) trainer = pl.Trainer( max_epochs=1, weights_summary=None, num_sanity_val_steps=0, logger=False, checkpoint_callback=False, ) trainer.fit(light_model, train_dl, valid_dl)
def test_fastai_keypoints_rcnn_train(ochuman_keypoints_dls): model = keypoint_rcnn.model(num_keypoints=19) learn = keypoint_rcnn.fastai.learner(dls=[*ochuman_keypoints_dls], model=model) learn.fine_tune(1, 1e-4)
def test_keypoint_rcnn_fpn_backbones(model_name, param_groups_len): backbone_fn = getattr(models.torchvision.keypoint_rcnn.backbones, model_name) backbone = backbone_fn(pretrained=False) model = keypoint_rcnn.model(num_keypoints=2, num_classes=4, backbone=backbone) assert len(model.param_groups()) == param_groups_len