コード例 #1
0
def check_grads(f, args, order, atol=None, rtol=None, eps=None):
  # TODO(mattjj,dougalm): add higher-order check
  default_tol = 1e-6 if FLAGS.jax_enable_x64 else 1e-2
  atol = atol or default_tol
  rtol = rtol or default_tol
  eps = eps or default_tol
  jtu.check_jvp(f, partial(api.jvp, f), args, atol, rtol, eps)
  jtu.check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
コード例 #2
0
    def testQr(self, shape, dtype, full_matrices, rng_factory):
        rng = rng_factory()
        _skip_if_unsupported_type(dtype)
        if (np.issubdtype(dtype, onp.complexfloating)
                and (jtu.device_under_test() == "tpu" or jax.lib.version <=
                     (0, 1, 27))):
            raise unittest.SkipTest("No complex QR implementation")
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support batch dimensions. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * np.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(
            onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr,
                          partial(jvp, np.linalg.qr), (a, ),
                          atol=3e-3)
コード例 #3
0
ファイル: linalg_test.py プロジェクト: zhangfeilong/jax
    def testQr(self, shape, dtype, full_matrices, rng):
        m, n = shape[-2:]

        if full_matrices:
            mode, k = "complete", m
        else:
            mode, k = "reduced", min(m, n)

        a = rng(shape, dtype)
        lq, lr = np.linalg.qr(a, mode=mode)

        # onp.linalg.qr doesn't support broadcasting. But it seems like an
        # inevitable extension so we support it in our version.
        nq = onp.zeros(shape[:-2] + (m, k), dtype)
        nr = onp.zeros(shape[:-2] + (k, n), dtype)
        for index in onp.ndindex(*shape[:-2]):
            nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode)

        max_rank = max(m, n)

        # Norm, adjusted for dimension and type.
        def norm(x):
            n = onp.linalg.norm(x, axis=(-2, -1))
            return n / (max_rank * onp.finfo(dtype).eps)

        def compare_orthogonal(q1, q2):
            # Q is unique up to sign, so normalize the sign first.
            sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True)
            phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios))
            q1 *= phases
            self.assertTrue(onp.all(norm(q1 - q2) < 30))

        # Check a ~= qr
        self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30))

        # Compare the first 'k' vectors of Q; the remainder form an arbitrary
        # orthonormal basis for the null space.
        compare_orthogonal(nq[..., :k], lq[..., :k])

        # Check that q is close to unitary.
        self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5))

        if not full_matrices and m >= n:
            jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ))
コード例 #4
0
  def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, rng_factory):
    rng = rng_factory()
    _skip_if_unsupported_type(dtype)
    if b != () and jax.lib.version <= (0, 1, 28):
      raise unittest.SkipTest("Batched SVD requires jaxlib 0.1.29")
    args_maker = lambda: [rng(b + (m, n), dtype)]

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / (max(m, n) * onp.finfo(dtype).eps)

    a, = args_maker()
    out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
    if compute_uv:
      # Check the reconstructed matrices
      if full_matrices:
        k = min(m, n)
        if m < n:
          self.assertTrue(onp.all(
              norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50))
        else:
          self.assertTrue(onp.all(
              norm(a - onp.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350))
      else:
        self.assertTrue(onp.all(
          norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2])) < 300))

      # Check the unitary properties of the singular vector matrices.
      self.assertTrue(onp.all(norm(onp.eye(out[0].shape[-1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
      if m >= n:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
      else:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[-2]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

    else:
      self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))

    self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
                          args_maker, check_dtypes=True)
    if not full_matrices:
      svd = partial(np.linalg.svd, full_matrices=False)
      jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-2, atol=1e-1)
コード例 #5
0
ファイル: linalg_test.py プロジェクト: woerwin/jax
  def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng):
    _skip_if_unsupported_type(dtype)
    args_maker = lambda: [rng((m, n), dtype)]

    # Norm, adjusted for dimension and type.
    def norm(x):
      norm = onp.linalg.norm(x, axis=(-2, -1))
      return norm / (max(m, n) * onp.finfo(dtype).eps)

    a, = args_maker()
    out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)

    if compute_uv:
      # Check the reconstructed matrices
      if full_matrices:
        k = min(m, n)
        if m < n:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50))
        else:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50))
      else:
          self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50))

      # Check the unitary properties of the singular vector matrices.
      self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10))
      if m >= n:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10))
      else:
        self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20))

    else:
      self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4))

    self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv),
                          args_maker, check_dtypes=True)
    if not full_matrices:
      svd = partial(np.linalg.svd, full_matrices=False)
      jtu.check_jvp(svd, partial(jvp, svd), (a,), atol=1e-1 if FLAGS.jax_enable_x64 else jtu.ATOL)
