Exemple #1
0
def test_relay_reshape_legalize(ifm_shape, new_shape):

    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
    reshape = relay.op.reshape(ifm, new_shape)
    func = relay.Function([ifm], reshape)
    mod = tvm.IRModule()
    mod["main"] = func
    mod = relay.transform.InferType()(mod)

    reshape_pattern_table = [
        (
            ethosu.ReshapeParams.composite_name,
            ethosu.reshape_pattern(),
            lambda pat: ethosu.ReshapeParams(pat).is_valid(),
        ),
    ]

    mod = partition_ethosu_by_table(mod, reshape_pattern_table)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.ReshapeRewriter(), mod["tvmgen_default_ethos_u_main_0"])
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"])
    mod = relay.transform.InferType()(mod)

    ext_func = mod["tvmgen_default_ethos_u_main_0"]

    identity = ext_func.body
    assert identity.op.name == "contrib.ethosu.identity"

    # check that the reshape is still there
    reshape = identity.args[0]
    assert reshape.op.name == "reshape"

    # check that identity's output shape matches reshape's output shape
    assert tuple(identity.checked_type.shape) == new_shape
Exemple #2
0
 def transform_module(self, mod: tvm.ir.IRModule,
                      ctx: tvm.ir.transform.PassContext) -> tvm.ir.IRModule:
     for global_var, func in mod.functions.items():
         func = rewrite(PartitionedSplitRewriter(), func)
         func = rewrite(SplitRewriter(), func)
         mod.update_func(global_var, func)
     return mod
Exemple #3
0
 def transform_module(
     self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
 ) -> tvm.ir.IRModule:
     for global_var, func in mod.functions.items():
         func = rewrite(DepthwiseConv2DRewriter(), func)
         mod.update_func(global_var, func)
     return mod
Exemple #4
0
def test_tflite_tanh_legalize():
    dtype = "int8"
    ifm_shape = (1, 241, 132, 7)

    def create_tflite_graph():
        class Model(tf.Module):
            @tf.function
            def tanh_func(self, x):
                op = tf.math.tanh(x)
                return op

        model = Model()
        concrete_func = model.tanh_func.get_concrete_function(
            tf.TensorSpec(ifm_shape, dtype=tf.float32))

        # Convert the model
        def representative_dataset():
            for _ in range(100):
                data = np.random.rand(*tuple(ifm_shape))
                yield [data.astype(np.float32)]

        converter = tf.lite.TFLiteConverter.from_concrete_functions(
            [concrete_func])
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
        ]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        tflite_model = converter.convert()
        return tflite_model

    tflite_graph = create_tflite_graph()
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

    mod, params = relay.frontend.from_tflite(
        tflite_model,
        shape_dict={"input": ifm_shape},
        dtype_dict={"input": dtype},
    )

    mod = ethosu.partition_for_ethosu(mod, params)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.TanhRewriter(), mod["tvmgen_default_ethos_u_main_0"])
    mod = relay.transform.InferType()(mod)

    func_body = mod["tvmgen_default_ethos_u_main_0"].body
    assert func_body.op.name == "contrib.ethosu.identity"
    assert func_body.attrs.activation == "TANH"
    assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
    assert tuple(func_body.args[1].checked_type.shape) == (256, )
Exemple #5
0
def test_relay_strided_slice_legalize(ifm_shape, begin, end):

    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
    strided_slice = relay.op.strided_slice(ifm, begin, end)
    func = relay.Function([ifm], strided_slice)
    mod = tvm.IRModule()
    mod["main"] = func
    mod = relay.transform.InferType()(mod)

    strided_slice_pattern_table = [
        (
            ethosu.StridedSliceParams.composite_name,
            ethosu.strided_slice_pattern(),
            lambda pat: ethosu.StridedSliceParams(pat).is_valid(),
        ),
    ]

    mod = partition_ethosu_by_table(mod, strided_slice_pattern_table)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.StridedSliceRewriter(), mod["tvmgen_default_ethos_u_main_0"]
    )
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"]
    )
    mod = relay.transform.InferType()(mod)

    ext_func = mod["tvmgen_default_ethos_u_main_0"]

    identity = ext_func.body
    assert identity.op.name == "contrib.ethosu.identity"

    # check that the strided_slice is still there
    strided_slice = identity.args[0]
    assert strided_slice.op.name == "strided_slice"

    # check that identity's output shape matches strided slice's output shape
    slice_shape = [a - b for a, b in zip(end, begin)]
    assert list(identity.checked_type.shape) == slice_shape
