def test_multi_equal(): x, y, z = te.var("x"), te.var("y"), te.var("z") problem = [ tvm.tir.LE(x, 6), tvm.tir.GE(x, 6), tvm.tir.GE(x - z * y, 0), tvm.tir.LE(x - z * y, 0), ] solution = arith.solve_linear_inequalities(problem, [x, y, z]) assert solution.ranges[x].min == 6 assert solution.ranges[x].extent == 1 assert len(solution.relations) == 3 assert ir.structural_equal(solution.relations[0], x == z * y) assert isinstance(solution.relations[1], tvm.tir.LE) assert solution.relations[1].b == 0 assert isinstance(solution.relations[2], tvm.tir.LE) assert solution.relations[2].b == 0 # (z*y - 6) <= 0 && (6 - z*y) <= 0 ana = tvm.arith.Analyzer() assert ana.simplify(solution.relations[1].a + solution.relations[2].a) == 0 assert ir.structural_equal(solution.relations[1].a, (z * y - 6)) or ir.structural_equal( solution.relations[2].a, (z * y - 6) ) solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True) assert solution.src_to_dst[y] == y assert solution.src_to_dst[z] == z assert solution.src_to_dst[x] == 6
def test_global_var_supply_from_none(): var_supply = GlobalVarSupply() global_var = GlobalVar("test") var_supply.reserve_global(global_var) assert structural_equal(var_supply.unique_global_for("test"), global_var) assert not structural_equal(var_supply.fresh_global("test"), global_var)
def test_unique_solution(): x, y = te.var("x"), te.var("y") solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 20), tvm.tir.EQ(x - y, 10), ], [x, y]) assert list(solution.dst.variables) == [] assert ir.structural_equal(solution.src_to_dst[x], 15) assert ir.structural_equal(solution.src_to_dst[y], 5)
def test_global_var_supply_from_name_supply(): name_supply = NameSupply("prefix") var_supply = GlobalVarSupply(name_supply) global_var = GlobalVar("test") var_supply.reserve_global(global_var) assert structural_equal(var_supply.unique_global_for("test", False), global_var) assert not structural_equal(var_supply.unique_global_for("test"), global_var)
def test_low_rank(): x, y, z = te.var("x"), te.var("y"), te.var("z") ranges = {} solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y + z, 15), tvm.tir.EQ(x + y, 10), ], [x, y, z], ranges) [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) assert ir.structural_equal(solution.src_to_dst[y], -n0) assert ir.structural_equal(solution.src_to_dst[z], 5)
def test_empty_var_to_solve(): x, y = te.var("x"), te.var("y") equations = [ tvm.tir.EQ(x + y, 20), tvm.tir.EQ(x - y, 10), ] solution = arith.solve_linear_equations(equations) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 assert len(solution.src.variables) == 0 assert len(solution.src.ranges) == 0 assert ir.structural_equal(solution.src.relations, equations) assert ir.structural_equal(solution.src, solution.dst)
def test_global_var_supply_from_ir_mod(): x = relay.var("x") y = relay.var("y") mod = tvm.IRModule() global_var = GlobalVar("test") mod[global_var] = relay.Function([x, y], relay.add(x, y)) var_supply = GlobalVarSupply(mod) second_global_var = var_supply.fresh_global("test", False) assert structural_equal(var_supply.unique_global_for("test", False), global_var) assert not structural_equal(var_supply.unique_global_for("test"), global_var) assert not structural_equal(second_global_var, global_var)
def check(dim, axis, nstep): eps = 0.01 ttype1 = rly.TensorType(tuple(10 for i in range(dim)), dtype) ttype2 = rly.TensorType((10,), dtype) x = rly.var("x", ttype1) beta = rly.var("beta", ttype2) gamma = rly.var("gamma", ttype2) moving_var = rly.var("moving_var", ttype2) moving_mean = rly.var("moving_mean", ttype2) y1, y2 = x, x for _ in range(nstep): y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, dtype), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) y1 = rly.nn.dropout(y1) y2 = simple_bn(y2 + rly.const(1, dtype), gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis, shape=ttype1.shape) mod = IRModule.from_expr(y1) simplify = SimplifyInference() mod = simplify(mod) y1 = mod["main"].body assert structural_equal(y1, y2, map_free_vars=True)
def check_solution(solution, vranges={}): """Check that solution is a bijective transformation""" def _check_forward(constraints1, constraints2, varmap, backvarmap): ana = tvm.arith.Analyzer() all_vranges = vranges.copy() all_vranges.update({v: r for v, r in constraints1.ranges.items()}) # Check that the transformation is injective cond_on_vars = tir.const(1, 'bool') for v in constraints1.variables: # variable mapping is consistent v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) cond_on_vars = te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true cond_subst = tir.stmt_functor.substitute( te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) # We have to include relations from vranges too for v in constraints2.variables: if v in constraints2.ranges: r = constraints2.ranges[v] range_cond = te.all(v >= r.min, v < r.min + r.extent) range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) cond_subst = ana.simplify(cond_subst) check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) rels = solution.dst.relations if len(rels) == 1 and ir.structural_equal(rels[0], False): # not solvable, skip return _check_forward(solution.src, solution.dst, solution.src_to_dst, solution.dst_to_src) _check_forward(solution.dst, solution.src, solution.dst_to_src, solution.src_to_dst)
def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", text_format: Union[str, Sequence[str]] = ["python", "json"], ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask text_format: Union[str, Sequence[str]] The text format or formats whose round-trip behavior should be validated. If a single string, validate round-trips through """ if not isinstance(text_format, str): for opt in text_format: new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, text_format=opt) return new_sch trace = sch.trace assert trace is not None # Step 1. Perform a round-trip through the text-format new_sch = Schedule(mod=mod, debug_mask=debug_mask) if text_format == "json": json_obj = trace.as_json() Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) elif text_format == "python": py_trace = "\n".join(trace.as_python()) exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" # Step 2. Verify that the round-trip produced the same scheduling assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def _find_match_sketch_id( mod: IRModule, sketches: List[Schedule], expected_mod: IRModule, expected_decision: List[Tuple[str, List[int]]], *, debug_mask="all", ) -> Optional[int]: for sketch_id, sketch in enumerate(sketches): i = 0 new_decisions = {} for inst in sketch.trace.insts: if not inst.kind.name.startswith("Sample"): continue assert i < len(expected_decision) if inst.kind.name == expected_decision[i][0]: new_decisions[inst] = expected_decision[i][1] i += 1 if len(new_decisions) != len(expected_decision): continue sch = Schedule(mod, debug_mask=debug_mask) Trace( insts=sketch.trace.insts, decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) return sketch_id return None
def test_infer_range(): x, y = te.var("x"), te.var("y") ranges = { x: tvm.ir.Range.from_min_extent(-5, 10), y: tvm.ir.Range.from_min_extent(0, 10), } solution = arith.solve_linear_equations( [ tvm.tir.EQ(x + y, 0), ], [x, y], ranges, ) [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0) assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range assert ir.structural_equal(solution.dst.ranges[n0].min, -9) assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) assert ir.structural_equal(ineq.a, -5) assert ir.structural_equal(ineq.b, n0)
def test_multi_equal(): x, y, z = te.var("x"), te.var("y"), te.var("z") problem = [ tvm.tir.LE(x, 6), tvm.tir.GE(x, 6), tvm.tir.GE(x - z * y, 0), tvm.tir.LE(x - z * y, 0), ] solution = arith.solve_linear_inequalities(problem, [x, y, z]) assert solution.ranges[x].min == 6 assert solution.ranges[x].extent == 1 assert len(solution.relations) == 3 assert ir.structural_equal(solution.relations[0], x == z * y) assert ir.structural_equal(solution.relations[1], z * y - 6 <= 0) assert ir.structural_equal(solution.relations[2], 6 - z * y <= 0) solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True) assert solution.src_to_dst[y] == y assert solution.src_to_dst[z] == z assert solution.src_to_dst[x] == 6
def test_ill_formed(): x, y = te.var("x"), te.var("y") solution = arith.solve_linear_equations([ tvm.tir.EQ(x + y, 0), tvm.tir.EQ(x - y, 0), tvm.tir.EQ(x, 5), ], [x, y], {}) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations assert ir.structural_equal(rel, False) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0
def test_no_solution(): x = te.var("x0") vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)} problem = [-x - 4 <= -5 * x + 2, x * 4 + 5 <= x * 5] solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations assert ir.structural_equal(rel, False) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 solution = arith.solve_linear_inequalities(problem, [x], vranges) assert len(solution.variables) == 0 assert len(solution.ranges) == 0 [rel] = solution.relations assert not rel
def verify_trace_roundtrip( sch: Schedule, mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "all", ) -> Schedule: """Serialize a traced schedule to JSON, then replay the JSON trace by applying to a fresh new schedule, verifying the reproducibility of scheduling. Parameters ---------- sch : tir.Schedule The traced TensorIR schedule to be verified mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to construct the fresh new schedule debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. Possible choices of `debug_mask`: 1) "all" - Turn on all the checks 2) "none" - Turn off all the checks 3) An integer - Turn on checks according to the bitmasks provided in ScheduleDebugMask """ # Step 1. Serialize the trace to JSON trace = sch.trace assert trace is not None json_obj = trace.as_json() # Step 2. Apply the JSON trace to a new schedule, then check if it reproduces the scheduling new_sch = Schedule(mod=mod, debug_mask=debug_mask) Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) assert structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) new_py_repr = "\n".join(new_sch.trace.as_python()) assert py_repr == new_py_repr # Step 4. Return the new schedule in case it could be useful return new_sch
def test_dual_variable(): x, y = te.var("x"), te.var("y") variables = [x, y] ranges = { x: tvm.ir.Range(-100, 100), y: tvm.ir.Range(0, 10), } problem = [ tvm.tir.LE(x + y, 20), tvm.tir.GE(x - y, 10), ] # solution as conditions solution = arith._ffi_api.SolveInequalitiesAsCondition( variables, ranges, problem) assert ir.structural_equal(solution[0], x >= (y + 10)) assert ir.structural_equal(solution[1], x <= (20 - y)) assert ir.structural_equal(solution[2], y >= 0) assert ir.structural_equal(solution[3], y <= 5) # solve and get the ranges solution = arith.solve_linear_inequalities(problem, variables, ranges) # 0 <= y <=5 assert solution.ranges[y].min == 0 assert solution.ranges[y].extent == 6 # y + 10 <= x <= 20 - y assert ir.structural_equal(solution.ranges[x].min, y + 10) assert solution.ranges[x].extent == 11 # max(10 - 2y) # deskew the solved ranges to be starting from zero solution = arith.solve_linear_inequalities(problem, variables, ranges, deskew_range=True) [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10) assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) assert ir.structural_equal(solution.dst_to_src[y_new], y)