Ejemplo n.º 1
0
    def testHessian(self):
        # test based on code from sindhwani@google
        def fun(x, t):
            return np.sum(np.power(np.maximum(x, 0.0), 2)) + t

        x = onp.array([-1., -0.5, 0., 0.5, 1.0])

        ans = hessian(lambda x: fun(x, 0.0))(x)
        expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.],
                              [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.],
                              [0., 0., 0., 0., 2.]])
        self.assertAllClose(ans, expected, check_dtypes=False)
Ejemplo n.º 2
0
  def testIssue387(self):
    # https://github.com/google/jax/issues/387
    R = np.random.RandomState(0).rand(100, 2)

    def dist_sq(R):
      dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :]
      zero = jnp.zeros_like(dR)
      dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR))
      return jnp.sum(dR ** 2, axis=2)

    @jit
    def f(R):
      _ = dist_sq(R)
      return jnp.sum(R ** 2)

    _ = hessian(f)(R)  # don't crash on UnshapedArray