def test_interpolation_ramps(): t = Value(0) for x in range(-2, 5): exp = Interpolation(0, x, t) for i in range(-10, 20): t.set(i / 10) assert exp == simplify(exp) == t * x
def test_slice_replace_simplification(check_dump): val = Value(0, 1, 2) val = val.replace(0, 1) val = val.replace(1, 1) val = val.replace(2, 1) check_dump(val, 'Constant <1.0, 1.0, 1.0>')
def test_interpolation_chain_simplification(chain_length): t = Value(0) exp = Interpolation(1, 0, t) for i in range(chain_length): exp = Interpolation(exp, i + 1, t) print(id(exp)) with reduce_to_const(exp): t.fix()
def test_value_setting(): exp = Value(4) assert exp == 4 exp.set(6) assert exp == 6 with pytest.raises(ValueError): exp.set(6, 7)
def test_value_fix_1(): exp = Value(6) exp.fix() assert exp == 6 with pytest.raises(ValueError): exp.set(4) assert exp == 6 assert simplify(exp) == 6 assert isinstance(simplify(exp), Constant)
def test_box(): name = 'Boxed variable' val = Value(0, 0, 0) exp = Box(name, val) assert all(exp == (0, 0, 0)) assert exp.pretty_name == name val.fix(3, 3, 3) assert all(exp == (3, 3, 3)) exp.value = Value(6, 6, 6) assert all(exp == (6, 6, 6))
def test_concat_simplification(): v1 = Value(1) v2 = Value(2) v3 = Value(3) with reduce_to_const(Concat(v1, v2, v3)) as exp: v1.fix() v2.fix() assert len(exp.children) == 2 v3.fix()
def test_reduce_simplification(cls): v1 = Value(1) v2 = Value(2) v3 = Value(3) with reduce_to_const(cls([v1, v2, v3])) as exp: v1.fix() v2.fix() assert len(exp.children) == 2 v3.fix()
def test_interpolation(): val1 = Value(0, 1, 5, 1) val2 = Value(10, 1, 0, 2) t = Value(0) exp = Interpolation(val1, val2, t) assert all(exp == simplify(exp)) assert all(exp == (0, 1, 5, 1)) t.set(1) assert all(exp == simplify(exp)) assert all(exp == (10, 1, 0, 2)) t.set(0.5) assert all(exp == simplify(exp)) assert all(exp == (5, 1, 2.5, 1.5))
def test_slice_replace_simplification2(check_dump): val = Value(0, 1, 2) val = val.replace(1, val[1] + 3) val = val.replace(0, val[0] + 3) val = val.replace(2, val[2] + 3) check_dump(simplify(val), """ Concat <3.0, 4.0, 5.0>: + <3.0>: [0:1] <0.0>: Value <0.0, 1.0, 2.0> (&1) Constant <3.0> + <4.0>: [1:2] <1.0>: Value <0.0, 1.0, 2.0> (*1) Constant <3.0> + <5.0>: [2:3] <2.0>: Value <0.0, 1.0, 2.0> (*1) Constant <3.0> """)
def test_interpolation_const_to_const_simplification(check_dump, range_4): val1 = Value(1, 2, 3) if range_4 % 1 else Constant(1, 2, 3) val2 = Value(1, 2, 3) if range_4 % 2 else Constant(1, 2, 3) exp = Interpolation(val1, val2, Value(0.5)) with reduce_to_const(exp): if range_4 % 1: val1.fix() if range_4 % 2: val2.fix()
def test_index_get(): val = Value(1, 2, 3) assert val[1] == 2 assert val[2] == 3 assert val[-2] == 2 assert val[-3] == 1 assert all(val[:-1] == (1, 2)) assert val[-1:] == (3, ) with pytest.raises(IndexError): val[3] with pytest.raises(IndexError): val[-80] assert len(val[1:1]) == 0 assert len(val[2:1]) == 0 with pytest.raises(TypeError): val[None] first_item = val[0] last_item = val[-1] first_two = val[:2] assert first_item == 1 assert last_item == 3 assert all(first_two == (1, 2)) val.set(2, 3, 4) assert first_item == 2 assert last_item == 4 assert all(first_two == (2, 3)) assert len(val[1]) == 1 assert len(val[:1]) == 1 assert len(val[1:]) == 2 assert len(val[:]) == 3 assert len(val[:-1]) == 2
def test_replace_slice(): val = Value(1, 2, 3) val = val.replace(1, 0) assert all(val == (1, 0, 3)) val = val.replace(slice(1, None), -1) assert all(val == (1, -1, -1)) val = val.replace(slice(1), (2, 3)) assert all(val == (2, 3, -1, -1)) val = val.replace(slice(0, -1), ()) assert all(val == -1) assert all(val.replace(slice(None, None), ()) == ())
def test_complex_concat_simplification(check_dump): val1 = Value(1) val2 = Value(4, 5) cat = Concat(val1, 2, 3, val2) assert all(cat == (1, 2, 3, 4, 5)) check_dump(cat, """ Concat <1.0, 2.0, 3.0, 4.0, 5.0>: Value <1.0> Constant <2.0, 3.0> Value <4.0, 5.0> """) val1.fix() val2.fix() check_dump(simplify(cat), 'Constant <1.0, 2.0, 3.0, 4.0, 5.0>')
def test_elementwise_manipulation_depth(): exp = Value(0, 0, 0) for i in range(10): exp = exp.replace(i % 3, exp[i % 3] - i) check_depth(exp, 4)
def test_slicing_depth(): exp = Value(0, 0, 0) for i in range(10): exp = exp.replace(1, exp[1] - 1) check_depth(exp, 4)
def test_neg_simplification(): val = Value(1, 2, 3, 4) with reduce_to_const(Neg(val)): val.fix()
def test_slice_simplification(start, end): val = Value(1, 2, 3, 4) with reduce_to_const(Slice(val, slice(start, end))): val.fix()
def test_concat_of_slice_simplification_0(i, check_dump): val = Value(0, 1, 2) exp = Concat(val[0], val[1], val[2]) val.fix() exp = simplify(exp[i]) check_dump(exp, "Constant <%s.0>" % i)
def test_concat(): val = Value(3) cat = Concat(val, 2, Value(5, 4)) assert all(cat == (3, 2, 5, 4)) val.set(8) assert all(cat == (8, 2, 5, 4))