Ejemplo n.º 1
0
def test_load_state_dict(tmpdir):
    """ test that metric states can be loaded with state dict """
    metric = DummyMetricSum()
    metric.persistent(True)
    metric.update(5)
    loaded_metric = DummyMetricSum()
    loaded_metric.load_state_dict(metric.state_dict())
    assert metric.compute() == 5
Ejemplo n.º 2
0
def test_collection_add_metrics():
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()

    collection = MetricCollection([m1])
    collection.add_metrics({'m1_': DummyMetricSum()})
    collection.add_metrics(m2)

    collection.update(5)
    results = collection.compute()
    assert results['DummyMetricSum'] == results['m1_'] and results['m1_'] == 5
    assert results['DummyMetricDiff'] == -5
Ejemplo n.º 3
0
def test_metric_collection_args_kwargs(tmpdir):
    """ Check that args and kwargs gets passed correctly in metric collection,
        Checks both update and forward method
    """
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()

    metric_collection = MetricCollection([m1, m2])

    # args gets passed to all metrics
    metric_collection.update(5)
    assert metric_collection['DummyMetricSum'].x == 5
    assert metric_collection['DummyMetricDiff'].x == -5
    metric_collection.reset()
    _ = metric_collection(5)
    assert metric_collection['DummyMetricSum'].x == 5
    assert metric_collection['DummyMetricDiff'].x == -5
    metric_collection.reset()

    # kwargs gets only passed to metrics that it matches
    metric_collection.update(x=10, y=20)
    assert metric_collection['DummyMetricSum'].x == 10
    assert metric_collection['DummyMetricDiff'].x == -20
    metric_collection.reset()
    _ = metric_collection(x=10, y=20)
    assert metric_collection['DummyMetricSum'].x == 10
    assert metric_collection['DummyMetricDiff'].x == -20
Ejemplo n.º 4
0
def test_metric_collection_same_order():
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()
    col1 = MetricCollection({"a": m1, "b": m2})
    col2 = MetricCollection({"b": m2, "a": m1})
    for k1, k2 in zip(col1.keys(), col2.keys()):
        assert k1 == k2
Ejemplo n.º 5
0
def test_metric_collection(tmpdir):
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()

    metric_collection = MetricCollection([m1, m2])

    # Test correct dict structure
    assert len(metric_collection) == 2
    assert metric_collection['DummyMetricSum'] == m1
    assert metric_collection['DummyMetricDiff'] == m2

    # Test correct initialization
    for name, metric in metric_collection.items():
        assert metric.x == 0, f'Metric {name} not initialized correctly'

    # Test every metric gets updated
    metric_collection.update(5)
    for name, metric in metric_collection.items():
        assert metric.x.abs() == 5, f'Metric {name} not updated correctly'

    # Test compute on each metric
    metric_collection.update(-5)
    metric_vals = metric_collection.compute()
    assert len(metric_vals) == 2
    for name, metric_val in metric_vals.items():
        assert metric_val == 0, f'Metric {name}.compute not called correctly'

    # Test that everything is reset
    for name, metric in metric_collection.items():
        assert metric.x == 0, f'Metric {name} not reset correctly'

    # Test pickable
    metric_pickled = pickle.dumps(metric_collection)
    metric_loaded = pickle.loads(metric_pickled)
    assert isinstance(metric_loaded, MetricCollection)
Ejemplo n.º 6
0
def test_reset_compute():
    a = DummyMetricSum()
    assert a.x == 0
    a.update(tensor(5))
    assert a.compute() == 5
    a.reset()
    assert a.compute() == 0
Ejemplo n.º 7
0
def test_metric_forward_cache_reset():
    """test that forward cache is reset when `reset` is called."""
    metric = DummyMetricSum()
    _ = metric(2.0)
    assert metric._forward_cache == 2.0
    metric.reset()
    assert metric._forward_cache is None
Ejemplo n.º 8
0
def _test_ddp_compositional_tensor(rank, worldsize):
    setup_ddp(rank, worldsize)
    dummy = DummyMetricSum()
    dummy._reductions = {"x": torch.sum}
    dummy = dummy.clone() + dummy.clone()
    dummy.update(tensor(1))
    val = dummy.compute()
    assert val == 2 * worldsize
Ejemplo n.º 9
0
def test_reset_compute():
    a = DummyMetricSum()
    assert a.x == 0
    a.update(tensor(5))
    assert a.compute() == 5
    a.reset()
    if not _LIGHTNING_AVAILABLE or _LIGHTNING_GREATER_EQUAL_1_3:
        assert a.compute() == 0
    else:
        assert a.compute() == 5
