Ejemplo n.º 1
0
    def _train_and_test(self,
                        model_path,
                        output_model_path,
                        training_datapoints,
                        test_datapoints,
                        keep_classes=False):
        engine = ImprintingEngine(model_path, keep_classes)
        image_shape = self._get_image_shape(model_path)
        data_dir = test_utils.test_data_path('imprinting')
        # train.
        for training_datapoint in training_datapoints:
            engine.train(
                test_utils.prepare_images(training_datapoint['image_names'],
                                          data_dir, image_shape),
                training_datapoint['label_id'])
        engine.save_model(output_model_path)

        # Test.
        engine = ClassificationEngine(output_model_path)
        self.assertEqual(1, engine.get_num_of_output_tensors())
        if not keep_classes:
            self.assertEqual(len(training_datapoints),
                             engine.get_output_tensor_size(0))
        for test_datapoint in test_datapoints:
            self._classify_image(engine, data_dir,
                                 test_datapoint['image_name'],
                                 test_datapoint['label_id'],
                                 test_datapoint['score'])
Ejemplo n.º 2
0
    def test_train_all(self):
        for model_path in self._MODEL_LIST:
            with self.subTest():
                with test_utils.TemporaryFile(
                        suffix='.tflite') as output_model_path:
                    data_dir = test_utils.test_data_path('imprinting')
                    engine = ImprintingEngine(model_path, keep_classes=False)
                    image_shape = self._get_image_shape(model_path)

                    # train.
                    train_set = [['cat_train_0.bmp'], ['dog_train_0.bmp'],
                                 ['hotdog_train_0.bmp', 'hotdog_train_1.bmp']]
                    train_input = [
                        (test_utils.prepare_images(image_list, data_dir,
                                                   image_shape))
                        for image_list in train_set
                    ]
                    engine.train_all(train_input)
                    engine.save_model(output_model_path.name)

                    # Test.
                    engine = ClassificationEngine(output_model_path.name)
                    self.assertEqual(1, engine.get_num_of_output_tensors())
                    self.assertEqual(3, engine.get_output_tensor_size(0))

                    label_to_id_map = {'cat': 0, 'dog': 1, 'hot_dog': 2}
                    self._classify_image(engine, data_dir, 'cat_test_0.bmp',
                                         label_to_id_map['cat'], 0.99)
                    self._classify_image(engine, data_dir, 'dog_test_0.bmp',
                                         label_to_id_map['dog'], 0.99)
                    self._classify_image(engine, data_dir, 'hotdog_test_0.bmp',
                                         label_to_id_map['hot_dog'], 0.99)