def test2(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 return dx _x = bm.ones(5) _y = bm.ones(5) g = bm.vector_grad(f, argnums=0)(_x, _y) pprint(g) self.assertTrue(bm.array_equal(g, 2 * _x)) g = bm.vector_grad(f, argnums=(0,))(_x, _y) self.assertTrue(bm.array_equal(g[0], 2 * _x)) g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) pprint(g) self.assertTrue(bm.array_equal(g[0], 2 * _x)) self.assertTrue(bm.array_equal(g[1], 2 * _y))
def test3(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g = bm.vector_grad(f, argnums=0)(_x, _y) # pprint(g) self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2)) g = bm.vector_grad(f, argnums=(0,))(_x, _y) self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) # pprint(g) self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2))
def test1(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.x = bm.ones(5) self.y = bm.ones(5) def __call__(self, *args, **kwargs): return self.x ** 2 + self.y ** 2 + 10 t = Test() g = bm.vector_grad(t, grad_vars=t.x)() self.assertTrue(bm.array_equal(g, 2 * t.x)) g = bm.vector_grad(t, grad_vars=(t.x,))() self.assertTrue(bm.array_equal(g[0], 2 * t.x)) g = bm.vector_grad(t, grad_vars=(t.x, t.y))() self.assertTrue(bm.array_equal(g[0], 2 * t.x)) self.assertTrue(bm.array_equal(g[1], 2 * t.y))
def test_return1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 return dx _x = bm.ones(5) _y = bm.ones(5) g, value = bm.vector_grad(f, return_value=True)(_x, _y) pprint(g, ) pprint(value) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
def test_aux1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g, aux = bm.vector_grad(f, has_aux=True)(_x, _y) pprint(g, ) pprint(aux) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
def test_return_aux1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y) print('grad', g) print('value', value) print('aux', aux) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
def _build_integrator(self, eq): if isinstance(eq, joint_eq.JointEq): results = [] for sub_eq in eq.eqs: results.extend(self._build_integrator(sub_eq)) return results else: vars, pars, _ = utils.get_args(eq) # checking if len(vars) != 1: raise errors.DiffEqError( f'{self.__class__} only supports numerical integration ' f'for one variable once, while we got {vars} in {eq}. ' f'Please split your multiple variables into multiple ' f'derivative functions.') # gradient function value_and_grad = math.vector_grad(eq, argnums=0, dyn_vars=self.dyn_var, return_value=True) # integration function def integral(*args, **kwargs): assert len(args) > 0 dt = kwargs.pop('dt', math.get_dt()) linear, derivative = value_and_grad(*args, **kwargs) phi = math.where(linear == 0., math.ones_like(linear), (math.exp(dt * linear) - 1) / (dt * linear)) return args[0] + dt * phi * derivative return [ (integral, vars, pars), ]
def F_vmap_dfxdx(self): if C.F_vmap_dfxdx not in self.analyzed_results: f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) self.analyzed_results[C.F_vmap_dfxdx] = f return self.analyzed_results[C.F_vmap_dfxdx]
def test1(self): f = lambda x: 3 * x ** 2 _x = bm.ones(10) pprint(bm.vector_grad(f, argnums=0)(_x))