Beispiel #1
0
    def verify_sort(shape, axis, is_ascend, is_dyn=False):

        if is_dyn:
            x = relay.var(
                "x", relay.TensorType([relay.Any()] * len(shape), "float32"))
        else:
            x = relay.var("x", relay.TensorType(shape, "float32"))
        z = relay.sort(x, axis=axis, is_ascend=is_ascend)
        func = relay.Function([x], z)
        x_data = np.random.uniform(size=shape).astype("float32")
        if is_ascend:
            ref_res = np.sort(x_data, axis=axis)
        else:
            ref_res = -np.sort(-x_data, axis=axis)

        if is_dyn:
            backends = ["vm", "debug"]
        else:
            backends = ["graph", "debug"]
        for target, ctx in tvm.testing.enabled_targets():
            for kind in backends:
                mod = tvm.ir.IRModule.from_expr(func)
                intrp = relay.create_executor(kind,
                                              mod=mod,
                                              ctx=ctx,
                                              target=target)
                op_res = intrp.evaluate()(x_data)
                tvm.testing.assert_allclose(op_res.asnumpy(),
                                            ref_res,
                                            rtol=1e-5)
Beispiel #2
0
    def verify_sort(shape, axis, is_ascend, is_dyn=False, in_dtype="float32"):
        if is_dyn:
            x = relay.var(
                "x", relay.TensorType([relay.Any()] * len(shape), in_dtype))
        else:
            x = relay.var("x", relay.TensorType(shape, in_dtype))
        z = relay.sort(x, axis=axis, is_ascend=is_ascend)
        func = relay.Function([x], z)
        x_data = np.random.uniform(size=shape).astype(in_dtype)
        if is_ascend:
            ref_res = np.sort(x_data, axis=axis)
        else:
            ref_res = -np.sort(-x_data, axis=axis)

        if is_dyn:
            backend = "vm"
        else:
            backend = "graph"
        for target, dev in tvm.testing.enabled_targets():
            mod = tvm.ir.IRModule.from_expr(func)
            op_res = relay.create_executor(backend,
                                           mod=mod,
                                           device=dev,
                                           target=target).evaluate()(x_data)
            tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)