def test_rejects_no_strings(): """#SPC-asts.tst-no_strings""" class DummyClass: var1: int var2: List with raises(MacroError, match="annotations must be strings"): collect_vars(DummyClass)
def test_accepts_comparison(): """#SPC-asts.tst-comparison""" class DummyClass: var: "0 < var < 1" items = collect_vars(DummyClass) comp_node = parse(items[0]) assert type(comp_node) is Compare
def test_rejects_malformed_annotation(annotation): """#SPC-asts.tst-not_comparison""" class DummyClass: var: f"{annotation}" items = collect_vars(DummyClass) with raises(MacroError): parse(items[0])
def test_accepts_mixed_annotations(): """#SPC-asts.tst-mixed_strings""" class DummyClass: var1: "arbitrary string" var2: int items = collect_vars(DummyClass) assert len(items) == 1
def test_rejects_equal_endpoints(endpoint): """#SPC-asts.tst-equal""" assume(not isinf(endpoint)) class DummyClass: var: f"{endpoint} < var < {endpoint}" items = collect_vars(DummyClass) comp_node = parse(items[0]) with raises(MacroError, match="must be less than"): extract_endpoints(comp_node)
def test_rejects_inf_nan(endpoints): """#SPC-asts.tst-rejects_inf_nan""" # Make sure that one of the endpoints is `inf` or `nan` assume(any(map(lambda x: isnan(x) or isinf(x), endpoints))) lower = endpoints[0] upper = endpoints[1] class DummyClass: var: f"{lower} < var < {upper}" items = collect_vars(DummyClass) comp_node = parse(items[0]) with raises(MacroError, match="is not a valid range endpoint"): extract_endpoints(comp_node)
def test_rejects_out_of_order_endpoints(endpoints): """#SPC-asts.tst-order""" lower = endpoints[1] # note that this is backwards! upper = endpoints[0] assume(lower != upper) for e in endpoints: assume(not isinf(e)) assume(not isnan(e)) class DummyClass: var: f"{lower} < var < {upper}" items = collect_vars(DummyClass) comp_node = parse(items[0]) with raises(MacroError, match="must be less than"): extract_endpoints(comp_node)
def test_accepts_valid_int_endpoints(endpoints): """#SPC-asts.tst-valid_ints""" lower = endpoints[0] upper = endpoints[1] assume(lower != upper) for e in endpoints: assume(not isinf(e)) assume(not isnan(e)) class DummyClass: var: f"{lower} < var < {upper}" items = collect_vars(DummyClass) comp_node = parse(items[0]) ext_lower, ext_upper = extract_endpoints(comp_node) assert lower == ext_lower assert upper == ext_upper