Example #1
0
def test_aggregation_perplexity():
    # for AUC (different aggregation functionality)
    metric_state = {}
    metric = Perplexity()
    y_pred = torch.randn(size=(100, 50))
    y_true = torch.randint(0, 50, (100, ))
    for yp, yt in zip(torch.split(y_pred, 10), torch.split(y_true, 10)):
        metric.aggregate(metric_state, yp, yt)
    gt_value = torch.exp(torch.nn.CrossEntropyLoss()(y_pred, y_true).mean()).item()
    assert metric.finalize(metric_state) == approx(gt_value, NUMERIC_PRECISION)
Example #2
0
def test_perplexity():
    """Test perplexity"""
    metric_test_case(torch.tensor([[0.2, 0.8], [0.9, 0.1]]),
                     torch.tensor([0, 0]), Perplexity(), 2.022418975830078)