コード例 #6
0
 def test_jvp_linearized(self, f, args):
     print(f)
     jtu.check_jvp(f, partial(jvp_unlinearized, f), args)
コード例 #7
0
 def test_jvp(self, f, args):
     print(f)
     jtu.check_jvp(f, partial(jvp, f), args)
コード例 #8
0
ファイル: core_test.py プロジェクト: tomhennigan/jax
 def test_jvp_linearized(self, f, args):
   jtu.check_jvp(f, partial(jvp_unlinearized, f), args,
                 rtol={onp.float32: 3e-2})
コード例 #9
0
ファイル: core_test.py プロジェクト: tomhennigan/jax
 def test_jvp(self, f, args):
   jtu.check_jvp(f, partial(jvp, f), args, rtol={onp.float32: 3e-2})
コード例 #10
0
ファイル: test_jvp.py プロジェクト: duvenaud/jaxde
import pdb

# Parameters for the test function
#TODO: Test more functions
D = 4
t0 = 0.1
t1 = 0.11
y0 = np.linspace(0.1, 0.9, D)
fargs = (0.1, 0.2)


def f(y, t, (arg1, arg2)):
    return -np.sqrt(t) - np.sin(np.dot(y, arg1)) - np.mean((y + arg2)**2)


def test_odeint_jvp():
    def odeint2(y0, t0, t1, fargs):
        return odeint(y0,
                      np.array([t0, t1]),
                      fargs,
                      func=f,
                      atol=1e-8,
                      rtol=1e-8)

    def odeint2_jvp((y0, t0, t1, fargs), (tan_y, tan_t0, tan_t1, tan_fargs)):
        return jvp_odeint((y0, np.array([t0, t1]), fargs),
                          (tan_y, np.array([tan_t0, tan_t1]), tan_fargs),
                          func=f)

    check_jvp(odeint2, odeint2_jvp, (y0, t0, t1, fargs))
コード例 #11
0
    y0 = np.linspace(0.1, 0.9, D)
    arg = np.zeros((0, ))

    def f(y, t, args):
        return -np.sqrt(t) - y

    @custom_transforms
    def onearg_odeint(y0):
        return odeint(f, y0, np.array([t0, t1]), atol=1e-8, rtol=1e-8)[1]

    def onearg_jvp((y0, arg), (tangent_all, )):
        return jvp_odeint(tangent_all, f, y0, t0, t1, arg)

    ad.defjvp(onearg_odeint.primitive, onearg_jvp)

    check_jvp(onearg_odeint, onearg_jvp, (y0, arg))


def test_odeint_jvp_all():
    D = 10
    t0 = 0.1
    t1 = 0.2
    y0 = np.linspace(0.1, 0.9, D)
    fargs = (0.1, 0.2)

    def f(y, t, arg1, arg2):
        return -np.sqrt(t) - y + arg1 - np.mean((y + arg2)**2)

    @custom_transforms
    def twoarg_odeint(y0, args):
        return odeint(f,
コード例 #12
0
from jax.config import config
config.update("jax_enable_x64", True)

from jax.test_util import check_jvp

from jaxsde import ito_integrate, jvp_ito_integrate
from test_sdeint import make_example_sde


def test_ito_int_jvp():
    # forward mode
    f, g, b, y0, ts, dt = make_example_sde()

    def onearg_int(y0):
        return ito_integrate(f, g, y0, ts, b, dt)

    def odeint2_jvp((y0, ), (tan_y, )):
        return jvp_ito_integrate(tan_y, y0, f, g, ts, b, dt=dt, args=())

    check_jvp(onearg_int, odeint2_jvp, (y0, ), atol=1e-3, rtol=1e-3)