def callback( self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map ) -> tvm.relay.Expr: reshape_input = post.args[0] reshape_params = ethosu_patterns.ReshapeParams(post.op.body) new_shape = reshape_params.new_shape return relay.op.reshape(reshape_input, newshape=new_shape)
def test_relay_reshape_legalize(ifm_shape, new_shape): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") reshape = relay.op.reshape(ifm, new_shape) func = relay.Function([ifm], reshape) mod = tvm.IRModule() mod["main"] = func mod = relay.transform.InferType()(mod) reshape_pattern_table = [ ( ethosu.ReshapeParams.composite_name, ethosu.reshape_pattern(), lambda pat: ethosu.ReshapeParams(pat).is_valid(), ), ] mod = partition_ethosu_by_table(mod, reshape_pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"]) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"]) mod = relay.transform.InferType()(mod) ext_func = mod["tvmgen_default_ethos_u_main_0"] identity = ext_func.body assert identity.op.name == "contrib.ethosu.identity" # check that the reshape is still there reshape = identity.args[0] assert reshape.op.name == "reshape" # check that identity's output shape matches reshape's output shape assert tuple(identity.checked_type.shape) == new_shape