Ejemplo n.º 10
0
def test_metric_collection_prefix_postfix_args(prefix, postfix):
    """ Test that the prefix arg alters the keywords in the output"""
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()
    names = ['DummyMetricSum', 'DummyMetricDiff']
    names = [prefix + n if prefix is not None else n for n in names]
    names = [n + postfix if postfix is not None else n for n in names]

    metric_collection = MetricCollection([m1, m2],
                                         prefix=prefix,
                                         postfix=postfix)

    # test forward
    out = metric_collection(5)
    for name in names:
        assert name in out, 'prefix or postfix argument not working as intended with forward method'

    # test compute
    out = metric_collection.compute()
    for name in names:
        assert name in out, 'prefix or postfix argument not working as intended with compute method'

    # test clone
    new_metric_collection = metric_collection.clone(prefix='new_prefix_')
    out = new_metric_collection(5)
    names = [n[len(prefix):] if prefix is not None else n
             for n in names]  # strip away old prefix
    for name in names:
        assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'

    for k, _ in new_metric_collection.items():
        assert 'new_prefix_' in k

    for k in new_metric_collection.keys():
        assert 'new_prefix_' in k

    for k, _ in new_metric_collection.items(keep_base=True):
        assert 'new_prefix_' not in k

    for k in new_metric_collection.keys(keep_base=True):
        assert 'new_prefix_' not in k

    assert type(new_metric_collection.keys(keep_base=True)) == type(
        new_metric_collection.keys(keep_base=False))  # noqa E721
    assert type(new_metric_collection.items(keep_base=True)) == type(
        new_metric_collection.items(keep_base=False))  # noqa E721

    new_metric_collection = new_metric_collection.clone(postfix='_new_postfix')
    out = new_metric_collection(5)
    names = [n[:-len(postfix)] if postfix is not None else n
             for n in names]  # strip away old postfix
    for name in names:
        assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method'
Ejemplo n.º 11
0
def test_device_and_dtype_transfer(tmpdir):
    metric = DummyMetricSum()
    assert metric.x.is_cuda is False
    assert metric.x.dtype == torch.float32

    metric = metric.to(device='cuda')
    assert metric.x.is_cuda

    metric = metric.double()
    assert metric.x.dtype == torch.float64

    metric = metric.half()
    assert metric.x.dtype == torch.float16
Ejemplo n.º 12
0
def test_constant_memory(device, requires_grad):
    """Checks that when updating a metric the memory does not increase."""
    if not torch.cuda.is_available() and device == "cuda":
        pytest.skip("Test requires GPU support")

    def get_memory_usage():
        if device == "cpu":
            pid = os.getpid()
            py = psutil.Process(pid)
            return py.memory_info()[0] / 2.0**30
        else:
            return torch.cuda.memory_allocated()

    x = torch.randn(10, requires_grad=requires_grad, device=device)

    # try update method
    metric = DummyMetricSum().to(device)

    metric.update(x.sum())

    # we allow for 5% flucturation due to measuring
    base_memory_level = 1.05 * get_memory_usage()

    for _ in range(10):
        metric.update(x.sum())
        memory = get_memory_usage()
        assert base_memory_level >= memory, "memory increased above base level"

    # try forward method
    metric = DummyMetricSum().to(device)
    metric(x.sum())
    base_memory_level = get_memory_usage()

    for _ in range(10):
        metric.update(x.sum())
        memory = get_memory_usage()
        assert base_memory_level >= memory, "memory increased above base level"
Ejemplo n.º 13
0
def test_pickle(tmpdir):
    # doesn't tests for DDP
    a = DummyMetricSum()
    a.update(1)

    metric_pickled = pickle.dumps(a)
    metric_loaded = pickle.loads(metric_pickled)

    assert metric_loaded.compute() == 1

    metric_loaded.update(5)
    assert metric_loaded.compute() == 6

    metric_pickled = cloudpickle.dumps(a)
    metric_loaded = cloudpickle.loads(metric_pickled)

    assert metric_loaded.compute() == 1
Ejemplo n.º 14
0
def test_metric_collection_wrong_input(tmpdir):
    """ Check that errors are raised on wrong input """
    dms = DummyMetricSum()

    # Not all input are metrics (list)
    with pytest.raises(ValueError):
        _ = MetricCollection([dms, 5])

    # Not all input are metrics (dict)
    with pytest.raises(ValueError):
        _ = MetricCollection({'metric1': dms, 'metric2': 5})

    # Same metric passed in multiple times
    with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
        _ = MetricCollection([dms, dms])

    # Not a list or dict passed in
    with pytest.warns(Warning, match=' which are not `Metric` so they will be ignored.'):
        _ = MetricCollection(dms, [dms])
