def test_rejects_no_strings():
    """#SPC-asts.tst-no_strings"""
    class DummyClass:
        var1: int
        var2: List

    with raises(MacroError, match="annotations must be strings"):
        collect_vars(DummyClass)
def test_accepts_comparison():
    """#SPC-asts.tst-comparison"""
    class DummyClass:
        var: "0 < var < 1"

    items = collect_vars(DummyClass)
    comp_node = parse(items[0])
    assert type(comp_node) is Compare
def test_rejects_malformed_annotation(annotation):
    """#SPC-asts.tst-not_comparison"""
    class DummyClass:
        var: f"{annotation}"

    items = collect_vars(DummyClass)
    with raises(MacroError):
        parse(items[0])
def test_accepts_mixed_annotations():
    """#SPC-asts.tst-mixed_strings"""
    class DummyClass:
        var1: "arbitrary string"
        var2: int

    items = collect_vars(DummyClass)
    assert len(items) == 1
def test_rejects_equal_endpoints(endpoint):
    """#SPC-asts.tst-equal"""
    assume(not isinf(endpoint))

    class DummyClass:
        var: f"{endpoint} < var < {endpoint}"

    items = collect_vars(DummyClass)
    comp_node = parse(items[0])
    with raises(MacroError, match="must be less than"):
        extract_endpoints(comp_node)
def test_rejects_inf_nan(endpoints):
    """#SPC-asts.tst-rejects_inf_nan"""
    # Make sure that one of the endpoints is `inf` or `nan`
    assume(any(map(lambda x: isnan(x) or isinf(x), endpoints)))
    lower = endpoints[0]
    upper = endpoints[1]

    class DummyClass:
        var: f"{lower} < var < {upper}"

    items = collect_vars(DummyClass)
    comp_node = parse(items[0])
    with raises(MacroError, match="is not a valid range endpoint"):
        extract_endpoints(comp_node)
def test_rejects_out_of_order_endpoints(endpoints):
    """#SPC-asts.tst-order"""
    lower = endpoints[1]  # note that this is backwards!
    upper = endpoints[0]
    assume(lower != upper)
    for e in endpoints:
        assume(not isinf(e))
        assume(not isnan(e))

    class DummyClass:
        var: f"{lower} < var < {upper}"

    items = collect_vars(DummyClass)
    comp_node = parse(items[0])
    with raises(MacroError, match="must be less than"):
        extract_endpoints(comp_node)
def test_accepts_valid_int_endpoints(endpoints):
    """#SPC-asts.tst-valid_ints"""
    lower = endpoints[0]
    upper = endpoints[1]
    assume(lower != upper)
    for e in endpoints:
        assume(not isinf(e))
        assume(not isnan(e))

    class DummyClass:
        var: f"{lower} < var < {upper}"

    items = collect_vars(DummyClass)
    comp_node = parse(items[0])
    ext_lower, ext_upper = extract_endpoints(comp_node)
    assert lower == ext_lower
    assert upper == ext_upper