Beispiel #1
0
    def first_step(self, beam, expected_beam_scores, expected_len_pen):
        # no EOS's yet
        assert beam.is_finished.sum() == 0
        scores_1 = torch.log_softmax(torch.tensor(
            [[0, 0,  0, .3,   0, .51, .2, 0],
             [0, 0, 1.5,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, .49, .48,  0, 0],
             [0, 0, 0, .2, .2, .2, .2, .2],
             [0, 0, 0, .2, .2, .2, .2, .2]]
        ), dim=1)
        scores_1 = scores_1.repeat(self.BATCH_SZ, 1)

        beam.advance(deepcopy(scores_1), self.random_attn())

        new_scores = scores_1 + expected_beam_scores.view(-1).unsqueeze(1)
        expected_beam_scores, unreduced_preds = new_scores\
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS)\
            .topk(self.BEAM_SZ, -1)
        expected_bptr_1 = unreduced_preds / self.N_WORDS
        # [5, 3, 2, 6, 0], so beam 2 predicts EOS!
        expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS
        self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
        self.assertTrue(beam.topk_scores.allclose(
            expected_beam_scores / expected_len_pen))
        self.assertTrue(beam.topk_ids.equal(expected_preds_1))
        self.assertTrue(beam.current_backptr.equal(expected_bptr_1))
        self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ)
        self.assertTrue(beam.is_finished[:, 2].all())  # beam 2 finished
        beam.update_finished()
        self.assertFalse(beam.top_beam_finished.any())
        self.assertFalse(beam.done)
        return expected_beam_scores
Beispiel #2
0
    def first_step(self, beam, expected_beam_scores, expected_len_pen):
        # no EOS's yet
        assert len(beam.finished) == 0
        scores_1 = torch.log_softmax(torch.tensor(
            [[0, 0,  0, .3,   0, .51, .2, 0],
             [0, 0, 1.5,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, .49, .48,  0, 0],
             [0, 0, 0, .2, .2, .2, .2, .2],
             [0, 0, 0, .2, .2, .2, .2, .2]]
        ), dim=1)

        beam.advance(scores_1, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))

        new_scores = scores_1 + expected_beam_scores.t()
        expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
            self.BEAM_SZ, 0, True, True)
        expected_bptr_1 = unreduced_preds / self.N_WORDS
        # [5, 3, 2, 6, 0], so beam 2 predicts EOS!
        expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS

        self.assertTrue(beam.scores.allclose(expected_beam_scores))
        self.assertTrue(beam.next_ys[-1].equal(expected_preds_1))
        self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_1))
        self.assertEqual(len(beam.finished), 1)
        self.assertEqual(beam.finished[0][2], 2)  # beam 2 finished
        self.assertEqual(beam.finished[0][1], 2)  # finished on second step
        self.assertEqual(beam.finished[0][0],  # finished with correct score
                         expected_beam_scores[2] / expected_len_pen)
        self.assertFalse(beam.eos_top)
        self.assertFalse(beam.done)
        return expected_beam_scores
Beispiel #3
0
    def second_step(self, beam, expected_beam_scores, expected_len_pen):
        # assumes beam 2 finished on last step
        scores_2 = torch.log_softmax(torch.tensor(
            [[0, 0,  0, .3,   0, .51, .2, 0],
             [0, 0, 0,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, 5000, .48,  0, 0],  # beam 2 shouldn't continue
             [0, 0, 50, .2, .2, .2, .2, .2],  # beam 3 -> beam 0 should die
             [0, 0, 0, .2, .2, .2, .2, .2]]
        ), dim=1)

        beam.advance(scores_2, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))

        new_scores = scores_2 + expected_beam_scores.unsqueeze(1)
        new_scores[2] = self.DEAD_SCORE  # ended beam 2 shouldn't continue
        expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
            self.BEAM_SZ, 0, True, True)
        expected_bptr_2 = unreduced_preds / self.N_WORDS
        # [2, 5, 3, 6, 0], so beam 0 predicts EOS!
        expected_preds_2 = unreduced_preds - expected_bptr_2 * self.N_WORDS
        # [-2.4879, -3.8910, -4.1010, -4.2010, -4.4010]
        self.assertTrue(beam.scores.allclose(expected_beam_scores))
        self.assertTrue(beam.next_ys[-1].equal(expected_preds_2))
        self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_2))
        self.assertEqual(len(beam.finished), 2)
        # new beam 0 finished
        self.assertEqual(beam.finished[1][2], 0)
        # new beam 0 is old beam 3
        self.assertEqual(expected_bptr_2[0], 3)
        self.assertEqual(beam.finished[1][1], 3)  # finished on third step
        self.assertEqual(beam.finished[1][0],  # finished with correct score
                         expected_beam_scores[0] / expected_len_pen)
        self.assertTrue(beam.eos_top)
        self.assertFalse(beam.done)
        return expected_beam_scores
Beispiel #4
0
    def third_step(self, beam, expected_beam_scores, expected_len_pen):
        # assumes beam 0 finished on last step
        scores_3 = torch.log_softmax(torch.tensor(
            [[0, 0,  5000, 0,   5000, .51, .2, 0],  # beam 0 shouldn't cont
             [0, 0, 0,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, 0, 5000,  0, 0],
             [0, 0, 0, .2, .2, .2, .2, .2],
             [0, 0, 50, 0, .2, .2, .2, .2]]  # beam 4 -> beam 1 should die
        ), dim=1)

        beam.advance(scores_3, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))

        new_scores = scores_3 + expected_beam_scores.unsqueeze(1)
        new_scores[0] = self.DEAD_SCORE  # ended beam 2 shouldn't continue
        expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
            self.BEAM_SZ, 0, True, True)
        expected_bptr_3 = unreduced_preds / self.N_WORDS
        # [5, 2, 6, 1, 0], so beam 1 predicts EOS!
        expected_preds_3 = unreduced_preds - expected_bptr_3 * self.N_WORDS
        self.assertTrue(beam.scores.allclose(expected_beam_scores))
        self.assertTrue(beam.next_ys[-1].equal(expected_preds_3))
        self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_3))
        self.assertEqual(len(beam.finished), 3)
        # new beam 1 finished
        self.assertEqual(beam.finished[2][2], 1)
        # new beam 1 is old beam 4
        self.assertEqual(expected_bptr_3[1], 4)
        self.assertEqual(beam.finished[2][1], 4)  # finished on fourth step
        self.assertEqual(beam.finished[2][0],  # finished with correct score
                         expected_beam_scores[1] / expected_len_pen)
        self.assertTrue(beam.eos_top)
        self.assertTrue(beam.done)
        return expected_beam_scores
Beispiel #5
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, False, 0, set(),
            torch.randint(0, 30, (batch_sz,)), False, 0.)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
Beispiel #6
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
                    exclusion_tokens=set(),
                    min_length=min_length,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx, j] = score
            else:
                word_probs[0, eos_idx] = valid_score_dist[0]
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertEqual(beam.finished[0][1], beam.min_length + 1)
                self.assertEqual(beam.finished[0][2], 1)
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertEqual(beam.finished[1][1], beam.min_length + 2)
                self.assertEqual(beam.finished[1][2], 0)
                self.assertTrue(beam.done)
