Ejemplo n.º 1
0
  def test_jarrett_jvps2(self):
    def f1(x, y):
      return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y)
    f2 = api.jarrett(f1)

    # TODO(mattjj): doesn't work for (3., onp.array([4., 5.]))
    for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]:
      self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True)

      _, f1_vjp = api.vjp(f1, x, y)
      _, f2_vjp = api.vjp(f2, x, y)
      self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True)
Ejemplo n.º 2
0
    def test_jarrett_jvps(self):
        def f1(x):
            return np.sin(np.sin(np.sin(x)))

        f2 = api.jarrett(f1)

        for x in [3., onp.array([2., 3., 4.])]:
            self.assertAllClose(f1(x), f2(x), check_dtypes=True)

            _, f1_vjp = api.vjp(f1, x)
            _, f2_vjp = api.vjp(f2, x)
            self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True)
Ejemplo n.º 3
0
  def test_remat_scan(self):
    to_scan = lambda c, x: (np.sin(c), None)

    def f_noremat(x):
      y, _ = lax.scan(to_scan, x, onp.arange(3.))
      return y

    def f_yesremat(x):
      y, _ = lax.scan(api.remat(to_scan), x, onp.arange(3.))
      return y

    ans = f_yesremat(4.)
    expected = f_noremat(4.)
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.grad(f_yesremat)(4.)
    expected = api.grad(f_noremat)(4.)
    self.assertAllClose(ans, expected, check_dtypes=False)

    jaxpr = api.make_jaxpr(api.linearize(f_yesremat, 4.)[1])(1.)
    scan_eqn, = jaxpr.jaxpr.eqns
    self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))

    jaxpr = api.make_jaxpr(api.vjp(f_yesremat, 4.)[1])(1.)
    scan_eqn, = jaxpr.jaxpr.eqns
    self.assertIn(' cos ', str(scan_eqn.params['jaxpr']))
Ejemplo n.º 4
0
  def test_coo_matvec_ad(self, shape, dtype, bshape):
    tol = {np.float32: 1E-6, np.float64: 1E-13, np.complex64: 1E-6, np.complex128: 1E-13}

    rng = rand_sparse(self.rng(), post=jnp.array)
    rng_b = jtu.rand_default(self.rng())

    M = rng(shape, dtype)
    data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
    x = rng_b(bshape, dtype)
    xdot = rng_b(bshape, dtype)

    # Forward-mode with respect to the vector
    f_dense = lambda x: M @ x
    f_sparse = lambda x: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
    v_sparse, t_sparse = api.jvp(f_sparse, [x], [xdot])
    v_dense, t_dense = api.jvp(f_dense, [x], [xdot])
    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to the vector
    primals_dense, vjp_dense = api.vjp(f_dense, x)
    primals_sparse, vjp_sparse = api.vjp(f_sparse, x)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)

    # Forward-mode with respect to nonzero elements of the matrix
    f_sparse = lambda data: sparse_ops.coo_matvec(data, row, col, x, shape=M.shape)
    f_dense = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape) @ x
    data = rng((len(data),), data.dtype)
    data_dot = rng((len(data),), data.dtype)
    v_sparse, t_sparse = api.jvp(f_sparse, [data], [data_dot])
    v_dense, t_dense = api.jvp(f_dense, [data], [data_dot])

    self.assertAllClose(v_sparse, v_dense, atol=tol, rtol=tol)
    self.assertAllClose(t_sparse, t_dense, atol=tol, rtol=tol)

    # Reverse-mode with respect to nonzero elements of the matrix
    primals_dense, vjp_dense = api.vjp(f_dense, data)
    primals_sparse, vjp_sparse = api.vjp(f_sparse, data)
    out_dense, = vjp_dense(primals_dense)
    out_sparse, = vjp_sparse(primals_sparse)
    self.assertAllClose(primals_dense[0], primals_sparse[0], atol=tol, rtol=tol)
    self.assertAllClose(out_dense, out_sparse, atol=tol, rtol=tol)
Ejemplo n.º 5
0
 def test_vjp_mismatched_arguments(self):
   _, pullback = api.vjp(lambda x, y: x * y, onp.float32(3), onp.float32(4))
   self.assertRaisesRegex(
     TypeError,
     "Tree structure of cotangent input.*does not match",
     lambda: pullback((onp.float32(7), onp.float32(100))))
   self.assertRaisesRegex(
     TypeError,
     "Type of cotangent input to vjp pullback.*does not match type",
     lambda: pullback((onp.float16(42))))
Ejemplo n.º 6
0
  def testAllGatherVjp(self):
    def f(x):
      return lax.all_gather(x, axis_name='i')

    rng = np.random.RandomState(1)
    x = rng.randn(3, 4)
    y_bar = rng.randn(3, 3, 4)

    x_bar, = vmap(lambda x, y_bar: vjp(f, x)[1](y_bar), axis_name='i')(x, y_bar)
    self.assertAllClose(x_bar, np.sum(y_bar, axis=0))
Ejemplo n.º 7
0
    def dzdt(delta):
        _, dfdw = vjp(lambda p: f(p, x2), params)
        dfdw, = dfdw(delta)

        def z(t):
            p = tree_multimap(np.add, params, tree_map(lambda x: t * x, dfdw))
            return f(p, x1)

        _, dzdot = jvp(z, (0.0, ), (1.0, ))
        return dzdot
