예제 #1
0
class IndexingTest(jtu.TestCase):
  """Tests for Numpy indexing translation rules."""

  @parameterized.named_parameters(jtu.cases_from_list({
      "testcase_name": "{}_inshape={}_indexer={}".format(
          name, jtu.format_shape_dtype_string( shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer
  } for name, index_specs in STATIC_INDEXING_TESTS
    for shape, indexer in index_specs
    for dtype in all_dtypes
    for rng_factory in [jtu.rand_default]))
  def testStaticIndexing(self, shape, dtype, rng_factory, indexer):
    # TODO(rohanj): Revisit passing in self.rng() to this to customize further.
    # This would need updating lax_numpy_test as well.
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype)]
    onp_fun = lambda x: x[indexer]
    jnp_fun = lambda x: jnp.asarray(x)[indexer]
    self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
                          check_incomplete_shape=True)

  def _ReplaceSlicesWithTuples(self, idx):
    """Helper method to replace slices with tuples for dynamic indexing args."""
    if isinstance(idx, slice):
      triple = idx.start, idx.stop, idx.step
      isnone = [i for i, elt in enumerate(triple) if elt is None]
      zeros = itertools.repeat(0)
      nones = itertools.repeat(None)
      out = subvals(triple, zip(isnone, zeros))
      return out, lambda out: slice(*subvals(out, zip(isnone, nones)))
    elif isinstance(idx, (tuple, list)) and idx:
      t = type(idx)
      elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
      return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
    else:
      return idx, lambda x: x

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in [
          ("OneSliceIndex",
           [IndexSpec(shape=(5,), indexer=slice(1, 3)),
            IndexSpec(shape=(5, 4), indexer=slice(1, 3))]),
          ("TwoSliceIndices",
           [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]),
          ("NonUnitStrides", [
              IndexSpec(shape=(3,), indexer=slice(None, None, -1)),
              IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
              IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
          ]),
          ("OnlyStartOrStopDynamic", [
              IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
              IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
          ]),
      ]
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testDynamicIndexingWithSlices(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

    def onp_fun(x, unpacked_indexer):
      indexer = pack_indexer(unpacked_indexer)
      return x[indexer]

    jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx)

    args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
    self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
                          check_eval_on_shapes=False,
                          check_incomplete_shape=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in [
          ("OneIntIndex",
           [IndexSpec(shape=(3,), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3,), indexer=-1),
            IndexSpec(shape=(3,), indexer=-2)]),
          ("TwoIntIndices",
           [IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]),
          ("ThreeIntIndices",
           [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
      ]
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer):
    # TODO(rohanj): Revisit passing in self.rng() to this to customize further.
    # This would need updating lax_numpy_test as well.
    rng = rng_factory()
    unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

    def onp_fun(x, unpacked_indexer):
      indexer = pack_indexer(unpacked_indexer)
      return x[indexer]

    jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx)

    args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
    self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
                          check_incomplete_shape=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in ADVANCED_INDEXING_TESTS
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    args_maker = lambda: [rng(shape, dtype), indexer]
    onp_fun = lambda x, idx: x[idx]
    jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx)

    self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
                          check_incomplete_shape=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_indexer={}"
       .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer),
       "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer}
      for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
      for shape, indexer in index_specs
      for dtype in all_dtypes
      for rng_factory in [jtu.rand_default])
  def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
    rng = rng_factory()
    indexer_with_dummies = [e if isinstance(e, onp.ndarray) else ()
                            for e in indexer]
    substitutes = [(i, e) for i, e in enumerate(indexer)
                   if not isinstance(e, onp.ndarray)]
    args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]

    def np_fun(x, indexer_with_dummies):
      idx = type(indexer)(subvals(indexer_with_dummies, substitutes))
      return x[idx]

    jnp_fun = lambda x, idx: np_fun(jnp.asarray(x), idx)

    self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True,
                          check_incomplete_shape=True)

  def testAdvancedIndexingManually(self):
    x = onp.random.RandomState(0).randn(3, 4, 5)
    index_array = onp.array([0, 2, -1, 0])

    op = lambda x, index_array: x[..., index_array, :]
    cop = npe.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

    op = lambda x, index_array: x[..., index_array, :, index_array, None]
    cop = npe.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

    op = lambda x, index_array: x[index_array, ..., index_array[:, None], None]
    cop = npe.jit(op)

    a1 = op(x, index_array)
    a2 = cop(x, index_array)

    self.assertAllClose(a1, a2, check_dtypes=True)

  # Note that we don't currently allow __iter__ in graph mode. So this test only
  # iterates over eager tensor.
  def testUnpacking(self):

    def foo(x):
      a, b, c = x
      return a + b + c

    a1 = foo(onp.arange(3))
    a2 = foo(jnp.arange(3))

    self.assertAllClose(a1, a2, check_dtypes=True)

  def testBooleanIndexingArray1D(self):
    idx = onp.array([True, True, False])
    x = jnp.asarray(onp.arange(3))
    ans = x[idx]
    expected = onp.arange(3)[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingList1D(self):
    idx = [True, True, False]
    x = jnp.asarray(onp.arange(3))
    ans = x[idx]
    expected = onp.arange(3)[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingArray2DBroadcast(self):
    idx = onp.array([True, True, False, True])
    x = onp.arange(8).reshape(4, 2)
    ans = jnp.asarray(x)[idx]
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingList2DBroadcast(self):
    idx = [True, True, False, True]
    x = onp.arange(8).reshape(4, 2)
    ans = jnp.asarray(x)[idx]
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingArray2D(self):
    idx = onp.array([[True, False],
                     [False, True],
                     [False, False],
                     [True, True]])
    x = onp.arange(8).reshape(4, 2)
    ans = jnp.asarray(x)[idx]
    expected = x[idx]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testBooleanIndexingDynamicShape(self):
    x = onp.zeros(3)
    i = onp.array([True, True, False])
    ans = x[i]
    expected = jnp.asarray(x)[i]
    self.assertAllClose(ans, expected, check_dtypes=True)

  def testIssue187(self):
    x = jnp.ones((5, 5))
    x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

    x = onp.arange(25).reshape((5, 5))
    ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
    expected = x[[0, 2, 4], [0, 2, 4]]
    self.assertAllClose(ans, expected, check_dtypes=False)

  # TODO(agarwal): Fix this use case.
  @jtu.disable
  def testIndexingEmptyDimension(self):
    # Issue 2671: XLA error when indexing into dimension of size 0
    x = jnp.ones((2, 0))
    # The following work, even on axis 1 of size 0
    _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]

    with self.assertRaisesRegex(IndexError,
                                "index .* is out of bounds for axis .* with size 0"):
      _ = onp.ones((2, 0))[0, 0]  # The numpy error
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      _ = x[0, 0]  # JAX indexing
    with self.assertRaisesRegex(IndexError,
                                "index is out of bounds for axis .* with size 0"):
      npe.jit(lambda i: x[0, i])(0)  # JAX indexing under jit

  def testBooleanIndexingWithEmptyResult(self):
    # based on a TensorFlow Probability test that started failing after #1623
    x = jnp.array([-1])
    mask = jnp.array([False])
    ans = x[mask]  # doesn't crash

    expected =  onp.array([-1])[onp.array([False])]
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testFloatIndexingError(self):
    error_regex = "only integers, slices.*are valid indices"
    # Verify onp behavior
    with self.assertRaisesRegex(IndexError, error_regex):
      _ = onp.zeros((2, 2))[(0, 0.)]
    # Test jnp
    with self.assertRaisesRegex(IndexError, error_regex):
      jnp.zeros(2)[0.]
    with self.assertRaisesRegex(IndexError, error_regex):
      jnp.zeros((2, 2))[(0, 0.)]
    # Test with jit
    with self.assertRaisesRegex(IndexError, error_regex):
      npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))

  def testIndexOutOfBounds(self):  # https://github.com/google/jax/issues/2245
    array = jnp.ones(5)
    self.assertAllClose(array, array[:10], check_dtypes=True)
