def test_metric_collection(tmpdir): m1 = DummyMetricSum() m2 = DummyMetricDiff() metric_collection = MetricCollection([m1, m2]) # Test correct dict structure assert len(metric_collection) == 2 assert metric_collection['DummyMetricSum'] == m1 assert metric_collection['DummyMetricDiff'] == m2 # Test correct initialization for name, metric in metric_collection.items(): assert metric.x == 0, f'Metric {name} not initialized correctly' # Test every metric gets updated metric_collection.update(5) for name, metric in metric_collection.items(): assert metric.x.abs() == 5, f'Metric {name} not updated correctly' # Test compute on each metric metric_collection.update(-5) metric_vals = metric_collection.compute() assert len(metric_vals) == 2 for name, metric_val in metric_vals.items(): assert metric_val == 0, f'Metric {name}.compute not called correctly' # Test that everything is reset for name, metric in metric_collection.items(): assert metric.x == 0, f'Metric {name} not reset correctly' # Test pickable metric_pickled = pickle.dumps(metric_collection) metric_loaded = pickle.loads(metric_pickled) assert isinstance(metric_loaded, MetricCollection)
def test_device_and_dtype_transfer_metriccollection(tmpdir): m1 = DummyMetricSum() m2 = DummyMetricDiff() metric_collection = MetricCollection([m1, m2]) for _, metric in metric_collection.items(): assert metric.x.is_cuda is False assert metric.x.dtype == torch.float32 metric_collection = metric_collection.to(device='cuda') for _, metric in metric_collection.items(): assert metric.x.is_cuda metric_collection = metric_collection.double() for _, metric in metric_collection.items(): assert metric.x.dtype == torch.float64 metric_collection = metric_collection.half() for _, metric in metric_collection.items(): assert metric.x.dtype == torch.float16