def test_bert_model(self, version): model_spec = ms.BertQAModelSpec(trainable=False, predict_batch_size=1) train_data, validation_data = _get_data(model_spec, version) model = question_answer.create(train_data, model_spec=model_spec, epochs=1, batch_size=1) self._test_f1_score(model, validation_data, 0.0) self._test_export_vocab(model) self._test_export_to_tflite(model, validation_data) self._test_export_to_saved_model(model)
def test_bert_model(self): # Only test squad1.1 since it takes too long time for this. version = '1.1' model_spec = ms.BertQAModelSpec(trainable=False, predict_batch_size=1) train_data, validation_data = _get_data(model_spec, version) model = question_answer.create( train_data, model_spec=model_spec, epochs=1, batch_size=1) self._test_f1_score(model, validation_data, 0.0) self._test_export_vocab(model) self._test_export_to_tflite( model, validation_data, expected_json_file='bert_qa_metadata.json') self._test_export_to_saved_model(model)
def test_bert_model_v1_incompatible(self): with self.assertRaisesRegex(ValueError, 'Incompatible versions'): _ = ms.BertQAModelSpec(trainable=False)