def test_multi_equal():
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    problem = [
        tvm.tir.LE(x, 6),
        tvm.tir.GE(x, 6),
        tvm.tir.GE(x - z * y, 0),
        tvm.tir.LE(x - z * y, 0),
    ]

    solution = arith.solve_linear_inequalities(problem, [x, y, z])
    assert solution.ranges[x].min == 6
    assert solution.ranges[x].extent == 1
    assert len(solution.relations) == 3
    assert ir.structural_equal(solution.relations[0], x == z * y)

    assert isinstance(solution.relations[1], tvm.tir.LE)
    assert solution.relations[1].b == 0
    assert isinstance(solution.relations[2], tvm.tir.LE)
    assert solution.relations[2].b == 0
    # (z*y - 6) <= 0 && (6 - z*y) <= 0
    ana = tvm.arith.Analyzer()
    assert ana.simplify(solution.relations[1].a + solution.relations[2].a) == 0
    assert ir.structural_equal(solution.relations[1].a, (z * y - 6)) or ir.structural_equal(
        solution.relations[2].a, (z * y - 6)
    )

    solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True)
    assert solution.src_to_dst[y] == y
    assert solution.src_to_dst[z] == z
    assert solution.src_to_dst[x] == 6
Exemple #2
0
def test_global_var_supply_from_none():
    var_supply = GlobalVarSupply()
    global_var = GlobalVar("test")
    var_supply.reserve_global(global_var)

    assert structural_equal(var_supply.unique_global_for("test"), global_var)
    assert not structural_equal(var_supply.fresh_global("test"), global_var)
def test_unique_solution():
    x, y = te.var("x"), te.var("y")

    solution = arith.solve_linear_equations([
        tvm.tir.EQ(x + y, 20),
        tvm.tir.EQ(x - y, 10),
    ], [x, y])
    assert list(solution.dst.variables) == []
    assert ir.structural_equal(solution.src_to_dst[x], 15)
    assert ir.structural_equal(solution.src_to_dst[y], 5)
Exemple #4
0
def test_global_var_supply_from_name_supply():
    name_supply = NameSupply("prefix")
    var_supply = GlobalVarSupply(name_supply)
    global_var = GlobalVar("test")
    var_supply.reserve_global(global_var)

    assert structural_equal(var_supply.unique_global_for("test", False),
                            global_var)
    assert not structural_equal(var_supply.unique_global_for("test"),
                                global_var)
def test_low_rank():
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    ranges = {}

    solution = arith.solve_linear_equations([
        tvm.tir.EQ(x + y + z, 15),
        tvm.tir.EQ(x + y, 10),
    ], [x, y, z], ranges)
    [n0] = solution.dst.variables
    assert ir.structural_equal(solution.src_to_dst[x], n0 + 10)
    assert ir.structural_equal(solution.src_to_dst[y], -n0)
    assert ir.structural_equal(solution.src_to_dst[z], 5)
def test_empty_var_to_solve():
    x, y = te.var("x"), te.var("y")
    equations = [
        tvm.tir.EQ(x + y, 20),
        tvm.tir.EQ(x - y, 10),
    ]
    solution = arith.solve_linear_equations(equations)
    assert len(solution.src_to_dst) == 0
    assert len(solution.dst_to_src) == 0
    assert len(solution.src.variables) == 0
    assert len(solution.src.ranges) == 0
    assert ir.structural_equal(solution.src.relations, equations)
    assert ir.structural_equal(solution.src, solution.dst)
Exemple #7
0
def test_global_var_supply_from_ir_mod():
    x = relay.var("x")
    y = relay.var("y")
    mod = tvm.IRModule()
    global_var = GlobalVar("test")
    mod[global_var] = relay.Function([x, y], relay.add(x, y))
    var_supply = GlobalVarSupply(mod)

    second_global_var = var_supply.fresh_global("test", False)

    assert structural_equal(var_supply.unique_global_for("test", False),
                            global_var)
    assert not structural_equal(var_supply.unique_global_for("test"),
                                global_var)
    assert not structural_equal(second_global_var, global_var)
    def check(dim, axis, nstep):
        eps = 0.01
        ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype)
        ttype2 = rly.TensorType((10,), dtype)
        x = rly.var("x", ttype1)
        beta = rly.var("beta", ttype2)
        gamma = rly.var("gamma", ttype2)
        moving_var = rly.var("moving_var", ttype2)
        moving_mean = rly.var("moving_mean", ttype2)
        y1, y2 = x, x

        for _ in range(nstep):
            y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype),
                gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
            y1 = rly.nn.dropout(y1)
            y2 = simple_bn(y2 + rly.const(1, dtype),
                           gamma, beta, moving_mean, moving_var,
                           epsilon=eps, axis=axis, shape=ttype1.shape)

        mod = IRModule.from_expr(y1)
        simplify = SimplifyInference()
        mod = simplify(mod)
        y1 = mod["main"].body

        assert structural_equal(y1, y2, map_free_vars=True)
