def _test_score( metric: AccumulationMetric, batch: Dict[str, torch.Tensor], true_values: Dict[str, float], ) -> None: """Check if given metric works correctly""" metric.reset(num_batches=1, num_samples=len(batch["embeddings"])) metric.update(**batch) values = metric.compute_key_value() for key in true_values: assert key in values assert values[key] == true_values[key]
def test_accumulation_reset(generate_batched_data): # noqa: WPS442 """Check if AccumulationMetric accumulates all the data correctly with multiple resets""" for (fields_names, num_batches, num_samples, batches, true_values) in generate_batched_data: metric = AccumulationMetric(accumulative_fields=fields_names) for _ in range(5): metric.reset(num_batches=num_batches, num_samples=num_samples) for batch in batches: metric.update(**batch) for field_name in true_values: assert (true_values[field_name] == metric.storage[field_name]).all()
def test_accumulation(generate_batched_data) -> None: # noqa: WPS442 """ Check if AccumulationMetric accumulates all the data correctly along one loader """ for (fields_names, num_batches, num_samples, batches, true_values) in generate_batched_data: metric = AccumulationMetric(accumulative_fields=fields_names) metric.reset(num_batches=num_batches, num_samples=num_samples) for batch in batches: metric.update(**batch) for field_name in true_values: assert (true_values[field_name] == metric.storage[field_name]).all()
def test_accumulation_dtype(): """Check if AccumulationMetric accumulates all the data with correct types""" batch_size = 10 batch = { "field_int": torch.randint(low=0, high=5, size=(batch_size, 5)), "field_bool": torch.randint(low=0, high=2, size=(batch_size, 10), dtype=torch.bool), "field_float32": torch.rand(size=(batch_size, 4), dtype=torch.float32), } metric = AccumulationMetric(accumulative_fields=list(batch.keys())) metric.reset(num_samples=batch_size, num_batches=1) metric.update(**batch) for key in batch: assert (batch[key] == metric.storage[key]).all() assert batch[key].dtype == metric.storage[key].dtype