Esempio n. 1
0
 def test_bert_model(self, version):
     model_spec = ms.BertQAModelSpec(trainable=False, predict_batch_size=1)
     train_data, validation_data = _get_data(model_spec, version)
     model = question_answer.create(train_data,
                                    model_spec=model_spec,
                                    epochs=1,
                                    batch_size=1)
     self._test_f1_score(model, validation_data, 0.0)
     self._test_export_vocab(model)
     self._test_export_to_tflite(model, validation_data)
     self._test_export_to_saved_model(model)
 def test_bert_model(self):
   # Only test squad1.1 since it takes too long time for this.
   version = '1.1'
   model_spec = ms.BertQAModelSpec(trainable=False, predict_batch_size=1)
   train_data, validation_data = _get_data(model_spec, version)
   model = question_answer.create(
       train_data, model_spec=model_spec, epochs=1, batch_size=1)
   self._test_f1_score(model, validation_data, 0.0)
   self._test_export_vocab(model)
   self._test_export_to_tflite(
       model, validation_data, expected_json_file='bert_qa_metadata.json')
   self._test_export_to_saved_model(model)
Esempio n. 3
0
 def test_bert_model_v1_incompatible(self):
     with self.assertRaisesRegex(ValueError, 'Incompatible versions'):
         _ = ms.BertQAModelSpec(trainable=False)