Esempio n. 1
0
def test_meta_schedule_integration_apply_history_best():
    class DummyDatabase(PyDatabase):
        def __init__(self):
            super().__init__()
            self.records = []
            self.workload_reg = []

        def has_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return True
            return False

        def commit_tuning_record(self, record: TuningRecord) -> None:
            self.records.append(record)

        def commit_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return workload
            workload = Workload(mod)
            self.workload_reg.append(workload)
            return workload

        def get_top_k(self, workload: Workload,
                      top_k: int) -> List[TuningRecord]:
            return list(
                filter(
                    lambda x: x.workload == workload,
                    sorted(self.records,
                           key=lambda x: sum(x.run_secs) / len(x.run_secs)),
                ))[:int(top_k)]

        def __len__(self) -> int:
            return len(self.records)

        def print_results(self) -> None:
            print("\n".join([str(r) for r in self.records]))

    mod, _, _, _ = get_network(
        name="resnet-18",
        batch_size=1,
        layout="NHWC",
        dtype="float32",
    )
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []))
    mod = env.query(task_name="mock-task",
                    mod=mod,
                    target=target,
                    dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
Esempio n. 2
0
def test_meta_schedule_integration_task_extraction_query():
    mod, _, _, _ = get_network(
        name="resnet-18",
        batch_size=1,
        layout="NHWC",
        dtype="float32",
    )
    env = TaskExtraction()
    env.query(task_name="mock-task", mod=mod, dispatched=[MockModule])
    _check_mock_task(env.tasks, mod)
Esempio n. 3
0
def test_meta_schedule_integration_extract_from_resnet():
    mod, params, _, _ = get_network(
        name="resnet-18",
        batch_size=1,
        layout="NHWC",
        dtype="float32",
    )
    extracted_tasks = ms.integration.extract_task_from_relay(mod,
                                                             target="llvm",
                                                             params=params)
    assert len(extracted_tasks) == 30
Esempio n. 4
0
def test_meta_schedule_integration_query_inside_with_scope():
    mod, _, _, _ = get_network(
        name="resnet-18",
        batch_size=1,
        layout="NHWC",
        dtype="float32",
    )
    env = TaskExtraction()
    with env:
        MetaScheduleContext.query_inside_with_scope(
            task_name="mock-task",
            mod=mod,
            dispatched=[MockModule],
        )
    _check_mock_task(env.tasks, mod)
Esempio n. 5
0
def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool):

    mod, params, input_shape, output_shape = get_network(name=model_name, batch_size=batch_size)
    verify_meta_schedule_with_tensorrt(
        mod, params, input_shape, use_meta_sched=use_meta_sched, use_trt=use_trt, mode="vm"
    )