def od_detection_keypoint_learner(tiny_od_detection_keypoint_dataset): """ returns a keypoint detection learner that has been trained for one epoch. """ model = get_pretrained_keypointrcnn( num_classes=len(tiny_od_detection_keypoint_dataset.labels) + 1, num_keypoints=len( tiny_od_detection_keypoint_dataset.keypoint_meta["labels"]), min_size=100, max_size=200, rpn_pre_nms_top_n_train=500, rpn_pre_nms_top_n_test=250, rpn_post_nms_top_n_train=500, rpn_post_nms_top_n_test=250, ) learner = DetectionLearner(tiny_od_detection_keypoint_dataset, model=model) learner.fit(1, skip_evaluation=True) return learner
def test_get_pretrained_keypointrcnn(): """ Simply test that `get_pretrained_keypointrcnn` returns the correct type for now. """ assert type(get_pretrained_keypointrcnn(2, 6)) == KeypointRCNN