def check_solution(solution, vranges={}):
    """Check that solution is a bijective transformation"""
    def _check_forward(constraints1, constraints2, varmap, backvarmap):
        ana = tvm.arith.Analyzer()
        all_vranges = vranges.copy()
        all_vranges.update({v: r for v, r in constraints1.ranges.items()})

        # Check that the transformation is injective
        cond_on_vars = tir.const(1, 'bool')
        for v in constraints1.variables:
            # variable mapping is consistent
            v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap))
            cond_on_vars = te.all(cond_on_vars, v == v_back)
        # Also we have to check that the new relations are true when old relations are true
        cond_subst = tir.stmt_functor.substitute(
            te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap)
        # We have to include relations from vranges too
        for v in constraints2.variables:
            if v in constraints2.ranges:
                r = constraints2.ranges[v]
                range_cond = te.all(v >= r.min, v < r.min + r.extent)
                range_cond = tir.stmt_functor.substitute(range_cond, backvarmap)
                cond_subst = te.all(cond_subst, range_cond)
        cond_subst = ana.simplify(cond_subst)
        check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges,
                         cond=te.all(tir.const(1, 'bool'), *constraints1.relations))

    rels = solution.dst.relations
    if len(rels) == 1 and ir.structural_equal(rels[0], False):
        # not solvable, skip
        return
    _check_forward(solution.src, solution.dst,
                   solution.src_to_dst, solution.dst_to_src)
    _check_forward(solution.dst, solution.src,
                   solution.dst_to_src, solution.src_to_dst)
