def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)):
        variables = [te.var("x" + str(i)) for i in range(num_vars)]

        relations = []
        for i in range(num_formulas):
            s1 = sum([v * random.randint(coef[0], coef[1]) for v in variables])
            s1 += random.randint(coef[0], coef[1])
            s2 = sum([v * random.randint(coef[0], coef[1]) for v in variables])
            s2 += random.randint(coef[0], coef[1])
            if random.random() < 0.7:
                op = tvm.tir.EQ
            else:
                # we also make sure it can correctly handle inequalities
                op = random.choice(
                    [tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT])
            relations.append(op(s1, s2))

        vranges = {
            v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1)
            for v in variables
        }
        solution = arith.solve_linear_equations(relations, variables, vranges)

        testing.check_int_constraints_trans_consistency(solution)

        # leaving some variables as parameters should also be ok
        for k in [1, 2]:
            if len(variables) > k:
                solution = arith.solve_linear_equations(
                    relations, variables[:-k], vranges)
                param_ranges = {v: vranges[v] for v in variables[-k:]}
                testing.check_int_constraints_trans_consistency(
                    solution, param_ranges)
def test_infer_range():
    x, y = te.var("x"), te.var("y")
    ranges = {
        x: tvm.ir.Range.from_min_extent(-5, 10),
        y: tvm.ir.Range.from_min_extent(0, 10),
    }

    solution = arith.solve_linear_equations(
        [
            tvm.tir.EQ(x + y, 0),
        ],
        [x, y],
        ranges,
    )
    [n0] = solution.dst.variables
    assert ir.structural_equal(solution.src_to_dst[x], n0)
    assert ir.structural_equal(solution.src_to_dst[y], -n0)
    # inferred from y's range
    assert ir.structural_equal(solution.dst.ranges[n0].min, -9)
    assert ir.structural_equal(solution.dst.ranges[n0].extent, 10)
    # additional inequality is added into the system for x
    [ineq] = solution.dst.relations
    assert isinstance(ineq, tvm.tir.LE)
    assert ir.structural_equal(ineq.a, -5)
    assert ir.structural_equal(ineq.b, n0)
def test_unique_solution():
    x, y = te.var("x"), te.var("y")

    solution = arith.solve_linear_equations([
        tvm.tir.EQ(x + y, 20),
        tvm.tir.EQ(x - y, 10),
    ], [x, y])
    assert list(solution.dst.variables) == []
    assert ir.structural_equal(solution.src_to_dst[x], 15)
    assert ir.structural_equal(solution.src_to_dst[y], 5)
def test_low_rank():
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    ranges = {}

    solution = arith.solve_linear_equations([
        tvm.tir.EQ(x + y + z, 15),
        tvm.tir.EQ(x + y, 10),
    ], [x, y, z], ranges)
    [n0] = solution.dst.variables
    assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
    assert ir.structural_equal(solution.src_to_dst[y], -n0)
    assert ir.structural_equal(solution.src_to_dst[z], 5)
def test_empty_var_to_solve():
    x, y = te.var("x"), te.var("y")
    equations = [
        tvm.tir.EQ(x + y, 20),
        tvm.tir.EQ(x - y, 10),
    ]
    solution = arith.solve_linear_equations(equations)
    assert len(solution.src_to_dst) == 0
    assert len(solution.dst_to_src) == 0
    assert len(solution.src.variables) == 0
    assert len(solution.src.ranges) == 0
    assert ir.structural_equal(solution.src.relations, equations)
    assert ir.structural_equal(solution.src, solution.dst)
def test_ill_formed():
    x, y = te.var("x"), te.var("y")

    solution = arith.solve_linear_equations([
        tvm.tir.EQ(x + y, 0),
        tvm.tir.EQ(x - y, 0),
        tvm.tir.EQ(x, 5),
    ], [x, y], {})
    assert list(solution.dst.variables) == []
    [rel] = solution.dst.relations
    assert ir.structural_equal(rel, False)
    assert len(solution.src_to_dst) == 0
    assert len(solution.dst_to_src) == 0