def test_from_json(self): self.assertEqual(Config(input_names=[], input_formats=[]), Config.from_json({})) self.assertEqual(Config(input_names=['x', 'y'], input_formats=[]), Config.from_json({'input_names': ['x', 'y']})) self.assertEqual( Config(input_names=['x', 'y'], input_formats=['channels_first']), Config.from_json({ 'input_names': ['x', 'y'], 'input_formats': ['channels_first'] })) self.assertEqual( Config(input_names=['x', 'y'], input_formats=['channels_first', 'channels_last']), Config.from_json({ 'input_names': ['x', 'y'], 'input_formats': ['channels_first', 'channels_last'] })) self.assertEqual( Config(input_names=[], input_formats=[], optimization=False, supported_types=None), Config.from_json({})) self.assertEqual( Config(input_names=[], input_formats=[], optimization=False, supported_types=None), Config.from_json({'optimization': False})) self.assertEqual( Config(input_names=[], input_formats=[], optimization=True, supported_types=None), Config.from_json({'optimization': True})) self.assertEqual( Config(input_names=[], input_formats=[], optimization=False, supported_types=[tf.float16]), Config.from_json({'supported_types': ['float16']})) self.assertEqual( Config(input_names=[], input_formats=[], optimization=False, supported_types=[tf.float16, tf.float32]), Config.from_json({'supported_types': ['float16', 'float32']})) self.assertEqual( Config(input_names=[], input_formats=[], optimization=True, supported_types=[tf.float16, tf.float32]), Config.from_json({ 'optimization': True, 'supported_types': ['float16', 'float32'] })) self.assertEqual( Config(input_names=[], input_formats=[], supported_ops=None), Config.from_json({})) self.assertEqual( Config(input_names=[], input_formats=[], supported_ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]), Config.from_json({'supported_ops': ['TFLITE_BUILTINS_INT8']})) self.assertEqual( Config(input_names=[], input_formats=[], supported_ops=[ tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ]), Config.from_json( {'supported_ops': ['SELECT_TF_OPS', 'TFLITE_BUILTINS_INT8']})) self.assertEqual( Config(input_names=[], input_formats=[], inference_input_type=None), Config.from_json({})) self.assertEqual( Config(input_names=[], input_formats=[], inference_input_type=tf.float32), Config.from_json({'inference_input_type': 'float32'})) self.assertEqual( Config(input_names=[], input_formats=[], inference_output_type=None), Config.from_json({})) self.assertEqual( Config(input_names=[], input_formats=[], inference_output_type=tf.float32), Config.from_json({'inference_output_type': 'float32'}))
def test_invalid_data_type(self): with self.assertRaises(ValueError): Config.from_json({'supported_types': ['foobar']}) with self.assertRaises(ValueError): Config.from_json({'supported_types': ['as_dtype']})