예제 #2
0
class IndexedUpdateTest(jtu.TestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing
        # xla_context 0-th output from". Investigate.
        check_xla = (
            not has_non_trivial_stride(indexer) and  # b/123559667
            not (isinstance(indexer, slice) and indexer.stop == 8
                 and indexer.step == -1))
        self._CompileAndCheck(tfnp_fn,
                              args_maker,
                              check_incomplete_shape=True,
                              check_experimental_compile=check_xla,
                              check_xla_forced_compile=check_xla)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                             rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testMixedAdvancedIndexing(self, shape, dtype, update_shape,
                                  update_dtype, rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        check_xla = not has_non_trivial_stride(indexer)  # b/123559667
        self._CompileAndCheck(tfnp_fn,
                              args_maker,
                              check_incomplete_shape=True,
                              check_experimental_compile=check_xla,
                              check_xla_forced_compile=check_xla)

    @parameterized.named_parameters(
        jtu.cases_from_list({  # pylint: disable=g-complex-comprehension
            "testcase_name":
            "_{}_{}_{}_{}".format(
                jtu.format_shape_dtype_string(shape, dtype), indexer,
                jtu.format_shape_dtype_string(update_shape, update_dtype),
                op.name),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng_factory":
            rng_factory,
            "indexer":
            indexer,
            "update_shape":
            update_shape,
            "update_dtype":
            update_dtype,
            "op":
            op
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for op in [UpdateOps.ADD, UpdateOps.UPDATE]
                            for dtype in float_dtypes
                            for update_shape in _broadcastable_shapes(
                                _update_shape(shape, indexer))
                            for update_dtype in float_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                                rng_factory, indexer, op):
        rng = rng_factory()
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        x = rng(shape, dtype)
        y = rng(update_shape, update_dtype)
        self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_start_indices={}_update_shape={}".format(  # pylint: disable=g-complex-comprehension
                jtu.format_shape_dtype_string(shape, dtype), start_indices,
                update_shape),
            "shape":
            shape,
            "dtype":
            dtype,
            "start_indices":
            start_indices,
            "update_shape":
            update_shape,
            "rng_factory":
            rng_factory
        } for shape, start_indices, update_shape in [
            [(3, ), (1, ), (1, )],
            [(5, 3), (1, 1), (3, 1)],
            [(5, 3), (1, -2), (3, 1)],
            [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
            [(), (), ()],
        ] for dtype in default_dtypes for rng_factory in [jtu.rand_default]))
    def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape,
                               rng_factory):
        rng = rng_factory()

        def args_maker():
            return [
                rng(shape, dtype),
                rng(update_shape, dtype),
                onp.array(start_indices)
            ]

        # update's shape must be fully known.
        # TODO(wangpeng): Support turning off check_incomplete_shape for individual
        #   arguments.
        self._CompileAndCheck(npe.dynamic_update_slice,
                              args_maker,
                              check_incomplete_shape=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_start_indices={}_update_shape={}".format(  # pylint: disable=g-complex-comprehension
                jtu.format_shape_dtype_string(shape, dtype), start_indices,
                update_shape),
            "shape":
            shape,
            "dtype":
            dtype,
            "start_indices":
            start_indices,
            "update_shape":
            update_shape,
            "rng_factory":
            rng_factory
        } for shape, start_indices, update_shape in [
            [(3, ), (1, ), (1, )],
            [(5, 3), (1, 1), (3, 1)],
            [(5, 3), (1, -2), (3, 1)],
            [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
            [(), (), ()],
        ] for dtype in default_dtypes for rng_factory in [jtu.rand_default]))
    def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, start_indices,
                                           update_shape, rng_factory):
        rng = rng_factory()

        def args_maker():
            return [
                rng(shape, dtype),
                rng(update_shape, dtype),
                onp.array(start_indices)
            ]

        self._CheckAgainstNumpy(dynamic_update_slice_reference,
                                npe.dynamic_update_slice, args_maker)

    def testDynamicUpdateSliceInDim(self):
        rng = jtu.rand_default()
        x = rng((6, 7), onp.int32)
        y = rng((3, 7), onp.int32)
        z = x.copy()
        z[2:5] = y
        self.assertAllClose(npe.dynamic_update_slice_in_dim(x, y, 2, 0),
                            z,
                            check_dtypes=True)
