コード例 #1
0
def compute_argsort(attrs, inputs, _, target):
    """Compute definition of argsort"""
    axis = get_const_int(attrs.axis)
    is_ascend = bool(get_const_int(attrs.is_ascend))
    dtype = attrs.dtype
    return [
        topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)
    ]
コード例 #2
0
ファイル: _algorithm.py プロジェクト: bddppq/tvm
def compute_argsort(attrs, inputs, _, target):
    """Compute definition of argsort"""
    axis = get_const_int(attrs.axis)
    is_ascend = bool(get_const_int(attrs.is_ascend))
    dtype = str(attrs.dtype)
    return [
        topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
                            dtype=dtype, flag=False)
    ]
コード例 #3
0
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            out = topi.argsort(data, axis=axis, is_ascend=is_ascend)
            s = topi.generic.schedule_argsort(out)

        tvm_data = tvm.nd.array(np_data, ctx)
        tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), ctx)
        f = tvm.build(s, [data, out], device)
        f(tvm_data, tvm_out)
        tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0)
コード例 #4
0
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        with tvm.target.create(device):
            out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
            s = topi.generic.schedule_argsort(out)

        tvm_data = tvm.nd.array(np_data, ctx)
        tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
        tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
        f = tvm.build(s, [data, valid_count, out], device)
        f(tvm_data, tvm_valid_count, tvm_out)
        tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)