def test_slice_replace_simplification(check_dump): val = Value(0, 1, 2) val = val.replace(0, 1) val = val.replace(1, 1) val = val.replace(2, 1) check_dump(val, 'Constant <1.0, 1.0, 1.0>')
def test_replace_slice(): val = Value(1, 2, 3) val = val.replace(1, 0) assert all(val == (1, 0, 3)) val = val.replace(slice(1, None), -1) assert all(val == (1, -1, -1)) val = val.replace(slice(1), (2, 3)) assert all(val == (2, 3, -1, -1)) val = val.replace(slice(0, -1), ()) assert all(val == -1) assert all(val.replace(slice(None, None), ()) == ())
def test_slice_replace_simplification2(check_dump): val = Value(0, 1, 2) val = val.replace(1, val[1] + 3) val = val.replace(0, val[0] + 3) val = val.replace(2, val[2] + 3) check_dump(simplify(val), """ Concat <3.0, 4.0, 5.0>: + <3.0>: [0:1] <0.0>: Value <0.0, 1.0, 2.0> (&1) Constant <3.0> + <4.0>: [1:2] <1.0>: Value <0.0, 1.0, 2.0> (*1) Constant <3.0> + <5.0>: [2:3] <2.0>: Value <0.0, 1.0, 2.0> (*1) Constant <3.0> """)
def test_elementwise_manipulation_depth(): exp = Value(0, 0, 0) for i in range(10): exp = exp.replace(i % 3, exp[i % 3] - i) check_depth(exp, 4)
def test_slicing_depth(): exp = Value(0, 0, 0) for i in range(10): exp = exp.replace(1, exp[1] - 1) check_depth(exp, 4)