Exemple #10
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
    text_format: Union[str, Sequence[str]] = ["python", "json"],
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    text_format: Union[str, Sequence[str]]
        The text format or formats whose round-trip behavior should be
        validated.  If a single string, validate round-trips through
    """
    if not isinstance(text_format, str):
        for opt in text_format:
            new_sch = verify_trace_roundtrip(sch,
                                             mod,
                                             debug_mask=debug_mask,
                                             text_format=opt)
        return new_sch

    trace = sch.trace
    assert trace is not None

    # Step 1. Perform a round-trip through the text-format
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    if text_format == "json":
        json_obj = trace.as_json()
        Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    elif text_format == "python":
        py_trace = "\n".join(trace.as_python())
        exec(py_trace, tvm.tir.__dict__, {"sch": new_sch})  # pylint: disable=exec-used
    else:
        assert text_format in ("json",
                               "python"), f"Unknown text format: {text_format}"

    # Step 2. Verify that the round-trip produced the same scheduling
    assert structural_equal(new_sch.mod, sch.mod)

    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr

    # Step 4. Return the new schedule in case it could be useful
    return new_sch
Exemple #11
0
def _find_match_sketch_id(
    mod: IRModule,
    sketches: List[Schedule],
    expected_mod: IRModule,
    expected_decision: List[Tuple[str, List[int]]],
    *,
    debug_mask="all",
) -> Optional[int]:
    for sketch_id, sketch in enumerate(sketches):
        i = 0
        new_decisions = {}
        for inst in sketch.trace.insts:
            if not inst.kind.name.startswith("Sample"):
                continue
            assert i < len(expected_decision)
            if inst.kind.name == expected_decision[i][0]:
                new_decisions[inst] = expected_decision[i][1]
                i += 1
        if len(new_decisions) != len(expected_decision):
            continue
        sch = Schedule(mod, debug_mask=debug_mask)
        Trace(
            insts=sketch.trace.insts,
            decisions=new_decisions,
        ).apply_to_schedule(sch, remove_postproc=True)
        if structural_equal(sch.mod, expected_mod):
            verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
            return sketch_id
    return None
def test_infer_range():
    x, y = te.var("x"), te.var("y")
    ranges = {
        x: tvm.ir.Range.from_min_extent(-5, 10),
        y: tvm.ir.Range.from_min_extent(0, 10),
    }

    solution = arith.solve_linear_equations(
        [
            tvm.tir.EQ(x + y, 0),
        ],
        [x, y],
        ranges,
    )
    [n0] = solution.dst.variables
    assert ir.structural_equal(solution.src_to_dst[x], n0)
    assert ir.structural_equal(solution.src_to_dst[y], -n0)
    # inferred from y's range
    assert ir.structural_equal(solution.dst.ranges[n0].min, -9)
    assert ir.structural_equal(solution.dst.ranges[n0].extent, 10)
    # additional inequality is added into the system for x
    [ineq] = solution.dst.relations
    assert isinstance(ineq, tvm.tir.LE)
    assert ir.structural_equal(ineq.a, -5)
    assert ir.structural_equal(ineq.b, n0)
Exemple #13
0
def test_multi_equal():
    x, y, z = te.var("x"), te.var("y"), te.var("z")
    problem = [
        tvm.tir.LE(x, 6),
        tvm.tir.GE(x, 6),
        tvm.tir.GE(x - z * y, 0),
        tvm.tir.LE(x - z * y, 0),
    ]

    solution = arith.solve_linear_inequalities(problem, [x, y, z])
    assert solution.ranges[x].min == 6
    assert solution.ranges[x].extent == 1
    assert len(solution.relations) == 3
    assert ir.structural_equal(solution.relations[0], x == z * y)
    assert ir.structural_equal(solution.relations[1], z * y - 6 <= 0)
    assert ir.structural_equal(solution.relations[2], 6 - z * y <= 0)

    solution = arith.solve_linear_inequalities(problem, [x, y, z],
                                               deskew_range=True)
    assert solution.src_to_dst[y] == y
    assert solution.src_to_dst[z] == z
    assert solution.src_to_dst[x] == 6
def test_ill_formed():
    x, y = te.var("x"), te.var("y")

    solution = arith.solve_linear_equations([
        tvm.tir.EQ(x + y, 0),
        tvm.tir.EQ(x - y, 0),
        tvm.tir.EQ(x, 5),
    ], [x, y], {})
    assert list(solution.dst.variables) == []
    [rel] = solution.dst.relations
    assert ir.structural_equal(rel, False)
    assert len(solution.src_to_dst) == 0
    assert len(solution.dst_to_src) == 0
def test_no_solution():
    x = te.var("x0")
    vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)}
    problem = [-x - 4 <= -5 * x + 2, x * 4 + 5 <= x * 5]

    solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True)
    assert list(solution.dst.variables) == []
    [rel] = solution.dst.relations
    assert ir.structural_equal(rel, False)
    assert len(solution.src_to_dst) == 0
    assert len(solution.dst_to_src) == 0

    solution = arith.solve_linear_inequalities(problem, [x], vranges)
    assert len(solution.variables) == 0
    assert len(solution.ranges) == 0
    [rel] = solution.relations
    assert not rel
Exemple #16
0
def verify_trace_roundtrip(
    sch: Schedule,
    mod: Union[PrimFunc, IRModule],
    *,
    debug_mask: Union[str, int] = "all",
) -> Schedule:
    """Serialize a traced schedule to JSON, then replay the JSON trace by applying to
    a fresh new schedule, verifying the reproducibility of scheduling.

    Parameters
    ----------
    sch : tir.Schedule
        The traced TensorIR schedule to be verified
    mod : Union[PrimFunc, IRModule]
        The IRModule or PrimFunc to construct the fresh new schedule
    debug_mask : Union[str, int]
        Do extra correctness checking after the class creation and each time
        after calling the Replace method.
        Possible choices of `debug_mask`:
        1) "all" - Turn on all the checks
        2) "none" - Turn off all the checks
        3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask
    """
    # Step 1. Serialize the trace to JSON
    trace = sch.trace
    assert trace is not None
    json_obj = trace.as_json()
    # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling
    new_sch = Schedule(mod=mod, debug_mask=debug_mask)
    Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
    assert structural_equal(new_sch.mod, sch.mod)
    # Step 3. Check the consistency of the text format between the old and new traces
    py_repr = "\n".join(trace.as_python())
    new_py_repr = "\n".join(new_sch.trace.as_python())
    assert py_repr == new_py_repr
    # Step 4. Return the new schedule in case it could be useful
    return new_sch
Exemple #17
0
def test_dual_variable():
    x, y = te.var("x"), te.var("y")

    variables = [x, y]
    ranges = {
        x: tvm.ir.Range(-100, 100),
        y: tvm.ir.Range(0, 10),
    }
    problem = [
        tvm.tir.LE(x + y, 20),
        tvm.tir.GE(x - y, 10),
    ]

    # solution as conditions
    solution = arith._ffi_api.SolveInequalitiesAsCondition(
        variables, ranges, problem)
    assert ir.structural_equal(solution[0], x >= (y + 10))
    assert ir.structural_equal(solution[1], x <= (20 - y))
    assert ir.structural_equal(solution[2], y >= 0)
    assert ir.structural_equal(solution[3], y <= 5)

    # solve and get the ranges
    solution = arith.solve_linear_inequalities(problem, variables, ranges)
    # 0 <= y <=5
    assert solution.ranges[y].min == 0
    assert solution.ranges[y].extent == 6
    # y + 10 <= x <= 20 - y
    assert ir.structural_equal(solution.ranges[x].min, y + 10)
    assert solution.ranges[x].extent == 11  # max(10 - 2y)

    # deskew the solved ranges to be starting from zero
    solution = arith.solve_linear_inequalities(problem,
                                               variables,
                                               ranges,
                                               deskew_range=True)
    [x_new, y_new] = solution.dst.variables
    [rel] = solution.dst.relations
    assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10)
    assert ir.structural_equal(solution.dst.ranges[x_new].min, 0)
    assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11)
    assert ir.structural_equal(solution.dst.ranges[y_new].min, 0)
    assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6)
    assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10))
    assert ir.structural_equal(solution.src_to_dst[y], y_new)
    assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10)
    assert ir.structural_equal(solution.dst_to_src[y_new], y)