Ejemplo n.º 15
0
def test_metric_collection_wrong_input(tmpdir):
    """ Check that errors are raised on wrong input """
    m1 = DummyMetricSum()

    # Not all input are metrics (list)
    with pytest.raises(ValueError):
        _ = MetricCollection([m1, 5])

    # Not all input are metrics (dict)
    with pytest.raises(ValueError):
        _ = MetricCollection({'metric1': m1, 'metric2': 5})

    # Same metric passed in multiple times
    with pytest.raises(ValueError, match='Encountered two metrics both named *.'):
        _ = MetricCollection([m1, m1])

    # Not a list or dict passed in
    with pytest.raises(ValueError, match='Unknown input to MetricCollection.'):
        _ = MetricCollection(m1)
Ejemplo n.º 16
0
def test_device_and_dtype_transfer(tmpdir):
    metric = DummyMetricSum()
    assert metric.x.is_cuda is False
    assert metric.device == torch.device("cpu")
    assert metric.x.dtype == torch.float32

    metric = metric.to(device="cuda")
    assert metric.x.is_cuda
    assert metric.device == torch.device("cuda", index=0)

    metric.set_dtype(torch.double)
    assert metric.x.dtype == torch.float64
    metric.reset()
    assert metric.x.dtype == torch.float64

    metric.set_dtype(torch.half)
    assert metric.x.dtype == torch.float16
    metric.reset()
    assert metric.x.dtype == torch.float16
Ejemplo n.º 17
0
def test_warning_on_compute_before_update():
    metric = DummyMetricSum()

    # make sure everything is fine with forward
    with pytest.warns(None) as record:
        val = metric(1)
    assert not record

    metric.reset()

    with pytest.warns(UserWarning, match=r'The ``compute`` method of metric .*'):
        val = metric.compute()
    assert val == 0.0

    # after update things should be fine
    metric.update(2.0)
    with pytest.warns(None) as record:
        val = metric.compute()
    assert not record
    assert val == 2.0
Ejemplo n.º 18
0
def test_device_and_dtype_transfer_metriccollection(tmpdir):
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()

    metric_collection = MetricCollection([m1, m2])
    for _, metric in metric_collection.items():
        assert metric.x.is_cuda is False
        assert metric.x.dtype == torch.float32

    metric_collection = metric_collection.to(device='cuda')
    for _, metric in metric_collection.items():
        assert metric.x.is_cuda

    metric_collection = metric_collection.double()
    for _, metric in metric_collection.items():
        assert metric.x.dtype == torch.float64

    metric_collection = metric_collection.half()
    for _, metric in metric_collection.items():
        assert metric.x.dtype == torch.float16
Ejemplo n.º 19
0
def test_warning_on_compute_before_update():
    """test that an warning is raised if user tries to call compute before update."""
    metric = DummyMetricSum()

    # make sure everything is fine with forward
    with pytest.warns(None) as record:
        val = metric(1)
    assert not record

    metric.reset()

    with pytest.warns(UserWarning, match=r"The ``compute`` method of metric .*"):
        val = metric.compute()
    assert val == 0.0

    # after update things should be fine
    metric.update(2.0)
    with pytest.warns(None) as record:
        val = metric.compute()
    assert not record
    assert val == 2.0
Ejemplo n.º 20
0
def test_metric_collection_prefix_arg(tmpdir):
    """ Test that the prefix arg alters the keywords in the output"""
    m1 = DummyMetricSum()
    m2 = DummyMetricDiff()
    names = ['DummyMetricSum', 'DummyMetricDiff']

    metric_collection = MetricCollection([m1, m2], prefix='prefix_')

    # test forward
    out = metric_collection(5)
    for name in names:
        assert f"prefix_{name}" in out, 'prefix argument not working as intended with forward method'

    # test compute
    out = metric_collection.compute()
    for name in names:
        assert f"prefix_{name}" in out, 'prefix argument not working as intended with compute method'

    # test clone
    new_metric_collection = metric_collection.clone(prefix='new_prefix_')
    out = new_metric_collection(5)
    for name in names:
        assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'
Ejemplo n.º 21
0
def test_metric_scripts():
    torch.jit.script(DummyMetric())
    torch.jit.script(DummyMetricSum())
Ejemplo n.º 22
0
def test_metric_forward_cache_reset():
    metric = DummyMetricSum()
    _ = metric(2.0)
    assert metric._forward_cache == 2.0
    metric.reset()
    assert metric._forward_cache is None
Ejemplo n.º 23
0
def test_metric_scripts():
    """test that metrics are scriptable."""
    torch.jit.script(DummyMetric())
    torch.jit.script(DummyMetricSum())