コード例 #1
0
def test_function_taking_adt_ref_tuple():
    mod = tvm.IRModule()
    prelude = relay.prelude.Prelude(mod)
    intrp = create_executor("debug", mod)
    _, cons, nil = prelude.mod.get_type("List")

    nil_value = ConstructorValue(nil.tag, [], nil)
    cons_value = ConstructorValue(
        cons.tag,
        [nd.array(np.random.rand(1, 10).astype("float32")), nil_value],
        cons,
    )

    ref_value = RefValue(nd.array(np.random.rand(1, 10).astype("float32")))
    tuple_value = container.tuple_object(
        [nd.array(np.random.rand(1, 10).astype("float32")) for _ in range(10)])

    id_func = intrp.evaluate(prelude.id)

    res_nil = id_func(nil_value)
    assert res_nil.tag == nil_value.tag
    assert len(res_nil.fields) == 0

    res_cons = id_func(cons_value)
    assert res_cons.tag == cons_value.tag
    assert len(res_cons.fields) == len(cons_value.fields)
    tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
                                cons_value.fields[0].asnumpy())
    assert isinstance(res_cons.fields[1], ConstructorValue)
    assert res_cons.fields[1].tag == nil.tag
    assert len(res_cons.fields[1].fields) == 0

    res_ref = id_func(ref_value)
    tvm.testing.assert_allclose(res_ref.value.asnumpy(),
                                ref_value.value.asnumpy())

    res_tuple = id_func(tuple_value)
    for i in range(10):
        tvm.testing.assert_allclose(res_tuple[i].asnumpy(),
                                    tuple_value[i].asnumpy())
コード例 #2
0
def test_function_taking_adt_ref_tuple():
    mod = relay.Module()
    prelude = relay.prelude.Prelude(mod)
    intrp = create_executor("debug", mod)

    nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
    cons_value = ConstructorValue(
        prelude.cons.tag,
        [TensorValue(np.random.rand(1, 10).astype('float32')), nil_value],
        prelude.cons, [relay.TensorType((1, 10), 'float32')])

    ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
    tuple_value = TupleValue(*[
        TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
    ])

    id_func = intrp.evaluate(prelude.id)

    res_nil = id_func(nil_value)
    assert res_nil.tag == nil_value.tag
    assert len(res_nil.fields) == 0

    res_cons = id_func(cons_value)
    assert res_cons.tag == cons_value.tag
    assert len(res_cons.fields) == len(cons_value.fields)
    tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
                                cons_value.fields[0].asnumpy())
    assert isinstance(res_cons.fields[1], ConstructorValue)
    assert res_cons.fields[1].tag == prelude.nil.tag
    assert len(res_cons.fields[1].fields) == 0

    res_ref = id_func(ref_value)
    tvm.testing.assert_allclose(res_ref.value.asnumpy(),
                                ref_value.value.asnumpy())

    res_tuple = id_func(tuple_value)
    for i in range(10):
        tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
                                    tuple_value.fields[i].asnumpy())