예제 #1
0
def test_zip_strict():
    # Iterables with different lengths
    list_a = [0, 1]
    list_b = [1, 2, 3]
    # zip does not raise any error
    for _, _ in zip(list_a, list_b):
        pass

    # zip_strict does raise an error
    with pytest.raises(ValueError):
        for _, _ in zip_strict(list_a, list_b):
            pass

    # same length, should not raise an error
    for _, _ in zip_strict(list_a, list_b[:len(list_a)]):
        pass
def params_should_differ(params, other_params):
    for param, other_param in zip_strict(params, other_params):
        assert not th.allclose(param, other_param)
def params_should_match(params, other_params):
    for param, other_param in zip_strict(params, other_params):
        assert th.allclose(param, other_param)