Ejemplo n.º 1
0
def verify_take(src_shape, indices_src, axis=None):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    a = sym.Variable("a")
    indices = sym.Variable("indices")
    if axis is None:
        out = sym.take(a, indices)
    else:
        out = sym.take(a, indices, axis=axis)
    for target, ctx in ctx_list():
        # set input
        shape_dict = {"a": src_shape, "indices": indices_src.shape}
        type_dict = {"a": src_dtype, "indices": indices_dtype}
        graph, lib, _ = nnvm.compiler.build(out,
                                            target,
                                            shape=shape_dict,
                                            dtype=type_dict)
        m = graph_runtime.create(graph, lib, ctx)

        shape_size = 1
        for i in range(len(src_shape)):
            shape_size = shape_size * src_shape[i]
        a_src = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
        if axis is None:
            out_np = np.take(a_src, indices_src)
        else:
            out_np = np.take(a_src, indices_src, axis=axis)
        #print("out_np:", out_np.shape, "\n", out_np)
        m.run(a=a_src, indices=indices_src)
        out = m.get_output(0, tvm.nd.empty(out_np.shape, dtype=src_dtype))
        #print("out:", out.shape, "\n", out.asnumpy())
        np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
Ejemplo n.º 2
0
def verify_take(src_shape, indices_src, axis=None):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    a = sym.Variable("a", shape=src_shape)
    indices = sym.Variable("indices", shape=indices_src.shape)
    y = sym.take(a, indices, axis=axis)

    def forward(a, indices):
        return np.take(a, indices=indices, axis=axis)

    a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)

    check_function(y, forward,
                   dtype={'a': src_dtype, 'indices': indices_dtype},
                   values={'a': a_src, 'indices': indices_src})
Ejemplo n.º 3
0
def verify_take(src_shape, indices_src, axis=None):
    src_dtype = "float32"
    indices_dtype = "int32"
    indices_src = np.array(indices_src, dtype=indices_dtype)
    a = sym.Variable("a", shape=src_shape)
    indices = sym.Variable("indices", shape=indices_src.shape)
    y = sym.take(a, indices, axis=axis)

    def forward(a, indices):
        return np.take(a, indices=indices, axis=axis)

    a_src = np.arange(np.prod(src_shape), dtype=src_dtype).reshape(src_shape)

    check_function(y, forward,
                   dtype={'a': src_dtype, 'indices': indices_dtype},
                   values={'a': a_src, 'indices': indices_src})