def test_nucleus_sampling(self): vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES) model_config = self.config.model_attributes.butd model = TestDecoderModel(model_config, vocab) model.build() model.to("cuda") model.eval() sample = Sample() sample.dataset_name = "coco" sample.dataset_type = "test" sample.image_feature_0 = torch.randn(100, 2048) sample.answers = torch.zeros((5, 10), dtype=torch.long) sample_list = SampleList([sample]) tokens = model(sample_list)["captions"] # these are expected tokens for sum_threshold = 0.5 expected_tokens = [ 1.0000e+00, 2.9140e+03, 5.9210e+03, 2.2040e+03, 5.0550e+03, 9.2240e+03, 4.5120e+03, 1.8200e+02, 3.6490e+03, 6.4090e+03, 2.0000e+00 ] self.assertEqual(tokens[0].tolist(), expected_tokens)
def test_vocab_from_text(self): vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES) self.assertEqual(vocab.get_size(), 41) self.assertEqual(len(vocab), 41) self.assertEqual(vocab.get_unk_index(), 1) self.assertEqual(vocab.itos[0], vocab.DEFAULT_TOKENS[0]) self.assertEqual(vocab.itos[34], "that") self.assertEqual(vocab.itos[31], "cube") self.assertEqual(vocab.itos[25], "cyan") self.assertEqual(vocab.itos[20], "the") self.assertEqual(vocab.itos[10], "than") self.assertEqual(vocab.stoi["sphere"], 30) self.assertEqual(vocab.stoi["shape"], 22) vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES, min_count=10) self.assertEqual(vocab.get_size(), 5) self.assertEqual(vocab.itos[vocab.get_size() - 1], "the") vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES, min_count=11) self.assertEqual(vocab.get_size(), 4) vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES, min_count=11, only_unk_extra=True) self.assertEqual(vocab.get_size(), 1) self.assertEqual(vocab.itos[vocab.get_size() - 1], "<unk>") vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES, min_count=1, remove=[";"]) self.assertEqual(vocab.get_size(), 40) vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES, min_count=1, remove=[";", ",", "?"]) self.assertEqual(vocab.get_size(), 38) vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES, min_count=1, keep=["?"], remove=";") self.assertEqual(vocab.get_size(), 40)