예제 #3
0
class IndexedUpdateTest(jtu.TestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    sugared, op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op,
                "sugared":
                sugared
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for sugared in [True, False]
            for rng_factory in [jtu.rand_default]))
    @jtu.disable
    def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, sugared, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        if sugared:
            jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
        else:
            jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    sugared, op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op,
                "sugared":
                sugared
            } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for sugared in [True, False]
            for rng_factory in [jtu.rand_default]))
    @jtu.disable
    def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                             rng_factory, indexer, sugared, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        if sugared:
            jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
        else:
            jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op,
                "sugared":
                sugared
            } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else all_dtypes)
            for sugared in [True, False]
            for rng_factory in [jtu.rand_default]))
    @jtu.disable
    def testMixedAdvancedIndexing(self, shape, dtype, update_shape,
                                  update_dtype, rng_factory, indexer, sugared,
                                  op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y)
        if sugared:
            jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y)
        else:
            jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True)
        self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "{}_inshape={}_indexer={}_update={}_op={}".format(
                    name, jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs
            for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
            for dtype in float_dtypes
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in (
                    [dtype] if op == UpdateOps.ADD else float_dtypes)
            for rng_factory in [jtu.rand_default]))
    @jtu.skip_on_devices("tpu")  # TODO(mattjj,phawkins): tpu issues
    @jtu.disable
    def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                                rng_factory, indexer, op):
        rng = rng_factory()
        jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y)
        x = rng(shape, dtype)
        y = rng(update_shape, update_dtype)
        check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.)

    @jtu.disable
    def testSegmentSumBehavior(self):
        # testAdvancedIndexing compares against NumPy, and as a result doesn't check
        # repeated indices. This test is just a simple manual check, based on
        # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum
        data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
        segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

        ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids,
                            data)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testSegmentSum(self):
        data = onp.array([5, 1, 7, 2, 3, 4, 1, 3])
        segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3])

        # test with explicit num_segments
        ans = ops.segment_sum(data, segment_ids, num_segments=4)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

        # test without explicit num_segments
        ans = ops.segment_sum(data, segment_ids)
        expected = onp.array([13, 2, 7, 4])
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testIndexDtypeError(self):
        # https://github.com/google/jax/issues/2795
        jnp.array(1)  # get rid of startup warning
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("error")
            jnp.zeros(5).at[::2].set(1)
            self.assertLen(w, 0)
