Пример #1
0
    def test_inference_masked_lm_long_input(self):
        model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
        input_ids = torch.arange(4096).unsqueeze(0)

        with torch.no_grad():
            output = model(input_ids)[0]

        vocab_size = 50265

        expected_shape = torch.Size((1, 4096, vocab_size))
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor([[[-2.3914, -4.3742, -5.0956],
                                        [-4.0988, -4.2384, -7.0406],
                                        [-3.1427, -3.7192, -6.6800]]])

        self.assertTrue(
            torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
Пример #2
0
    def test_inference_masked_lm(self):
        model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])

        with torch.no_grad():
            output = model(input_ids)[0]

        vocab_size = 50265

        expected_shape = torch.Size((1, 6, vocab_size))
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor([[[-2.1313, -3.7285, -2.2407],
                                        [-2.7047, -3.3314, -2.6408],
                                        [0.0629, -2.5166, -0.3356]]])

        self.assertTrue(
            torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
Пример #3
0
 def create_and_check_for_masked_lm(self, config, input_ids, token_type_ids,
                                    input_mask, sequence_labels,
                                    token_labels, choice_labels):
     model = YosoForMaskedLM(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids,
                    attention_mask=input_mask,
                    token_type_ids=token_type_ids,
                    labels=token_labels)
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.seq_length, self.vocab_size))
Пример #4
0
def convert_yoso_checkpoint(checkpoint_path, yoso_config_file,
                            pytorch_dump_path):

    orig_state_dict = torch.load(checkpoint_path,
                                 map_location="cpu")["model_state_dict"]
    config = YosoConfig.from_json_file(yoso_config_file)
    model = YosoForMaskedLM(config)

    new_state_dict = convert_checkpoint_helper(config.max_position_embeddings,
                                               orig_state_dict)

    print(model.load_state_dict(new_state_dict))
    model.eval()
    model.save_pretrained(pytorch_dump_path)

    print(
        f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}"
    )