def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map) -> tvm.relay.Expr: slice_input = post.args[0] params = ethosu_patterns.StridedSliceParams(post.op.body) strided_slice = relay.op.strided_slice( slice_input, params.begin, params.end, strides=params.strides, axes=params.axes, slice_mode=params.slice_mode, ) return strided_slice
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map) -> tvm.relay.Expr: slice_input = post.args[0] # TODO(lhutton1) For an unknown reason compilation will fail for strides of 4 # dimensions, so we cannot use params.strides as this will sometimes give # strides as [1, 1, 1, 1]. Since we only support strides of 1, hardcoding this # value for now. strides = [1] params = ethosu_patterns.StridedSliceParams(post.op.body) strided_slice = relay.op.strided_slice( slice_input, params.begin, params.end, strides=strides, axes=params.axes, slice_mode=params.slice_mode, ) return strided_slice
def test_relay_strided_slice_legalize(ifm_shape, begin, end): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") strided_slice = relay.op.strided_slice(ifm, begin, end) func = relay.Function([ifm], strided_slice) mod = tvm.IRModule() mod["main"] = func mod = relay.transform.InferType()(mod) strided_slice_pattern_table = [ ( ethosu.StridedSliceParams.composite_name, ethosu.strided_slice_pattern(), lambda pat: ethosu.StridedSliceParams(pat).is_valid(), ), ] mod = partition_ethosu_by_table(mod, strided_slice_pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.StridedSliceRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) mod = relay.transform.InferType()(mod) ext_func = mod["tvmgen_default_ethos_u_main_0"] identity = ext_func.body assert identity.op.name == "contrib.ethosu.identity" # check that the strided_slice is still there strided_slice = identity.args[0] assert strided_slice.op.name == "strided_slice" # check that identity's output shape matches strided slice's output shape slice_shape = [a - b for a, b in zip(end, begin)] assert list(identity.checked_type.shape) == slice_shape