Beispiel #1
0
 def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                lhs_contracting, rhs_contracting, bdims):
     rng = jtu.rand_small(self.rng())
     dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
     dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
     self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                         (dtype, dtype), rng)
Beispiel #2
0
    def test_bcoo_rdot_general_contract_and_batch(self, lhs_shape, rhs_shape,
                                                  dtype, dimension_numbers,
                                                  n_batch, n_dense):
        rng = jtu.rand_small(self.rng())
        rng_sparse = rand_sparse(self.rng())

        def args_maker():
            lhs = rng(lhs_shape, dtype)
            rhs = rng_sparse(rhs_shape, dtype)
            data, indices = sparse_ops.bcoo_fromdense(rhs,
                                                      n_batch=n_batch,
                                                      n_dense=n_dense)
            return data, indices, lhs, rhs

        def f_dense(data, indices, lhs, rhs):
            return lax.dot_general(lhs,
                                   rhs,
                                   dimension_numbers=dimension_numbers)

        def f_sparse(data, indices, lhs, rhs):
            return sparse_ops.bcoo_rdot_general(
                lhs,
                data,
                indices,
                rhs_shape=rhs.shape,
                dimension_numbers=dimension_numbers)

        self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
        self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
Beispiel #3
0
    def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype,
                                 dimension_numbers, n_batch, n_dense):
        rng = jtu.rand_small(self.rng())
        rng_sparse = rand_sparse(self.rng())

        X = rng_sparse(lhs_shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(X,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)
        Y = rng(rhs_shape, dtype)

        def f_dense(Y):
            return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

        def f_sparse(Y):
            return sparse_ops.bcoo_dot_general(
                data,
                indices,
                Y,
                lhs_shape=X.shape,
                dimension_numbers=dimension_numbers)

        jf_dense = jax.jacfwd(f_dense)(Y)
        jr_dense = jax.jacrev(f_dense)(Y)
        jf_sparse = jax.jacfwd(f_sparse)(Y)
        jr_sparse = jax.jacrev(f_sparse)(Y)

        tol = {}
        if jtu.device_under_test() == "tpu":
            tol = {np.float32: 5E-3}

        self.assertAllClose(jf_dense, jf_sparse, rtol=tol)
        self.assertAllClose(jr_dense, jr_sparse, rtol=tol)
        self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
Beispiel #4
0
    def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype,
                                            dimension_numbers, n_batch,
                                            n_dense):
        rng = jtu.rand_small(self.rng())
        rng_sparse = rand_sparse(self.rng())

        X = rng_sparse(lhs_shape, dtype)
        data, indices = sparse_ops.bcoo_fromdense(X,
                                                  n_batch=n_batch,
                                                  n_dense=n_dense)
        Y = rng(rhs_shape, dtype)

        def f_dense(X, Y):
            return lax.dot_general(X, Y, dimension_numbers=dimension_numbers)

        def f_sparse(data, indices, Y):
            return sparse_ops.bcoo_dot_general(
                data,
                indices,
                Y,
                lhs_shape=X.shape,
                dimension_numbers=dimension_numbers)

        for data, indices in itertools.product([data, data[:1]],
                                               [indices, indices[:1]]):
            X = sparse_ops.bcoo_todense(data, indices, shape=X.shape)
            self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
Beispiel #5
0
 def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding):
   rng = jtu.rand_small(self.rng())
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   conv = partial(lax.conv, window_strides=strides, padding=padding,
                  precision=lax.Precision.HIGHEST)
   check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                        atol=1e-2, rtol=1e-2)
Beispiel #6
0
def test_custom_jvp():
    @custom_jvp
    def f(x):
        return x**2

    f.defjvp(lambda x: 2 * x)
    rng = jtu.rand_small(np.random)
    tu.check_lazy_fun(f, rng((1, ), 'float32'))
Beispiel #7
0
  def testPadGrad(self, shape, dtype, pads):
    rng = jtu.rand_small(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = np.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
Beispiel #8
0
    def testReduceWindow(self, op, init_val, dtype, shape, dims, strides,
                         padding, base_dilation, window_dilation):
        rng = jtu.rand_small(self.rng())
        init_val = np.asarray(init_val, dtype=dtype)

        def fun(operand):
            return lax.reduce_window(operand, init_val, op, dims, strides,
                                     padding, base_dilation, window_dilation)

        for bdims in all_bdims(shape):
            self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng)
Beispiel #9
0
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims):
        rng = jtu.rand_small(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

        # Checks that batching didn't introduce any transposes or broadcasts.
        jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                    np.zeros(rhs_shape, dtype))
        for eqn in jtu.iter_eqns(jaxpr.jaxpr):
            self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
