示例#1
0
def test_tuple_passing():
    x = relay.var('x',
                  type_annotation=relay.ty.TupleType([
                      relay.ty.TensorType((), 'int64'),
                      relay.ty.TensorType((), 'int64')
                  ]))

    fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
    mod = relay.Module({})
    gv = relay.GlobalVar('fn')
    mod[gv] = fn
    mod.entry_func = gv
    mod = relay.transform.InferType()(mod)

    ctx = tvm.cpu()
    target = tvm.target.create('llvm')
    exec = relay.create_executor(mod=mod, ctx=ctx, target=target)
    f = exec.evaluate(gv)
    # First use a Python tuple.
    out = f((10, 8))
    tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
    # Second use a tuple value.
    value_tuple = TupleValue(TensorValue(np.array(11)),
                             TensorValue(np.array(12)))
    out = f(value_tuple)
    tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
示例#2
0
def test_function_taking_adt_ref_tuple():
    mod = relay.Module()
    prelude = relay.prelude.Prelude(mod)
    intrp = create_executor("debug", mod)

    nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
    cons_value = ConstructorValue(
        prelude.cons.tag,
        [TensorValue(np.random.rand(1, 10).astype('float32')), nil_value],
        prelude.cons, [relay.TensorType((1, 10), 'float32')])

    ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
    tuple_value = TupleValue(*[
        TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
    ])

    id_func = intrp.evaluate(prelude.id)

    res_nil = id_func(nil_value)
    assert res_nil.tag == nil_value.tag
    assert len(res_nil.fields) == 0

    res_cons = id_func(cons_value)
    assert res_cons.tag == cons_value.tag
    assert len(res_cons.fields) == len(cons_value.fields)
    tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
                                cons_value.fields[0].asnumpy())
    assert isinstance(res_cons.fields[1], ConstructorValue)
    assert res_cons.fields[1].tag == prelude.nil.tag
    assert len(res_cons.fields[1].fields) == 0

    res_ref = id_func(ref_value)
    tvm.testing.assert_allclose(res_ref.value.asnumpy(),
                                ref_value.value.asnumpy())

    res_tuple = id_func(tuple_value)
    for i in range(10):
        tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
                                    tuple_value.fields[i].asnumpy())
示例#3
0
def test_tuple_value():
    tv = TupleValue(Value.from_scalar(1), Value.from_scalar(2),
                    Value.from_scalar(3))
    np.testing.assert_allclose(tv[0].asnumpy(), 1)
    np.testing.assert_allclose(tv[1].asnumpy(), 2)
    np.testing.assert_allclose(tv[2].asnumpy(), 3)
示例#4
0
def test_tuple_value():
    tv = TupleValue(relay.const(1), relay.const(2), relay.const(3))
    np.testing.assert_allclose(tv[0].data.asnumpy(), 1)
    np.testing.assert_allclose(tv[1].data.asnumpy(), 2)
    np.testing.assert_allclose(tv[2].data.asnumpy(), 3)