Esempio n. 1
0
    def verify_gather(data, axis, indices, ref_res):
        data = np.asarray(data, dtype='float32')
        indices = np.asarray(indices, dtype='int32')
        ref_res = np.asarray(ref_res)

        d = relay.var("x", relay.TensorType(data.shape, "float32"))
        i = relay.var("y", relay.TensorType(indices.shape, "int32"))
        z = relay.gather(d, axis, i)

        func = relay.Function([d, i], z)

        for target, ctx in tvm.testing.enabled_targets():
            for kind in ["graph", "debug"]:
                intrp = relay.create_executor(kind, ctx=ctx, target=target)
                op_res = intrp.evaluate(func)(data, indices)
                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
                                            rtol=1e-5)
Esempio n. 2
0
def relay_gather(c, data, axis, indices):
    assert axis.is_constant(int)
    return relay.gather(c.ref(data), axis.value, c.ref(indices))