def _test(p, r, average, n_epochs): n_iters = 60 s = 16 n_classes = 7 offset = n_iters * s y_true = torch.randint(0, n_classes, size=(offset * dist.get_world_size(), )).to(device) y_preds = torch.rand(offset * dist.get_world_size(), n_classes).to(device) def update(engine, i): return y_preds[i * s + rank * offset:(i + 1) * s + rank * offset, :], \ y_true[i * s + rank * offset:(i + 1) * s + rank * offset] engine = Engine(update) fbeta = Fbeta(beta=2.5, average=average, device=device) fbeta.attach(engine, "f2.5") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "f2.5" in engine.state.metrics res = engine.state.metrics['f2.5'] if isinstance(res, torch.Tensor): res = res.cpu().numpy() true_res = fbeta_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy(), beta=2.5, average='macro' if average else None) assert pytest.approx(res) == true_res
def _test(p, r, average, output_transform): np.random.seed(1) n_iters = 10 batch_size = 10 n_classes = 10 y_true = np.arange(0, n_iters * batch_size) % n_classes y_pred = 0.2 * np.random.rand(n_iters * batch_size, n_classes) for i in range(n_iters * batch_size): if np.random.rand() > 0.4: y_pred[i, y_true[i]] = 1.0 else: j = np.random.randint(0, n_classes) y_pred[i, j] = 0.7 y_true_batch_values = iter(y_true.reshape(n_iters, batch_size)) y_pred_batch_values = iter(y_pred.reshape(n_iters, batch_size, n_classes)) def update_fn(engine, batch): y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) if output_transform is not None: return { "y_pred": torch.from_numpy(y_pred_batch), "y": torch.from_numpy(y_true_batch), } return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) evaluator = Engine(update_fn) f2 = Fbeta( beta=2.0, average=average, precision=p, recall=r, output_transform=output_transform, ) f2.attach(evaluator, "f2") data = list(range(n_iters)) state = evaluator.run(data, max_epochs=1) f2_true = fbeta_score( y_true, np.argmax(y_pred, axis=-1), average="macro" if average else None, beta=2.0, ) if isinstance(state.metrics["f2"], torch.Tensor): np.testing.assert_allclose(f2_true, state.metrics["f2"].numpy()) else: assert f2_true == pytest.approx(state.metrics["f2"]), "{} vs {}".format( f2_true, state.metrics["f2"] )