Exemplo n.º 1
0
    def test_inference_nezha_model(self):
        model = NezhaModel.from_pretrained("sijunhe/nezha-cn-base")
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
        attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1]])
        with torch.no_grad():
            output = model(input_ids, attention_mask=attention_mask)[0]
        expected_shape = torch.Size((1, 6, 768))
        self.assertEqual(output.shape, expected_shape)
        expected_slice = torch.tensor([[[0.0685, 0.2441, 0.1102],
                                        [0.0600, 0.1906, 0.1349],
                                        [0.0221, 0.0819, 0.0586]]])

        self.assertTrue(
            torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
Exemplo n.º 2
0
 def create_and_check_model(self, config, input_ids, token_type_ids,
                            input_mask, sequence_labels, token_labels,
                            choice_labels):
     model = NezhaModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids,
                    attention_mask=input_mask,
                    token_type_ids=token_type_ids)
     result = model(input_ids, token_type_ids=token_type_ids)
     result = model(input_ids)
     self.parent.assertEqual(
         result.last_hidden_state.shape,
         (self.batch_size, self.seq_length, self.hidden_size))
     self.parent.assertEqual(result.pooler_output.shape,
                             (self.batch_size, self.hidden_size))
Exemplo n.º 3
0
 def test_model_from_pretrained(self):
     for model_name in NEZHA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = NezhaModel.from_pretrained(model_name)
         self.assertIsNotNone(model)