def check_scatter_nd(data, indices, shape, out): implementations = { "generic": (lambda x, y: topi.scatter_nd(x, y, shape), topi.generic.schedule_extern), "gpu": (lambda x, y: topi.cuda.scatter_nd(x, y, shape), topi.generic.schedule_extern), "cpu": (lambda x, y: topi.x86.scatter_nd(x, y, shape), topi.generic.schedule_extern), } fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations) tvm.topi.testing.compare_numpy_tvm([data, indices], out, target, ctx, fcompute, fschedule)
def check_scatter_nd(data, indices, updates, out, mode="add"): implementations = { "generic": ( lambda x, y, z: topi.scatter_nd(x, y, z, mode), topi.generic.schedule_extern, ), "gpu": ( lambda x, y, z: topi.cuda.scatter_nd(x, y, z, mode), topi.generic.schedule_extern, ), "cpu": ( lambda x, y, z: topi.x86.scatter_nd(x, y, z, mode), topi.generic.schedule_extern, ), } fcompute, fschedule = tvm.topi.testing.dispatch( target, implementations) tvm.topi.testing.compare_numpy_tvm([data, indices, updates], out, target, dev, fcompute, fschedule)
def compute_scatter_nd(attrs, inputs, output_type): """Compute definition of scatter_nd""" return [topi.scatter_nd(inputs[0], inputs[1], attrs.out_shape)]