예제 #4
0
class IndexingTest(jtu.TestCase):
    """Tests for Numpy indexing translation rules."""
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "{}_inshape={}_indexer={}".format(
                name, jtu.format_shape_dtype_string(shape, dtype), indexer),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng_factory":
            rng_factory,
            "indexer":
            indexer
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for dtype in all_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testStaticIndexing(self, shape, dtype, rng_factory, indexer):
        # TODO(rohanj): Revisit passing in self.rng() to this to customize further.
        # This would need updating lax_numpy_test as well.
        rng = rng_factory()
        args_maker = lambda: [rng(shape, dtype)]
        fun = lambda x: x[indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in STATIC_INDEXING_GRAD_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in float_dtypes
                                    for rng_factory in [jtu.rand_default])
    @jtu.disable
    def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer):
        rng = rng_factory()
        tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: x[indexer]**2
        check_grads(fun, (arg, ), 2, tol, tol, tol)

    def _ReplaceSlicesWithTuples(self, idx):
        """Helper method to replace slices with tuples for dynamic indexing args."""
        if isinstance(idx, slice):
            triple = idx.start, idx.stop, idx.step
            isnone = [i for i, elt in enumerate(triple) if elt is None]
            zeros = itertools.repeat(0)
            nones = itertools.repeat(None)
            out = subvals(triple, zip(isnone, zeros))
            return out, lambda out: slice(*subvals(out, zip(isnone, nones)))
        elif isinstance(idx, (tuple, list)) and idx:
            t = type(idx)
            elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
            return elts, lambda elts: t(
                (pack(i) for pack, i in zip(packs, elts)))
        else:
            return idx, lambda x: x

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneSliceIndex", [
            IndexSpec(shape=(5, ), indexer=slice(1, 3)),
            IndexSpec(shape=(5, 4), indexer=slice(1, 3))
        ]),
        ("TwoSliceIndices", [
            IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))
        ]),
        ("NonUnitStrides", [
            IndexSpec(shape=(3, ), indexer=slice(None, None, -1)),
            IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)),
            IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2))
        ]),
        ("OnlyStartOrStopDynamic", [
            IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))),
            IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None)))
        ]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    @jtu.disable
    def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng_factory,
                                            indexer):
        rng = rng_factory()
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @npe.jit
        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self.assertRaises(IndexError, lambda: fun(*args_maker()))

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2)
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory,
                                        indexer):
        # TODO(rohanj): Revisit passing in self.rng() to this to customize further.
        # This would need updating lax_numpy_test as well.
        rng = rng_factory()
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        def fun(x, unpacked_indexer):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        args_maker = lambda: [rng(shape, dtype), unpacked_indexer]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("OneIntIndex", [
            IndexSpec(shape=(3, ), indexer=1),
            IndexSpec(shape=(3, 3), indexer=0),
            IndexSpec(shape=(3, 4, 5), indexer=2),
            IndexSpec(shape=(3, ), indexer=-1),
            IndexSpec(shape=(3, ), indexer=-2),
        ]),
        ("TwoIntIndices", [
            IndexSpec(shape=(3, 3), indexer=(2, 1)),
            IndexSpec(shape=(3, 4, 5), indexer=(1, 2)),
            IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)),
        ]),
        ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng_factory in [jtu.rand_default])
    @jtu.disable
    def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory,
                                             indexer):
        rng = rng_factory(self.rng())
        tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @npe.jit
        def fun(unpacked_indexer, x):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        arr = rng(shape, dtype)
        check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    @jtu.disable
    def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer):
        rng = rng_factory()
        args_maker = lambda: [rng(shape, dtype), indexer]
        fun = lambda x, idx: jnp.asarray(x)[idx]
        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in [
        ("One1DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([0, 1])),
            IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])),
            IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])),
            IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])),
        ]),
        ("One2DIntArrayIndex", [
            IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])),
            IndexSpec(shape=(3,
                             3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])),
            IndexSpec(shape=(3, 4, 5),
                      indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])),
        ]),
        ("Two1DIntArrayIndicesNoBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([0, 1]),
                               onp.array([1, 2])]),
            IndexSpec(
                shape=(3,
                       4, 5),
                indexer=[onp.array([0, 2, 0, 1]),
                         onp.array([-1, 0, -1, 2])]),
        ]),
        ("Two1DIntArrayIndicesWithBroadcasting", [
            IndexSpec(shape=(3, 3),
                      indexer=[onp.array([[0, 1]]),
                               onp.array([1, 2])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[
                          onp.array([[0, 2, 0, 1]]),
                          onp.array([-1, 0, -1, 2])
                      ]),
        ]),
        ("ListOfPythonInts", [
            IndexSpec(shape=(3, ), indexer=[0, 1, 0]),
            IndexSpec(shape=(3, 4, 5), indexer=[0, -1]),
        ]),
        ("ListOfListsOfPythonInts", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]),
            IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]),
        ]),
        ("ListOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[0, 1, onp.array([[2, 3, 0, 3]])]),
        ]),
        ("ListOfListsOfPythonIntsAndIntArrays", [
            IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]),
            IndexSpec(shape=(3, 4, 5),
                      indexer=[[[0], [-1]],
                               onp.array([[2, 3, 0, 3]])]),
        ]),
    ] for shape, indexer in index_specs for dtype in float_dtypes
                                    for rng_factory in [jtu.rand_default])
    @jtu.disable
    def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory,
                                         indexer):
        rng = rng_factory()
        tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
        arg = rng(shape, dtype)
        fun = lambda x: jnp.asarray(x)[indexer]
        check_grads(fun, (arg, ), 2, tol, tol, eps=1.)

    @parameterized.named_parameters({
        "testcase_name":
        "{}_inshape={}_indexer={}".format(
            name, jtu.format_shape_dtype_string(shape, dtype), indexer),
        "shape":
        shape,
        "dtype":
        dtype,
        "rng_factory":
        rng_factory,
        "indexer":
        indexer
    } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS
                                    for shape, indexer in index_specs
                                    for dtype in all_dtypes
                                    for rng_factory in [jtu.rand_default])
    @jtu.disable
    def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory,
                                         indexer):
        rng = rng_factory()
        indexer_with_dummies = [
            e if isinstance(e, onp.ndarray) else () for e in indexer
        ]
        substitutes = [(i, e) for i, e in enumerate(indexer)
                       if not isinstance(e, onp.ndarray)]
        args_maker = lambda: [rng(shape, dtype), indexer_with_dummies]

        def fun(x, indexer_with_dummies):
            idx = type(indexer)(subvals(indexer_with_dummies, substitutes))
            return jnp.asarray(x)[idx]

        self._CompileAndCheck(fun, args_maker, check_dtypes=True)

    @jtu.disable
    def testAdvancedIndexingManually(self):
        x = onp.random.RandomState(0).randn(3, 4, 5)
        index_array = onp.array([0, 2, -1, 0])

        op = lambda x, index_array: x[..., index_array, :]
        cop = npe.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[..., index_array, :, index_array, None]
        cop = npe.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

        op = lambda x, index_array: x[index_array, ..., index_array[:, None],
                                      None]
        cop = npe.jit(op)

        a1 = op(x, index_array)
        a2 = cop(x, index_array)

        self.assertAllClose(a1, a2, check_dtypes=True)

    @jtu.disable
    def testUnpacking(self):
        def foo(x):
            a, b, c = x
            return a + b + c

        cfoo = npe.jit(foo)

        a1 = foo(onp.arange(3))
        a2 = cfoo(onp.arange(3))

        self.assertAllClose(a1, a2, check_dtypes=True)

    @jtu.disable
    def testBooleanIndexingArray1D(self):
        idx = onp.array([True, True, False])
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testBooleanIndexingList1D(self):
        idx = [True, True, False]
        x = api.device_put(onp.arange(3))
        ans = x[idx]
        expected = onp.arange(3)[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testBooleanIndexingArray2DBroadcast(self):
        idx = onp.array([True, True, False, True])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testBooleanIndexingList2DBroadcast(self):
        idx = [True, True, False, True]
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testBooleanIndexingArray2D(self):
        idx = onp.array([[True, False], [False, True], [False, False],
                         [True, True]])
        x = onp.arange(8).reshape(4, 2)
        ans = api.device_put(x)[idx]
        expected = x[idx]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testBooleanIndexingDynamicShapeError(self):
        x = onp.zeros(3)
        i = onp.array([True, True, False])
        self.assertRaises(IndexError, lambda: npe.jit(lambda x, i: x[i])(x, i))

    @jtu.disable
    def testIssue187(self):
        x = jnp.ones((5, 5))
        x[[0, 2, 4], [0, 2, 4]]  # doesn't crash

        x = onp.arange(25).reshape((5, 5))
        ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x)
        expected = x[[0, 2, 4], [0, 2, 4]]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testJVPOfGradOfIndexing(self):
        # Should return a value, even though we didn't pass a symbolic zero as the
        # index tangent.
        x = jnp.ones((3, 4), jnp.float32)
        i = jnp.ones((3, ), jnp.int32)
        f = lambda x, i: jnp.sum(x[i])
        primals, tangents = api.jvp(api.grad(f), (x, i),
                                    (x, onp.zeros_like(i)))
        expected = onp.broadcast_to(
            onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4))
        self.assertAllClose(expected, primals, check_dtypes=True)
        self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)

    @jtu.disable
    def testTrivialGatherIsntGenerated(self):
        # https://github.com/google/jax/issues/1621
        jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4))
        self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
        self.assertNotIn('gather', str(jaxpr))

    @jtu.disable
    def testIndexingEmptyDimension(self):
        # Issue 2671: XLA error when indexing into dimension of size 0
        x = jnp.ones((2, 0))
        # The following work, even on axis 1 of size 0
        _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2]

        with self.assertRaisesRegex(
                IndexError,
                "index .* is out of bounds for axis .* with size 0"):
            _ = onp.ones((2, 0))[0, 0]  # The numpy error
        with self.assertRaisesRegex(
                IndexError, "index is out of bounds for axis .* with size 0"):
            _ = x[0, 0]  # JAX indexing
        with self.assertRaisesRegex(
                IndexError, "index is out of bounds for axis .* with size 0"):
            npe.jit(lambda i: x[0, i])(0)  # JAX indexing under jit

    @jtu.disable
    def testBooleanIndexingWithEmptyResult(self):
        # based on a TensorFlow Probability test that started failing after #1622
        x = jnp.array([-1])
        mask = jnp.array([False])
        ans = x[mask]  # doesn't crash

        expected = onp.array([-1])[onp.array([False])]
        self.assertAllClose(ans, expected, check_dtypes=False)

    @jtu.disable
    def testFloatIndexingError(self):
        BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
        with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
            jnp.zeros(2)[0.]
        with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
            jnp.zeros((2, 2))[(0, 0.)]
        with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
            jnp.zeros((2, 2))[(0, 0.)]
        with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
            npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.))
        with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
            ops.index_add(jnp.zeros(2), 0., 1.)
        with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
            ops.index_update(jnp.zeros(2), 0., 1.)

    @jtu.disable
    def testIndexOutOfBounds(
            self):  # https://github.com/google/jax/issues/2245
        array = jnp.ones(5)
        self.assertAllClose(array, array[:10], check_dtypes=True)
