def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): variables = [te.var("x" + str(i)) for i in range(num_vars)] relations = [] for i in range(num_formulas): s1 = sum([v * random.randint(coef[0], coef[1]) for v in variables]) s1 += random.randint(coef[0], coef[1]) s2 = sum([v * random.randint(coef[0], coef[1]) for v in variables]) s2 += random.randint(coef[0], coef[1]) if random.random() < 0.7: op = tvm.tir.EQ else: # we also make sure it can correctly handle inequalities op = random.choice( [tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT]) relations.append(op(s1, s2)) vranges = { v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables } solution = arith.solve_linear_equations(relations, variables, vranges) testing.check_int_constraints_trans_consistency(solution) # leaving some variables as parameters should also be ok for k in [1, 2]: if len(variables) > k: solution = arith.solve_linear_equations( relations, variables[:-k], vranges) param_ranges = {v: vranges[v] for v in variables[-k:]} testing.check_int_constraints_trans_consistency( solution, param_ranges)
def test_infer_range(): x, y = te.var("x"), te.var("y") ranges = { x: tvm.ir.Range.from_min_extent(-5, 10), y: tvm.ir.Range.from_min_extent(0, 10), } solution = arith.solve_linear_equations( [ tvm.tir.EQ(x + y, 0), ], [x, y], ranges, ) [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0) assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range assert ir.structural_equal(solution.dst.ranges[n0].min, -9) assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) assert ir.structural_equal(ineq.a, -5) assert ir.structural_equal(ineq.b, n0)
def test_unique_solution(): x, y = te.var("x"), te.var("y") solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 20), tvm.tir.EQ(x - y, 10), ], [x, y]) assert list(solution.dst.variables) == [] assert ir.structural_equal(solution.src_to_dst[x], 15) assert ir.structural_equal(solution.src_to_dst[y], 5)
def test_low_rank(): x, y, z = te.var("x"), te.var("y"), te.var("z") ranges = {} solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y + z, 15), tvm.tir.EQ(x + y, 10), ], [x, y, z], ranges) [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) assert ir.structural_equal(solution.src_to_dst[y], -n0) assert ir.structural_equal(solution.src_to_dst[z], 5)
def test_empty_var_to_solve(): x, y = te.var("x"), te.var("y") equations = [ tvm.tir.EQ(x + y, 20), tvm.tir.EQ(x - y, 10), ] solution = arith.solve_linear_equations(equations) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 assert len(solution.src.variables) == 0 assert len(solution.src.ranges) == 0 assert ir.structural_equal(solution.src.relations, equations) assert ir.structural_equal(solution.src, solution.dst)
def test_ill_formed(): x, y = te.var("x"), te.var("y") solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 0), tvm.tir.EQ(x - y, 0), tvm.tir.EQ(x, 5), ], [x, y], {}) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations assert ir.structural_equal(rel, False) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0