コード例 #1
0
 def test_custom_load_tf_weights(self):
     model, output_loading_info = TFBertForTokenClassification.from_pretrained(
         "jplu/tiny-tf-bert-random", output_loading_info=True)
     self.assertEqual(sorted(output_loading_info["unexpected_keys"]),
                      ["mlm___cls", "nsp___cls"])
     for layer in output_loading_info["missing_keys"]:
         self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
コード例 #2
0
 def create_and_check_bert_for_token_classification(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     config.num_labels = self.num_labels
     model = TFBertForTokenClassification(config=config)
     inputs = {
         "input_ids": input_ids,
         "attention_mask": input_mask,
         "token_type_ids": token_type_ids,
     }
     result = model(inputs)
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
コード例 #3
0
 def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
     config.num_labels = self.num_labels
     model = TFBertForTokenClassification(config=config)
     inputs = {'input_ids': input_ids,
               'attention_mask': input_mask,
               'token_type_ids': token_type_ids}
     logits, = model(inputs)
     result = {
         "logits": logits.numpy(),
     }
     self.parent.assertListEqual(
         list(result["logits"].shape),
         [self.batch_size, self.seq_length, self.num_labels])