Exemple #6
0
def test_binary_add_from_constant_scalar():
    dtype = "uint8"
    ifm_shape = (1, 4, 4, 8)

    def create_graph():
        inp = relay.var("input", shape=ifm_shape, dtype=dtype)
        scalar = relay.const(np.ones((1, 1, 1, 1), dtype=dtype), dtype=dtype)
        add = relay.qnn.op.add(
            inp,
            scalar,
            relay.const(1.0, dtype="float32"),
            relay.const(0, dtype="int32"),
            relay.const(1.0, dtype="float32"),
            relay.const(0, dtype="int32"),
            relay.const(1.0, dtype="float32"),
            relay.const(0, dtype="int32"),
        )
        func = relay.Function(relay.analysis.free_vars(add), add)
        return tvm.IRModule.from_expr(func)

    def verify(ext_func):
        op = ext_func.body
        assert list(op.args[0].checked_type.shape) == [1, 4, 4, 8]
        assert list(op.args[1].checked_type.shape) == [1, 1, 1, 1]
        assert op.args[0].checked_type.dtype == "uint8"
        assert list(op.checked_type.shape) == [1, 4, 4, 8]
        assert op.checked_type.dtype == "uint8"
        assert op.attrs.operator_type == "ADD"

    rewriter = legalize.AddRewriter()
    pattern_table = [
        (
            ethosu.AddParams.composite_name,
            ethosu.qnn_add_pattern(),
            lambda pat: ethosu.AddParams(pat).is_valid(),
        ),
    ]

    mod = create_graph()
    mod = partition_ethosu_by_table(mod, pattern_table)

    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        rewriter, mod["tvmgen_default_ethos_u_main_0"]
    )
    verify(mod["tvmgen_default_ethos_u_main_0"])
Exemple #7
0
def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, reversed_operands):
    dtype = "int32"
    operator_type = "SHL"

    def create_graph():
        input1 = relay.var("x1", shape=ifm_shape, dtype=dtype)
        input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype)
        c1 = relay.left_shift(input1, input2)
        f = relay.Function([input1, input2], c1)
        mod = tvm.IRModule()
        mod["main"] = f
        return mod

    def verify(ext_func):
        out_shape = ifm2_shape if reversed_operands else ifm_shape
        shapes = [ifm_shape, ifm2_shape]
        ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
        op = ext_func.body
        assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
        assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
        assert op.args[0].checked_type.dtype == dtype
        assert list(op.checked_type.shape) == out_shape
        assert op.checked_type.dtype == dtype
        assert op.attrs.operator_type == operator_type
        assert op.attrs.reversed_operands == reversed_operands
        assert str(op.attrs.activation) == "NONE"

    rewriter = legalize.ShlRewriter()
    pattern_table = [
        (
            ethosu.ShlParams.composite_name,
            ethosu.shl_pattern(),
            lambda pat: ethosu.ShlParams(pat).is_valid(),
        ),
    ]

    mod = create_graph()
    mod = partition_ethosu_by_table(mod, pattern_table)

    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        rewriter, mod["tvmgen_default_ethos_u_main_0"]
    )
    verify(mod["tvmgen_default_ethos_u_main_0"])
Exemple #8
0
    def transform_npu_function(self, _,
                               func: relay.Function) -> relay.Function:
        """This is the method that replaces the operations with hardware/codegen supported
        operations.
        """
        rewriters = [
            PartitionedSplitRewriter(),
            SplitRewriter(),
            Conv2DRewriter(),
            Conv2DTransposeRewriter(),
            DepthwiseConv2DRewriter(),
            FullyConnectedRewriter(),
            MaxPoolingRewriter(),
            AvgPoolingRewriter(),
            AddRewriter(),
            SubRewriter(),
            MulRewriter(),
            MinRewriter(),
            MaxRewriter(),
            ShlRewriter(),
            AbsRewriter(),
            TanhRewriter(),
            LeakyReLURewriter(),
            MeanRewriter(),
            ConcatRewriter(),
            SigmoidRewriter(),
            RequantizeRewriter(),
            Resize2dRewriter(),
            ExpandDimsRewriter(),
            SqueezeRewriter(),
            ReshapeRewriter(),
            StridedSliceRewriter(),
            NoOpRewriter(),
        ]
        for rewriter in rewriters:
            func = rewrite(rewriter, func)

        return func
