Exemplo n.º 1
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Here we don't truncate the target side and use masking.
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = sorted(self.trg_data[:batch_size],
                           key=lambda x: x.sent_len(),
                           reverse=True)
        trg_max = max([x.sent_len() for x in trg_sents])
        np_arr = np.zeros([batch_size, trg_max])
        for i in range(batch_size):
            for j in range(trg_sents[i].sent_len(), trg_max):
                np_arr[i, j] = 1.0
        trg_masks = Mask(np_arr)
        trg_sents_padded = [[w for w in s] + [Vocab.ES] *
                            (trg_max - s.sent_len()) for s in trg_sents]

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_padded = [
            sent.SimpleSentence(words=s) for s in trg_sents_padded
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss, _ = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents[sent_id]).compute()
            single_loss += train_loss.value()

        dy.renew_cg()

        batched_loss, _ = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_padded, trg_masks)).compute()
        self.assertAlmostEqual(single_loss,
                               np.sum(batched_loss.value()),
                               places=4)
Exemplo n.º 2
0
    def assert_single_loss_equals_batch_loss(self,
                                             model,
                                             pad_src_to_multiple=1):
        """
    Tests whether single loss equals batch loss.
    Truncating src / trg sents to same length so no masking is necessary
    """
        batch_size = 5
        src_sents = self.src_data[:batch_size]
        src_min = min([x.sent_len() for x in src_sents])
        src_sents_trunc = [s.words[:src_min] for s in src_sents]
        for single_sent in src_sents_trunc:
            single_sent[src_min - 1] = Vocab.ES
            while len(single_sent) % pad_src_to_multiple != 0:
                single_sent.append(Vocab.ES)
        trg_sents = self.trg_data[:batch_size]
        trg_min = min([x.sent_len() for x in trg_sents])
        trg_sents_trunc = [s.words[:trg_min] for s in trg_sents]
        for single_sent in trg_sents_trunc:
            single_sent[trg_min - 1] = Vocab.ES

        src_sents_trunc = [
            sent.SimpleSentence(words=s) for s in src_sents_trunc
        ]
        trg_sents_trunc = [
            sent.SimpleSentence(words=s) for s in trg_sents_trunc
        ]

        single_loss = 0.0
        for sent_id in range(batch_size):
            dy.renew_cg()
            train_loss, _ = MLELoss().calc_loss(
                model=model,
                src=src_sents_trunc[sent_id],
                trg=trg_sents_trunc[sent_id]).compute()
            single_loss += train_loss.value()

        dy.renew_cg()

        batched_loss, _ = MLELoss().calc_loss(
            model=model,
            src=mark_as_batch(src_sents_trunc),
            trg=mark_as_batch(trg_sents_trunc)).compute()
        self.assertAlmostEqual(single_loss,
                               np.sum(batched_loss.value()),
                               places=4)