예제 #1
0
    def test_keras_model_to_saved_model(self):
        with NamedTemporaryFile(
                suffix='.h5') as model_file, TemporaryDirectory(
                ) as target_dir:
            _save_model(path=model_file.name)

            with _use_env({
                    'SERVING_TYPE': 'tf',
                    'MODEL_NAME': 'foobar',
                    'VERSION': '4',
                    'MAX_BATCH_SIZE': '7',
                    'H5_PATH': model_file.name,
                    'INPUT_SIGNATURES': 'x',
                    'OUTPUT_SIGNATURES': 'y',
                    'EXPORT_PATH': target_dir
            }):
                compiler.compile_from_env()

            self.assertEqual(sorted(os.listdir(target_dir)),
                             ['foobar', 'foobar_4.zip'])
            self.assertEqual(
                sorted(os.listdir(os.path.join(target_dir, 'foobar'))),
                ['4', 'config.pbtxt'])

            self.assertEqual(
                sorted(os.listdir(os.path.join(target_dir, 'foobar', '4'))),
                ['saved_model.pb', 'variables'])

            self.assertEqual(
                sorted(
                    os.listdir(
                        os.path.join(target_dir, 'foobar', '4', 'variables'))),
                ['variables.data-00000-of-00001', 'variables.index'])
예제 #2
0
    def test_keras_model_to_saved_model(self):
        for k in [keras, tf.keras]:
            with self.subTest(k=k), NamedTemporaryFile(
                    suffix='.h5') as model_file, TemporaryDirectory(
                    ) as target_dir:
                _save_model(k=k, path=model_file.name)

                with _use_env({
                        'serving_type': 'tf',
                        'model_name': 'foobar',
                        'version': '4',
                        'max_batch_size': '7',
                        'h5_path': model_file.name,
                        'input_signatures': 'x',
                        'output_signatures': 'y',
                        'export_path': target_dir
                }):
                    compiler.compile_from_env()

                self.assertEqual(sorted(os.listdir(target_dir)),
                                 ['foobar', 'foobar_4.zip'])
                self.assertEqual(
                    sorted(os.listdir(os.path.join(target_dir, 'foobar'))),
                    ['4', 'config.pbtxt'])

                self.assertEqual(
                    sorted(os.listdir(os.path.join(target_dir, 'foobar',
                                                   '4'))),
                    ['saved_model.pb', 'variables'])

                self.assertEqual(
                    sorted(
                        os.listdir(
                            os.path.join(target_dir, 'foobar', '4',
                                         'variables'))),
                    ['variables.data-00000-of-00001', 'variables.index'])
예제 #3
0
 def test_invalid_source(self):
     with self.assertRaises(ValueError):
         compiler.compile_from_env()