예제 #5
0
class IndexedUpdateTest(jtu.TestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                             rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testMixedAdvancedIndexing(self, shape, dtype, update_shape,
                                  update_dtype, rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True)

    @parameterized.named_parameters(
        jtu.cases_from_list({  # pylint: disable=g-complex-comprehension
            "testcase_name":
            "_{}_{}_{}_{}".format(
                jtu.format_shape_dtype_string(shape, dtype), indexer,
                jtu.format_shape_dtype_string(update_shape, update_dtype),
                op.name),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng_factory":
            rng_factory,
            "indexer":
            indexer,
            "update_shape":
            update_shape,
            "update_dtype":
            update_dtype,
            "op":
            op
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for op in UpdateOps for dtype in float_dtypes
                            for update_shape in _broadcastable_shapes(
                                _update_shape(shape, indexer))
                            for update_dtype in float_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                                rng_factory, indexer, op):
        rng = rng_factory()
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        x = rng(shape, dtype)
        y = rng(update_shape, update_dtype)
        self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.)
예제 #6
0
class IndexedUpdateTest(jtu.TestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in STATIC_INDEXING_TESTS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
                           rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing
        # xla_context 0-th output from". Investigate.
        check_xla = (
            not has_non_trivial_stride(indexer) and  # b/123559667
            not (isinstance(indexer, slice) and indexer.stop == 8
                 and indexer.step == -1))
        self._CompileAndCheck(tfnp_fn,
                              args_maker,
                              check_incomplete_shape=True,
                              check_experimental_compile=check_xla,
                              check_xla_forced_compile=check_xla)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
                             rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint: disable=g-complex-comprehension
                "testcase_name":
                "_{}_{}_{}_{}".format(
                    jtu.format_shape_dtype_string(shape, dtype), indexer,
                    jtu.format_shape_dtype_string(update_shape, update_dtype),
                    op.name),
                "shape":
                shape,
                "dtype":
                dtype,
                "rng_factory":
                rng_factory,
                "indexer":
                indexer,
                "update_shape":
                update_shape,
                "update_dtype":
                update_dtype,
                "op":
                op
            } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS
            for shape, indexer in index_specs for op in UpdateOps
            for dtype in (
                all_dtypes if op == UpdateOps.UPDATE else default_dtypes)
            for update_shape in _broadcastable_shapes(
                _update_shape(shape, indexer)) for update_dtype in all_dtypes
            for rng_factory in [jtu.rand_default]))
    def testMixedAdvancedIndexing(self, shape, dtype, update_shape,
                                  update_dtype, rng_factory, indexer, op):
        rng = rng_factory()
        args_maker = lambda: [
            rng(shape, dtype),
            rng(update_shape, update_dtype)
        ]
        np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y)
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker)
        check_xla = not has_non_trivial_stride(indexer)  # b/123559667
        self._CompileAndCheck(tfnp_fn,
                              args_maker,
                              check_incomplete_shape=True,
                              check_experimental_compile=check_xla,
                              check_xla_forced_compile=check_xla)

    @parameterized.named_parameters(
        jtu.cases_from_list({  # pylint: disable=g-complex-comprehension
            "testcase_name":
            "_{}_{}_{}_{}".format(
                jtu.format_shape_dtype_string(shape, dtype), indexer,
                jtu.format_shape_dtype_string(update_shape, update_dtype),
                op.name),
            "shape":
            shape,
            "dtype":
            dtype,
            "rng_factory":
            rng_factory,
            "indexer":
            indexer,
            "update_shape":
            update_shape,
            "update_dtype":
            update_dtype,
            "op":
            op
        } for name, index_specs in STATIC_INDEXING_TESTS
                            for shape, indexer in index_specs
                            for op in [UpdateOps.ADD, UpdateOps.UPDATE]
                            for dtype in float_dtypes
                            for update_shape in _broadcastable_shapes(
                                _update_shape(shape, indexer))
                            for update_dtype in float_dtypes
                            for rng_factory in [jtu.rand_default]))
    def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
                                rng_factory, indexer, op):
        rng = rng_factory()
        tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y)
        x = rng(shape, dtype)
        y = rng(update_shape, update_dtype)
        self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.)