예제 #1
0
def test_make_ggnvp():
    A = npr.randn(5, 4)
    x = npr.randn(4)
    v = npr.randn(4)

    fun = lambda x: np.dot(A, x)
    check_equivalent(make_ggnvp(fun)(x)(v), _make_explicit_ggnvp(fun)(x)(v))

    fun2 = lambda x: np.tanh(np.dot(A, x))
    check_equivalent(make_ggnvp(fun2)(x)(v), _make_explicit_ggnvp(fun2)(x)(v))
예제 #2
0
def test_make_ggnvp():
    A = npr.randn(5, 4)
    x = npr.randn(4)
    v = npr.randn(4)

    fun = lambda x: np.dot(A, x)
    check_equivalent(make_ggnvp(fun)(x)(v), _make_explicit_ggnvp(fun)(x)(v))

    fun2 = lambda x: np.tanh(np.dot(A, x))
    check_equivalent(make_ggnvp(fun2)(x)(v), _make_explicit_ggnvp(fun2)(x)(v))
예제 #3
0
def test_make_ggnvp_nondefault_g():
    A = npr.randn(5, 4)
    x = npr.randn(4)
    v = npr.randn(4)

    g = lambda y: np.sum(2.*y**2 + y**4)

    fun = lambda x: np.dot(A, x)
    check_equivalent(make_ggnvp(fun, g)(x)(v), _make_explicit_ggnvp(fun, g)(x)(v))

    fun2 = lambda x: np.tanh(np.dot(A, x))
    check_equivalent(make_ggnvp(fun2, g)(x)(v), _make_explicit_ggnvp(fun2, g)(x)(v))
예제 #4
0
def test_make_ggnvp_nondefault_g():
    A = npr.randn(5, 4)
    x = npr.randn(4)
    v = npr.randn(4)

    g = lambda y: np.sum(2. * y**2 + y**4)

    fun = lambda x: np.dot(A, x)
    check_equivalent(
        make_ggnvp(fun, g)(x)(v),
        _make_explicit_ggnvp(fun, g)(x)(v))

    fun2 = lambda x: np.tanh(np.dot(A, x))
    check_equivalent(
        make_ggnvp(fun2, g)(x)(v),
        _make_explicit_ggnvp(fun2, g)(x)(v))
예제 #5
0
def test_make_ggnvp_broadcasting():
  A = npr.randn(4, 5)
  x = npr.randn(10, 4)
  v = npr.randn(10, 4)

  fun = lambda x: np.tanh(np.dot(x, A))
  res1 = np.stack([_make_explicit_ggnvp(fun)(xi)(vi) for xi, vi in zip(x, v)])
  res2 = make_ggnvp(fun)(x)(v)
  check_equivalent(res1, res2)
예제 #6
0
def test_make_ggnvp_broadcasting():
    A = npr.randn(4, 5)
    x = npr.randn(10, 4)
    v = npr.randn(10, 4)

    fun = lambda x: np.tanh(np.dot(x, A))
    res1 = np.stack(
        [_make_explicit_ggnvp(fun)(xi)(vi) for xi, vi in zip(x, v)])
    res2 = make_ggnvp(fun)(x)(v)
    check_equivalent(res1, res2)