def test_detection_learner_init_model(od_detection_dataset):
    """ Tests detection learner with model settings. """
    classes = len(od_detection_dataset.labels)
    model = get_pretrained_fasterrcnn(num_classes=classes,
                                      min_size=600,
                                      max_size=2000)
    learner = DetectionLearner(od_detection_dataset, model=model)
    assert type(learner) == DetectionLearner
    assert learner.model == model
    assert learner.model != get_pretrained_fasterrcnn(classes)
Exemplo n.º 2
0
def od_detection_learner(od_detection_dataset):
    """ returns a basic detection learner that has been trained for one epoch. """
    model = get_pretrained_fasterrcnn(
        num_classes=len(od_detection_dataset.labels) + 1,
        min_size=100,
        max_size=200,
        rpn_pre_nms_top_n_train=500,
        rpn_pre_nms_top_n_test=250,
        rpn_post_nms_top_n_train=500,
        rpn_post_nms_top_n_test=250,
    )
    learner = DetectionLearner(od_detection_dataset, model=model)
    learner.fit(1)
    return learner
def test_get_pretrained_fasterrcnn():
    """ Simply test that `get_pretrained_fasterrcnn` returns the correct type for now. """
    assert type(get_pretrained_fasterrcnn(4)) == FasterRCNN