Exemple #9
0
def test_tflite_binary_elemwise_legalize(
    operator_type,
    ifm_shape,
    ifm2_shape,
    reversed_operands,
    activation_function,
):
    dtype = "int8"

    def create_tflite_graph():
        class Model(tf.Module):
            @tf.function
            def tf_function(self, x, y):
                if operator_type == "ADD":
                    op = tf.math.add(x, y)
                elif operator_type == "SUB":
                    op = tf.math.subtract(x, y)
                elif operator_type == "MUL":
                    op = tf.math.multiply(x, y)
                elif operator_type == "MIN":
                    op = tf.math.minimum(x, y)
                elif operator_type == "MAX":
                    op = tf.math.maximum(x, y)
                if activation_function == "RELU":
                    op = tf.nn.relu(op)
                return op

        model = Model()
        concrete_func = model.tf_function.get_concrete_function(
            tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32)
        )

        # Convert the model
        def representative_dataset():
            for _ in range(100):
                data = np.random.rand(*tuple(ifm_shape))
                data2 = np.random.rand(*tuple(ifm2_shape)) * 2
                yield [data.astype(np.float32), data2.astype(np.float32)]

        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        tflite_model = converter.convert()
        return tflite_model

    def verify(ext_func):
        out_shape = ifm2_shape if reversed_operands else ifm_shape
        shapes = [ifm_shape, ifm2_shape]
        ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
        op = ext_func.body
        assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
        assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
        assert op.args[0].checked_type.dtype == dtype
        assert list(op.checked_type.shape) == out_shape
        assert op.checked_type.dtype == dtype
        assert op.attrs.operator_type == operator_type
        assert op.attrs.reversed_operands == reversed_operands
        if activation_function == "RELU":
            assert str(op.attrs.activation) == "CLIP"

    if operator_type == "ADD":
        rewriter = legalize.AddRewriter()
        pattern_table = [
            (
                ethosu.AddParams.composite_name,
                ethosu.qnn_add_pattern(),
                lambda pat: ethosu.AddParams(pat).is_valid(),
            ),
        ]
    elif operator_type == "SUB":
        rewriter = legalize.SubRewriter()
        pattern_table = [
            (
                ethosu.SubParams.composite_name,
                ethosu.qnn_subtract_pattern(),
                lambda pat: ethosu.SubParams(pat).is_valid(),
            ),
        ]
    elif operator_type == "MUL":
        rewriter = legalize.MulRewriter()
        pattern_table = [
            (
                ethosu.MulParams.composite_name,
                ethosu.qnn_mul_pattern(),
                lambda pat: ethosu.MulParams(pat).is_valid(),
            ),
        ]
    elif operator_type == "MIN":
        rewriter = legalize.MinRewriter()
        pattern_table = [
            (
                ethosu.MinParams.composite_name,
                ethosu.minimum_pattern(),
                lambda pat: ethosu.MinParams(pat).is_valid(),
            ),
        ]
    elif operator_type == "MAX":
        rewriter = legalize.MaxRewriter()
        pattern_table = [
            (
                ethosu.MaxParams.composite_name,
                ethosu.maximum_pattern(),
                lambda pat: ethosu.MaxParams(pat).is_valid(),
            ),
        ]

    tflite_graph = create_tflite_graph()
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

    mod, _ = relay.frontend.from_tflite(
        tflite_model,
        shape_dict={"x": ifm_shape, "y": ifm2_shape},
        dtype_dict={"x": dtype, "y": dtype},
    )
    mod = partition_ethosu_by_table(mod, pattern_table)

    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        rewriter, mod["tvmgen_default_ethos_u_main_0"]
    )
    verify(mod["tvmgen_default_ethos_u_main_0"])
