def testQuantizationInvalid(self): with ops.Graph().as_default(): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor, min=0., max=1.) sess = session.Session() with self.assertRaises(ValueError) as error: convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor], inference_type=lite_constants.QUANTIZED_UINT8) self.assertEqual( "std_dev and mean must be defined when inference_type or " "inference_input_type is QUANTIZED_UINT8 or INT8.", str(error.exception)) with self.assertRaises(ValueError) as error: convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor], inference_type=lite_constants.QUANTIZED_UINT8, inference_input_type=lite_constants.FLOAT) self.assertEqual( "std_dev and mean must be defined when inference_type or " "inference_input_type is QUANTIZED_UINT8 or INT8.", str(error.exception))
def testQuantizationInvalid(self): in_tensor = array_ops.placeholder( shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = array_ops.fake_quant_with_min_max_args( in_tensor + in_tensor, min=0., max=1.) sess = session.Session() with self.assertRaises(ValueError) as error: convert.toco_convert( sess.graph_def, [in_tensor], [out_tensor], inference_type=lite_constants.QUANTIZED_UINT8) self.assertEqual( "std_dev and mean must be defined when inference_input_type is " "QUANTIZED_UINT8.", str(error.exception))
def testBasic(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() # Try running on valid graph tflite_model = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) self.assertTrue(tflite_model)
def testQuantization(self): in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor, min=0., max=1.) sess = session.Session() tflite_model = convert.toco_convert( sess.graph_def, [in_tensor], [out_tensor], inference_type=lite_constants.QUANTIZED_UINT8, quantized_input_stats=[(0., 1.)]) self.assertTrue(tflite_model)