示例#1
0
    def _test_export_to_tflite(self,
                               model,
                               validation_data,
                               threshold=0.0,
                               atol=1e-04):
        tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
        model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE)

        self.assertTrue(os.path.isfile(tflite_output_file))
        self.assertGreater(os.path.getsize(tflite_output_file), 0)

        metric = model.evaluate_tflite(tflite_output_file, validation_data)
        self.assertGreaterEqual(metric['final_f1'], threshold)

        spec = model.model_spec
        input_word_ids = np.random.randint(low=0,
                                           high=len(spec.tokenizer.vocab),
                                           size=(1, spec.seq_len),
                                           dtype=np.int32)
        input_mask = np.random.randint(low=0,
                                       high=2,
                                       size=(1, spec.seq_len),
                                       dtype=np.int32)
        input_type_ids = np.random.randint(low=0,
                                           high=2,
                                           size=(1, spec.seq_len),
                                           dtype=np.int32)
        random_inputs = (input_word_ids, input_mask, input_type_ids)

        self.assertTrue(
            test_util.is_same_output(tflite_output_file,
                                     model.model,
                                     random_inputs,
                                     model.model_spec,
                                     atol=atol))
    def _test_export_to_tflite(self,
                               model,
                               threshold=1.0,
                               atol=1e-04,
                               expected_json_file=None):
        tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
        model.export(self.get_temp_dir(),
                     export_format=ExportFormat.TFLITE,
                     quantization_config=None,
                     export_metadata_json_file=expected_json_file is not None)

        self.assertTrue(tf.io.gfile.exists(tflite_output_file))
        self.assertGreater(os.path.getsize(tflite_output_file), 0)

        result = model.evaluate_tflite(tflite_output_file, self.test_data)
        self.assertGreaterEqual(result['accuracy'], threshold)

        spec = model.model_spec
        if isinstance(spec, text_spec.AverageWordVecModelSpec):
            random_inputs = np.random.randint(low=0,
                                              high=len(spec.vocab),
                                              size=(1, spec.seq_len),
                                              dtype=np.int32)
        elif isinstance(spec, text_spec.BertClassifierModelSpec):
            input_word_ids = np.random.randint(low=0,
                                               high=len(spec.tokenizer.vocab),
                                               size=(1, spec.seq_len),
                                               dtype=np.int32)
            input_mask = np.random.randint(low=0,
                                           high=2,
                                           size=(1, spec.seq_len),
                                           dtype=np.int32)
            input_type_ids = np.random.randint(low=0,
                                               high=2,
                                               size=(1, spec.seq_len),
                                               dtype=np.int32)
            random_inputs = (input_word_ids, input_mask, input_type_ids)
        else:
            raise ValueError('Unsupported model_spec type: %s' %
                             str(type(spec)))

        self.assertTrue(
            test_util.is_same_output(tflite_output_file,
                                     model.model,
                                     random_inputs,
                                     spec,
                                     atol=atol))

        if expected_json_file is not None:
            json_output_file = os.path.join(self.get_temp_dir(), 'model.json')
            self.assertTrue(os.path.isfile(json_output_file))
            self.assertGreater(os.path.getsize(json_output_file), 0)
            expected_json_file = test_util.get_test_data_path(
                expected_json_file)
            self.assertTrue(filecmp.cmp(json_output_file, expected_json_file))
示例#3
0
  def _test_export_to_tflite(self, model, threshold=0.0):
    tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite')

    model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE)

    result = model.evaluate_tflite(tflite_output_file, self.test_data)
    self.assertGreaterEqual(result['accuracy'], threshold)

    random_input = np.random.uniform(
        size=[1] + model.model_spec.input_image_shape + [3]).astype(np.float32)
    self.assertTrue(
        test_util.is_same_output(tflite_output_file, model.model, random_input,
                                 model.model_spec))
示例#4
0
  def _test_tflite(self,
                   keras_model,
                   tflite_model_file,
                   input_dim,
                   max_input_value=1000,
                   atol=1e-04):
    np.random.seed(0)
    random_input = np.random.uniform(
        low=0, high=max_input_value, size=(1, input_dim)).astype(np.float32)

    self.assertTrue(
        test_util.is_same_output(
            tflite_model_file, keras_model, random_input, atol=atol))
示例#5
0
    def _test_export_to_tflite(self, model, threshold=1.0):
        tflite_output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
        model.export(self.get_temp_dir(), export_format=ExportFormat.TFLITE)

        self.assertTrue(tf.io.gfile.exists(tflite_output_file))
        self.assertGreater(os.path.getsize(tflite_output_file), 0)

        result = model.evaluate_tflite(tflite_output_file, self.test_data)
        self.assertGreaterEqual(result['accuracy'], threshold)

        spec = model.model_spec
        if isinstance(spec, ms.AverageWordVecModelSpec):
            random_inputs = np.random.randint(low=0,
                                              high=len(spec.vocab),
                                              size=(1, spec.seq_len),
                                              dtype=np.int32)
        elif isinstance(spec, ms.BertClassifierModelSpec):
            input_word_ids = np.random.randint(low=0,
                                               high=len(spec.tokenizer.vocab),
                                               size=(1, spec.seq_len),
                                               dtype=np.int32)
            input_mask = np.random.randint(low=0,
                                           high=2,
                                           size=(1, spec.seq_len),
                                           dtype=np.int32)
            input_type_ids = np.random.randint(low=0,
                                               high=2,
                                               size=(1, spec.seq_len),
                                               dtype=np.int32)
            random_inputs = (input_word_ids, input_mask, input_type_ids)
        else:
            raise ValueError('Unsupported model_spec type: %s' %
                             str(type(spec)))

        self.assertTrue(
            test_util.is_same_output(tflite_output_file, model.model,
                                     random_inputs, spec))