def verify_full(shape, dtype, fill_value):
    A = te.placeholder(shape, dtype=dtype, name="A")
    B = topi.full_like(A, fill_value=fill_value)
    C = topi.full(shape=shape, dtype=dtype, fill_value=fill_value)
    s1 = te.create_schedule([B.op])
    s2 = te.create_schedule([C.op])

    @memoize("topi.tests.test_topi_full")
    def get_ref_data():
        return np.full(shape, fill_value, dtype)

    np_nd = get_ref_data()

    def check_device(device):
        if not tvm.runtime.enabled(device):
            print("Skip because %s is not enabled" % device)
            return

        ctx = tvm.context(device, 0)
        out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
        f = tvm.build(s1, [A, B], device, name="full_like")
        f(tvm.nd.array(np.zeros(shape, dtype), ctx), out)
        tvm.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)

        f = tvm.build(s2, [C], device, name="full")
        f(out)
        tvm.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)

    for device in ["llvm"]:
        check_device(device)
Beispiel #2
0
 def check_device(device, ctx):
     with tvm.target.Target(device):
         print("Running on target: %s" % device)
         conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
         data = te.placeholder((2, 1, 2, 4), "int8", "data")
         w = te.placeholder((3, 1, 2, 2), "int8", "w")
         conv1 = conv2d_compute(data, w, 1, 0, 1, "int32")
         zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32"))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32"))
         two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32"))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = conv2d_schedule(outs)
         tvm.build(s, [data, w, add], target=backend)
Beispiel #3
0
 def check_device(device):
     with tvm.target.create(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
             print("Skip because %s is not enabled" % device)
             return
         print("Running on target: %s" % device)
         conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
         data = te.placeholder((2, 1, 2, 4), 'int8', 'data')
         w = te.placeholder((3, 1, 2, 2), 'int8', 'w')
         conv1 = conv2d_compute(data, w, 1, 0, 1, 'int32')
         zeros = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(0, dtype='int32'))
         gt = topi.greater_equal(conv1, zeros)
         one = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(1, dtype='int32'))
         two = topi.full((2, 3, 1, 3), 'int32', tvm.tir.const(2, dtype='int32'))
         where = topi.where(gt, one, two)
         add = topi.add(conv1, where)
         outs = [add]
         s = conv2d_schedule(outs)
         tvm.build(s, [data, w, add], target=backend)
Beispiel #4
0
def zeros_compute(attrs, inputs, output_type):
    assert not inputs
    return [topi.full(output_type.shape, output_type.dtype, 0.0)]
Beispiel #5
0
 def ensure_tensor(tensor):
     if len(tensor.shape) == 0:
         return topi.full((1,), "int64", 1)
     return tensor
Beispiel #6
0
def ones_compute(attrs, inputs, output_type):
    assert len(inputs) == 1
    return [topi.full(output_type.shape, output_type.dtype, 1.0)]