def test_relay_reshape_legalize(ifm_shape, new_shape): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") reshape = relay.op.reshape(ifm, new_shape) func = relay.Function([ifm], reshape) mod = tvm.IRModule() mod["main"] = func mod = relay.transform.InferType()(mod) reshape_pattern_table = [ ( ethosu.ReshapeParams.composite_name, ethosu.reshape_pattern(), lambda pat: ethosu.ReshapeParams(pat).is_valid(), ), ] mod = partition_ethosu_by_table(mod, reshape_pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"]) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"]) mod = relay.transform.InferType()(mod) ext_func = mod["tvmgen_default_ethos_u_main_0"] identity = ext_func.body assert identity.op.name == "contrib.ethosu.identity" # check that the reshape is still there reshape = identity.args[0] assert reshape.op.name == "reshape" # check that identity's output shape matches reshape's output shape assert tuple(identity.checked_type.shape) == new_shape
def transform_module(self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): func = rewrite(PartitionedSplitRewriter(), func) func = rewrite(SplitRewriter(), func) mod.update_func(global_var, func) return mod
def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): func = rewrite(DepthwiseConv2DRewriter(), func) mod.update_func(global_var, func) return mod
def test_tflite_tanh_legalize(): dtype = "int8" ifm_shape = (1, 241, 132, 7) def create_tflite_graph(): class Model(tf.Module): @tf.function def tanh_func(self, x): op = tf.math.tanh(x) return op model = Model() concrete_func = model.tanh_func.get_concrete_function( tf.TensorSpec(ifm_shape, dtype=tf.float32)) # Convert the model def representative_dataset(): for _ in range(100): data = np.random.rand(*tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() return tflite_model tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) mod, params = relay.frontend.from_tflite( tflite_model, shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) mod = ethosu.partition_for_ethosu(mod, params) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.TanhRewriter(), mod["tvmgen_default_ethos_u_main_0"]) mod = relay.transform.InferType()(mod) func_body = mod["tvmgen_default_ethos_u_main_0"].body assert func_body.op.name == "contrib.ethosu.identity" assert func_body.attrs.activation == "TANH" assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape) assert tuple(func_body.args[1].checked_type.shape) == (256, )
def test_relay_strided_slice_legalize(ifm_shape, begin, end): ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") strided_slice = relay.op.strided_slice(ifm, begin, end) func = relay.Function([ifm], strided_slice) mod = tvm.IRModule() mod["main"] = func mod = relay.transform.InferType()(mod) strided_slice_pattern_table = [ ( ethosu.StridedSliceParams.composite_name, ethosu.strided_slice_pattern(), lambda pat: ethosu.StridedSliceParams(pat).is_valid(), ), ] mod = partition_ethosu_by_table(mod, strided_slice_pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.StridedSliceRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) mod = relay.transform.InferType()(mod) ext_func = mod["tvmgen_default_ethos_u_main_0"] identity = ext_func.body assert identity.op.name == "contrib.ethosu.identity" # check that the strided_slice is still there strided_slice = identity.args[0] assert strided_slice.op.name == "strided_slice" # check that identity's output shape matches strided slice's output shape slice_shape = [a - b for a, b in zip(end, begin)] assert list(identity.checked_type.shape) == slice_shape
def test_binary_add_from_constant_scalar(): dtype = "uint8" ifm_shape = (1, 4, 4, 8) def create_graph(): inp = relay.var("input", shape=ifm_shape, dtype=dtype) scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype) add = relay.qnn.op.add( inp, scalar, relay.const(1.0, dtype="float32"), relay.const(0, dtype="int32"), relay.const(1.0, dtype="float32"), relay.const(0, dtype="int32"), relay.const(1.0, dtype="float32"), relay.const(0, dtype="int32"), ) func = relay.Function(relay.analysis.free_vars(add), add) return tvm.IRModule.from_expr(func) def verify(ext_func): op = ext_func.body assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8] assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1] assert op.args[0].checked_type.dtype == "uint8" assert list(op.checked_type.shape) == [1, 4, 4, 8] assert op.checked_type.dtype == "uint8" assert op.attrs.operator_type == "ADD" rewriter = legalize.AddRewriter() pattern_table = [ ( ethosu.AddParams.composite_name, ethosu.qnn_add_pattern(), lambda pat: ethosu.AddParams(pat).is_valid(), ), ] mod = create_graph() mod = partition_ethosu_by_table(mod, pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( rewriter, mod["tvmgen_default_ethos_u_main_0"] ) verify(mod["tvmgen_default_ethos_u_main_0"])
def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, reversed_operands): dtype = "int32" operator_type = "SHL" def create_graph(): input1 = relay.var("x1", shape=ifm_shape, dtype=dtype) input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype) c1 = relay.left_shift(input1, input2) f = relay.Function([input1, input2], c1) mod = tvm.IRModule() mod["main"] = f return mod def verify(ext_func): out_shape = ifm2_shape if reversed_operands else ifm_shape shapes = [ifm_shape, ifm2_shape] ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) op = ext_func.body assert list(op.args[0].checked_type.shape) == shapes[ifm_index] assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] assert op.args[0].checked_type.dtype == dtype assert list(op.checked_type.shape) == out_shape assert op.checked_type.dtype == dtype assert op.attrs.operator_type == operator_type assert op.attrs.reversed_operands == reversed_operands assert str(op.attrs.activation) == "NONE" rewriter = legalize.ShlRewriter() pattern_table = [ ( ethosu.ShlParams.composite_name, ethosu.shl_pattern(), lambda pat: ethosu.ShlParams(pat).is_valid(), ), ] mod = create_graph() mod = partition_ethosu_by_table(mod, pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( rewriter, mod["tvmgen_default_ethos_u_main_0"] ) verify(mod["tvmgen_default_ethos_u_main_0"])
def transform_npu_function(self, _, func: relay.Function) -> relay.Function: """This is the method that replaces the operations with hardware/codegen supported operations. """ rewriters = [ PartitionedSplitRewriter(), SplitRewriter(), Conv2DRewriter(), Conv2DTransposeRewriter(), DepthwiseConv2DRewriter(), FullyConnectedRewriter(), MaxPoolingRewriter(), AvgPoolingRewriter(), AddRewriter(), SubRewriter(), MulRewriter(), MinRewriter(), MaxRewriter(), ShlRewriter(), AbsRewriter(), TanhRewriter(), LeakyReLURewriter(), MeanRewriter(), ConcatRewriter(), SigmoidRewriter(), RequantizeRewriter(), Resize2dRewriter(), ExpandDimsRewriter(), SqueezeRewriter(), ReshapeRewriter(), StridedSliceRewriter(), NoOpRewriter(), ] for rewriter in rewriters: func = rewrite(rewriter, func) return func
def test_tflite_binary_elemwise_legalize( operator_type, ifm_shape, ifm2_shape, reversed_operands, activation_function, ): dtype = "int8" def create_tflite_graph(): class Model(tf.Module): @tf.function def tf_function(self, x, y): if operator_type == "ADD": op = tf.math.add(x, y) elif operator_type == "SUB": op = tf.math.subtract(x, y) elif operator_type == "MUL": op = tf.math.multiply(x, y) elif operator_type == "MIN": op = tf.math.minimum(x, y) elif operator_type == "MAX": op = tf.math.maximum(x, y) if activation_function == "RELU": op = tf.nn.relu(op) return op model = Model() concrete_func = model.tf_function.get_concrete_function( tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) ) # Convert the model def representative_dataset(): for _ in range(100): data = np.random.rand(*tuple(ifm_shape)) data2 = np.random.rand(*tuple(ifm2_shape)) * 2 yield [data.astype(np.float32), data2.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() return tflite_model def verify(ext_func): out_shape = ifm2_shape if reversed_operands else ifm_shape shapes = [ifm_shape, ifm2_shape] ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) op = ext_func.body assert list(op.args[0].checked_type.shape) == shapes[ifm_index] assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] assert op.args[0].checked_type.dtype == dtype assert list(op.checked_type.shape) == out_shape assert op.checked_type.dtype == dtype assert op.attrs.operator_type == operator_type assert op.attrs.reversed_operands == reversed_operands if activation_function == "RELU": assert str(op.attrs.activation) == "CLIP" if operator_type == "ADD": rewriter = legalize.AddRewriter() pattern_table = [ ( ethosu.AddParams.composite_name, ethosu.qnn_add_pattern(), lambda pat: ethosu.AddParams(pat).is_valid(), ), ] elif operator_type == "SUB": rewriter = legalize.SubRewriter() pattern_table = [ ( ethosu.SubParams.composite_name, ethosu.qnn_subtract_pattern(), lambda pat: ethosu.SubParams(pat).is_valid(), ), ] elif operator_type == "MUL": rewriter = legalize.MulRewriter() pattern_table = [ ( ethosu.MulParams.composite_name, ethosu.qnn_mul_pattern(), lambda pat: ethosu.MulParams(pat).is_valid(), ), ] elif operator_type == "MIN": rewriter = legalize.MinRewriter() pattern_table = [ ( ethosu.MinParams.composite_name, ethosu.minimum_pattern(), lambda pat: ethosu.MinParams(pat).is_valid(), ), ] elif operator_type == "MAX": rewriter = legalize.MaxRewriter() pattern_table = [ ( ethosu.MaxParams.composite_name, ethosu.maximum_pattern(), lambda pat: ethosu.MaxParams(pat).is_valid(), ), ] tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) mod, _ = relay.frontend.from_tflite( tflite_model, shape_dict={"x": ifm_shape, "y": ifm2_shape}, dtype_dict={"x": dtype, "y": dtype}, ) mod = partition_ethosu_by_table(mod, pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( rewriter, mod["tvmgen_default_ethos_u_main_0"] ) verify(mod["tvmgen_default_ethos_u_main_0"])
def test_tflite_pool2d_legalize( ifm_shape, pooling_type, strides, pool_shape, activation_function, padding ): dtype = "int8" def create_tflite_graph(): class Model(tf.Module): @tf.function def tf_function(self, x): if pooling_type == "MAX": op = tf.nn.max_pool(x, pool_shape, strides, padding) elif pooling_type == "AVG": op = tf.nn.avg_pool(x, pool_shape, strides, padding) if activation_function == "RELU": op = tf.nn.relu(op) return op model = Model() concrete_func = model.tf_function.get_concrete_function( tf.TensorSpec(ifm_shape, dtype=tf.float32) ) # Convert the model def representative_dataset(): for _ in range(100): data = np.random.rand(*tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() return tflite_model def verify(ext_func): ofm_shape = infra.compute_ofm_shape(ifm_shape, padding, pool_shape, strides) op = ext_func.body assert list(op.args[0].checked_type.shape) == ifm_shape assert op.args[0].checked_type.dtype == dtype assert list(op.checked_type.shape) == ofm_shape assert op.checked_type.dtype == dtype assert op.attrs.pooling_type == pooling_type assert list(op.attrs.strides) == strides assert list(op.attrs.padding) == infra.compute_padding_shape( ifm_shape, ofm_shape, padding, pool_shape, strides ) assert list(op.attrs.pool_shape) == pool_shape assert op.attrs.ofm_channels == ifm_shape[3] if activation_function == "RELU": assert str(op.attrs.activation) == "CLIP" if pooling_type == "MAX": rewriter = legalize.MaxPoolingRewriter() pattern_table = [ ( ethosu.MaxPool2DParams.composite_name, ethosu.qnn_maxpool2d_pattern(), lambda pat: ethosu.MaxPool2DParams(pat).is_valid(), ), ] elif pooling_type == "AVG": rewriter = legalize.AvgPoolingRewriter() pattern_table = [ ( ethosu.AvgPool2DParams.composite_name, ethosu.qnn_avgpool2d_pattern(), lambda pat: ethosu.AvgPool2DParams(pat).is_valid(), ), ] tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) mod, _ = relay.frontend.from_tflite( tflite_model, shape_dict={"x": ifm_shape}, dtype_dict={"x": dtype}, ) mod = partition_ethosu_by_table(mod, pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( rewriter, mod["tvmgen_default_ethos_u_main_0"] ) verify(mod["tvmgen_default_ethos_u_main_0"])
def test_tflite_depthwise_conv_2d_legalize( ifm_shape, kernel_shape, padding, strides, dilation, activation ): dtype = "int8" def create_tflite_graph(): class Model(tf.Module): @tf.function def depthwise_conv2d(self, x): weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) # The input strides to the TensorFlow API needs to be of shape 1x4 tf_strides = [1, strides[0], strides[1], 1] op = tf.nn.depthwise_conv2d( x, weight, strides=tf_strides, padding=padding, dilations=dilation ) if activation: op = tf.nn.relu(op) return op model = Model() concrete_func = model.depthwise_conv2d.get_concrete_function( tf.TensorSpec(ifm_shape, dtype=tf.float32) ) # Convert the model def representative_dataset(): for _ in range(100): data = np.random.rand(*tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() return tflite_model def verify(ext_func): op = ext_func.body ofm_channels = op.attrs.ofm_channels # check IFM ifm = op.args[0].checked_type assert list(ifm.shape) == list(ifm_shape) assert str(ifm.dtype) == dtype assert ifm.shape[3] == ofm_channels # check OFM ofm = op.checked_type expected_ofm_shape = infra.compute_ofm_shape( ifm_shape, padding, kernel_shape, strides, dilation ) assert list(ofm.shape) == list(expected_ofm_shape) assert str(ofm.dtype) == dtype assert ofm.shape[3] == ofm_channels # check weights weights_ohwi = op.args[1].data.asnumpy() assert str(weights_ohwi.dtype) == dtype assert weights_ohwi.shape[0] == ofm_channels assert weights_ohwi.shape[1] == kernel_shape[0] assert weights_ohwi.shape[2] == kernel_shape[1] assert weights_ohwi.shape[3] == 1 # only depth multiplier 1 is supported # Check that scale_bias matches weight tensor assert list(op.args[2].checked_type.shape)[0] == ofm_channels expected_padding = infra.compute_padding_shape( ifm_shape, expected_ofm_shape, padding, kernel_shape, strides, dilation ) assert list(op.attrs.padding) == list(expected_padding) assert op.attrs.ofm_channels == ofm_channels assert list(op.attrs.strides) == list(strides) assert list(op.attrs.dilation) == list(dilation) if activation == "RELU": assert str(op.attrs.activation) == "CLIP" depthwise_pattern_table = [ ( ethosu.QnnDepthwiseConv2DParams.composite_name, ethosu.qnn_depthwise_conv2d_pattern(), lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), ) ] tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) mod, params = relay.frontend.from_tflite( tflite_model, shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) mod["main"] = bind_params_by_name(mod["main"], params) mod = partition_ethosu_by_table(mod, depthwise_pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"] ) verify(mod["tvmgen_default_ethos_u_main_0"])
def test_tflite_unary_elemwise_legalize( operator_type, ifm_shape, ): dtype = "int8" def create_tflite_graph(): class Model(tf.Module): @tf.function def abs_func(self, x): if operator_type == "ABS": op = tf.math.abs(x) return op model = Model() # Save the model concrete_func = model.abs_func.get_concrete_function( tf.TensorSpec(ifm_shape, dtype=tf.float32)) # Convert the model def representative_dataset(): for _ in range(100): data = np.random.rand(*tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() return tflite_model def verify(ext_func): out_shape = ifm_shape func_body = ext_func.body # If we legalized the unary elementwise op into 4D if func_body.op.name == "reshape": reshape = func_body unary = func_body.args[0] reshape2 = unary.args[0] # Check the input to the reshape reshape2_in_shape = [ i for i in reshape2.args[0].checked_type.shape ] assert reshape2_in_shape == ifm_shape # Check that the unary elementwise operator is 4D after reshape assert len(unary.checked_type.shape) == 4 assert unary.args[0].checked_type.dtype == dtype # Check that the output of the graph has the same shape as input reshape_out_shape = [i for i in reshape.checked_type.shape] assert reshape_out_shape == ifm_shape assert unary.attrs.operator_type == operator_type else: unary = func_body # Check the IFM assert list(unary.args[0].checked_type.shape) == ifm_shape assert unary.args[0].checked_type.dtype == dtype # Check the OFM assert list(unary.checked_type.shape) == out_shape assert unary.checked_type.dtype == dtype # operator type check assert unary.attrs.operator_type == operator_type if operator_type == "ABS": rewriter = legalize.AbsRewriter() pattern_table = [ ( ethosu.AbsParams.composite_name, ethosu.abs_pattern(), lambda pat: ethosu.AbsParams(pat).is_valid(), ), ] tflite_graph = create_tflite_graph() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) mod, _ = relay.frontend.from_tflite( tflite_model, shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) mod = partition_ethosu_by_table(mod, pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( rewriter, mod["tvmgen_default_ethos_u_main_0"]) verify(mod["tvmgen_default_ethos_u_main_0"])
def test_mean(ifm_shape, axis, keep_dims, use_same_quantization): dtype = "int8" def create_tflite_graph(): class Model(tf.Module): @tf.function def tf_function(self, x): op = tf.math.reduce_mean(x, axis=axis, keepdims=keep_dims) return op model = Model() concrete_func = model.tf_function.get_concrete_function( tf.TensorSpec(ifm_shape, dtype=tf.float32)) # Convert the model def representative_dataset(): for _ in range(100): data = np.random.rand(*tuple(ifm_shape)) yield [data.astype(np.float32)] converter = tf.lite.TFLiteConverter.from_concrete_functions( [concrete_func]) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model, 0) mod, _ = relay.frontend.from_tflite( tflite_model, shape_dict={"input": ifm_shape}, dtype_dict={"input": dtype}, ) return mod def create_relay_graph_with_same_quantization(): ifm = relay.var("input", shape=ifm_shape, dtype=dtype) cast = relay.cast(ifm, dtype="int32") mean = relay.mean(cast, axis=axis, keepdims=keep_dims) requantize = relay.qnn.op.requantize( mean, input_scale=relay.const(1.0, dtype="float32"), input_zero_point=relay.const(0, dtype="int32"), output_scale=relay.const(1.0, dtype="float32"), output_zero_point=relay.const(0, dtype="int32"), ) func = relay.Function(relay.analysis.free_vars(requantize), requantize) mod = tvm.IRModule.from_expr(func) return mod def verify(ext_func): out_var = ext_func.body next_op = out_var mul_op = None pooling_op = None depthwise_op = None if (isinstance(next_op, relay.expr.Call) and isinstance(next_op.op, tvm.ir.op.Op) and next_op.op.name == "reshape"): next_op = next_op.args[0] if util.is_named_ethosu_op(next_op, "binary_elementwise"): mul_op = next_op next_op = next_op.args[0] if util.is_named_ethosu_op(next_op, "pooling"): pooling_op = next_op next_op = next_op.args[0] if util.is_named_ethosu_op(next_op, "depthwise_conv2d"): depthwise_op = next_op next_op = next_op.args[0] while (isinstance(next_op, relay.expr.Call) and isinstance(next_op.op, tvm.ir.op.Op) and next_op.op.name == "reshape"): next_op = next_op.args[0] in_var = next_op def calculate_expected_output_shape(): for i in range(len(ifm_shape)): if i in axis: if keep_dims: yield 1 else: yield ifm_shape[i] out_shape = tuple(calculate_expected_output_shape()) # check IFM assert tuple(in_var.checked_type.shape) == ifm_shape assert in_var.checked_type.dtype == dtype # check OFM assert tuple(out_var.checked_type.shape) == out_shape assert out_var.checked_type.dtype == dtype # check expected legalization case if axis in [(1, 2), (2, 1), (0, 1), (1, 0)] and keep_dims and dtype == "int8": assert depthwise_op and mul_op assert mul_op.attrs.operator_type == "MUL" elif pooling_op: attrs = pooling_op.attrs assert (attrs.ifm_scale == attrs.ofm_scale and attrs.ifm_zero_point == attrs.ofm_zero_point) else: assert depthwise_op assert not mul_op rewriter = legalize.MeanRewriter() pattern_table = [ ( ethosu.MeanParams.composite_name, ethosu.mean_pattern(), lambda pat: ethosu.MeanParams(pat).is_valid(), ), ] mod = (create_relay_graph_with_same_quantization() if use_same_quantization else create_tflite_graph()) mod = partition_ethosu_by_table(mod, pattern_table) mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( rewriter, mod["tvmgen_default_ethos_u_main_0"]) verify(mod["tvmgen_default_ethos_u_main_0"])