def verify_gather(data, axis, indices, ref_res): data = np.asarray(data, dtype='float32') indices = np.asarray(indices, dtype='int32') ref_res = np.asarray(ref_res) d = relay.var("x", relay.TensorType(data.shape, "float32")) i = relay.var("y", relay.TensorType(indices.shape, "int32")) z = relay.gather(d, axis, i) func = relay.Function([d, i], z) for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data, indices) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
def relay_gather(c, data, axis, indices): assert axis.is_constant(int) return relay.gather(c.ref(data), axis.value, c.ref(indices))