def _get_max_identity(dtype): if dtypes.issubdtype(dtype, np.inexact): return np.array(-np.inf, dtype) elif dtypes.issubdtype(dtype, np.integer): return np.array(dtypes.iinfo(dtype).min, dtype) elif dtypes.issubdtype(dtype, np.bool_): return np.array(False, np.bool_)
def _get_min_identity(dtype): if dtypes.issubdtype(dtype, np.inexact): return np.array(np.inf, dtype) elif dtypes.issubdtype(dtype, np.integer): return np.array(dtypes.iinfo(dtype).max, dtype) elif dtypes.issubdtype(dtype, np.bool_): return np.array(True, np.bool_)
def testIsSubdtype(self): for t in scalar_types: self.assertTrue(dtypes.issubdtype(t, t)) self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t)) self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type)) if t != jnp.bfloat16: for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger, jnp.unsignedinteger, jnp.floating, jnp.complexfloating]: self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category)) self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category))
def _check_input_dtype_jacfwd(holomorphic, x): _check_arg(x) aval = core.get_aval(x) if holomorphic: if not (dtypes.issubdtype(aval.dtype, np.complexfloating) and not dtypes.issubdtype(aval.dtype, np.floating)): raise TypeError( "jacfwd with holomorphic=True requires inputs with complex dtype, " f"but got {aval.dtype.name}.") elif not dtypes.issubdtype(aval.dtype, np.floating): raise TypeError( "jacfwd requires real-valued inputs (input dtype that is " f"a sub-dtype of np.floating), but got {aval.dtype.name}. " "For holomorphic differentiation, pass holomorphic=True. " "For differentiation of non-holomorphic functions involving complex " "inputs or integer inputs, use jax.jvp directly.")
def closure_convert(fun, in_tree, in_avals): if config.omnistaging_enabled: wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic( wrapped_fun, in_avals) else: in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) with core.initial_style_staging(): # type: ignore jaxpr, out_pvals, consts = pe.trace_to_jaxpr( wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore out_tree = out_tree() # We only want to closure convert for constants with respect to which we're # differentiating. As a proxy for that, we hoist consts with float dtype. # TODO(mattjj): revise this approach is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact) (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) num_consts = len(hoisted_consts) def converted_fun(y, t, *hconsts_args): hoisted_consts, args = split_list(hconsts_args, [num_consts]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten((y, t, *args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def _check_output_dtype_jacfwd(holomorphic, x): aval = core.get_aval(x) if holomorphic: if not dtypes.issubdtype(aval.dtype, np.complexfloating): raise TypeError( "jacfwd with holomorphic=True requires outputs with complex dtype, " f"but got {aval.dtype.name}.")
def _psum_translation_rule(c, *args, replica_groups=None, platform=None): if platform in ("cpu", "tpu"): return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, onp.complexfloating) n = len(dtype_args) if is_complex: dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex: xs = [xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n)] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def _translate(val): psum = partial(_allreduce_translation_rule, lax.add_p, c, replica_groups=replica_groups) dtype = c.get_shape(val).numpy_dtype() if dtypes.issubdtype(dtype, onp.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val)
def _psum_translation_rule(c, val, replica_groups, backend=None): psum = partial(_allreduce_translation_rule, lax.add_p, c, replica_groups=replica_groups, backend=backend) dtype = c.GetShape(val).numpy_dtype() if dtypes.issubdtype(dtype, onp.complexfloating): return c.Complex(psum(c.Real(val)), psum(c.Imag(val))) else: return psum(val)
def _translate(val): psum = partial(_allreduce_translation_rule, lax.add_p, c, axis_name=axis_name, axis_env=axis_env, axis_index_groups=axis_index_groups, platform=platform) dtype = c.get_shape(val).numpy_dtype() if dtypes.issubdtype(dtype, np.complexfloating): return xops.Complex(psum(xops.Real(val)), psum(xops.Imag(val))) else: return psum(val)
def test_div(self, harness: primitive_harness.Harness): dividend, divisor = harness.dyn_args_maker(self.rng()) prim = harness.params["prim"] if dtypes.issubdtype(dividend.dtype, np.integer): if (prim is lax.div_p and np.any(divisor == np.array(0, dtype=divisor.dtype))): raise unittest.SkipTest( "Divisor contains a 0, and TF returns an error value in compiled " "mode instead of failing like in eager and graph mode for dtype " f"{divisor.dtype}") self.ConvertAndCompare(harness.dyn_fun, dividend, divisor)
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups, axis_env, platform): if platform in ("cpu", "tpu"): return _notuple_allreduce_translation_rule( prim, c, *args, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) # XLA's tuple all-reduce doesn't support different dtypes in the same # allreduce. Instead, we perform once all-reduce for each argument input type. args_by_type = collections.defaultdict(lambda: ([], [])) for i, arg in enumerate(args): indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()] indices.append(i) dtype_args.append(arg) # The outputs, in the original argument order. out = [None] * len(args) replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) replica_groups_protos = xc.make_replica_groups(replica_groups) for dtype, (indices, dtype_args) in sorted(args_by_type.items()): is_complex = dtypes.issubdtype(dtype, np.complexfloating) n = len(dtype_args) if is_complex and prim is lax.add_p: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU dtype_args = ([xops.Real(x) for x in dtype_args] + [xops.Imag(x) for x in dtype_args]) scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation, replica_groups_protos, None, None) if is_complex and prim is lax.add_p: xs = [ xops.Complex(xops.GetTupleElement(all_reduce, i), xops.GetTupleElement(all_reduce, n + i)) for i in range(n) ] else: xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)] for i, x in zip(indices, xs): out[i] = x return xops.Tuple(c, out)
def _notuple_allreduce_translation_rule(prim, c, *args, axis_name, axis_env, axis_index_groups, platform): def all_reduce(x): replica_groups_protos = xc.make_replica_groups( _replica_groups(axis_env, axis_name, axis_index_groups)) scalar = ShapedArray((), c.get_shape(x).numpy_dtype()) computation = xla.primitive_subcomputation(prim, scalar, scalar) return xops.AllReduce(x, computation, replica_groups_protos, None, None) if prim is not lax.add_p: outs = [all_reduce(x) for x in args] else: # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a # special case because it's not currently handled by XLA:GPU outs = [xops.Complex(all_reduce(xops.Real(x)), all_reduce(xops.Imag(x))) if dtypes.issubdtype(c.get_shape(x).numpy_dtype(), np.complexfloating) else all_reduce(x) for x in args] return xops.Tuple(c, outs)
class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [ rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes) ] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) expected = onp.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix( rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol } for shape_group in compatible_shapes for shapes in CombosWithReplacement(shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ ( (b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3) ] for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in [onp.float32] for padding in all_pads for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for lhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(lhs_shape) + 1)) for rhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(rhs_shape) + 1)) if (lhs_bdim, rhs_bdim) != (None, None) for rng_factory in [jtu.rand_default])) def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim, rng_factory): rng = rng_factory(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(onp.take(lhs_shape, lhs_perm)) rhs_shape = list(onp.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims, "rng_factory": rng_factory } for from_dtype, to_dtype in itertools.product( [onp.float32, onp.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_default])) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims, "rng_factory": rng_factory } for from_dtype, to_dtype in itertools.product( [onp.float32, onp.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_default])) def testBitcastElementType(self, shape, from_dtype, to_dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format( jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory } for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape) for rng_factory in [jtu.rand_default])) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims, rng_factory): rng = rng_factory(self.rng()) raise SkipTest( "batching rule for clamp not implemented") # TODO(mattj) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims, "rng_factory": rng_factory } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testDot(self, lhs_shape, rhs_shape, dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={ onp.float16: 5e-2, onp.float64: 5e-14 }) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims, "rng_factory": rng_factory } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes for rng_factory in [jtu.rand_small])) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims, rng_factory): rng = rng_factory(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims, "rng_factory": rng_factory } for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes for rng_factory in [jtu.rand_small])) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims, rng_factory): rng = rng_factory(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, onp.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims, "rng_factory": rng_factory } for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2, ), (1, 2)] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_default])) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims, "rng_factory": rng_factory } for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape) for rng_factory in [jtu.rand_default])) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims, rng_factory): rng = rng_factory(self.rng()) raise SkipTest("this test has failures in some cases") # TODO(mattjj) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, onp.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory } for arg_shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 4), (-2, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape) for rng_factory in [jtu.rand_default])) def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory): dtype = onp.float32 rng = rng_factory(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims, "rng_factory": rng_factory } for dtype in default_dtypes for arg_shape, dimensions, out_shape in [[(3, 4), None, ( 12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, ( 2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), ( 8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)]] for bdims in all_bdims(arg_shape) for rng_factory in [jtu.rand_default])) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_pads={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small, "bdims": bdims } for shape in [(2, 3)] for bdims in all_bdims(shape) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims, rng_factory): rng = rng_factory(self.rng()) fun = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, onp.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims, "rng_factory": rng_factory } for arg_shape in [(), (3, ), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, ( pred_shape, arg_shape, arg_shape, ), (onp.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}". format(jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims, "rng_factory": rng_factory } for shape, start_indices, limit_indices, strides in [ [(3, ), (1, ), (2, ), None], [(7, ), (4, ), (7, ), None], [(5, ), (1, ), (5, ), (2, )], [(8, ), (1, ), (6, ), (2, )], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testSlice(self, shape, dtype, starts, limits, strides, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims, "rng_factory": rng_factory } for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes for rng_factory in [jtu.rand_default])) def testTranspose(self, shape, dtype, perm, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims, "rng_factory": rng_factory } for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-onp.inf, lax.max, float_dtypes), (dtypes.iinfo(onp.int32).min, lax.max, [onp.int32]), (dtypes.iinfo(onp.int64).min, lax.max, [onp.int64]), (dtypes.iinfo(onp.uint32).min, lax.max, [onp.uint32]), (dtypes.iinfo(onp.uint64).min, lax.max, [onp.uint64]), (onp.inf, lax.min, float_dtypes), (dtypes.iinfo(onp.int32).max, lax.min, [onp.int32]), (dtypes.iinfo(onp.int64).max, lax.min, [onp.int64]), (dtypes.iinfo(onp.uint32).max, lax.min, [onp.uint32]), (dtypes.iinfo(onp.uint64).max, lax.min, [onp.uint64]), ] for dtype in dtypes for shape, dims in [[(3, 4, 5), ( 0, )], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]] for bdims in all_bdims(shape) for rng_factory in [jtu.rand_small])) def testReduce(self, op, init_val, shape, dtype, dims, bdims, rng_factory): rng = rng_factory(self.rng()) init_val = onp.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_dtype={}_padding={}".format(op.__name__, onp.dtype(dtype).name, padding), "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, "rng_factory": rng_factory } for init_val, op, dtypes in [ (0, lax.add, [onp.float32]), (-onp.inf, lax.max, [onp.float32]), (onp.inf, lax.min, [onp.float32]), ] for dtype in dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) def testReduceWindow(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory(self.rng()) init_val = onp.asarray(init_val, dtype=dtype) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding) for shape, dims, strides in all_configs: for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_op={}_shape={}_axis={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "rng_factory": rng_factory } for op, types in [ (lax.cumsum, [onp.float32, onp.float64]), (lax.cumprod, [onp.float32, onp.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for rng_factory in [ jtu.rand_default if dtypes.issubdtype(dtype, onp.integer ) else jtu.rand_small ])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, rng_factory): rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis), 7, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_padding={}".format(onp.dtype(dtype).name, padding), "dtype": dtype, "padding": padding, "rng_factory": rng_factory } for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding, rng_factory): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest( "bfloat16 _select_and_gather_add doesn't work on tpu") rng = rng_factory(self.rng()) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, padding) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims, "rng_factory": rng_factory } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1) for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims, rng_factory): rng = rng_factory(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = [shape[axis] for axis in axes] op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [onp.complex64], rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5, ), onp.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), ((10, ), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (( 10, 5, ), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), ((10, 5), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims } for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5, ), onp.array([[0], [2]]), (2, ), lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ((10, ), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(), scatter_dims_to_operand_dims=( 0, ))), (( 10, 5, ), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={onp.float16: 5e-3}) def testShapeUsesBuiltinInt(self): x = lax.iota(onp.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory } for shape in [(4, ), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(float_dtypes + int_dtypes, itertools.repeat(jtu.rand_unique_int)))) ) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format( jtu.format_shape_dtype_string(shape, onp.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape, ) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (onp.float32, ) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort( args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (onp.float32, ) * arity, rng)
def categorize(prim: core.Primitive, *args, **kwargs) \ -> List[Limitation]: """ Given a primitive and a set of parameters one would like to pass to it, categorize identifies the potential limitations the call would encounter when converted to TF through jax2tf. Args: prim: the primitive to call. args: the arguments to pass to prim. kwargs: the keyword arguments to pass to prim. Returns: A list of limitations """ limitations: List[Limitation] = [] all_devices = ["CPU", "GPU", "TPU"] def _report_failure(error_type: str, msg: str, affected_dtype: Optional[NpDType] = None, devs: Sequence[str] = all_devices) -> None: affected_dtypes = ( tuple([affected_dtype]) if affected_dtype is not None else tuple()) limitations.append(Limitation(prim.name, error_type, msg, affected_dtypes, tuple(devs))) def tf_unimpl(np_dtype: Optional[NpDType] = None, additional_msg: Optional[str] = None, devs: Sequence[str] = all_devices) -> None: msg = "Primitive is unimplemented in TF" if additional_msg: msg += '; ' + additional_msg _report_failure(CATEGORY_MISSING_TF_SUPPORT, msg, np_dtype, devs=devs) def tf_possible_incorrect(np_dtype: Optional[NpDType] = None, msg: str = "", devs: Sequence[str] = all_devices) -> None: _report_failure(CATEGORY_POSSIBLE_INCORRECT_RESULTS, msg, np_dtype, devs=devs) def _to_np_dtype(dtype) -> NpDType: try: dtype = to_jax_dtype(dtype) except: pass return np.dtype(dtype) if args and args[0] is not core.unit: np_dtype = _to_np_dtype(args[0].dtype) else: np_dtype = None if prim is lax.regularized_incomplete_beta_p: if np_dtype in [np.float16, dtypes.bfloat16]: tf_unimpl(np_dtype) if prim in [lax.reduce_min_p, lax.reduce_max_p]: if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype) if prim in [lax.min_p, lax.max_p, lax.reduce_window_min_p, lax.reduce_window_max_p]: if np_dtype in [np.bool_, np.int8, np.uint16, np.uint32, np.uint64, np.complex64, np.complex128]: tf_unimpl(np_dtype) if prim is lax.div_p: if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16]: tf_unimpl(np_dtype) elif dtypes.issubdtype(np_dtype, np.integer): tf_unimpl(np_dtype, additional_msg=("integer division fails if the " "divisor contains a 0")) if prim is lax.rem_p: if np_dtype in [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.float16]: tf_unimpl(np_dtype) elif dtypes.issubdtype(np_dtype, np.integer): tf_unimpl(np_dtype, additional_msg=("integer division fails if the " "divisor contains a 0")) if prim is lax.atan2_p and np_dtype in [np.float16, dtypes.bfloat16]: # b/158006398: TF kernels are missing for 'rem' and 'atan2' tf_unimpl(np_dtype) if prim is lax.nextafter_p: if np_dtype in [np.float16, dtypes.bfloat16]: tf_unimpl(np_dtype) if prim is lax.linalg.cholesky_p: if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.linalg.qr_p: if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.linalg.eig_p: tf_unimpl(additional_msg=("this is a problem only in compiled mode " "(experimental_compile=True))")) compute_left_eigenvectors = kwargs['compute_left_eigenvectors'] compute_right_eigenvectors = kwargs['compute_right_eigenvectors'] if compute_left_eigenvectors and compute_right_eigenvectors: tf_unimpl(additional_msg=("it is not possible to request both left and " "right eigenvectors for now")) if prim is lax.linalg.eigh_p: if np_dtype in [np.complex64, np.complex128]: # See https://github.com/google/jax/pull/3775#issuecomment-659407824; # experimental_compile=True breaks for complex types. tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.linalg.lu_p: if np_dtype == np.complex64: tf_unimpl(np_dtype, devs=["TPU"]) if prim is lax.linalg.triangular_solve_p: if np_dtype in [dtypes.bfloat16, np.float16]: tf_unimpl(np_dtype) if prim is lax.linalg.svd_p: if np_dtype in [dtypes.bfloat16]: # TODO: SVD on TPU for bfloat16 seems to work for JAX but fails for TF tf_unimpl(np_dtype, devs=["TPU"]) elif np_dtype in [np.complex64, np.complex128]: # TODO: on CPU and GPU "No registered 'Svd' OpKernel for XLA_CPU_JIT # devices". Works on JAX because JAX uses a custom implementation. # There exists a XlaSvd operation that could replace tf.linalg.svd in # these cases but complex numbers support is not implemented in XLA yet, # and the API of XlaSvd is different than the one in JAX/TF, which also # limits its useability (e.g. no full_matrices argument, …). additional_msg = ("this works on JAX because JAX uses a custom " "implementation") tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"]) if prim is lax.select_and_scatter_add_p: if np_dtype in [np.uint64, np.uint32, np.uint16]: tf_unimpl(np_dtype) if prim is lax.select_and_gather_add_p: # TODO: the conversion is only supported for float16/float32 on CPU/GPU, # and float16 on TPU. This is because we do not implement a precision # reduction in the case where packing 2 n-bit values together results in # more than the maximum number of bits allowed on the platform (64 on # CPU/GPU, 32 on TPU). This could be fixed by implementing a variadic # reduce_window in tfxla, or we can require the user to reduce the # precision of their arrays manually based on the platform they run on. devices_and_max_bits = [ (["CPU", "GPU"], 64) , (["TPU"], 32) ] for devs, max_bits in devices_and_max_bits: if dtypes.finfo(np_dtype).bits * 2 > max_bits: # TODO: getting an exception "XLA encountered an HLO for which this # rewriting is not implemented" tf_unimpl(np_dtype, devs=devs) if prim in [lax.add_p, lax.reduce_window_sum_p]: if np_dtype in [np.uint16, np.uint32, np.uint64]: # TODO(bchetioui): tf.math.add is not defined for the above types. tf_unimpl(np_dtype) if prim is lax.mul_p: if np_dtype in [np.uint32, np.uint64]: # TODO(bchetioui): tf.math.multiply is not defined for the above types. tf_unimpl(np_dtype) if prim is lax.sort_p: if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype) if np_dtype == np.bool_ and len(args) == 2: tf_unimpl(np_dtype, additional_msg=( "sorting 2 arrays where the first one is an array of booleans is not " "supported for XlaSort")) if kwargs["is_stable"]: tf_unimpl(additional_msg="stable sort not implemented for XlaSort") if kwargs["dimension"] != len(np.shape(args[0])) - 1: tf_unimpl(additional_msg="only sorting on last dimension is supported " "for XlaSort") if len(args) > 2: tf_unimpl(additional_msg=( "sorting more than 2 arrays is not supported for XlaSort")) if prim is lax.population_count_p: if np_dtype in [np.uint32, np.uint64]: tf_unimpl(np_dtype) if prim is lax.clamp_p: if np_dtype in [np.int8, np.uint16, np.uint32, np.uint64]: tf_unimpl(np_dtype) # Testing with matmul (TODO: comment out and test without matmul) if prim is lax.dot_general_p: np_dtype = _to_np_dtype(args[0].dtype) if np_dtype in [np.bool, np.uint8, np.uint16, np.uint32, np.uint64, np.int8]: tf_unimpl(np_dtype) elif np_dtype == np.int16: # TODO(bchetioui): the path using 'einsum' is not compatible with int16 # arguments on CPU/GPU, while the one using 'matmul' is (but not in # compiled mode). tf_unimpl(np_dtype, additional_msg=("only cases representable as 2D " "matrix multiplication can be " "converted properly"), devs=['CPU', 'GPU']) tf_unimpl(np_dtype, devs=['TPU']) elif np_dtype in [np.int16, np.int64]: devs = ['CPU'] if np_dtype == np.int16 else ['CPU', 'GPU'] tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))"), devs=devs) if prim is lax.conv_general_dilated_p: batch_group_count = kwargs['batch_group_count'] if batch_group_count != 1: tf_unimpl(additional_msg="batch_group_count != 1 unsupported") if np_dtype in [np.complex64, np.complex128]: tf_unimpl(np_dtype, additional_msg="likely bug in the HLO -> LLVM IR " "lowering of XlaConv") if prim in [lax.acosh_p, lax.asinh_p, lax.atanh_p, lax.bessel_i0e_p, lax.bessel_i1e_p, lax.digamma_p, lax.erf_p, lax.erf_inv_p, lax.erfc_p, lax.lgamma_p, lax.round_p, lax.rsqrt_p]: if np_dtype == dtypes.bfloat16: tf_unimpl(np_dtype, devs=["CPU", "GPU"]) if prim is lax.convert_element_type_p: if np_dtype == dtypes.bfloat16: tf_unimpl(np_dtype, devs=["CPU", "GPU"]) if prim in [lax.sinh_p, lax.cosh_p, lax.atanh_p, lax.asinh_p, lax.acosh_p, lax.erf_inv_p]: if np_dtype == np.float16: # b/158006398: float16 support missing from the kernel of the above # operations. tf_unimpl(np_dtype) if prim in [lax.le_p, lax.lt_p, lax.ge_p, lax.gt_p]: if np_dtype in [np.bool_, np.uint16, np.uint32, np.uint64]: tf_unimpl(np_dtype) if prim is lax.fft_p: if np_dtype in [np.float64, np.complex128]: tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) if prim is lax.top_k_p: if np_dtype in [np.float64, np.int64, np.uint64]: tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled " "mode (experimental_compile=True))")) return limitations
class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.name, shapes, itertools.repeat(dtype)), "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "order": rec.order, "tol": rec.tol} for shape_group in compatible_shapes for shapes in CombosWithReplacement(shape_group, rec.nargs) for dtype in rec.dtypes) for rec in LAX_GRAD_OPS)) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.pow: raise SkipTest("pow grad imprecise on tpu") tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), "op": rec.op, "special_value": special_value, "tol": rec.tol} for special_value in rec.values) for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}".format( jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory} for from_dtype, to_dtype in itertools.product( float_dtypes + complex_dtypes, repeat=2) for rng_factory in [jtu.rand_default])) def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): rng = rng_factory(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in grad_float_dtypes for rng_factory in [jtu.rand_default])) def testClampGrad(self, shape, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) low = operand - dtype(10) high = operand + dtype(10) # Avoids points near the boundary where the gradient may be inaccurate. check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name, num_arrs), "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs, "rng_factory": rng_factory} for num_arrs in [3] for dtype in float_dtypes for base_shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(base_shape)) for rng_factory in [jtu.rand_default])) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "rng_factory": rng_factory,} for lhs_shape, rhs_shape, all_strides in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)])]) for strides in all_strides for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv, window_strides=strides, padding=padding, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory} for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], [(1, 1), (2, 1)], [(1, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], [(1,), (2,)], [(1,), (2,)])]) for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in float_dtypes for padding in all_pads for rng_factory in [jtu.rand_small])) def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_with_general_padding, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, "perms": perms, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count} for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ ([(b * batch_group_count, i * feature_group_count, 6, 7), (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)] for lhs_shape in lhs_shapes for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in grad_float_dtypes for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) for dim_nums, perms in [ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for rng_factory in [jtu.rand_default] )) def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, rng_factory): if dtype == onp.float16: raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve rng = rng_factory(self.rng()) tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4} # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(onp.take(lhs_shape, lhs_perm)) rhs_shape = list(onp.take(rhs_shape, rhs_perm)) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng_factory": jtu.rand_default} for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] for dtype in float_dtypes)) def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 2), (2, 4), (([1], [0]), ([], []))), ((3, 5), (2, 5), (([1], [1]), ([], []))), ((5, 3), (5, 2), (([0], [0]), ([], []))), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ] for dtype in float_dtypes)) def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( shape, onp.dtype(dtype).name, broadcast_sizes), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in float_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for rng_factory in [jtu.rand_default])) def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): rng = rng_factory(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "rng_factory": rng_factory} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): rng = rng_factory(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_perm={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), permutation), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng_factory": rng_factory, "permutation": permutation} for dtype in float_dtypes for arg_shape, out_shape, permutation in [ [(3, 4), (12,), None], [(2, 1, 4), (8,), None], [(2, 2, 4), (2, 8), None], [(3, 4), (12,), (0, 1)], [(3, 4), (12,), (1, 0)], [(2, 1, 4), (8,), (0, 2, 1)], [(2, 1, 4), (8,), (2, 0, 1)], [(2, 2, 4), (2, 8), (0, 2, 1)], [(2, 2, 4), (2, 8), (2, 0, 1)], ] for rng_factory in [jtu.rand_default])) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads), "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small} for shape in [(2, 3)] for dtype in float_dtypes for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) def testPadGrad(self, shape, dtype, pads, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = onp.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) dimensions = [0] check_grads(rev, (onp.array([3., 2., 1.]),), 2) dimensions = [0, 1] check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={onp.float32: 3e-3}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( jtu.format_shape_dtype_string(pred_shape, onp.bool_), jtu.format_shape_dtype_string(arg_shape, dtype)), "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype, "rng_factory": rng_factory} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): rng = rng_factory(self.rng()) pred = rng(pred_shape, onp.bool_) on_true = rng(arg_shape, dtype) on_false = rng(arg_shape, dtype) select = lambda on_true, on_false: lax.select(pred, on_true, on_false) check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "rng_factory": rng_factory} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, size_indices), "shape": shape, "dtype": dtype, "start_indices": start_indices, "size_indices": size_indices, "rng_factory": rng_factory} for shape, start_indices, size_indices in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape, "rng_factory": rng_factory} for shape, start_indices, update_shape in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = onp.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( jtu.format_shape_dtype_string(shape, dtype), perm), "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testTransposeGrad(self, shape, dtype, perm, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, inexact_dtypes, jtu.rand_default), (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), ()], [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], ])) def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, onp.float64: 1e-3, onp.complex64: 1e-1} operand = rng(shape, dtype) init_val = onp.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_dtype={}_padding={}" .format(op.__name__, onp.dtype(dtype).name, padding), "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] for dtype in dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory(self.rng()) init_val = onp.asarray(init_val, dtype=dtype) # We need this conditional and the corresponding loop logic to be in the # test method, rather than at the parameterized test level, because it # depends on FLAGS for the device under test. # TODO(b/31565929): enable when fixed. if jtu.device_under_test() == "tpu" and op is not lax.add: all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] # TODO(b/73062247): need variadic reduce-window for better precision. gradient_order = 1 else: all_configs = itertools.chain( itertools.product( [(4, 6)], # shapes [(2, 1), (1, 2)], # window_dimensions [(1, 1), (2, 1), (1, 2)] # strides ), itertools.product( [(3, 2, 4, 6)], # shapes [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides ) gradient_order = 3 def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding) for shape, dims, strides in all_configs: operand = rng(shape, dtype) if op is lax.add: eps = 1. tol = None else: # this test can fail if there are duplicates in operand self.assertEqual(onp.unique(operand).size, operand.size, msg="test requires operand elements to be unique.") eps = 1e-2 tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2} check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis), "op": op, "shape": shape, "dtype": dtype, "axis": axis, "rng_factory": rng_factory} for op, types in [ (lax.cumsum, [onp.float32, onp.float64]), (lax.cumprod, [onp.float32, onp.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for rng_factory in [ jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) else jtu.rand_small])) def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable} for dtype in [onp.float32] for shape in [(5,), (5, 7)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, key_dtype), jtu.format_shape_dtype_string(shape, val_dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis, "is_stable": is_stable} for key_dtype in [onp.float32] for val_dtype in [onp.float32] for shape in [(3,), (5, 3)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) # This test relies on the property that wherever keys are tied, values are # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values keys, values = args_maker() fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} for dtype in [onp.float32,] for shape in [(4,), (5, 5), (2, 1, 4)] for k in [1, 3] for rng_factory in [jtu.rand_default])) def testTopKGrad(self, shape, dtype, k, rng_factory): flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, axes), "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory} for dtype in float_dtypes for shape, idxs, axes in [ [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)], [(3, 4, 5), (onp.array([-1, -2]),), (0,)], [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)], [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)], ] for rng_factory in [jtu.rand_default])) def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): rng = rng_factory(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for shape, idxs, dnums, slice_sizes, max_idx in [ ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,), 5), ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,), 9), ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes) x = rng(shape, dtype) check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] # Scatters with conflicting indices are not deterministic on GPU, so we # use indices that do not collide. for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): # https://github.com/google/jax/issues/1901 def f(x): n = x.shape[0] y = onp.arange(n, dtype=x.dtype) return jax.ops.index_update(x, onp.diag_indices(n), y) rng = jtu.rand_default(self.rng()) check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with core.skipping_checks(): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test def testRemainder(self): rng = onp.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(3, 4)) y = rng.uniform(0.7, 1.9, size=(3, 1)) assert not set(onp.unique(x)) & set(onp.unique(y)) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) rng = onp.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(1, 4)) y = rng.uniform(0.7, 1.9, size=(3, 4)) assert not set(onp.unique(x)) & set(onp.unique(y)) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) def testHigherOrderGradientOfReciprocal(self): # Regression test for https://github.com/google/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),), order=2)
def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype( dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape, ), (dtype, ), rng)
def is_float(c): return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)