Exemplo n.º 1
0
def test_apply_steps_with_layout_rewrite():
    dag, s = get_tiled_matmul()
    _, bufs = dag.apply_steps_from_state(s, layout_rewrite=False)
    assert bufs[1].shape[0] == 512
    assert bufs[1].shape[1] == 512
    _, bufs = dag.apply_steps_from_state(s, layout_rewrite=True)
    assert bufs[1].shape[0] == 4
    assert bufs[1].shape[1] == 8
    assert bufs[1].shape[2] == 4
    assert bufs[1].shape[3] == 4
    assert bufs[1].shape[4] == 512
def test_measure_local_builder_runner():
    if not tvm.runtime.enabled("llvm"):
        return

    dag, s0 = get_tiled_matmul()
    tgt = tvm.target.create("llvm")
    task = auto_scheduler.SearchTask(dag, "test", tgt)

    minp = auto_scheduler.MeasureInput(task, s0)
    local_builder = auto_scheduler.LocalBuilder()
    local_runner = auto_scheduler.LocalRunner(timeout=60)

    bress = local_builder.build([minp])
    assert bress[0].error_no == 0
    mress = local_runner.run([minp], bress)
    assert mress[0].error_no == 0
Exemplo n.º 3
0
def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
    if not tvm.testing.device_enabled("llvm"):
        return

    dag, s0 = get_tiled_matmul()
    tgt = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(dag, "test", tgt)

    minp = auto_scheduler.MeasureInput(task, s0)
    local_builder = auto_scheduler.LocalBuilder()
    local_runner = auto_scheduler.LocalRunner(
        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush)

    bress = local_builder.build([minp])
    assert bress[0].error_no == 0
    mress = local_runner.run([minp], bress)
    assert mress[0].error_no == 0
Exemplo n.º 4
0
def test_apply_steps_with_layout_rewrite():
    dag, s = get_tiled_matmul()
    _, bufs = dag.apply_steps_from_state(s)
    assert bufs[1].shape[0] == 512
    assert bufs[1].shape[1] == 512
    _, bufs = dag.apply_steps_from_state(
        s, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
    )
    assert bufs[1].shape[0] == 4
    assert bufs[1].shape[1] == 8
    assert bufs[1].shape[2] == 4
    assert bufs[1].shape[3] == 4
    assert bufs[1].shape[4] == 512
    _, bufs = dag.apply_steps_from_state(
        s, layout_rewrite=auto_scheduler.LayoutRewriteOption.INSERT_TRANSFORM_STAGE
    )
    assert bufs[1].shape[0] == 512
    assert bufs[1].shape[1] == 512
Exemplo n.º 5
0
def test_apply_steps_with_layout_rewrite():
    dag, s = get_tiled_matmul()
    _, bufs = dag.apply_steps_from_state(s)
    assert bufs[1].shape[0] == 512
    assert bufs[1].shape[1] == 512
    _, bufs = dag.apply_steps_from_state(
        s,
        layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.
        RewriteForPreTransformed)
    assert bufs[1].shape[0] == 4
    assert bufs[1].shape[1] == 8
    assert bufs[1].shape[2] == 4
    assert bufs[1].shape[3] == 4
    assert bufs[1].shape[4] == 512
    _, bufs = dag.apply_steps_from_state(
        s,
        layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.
        InsertTransformStage)
    assert bufs[1].shape[0] == 512
    assert bufs[1].shape[1] == 512
def test_record():
    dag, s = get_tiled_matmul()

    if not tvm.runtime.enabled("llvm"):
        return
    target = tvm.target.create("llvm")
    task = auto_scheduler.SearchTask(dag, "test", target)

    inp = auto_scheduler.measure.MeasureInput(task, s)
    res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)

    with tempfile.NamedTemporaryFile() as fp:
        auto_scheduler.save_records(fp.name, [inp], [res])

        log_reader = auto_scheduler.RecordReader(fp.name)
        inputs, results = log_reader.read_lines()
        assert len(inputs) == 1

        s1 = dag.infer_bound_from_state(s)
        s2 = dag.infer_bound_from_state(inputs[0].state)

        assert s1 == s2
        assert not (s1 == dag.get_init_state())
def test_infer_bound():
    dag, s = get_tiled_matmul()
    s = dag.infer_bound_from_state(s)
def test_apply_steps():
    dag, s = get_tiled_matmul()
    dag.print_python_code_from_state(s)
    sch, tensors = dag.apply_steps_from_state(s)
    tvm.lower(sch, tensors, simple_mode=True)
Exemplo n.º 9
0
def test_estimate_flop():
    dag, s = get_tiled_matmul()
    assert abs(dag.flop_ct - 2 * 512**3) < 0.5