Ejemplo n.º 1
0
  def test2(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      return dx

    _x = bm.ones(5)
    _y = bm.ones(5)

    g = bm.vector_grad(f, argnums=0)(_x, _y)
    pprint(g)
    self.assertTrue(bm.array_equal(g, 2 * _x))

    g = bm.vector_grad(f, argnums=(0,))(_x, _y)
    self.assertTrue(bm.array_equal(g[0], 2 * _x))

    g = bm.vector_grad(f, argnums=(0, 1))(_x, _y)
    pprint(g)
    self.assertTrue(bm.array_equal(g[0], 2 * _x))
    self.assertTrue(bm.array_equal(g[1], 2 * _y))
Ejemplo n.º 2
0
  def test3(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      dy = x ** 3 + y ** 3 - 10
      return dx, dy

    _x = bm.ones(5)
    _y = bm.ones(5)

    g = bm.vector_grad(f, argnums=0)(_x, _y)
    # pprint(g)
    self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2))

    g = bm.vector_grad(f, argnums=(0,))(_x, _y)
    self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2))

    g = bm.vector_grad(f, argnums=(0, 1))(_x, _y)
    # pprint(g)
    self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2))
    self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2))
Ejemplo n.º 3
0
  def test1(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.x = bm.ones(5)
        self.y = bm.ones(5)

      def __call__(self, *args, **kwargs):
        return self.x ** 2 + self.y ** 2 + 10

    t = Test()

    g = bm.vector_grad(t, grad_vars=t.x)()
    self.assertTrue(bm.array_equal(g, 2 * t.x))

    g = bm.vector_grad(t, grad_vars=(t.x,))()
    self.assertTrue(bm.array_equal(g[0], 2 * t.x))

    g = bm.vector_grad(t, grad_vars=(t.x, t.y))()
    self.assertTrue(bm.array_equal(g[0], 2 * t.x))
    self.assertTrue(bm.array_equal(g[1], 2 * t.y))
Ejemplo n.º 4
0
  def test_return1(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      return dx

    _x = bm.ones(5)
    _y = bm.ones(5)

    g, value = bm.vector_grad(f, return_value=True)(_x, _y)
    pprint(g, )
    pprint(value)
    self.assertTrue(bm.array_equal(g, 2 * _x))
    self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
Ejemplo n.º 5
0
  def test_aux1(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      dy = x ** 3 + y ** 3 - 10
      return dx, dy

    _x = bm.ones(5)
    _y = bm.ones(5)

    g, aux = bm.vector_grad(f, has_aux=True)(_x, _y)
    pprint(g, )
    pprint(aux)
    self.assertTrue(bm.array_equal(g, 2 * _x))
    self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
Ejemplo n.º 6
0
  def test_return_aux1(self):
    def f(x, y):
      dx = x ** 2 + y ** 2 + 10
      dy = x ** 3 + y ** 3 - 10
      return dx, dy

    _x = bm.ones(5)
    _y = bm.ones(5)

    g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y)
    print('grad', g)
    print('value', value)
    print('aux', aux)
    self.assertTrue(bm.array_equal(g, 2 * _x))
    self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
    self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
Ejemplo n.º 7
0
    def _build_integrator(self, eq):
        if isinstance(eq, joint_eq.JointEq):
            results = []
            for sub_eq in eq.eqs:
                results.extend(self._build_integrator(sub_eq))
            return results

        else:
            vars, pars, _ = utils.get_args(eq)

            # checking
            if len(vars) != 1:
                raise errors.DiffEqError(
                    f'{self.__class__} only supports numerical integration '
                    f'for one variable once, while we got {vars} in {eq}. '
                    f'Please split your multiple variables into multiple '
                    f'derivative functions.')

            # gradient function
            value_and_grad = math.vector_grad(eq,
                                              argnums=0,
                                              dyn_vars=self.dyn_var,
                                              return_value=True)

            # integration function
            def integral(*args, **kwargs):
                assert len(args) > 0
                dt = kwargs.pop('dt', math.get_dt())
                linear, derivative = value_and_grad(*args, **kwargs)
                phi = math.where(linear == 0., math.ones_like(linear),
                                 (math.exp(dt * linear) - 1) / (dt * linear))
                return args[0] + dt * phi * derivative

            return [
                (integral, vars, pars),
            ]
Ejemplo n.º 8
0
 def F_vmap_dfxdx(self):
   if C.F_vmap_dfxdx not in self.analyzed_results:
     f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device)
     self.analyzed_results[C.F_vmap_dfxdx] = f
   return self.analyzed_results[C.F_vmap_dfxdx]
Ejemplo n.º 9
0
 def test1(self):
   f = lambda x: 3 * x ** 2
   _x = bm.ones(10)
   pprint(bm.vector_grad(f, argnums=0)(_x))