def test_custom_load_tf_weights(self): model, output_loading_info = TFBertForTokenClassification.from_pretrained( "jplu/tiny-tf-bert-random", output_loading_info=True) self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"]) for layer in output_loading_info["missing_keys"]: self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
def create_and_check_bert_for_token_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels model = TFBertForTokenClassification(config=config) inputs = { "input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids, } result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): config.num_labels = self.num_labels model = TFBertForTokenClassification(config=config) inputs = {'input_ids': input_ids, 'attention_mask': input_mask, 'token_type_ids': token_type_ids} logits, = model(inputs) result = { "logits": logits.numpy(), } self.parent.assertListEqual( list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])