Beispiel #7
0
 def init_step(self, beam):
     # init_preds: [4, 3, 5, 6, 7] - no EOS's
     init_scores = torch.log_softmax(torch.tensor(
         [[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
     expected_beam_scores, expected_preds_0 = init_scores.topk(self.BEAM_SZ)
     beam.advance(init_scores, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
     self.assertTrue(beam.scores.allclose(expected_beam_scores))
     self.assertTrue(beam.next_ys[-1].equal(expected_preds_0[0]))
     self.assertFalse(beam.eos_top)
     self.assertFalse(beam.done)
     return expected_beam_scores
Beispiel #8
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            beam_sz = 5
            n_words = 100
            _non_eos_idxs = [47, 51, 13, 88, 99]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5., 4., 3., 2., 1.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
                              torch.device("cpu"), GlobalScorerStub(),
                              min_length, 30, False, 0, set(),
                              lengths, False, 0.)
            all_attns = []
            for i in range(min_length + 4):
                # non-interesting beams are going to get dummy values
                word_probs = torch.full(
                    (batch_sz * beam_sz, n_words), -float('inf'))
                if i == 0:
                    # "best" prediction is eos - that should be blocked
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # include at least beam_sz predictions OTHER than EOS
                    # that are greater than -1e20
                    for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                        word_probs[0::beam_sz, j] = score
                else:
                    # predict eos in beam 0
                    word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                    # provide beam_sz other good predictions
                    for k, (j, score) in enumerate(
                            zip(_non_eos_idxs, valid_score_dist[1:])):
                        beam_idx = min(beam_sz-1, k)
                        word_probs[beam_idx::beam_sz, j] = score

                attns = torch.randn(1, batch_sz * beam_sz, 53)
                all_attns.append(attns)
                beam.advance(word_probs, attns)
                if i < min_length:
                    expected_score_dist = \
                        (i+1) * valid_score_dist[1:].unsqueeze(0)
                    self.assertTrue(
                        beam.topk_log_probs.allclose(
                            expected_score_dist))
                elif i == min_length:
                    # now the top beam has ended and no others have
                    self.assertTrue(beam.is_finished[:, 0].eq(1).all())
                    self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
                else:  # i > min_length
                    # not of interest, but want to make sure it keeps running
                    # since only beam 0 terminates and n_best = 2
                    pass
Beispiel #9
0
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # beam 0 will always predict EOS. The other beams will predict
        # non-eos scores.
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
                    exclusion_tokens=set(),
                    min_length=min_length,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0, j] = score
            else:
                # predict eos in beam 0
                word_probs[0, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                expected_score_dist = (i+1) * valid_score_dist[1:]
                self.assertTrue(beam.scores.allclose(expected_score_dist))
            elif i == min_length:
                # now the top beam has ended and no others have
                # first beam finished had length beam.min_length
                self.assertEqual(beam.finished[0][1], beam.min_length + 1)
                # first beam finished was 0
                self.assertEqual(beam.finished[0][2], 0)
            else:  # i > min_length
                # not of interest, but want to make sure it keeps running
                # since only beam 0 terminates and n_best = 2
                pass
Beispiel #10
0
 def init_step(self, beam, expected_len_pen):
     # init_preds: [4, 3, 5, 6, 7] - no EOS's
     init_scores = torch.log_softmax(torch.tensor(
         [[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
     init_scores = deepcopy(init_scores.repeat(
         self.BATCH_SZ * self.BEAM_SZ, 1))
     new_scores = init_scores + beam.topk_log_probs.view(-1).unsqueeze(1)
     expected_beam_scores, expected_preds_0 = new_scores \
         .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS) \
         .topk(self.BEAM_SZ, dim=-1)
     beam.advance(deepcopy(init_scores), self.random_attn())
     self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
     self.assertTrue(beam.topk_ids.equal(expected_preds_0))
     self.assertFalse(beam.is_finished.any())
     self.assertFalse(beam.done)
     return expected_beam_scores
    def test_doesnt_predict_eos_if_shorter_than_min_len(self):
        # batch 0 will always predict EOS. The other batches will predict
        # non-eos scores.
        for batch_sz in [1, 3]:
            n_words = 100
            _non_eos_idxs = [47]
            valid_score_dist = torch.log_softmax(torch.tensor(
                [6., 5.]), dim=0)
            min_length = 5
            eos_idx = 2
            lengths = torch.randint(0, 30, (batch_sz,))
            samp = RandomSampling(
                0, 1, 2, batch_sz, torch.device("cpu"), min_length,
                False, set(), False, 30, 1., 1, lengths)
            all_attns = []
            for i in range(min_length + 4):
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                all_attns.append(attns)
                samp.advance(word_probs, attns)
                if i < min_length:
                    self.assertTrue(
                        samp.topk_scores[0].allclose(valid_score_dist[1]))
                    self.assertTrue(
                        samp.topk_scores[1:].eq(0).all())
                elif i == min_length:
                    # now batch 0 has ended and no others have
                    self.assertTrue(samp.is_finished[0, :].eq(1).all())
                    self.assertTrue(samp.is_finished[1:, 1:].eq(0).all())
                else:  # i > min_length
                    break
Beispiel #12
0
    def second_step(self, beam, expected_beam_scores, expected_len_pen):
        # assumes beam 2 finished on last step
        scores_2 = torch.log_softmax(torch.tensor(
            [[0, 0,  0, .3,   0, .51, .2, 0],
             [0, 0, 0,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, 5000, .48,  0, 0],  # beam 2 shouldn't continue
             [0, 0, 50, .2, .2, .2, .2, .2],  # beam 3 -> beam 0 should die
             [0, 0, 0, .2, .2, .2, .2, .2]]
        ), dim=1)
        scores_2 = scores_2.repeat(self.BATCH_SZ, 1)

        beam.advance(deepcopy(scores_2), self.random_attn())

        # ended beam 2 shouldn't continue
        expected_beam_scores[:, 2::self.BEAM_SZ] = self.DEAD_SCORE
        new_scores = scores_2 + expected_beam_scores.view(-1).unsqueeze(1)
        expected_beam_scores, unreduced_preds = new_scores\
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS)\
            .topk(self.BEAM_SZ, -1)
        expected_bptr_2 = unreduced_preds / self.N_WORDS
        # [2, 5, 3, 6, 0] repeat self.BATCH_SZ, so beam 0 predicts EOS!
        expected_preds_2 = unreduced_preds - expected_bptr_2 * self.N_WORDS
        # [-2.4879, -3.8910, -4.1010, -4.2010, -4.4010] repeat self.BATCH_SZ
        self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
        self.assertTrue(beam.topk_scores.allclose(
            expected_beam_scores / expected_len_pen))
        self.assertTrue(beam.topk_ids.equal(expected_preds_2))
        self.assertTrue(beam.current_backptr.equal(expected_bptr_2))
        # another beam is finished in all batches
        self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ)
        # new beam 0 finished
        self.assertTrue(beam.is_finished[:, 0].all())
        # new beam 0 is old beam 3
        self.assertTrue(expected_bptr_2[:, 0].eq(3).all())
        beam.update_finished()
        self.assertTrue(beam.top_beam_finished.all())
        self.assertFalse(beam.done)
        return expected_beam_scores
Beispiel #13
0
    def third_step(self, beam, expected_beam_scores, expected_len_pen):
        # assumes beam 0 finished on last step
        scores_3 = torch.log_softmax(torch.tensor(
            [[0, 0,  5000, 0,   5000, .51, .2, 0],  # beam 0 shouldn't cont
             [0, 0, 0,  0,   0,   0,  0, 0],
             [0, 0,  0,  0, 0, 5000,  0, 0],
             [0, 0, 0, .2, .2, .2, .2, .2],
             [0, 0, 50, 0, .2, .2, .2, .2]]  # beam 4 -> beam 1 should die
        ), dim=1)
        scores_3 = scores_3.repeat(self.BATCH_SZ, 1)

        beam.advance(deepcopy(scores_3), self.random_attn())

        expected_beam_scores[:, 0::self.BEAM_SZ] = self.DEAD_SCORE
        new_scores = scores_3 + expected_beam_scores.view(-1).unsqueeze(1)
        expected_beam_scores, unreduced_preds = new_scores\
            .view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS)\
            .topk(self.BEAM_SZ, -1)
        expected_bptr_3 = unreduced_preds / self.N_WORDS
        # [5, 2, 6, 1, 0] repeat self.BATCH_SZ, so beam 1 predicts EOS!
        expected_preds_3 = unreduced_preds - expected_bptr_3 * self.N_WORDS
        self.assertTrue(beam.topk_log_probs.allclose(
            expected_beam_scores))
        self.assertTrue(beam.topk_scores.allclose(
            expected_beam_scores / expected_len_pen))
        self.assertTrue(beam.topk_ids.equal(expected_preds_3))
        self.assertTrue(beam.current_backptr.equal(expected_bptr_3))
        self.assertEqual(beam.is_finished.sum(), self.BATCH_SZ)
        # new beam 1 finished
        self.assertTrue(beam.is_finished[:, 1].all())
        # new beam 1 is old beam 4
        self.assertTrue(expected_bptr_3[:, 1].eq(4).all())
        beam.update_finished()
        self.assertTrue(beam.top_beam_finished.all())
        self.assertTrue(beam.done)
        return expected_beam_scores
 def forward(self, embeds):
     x = F.elu(self.fc1(embeds))
     x = F.elu(self.fc2(x))
     logists = torch.log_softmax(x, 1)
     return logists
Beispiel #15
0
 def forward(self, x):
     out = torch.relu(self.fc1(x))
     out = self.fc2(self.dropout(out))
     out = torch.log_softmax(out, dim=-1)
     return out
def getPredictions(y_pred):
	y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
	_, y_pred_tags = torch.max(y_pred_softmax, dim = 1)
	return y_pred_tags
    def test_returns_correct_scores_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = RandomSampling(
                    0, 1, 2, batch_sz, torch.device("cpu"), 0,
                    False, set(), False, 30, temp, 1, lengths)

                # initial step
                i = 0
                word_probs = torch.full(
                    (batch_sz, n_words), -float('inf'))
                # batch 0 dies on step 0
                word_probs[0, eos_idx] = valid_score_dist_1[0]
                # include at least one prediction OTHER than EOS
                # that is greater than -1e20
                word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                word_probs[1:, _non_eos_idxs[0] + i] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)
                self.assertTrue(samp.is_finished[0].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                word_probs = torch.full(
                    (batch_sz - 1, n_words), -float('inf'))
                # (old) batch 8 dies on step 1
                word_probs[7, eos_idx] = valid_score_dist_2[0]
                word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished[7].eq(1).all())
                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                word_probs = torch.full(
                    (batch_sz - 2, n_words), -float('inf'))
                # everything dies
                word_probs[:, eos_idx] = 0

                attns = torch.randn(1, batch_sz, 53)
                samp.advance(word_probs, attns)

                self.assertTrue(samp.is_finished.eq(1).all())
                samp.update_finished()
                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
Beispiel #18
0
    def forward(self, x, y=None, y_bar=None, add_bias=True):
        # Stick x into h for cleaner for loops without flow control
        h = x
        # Loop over blocks
        for index, blocklist in enumerate(self.blocks):
            for block in blocklist:
                h = block(h)
        # Apply global sum pooling as in SN-GAN
        h = torch.sum(self.activation(h), [2, 3])
        # Get initial class-unconditional output
        out = self.linear(h)

        out_mi = None
        out_c = None
        tP = None
        tQ = None
        tP_bar = None
        tQ_bar = None
        if self.Projection:
            out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
        if self.AC:
            out_c = self.linear_c(h)
        if self.TAC:
            out_mi = self.linear_mi(h)
        if self.TP:
            cP = self.embed_cP(y) if add_bias else 0.
            out_P = self.linear_P(h)
            if self.use_softmax:
                logP = torch.log_softmax(out_P, dim=1)
                tP = logP[range(y.size(0)), y].view(y.size(0), 1) + cP
            else:
                tP = torch.sum(self.embed_vP(y) * h, 1,
                               keepdim=True) + out_P + cP
            if y_bar is not None:
                cP_bar = self.embed_cP(y_bar) if add_bias else 0.
                if self.use_softmax:
                    tP_bar = logP[range(y.size(0)), y_bar].view(y.size(0),
                                                                1) + cP_bar
                else:
                    tP_bar = torch.sum(self.embed_vP(y_bar) * h,
                                       1,
                                       keepdim=True) + out_P + cP_bar
        if self.TQ:
            cQ = self.embed_cQ(y) if add_bias else 0.
            out_Q = self.linear_Q(h)
            if self.use_softmax:
                logQ = torch.log_softmax(out_Q, dim=1)
                tQ = logQ[range(y.size(0)), y].view(y.size(0), 1) + cQ
            else:
                tQ = torch.sum(self.embed_vQ(y) * h, 1,
                               keepdim=True) + out_Q + cQ
            if y_bar is not None:
                cQ_bar = self.embed_cQ(y_bar) if add_bias else 0.
                if self.use_softmax:
                    tQ_bar = logQ[range(y.size(0)), y_bar].view(y.size(0),
                                                                1) + cQ_bar
                else:
                    tQ_bar = torch.sum(self.embed_vQ(y_bar) * h,
                                       1,
                                       keepdim=True) + out_Q + cQ_bar

        return out, out_mi, out_c, tP, tP_bar, tQ, tQ_bar
Beispiel #19
0
def CustomKLDiv(logits, labels, T, dim = 1):
    logits = torch.log_softmax(logits/T, dim=dim)
    labels = torch.softmax(labels/T, dim=dim)
    kldiv = nn.KLDivLoss()(logits,labels)
    return kldiv
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz,))
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            GlobalScorerStub(),
            min_length, 30, True, 0, set(),
            False, 0.)
        device_init = torch.zeros(1, 1)
        _, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
        # inp_lens is tiled in initialize, reassign to make attn match
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i + 1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
Beispiel #21
0
 def forward(self, batch, message_vector):
     logits = self.linear_message(message_vector)
     return torch.log_softmax(logits, dim=-1)