Exemple #10
0
def test_tflite_pool2d_legalize(
    ifm_shape, pooling_type, strides, pool_shape, activation_function, padding
):
    dtype = "int8"

    def create_tflite_graph():
        class Model(tf.Module):
            @tf.function
            def tf_function(self, x):
                if pooling_type == "MAX":
                    op = tf.nn.max_pool(x, pool_shape, strides, padding)
                elif pooling_type == "AVG":
                    op = tf.nn.avg_pool(x, pool_shape, strides, padding)
                if activation_function == "RELU":
                    op = tf.nn.relu(op)
                return op

        model = Model()
        concrete_func = model.tf_function.get_concrete_function(
            tf.TensorSpec(ifm_shape, dtype=tf.float32)
        )

        # Convert the model
        def representative_dataset():
            for _ in range(100):
                data = np.random.rand(*tuple(ifm_shape))
                yield [data.astype(np.float32)]

        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        tflite_model = converter.convert()
        return tflite_model

    def verify(ext_func):
        ofm_shape = infra.compute_ofm_shape(ifm_shape, padding, pool_shape, strides)
        op = ext_func.body
        assert list(op.args[0].checked_type.shape) == ifm_shape
        assert op.args[0].checked_type.dtype == dtype
        assert list(op.checked_type.shape) == ofm_shape
        assert op.checked_type.dtype == dtype
        assert op.attrs.pooling_type == pooling_type
        assert list(op.attrs.strides) == strides
        assert list(op.attrs.padding) == infra.compute_padding_shape(
            ifm_shape, ofm_shape, padding, pool_shape, strides
        )
        assert list(op.attrs.pool_shape) == pool_shape
        assert op.attrs.ofm_channels == ifm_shape[3]
        if activation_function == "RELU":
            assert str(op.attrs.activation) == "CLIP"

    if pooling_type == "MAX":
        rewriter = legalize.MaxPoolingRewriter()
        pattern_table = [
            (
                ethosu.MaxPool2DParams.composite_name,
                ethosu.qnn_maxpool2d_pattern(),
                lambda pat: ethosu.MaxPool2DParams(pat).is_valid(),
            ),
        ]
    elif pooling_type == "AVG":
        rewriter = legalize.AvgPoolingRewriter()
        pattern_table = [
            (
                ethosu.AvgPool2DParams.composite_name,
                ethosu.qnn_avgpool2d_pattern(),
                lambda pat: ethosu.AvgPool2DParams(pat).is_valid(),
            ),
        ]

    tflite_graph = create_tflite_graph()
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

    mod, _ = relay.frontend.from_tflite(
        tflite_model,
        shape_dict={"x": ifm_shape},
        dtype_dict={"x": dtype},
    )
    mod = partition_ethosu_by_table(mod, pattern_table)

    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        rewriter, mod["tvmgen_default_ethos_u_main_0"]
    )
    verify(mod["tvmgen_default_ethos_u_main_0"])
Exemple #11
0
def test_tflite_depthwise_conv_2d_legalize(
    ifm_shape, kernel_shape, padding, strides, dilation, activation
):
    dtype = "int8"

    def create_tflite_graph():
        class Model(tf.Module):
            @tf.function
            def depthwise_conv2d(self, x):
                weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
                weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
                # The input strides to the TensorFlow API needs to be of shape 1x4
                tf_strides = [1, strides[0], strides[1], 1]
                op = tf.nn.depthwise_conv2d(
                    x, weight, strides=tf_strides, padding=padding, dilations=dilation
                )
                if activation:
                    op = tf.nn.relu(op)
                return op

        model = Model()
        concrete_func = model.depthwise_conv2d.get_concrete_function(
            tf.TensorSpec(ifm_shape, dtype=tf.float32)
        )

        # Convert the model
        def representative_dataset():
            for _ in range(100):
                data = np.random.rand(*tuple(ifm_shape))
                yield [data.astype(np.float32)]

        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        tflite_model = converter.convert()
        return tflite_model

    def verify(ext_func):
        op = ext_func.body
        ofm_channels = op.attrs.ofm_channels

        # check IFM
        ifm = op.args[0].checked_type
        assert list(ifm.shape) == list(ifm_shape)
        assert str(ifm.dtype) == dtype
        assert ifm.shape[3] == ofm_channels

        # check OFM
        ofm = op.checked_type
        expected_ofm_shape = infra.compute_ofm_shape(
            ifm_shape, padding, kernel_shape, strides, dilation
        )
        assert list(ofm.shape) == list(expected_ofm_shape)
        assert str(ofm.dtype) == dtype
        assert ofm.shape[3] == ofm_channels

        # check weights
        weights_ohwi = op.args[1].data.asnumpy()
        assert str(weights_ohwi.dtype) == dtype
        assert weights_ohwi.shape[0] == ofm_channels
        assert weights_ohwi.shape[1] == kernel_shape[0]
        assert weights_ohwi.shape[2] == kernel_shape[1]
        assert weights_ohwi.shape[3] == 1  # only depth multiplier 1 is supported

        # Check that scale_bias matches weight tensor
        assert list(op.args[2].checked_type.shape)[0] == ofm_channels

        expected_padding = infra.compute_padding_shape(
            ifm_shape, expected_ofm_shape, padding, kernel_shape, strides, dilation
        )
        assert list(op.attrs.padding) == list(expected_padding)
        assert op.attrs.ofm_channels == ofm_channels
        assert list(op.attrs.strides) == list(strides)
        assert list(op.attrs.dilation) == list(dilation)
        if activation == "RELU":
            assert str(op.attrs.activation) == "CLIP"

    depthwise_pattern_table = [
        (
            ethosu.QnnDepthwiseConv2DParams.composite_name,
            ethosu.qnn_depthwise_conv2d_pattern(),
            lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
        )
    ]

    tflite_graph = create_tflite_graph()
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

    mod, params = relay.frontend.from_tflite(
        tflite_model,
        shape_dict={"input": ifm_shape},
        dtype_dict={"input": dtype},
    )

    mod["main"] = bind_params_by_name(mod["main"], params)
    mod = partition_ethosu_by_table(mod, depthwise_pattern_table)

    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
    )
    verify(mod["tvmgen_default_ethos_u_main_0"])
