示例#1
0
    def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr,
                 node_map: tvm.ir.container.Map) -> tvm.relay.Expr:

        slice_input = post.args[0]
        params = ethosu_patterns.StridedSliceParams(post.op.body)
        strided_slice = relay.op.strided_slice(
            slice_input,
            params.begin,
            params.end,
            strides=params.strides,
            axes=params.axes,
            slice_mode=params.slice_mode,
        )
        return strided_slice
示例#2
0
    def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr,
                 node_map: tvm.ir.container.Map) -> tvm.relay.Expr:

        slice_input = post.args[0]

        # TODO(lhutton1) For an unknown reason compilation will fail for strides of 4
        # dimensions, so we cannot use params.strides as this will sometimes give
        # strides as [1, 1, 1, 1]. Since we only support strides of 1, hardcoding this
        # value for now.
        strides = [1]

        params = ethosu_patterns.StridedSliceParams(post.op.body)
        strided_slice = relay.op.strided_slice(
            slice_input,
            params.begin,
            params.end,
            strides=strides,
            axes=params.axes,
            slice_mode=params.slice_mode,
        )
        return strided_slice
示例#3
0
def test_relay_strided_slice_legalize(ifm_shape, begin, end):

    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
    strided_slice = relay.op.strided_slice(ifm, begin, end)
    func = relay.Function([ifm], strided_slice)
    mod = tvm.IRModule()
    mod["main"] = func
    mod = relay.transform.InferType()(mod)

    strided_slice_pattern_table = [
        (
            ethosu.StridedSliceParams.composite_name,
            ethosu.strided_slice_pattern(),
            lambda pat: ethosu.StridedSliceParams(pat).is_valid(),
        ),
    ]

    mod = partition_ethosu_by_table(mod, strided_slice_pattern_table)
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.StridedSliceRewriter(), mod["tvmgen_default_ethos_u_main_0"]
    )
    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
        legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"]
    )
    mod = relay.transform.InferType()(mod)

    ext_func = mod["tvmgen_default_ethos_u_main_0"]

    identity = ext_func.body
    assert identity.op.name == "contrib.ethosu.identity"

    # check that the strided_slice is still there
    strided_slice = identity.args[0]
    assert strided_slice.op.name == "strided_slice"

    # check that identity's output shape matches strided slice's output shape
    slice_shape = [a - b for a, b in zip(end, begin)]
    assert list(identity.checked_type.shape) == slice_shape