Beispiel #10
0
 def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                         dimension_numbers):
   rng = jtu.rand_small(self.rng())
   lhs = rng(lhs_shape, dtype)
   rhs = rng(rhs_shape, dtype)
   dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                         precision=lax.Precision.HIGHEST)
   check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
   # check that precision config is preserved
   result, pullback = api.vjp(dot_general, lhs, rhs)
   gresult = lax.zeros_like_array(result)
   s = str(api.make_jaxpr(pullback)(gresult))
   assert "precision=HIGHEST" in s
Beispiel #11
0
  def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
    rng = jtu.rand_small(self.rng())

    pads = lax.padtype_to_pads(shape, dims, strides, padding)

    def fun(operand, cotangents):
      return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims,
                                         strides, pads)
    ones = (1,) * len(shape)
    cotangent_shape = api.eval_shape(
      lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                           pads, ones, ones),
      np.ones(shape, dtype)).shape

    for bdims in all_bdims(cotangent_shape, shape):
      self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                          (dtype, dtype), rng)
Beispiel #12
0
    def testSelectAndGatherAdd(self, dtype, padding):
        rng = jtu.rand_small(self.rng())
        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand, tangents):
            pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
            ones = (1, ) * len(operand.shape)
            return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                              dims, strides, pads, ones, ones)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape, shape):
                self._CheckBatching(fun, 3, bdims, (shape, shape),
                                    (dtype, dtype), rng)
Beispiel #13
0
    def testSelectAndGatherAdd(self, dtype, padding):
        if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
            raise SkipTest(
                "bfloat16 _select_and_gather_add doesn't work on tpu")
        rng = jtu.rand_small(self.rng())
        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand, tangents):
            pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
            ones = (1, ) * len(operand.shape)
            return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                              dims, strides, pads, ones, ones)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape, shape):
                self._CheckBatching(fun, 3, bdims, (shape, shape),
                                    (dtype, dtype), rng)
Beispiel #14
0
  def test_bcoo_dot_general_contract_only(self, lhs_shape, rhs_shape, dtype,
                                          lhs_contracting, rhs_contracting, n_dense):
    rng = jtu.rand_small(self.rng())
    rng_sparse = rand_sparse(self.rng())
    def args_maker():
      lhs = rng_sparse(lhs_shape, dtype)
      rhs = rng(rhs_shape, dtype)
      data, indices = sparse.bcoo_fromdense(lhs, n_dense=n_dense)
      return data, indices, lhs, rhs
    dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))

    def f_dense(data, indices, lhs, rhs):
      return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)

    def f_sparse(data, indices, lhs, rhs):
      return sparse.bcoo_dot_general(data, indices, rhs,
                                         lhs_shape=lhs.shape,
                                         dimension_numbers=dimension_numbers)

    self._CheckAgainstNumpy(f_dense, f_sparse, args_maker)
    self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
Beispiel #15
0
 def testReduce(self, op, init_val, shape, dtype, dims, bdims):
     rng = jtu.rand_small(self.rng())
     init_val = np.asarray(init_val, dtype=dtype)
     fun = lambda operand: lax.reduce(operand, init_val, op, dims)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
Beispiel #16
0
def test_jit():
    rng = jtu.rand_small(np.random)
    tu.check_lazy_fun(jit(lambda x: x * 2), rng((1, ), int))
Beispiel #17
0
 def testPad(self, shape, dtype, pads, bdims):
     rng = jtu.rand_small(self.rng())
     fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
Beispiel #18
0
def test_jit_freevar():
    rng = jtu.rand_small(np.random)
    tu.check_lazy_fun(lambda x, y: jit(lambda x: x * y)(x), rng((1, ), int),
                      rng((1, ), int))
Beispiel #19
0
    op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
    op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5),
              ["rev"]),
    op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
]

JAX_COMPOUND_OP_RECORDS = [
    op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="expm1_large"),
    op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
Beispiel #20
0
    op_record("logical_or", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_xor", 2, default_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("maximum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("negative", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("not_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
    op_record("power", 2, float_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("subtract", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sin", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cos", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tan", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sinh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cosh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tanh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("arcsin", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccos", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan2", 2, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arcsinh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccosh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctanh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
]

JAX_COMPOUND_OP_RECORDS = [
    op_record("divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
    op_record("exp2", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("expm1", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), [],
              test_name="expm1_large"),
    op_record("expm1", 1, numeric_dtypes, all_shapes, jtu.rand_small_positive(), []),
    op_record("floor_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
Beispiel #21
0
def op_record(name, nargs, dtypes, rng, test_grad, test_name=None):
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, rng, test_grad, test_name)


JAX_SPECIAL_FUNCTION_RECORDS = [
    # TODO: digamma has no JVP implemented.
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True),
    # TODO: gammaln has slightly high error.
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False),
    op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True),
    op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True),
    op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
]

CombosWithReplacement = itertools.combinations_with_replacement


class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(