Beispiel #22
0
    def forward(self, message, _):
        x = self.message_inp(message)
        x = self.fc(x)

        return torch.log_softmax(x, dim=1)
Beispiel #23
0
 def log_prob(self, x):
     x = self._pad(x)
     log_prob_x = self.component_distribution.log_prob(x)  # [S, B, k]
     log_mix_prob = torch.log_softmax(self.mixture_distribution.logits,
                                      dim=-1)  # [B, k]
     return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1)  # [S, B]
    def test_returns_correct_scores_non_deterministic(self):
        for batch_sz in [1, 13]:
            for temp in [1., 3.]:
                n_words = 100
                _non_eos_idxs = [47, 51, 13, 88, 99]
                valid_score_dist_1 = torch.log_softmax(torch.tensor(
                    [6., 5., 4., 3., 2., 1.]), dim=0)
                valid_score_dist_2 = torch.log_softmax(torch.tensor(
                    [6., 1.]), dim=0)
                eos_idx = 2
                lengths = torch.randint(0, 30, (batch_sz,))
                samp = RandomSampling(
                    0, 1, 2, batch_sz, torch.device("cpu"), 0,
                    False, set(), False, 30, temp, 2, lengths)

                # initial step
                i = 0
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz, n_words), -float('inf'))
                    # batch 0 dies on step 0
                    word_probs[0, eos_idx] = valid_score_dist_1[0]
                    # include at least one prediction OTHER than EOS
                    # that is greater than -1e20
                    word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:]
                    word_probs[1:, _non_eos_idxs[0] + i] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[0].eq(1).all():
                        break
                else:
                    self.fail("Batch 0 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")
                samp.update_finished()
                self.assertEqual(
                    samp.scores[0], [valid_score_dist_1[0] / temp])
                if batch_sz == 1:
                    self.assertTrue(samp.done)
                    continue
                else:
                    self.assertFalse(samp.done)

                # step 2
                i = 1
                for _ in range(100):
                    word_probs = torch.full(
                        (batch_sz - 1, n_words), -float('inf'))
                    # (old) batch 8 dies on step 1
                    word_probs[7, eos_idx] = valid_score_dist_2[0]
                    word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2
                    word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished[7].eq(1).all():
                        break
                else:
                    self.fail("Batch 8 never ended (very unlikely but maybe "
                              "due to stochasticisty. If so, please increase "
                              "the range of the for-loop.")

                samp.update_finished()
                self.assertEqual(
                    samp.scores[8], [valid_score_dist_2[0] / temp])

                # step 3
                i = 2
                for _ in range(250):
                    word_probs = torch.full(
                        (samp.alive_seq.shape[0], n_words), -float('inf'))
                    # everything dies
                    word_probs[:, eos_idx] = 0

                    attns = torch.randn(1, batch_sz, 53)
                    samp.advance(word_probs, attns)
                    if samp.is_finished.any():
                        samp.update_finished()
                    if samp.is_finished.eq(1).all():
                        break
                else:
                    self.fail("All batches never ended (very unlikely but "
                              "maybe due to stochasticisty. If so, please "
                              "increase the range of the for-loop.")

                for b in range(batch_sz):
                    if b != 0 and b != 8:
                        self.assertEqual(samp.scores[b], [0])
                self.assertTrue(samp.done)
