def test_bert(self, uri, is_tf2):
     model_spec = ms.BertClassifierModelSpec(uri,
                                             is_tf2=is_tf2,
                                             distribution_strategy='off',
                                             seq_len=3)
     self._test_convert_examples_to_features(model_spec)
     self._test_run_classifier(model_spec)
    def test_bert_model(self):
        model_spec = ms.BertClassifierModelSpec(seq_len=2, trainable=False)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.text_dir, model_spec=model_spec)
        # Splits data, 90% data for training, 10% for testing
        self.train_data, self.test_data = all_data.split(0.9)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.5)
示例#3
0
    def test_bert_model(self):
        model_spec = ms.BertClassifierModelSpec(seq_len=2, trainable=False)
        all_data = text_dataloader.TextClassifierDataLoader.from_folder(
            self.tiny_text_dir, model_spec=model_spec)
        # Splits data, 50% data for training, 50% for testing
        self.train_data, self.test_data = all_data.split(0.5)

        model = text_classifier.create(self.train_data,
                                       model_spec=model_spec,
                                       epochs=1,
                                       batch_size=1,
                                       shuffle=True)
        self._test_accuracy(model, 0.0)
        self._test_export_to_tflite(
            model,
            threshold=0.0,
            expected_json_file='bert_classifier_metadata.json')
        self._test_model_without_training(model_spec)