def _test_export_to_tflite(self, model, validation_data, threshold=0.0, atol=1e-04): tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite') model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE) self.assertTrue(os.path.isfile(tflite_output_file)) self.assertGreater(os.path.getsize(tflite_output_file), 0) metric = model.evaluate_tflite(tflite_output_file, validation_data) self.assertGreaterEqual(metric['final_f1'], threshold) spec = model.model_spec input_word_ids = np.random.randint(low=0, high=len(spec.tokenizer.vocab), size=(1, spec.seq_len), dtype=np.int32) input_mask = np.random.randint(low=0, high=2, size=(1, spec.seq_len), dtype=np.int32) input_type_ids = np.random.randint(low=0, high=2, size=(1, spec.seq_len), dtype=np.int32) random_inputs = (input_word_ids, input_mask, input_type_ids) self.assertTrue( test_util.is_same_output(tflite_output_file, model.model, random_inputs, model.model_spec, atol=atol))
def _test_export_to_tflite(self, model, threshold=1.0, atol=1e-04, expected_json_file=None): tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite') model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE, quantization_config=None, export_metadata_json_file=expected_json_file is not None) self.assertTrue(tf.io.gfile.exists(tflite_output_file)) self.assertGreater(os.path.getsize(tflite_output_file), 0) result = model.evaluate_tflite(tflite_output_file, self.test_data) self.assertGreaterEqual(result['accuracy'], threshold) spec = model.model_spec if isinstance(spec, text_spec.AverageWordVecModelSpec): random_inputs = np.random.randint(low=0, high=len(spec.vocab), size=(1, spec.seq_len), dtype=np.int32) elif isinstance(spec, text_spec.BertClassifierModelSpec): input_word_ids = np.random.randint(low=0, high=len(spec.tokenizer.vocab), size=(1, spec.seq_len), dtype=np.int32) input_mask = np.random.randint(low=0, high=2, size=(1, spec.seq_len), dtype=np.int32) input_type_ids = np.random.randint(low=0, high=2, size=(1, spec.seq_len), dtype=np.int32) random_inputs = (input_word_ids, input_mask, input_type_ids) else: raise ValueError('Unsupported model_spec type: %s' % str(type(spec))) self.assertTrue( test_util.is_same_output(tflite_output_file, model.model, random_inputs, spec, atol=atol)) if expected_json_file is not None: json_output_file = os.path.join(self.get_temp_dir(), 'model.json') self.assertTrue(os.path.isfile(json_output_file)) self.assertGreater(os.path.getsize(json_output_file), 0) expected_json_file = test_util.get_test_data_path( expected_json_file) self.assertTrue(filecmp.cmp(json_output_file, expected_json_file))
def _test_export_to_tflite(self, model, threshold=0.0): tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite') model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE) result = model.evaluate_tflite(tflite_output_file, self.test_data) self.assertGreaterEqual(result['accuracy'], threshold) random_input = np.random.uniform( size=[1] + model.model_spec.input_image_shape + [3]).astype(np.float32) self.assertTrue( test_util.is_same_output(tflite_output_file, model.model, random_input, model.model_spec))
def _test_tflite(self, keras_model, tflite_model_file, input_dim, max_input_value=1000, atol=1e-04): np.random.seed(0) random_input = np.random.uniform( low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32) self.assertTrue( test_util.is_same_output( tflite_model_file, keras_model, random_input, atol=atol))
def _test_export_to_tflite(self, model, threshold=1.0): tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite') model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE) self.assertTrue(tf.io.gfile.exists(tflite_output_file)) self.assertGreater(os.path.getsize(tflite_output_file), 0) result = model.evaluate_tflite(tflite_output_file, self.test_data) self.assertGreaterEqual(result['accuracy'], threshold) spec = model.model_spec if isinstance(spec, ms.AverageWordVecModelSpec): random_inputs = np.random.randint(low=0, high=len(spec.vocab), size=(1, spec.seq_len), dtype=np.int32) elif isinstance(spec, ms.BertClassifierModelSpec): input_word_ids = np.random.randint(low=0, high=len(spec.tokenizer.vocab), size=(1, spec.seq_len), dtype=np.int32) input_mask = np.random.randint(low=0, high=2, size=(1, spec.seq_len), dtype=np.int32) input_type_ids = np.random.randint(low=0, high=2, size=(1, spec.seq_len), dtype=np.int32) random_inputs = (input_word_ids, input_mask, input_type_ids) else: raise ValueError('Unsupported model_spec type: %s' % str(type(spec))) self.assertTrue( test_util.is_same_output(tflite_output_file, model.model, random_inputs, spec))