Beispiel #25
0
    def test_beam_returns_attn_with_correct_length(self):
        beam_sz = 5
        batch_sz = 3
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]), dim=0)
        min_length = 5
        eos_idx = 2
        inp_lens = torch.randint(1, 30, (batch_sz,))
        beam = BeamSearch(
            beam_sz, batch_sz, 0, 1, 2, 2,
            torch.device("cpu"), GlobalScorerStub(),
            min_length, 30, True, 0, set(),
            inp_lens, False, 0.)
        for i in range(min_length + 2):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full(
                (batch_sz * beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0::beam_sz, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score
            else:
                word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
                word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz-1, k)
                    word_probs[beam_idx::beam_sz, j] = score

            attns = torch.randn(1, batch_sz * beam_sz, 53)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertTrue(beam.is_finished[:, 1].all())
                beam.update_finished()
                self.assertFalse(beam.done)
                # no top beams are finished yet
                for b in range(batch_sz):
                    self.assertEqual(beam.attention[b], [])
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertTrue(beam.is_finished[:, 0].all())
                beam.update_finished()
                self.assertTrue(beam.done)
                # top beam is finished now so there are attentions
                for b in range(batch_sz):
                    # two beams are finished in each batch
                    self.assertEqual(len(beam.attention[b]), 2)
                    for k in range(2):
                        # second dim is cut down to the non-padded src length
                        self.assertEqual(beam.attention[b][k].shape[-1],
                                         inp_lens[b])
                    # first dim is equal to the time of death
                    # (beam 0 died at current step - adjust for SOS)
                    self.assertEqual(beam.attention[b][0].shape[0], i+1)
                    # (beam 1 died at last step - adjust for SOS)
                    self.assertEqual(beam.attention[b][1].shape[0], i)
                # behavior gets weird when beam is already done so just stop
                break
def cross_entropy_loss(y_pred, y_label):
    if y_pred.size() == y_label.size():
        return torch.mean(-torch.sum(torch.log_softmax(y_pred, dim=-1) * y_label, dim=-1))
    else:
        return torch.nn.CrossEntropyLoss()(y_pred, y_label.long())
Beispiel #27
0
def manual_CE(predictions, labels):
    loss = -torch.mean(
        torch.sum(labels * torch.log_softmax(predictions, dim=1), dim=1))
    return loss
Beispiel #28
0
def train_mpl(teacher_model,
              student_model,
              labeled_dl,
              unlabeled_dl,
              batch_size,
              dataset,
              num_epochs=10,
              learning_rate=1e-3,
              uda_threshold=1.0,
              weight_u=1,
              n_student_steps=1,
              t_optimizer=None,
              s_optimizer=None,
              version='v1',
              model_save_dir=None):
    # Setup wandb
    config = wandb.config
    config.update({
        "num_epochs": num_epochs,
        "batch_size": batch_size,
        "uda_threshold": uda_threshold,
        "weight_u": weight_u,
        "n_student_steps": n_student_steps,
    })

    teacher_model.train()
    student_model.train()

    num_labeled = len(labeled_dl)
    num_unlabeled = len(unlabeled_dl)

    num_iter = num_labeled

    # Setup definitions
    if not t_optimizer:
        t_optimizer = torch.optim.Adam(teacher_model.parameters(),
                                       lr=learning_rate,
                                       weight_decay=1e-5)
        t_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            t_optimizer, num_epochs * (num_iter // batch_size))
    if not s_optimizer:
        s_optimizer = torch.optim.Adam(student_model.parameters(),
                                       lr=learning_rate,
                                       weight_decay=1e-5)
        s_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            s_optimizer, num_epochs * (num_iter // batch_size))

    # Iterate for num_epochs
    global_step = 0
    total_batches = num_iter
    for epoch in range(num_epochs):
        batch_num = 0
        running_teacher_loss = 0.0
        running_student_loss = 0.0
        labeled_iter = iter(labeled_dl)
        unlabeled_iter = iter(unlabeled_dl)

        for i in range(num_iter):
            # We get one unlabeled image and one labeled image batch
            image_l, label = next(labeled_iter)
            images_u, _ = next(unlabeled_iter)

            image_u, image_u_aug = images_u

            # 0) resize image to what PyTorch wants and convert to device
            if dataset == 'fashion_mnist':
                image_l = image_l.view(-1, 1, 28, 28).to(device)
                image_u = image_u.view(-1, 1, 28, 28).to(device)
            elif dataset == 'imagenet':
                image_l = image_l.view(-1, 3, 224, 224).to(device)
                image_u = image_u.view(-1, 3, 224, 224).to(device)
            elif dataset == 'cifar10':
                image_l = image_l.view(-1, 3, 32, 32).to(device)
                image_u = image_u.view(-1, 3, 32, 32).to(device)

            label = label.type(torch.LongTensor).to(device)

            # 1) pass labeled image through teacher and save the loss for future backprop
            t_logits = teacher_model(image_l)
            t_l_loss = F.cross_entropy(t_logits, label)

            # 2) pass labeled image through student and save the loss
            with torch.no_grad():  # we don't want to update student
                s_logits_l = student_model(image_l)
            s_l_loss_1 = F.cross_entropy(
                s_logits_l, label)  # cross_entropy interally takes the average

            # 3) generate pseudo labels from teacher
            for _ in range(n_student_steps):
                mpl_image_u = teacher_model(image_u)
                soft_mpl_image_u = torch.softmax(
                    mpl_image_u.detach(), dim=-1
                )  # don't propagate gradients into teacher so use .detach()

                # 4) pass unlabeled through student, calculate gradients, and optimize student
                s_logits_u = student_model(image_u)
                s_mpl_loss = F.binary_cross_entropy_with_logits(
                    s_logits_u, soft_mpl_image_u.detach())
                s_mpl_loss.backward(
                )  # calculate gradients for student network
                s_optimizer.step(
                )  # step in the direction of gradients for student network
                s_optimizer.zero_grad()
            # We will clear out gradients at the end

            # 5) pass labeled data through updated student and save the loss
            with torch.no_grad():  # we don't want to update student
                s_logits_l_updated = student_model(image_l)
            s_l_loss_2 = F.cross_entropy(
                s_logits_l_updated,
                label)  # cross_entropy interally takes the average

            # more details about the mpl loss: https://github.com/google-research/google-research/issues/534
            # NOTE: I'm using soft labels (requires BCE not just CE) instead of hard labels to train so there may be some differences with the reference code
            # the difference between the losses is an approximation of the dot product via a taylor expansion
            dot_product = s_l_loss_2 - s_l_loss_1
            # with hard labels, use log softmax trick from REINFORCE to compute gradients which we then scale with the dot product
            # http://stillbreeze.github.io/REINFORCE-vs-Reparameterization-trick/
            # with soft labels, I have no idea how we can propogate the gradients into the teacher so I use hard labels here
            max_probs, hard_pseudo_label = torch.max(mpl_image_u.detach(),
                                                     dim=-1)
            t_mpl_loss = dot_product * F.cross_entropy(mpl_image_u,
                                                       hard_pseudo_label)

            # 6) calculate unsupervised distribution alignment (UDA) loss of teacher on unlabeled images
            t_logits_image_u_aug = teacher_model(image_u_aug)
            uda_loss_mask = (max_probs >= uda_threshold).float()
            t_uda_loss = torch.mean(-(soft_mpl_image_u * torch.log_softmax(
                t_logits_image_u_aug, dim=-1)).sum(dim=-1) * uda_loss_mask)

            # 6) calculate teacher loss and optimizer teacher
            t_u_loss = t_mpl_loss + t_uda_loss
            t_loss = t_l_loss + (weight_u * t_u_loss)
            t_loss.backward()
            # DEBUG: print(teacher_model.encoder.gate[0].weight.grad) # verify that the teacher has a gradient
            t_optimizer.step()
            # We will clear out gradients at the end

            t_scheduler.step()
            s_scheduler.step()

            # 7) clear out gradients of teacher and student
            teacher_model.zero_grad()
            student_model.zero_grad()

            # 8) display current training information and update batch information
            global_step += 1
            batch_num += 1
            running_teacher_loss += t_loss.item()
            running_student_loss += s_l_loss_2.item()
            if global_step % 100 == 0:
                print(
                    'Epoch:{} Batch:{}/{} Teacher Loss:{:.4f} Student Loss:{:.4f}'
                    .format(epoch + 1, batch_num, total_batches, t_loss.item(),
                            s_l_loss_2.item()))
                wandb.log({
                    'batch_teacher_loss': t_loss.item(),
                    'batch_student_loss': s_l_loss_2.item(),
                    't_lr': t_optimizer.param_groups[0]['lr'],
                    's_lr': s_optimizer.param_groups[0]['lr']
                })

        # display information for each epoch
        print('Epoch:{} Teacher Loss:{:.4f} Student Loss:{:.4f}'.format(
            epoch + 1, running_teacher_loss / num_iter,
            running_student_loss / num_iter))
        wandb.log({
            'epoch': epoch + 1,
            'teacher_loss': running_teacher_loss / num_iter,
            'student_loss': running_student_loss / num_iter
        })
        if model_save_dir is not None:
            checkpoint = {
                't_optimizer': t_optimizer.state_dict(),
                's_optimizer': s_optimizer.state_dict(),
                'teacher_model': teacher_model.state_dict(),
                'student_model': student_model.state_dict()
            }
            os.makedirs(model_save_dir, exist_ok=True)
            filename = f"{epoch+1}.pt"
            torch.save(checkpoint, os.path.join(model_save_dir, filename))
Beispiel #29
0
def test_step(datax, datay, Ns, Nc, Nq):
    Qx, Qy = protonet(datax, datay, Ns, Nc, Nq, np.unique(datay))
    pred = torch.log_softmax(Qx, dim=-1)
    loss = F.nll_loss(pred, Qy)
    acc = torch.mean((torch.argmax(pred, 1) == Qy).float())
    return loss, acc
Beispiel #30
0
 def forward(self, seq):
     ret = torch.log_softmax(self.fc(seq), dim=-1)
     return ret
Beispiel #31
0
    def _record_artifacts(self, info, _config):
        epoch = info['epoch']
        artifact_storage_interval = _config['model_debug'][
            'artifact_storage_interval']
        results_dir = log_dir_path('results')

        if epoch % artifact_storage_interval == 0:

            # Data
            with torch.no_grad():

                self.model.eval()

                test_data = next(iter(self.test_dataloader))
                m_data = test_data[1]
                t_data = torch.nn.functional.one_hot(test_data[0], num_classes=10).float()

                if self.model.use_cuda:
                    m_data = m_data.cuda()
                    t_data = t_data.cuda()

                # Generate modalities
                mod_recons, nx_recons = self.model.generate([m_data, t_data])
                m_out, t_out = mod_recons[0], mod_recons[1]
                nx_m_m_out, nx_m_t_out = nx_recons[0][0], nx_recons[0][1]
                nx_t_m_out, nx_t_t_out = nx_recons[1][0], nx_recons[1][1]

                # Mnist Recon
                m_out_comp = torch.cat([m_data.view(-1, 1, 28, 28).cpu(), m_out.view(-1, 1, 28, 28).cpu()])
                nx_m_m_comp = torch.cat([m_data.view(-1, 1, 28, 28).cpu(), nx_m_m_out.view(-1, 1, 28, 28).cpu()])
                nx_t_m_comp = torch.cat([m_data.view(-1, 1, 28, 28).cpu(), nx_t_m_out.view(-1, 1, 28, 28).cpu()])

                # Text Recon
                t_res = np.argmax(torch.log_softmax(t_out, dim=-1).cpu().numpy(), axis=1).tolist()
                t_res_str = ''
                for i, item in enumerate(t_res):
                    t_res_str += str(item) + " "

                nx_m_t_res = np.argmax(torch.log_softmax(nx_m_t_out, dim=-1).cpu().numpy(), axis=1).tolist()
                nx_m_t_res_str = ''
                for i, item in enumerate(nx_m_t_res):
                    nx_m_t_res_str += str(item) + " "

                nx_t_t_res = np.argmax(torch.log_softmax(nx_t_t_out, dim=-1).cpu().numpy(), axis=1).tolist()
                nx_t_t_res_str = ''
                for i, item in enumerate(nx_t_t_res):
                    nx_t_t_res_str += str(item) + " "

            # Save data

            torchvision.utils.save_image(torchvision.utils.make_grid(m_out_comp,
                                                                     padding=5,
                                                                     pad_value=.5,
                                                                     nrow=m_data.size(0)),
                                         os.path.join(results_dir, 'm_comp_e' + str(epoch) + '.png'))
            ex.add_artifact(os.path.join(results_dir, "m_comp_e" + str(epoch) + '.png'),
                            name="image_recon_e" + str(epoch) + '.png')


            torchvision.utils.save_image(torchvision.utils.make_grid(nx_m_m_comp,
                                                                     padding=5,
                                                                     pad_value=.5,
                                                                     nrow=m_data.size(0)),
                                         os.path.join(results_dir, 'nx_m_m_comp_e' + str(epoch) + '.png'))
            ex.add_artifact(os.path.join(results_dir, "nx_m_m_comp_e" + str(epoch) + '.png'),
                            name="image_nexus_image_recon_e" + str(epoch) + '.png')


            torchvision.utils.save_image(torchvision.utils.make_grid(nx_t_m_comp,
                                                                     padding=5,
                                                                     pad_value=.5,
                                                                     nrow=m_data.size(0)),
                                         os.path.join(results_dir, 'nx_t_m_comp_e' + str(epoch) + '.png'))
            ex.add_artifact(os.path.join(results_dir, "nx_t_m_comp_e" + str(epoch) + '.png'),
                            name="symbol_nexus_image_recon_e" + str(epoch) + '.png')


            with open(os.path.join(results_dir,'t_res_str_e' + str(epoch) + '.txt'), "w") as symbol_file:
                print(t_res_str, file=symbol_file)
            ex.add_artifact(os.path.join(results_dir, "t_res_str_e" + str(epoch) + '.txt'),
                            name= "symbol_recon_e" + str(epoch) + '.txt')

            with open(os.path.join(results_dir,'nx_t_t_res_str_e' + str(epoch) + '.txt'), "w") as symbol_file:
                print(nx_t_t_res_str, file=symbol_file)
            ex.add_artifact(os.path.join(results_dir, "nx_t_t_res_str_e" + str(epoch) + '.txt'),
                            name= "symbol_nexus_symbol_recon_e" + str(epoch) + '.txt')

            with open(os.path.join(results_dir,'nx_m_t_res_str_e' + str(epoch) + '.txt'), "w") as symbol_file:
                print(nx_m_t_res_str, file=symbol_file)
            ex.add_artifact(os.path.join(results_dir, "nx_m_t_res_str_e" + str(epoch) + '.txt'),
                            name= "image_nexus_symbol_recon_e" + str(epoch) + '.txt')
Beispiel #32
0
 def get_logits(self, x):
     return torch.log_softmax(self.forward(x), dim=-1)
Beispiel #33
0
 def updateOutput(self, input):
     self.output = torch.log_softmax(
         input,
         self._get_dim(input)
     )
     return self.output
Beispiel #34
0
def logprobs(D, h, e):
    #computing y
    return torch.log_softmax(torch.add(D @ h, e), dim=0)
Beispiel #35
0
    def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
        """Beam search implementation.

        Modified from https://arxiv.org/pdf/1211.3711.pdf

        Args:
            enc_out: Encoder output sequence. (T, D)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))

        dec_state = self.decoder.init_state(1)

        kept_hyps = [
            Hypothesis(score=0.0, yseq=[self.blank_id], dec_state=dec_state)
        ]
        cache = {}
        cache_lm = {}

        for enc_out_t in enc_out:
            hyps = kept_hyps
            kept_hyps = []

            if self.token_list is not None:
                logging.debug("\n" + "\n".join([
                    "hypo: " +
                    "".join([self.token_list[x] for x in hyp.yseq[1:]]) +
                    f", score: {round(float(hyp.score), 2)}" for hyp in sorted(
                        hyps, key=lambda x: x.score, reverse=True)
                ]))

            while True:
                max_hyp = max(hyps, key=lambda x: x.score)
                hyps.remove(max_hyp)

                dec_out, state, lm_tokens = self.decoder.score(max_hyp, cache)

                logp = torch.log_softmax(
                    self.joint_network(enc_out_t, dec_out),
                    dim=-1,
                )
                top_k = logp[1:].topk(beam_k, dim=-1)

                kept_hyps.append(
                    Hypothesis(
                        score=(max_hyp.score + float(logp[0:1])),
                        yseq=max_hyp.yseq[:],
                        dec_state=max_hyp.dec_state,
                        lm_state=max_hyp.lm_state,
                    ))

                if self.use_lm:
                    if tuple(max_hyp.yseq) not in cache_lm:
                        lm_scores, lm_state = self.lm.score(
                            torch.LongTensor(
                                [self.sos] + max_hyp.yseq[1:],
                                device=self.decoder.device,
                            ),
                            max_hyp.lm_state,
                            None,
                        )
                        cache_lm[tuple(max_hyp.yseq)] = (lm_scores, lm_state)
                    else:
                        lm_scores, lm_state = cache_lm[tuple(max_hyp.yseq)]
                else:
                    lm_state = max_hyp.lm_state

                for logp, k in zip(*top_k):
                    score = max_hyp.score + float(logp)

                    if self.use_lm:
                        score += self.lm_weight * lm_scores[k + 1]

                    hyps.append(
                        Hypothesis(
                            score=score,
                            yseq=max_hyp.yseq[:] + [int(k + 1)],
                            dec_state=state,
                            lm_state=lm_state,
                        ))

                hyps_max = float(max(hyps, key=lambda x: x.score).score)
                kept_most_prob = sorted(
                    [hyp for hyp in kept_hyps if hyp.score > hyps_max],
                    key=lambda x: x.score,
                )
                if len(kept_most_prob) >= beam:
                    kept_hyps = kept_most_prob
                    break

        return self.sort_nbest(kept_hyps)
Beispiel #36
0
 def forward(self, x, return_logit=False):
     if return_logit:
         logit = self.proj(x)
         return logit
     else:
         return torch.log_softmax(self.proj(x), dim=-1)
Beispiel #37
0
    def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
        """Time synchronous beam search implementation.

        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            enc_out: Encoder output sequence. (T, D)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)

        beam_state = self.decoder.init_state(beam)

        B = [
            Hypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]
        cache = {}

        if self.use_lm:
            B[0].lm_state = self.lm.zero_state()

        for enc_out_t in enc_out:
            A = []
            C = B

            enc_out_t = enc_out_t.unsqueeze(0)

            for v in range(self.max_sym_exp):
                D = []

                beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    C,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                beam_logp = torch.log_softmax(
                    self.joint_network(enc_out_t, beam_dec_out),
                    dim=-1,
                )
                beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

                seq_A = [h.yseq for h in A]

                for i, hyp in enumerate(C):
                    if hyp.yseq not in seq_A:
                        A.append(
                            Hypothesis(
                                score=(hyp.score + float(beam_logp[i, 0])),
                                yseq=hyp.yseq[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                            ))
                    else:
                        dict_pos = seq_A.index(hyp.yseq)

                        A[dict_pos].score = np.logaddexp(
                            A[dict_pos].score,
                            (hyp.score + float(beam_logp[i, 0])))

                if v < (self.max_sym_exp - 1):
                    if self.use_lm:
                        beam_lm_scores, beam_lm_states = self.lm.batch_score(
                            beam_lm_tokens, [c.lm_state for c in C], None)

                    for i, hyp in enumerate(C):
                        for logp, k in zip(beam_topk[0][i],
                                           beam_topk[1][i] + 1):
                            new_hyp = Hypothesis(
                                score=(hyp.score + float(logp)),
                                yseq=(hyp.yseq + [int(k)]),
                                dec_state=self.decoder.select_state(
                                    beam_state, i),
                                lm_state=hyp.lm_state,
                            )

                            if self.use_lm:
                                new_hyp.score += self.lm_weight * beam_lm_scores[
                                    i, k]
                                new_hyp.lm_state = beam_lm_states[i]

                            D.append(new_hyp)

                C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]

            B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]

        return self.sort_nbest(B)
Beispiel #38
0
    def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
        # this is also a test that when block_ngram_repeat=0,
        # repeating is acceptable
        beam_sz = 5
        n_words = 100
        _non_eos_idxs = [47, 51, 13, 88, 99]
        valid_score_dist = torch.log_softmax(torch.tensor(
            [6., 5., 4., 3., 2., 1.]),
                                             dim=0)
        min_length = 5
        eos_idx = 2
        # beam includes start token in cur_len count.
        # Add one to its min_length to compensate
        beam = Beam(beam_sz,
                    0,
                    1,
                    eos_idx,
                    n_best=2,
                    exclusion_tokens=set(),
                    min_length=min_length,
                    global_scorer=GlobalScorerStub(),
                    block_ngram_repeat=0)
        for i in range(min_length + 4):
            # non-interesting beams are going to get dummy values
            word_probs = torch.full((beam_sz, n_words), -float('inf'))
            if i == 0:
                # "best" prediction is eos - that should be blocked
                word_probs[0, eos_idx] = valid_score_dist[0]
                # include at least beam_sz predictions OTHER than EOS
                # that are greater than -1e20
                for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
                    word_probs[0, j] = score
            elif i <= min_length:
                # predict eos in beam 1
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx, j] = score
            else:
                word_probs[0, eos_idx] = valid_score_dist[0]
                word_probs[1, eos_idx] = valid_score_dist[0]
                # provide beam_sz other good predictions in other beams
                for k, (j, score) in enumerate(
                        zip(_non_eos_idxs, valid_score_dist[1:])):
                    beam_idx = min(beam_sz - 1, k)
                    word_probs[beam_idx, j] = score

            attns = torch.randn(beam_sz)
            beam.advance(word_probs, attns)
            if i < min_length:
                self.assertFalse(beam.done)
            elif i == min_length:
                # beam 1 dies on min_length
                self.assertEqual(beam.finished[0][1], beam.min_length + 1)
                self.assertEqual(beam.finished[0][2], 1)
                self.assertFalse(beam.done)
            else:  # i > min_length
                # beam 0 dies on the step after beam 1 dies
                self.assertEqual(beam.finished[1][1], beam.min_length + 2)
                self.assertEqual(beam.finished[1][2], 0)
                self.assertTrue(beam.done)
