def test_meta_schedule_database_reload(): mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) token = database.commit_workload(mod) trace = _create_schedule(mod, _schedule_matmul).trace records = [ TuningRecord( trace, [7.0, 8.0, 9.0], token, tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, [1.0, 2.0, 3.0], token, tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), TuningRecord( trace, [4.0, 5.0, 6.0], token, tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ), ] for record in records: database.commit_tuning_record(record) new_database = JSONDatabase( # pylint: disable=unused-variable path_workload=database.path_workload, path_tuning_record=database.path_tuning_record, ) token = new_database.commit_workload(mod) ret = new_database.get_top_k(token, 2) assert len(ret) == 2 try: _equal_record(ret[0], records[2]) _equal_record(ret[1], records[1]) except AssertionError: _equal_record(ret[0], records[1]) _equal_record(ret[1], records[2])
def test_meta_schedule_tuning_record_round_trip(): mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) record = TuningRecord( _create_schedule(mod, _schedule_matmul).trace, [1.5, 2.5, 1.8], workload, tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ) database.commit_tuning_record(record) new_record = TuningRecord.from_json(record.as_json(), workload) _equal_record(record, new_record)
def test_meta_schedule_database_add_entry(): mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) record = TuningRecord( _create_schedule(mod, _schedule_matmul).trace, [1.5, 2.5, 1.8], workload, tvm.target.Target("llvm"), ArgInfo.from_prim_func(func=mod["main"]), # pylint: disable=unsubscriptable-object ) database.commit_tuning_record(record) assert len(database) == 1 (ret, ) = database.get_top_k(workload, 3) _equal_record(ret, record)
def test_meta_schedule_arg_info_from_prim_func(): a_info, b_info, c_info = ArgInfo.from_prim_func(Matmul) assert str(a_info) == 'TensorInfo("float32", [128, 256])' assert str(b_info) == 'TensorInfo("float32", [256, 512])' assert str(c_info) == 'TensorInfo("float32", [128, 512])'