Ejemplo n.º 8
0
def _rfft_transpose(t, fft_lengths):
    # The transpose of RFFT can't be expressed only in terms of irfft. Instead of
    # manually building up larger twiddle matrices (which would increase the
    # asymptotic complexity and is also rather complicated), we rely JAX to
    # transpose a naive RFFT implementation.
    dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
    dummy_primals = lax.full_like(t, 0.0, onp.float64, dummy_shape)
    _, jvpfun = vjp(partial(_naive_rfft, fft_lengths=fft_lengths),
                    dummy_primals)
    result, = jvpfun(t)
    return result
Ejemplo n.º 9
0
def _transpose_function(linear_fun, primals):
  """Transpose a linear function."""
  # TODO(shoyer): can we use something more direct than the vjp machinery?
  # It's particularly awkward that we need the second argument to give
  # particular values of the primals, which are entirely arbitrary.
  _, vjp_fun = api.vjp(linear_fun, primals)

  def transposed_fun(x):
    (y,) = vjp_fun(x)
    return y

  return transposed_fun
Ejemplo n.º 10
0
  def testPdotVjp(self):
    def f(x, y):
      return lax.pdot(x, y, 'i')

    rng = np.random.RandomState(1)
    x = rng.randn(3, 4)
    y = rng.randn(4, 5)
    z_bar = rng.randn(3, 5)

    x_bar, y_bar = vmap(lambda x, y, z_bar: vjp(f, x, y)[1](z_bar),
                        axis_name='i', in_axes=(1, 0, None), out_axes=(1, 0))(x, y, z_bar)
    self.assertAllClose(x_bar, jnp.dot(z_bar, y.T))
    self.assertAllClose(y_bar, jnp.dot(x.T, z_bar))
Ejemplo n.º 11
0
 def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                         dimension_numbers, rng_factory):
   rng = rng_factory(self.rng())
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                         precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
   # check that precision config is preserved
   result, pullback = api.vjp(dot_general, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Ejemplo n.º 12
0
 def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
   rng = rng_factory(self.rng())
   tol = {onp.float16: 1e-1, onp.float32: 1e-4}
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                        atol=tol, rtol=tol)
   # check that precision config is preserved
   result, pullback = api.vjp(dot, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Ejemplo n.º 13
0
def ravel_pytree(pytree):
    """Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

  Args:
    pytree: a pytree to ravel.

  Returns:
    A pair where the first element is a 1D array representing the flattened and
    concatenated leaf values, and the second element is a callable for
    unflattening a 1D vector of the same length back to a pytree of of the same
    structure as the input ``pytree``.
  """
    leaves, treedef = tree_flatten(pytree)
    flat, unravel_list = vjp(_ravel_list, *leaves)
    unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
    return flat, unravel_pytree
Ejemplo n.º 14
0
  def test_coo_todense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    data, row, col = sparse_ops.coo_fromdense(M, nnz=(M != 0).sum())
    f = lambda data: sparse_ops.coo_todense(data, row, col, shape=M.shape)

    # Forward-mode
    primals, tangents = api.jvp(f, [data], [jnp.ones_like(data)])
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(tangents, jnp.zeros_like(M).at[row, col].set(1))

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, data)
    data_out, = vjp_fun(primals)
    self.assertArraysEqual(primals, f(data))
    self.assertArraysEqual(data_out, data)
Ejemplo n.º 15
0
  def test_coo_fromdense_ad(self, shape, dtype):
    rng = rand_sparse(self.rng(), post=jnp.array)
    M = rng(shape, dtype)
    nnz = (M != 0).sum()
    f = lambda M: sparse_ops.coo_fromdense(M, nnz=nnz)

    # Forward-mode
    primals, tangents = api.jvp(f, [M], [jnp.ones_like(M)])
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(tangents[0], jnp.ones(nnz, dtype=dtype))
    self.assertEqual(tangents[1].dtype, dtypes.float0)
    self.assertEqual(tangents[2].dtype, dtypes.float0)

    # Reverse-mode
    primals, vjp_fun = api.vjp(f, M)
    M_out, = vjp_fun(primals)
    self.assertArraysEqual(primals[0], f(M)[0])
    self.assertArraysEqual(primals[1], f(M)[1])
    self.assertArraysEqual(primals[2], f(M)[2])
    self.assertArraysEqual(M_out, M)
Ejemplo n.º 16
0
 def jacbwd(f, x):
   y, pullback = vjp(f, x)
   std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
   jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
   return jac_flat.reshape(np.shape(y) + np.shape(x))
Ejemplo n.º 17
0
 def delta_vjp(delta):
     return vjp(lambda p: f(p, x2), params)[1](delta)
Ejemplo n.º 18
0
 def delta_vjp(delta):
     return vjp(f2, params)[1](delta)
Ejemplo n.º 19
0
 def jacbwd(f, x):
     y, pullback = vjp(f, x)
     std_basis = onp.eye(onp.size(y)).reshape((-1, ) + onp.shape(y))
     jac_flat, = vmap(pullback, std_basis, out_bdim=onp.ndim(y))
     return jac_flat.reshape(onp.shape(y) + onp.shape(x))
Ejemplo n.º 20
0
def ravel_pytree(pytree):
  leaves, treedef = tree_flatten(pytree)
  flat, unravel_list = vjp(ravel_list, *leaves)
  unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
  return flat, unravel_pytree