示例#1
0
def test_dump_volatile():
    p = tensor([2])

    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        return x * p

    x = tensor([3])
    y = f(x).numpy()

    for i in range(3):
        np.testing.assert_equal(f(x).numpy(), y)

    file = io.BytesIO()
    f.dump(file, optimize_for_inference=False)
    file.seek(0)
    cg, _, outputs = G.load_graph(file)
    (out, ) = outputs
    assert (cgtools.get_owner_opr_type(
        cgtools.get_owner_opr_inputs(out)[1]) == "ImmutableTensor")
示例#2
0
def test_dump_volatile():
    p = as_raw_tensor([2])

    @trace(symbolic=True, capture_as_const=True)
    def f(x):
        op = ops.Elemwise(Elemwise.Mode.MUL)
        (y, ) = apply(op, x, p)
        return y

    x = as_raw_tensor([3]).numpy()
    y = f.__wrapped__(as_raw_tensor(x)).numpy()

    for i in range(3):
        np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)

    file = io.BytesIO()
    f.dump(file, optimize_for_inference=False)
    file.seek(0)
    cg, _, outputs = G.load_graph(file)
    (out, ) = outputs
    assert (cgtools.get_owner_opr_type(
        cgtools.get_owner_opr_inputs(out)[1]) == "ImmutableTensor")