Ejemplo n.º 1
0
def test_correctness_layout_rewrite_insert_transform_stage():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner=measure_ctx.runner,
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(
            inp.state,
            layout_rewrite=auto_scheduler.compute_dag.ComputeDAG.
            InsertTransformStage)

        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        tvm.testing.assert_allclose(args[0].asnumpy(),
                                    args_ref[0].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[1].asnumpy(),
                                    args_ref[1].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        tvm.testing.assert_allclose(args[2].asnumpy(),
                                    args_ref[2].asnumpy(),
                                    atol=1e-3,
                                    rtol=1e-3)
        del measure_ctx
Ejemplo n.º 2
0
def test_layout_rewrite_correctness():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N),
                                      target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner="local",
            verbose=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        auto_scheduler.auto_schedule(task, search_policy, tuning_options)
        inp, _ = auto_scheduler.load_best(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(inp.state, layout_rewrite=True)
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state,
                                                     layout_rewrite=False)
        np_args = [
            np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype)
            for x in bufs
        ]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = ([
                2 + base,
                4 + base,
            ] + list(range(base + 2)) + [
                3 + base,
                5 + base,
            ])
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        np.testing.assert_allclose(np_args[0], np_args_ref[0])
        np.testing.assert_allclose(np_args[2], np_args_ref[2])
Ejemplo n.º 3
0
def test_correctness_layout_rewrite_rewrite_for_preTransformed():
    N = 128
    target = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test, args=(N, N, N), target=target)
    dag = task.compute_dag

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        search_policy = auto_scheduler.SketchPolicy(task)

        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tuning_options = auto_scheduler.TuningOptions(
            num_measure_trials=2,
            runner=measure_ctx.runner,
            verbose=2,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        task.tune(tuning_options, search_policy=search_policy)
        inp, _ = auto_scheduler.load_best_record(log_file, task.workload_key, target)
        s, bufs = dag.apply_steps_from_state(
            inp.state, layout_rewrite=auto_scheduler.LayoutRewriteOption.REWRITE_FOR_PRE_TRANSFORMED
        )
        s_ref, bufs_ref = dag.apply_steps_from_state(inp.state)
        np_args = [np.random.randn(*topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs]
        np_args_ref = [np.array(x) for x in np_args]

        weight = np_args_ref[1]
        # infer shape for the rewritten layout
        if len(weight.shape) >= 6:
            # For cpu tile structure SSRSRS
            base = len(weight.shape) - 6
            red_dim = weight.shape[2 + base] * weight.shape[4 + base]
            out_dim = weight.shape[3 + base] * weight.shape[5 + base]
            for i in range(base + 2):
                out_dim *= weight.shape[i]
            new_order = (
                [
                    2 + base,
                    4 + base,
                ]
                + list(range(base + 2))
                + [
                    3 + base,
                    5 + base,
                ]
            )
            np_args_ref[1] = np_args_ref[1].transpose(new_order)
            np_args_ref[1] = np_args_ref[1].reshape((red_dim, out_dim))

        func = tvm.build(s, bufs, target=target)
        func_ref = tvm.build(s_ref, bufs_ref, target=target)

        ctx = tvm.context(str(target))
        ctx_ref = tvm.cpu()

        args = [tvm.nd.array(x, ctx=ctx) for x in np_args]
        args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args_ref]
        ctx.sync()

        func(*args)
        func_ref(*args_ref)
        ctx.sync()

        tvm.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy(), atol=1e-3, rtol=1e-3)
        tvm.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy(), atol=1e-3, rtol=1e-3)
        del measure_ctx