コード例 #1
0
    def test_inference_masked_lm_long_input(self):
        model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
        input_ids = torch.arange(4096).unsqueeze(0)

        with torch.no_grad():
            output = model(input_ids)[0]

        vocab_size = 50265

        expected_shape = torch.Size((1, 4096, vocab_size))
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor([[[-2.3914, -4.3742, -5.0956],
                                        [-4.0988, -4.2384, -7.0406],
                                        [-3.1427, -3.7192, -6.6800]]])

        self.assertTrue(
            torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
コード例 #2
0
    def test_inference_masked_lm(self):
        model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])

        with torch.no_grad():
            output = model(input_ids)[0]

        vocab_size = 50265

        expected_shape = torch.Size((1, 6, vocab_size))
        self.assertEqual(output.shape, expected_shape)

        expected_slice = torch.tensor([[[-2.1313, -3.7285, -2.2407],
                                        [-2.7047, -3.3314, -2.6408],
                                        [0.0629, -2.5166, -0.3356]]])

        self.assertTrue(
            torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))