def test_return1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 return dx _x = bm.ones(5) _y = bm.ones(5) g, value = bm.vector_grad(f, return_value=True)(_x, _y) pprint(g, ) pprint(value) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
def test_jacrev2(self): print() def f2(x, y): r1 = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1]]) r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) return r1, r2 jr = jax.jacrev(f2)(jnp.array([1., 2., 3.]), jnp.array([10., 5.])) pprint(jr) br = bm.jacrev(f2)(bm.array([1., 2., 3.]).value, bm.array([10., 5.]).value) pprint(br) assert bm.array_equal(br[0], jr[0]) assert bm.array_equal(br[1], jr[1]) br = bm.jacrev(f2)(bm.array([1., 2., 3.]), bm.array([10., 5.])) pprint(br) assert bm.array_equal(br[0], jr[0]) assert bm.array_equal(br[1], jr[1]) def f2(x, y): r1 = bm.asarray([x[0] * y[0], 5 * x[2] * y[1]]) r2 = bm.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) return r1, r2 br = bm.jacrev(f2)(bm.array([1., 2., 3.]).value, bm.array([10., 5.]).value) pprint(br) assert bm.array_equal(br[0], jr[0]) assert bm.array_equal(br[1], jr[1]) br = bm.jacrev(f2)(bm.array([1., 2., 3.]), bm.array([10., 5.])) pprint(br) assert bm.array_equal(br[0], jr[0]) assert bm.array_equal(br[1], jr[1])
def test_aux1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g, aux = bm.vector_grad(f, has_aux=True)(_x, _y) pprint(g, ) pprint(aux) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
def test_return_aux1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y) print('grad', g) print('value', value) print('aux', aux) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
def test_jacfwd_aux1(self): def f1(x, y): r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) return r _x = bm.array([1., 2., 3.]) _y = bm.array([10., 5.]) class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.x = bm.array([1., 2., 3.]) def __call__(self, y): a = self.x[0] * y[0] b = 5 * self.x[2] * y[1] c = 4 * self.x[1] ** 2 - 2 * self.x[2] d = self.x[2] * jnp.sin(self.x[0]) r = jnp.asarray([a, b, c, d]) return r, (c, d) _jr = jax.jacfwd(f1)(_x, _y) t = Test() br = bm.jacfwd(t, grad_vars=t.x)(_y) self.assertTrue((br == _jr).all()) t = Test() _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) _aux = t(_y)[1] (var_grads, arg_grads), aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True)(_y) print(var_grads, ) print(arg_grads, ) self.assertTrue((var_grads == _jr[0]).all()) self.assertTrue((arg_grads == _jr[1]).all()) self.assertTrue(bm.array_equal(aux, _aux))
def test2(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 return dx _x = bm.ones(5) _y = bm.ones(5) g = bm.vector_grad(f, argnums=0)(_x, _y) pprint(g) self.assertTrue(bm.array_equal(g, 2 * _x)) g = bm.vector_grad(f, argnums=(0,))(_x, _y) self.assertTrue(bm.array_equal(g[0], 2 * _x)) g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) pprint(g) self.assertTrue(bm.array_equal(g[0], 2 * _x)) self.assertTrue(bm.array_equal(g[1], 2 * _y))
def test_fix_type(self): duration = 10. dt = 0.1 for jit in [True, False]: for run_method in [bp.ReportRunner, bp.StructRunner]: ds = ExampleDS() runner = run_method(ds, inputs=('o', 1.), monitors=['o'], dyn_vars=ds.vars(), jit=jit, dt=dt) runner(duration) length = int(duration / dt) assert bm.array_equal(runner.mon.o, bm.repeat(bm.arange(length) + 1, 2).reshape((length, 2)))
def test3(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g = bm.vector_grad(f, argnums=0)(_x, _y) # pprint(g) self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2)) g = bm.vector_grad(f, argnums=(0,))(_x, _y) self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) # pprint(g) self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2))
def test1(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.x = bm.ones(5) self.y = bm.ones(5) def __call__(self, *args, **kwargs): return self.x ** 2 + self.y ** 2 + 10 t = Test() g = bm.vector_grad(t, grad_vars=t.x)() self.assertTrue(bm.array_equal(g, 2 * t.x)) g = bm.vector_grad(t, grad_vars=(t.x,))() self.assertTrue(bm.array_equal(g[0], 2 * t.x)) g = bm.vector_grad(t, grad_vars=(t.x, t.y))() self.assertTrue(bm.array_equal(g[0], 2 * t.x)) self.assertTrue(bm.array_equal(g[1], 2 * t.y))
def test_syn2post_mean(self): data = bm.arange(5) segment_ids = bm.array([0, 0, 1, 1, 2]) self.assertTrue( bm.array_equal(bm.syn2post_mean(data, segment_ids, 3), bm.asarray([0.5, 2.5, 4.])))
def test_syn2post_prod(self): data = bm.arange(5) segment_ids = bm.array([0, 0, 1, 1, 2]) self.assertTrue( bm.array_equal(bm.syn2post_prod(data, segment_ids, 3), bm.asarray([0, 6, 4])))