class IndexingTest(jtu.TestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string( shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, rng_factory, indexer): # TODO(rohanj): Revisit passing in self.rng() to this to customize further. # This would need updating lax_numpy_test as well. rng = rng_factory() args_maker = lambda: [rng(shape, dtype)] onp_fun = lambda x: x[indexer] jnp_fun = lambda x: jnp.asarray(x)[indexer] self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in [ ("OneSliceIndex", [IndexSpec(shape=(5,), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), ("TwoSliceIndices", [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), ("NonUnitStrides", [ IndexSpec(shape=(3,), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithSlices(self, shape, dtype, rng_factory, indexer): rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def onp_fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_eval_on_shapes=False, check_incomplete_shape=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in [ ("OneIntIndex", [IndexSpec(shape=(3,), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3,), indexer=-1), IndexSpec(shape=(3,), indexer=-2)]), ("TwoIntIndices", [IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): # TODO(rohanj): Revisit passing in self.rng() to this to customize further. # This would need updating lax_numpy_test as well. rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def onp_fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), indexer] onp_fun = lambda x, idx: x[idx] jnp_fun = lambda x, idx: onp_fun(jnp.asarray(x), idx) self._CheckAgainstNumpy(onp_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() indexer_with_dummies = [e if isinstance(e, onp.ndarray) else () for e in indexer] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def np_fun(x, indexer_with_dummies): idx = type(indexer)(subvals(indexer_with_dummies, substitutes)) return x[idx] jnp_fun = lambda x, idx: np_fun(jnp.asarray(x), idx) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True) def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) # Note that we don't currently allow __iter__ in graph mode. So this test only # iterates over eager tensor. def testUnpacking(self): def foo(x): a, b, c = x return a + b + c a1 = foo(onp.arange(3)) a2 = foo(jnp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True) def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) x = jnp.asarray(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList1D(self): idx = [True, True, False] x = jnp.asarray(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2DBroadcast(self): idx = onp.array([True, True, False, True]) x = onp.arange(8).reshape(4, 2) ans = jnp.asarray(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = onp.arange(8).reshape(4, 2) ans = jnp.asarray(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2D(self): idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) x = onp.arange(8).reshape(4, 2) ans = jnp.asarray(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingDynamicShape(self): x = onp.zeros(3) i = onp.array([True, True, False]) ans = x[i] expected = jnp.asarray(x)[i] self.assertAllClose(ans, expected, check_dtypes=True) def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False) # TODO(agarwal): Fix this use case. @jtu.disable def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0)) # The following work, even on axis 1 of size 0 _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] with self.assertRaisesRegex(IndexError, "index .* is out of bounds for axis .* with size 0"): _ = onp.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex(IndexError, "index is out of bounds for axis .* with size 0"): npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit def testBooleanIndexingWithEmptyResult(self): # based on a TensorFlow Probability test that started failing after #1623 x = jnp.array([-1]) mask = jnp.array([False]) ans = x[mask] # doesn't crash expected = onp.array([-1])[onp.array([False])] self.assertAllClose(ans, expected, check_dtypes=False) def testFloatIndexingError(self): error_regex = "only integers, slices.*are valid indices" # Verify onp behavior with self.assertRaisesRegex(IndexError, error_regex): _ = onp.zeros((2, 2))[(0, 0.)] # Test jnp with self.assertRaisesRegex(IndexError, error_regex): jnp.zeros(2)[0.] with self.assertRaisesRegex(IndexError, error_regex): jnp.zeros((2, 2))[(0, 0.)] # Test with jit with self.assertRaisesRegex(IndexError, error_regex): npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 array = jnp.ones(5) self.assertAllClose(array, array[:10], check_dtypes=True)
class IndexedUpdateTest(jtu.TestCase): @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing # xla_context 0-th output from". Investigate. check_xla = ( not has_non_trivial_stride(indexer) and # b/123559667 not (isinstance(indexer, slice) and indexer.stop == 8 and indexer.step == -1)) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, check_experimental_compile=check_xla, check_xla_forced_compile=check_xla) @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) check_xla = not has_non_trivial_stride(indexer) # b/123559667 self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, check_experimental_compile=check_xla, check_xla_forced_compile=check_xla) @parameterized.named_parameters( jtu.cases_from_list({ # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in [UpdateOps.ADD, UpdateOps.UPDATE] for dtype in float_dtypes for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape, "rng_factory": rng_factory } for shape, start_indices, update_shape in [ [(3, ), (1, ), (1, )], [(5, 3), (1, 1), (3, 1)], [(5, 3), (1, -2), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], [(), (), ()], ] for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testDynamicUpdateSlice(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory() def args_maker(): return [ rng(shape, dtype), rng(update_shape, dtype), onp.array(start_indices) ] # update's shape must be fully known. # TODO(wangpeng): Support turning off check_incomplete_shape for individual # arguments. self._CompileAndCheck(npe.dynamic_update_slice, args_maker, check_incomplete_shape=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_start_indices={}_update_shape={}".format( # pylint: disable=g-complex-comprehension jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape, "rng_factory": rng_factory } for shape, start_indices, update_shape in [ [(3, ), (1, ), (1, )], [(5, 3), (1, 1), (3, 1)], [(5, 3), (1, -2), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], [(), (), ()], ] for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testDynamicUpdateSliceAgainstNumpy(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory() def args_maker(): return [ rng(shape, dtype), rng(update_shape, dtype), onp.array(start_indices) ] self._CheckAgainstNumpy(dynamic_update_slice_reference, npe.dynamic_update_slice, args_maker) def testDynamicUpdateSliceInDim(self): rng = jtu.rand_default() x = rng((6, 7), onp.int32) y = rng((3, 7), onp.int32) z = x.copy() z[2:5] = y self.assertAllClose(npe.dynamic_update_slice_in_dim(x, y, 2, 0), z, check_dtypes=True)
class IndexedUpdateTest(jtu.TestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op, "sugared": sugared } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for sugared in [True, False] for rng_factory in [jtu.rand_default])) @jtu.disable def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, sugared, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) if sugared: jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_sugared={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), sugared, op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op, "sugared": sugared } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for sugared in [True, False] for rng_factory in [jtu.rand_default])) @jtu.disable def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, sugared, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) if sugared: jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op, "sugared": sugared } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for sugared in [True, False] for rng_factory in [jtu.rand_default])) @jtu.disable def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, sugared, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) if sugared: jax_fn = lambda x, y: UpdateOps.sugar_fn(op, indexer, x, y) else: jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] for dtype in float_dtypes for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else float_dtypes) for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # TODO(mattjj,phawkins): tpu issues @jtu.disable def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) @jtu.disable def testSegmentSumBehavior(self): # testAdvancedIndexing compares against NumPy, and as a result doesn't check # repeated indices. This test is just a simple manual check, based on # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids, data) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testSegmentSum(self): data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testIndexDtypeError(self): # https://github.com/google/jax/issues/2795 jnp.array(1) # get rid of startup warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("error") jnp.zeros(5).at[::2].set(1) self.assertLen(w, 0)
class IndexingTest(jtu.TestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, rng_factory, indexer): # TODO(rohanj): Revisit passing in self.rng() to this to customize further. # This would need updating lax_numpy_test as well. rng = rng_factory() args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_GRAD_TESTS for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) @jtu.disable def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t( (pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("OneSliceIndex", [ IndexSpec(shape=(5, ), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3)) ]), ("TwoSliceIndices", [ IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))) ]), ("NonUnitStrides", [ IndexSpec(shape=(3, ), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) @jtu.disable def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng_factory, indexer): rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @npe.jit def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2) ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)) ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): # TODO(rohanj): Revisit passing in self.rng() to this to customize further. # This would need updating lax_numpy_test as well. rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2), ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) @jtu.disable def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory(self.rng()) tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @npe.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) @jtu.disable def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: jnp.asarray(x)[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([0, 1])), IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])), IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]), onp.array([1, 2])]), IndexSpec( shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])]), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]), onp.array([1, 2])]), IndexSpec(shape=(3, 4, 5), indexer=[ onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2]) ]), ]), ("ListOfPythonInts", [ IndexSpec(shape=(3, ), indexer=[0, 1, 0]), IndexSpec(shape=(3, 4, 5), indexer=[0, -1]), ]), ("ListOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]), ]), ("ListOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]), IndexSpec(shape=(3, 4, 5), indexer=[0, 1, onp.array([[2, 3, 0, 3]])]), ]), ("ListOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], onp.array([[2, 3, 0, 3]])]), ]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) @jtu.disable def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: jnp.asarray(x)[indexer] check_grads(fun, (arg, ), 2, tol, tol, eps=1.) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) @jtu.disable def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() indexer_with_dummies = [ e if isinstance(e, onp.ndarray) else () for e in indexer ] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def fun(x, indexer_with_dummies): idx = type(indexer)(subvals(indexer_with_dummies, substitutes)) return jnp.asarray(x)[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @jtu.disable def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = npe.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) @jtu.disable def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = npe.jit(foo) a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True) @jtu.disable def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testBooleanIndexingList1D(self): idx = [True, True, False] x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testBooleanIndexingArray2DBroadcast(self): idx = onp.array([True, True, False, True]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testBooleanIndexingArray2D(self): idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testBooleanIndexingDynamicShapeError(self): x = onp.zeros(3) i = onp.array([True, True, False]) self.assertRaises(IndexError, lambda: npe.jit(lambda x, i: x[i])(x, i)) @jtu.disable def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = npe.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = jnp.ones((3, 4), jnp.float32) i = jnp.ones((3, ), jnp.int32) f = lambda x, i: jnp.sum(x[i]) primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i))) expected = onp.broadcast_to( onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals, check_dtypes=True) self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True) @jtu.disable def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) @jtu.disable def testIndexingEmptyDimension(self): # Issue 2671: XLA error when indexing into dimension of size 0 x = jnp.ones((2, 0)) # The following work, even on axis 1 of size 0 _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] with self.assertRaisesRegex( IndexError, "index .* is out of bounds for axis .* with size 0"): _ = onp.ones((2, 0))[0, 0] # The numpy error with self.assertRaisesRegex( IndexError, "index is out of bounds for axis .* with size 0"): _ = x[0, 0] # JAX indexing with self.assertRaisesRegex( IndexError, "index is out of bounds for axis .* with size 0"): npe.jit(lambda i: x[0, i])(0) # JAX indexing under jit @jtu.disable def testBooleanIndexingWithEmptyResult(self): # based on a TensorFlow Probability test that started failing after #1622 x = jnp.array([-1]) mask = jnp.array([False]) ans = x[mask] # doesn't crash expected = onp.array([-1])[onp.array([False])] self.assertAllClose(ans, expected, check_dtypes=False) @jtu.disable def testFloatIndexingError(self): BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros(2)[0.] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): jnp.zeros((2, 2))[(0, 0.)] with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): npe.jit(lambda idx: jnp.zeros((2, 2))[idx])((0, 0.)) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): ops.index_add(jnp.zeros(2), 0., 1.) with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): ops.index_update(jnp.zeros(2), 0., 1.) @jtu.disable def testIndexOutOfBounds( self): # https://github.com/google/jax/issues/2245 array = jnp.ones(5) self.assertAllClose(array, array[:10], check_dtypes=True)
class IndexedUpdateTest(jtu.TestCase): @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) @parameterized.named_parameters( jtu.cases_from_list({ # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in float_dtypes for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.)
class IndexedUpdateTest(jtu.TestCase): @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) # TODO(wangpeng): When indexer is slice(_, 8, -1), XLA throws error "Missing # xla_context 0-th output from". Investigate. check_xla = ( not has_non_trivial_stride(indexer) and # b/123559667 not (isinstance(indexer, slice) and indexer.stop == 8 and indexer.step == -1)) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, check_experimental_compile=check_xla, check_xla_forced_compile=check_xla) @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True) @parameterized.named_parameters( jtu.cases_from_list( { # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) self._CheckAgainstNumpy(np_fn, tfnp_fn, args_maker) check_xla = not has_non_trivial_stride(indexer) # b/123559667 self._CompileAndCheck(tfnp_fn, args_maker, check_incomplete_shape=True, check_experimental_compile=check_xla, check_xla_forced_compile=check_xla) @parameterized.named_parameters( jtu.cases_from_list({ # pylint: disable=g-complex-comprehension "testcase_name": "_{}_{}_{}_{}".format( jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in [UpdateOps.ADD, UpdateOps.UPDATE] for dtype in float_dtypes for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() tfnp_fn = lambda x, y: UpdateOps.tfnp_fn(op, indexer, x, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) self.check_grads(tfnp_fn, (x, y), rtol=1e-3, atol=1e-3, delta=1.)