def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Here we don't truncate the target side and use masking. """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: x.sent_len(), reverse=True) trg_max = max([x.sent_len() for x in trg_sents]) np_arr = np.zeros([batch_size, trg_max]) for i in range(batch_size): for j in range(trg_sents[i].sent_len(), trg_max): np_arr[i, j] = 1.0 trg_masks = Mask(np_arr) trg_sents_padded = [[w for w in s] + [Vocab.ES] * (trg_max - s.sent_len()) for s in trg_sents] src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_padded = [ sent.SimpleSentence(words=s) for s in trg_sents_padded ] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss, _ = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents[sent_id]).compute() single_loss += train_loss.value() dy.renew_cg() batched_loss, _ = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_padded, trg_masks)).compute() self.assertAlmostEqual(single_loss, np.sum(batched_loss.value()), places=4)
def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1): """ Tests whether single loss equals batch loss. Truncating src / trg sents to same length so no masking is necessary """ batch_size = 5 src_sents = self.src_data[:batch_size] src_min = min([x.sent_len() for x in src_sents]) src_sents_trunc = [s.words[:src_min] for s in src_sents] for single_sent in src_sents_trunc: single_sent[src_min - 1] = Vocab.ES while len(single_sent) % pad_src_to_multiple != 0: single_sent.append(Vocab.ES) trg_sents = self.trg_data[:batch_size] trg_min = min([x.sent_len() for x in trg_sents]) trg_sents_trunc = [s.words[:trg_min] for s in trg_sents] for single_sent in trg_sents_trunc: single_sent[trg_min - 1] = Vocab.ES src_sents_trunc = [ sent.SimpleSentence(words=s) for s in src_sents_trunc ] trg_sents_trunc = [ sent.SimpleSentence(words=s) for s in trg_sents_trunc ] single_loss = 0.0 for sent_id in range(batch_size): dy.renew_cg() train_loss, _ = MLELoss().calc_loss( model=model, src=src_sents_trunc[sent_id], trg=trg_sents_trunc[sent_id]).compute() single_loss += train_loss.value() dy.renew_cg() batched_loss, _ = MLELoss().calc_loss( model=model, src=mark_as_batch(src_sents_trunc), trg=mark_as_batch(trg_sents_trunc)).compute() self.assertAlmostEqual(single_loss, np.sum(batched_loss.value()), places=4)