예제 #1
0
def test_add_update_params():
    b = np.random.random((2, 3)).astype(np.float32)
    y = Buffer(b)

    @jit.trace
    def f(x):
        return F.add_update(y, x)

    f(np.zeros((2, 3)).astype(np.float32))

    z = Buffer(np.zeros((2, 3)).astype(np.float32))
    F.add_update(y, z, beta=0.1)

    res = f(np.ones((2, 3)).astype(np.float32))
    assertTensorClose(res, b + 1)
예제 #2
0
def test_add_update():
    shape = (2, 3)
    v = np.random.random(shape).astype(np.float32)
    b = Buffer(v)

    u = F.add_update(b, 1)
    assertTensorClose(u.numpy(), v + 1)
    u = F.add_update(b, 1)
    assertTensorClose(u.numpy(), v + 2)

    x = np.ones((2, 2), dtype=np.float32)
    y = x * 0.5
    dest = tensor(x)
    delta = tensor(y)
    r = F.add_update(dest, delta, alpha=tensor(0.9), beta=0.1, bias=0.1)
    assertTensorClose(r.numpy(), x * 0.9 + y * 0.1 + 0.1)