コード例 #1
0
ファイル: legalize.py プロジェクト: saudet/tvm
 def callback(
     self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
 ) -> tvm.relay.Expr:
     reshape_input = post.args[0]
     reshape_params = ethosu_patterns.ReshapeParams(post.op.body)
     new_shape = reshape_params.new_shape
     return relay.op.reshape(reshape_input, newshape=new_shape)
コード例 #2
0
def test_relay_reshape_legalize(ifm_shape, new_shape):

    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
    reshape = relay.op.reshape(ifm, new_shape)
    func = relay.Function([ifm], reshape)
    mod = tvm.IRModule()
    mod["main"] = func
    mod = relay.transform.InferType()(mod)

    reshape_pattern_table = [
        (
            ethosu.ReshapeParams.composite_name,
            ethosu.reshape_pattern(),
            lambda pat: ethosu.ReshapeParams(pat).is_valid(),
        ),
    ]

    mod = partition_ethosu_by_table(mod, reshape_pattern_table)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"])
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"])
    mod = relay.transform.InferType()(mod)

    ext_func = mod["tvmgen_default_ethos_u_main_0"]

    identity = ext_func.body
    assert identity.op.name == "contrib.ethosu.identity"

    # check that the reshape is still there
    reshape = identity.args[0]
    assert reshape.op.name == "reshape"

    # check that identity's output shape matches reshape's output shape
    assert tuple(identity.checked_type.shape) == new_shape