Пример #1
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # batch 0 will always predict EOS. The other batches will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            n_words = 100
            _non_eos_idxs = [47]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            samp = GreedySearch(
                0, 1, 2, batch_sz, min_length,
                False, set(), False, 30, 1., 1)
            samp.initialize(torch.zeros(1), lengths)
            all_attns = []
            for i in range(min_length + 4):
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                all_attns.append(attns)
                samp.advance(word_probs, attns)
                if i < min_length:
                    self.assertTrue(
                        samp.topk_scores[0].allclose(valid_score_dist[1]))
                    self.assertTrue(
                        samp.topk_scores[1:].eq(0).all())
                elif i == min_length:
                    # now batch 0 has ended and no others have
                    self.assertTrue(samp.is_finished[0, :].eq(1).all())
                    self.assertTrue(samp.is_finished[1:, 1:].eq(0).all())
                else:  # i > min_length
                    break
Пример #2
0
    def test_returns_correct_scores_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = GreedySearch(
                    0, 1, 2, batch_sz, 0,
                    False, set(), False, 30, temp, 1)
                samp.initialize(torch.zeros(1), lengths)
                # initial step
                i = 0
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # batch 0 dies on step 0
                word_probs[0, eos_idx] = valid_score_dist_1[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)
                self.assertTrue(samp.is_finished[0].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                word_probs = torch.full(
                    (batch_sz - 1, n_words), -float('inf'))
                # (old) batch 8 dies on step 1
                word_probs[7, eos_idx] = valid_score_dist_2[0]
                word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished[7].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                word_probs = torch.full(
                    (batch_sz - 2, n_words), -float('inf'))
                # everything dies
                word_probs[:, eos_idx] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished.eq(1).all())
                samp.update_finished()
                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
Пример #3
0
    def test_returns_correct_scores_non_deterministic_beams(self):
        beam_size = 10
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]),
                                                       dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]),
                                                       dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz, ))
                samp = GreedySearch(0, 1, 2, 3,
                                    batch_sz, GlobalScorerStub(), 0, False,
                                    set(), False, 30, temp, 50, 0, beam_size,
                                    False)
                samp.initialize(torch.zeros((1, 1)), lengths)
                # initial step
                # finish one beam
                i = 0
                for _ in range(100):
                    word_probs = torch.full((batch_sz * beam_size, n_words),
                                            -float('inf'))

                    word_probs[beam_size - 2, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[beam_size - 2,
                               _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[beam_size - 2 + 1:, _non_eos_idxs[0] + i] = 0
                    word_probs[:beam_size - 2, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[beam_size - 2].eq(1).all():
                        self.assertFalse(samp.is_finished[:beam_size -
                                                          2].eq(1).any())
                        self.assertFalse(samp.is_finished[beam_size - 2 +
                                                          1].eq(1).any())
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual([samp.topk_scores[beam_size - 2]],
                                 [valid_score_dist_1[0] / temp])

                # step 2
                # finish example in last batch
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz * beam_size - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[(batch_sz - 1) * beam_size + 7,
                               eos_idx] = valid_score_dist_2[0]
                    word_probs[:(batch_sz - 1) * beam_size + 7,
                               _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[(batch_sz - 1) * beam_size + 8:,
                               _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if (samp.is_finished[(batch_sz - 1) * beam_size +
                                         7].eq(1).all()):
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual([
                    score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:]
                ], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full((samp.alive_seq.shape[0], n_words),
                                            -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                self.assertTrue(samp.done)
Пример #4
0
    def test_returns_correct_scores_non_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = GreedySearch(
                    0, 1, 2, batch_sz, 0,
                    False, set(), False, 30, temp, 2)
                samp.initialize(torch.zeros(1), lengths)
                # initial step
                i = 0
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz, n_words), -float('inf'))
                    # batch 0 dies on step 0
                    word_probs[0, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[1:, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[0].eq(1).all():
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[7, eos_idx] = valid_score_dist_2[0]
                    word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[7].eq(1).all():
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full(
                        (samp.alive_seq.shape[0], n_words), -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
Пример #5
0
    def _sample_batch_with_strategy(self, batch,
                                    decode_strategy: GreedySearch):
        usr_src_map = False
        parallel_paths = decode_strategy.parallel_paths

        src, enc_states, memory_bank, src_lengths = self._run_encoder(
            batch=batch)
        self.model.decoder.init_state(src, memory_bank, enc_states)
        results = {
            "scores": None,
            "predictions": None,
            "attention": None,
        }
        src_map = batch.src_map if usr_src_map else None
        fn_map_state, memory_bank, memory_lengths, src_map = decode_strategy.initialize(
            memory_bank, src_lengths, src_map)
        if fn_map_state is not None:
            self.model.decoder.map_state(fn_map_state)

        for step in range(decode_strategy.max_length):
            decoder_input = decode_strategy.current_predictions.view(1, -1, 1)

            log_probs, attn = self._decoder_and_generate(
                decoder_input,
                memory_bank,
                batch,
                memory_lengths=memory_lengths,
                src_map=src_map,
                step=step,
                batch_offset=decode_strategy.batch_offset)

            decode_strategy.advance(log_probs, attn)
            any_finished = decode_strategy.is_finished.any()
            if any_finished:
                decode_strategy.update_finished()
                if decode_strategy.done:
                    break

            select_indices = decode_strategy.select_indices

            if any_finished:
                # Reorder states.
                if isinstance(memory_bank, tuple):
                    memory_bank = tuple(
                        x.index_select(1, select_indices) for x in memory_bank)
                else:
                    memory_bank = memory_bank.index_select(1, select_indices)

                memory_lengths = memory_lengths.index_select(0, select_indices)

                if src_map is not None:
                    src_map = src_map.index_select(1, select_indices)

            if parallel_paths > 1 or any_finished:
                self.model.decoder.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

        #    results['log_probs'] = log_probs

        results['scores'] = decode_strategy.scores
        results['predictions'] = decode_strategy.predictions
        # predictions is a list which its len is same as batch_size
        results['attention'] = decode_strategy.attention
        return results