Exemplo n.º 1
0
    def test_can_detect(self):
        class TestExtractor(Extractor):
            def __iter__(self):
                return iter([
                    DatasetItem(id=1,
                                subset='train',
                                image=np.ones((16, 16, 3)),
                                annotations=[
                                    Bbox(0, 4, 4, 8, label=2),
                                ]),
                ])

            def categories(self):
                label_cat = LabelCategories()
                for label in range(10):
                    label_cat.add('label_' + str(label))
                return {
                    AnnotationType.label: label_cat,
                }

        def generate_dummy_tfrecord(path):
            TfDetectionApiConverter()(TestExtractor(), save_dir=path)

        with TestDir() as test_dir:
            generate_dummy_tfrecord(test_dir)

            self.assertTrue(TfDetectionApiImporter.detect(test_dir))
Exemplo n.º 2
0
 def test_can_detect(self):
     self.assertTrue(TfDetectionApiImporter.detect(DUMMY_DATASET_DIR))