class BertNER(nn.Module): def __init__(self, vocab_size=None, device='cpu', training=False): super().__init__() bert_vocab_size = 30522 config = BertConfig(bert_vocab_size, max_position_embeddings=512) self.bert = BertModel(config).from_pretrained('bert-base-cased').to( device) self.classifier = nn.Linear(768, vocab_size) self.device = device self.training = training self.bert.eval() def forward(self, x): x = x.to(self.device) if self.training: self.bert.train() layers_out, _ = self.bert(x) last_layer = layers_out[-1] else: with torch.no_grad(): layers_out, _ = self.bert(x) last_layer = layers_out[-1] logits = self.classifier(last_layer) preds = logits.argmax(-1) return logits, preds
def test_convert_onnx(self): model = BertModel(BertConfig.from_json_file(BERT_CONFIG_PATH)) model.train(False) output = torch.onnx.export( model, self.org_dummy_input, self.model_onnx_path, verbose=True, operator_export_type=OPERATOR_EXPORT_TYPE, input_names=['input_ids', 'token_type_ids', 'attention_mask']) print("Export of torch_model.onnx complete!")