def test_beam_search(self): vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES) model_config = self.config.model_config.butd model = TestDecoderModel(model_config, vocab) model.build() model.eval() expected_tokens = { 1: [1.0, 23.0, 1.0, 24.0, 29.0, 37.0, 40.0, 17.0, 29.0, 2.0], 2: [1.0, 0.0, 8.0, 1.0, 28.0, 25.0, 2.0], 8: [1.0, 34.0, 1.0, 13.0, 1.0, 2.0], 16: [1.0, 25.0, 18.0, 2.0], } for batch_size in [1, 2, 8, 16]: samples = [] for _ in range(batch_size): sample = Sample() sample.dataset_name = "coco" sample.dataset_type = "test" sample.image_feature_0 = torch.randn(100, 2048) sample.answers = torch.zeros((5, 10), dtype=torch.long) samples.append(sample) sample_list = SampleList(samples) tokens = model(sample_list)["captions"] self.assertEqual(np.trim_zeros(tokens[0].tolist()), expected_tokens[batch_size])
def test_nucleus_sampling(self): vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES) model_config = self.config.model_config.butd model = TestDecoderModel(model_config, vocab) model.build() model.eval() sample = Sample() sample.dataset_name = "coco" sample.dataset_type = "test" sample.image_feature_0 = torch.randn(100, 2048) sample.answers = torch.zeros((5, 10), dtype=torch.long) sample_list = SampleList([sample]) tokens = model(sample_list)["captions"] # these are expected tokens for sum_threshold = 0.5 # Because of a bug fix in https://github.com/pytorch/pytorch/pull/47386 # the torch.Tensor.multinomail will generate different random sequence. # TODO: Remove this hack after OSS uses later version of PyTorch. if LegacyVersion(torch.__version__) > LegacyVersion("1.7.1"): expected_tokens = [1.0, 23.0, 38.0, 30.0, 5.0, 11.0, 2.0] else: expected_tokens = [ 1.0, 29.0, 11.0, 11.0, 39.0, 10.0, 31.0, 4.0, 19.0, 39.0, 2.0, ] self.assertEqual(tokens[0].tolist(), expected_tokens)