def test_aggregation_perplexity(): # for AUC (different aggregation functionality) metric_state = {} metric = Perplexity() y_pred = torch.randn(size=(100, 50)) y_true = torch.randint(0, 50, (100, )) for yp, yt in zip(torch.split(y_pred, 10), torch.split(y_true, 10)): metric.aggregate(metric_state, yp, yt) gt_value = torch.exp(torch.nn.CrossEntropyLoss()(y_pred, y_true).mean()).item() assert metric.finalize(metric_state) == approx(gt_value, NUMERIC_PRECISION)
def test_perplexity(): """Test perplexity""" metric_test_case(torch.tensor([[0.2, 0.8], [0.9, 0.1]]), torch.tensor([0, 0]), Perplexity(), 2.022418975830078)