def test_perplexity(self): nll = NLLLoss() ppl = Perplexity() for output, target in zip(self.outputs, self.targets): nll.eval_step(output, target) ppl.eval_step(output, target) nll_loss = nll.get_loss() ppl_loss = ppl.get_loss() self.assertAlmostEqual(ppl_loss, math.exp(nll_loss))
def test_perplexity_init(self): loss = Perplexity() self.assertEqual(loss.name, Perplexity._NAME) self.assertEqual(loss.log_name, Perplexity._SHORTNAME)