Beispiel #39
0
    def align_length_sync_decoding(self,
                                   enc_out: torch.Tensor) -> List[Hypothesis]:
        """Alignment-length synchronous beam search implementation.

        Based on https://ieeexplore.ieee.org/document/9053040

        Args:
            h: Encoder output sequences. (T, D)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)

        t_max = int(enc_out.size(0))
        u_max = min(self.u_max, (t_max - 1))

        beam_state = self.decoder.init_state(beam)

        B = [
            Hypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]
        final = []
        cache = {}

        if self.use_lm:
            B[0].lm_state = self.lm.zero_state()

        for i in range(t_max + u_max):
            A = []

            B_ = []
            B_enc_out = []
            for hyp in B:
                u = len(hyp.yseq) - 1
                t = i - u

                if t > (t_max - 1):
                    continue

                B_.append(hyp)
                B_enc_out.append((t, enc_out[t]))

            if B_:
                beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    B_,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                beam_enc_out = torch.stack([x[1] for x in B_enc_out])

                beam_logp = torch.log_softmax(
                    self.joint_network(beam_enc_out, beam_dec_out),
                    dim=-1,
                )
                beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)

                if self.use_lm:
                    beam_lm_scores, beam_lm_states = self.lm.batch_score(
                        beam_lm_tokens,
                        [b.lm_state for b in B_],
                        None,
                    )

                for i, hyp in enumerate(B_):
                    new_hyp = Hypothesis(
                        score=(hyp.score + float(beam_logp[i, 0])),
                        yseq=hyp.yseq[:],
                        dec_state=hyp.dec_state,
                        lm_state=hyp.lm_state,
                    )

                    A.append(new_hyp)

                    if B_enc_out[i][0] == (t_max - 1):
                        final.append(new_hyp)

                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        new_hyp = Hypothesis(
                            score=(hyp.score + float(logp)),
                            yseq=(hyp.yseq[:] + [int(k)]),
                            dec_state=self.decoder.select_state(beam_state, i),
                            lm_state=hyp.lm_state,
                        )

                        if self.use_lm:
                            new_hyp.score += self.lm_weight * beam_lm_scores[i,
                                                                             k]
                            new_hyp.lm_state = beam_lm_states[i]

                        A.append(new_hyp)

                B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
                B = recombine_hyps(B)

        if final:
            return self.sort_nbest(final)
        else:
            return B
Beispiel #40
0
 def _loss(i):
     y = torch.log_softmax(pred[i], axis=-1)
     y = torch.one_hot(span[:, i], self.slen) * y
     return -torch.reduce_mean(torch.reduce_sum(y, axis=-1))
Beispiel #41
0
    def nsc_beam_search(self,
                        enc_out: torch.Tensor) -> List[ExtendedHypothesis]:
        """N-step constrained beam search implementation.

        Based on/Modified from https://arxiv.org/pdf/2002.03577.pdf.
        Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
        until further modifications.

        Args:
            enc_out: Encoder output sequence. (T, D_enc)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))

        beam_state = self.decoder.init_state(beam)

        init_tokens = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        cache = {}

        beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
            init_tokens,
            beam_state,
            cache,
            self.use_lm,
        )

        state = self.decoder.select_state(beam_state, 0)

        if self.use_lm:
            beam_lm_scores, beam_lm_states = self.lm.batch_score(
                beam_lm_tokens,
                [i.lm_state for i in init_tokens],
                None,
            )
            lm_state = beam_lm_states[0]
            lm_scores = beam_lm_scores[0]
        else:
            lm_state = None
            lm_scores = None

        kept_hyps = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=state,
                dec_out=[beam_dec_out[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for enc_out_t in enc_out:
            hyps = self.prefix_search(
                sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True),
                enc_out_t,
            )
            kept_hyps = []

            beam_enc_out = enc_out_t.unsqueeze(0)

            S = []
            V = []
            for n in range(self.nstep):
                beam_dec_out = torch.stack([hyp.dec_out[-1] for hyp in hyps])

                beam_logp = torch.log_softmax(
                    self.joint_network(beam_enc_out, beam_dec_out),
                    dim=-1,
                )
                beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)

                for i, hyp in enumerate(hyps):
                    S.append(
                        ExtendedHypothesis(
                            yseq=hyp.yseq[:],
                            score=hyp.score + float(beam_logp[i, 0:1]),
                            dec_out=hyp.dec_out[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        ))

                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
                        score = hyp.score + float(logp)

                        if self.use_lm:
                            score += self.lm_weight * float(hyp.lm_scores[k])

                        V.append(
                            ExtendedHypothesis(
                                yseq=hyp.yseq[:] + [int(k)],
                                score=score,
                                dec_out=hyp.dec_out[:],
                                dec_state=hyp.dec_state,
                                lm_state=hyp.lm_state,
                                lm_scores=hyp.lm_scores,
                            ))

                V.sort(key=lambda x: x.score, reverse=True)
                V = subtract(V, hyps)[:beam]

                beam_state = self.decoder.create_batch_states(
                    beam_state,
                    [v.dec_state for v in V],
                    [v.yseq for v in V],
                )
                beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                    V,
                    beam_state,
                    cache,
                    self.use_lm,
                )

                if self.use_lm:
                    beam_lm_scores, beam_lm_states = self.lm.batch_score(
                        beam_lm_tokens, [v.lm_state for v in V], None)

                if n < (self.nstep - 1):
                    for i, v in enumerate(V):
                        v.dec_out.append(beam_dec_out[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.use_lm:
                            v.lm_state = beam_lm_states[i]
                            v.lm_scores = beam_lm_scores[i]

                    hyps = V[:]
                else:
                    beam_logp = torch.log_softmax(
                        self.joint_network(beam_enc_out, beam_dec_out),
                        dim=-1,
                    )

                    for i, v in enumerate(V):
                        if self.nstep != 1:
                            v.score += float(beam_logp[i, 0])

                        v.dec_out.append(beam_dec_out[i])

                        v.dec_state = self.decoder.select_state(beam_state, i)

                        if self.use_lm:
                            v.lm_state = beam_lm_states[i]
                            v.lm_scores = beam_lm_scores[i]

            kept_hyps = sorted((S + V), key=lambda x: x.score,
                               reverse=True)[:beam]

        return self.sort_nbest(kept_hyps)
Beispiel #42
0
def gumbel_softmax(logits, temp):
    logprobs = torch.log_softmax(logits, dim=1)
    y = (logprobs + sample_gumbel(logits.size())) / temp
    return F.softmax(y, dim=1)
Beispiel #43
0
    def modified_adaptive_expansion_search(
            self, enc_out: torch.Tensor) -> List[ExtendedHypothesis]:
        """It's the modified Adaptive Expansion Search (mAES) implementation.

        Based on/modified from https://ieeexplore.ieee.org/document/9250505 and NSC.

        Args:
            enc_out: Encoder output sequence. (T, D_enc)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_state = self.decoder.init_state(beam)

        init_tokens = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=self.decoder.select_state(beam_state, 0),
            )
        ]

        cache = {}

        beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
            init_tokens,
            beam_state,
            cache,
            self.use_lm,
        )

        state = self.decoder.select_state(beam_state, 0)

        if self.use_lm:
            beam_lm_scores, beam_lm_states = self.lm.batch_score(
                beam_lm_tokens, [i.lm_state for i in init_tokens], None)

            lm_state = beam_lm_states[0]
            lm_scores = beam_lm_scores[0]
        else:
            lm_state = None
            lm_scores = None

        kept_hyps = [
            ExtendedHypothesis(
                yseq=[self.blank_id],
                score=0.0,
                dec_state=state,
                dec_out=[beam_dec_out[0]],
                lm_state=lm_state,
                lm_scores=lm_scores,
            )
        ]

        for enc_out_t in enc_out:
            hyps = self.prefix_search(
                sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True),
                enc_out_t,
            )
            kept_hyps = []

            beam_enc_out = enc_out_t.unsqueeze(0)

            list_b = []
            duplication_check = [hyp.yseq for hyp in hyps]

            for n in range(self.nstep):
                beam_dec_out = torch.stack([h.dec_out[-1] for h in hyps])

                beam_logp, beam_idx = torch.log_softmax(
                    self.joint_network(beam_enc_out, beam_dec_out),
                    dim=-1,
                ).topk(self.max_candidates, dim=-1)

                k_expansions = select_k_expansions(
                    hyps,
                    beam_idx,
                    beam_logp,
                    self.expansion_gamma,
                )

                list_exp = []
                for i, hyp in enumerate(hyps):
                    for k, new_score in k_expansions[i]:
                        new_hyp = ExtendedHypothesis(
                            yseq=hyp.yseq[:],
                            score=new_score,
                            dec_out=hyp.dec_out[:],
                            dec_state=hyp.dec_state,
                            lm_state=hyp.lm_state,
                            lm_scores=hyp.lm_scores,
                        )

                        if k == 0:
                            list_b.append(new_hyp)
                        else:
                            if new_hyp.yseq + [int(k)
                                               ] not in duplication_check:
                                new_hyp.yseq.append(int(k))

                                if self.use_lm:
                                    new_hyp.score += self.lm_weight * float(
                                        hyp.lm_scores[k])

                                list_exp.append(new_hyp)

                if not list_exp:
                    kept_hyps = sorted(list_b,
                                       key=lambda x: x.score,
                                       reverse=True)[:beam]

                    break
                else:
                    beam_state = self.decoder.create_batch_states(
                        beam_state,
                        [hyp.dec_state for hyp in list_exp],
                        [hyp.yseq for hyp in list_exp],
                    )

                    beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
                        list_exp,
                        beam_state,
                        cache,
                        self.use_lm,
                    )

                    if self.use_lm:
                        beam_lm_scores, beam_lm_states = self.lm.batch_score(
                            beam_lm_tokens, [k.lm_state for k in list_exp],
                            None)

                    if n < (self.nstep - 1):
                        for i, hyp in enumerate(list_exp):
                            hyp.dec_out.append(beam_dec_out[i])
                            hyp.dec_state = self.decoder.select_state(
                                beam_state, i)

                            if self.use_lm:
                                hyp.lm_state = beam_lm_states[i]
                                hyp.lm_scores = beam_lm_scores[i]

                        hyps = list_exp[:]
                    else:
                        beam_logp = torch.log_softmax(
                            self.joint_network(beam_enc_out, beam_dec_out),
                            dim=-1,
                        )

                        for i, hyp in enumerate(list_exp):
                            hyp.score += float(beam_logp[i, 0])

                            hyp.dec_out.append(beam_dec_out[i])
                            hyp.dec_state = self.decoder.select_state(
                                beam_state, i)

                            if self.use_lm:
                                hyp.lm_states = beam_lm_states[i]
                                hyp.lm_scores = beam_lm_scores[i]

                        kept_hyps = sorted(list_b + list_exp,
                                           key=lambda x: x.score,
                                           reverse=True)[:beam]

        return self.sort_nbest(kept_hyps)
Beispiel #44
0
    def __call__(self, outputs_all, targets_all):

        outputs = outputs_all
        targets = targets_all[0]
        pos = targets_all[1]
        dm = targets_all[2]

        self.num_classes = targets.shape[1]

        log_prob = torch.log_softmax(outputs, dim=1)

        if (self.KLD_weight > 0):
            loss = self.kldiv_loss(log_prob, targets) * self.KLD_weight
        else:
            loss = 0.0

        if (self.debug_mode):
            print('KL D loss is ', loss)

        if self.dice_weight > 0:
            eps = 1e-15
            dice_loss = 0
            for cls in range(self.num_classes):
                if (self.class_weights[cls] > 0):
                    target = targets[:, cls].float()
                    output = log_prob[:, cls].exp()

                    # for every class, compute dice for every sample
                    numerator = 2. * torch.sum(output * target, dim=(1, 2))
                    denominator = torch.sum(torch.square(output) +
                                            torch.square(target),
                                            dim=(1, 2))

                    cls_loss = torch.log(
                        torch.mean(numerator / (denominator + eps)))
                    cls_loss *= self.dice_weight

                    if (self.debug_mode):
                        print('   class loss is ', -cls_loss)

                    if (self.class_weights is not None):
                        dice_loss -= cls_loss * self.class_weights[cls]
                    else:
                        dice_loss -= cls_loss

            if (self.debug_mode):
                print('log dice_loss is ', dice_loss)

            loss += dice_loss

        if (self.l2_dist_weight > 0):
            v = 0
            for cls in range(1, self.num_classes):
                if (self.class_weights[cls] > 0):
                    output = log_prob[:, cls].exp()
                    dm_cls = dm[:, cls - 1].float()

                    #print(dm_cls.shape)
                    sv = torch.sum(output * dm_cls, dim=(1, 2))
                    #print(sv, torch.mean(sv))
                    v -= torch.log(torch.mean(sv)) * self.class_weights[cls]

                    if (self.debug_mode):
                        print('l2 dist loss is ', v)

            loss += v * self.l2_dist_weight

        return loss
    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        optim_steps = 0
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\
                 in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                # DEBUG: print pointer (set batch size to 1)
                # print(dset.id2datum[ques_id[0]]['sent'])
                # print(dset.id2datum[ques_id[0]]['label'])
                # q_pointer = dset.id2datum[ques_id[0]]['pointer']['question']
                # for w_index in q_pointer:
                #     print(w_index)

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda(
                ), sem_answer_words.cuda(), bboxes_words.cuda()
                logit, iou_target, iou_score = self.model(
                    feats, boxes, sent, iou_question, iou_answer,
                    sem_question_words, sem_answer_words, bboxes_words)
                assert logit.dim() == target.dim() == 2
                if args.mce_loss:
                    max_value, target = target.max(1)
                    loss = self.mce_loss(logit, target) * logit.size(1)
                else:
                    loss = self.bce_loss(logit, target)
                    loss = loss * logit.size(1)
                #print('CE', loss.item())

                if args.answer_loss == 'glove':
                    gold_glove = (self.labelans2glove.unsqueeze(0) *
                                  target.unsqueeze(-1)).sum(1)
                    #gold_ans = self.train_tuple.dataset.label2ans[target.argmax(dim=1)[0]]
                    #print('gold:', gold_ans)
                    pred_glove = (
                        self.labelans2glove.unsqueeze(0) *
                        torch.softmax(logit, dim=1).unsqueeze(-1)).sum(1)
                    #pred_ans = self.train_tuple.dataset.label2ans[logit.argmax(dim=1)[0]]
                    #print('pred:', pred_ans)
                    sim_answer = self.cosineSim(gold_glove, pred_glove).mean()
                    loss += -10 * sim_answer
                    #print('Similarity', sim_answer)
                    #input(' ')

                if optim_steps % 1000 == 0:
                    self.writerTbrd.add_scalar('vqa_loss_train', loss.item(),
                                               optim_steps)

                # task_pointer = 'KLDiv'
                ALPHA = args.alpha_pointer

                def iou_preprocess(iou, obj_conf=None):
                    TRESHOLD = 0.1
                    TOPK = 3
                    # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                    # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                    sorted_values = torch.sort(iou, descending=True, dim=-1)[0]
                    t_top = sorted_values[:, :, TOPK - 1]
                    iou_topk = iou.masked_fill(iou < t_top.unsqueeze(-1), -1e9)
                    f_iou = torch.softmax(iou_topk, dim=-1)
                    treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >=
                                     TRESHOLD).float()
                    if args.task_pointer == 'KLDiv':
                        return f_iou, treshold_mask
                    elif args.task_pointer == 'Triplet':
                        # Remove top10 most similar objects
                        t_bot = sorted_values[:, :, 10]
                        iou_botk = (iou < t_bot.unsqueeze(-1)).float()
                        # Take topk most confident objects
                        conf_top = torch.sort(obj_conf.unsqueeze(1) * iou_botk,
                                              descending=True,
                                              dim=-1)[0][:, :, TOPK - 1]
                        conf_mask = obj_conf.unsqueeze(1).expand(
                            -1, iou.size(1), -1) >= conf_top.unsqueeze(-1)
                        neg_score = iou_botk * conf_mask.float()
                        return f_iou, treshold_mask, neg_score

                if args.task_pointer == 'KLDiv':
                    iou_target_preprocess, treshold_mask = iou_preprocess(
                        iou_target)
                    loss_pointer_fct = KLDivLoss(reduction='none')
                    iou_pred = torch.log_softmax(iou_score, dim=-1)
                    matching_loss = loss_pointer_fct(
                        input=iou_pred, target=iou_target_preprocess)
                    matching_loss = ALPHA * (matching_loss.sum(-1) *
                                             treshold_mask).sum() / (
                                                 (treshold_mask).sum() + 1e-9)
                    if optim_steps % 1000 == 0:
                        self.writerTbrd.add_scalar('pointer_loss_train',
                                                   matching_loss.item(),
                                                   optim_steps)
                    loss += matching_loss

                # ? by Corentin: Matching loss
                # def iou_preprocess(iou):
                #     TRESHOLD = 0.1
                #     TOPK = 1
                #     # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                #     # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                #     t = torch.sort(iou, descending=True, dim=-1)[0][:, :, TOPK-1]
                #     iou_topk = iou.masked_fill(iou < t.unsqueeze(-1), -1e9)
                #     f_iou = torch.softmax(iou_topk, dim=-1)
                #     treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >= TRESHOLD).float()
                #     return f_iou, treshold_mask
                # # discard iou_target when total iou is under treshold
                # # it includes unsupervised datum
                # iou_target_preprocess, treshold_mask = iou_preprocess(iou_target)
                # iou_pred = torch.log_softmax(iou_pred, dim=-1)
                # # KL loss
                # matching_loss = []
                # matching_loss = self.KL_loss(input=iou_pred, target=iou_target_preprocess)
                # matching_loss = (matching_loss.sum(-1) * treshold_mask).sum() / treshold_mask.sum()
                # if optim_steps % 1000 == 0:
                #     self.writerTbrd.add_scalar('pointer_loss_train', matching_loss.item(), optim_steps)
                # ALPHA = 5.0
                # loss += ALPHA * matching_loss
                # ? **************************

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                optim_steps += 1

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

                # if self.valid_tuple is not None and optim_steps % 1152 == 0:  # Do Validation
                #     valid_score = self.evaluate(eval_tuple)
                #     fastepoch = int(optim_steps / 1152)
                #     print("fastEpoch %d: Valid %0.2f\n" % (fastepoch, valid_score * 100.,))

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                self.writerTbrd.add_scalar('vqa_acc_valid', valid_score, epoch)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")
Beispiel #46
0
    def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
        """Beam search implementation.

        Modified from https://arxiv.org/pdf/1211.3711.pdf

        Args:
            enc_out: Encoder output sequence. (T, D)

        Returns:
            nbest_hyps: N-best hypothesis.

        """
        beam = min(self.beam_size, self.vocab_size)
        beam_k = min(beam, (self.vocab_size - 1))

        dec_state = self.decoder.init_state(1)

        kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank_id], dec_state=dec_state)]
        cache = {}

        for enc_out_t in enc_out:
            hyps = kept_hyps
            kept_hyps = []

            while True:
                max_hyp = max(hyps, key=lambda x: x.score)
                hyps.remove(max_hyp)

                dec_out, state, lm_tokens = self.decoder.score(max_hyp, cache)

                logp = torch.log_softmax(
                    self.joint_network(
                        enc_out_t, dec_out, quantization=self.quantization
                    )
                    / self.softmax_temperature,
                    dim=-1,
                )
                top_k = logp[1:].topk(beam_k, dim=-1)

                kept_hyps.append(
                    Hypothesis(
                        score=(max_hyp.score + float(logp[0:1])),
                        yseq=max_hyp.yseq[:],
                        dec_state=max_hyp.dec_state,
                        lm_state=max_hyp.lm_state,
                    )
                )

                if self.use_lm:
                    lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens)
                else:
                    lm_state = max_hyp.lm_state

                for logp, k in zip(*top_k):
                    score = max_hyp.score + float(logp)

                    if self.use_lm:
                        score += self.lm_weight * lm_scores[0][k + 1]

                    hyps.append(
                        Hypothesis(
                            score=score,
                            yseq=max_hyp.yseq[:] + [int(k + 1)],
                            dec_state=state,
                            lm_state=lm_state,
                        )
                    )

                hyps_max = float(max(hyps, key=lambda x: x.score).score)
                kept_most_prob = sorted(
                    [hyp for hyp in kept_hyps if hyp.score > hyps_max],
                    key=lambda x: x.score,
                )
                if len(kept_most_prob) >= beam:
                    kept_hyps = kept_most_prob
                    break

        return self.sort_nbest(kept_hyps)
Beispiel #47
0
 def get_knn_probmass(tgts, dists, knn_tgts):
     tgts = torch.from_numpy(tgts).long().view(-1)
     dists = torch.from_numpy(dists).float().squeeze(-1)
     probs = torch.log_softmax(dists, dim=-1)
     mass = torch.exp(probs)
     return probs, mass