def test_detection_learner_init_model(od_detection_dataset): """ Tests detection learner with model settings. """ classes = len(od_detection_dataset.labels) model = get_pretrained_fasterrcnn(num_classes=classes, min_size=600, max_size=2000) learner = DetectionLearner(od_detection_dataset, model=model) assert type(learner) == DetectionLearner assert learner.model == model assert learner.model != get_pretrained_fasterrcnn(classes)
def od_detection_learner(od_detection_dataset): """ returns a basic detection learner that has been trained for one epoch. """ model = get_pretrained_fasterrcnn( num_classes=len(od_detection_dataset.labels) + 1, min_size=100, max_size=200, rpn_pre_nms_top_n_train=500, rpn_pre_nms_top_n_test=250, rpn_post_nms_top_n_train=500, rpn_post_nms_top_n_test=250, ) learner = DetectionLearner(od_detection_dataset, model=model) learner.fit(1) return learner
def test_get_pretrained_fasterrcnn(): """ Simply test that `get_pretrained_fasterrcnn` returns the correct type for now. """ assert type(get_pretrained_fasterrcnn(4)) == FasterRCNN