Exemplo n.º 1
0
def od_detection_mask_learner(od_detection_mask_dataset):
    """ returns a mask detection learner that has been trained for one epoch. """
    model = get_pretrained_maskrcnn(
        num_classes=len(od_detection_mask_dataset.labels) + 1,
        min_size=100,
        max_size=200,
        rpn_pre_nms_top_n_train=500,
        rpn_pre_nms_top_n_test=250,
        rpn_post_nms_top_n_train=500,
        rpn_post_nms_top_n_test=250,
    )
    learner = DetectionLearner(od_detection_mask_dataset, model=model)
    learner.fit(1)
    return learner
Exemplo n.º 2
0
def od_detection_keypoint_learner(tiny_od_detection_keypoint_dataset):
    """ returns a keypoint detection learner that has been trained for one epoch. """
    model = get_pretrained_keypointrcnn(
        num_classes=len(tiny_od_detection_keypoint_dataset.labels) + 1,
        num_keypoints=len(
            tiny_od_detection_keypoint_dataset.keypoint_meta["labels"]),
        min_size=100,
        max_size=200,
        rpn_pre_nms_top_n_train=500,
        rpn_pre_nms_top_n_test=250,
        rpn_post_nms_top_n_train=500,
        rpn_post_nms_top_n_test=250,
    )
    learner = DetectionLearner(tiny_od_detection_keypoint_dataset, model=model)
    learner.fit(1, skip_evaluation=True)
    return learner