示例#1
0
    def test_nucleus_sampling(self):
        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES)

        model_config = self.config.model_attributes.butd
        model = TestDecoderModel(model_config, vocab)
        model.build()
        model.to("cuda")
        model.eval()

        sample = Sample()
        sample.dataset_name = "coco"
        sample.dataset_type = "test"
        sample.image_feature_0 = torch.randn(100, 2048)
        sample.answers = torch.zeros((5, 10), dtype=torch.long)
        sample_list = SampleList([sample])

        tokens = model(sample_list)["captions"]

        # these are expected tokens for sum_threshold = 0.5
        expected_tokens = [
            1.0000e+00, 2.9140e+03, 5.9210e+03, 2.2040e+03, 5.0550e+03,
            9.2240e+03, 4.5120e+03, 1.8200e+02, 3.6490e+03, 6.4090e+03,
            2.0000e+00
        ]

        self.assertEqual(tokens[0].tolist(), expected_tokens)
示例#2
0
    def test_vocab_from_text(self):
        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES)

        self.assertEqual(vocab.get_size(), 41)
        self.assertEqual(len(vocab), 41)
        self.assertEqual(vocab.get_unk_index(), 1)

        self.assertEqual(vocab.itos[0], vocab.DEFAULT_TOKENS[0])
        self.assertEqual(vocab.itos[34], "that")
        self.assertEqual(vocab.itos[31], "cube")
        self.assertEqual(vocab.itos[25], "cyan")
        self.assertEqual(vocab.itos[20], "the")
        self.assertEqual(vocab.itos[10], "than")

        self.assertEqual(vocab.stoi["sphere"], 30)
        self.assertEqual(vocab.stoi["shape"], 22)

        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES,
                                         min_count=10)
        self.assertEqual(vocab.get_size(), 5)
        self.assertEqual(vocab.itos[vocab.get_size() - 1], "the")

        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES,
                                         min_count=11)
        self.assertEqual(vocab.get_size(), 4)

        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES,
                                         min_count=11,
                                         only_unk_extra=True)
        self.assertEqual(vocab.get_size(), 1)
        self.assertEqual(vocab.itos[vocab.get_size() - 1], "<unk>")

        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES,
                                         min_count=1,
                                         remove=[";"])
        self.assertEqual(vocab.get_size(), 40)

        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES,
                                         min_count=1,
                                         remove=[";", ",", "?"])
        self.assertEqual(vocab.get_size(), 38)

        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES,
                                         min_count=1,
                                         keep=["?"],
                                         remove=";")
        self.assertEqual(vocab.get_size(), 40)