def test_compile_simple_fp16(self):
        with TemporaryDirectory() as model_dir:
            _save_saved_model_file(model_dir)
            compiled_1 = compiler.compile_source(
                source=SavedModelFile(model_path=model_dir),
                config=Config(input_formats=[]))
            compiled_2 = compiler.compile_source(
                source=SavedModelFile(model_path=model_dir),
                config=Config(optimization=True,
                              supported_types=[tf.float16],
                              input_formats=['']))

            self.assertLess(len(compiled_1.tflite_model),
                            len(compiled_2.tflite_model))
    def test_compile_select_tf_op(self):
        with TemporaryDirectory() as model_dir:
            _save_saved_model_file(model_dir)
            compiled = compiler.compile_source(
                source=SavedModelFile(model_path=model_dir),
                config=Config(supported_ops=[tf.lite.OpsSet.SELECT_TF_OPS],
                              input_formats=['channels_first']))

            self.assertIsInstance(compiled.tflite_model, bytes)
    def test_compile_simple(self):
        with TemporaryDirectory() as model_dir:
            _save_saved_model_file(model_dir)
            compiled = compiler.compile_source(
                source=SavedModelFile(model_path=model_dir),
                config=Config(
                    input_formats=['channels_first', 'channels_last']))

            self.assertIsInstance(compiled.tflite_model, bytes)
    def test_from_json(self):
        self.assertEqual(Config(input_names=[], input_formats=[]),
                         Config.from_json({}))
        self.assertEqual(Config(input_names=['x', 'y'], input_formats=[]),
                         Config.from_json({'input_names': ['x', 'y']}))
        self.assertEqual(
            Config(input_names=['x', 'y'], input_formats=['channels_first']),
            Config.from_json({
                'input_names': ['x', 'y'],
                'input_formats': ['channels_first']
            }))
        self.assertEqual(
            Config(input_names=['x', 'y'],
                   input_formats=['channels_first', 'channels_last']),
            Config.from_json({
                'input_names': ['x', 'y'],
                'input_formats': ['channels_first', 'channels_last']
            }))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=None), Config.from_json({}))
        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=None),
            Config.from_json({'optimization': False}))
        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=True,
                   supported_types=None),
            Config.from_json({'optimization': True}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=[tf.float16]),
            Config.from_json({'supported_types': ['float16']}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=[tf.float16, tf.float32]),
            Config.from_json({'supported_types': ['float16', 'float32']}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=True,
                   supported_types=[tf.float16, tf.float32]),
            Config.from_json({
                'optimization': True,
                'supported_types': ['float16', 'float32']
            }))

        self.assertEqual(
            Config(input_names=[], input_formats=[], supported_ops=None),
            Config.from_json({}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   supported_ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]),
            Config.from_json({'supported_ops': ['TFLITE_BUILTINS_INT8']}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   supported_ops=[
                       tf.lite.OpsSet.SELECT_TF_OPS,
                       tf.lite.OpsSet.TFLITE_BUILTINS_INT8
                   ]),
            Config.from_json(
                {'supported_ops': ['SELECT_TF_OPS', 'TFLITE_BUILTINS_INT8']}))

        self.assertEqual(
            Config(input_names=[], input_formats=[],
                   inference_input_type=None), Config.from_json({}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   inference_input_type=tf.float32),
            Config.from_json({'inference_input_type': 'float32'}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   inference_output_type=None), Config.from_json({}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   inference_output_type=tf.float32),
            Config.from_json({'inference_output_type': 'float32'}))
    def test_invalid_data_type(self):
        with self.assertRaises(ValueError):
            Config.from_json({'supported_types': ['foobar']})

        with self.assertRaises(ValueError):
            Config.from_json({'supported_types': ['as_dtype']})
    def test_from_env(self):
        self.assertEqual(Config(input_names=[], input_formats=[]),
                         Config.from_env({}))
        self.assertEqual(Config(input_names=['x', 'y'], input_formats=[]),
                         Config.from_env({'INPUT_NAMES': 'x,y'}))
        self.assertEqual(
            Config(input_names=['x', 'y'], input_formats=['channels_first']),
            Config.from_env({
                'INPUT_NAMES': 'x,y',
                'INPUT_FORMATS': 'channels_first'
            }))
        self.assertEqual(
            Config(input_names=['x', 'y'],
                   input_formats=['channels_first', 'channels_last']),
            Config.from_env({
                'INPUT_NAMES': 'x,y',
                'INPUT_FORMATS': 'channels_first,channels_last'
            }))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=None), Config.from_env({}))
        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=None),
            Config.from_env({'OPTIMIZATION': '0'}))
        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=True,
                   supported_types=None),
            Config.from_env({'OPTIMIZATION': '1'}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=[tf.float16]),
            Config.from_env({'SUPPORTED_TYPES': 'float16'}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=False,
                   supported_types=[tf.float16, tf.float32]),
            Config.from_env({'SUPPORTED_TYPES': 'float16,float32'}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   optimization=True,
                   supported_types=[tf.float16, tf.float32]),
            Config.from_env({
                'OPTIMIZATION': '1',
                'SUPPORTED_TYPES': 'float16,float32'
            }))

        self.assertEqual(
            Config(input_names=[], input_formats=[], supported_ops=None),
            Config.from_env({}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   supported_ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]),
            Config.from_env({'SUPPORTED_OPS': 'TFLITE_BUILTINS_INT8'}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   supported_ops=[
                       tf.lite.OpsSet.SELECT_TF_OPS,
                       tf.lite.OpsSet.TFLITE_BUILTINS_INT8
                   ]),
            Config.from_env(
                {'SUPPORTED_OPS': 'SELECT_TF_OPS,TFLITE_BUILTINS_INT8'}))

        self.assertEqual(
            Config(input_names=[], input_formats=[],
                   inference_input_type=None), Config.from_env({}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   inference_input_type=tf.float32),
            Config.from_env({'INFERENCE_INPUT_TYPE': 'float32'}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   inference_output_type=None), Config.from_env({}))

        self.assertEqual(
            Config(input_names=[],
                   input_formats=[],
                   inference_output_type=tf.float32),
            Config.from_env({'INFERENCE_OUTPUT_TYPE': 'float32'}))