コード例 #1
0
  def test_return1(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      return dx

    _x = bm.ones(5)
    _y = bm.ones(5)

    g, value = bm.vector_grad(f, return_value=True)(_x, _y)
    pprint(g, )
    pprint(value)
    self.assertTrue(bm.array_equal(g, 2 * _x))
    self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
コード例 #2
0
  def test_jacrev2(self):
    print()

    def f2(x, y):
      r1 = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1]])
      r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r1, r2

    jr = jax.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.]))
    pprint(jr)

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]).value, bm.array([10., 5.]).value)
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])

    def f2(x, y):
      r1 = bm.asarray([x[0] * y[0], 5 * x[2] * y[1]])
      r2 = bm.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r1, r2

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]).value, bm.array([10., 5.]).value)
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])

    br = bm.jacrev(f2)(bm.array([1., 2., 3.]), bm.array([10., 5.]))
    pprint(br)
    assert bm.array_equal(br[0], jr[0])
    assert bm.array_equal(br[1], jr[1])
コード例 #3
0
  def test_aux1(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      dy = x ** 3 + y ** 3 - 10
      return dx, dy

    _x = bm.ones(5)
    _y = bm.ones(5)

    g, aux = bm.vector_grad(f, has_aux=True)(_x, _y)
    pprint(g, )
    pprint(aux)
    self.assertTrue(bm.array_equal(g, 2 * _x))
    self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
コード例 #4
0
  def test_return_aux1(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      dy = x ** 3 + y ** 3 - 10
      return dx, dy

    _x = bm.ones(5)
    _y = bm.ones(5)

    g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y)
    print('grad', g)
    print('value', value)
    print('aux', aux)
    self.assertTrue(bm.array_equal(g, 2 * _x))
    self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
    self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
コード例 #5
0
  def test_jacfwd_aux1(self):
    def f1(x, y):
      r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
      return r

    _x = bm.array([1., 2., 3.])
    _y = bm.array([10., 5.])

    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.x = bm.array([1., 2., 3.])

      def __call__(self, y):
        a = self.x[0] * y[0]
        b = 5 * self.x[2] * y[1]
        c = 4 * self.x[1] ** 2 - 2 * self.x[2]
        d = self.x[2] * jnp.sin(self.x[0])
        r = jnp.asarray([a, b, c, d])
        return r, (c, d)

    _jr = jax.jacfwd(f1)(_x, _y)
    t = Test()
    br = bm.jacfwd(t, grad_vars=t.x)(_y)
    self.assertTrue((br == _jr).all())

    t = Test()
    _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y)
    _aux = t(_y)[1]
    (var_grads, arg_grads), aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True)(_y)
    print(var_grads, )
    print(arg_grads, )
    self.assertTrue((var_grads == _jr[0]).all())
    self.assertTrue((arg_grads == _jr[1]).all())
    self.assertTrue(bm.array_equal(aux, _aux))
コード例 #6
0
  def test2(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      return dx

    _x = bm.ones(5)
    _y = bm.ones(5)

    g = bm.vector_grad(f, argnums=0)(_x, _y)
    pprint(g)
    self.assertTrue(bm.array_equal(g, 2 * _x))

    g = bm.vector_grad(f, argnums=(0,))(_x, _y)
    self.assertTrue(bm.array_equal(g[0], 2 * _x))

    g = bm.vector_grad(f, argnums=(0, 1))(_x, _y)
    pprint(g)
    self.assertTrue(bm.array_equal(g[0], 2 * _x))
    self.assertTrue(bm.array_equal(g[1], 2 * _y))
コード例 #7
0
 def test_fix_type(self):
   duration = 10.
   dt = 0.1
   for jit in [True, False]:
     for run_method in [bp.ReportRunner, bp.StructRunner]:
       ds = ExampleDS()
       runner = run_method(ds, inputs=('o', 1.), monitors=['o'], dyn_vars=ds.vars(), jit=jit, dt=dt)
       runner(duration)
       length = int(duration / dt)
       assert bm.array_equal(runner.mon.o,
                             bm.repeat(bm.arange(length) + 1, 2).reshape((length, 2)))
コード例 #8
0
  def test3(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      dy = x ** 3 + y ** 3 - 10
      return dx, dy

    _x = bm.ones(5)
    _y = bm.ones(5)

    g = bm.vector_grad(f, argnums=0)(_x, _y)
    # pprint(g)
    self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2))

    g = bm.vector_grad(f, argnums=(0,))(_x, _y)
    self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2))

    g = bm.vector_grad(f, argnums=(0, 1))(_x, _y)
    # pprint(g)
    self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2))
    self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2))
コード例 #9
0
  def test1(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.x = bm.ones(5)
        self.y = bm.ones(5)

      def __call__(self, *args, **kwargs):
        return self.x ** 2 + self.y ** 2 + 10

    t = Test()

    g = bm.vector_grad(t, grad_vars=t.x)()
    self.assertTrue(bm.array_equal(g, 2 * t.x))

    g = bm.vector_grad(t, grad_vars=(t.x,))()
    self.assertTrue(bm.array_equal(g[0], 2 * t.x))

    g = bm.vector_grad(t, grad_vars=(t.x, t.y))()
    self.assertTrue(bm.array_equal(g[0], 2 * t.x))
    self.assertTrue(bm.array_equal(g[1], 2 * t.y))
コード例 #10
0
ファイル: test_oprators.py プロジェクト: PKU-NIP-Lab/BrainPy
 def test_syn2post_mean(self):
     data = bm.arange(5)
     segment_ids = bm.array([0, 0, 1, 1, 2])
     self.assertTrue(
         bm.array_equal(bm.syn2post_mean(data, segment_ids, 3),
                        bm.asarray([0.5, 2.5, 4.])))
コード例 #11
0
ファイル: test_oprators.py プロジェクト: PKU-NIP-Lab/BrainPy
 def test_syn2post_prod(self):
     data = bm.arange(5)
     segment_ids = bm.array([0, 0, 1, 1, 2])
     self.assertTrue(
         bm.array_equal(bm.syn2post_prod(data, segment_ids, 3),
                        bm.asarray([0, 6, 4])))