Exemplo n.º 1
0
 def create_and_check_mobilebert_for_sequence_classification(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     config.num_labels = self.num_labels
     model = MobileBertForSequenceClassification(config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
 def create_and_check_mobilebert_for_sequence_classification(
         self, config, input_ids, token_type_ids, input_mask,
         sequence_labels, token_labels, choice_labels):
     config.num_labels = self.num_labels
     model = MobileBertForSequenceClassification(config)
     model.to(torch_device)
     model.eval()
     loss, logits = model(input_ids,
                          attention_mask=input_mask,
                          token_type_ids=token_type_ids,
                          labels=sequence_labels)
     result = {
         "loss": loss,
         "logits": logits,
     }
     self.parent.assertListEqual(list(result["logits"].size()),
                                 [self.batch_size, self.num_labels])
     self.check_loss_output(result)
Exemplo n.º 3
0

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Export bert onnx model',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--input_dir',
        type=str,
        help='input_dir of bert model, must contain config.json')
    parser.add_argument('--task_name',
                        type=str,
                        choices=["mrpc", "mnli"],
                        help='tasks names of bert model')
    parser.add_argument('--max_len',
                        type=int,
                        default=128,
                        help='Maximum length of the sentence pairs')
    parser.add_argument('--do_lower_case',
                        type=bool,
                        default=True,
                        help='whether lower the tokenizer')
    parser.add_argument('--output_model',
                        type=str,
                        default='bert.onnx',
                        help='path to exported model file')
    args = parser.parse_args()

    model = MobileBertForSequenceClassification.from_pretrained(args.input_dir)
    export_onnx_model(args, model, args.output_model)