Exemple #12
0
def test_tflite_unary_elemwise_legalize(
    operator_type,
    ifm_shape,
):
    dtype = "int8"

    def create_tflite_graph():
        class Model(tf.Module):
            @tf.function
            def abs_func(self, x):
                if operator_type == "ABS":
                    op = tf.math.abs(x)
                return op

        model = Model()

        # Save the model
        concrete_func = model.abs_func.get_concrete_function(
            tf.TensorSpec(ifm_shape, dtype=tf.float32))

        # Convert the model
        def representative_dataset():
            for _ in range(100):
                data = np.random.rand(*tuple(ifm_shape))
                yield [data.astype(np.float32)]

        converter = tf.lite.TFLiteConverter.from_concrete_functions(
            [concrete_func])
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
        ]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        tflite_model = converter.convert()
        return tflite_model

    def verify(ext_func):
        out_shape = ifm_shape
        func_body = ext_func.body

        # If we legalized the unary elementwise op into 4D
        if func_body.op.name == "reshape":
            reshape = func_body
            unary = func_body.args[0]
            reshape2 = unary.args[0]

            # Check the input to the reshape
            reshape2_in_shape = [
                i for i in reshape2.args[0].checked_type.shape
            ]
            assert reshape2_in_shape == ifm_shape

            # Check that the unary elementwise operator is 4D after reshape
            assert len(unary.checked_type.shape) == 4
            assert unary.args[0].checked_type.dtype == dtype

            # Check that the output of the graph has the same shape as input
            reshape_out_shape = [i for i in reshape.checked_type.shape]
            assert reshape_out_shape == ifm_shape
            assert unary.attrs.operator_type == operator_type

        else:
            unary = func_body

            # Check the IFM
            assert list(unary.args[0].checked_type.shape) == ifm_shape
            assert unary.args[0].checked_type.dtype == dtype

            # Check the OFM
            assert list(unary.checked_type.shape) == out_shape
            assert unary.checked_type.dtype == dtype

            # operator type check
            assert unary.attrs.operator_type == operator_type

    if operator_type == "ABS":
        rewriter = legalize.AbsRewriter()
        pattern_table = [
            (
                ethosu.AbsParams.composite_name,
                ethosu.abs_pattern(),
                lambda pat: ethosu.AbsParams(pat).is_valid(),
            ),
        ]

    tflite_graph = create_tflite_graph()
    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
    mod, _ = relay.frontend.from_tflite(
        tflite_model,
        shape_dict={"input": ifm_shape},
        dtype_dict={"input": dtype},
    )
    mod = partition_ethosu_by_table(mod, pattern_table)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        rewriter, mod["tvmgen_default_ethos_u_main_0"])
    verify(mod["tvmgen_default_ethos_u_main_0"])
