Пример #1
0
 def test_TFBertForSequenceClassification(self):
     from transformers import BertConfig, TFBertForSequenceClassification
     keras.backend.clear_session()
     # pretrained_weights = 'bert-base-uncased'
     tokenizer_file = 'bert_bert-base-uncased.pickle'
     tokenizer = self._get_tokenzier(tokenizer_file)
     text, inputs, inputs_onnx = self._prepare_inputs(tokenizer)
     config = BertConfig()
     model = TFBertForSequenceClassification(config)
     predictions = model.predict(inputs)
     onnx_model = keras2onnx.convert_keras(model, model.name)
     self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, inputs_onnx, predictions, self.model_files))
Пример #2
0
config = BertConfig(num_labels=3, return_dict=True, model_type='bert-base-uncased')

model = TFBertForSequenceClassification(config=config)

if save_model:
    optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
    model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])

    model.fit(
        train_dataset[0],
        np.array(y_list),
        epochs=5,
        batch_size=BATCH_SIZE,
        callbacks=[cp_callback]
        )
else:
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    model.load_weights(latest)

preds = model.predict(val_dataset[0])["logits"]

preds_proba = tf.keras.backend.softmax(preds, axis=1)

classes = np.argmax(preds, axis=-1)

score = classification_report(y_val, classes, digits=3)
print(score)

total = time.time()  - start
print(f"Done in: {total}")