Beispiel #1
0
def od_detection_keypoint_learner(tiny_od_detection_keypoint_dataset):
    """ returns a keypoint detection learner that has been trained for one epoch. """
    model = get_pretrained_keypointrcnn(
        num_classes=len(tiny_od_detection_keypoint_dataset.labels) + 1,
        num_keypoints=len(
            tiny_od_detection_keypoint_dataset.keypoint_meta["labels"]),
        min_size=100,
        max_size=200,
        rpn_pre_nms_top_n_train=500,
        rpn_pre_nms_top_n_test=250,
        rpn_post_nms_top_n_train=500,
        rpn_post_nms_top_n_test=250,
    )
    learner = DetectionLearner(tiny_od_detection_keypoint_dataset, model=model)
    learner.fit(1, skip_evaluation=True)
    return learner
Beispiel #2
0
def test_get_pretrained_keypointrcnn():
    """ Simply test that `get_pretrained_keypointrcnn` returns the correct type for now. """
    assert type(get_pretrained_keypointrcnn(2, 6)) == KeypointRCNN