def test_gather_nd_grad():
    data = relay.var("data", relay.TensorType((2, 3), "float64"))
    indices = relay.var("indices", relay.TensorType((2, 4), "int64"))
    fwd = relay.Function([data, indices], relay.gather_nd(data, indices))
    data_np = np.random.rand(2, 3).astype("float64")
    indices_np = np.array([[0, 1, 1, 0], [0, 1, 0, 0]], dtype="int64")
    check_grad(fwd, inputs=[data_np, indices_np], test_inputs=[data_np])
Beispiel #2
0
 def before():
     shape = (tvm.tir.const(10, "int64"),
              tvm.tir.const(1, "int64"))
     x = relay.var("x", shape=shape)
     concat = relay.concatenate([x,x], axis=-1)
     out = relay.gather_nd(concat, indices=relay.expr.const([[0,1],[1,0]], dtype="int64"))
     return relay.Function(relay.analysis.free_vars(out), out)
Beispiel #3
0
    def verify_gather_nd(xshape, yshape, y_data):
        x = relay.var("x", relay.TensorType(xshape, "float32"))
        y = relay.var("y", relay.TensorType(yshape, "int32"))
        z = relay.gather_nd(x, y)

        func = relay.Function([x, y], z)
        x_data = np.random.uniform(size=xshape).astype("float32")
        ref_res = x_data[y_data]

        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
                intrp = relay.create_executor(kind, ctx=ctx, target=target)
                op_res = intrp.evaluate(func)(x_data, y_data)
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
Beispiel #4
0
    def verify_gather_nd(xshape, yshape, y_data):
        x = relay.var("x", relay.TensorType(xshape, "float32"))
        y = relay.var("y", relay.TensorType(yshape, "int32"))
        z = relay.gather_nd(x, y)

        func = relay.Function([x, y], z)
        x_data = np.random.uniform(size=xshape).astype("float32")
        ref_res = x_data[y_data]

        for target, ctx in ctx_list():
            for kind in ["graph", "debug"]:
                intrp = relay.create_executor(kind, ctx=ctx, target=target)
                op_res = intrp.evaluate(func)(x_data, y_data)
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
Beispiel #5
0
    def expected():
        shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
        shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64"))
        x = relay.var("x", shape=shape1)
        p0 = relay.var("p0", shape=shape1)
        p1 = relay.var("p1", shape=shape2, dtype="int64")
        c = relay.const([[0, 1], [1, 0]], dtype="int64")
        concat = relay.concatenate([p0, p0], axis=-1)
        out = relay.gather_nd(concat, indices=p1)

        f0 = relay.Function([p0, p1], out)
        f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))

        y = relay.Call(f0, [x, c])
        return relay.Function([x], y)