Exemple #13
0
def test_mean(ifm_shape, axis, keep_dims, use_same_quantization):
    dtype = "int8"

    def create_tflite_graph():
        class Model(tf.Module):
            @tf.function
            def tf_function(self, x):
                op = tf.math.reduce_mean(x, axis=axis, keepdims=keep_dims)
                return op

        model = Model()
        concrete_func = model.tf_function.get_concrete_function(
            tf.TensorSpec(ifm_shape, dtype=tf.float32))

        # Convert the model
        def representative_dataset():
            for _ in range(100):
                data = np.random.rand(*tuple(ifm_shape))
                yield [data.astype(np.float32)]

        converter = tf.lite.TFLiteConverter.from_concrete_functions(
            [concrete_func])
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS_INT8
        ]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
        tflite_model = converter.convert()
        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model, 0)

        mod, _ = relay.frontend.from_tflite(
            tflite_model,
            shape_dict={"input": ifm_shape},
            dtype_dict={"input": dtype},
        )
        return mod

    def create_relay_graph_with_same_quantization():
        ifm = relay.var("input", shape=ifm_shape, dtype=dtype)
        cast = relay.cast(ifm, dtype="int32")
        mean = relay.mean(cast, axis=axis, keepdims=keep_dims)
        requantize = relay.qnn.op.requantize(
            mean,
            input_scale=relay.const(1.0, dtype="float32"),
            input_zero_point=relay.const(0, dtype="int32"),
            output_scale=relay.const(1.0, dtype="float32"),
            output_zero_point=relay.const(0, dtype="int32"),
        )

        func = relay.Function(relay.analysis.free_vars(requantize), requantize)
        mod = tvm.IRModule.from_expr(func)
        return mod

    def verify(ext_func):
        out_var = ext_func.body

        next_op = out_var
        mul_op = None
        pooling_op = None
        depthwise_op = None
        if (isinstance(next_op, relay.expr.Call)
                and isinstance(next_op.op, tvm.ir.op.Op)
                and next_op.op.name == "reshape"):
            next_op = next_op.args[0]
        if util.is_named_ethosu_op(next_op, "binary_elementwise"):
            mul_op = next_op
            next_op = next_op.args[0]
        if util.is_named_ethosu_op(next_op, "pooling"):
            pooling_op = next_op
            next_op = next_op.args[0]
        if util.is_named_ethosu_op(next_op, "depthwise_conv2d"):
            depthwise_op = next_op
            next_op = next_op.args[0]
        while (isinstance(next_op, relay.expr.Call)
               and isinstance(next_op.op, tvm.ir.op.Op)
               and next_op.op.name == "reshape"):
            next_op = next_op.args[0]
        in_var = next_op

        def calculate_expected_output_shape():
            for i in range(len(ifm_shape)):
                if i in axis:
                    if keep_dims:
                        yield 1
                else:
                    yield ifm_shape[i]

        out_shape = tuple(calculate_expected_output_shape())

        # check IFM
        assert tuple(in_var.checked_type.shape) == ifm_shape
        assert in_var.checked_type.dtype == dtype

        # check OFM
        assert tuple(out_var.checked_type.shape) == out_shape
        assert out_var.checked_type.dtype == dtype

        # check expected legalization case
        if axis in [(1, 2), (2, 1), (0, 1),
                    (1, 0)] and keep_dims and dtype == "int8":
            assert depthwise_op and mul_op
            assert mul_op.attrs.operator_type == "MUL"
        elif pooling_op:
            attrs = pooling_op.attrs
            assert (attrs.ifm_scale == attrs.ofm_scale
                    and attrs.ifm_zero_point == attrs.ofm_zero_point)
        else:
            assert depthwise_op
            assert not mul_op

    rewriter = legalize.MeanRewriter()
    pattern_table = [
        (
            ethosu.MeanParams.composite_name,
            ethosu.mean_pattern(),
            lambda pat: ethosu.MeanParams(pat).is_valid(),
        ),
    ]

    mod = (create_relay_graph_with_same_quantization()
           if use_same_quantization else create_tflite_graph())
    mod = partition_ethosu_by_table(mod, pattern_table)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        rewriter, mod["tvmgen_default_ethos_u_main_0"])
    verify(mod["tvmgen_default_ethos_u_main_0"])