def omnistaging_disabler() -> None: global axis_index psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args) # type: ignore def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1] nreps = dynamic_axis_env.nreps trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(nreps=nreps, sizes=sizes, axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer def _axis_index_translation_rule(c, nreps, sizes, axis_name): div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32)) mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) axis_index_p.def_custom_bind(_axis_index_bind) axis_index_p.def_abstract_eval( lambda *args, **params: ShapedArray((), np.int32)) xla.translations[axis_index_p] = _axis_index_translation_rule
def _all_to_all_translation_rule(c, x, *, split_axis, concat_axis, axis_name, axis_index_groups, axis_env, platform): # Workaround for AllToAll not being implemented on CPU. replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups) if len(replica_groups[0]) == 1: return x elif platform != 'tpu': warnings.warn("all_to_all (and pswapaxes) are only implemented properly for TPUs. All other " "backends emulate it using a very slow and memory intensive algorithm, so expect " "significant slowdowns.") lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True) return lowering(c, x, split_axis=split_axis, concat_axis=concat_axis, axis_name=axis_name, axis_index_groups=axis_index_groups, axis_env=axis_env, platform=platform) else: split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') replica_groups_protos = xc.make_replica_groups(replica_groups) if concat_axis == split_axis: return xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) else: if concat_axis < split_axis: split_axis += 1 elif split_axis < concat_axis: concat_axis += 1 x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, x) x = xops.AllToAll(x, split_axis, concat_axis, split_count, replica_groups_protos) x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x) return x
def testSumPool(self): W = jnp.array(np.random.randn(3, 3, 1, 5), dtype=np.float32) X = jnp.array(np.random.randn(10, 5, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: jnp.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example = jnp.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct) # Test gradients. per_example = vmap(partial(grad_loss, W))(jnp.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, jnp.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ jnp.reshape(g, (1,) + g.shape)] per_example_direct = jnp.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, rtol=3e-2, atol=1e-3)
def value_and_jacfwd_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args) pushfwd = partial(_jvp, f_partial, dyn_args) y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y) example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args return y, tree_map(partial(_jacfwd_unravel, example_args), y, jac)
def testSort(self): v = np.arange(12)[::-1].reshape(3, 4) sv = vmap(partial(lax.sort, dimension=0), (0,))(v) self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=-1), (0,))(v) self.assertAllClose(sv, v[:, ::-1]) sv = vmap(partial(lax.sort, dimension=0), (1,))(v) self.assertAllClose(sv, v[::-1, :].T) sv = vmap(partial(lax.sort, dimension=0), (1,), 1)(v) self.assertAllClose(sv, v[::-1, :])
def triangular_solve_jvp_rule_a( g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): m, n = b.shape[-2:] k = 1 if unit_diagonal else 0 g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k) g_a = lax.neg(g_a) g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a g_a = jnp.conj(g_a) if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) def a_inverse(rhs): return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a, unit_diagonal) # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == (m, n) if m > n: return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X) else: return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X else: assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == (m, n) if m < n: return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1} else: return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
def ppermute(x, axis_name, perm): """Perform a collective permutation according to the permutation ``perm``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This function is an analog of the CollectivePermute XLA HLO. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). perm: list of pairs of ints, representing ``(source_index, destination_index)`` pairs that encode how the mapped axis named ``axis_name`` should be shuffled. The integer values are treated as indices into the mapped axis ``axis_name``. Any two pairs should not have the same source index or the same destination index. For each index of the axis ``axis_name`` that does not correspond to a destination index in ``perm``, the corresponding values in the result are filled with zeros of the appropriate type. Returns: Array(s) with the same shape as ``x`` with slices along the axis ``axis_name`` gathered from ``x`` according to the permutation ``perm``. """ return tree_util.tree_map( partial(ppermute_p.bind, axis_name=axis_name, perm=tuple(perm)), x)
def testRandom(self): seeds = vmap(random.PRNGKey)(np.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = np.stack([random.normal(random.PRNGKey(seed), (3, 2)) for seed in np.arange(10)]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(np.unique(ans)) == 10 * 3 * 2
def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = jnp.dot(W, inputs) + b inputs = jnp.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return jnp.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = np.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size,) + W.shape) self.assertEqual(db.shape, (batch_size,) + b.shape)
def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w_real = eigh_p.bind(symmetrize(a), lower=lower) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) eye_n = jnp.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, jnp.multiply(Fmat, vdag_adot_v)) dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw)
def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): index = axis_index(axis_name) if axis_index_groups is not None: indices = np.array(axis_index_groups).flatten() axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0]) index = lax_numpy.array(axis_index_to_group_index)[index] outs = tree_util.tree_map(partial(_expand, all_gather_dimension, axis_size, index), x) return psum(outs, axis_name, axis_index_groups=axis_index_groups)
def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = np.stack([fun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes): rng = jtu.rand_default(self.rng()) fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: jnp.sum(jnp.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = np.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False)
def test_check_jaxpr_scan_correct(self): def f(c, x): b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c))) c = jnp.sin(c * b) return c, b xs = jnp.ones((5, 3)) c = jnp.ones(4) jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr core.check_jaxpr(jaxpr)
def test_vjp(self, f, args): jtu.check_vjp(f, partial(vjp, f), args, rtol={ np.float32: 3e-1, np.float64: 1e-5 }, atol={ np.float32: 1e-2, np.float64: 1e-5 })
def testSortKeyVal(self): k = np.arange(12)[::-1].reshape(3, 4) v = np.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :]) self.assertAllClose(sv, v[::-1, :]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, np.broadcast_to(k[0, ::-1], (3, 4))) self.assertAllClose(sv, v[:, ::-1]) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1]) self.assertAllClose(sv, np.broadcast_to(v[0, ::-1], (3, 4)))
def value_and_jacrev_f(*args, **kwargs): f = lu.wrap_init(fun, kwargs) f_partial, dyn_args = argnums_partial( f, argnums, args, require_static_args_hashable=False) tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args) if not has_aux: y, pullback = _vjp(f_partial, *dyn_args) else: y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True) tree_map(partial(_check_output_dtype_jacrev, holomorphic), y) jac = vmap(pullback)(_std_basis(y)) jac = jac[0] if isinstance(argnums, int) else jac example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac) if not has_aux: return y, tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) else: return (y, aux), tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree) return
def _contains_query(vals, query): if isinstance(query, tuple): return map(partial(_contains_query, vals), query) if jnp.isnan(query): if jnp.any(jnp.isnan(vals)): raise FoundValue('NaN') elif jnp.isinf(query): if jnp.any(jnp.isinf(vals)): raise FoundValue('Found Inf') elif jnp.isscalar(query): if jnp.any(vals == query): raise FoundValue(str(query)) else: raise ValueError('Malformed Query: {}'.format(query))
def test_reference_cycles(self): gc.collect() def f(x): return x.sum() fn = partial(linearize, f) params = jnp.zeros([]) debug = gc.get_debug() try: fn(params) gc.set_debug(gc.DEBUG_SAVEALL) self.assertEqual(gc.collect(), 0, msg=str(gc.garbage)) finally: gc.set_debug(debug)
def tree_update(i, grad_tree, opt_state): states_flat, tree, subtrees = opt_state grad_flat, tree2 = tree_flatten(grad_tree) if tree2 != tree: msg = ("optimizer update function was passed a gradient tree that did " "not match the parameter tree structure with which it was " "initialized: parameter tree {} and grad tree {}.") raise TypeError(msg.format(tree, tree2)) states = map(tree_unflatten, subtrees, states_flat) new_states = map(partial(update, i), grad_flat, states) new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) for subtree, subtree2 in zip(subtrees, subtrees2): if subtree2 != subtree: msg = ("optimizer update function produced an output structure that " "did not match its input structure: input {} and output {}.") raise TypeError(msg.format(subtree, subtree2)) return OptimizerState(new_states_flat, tree, subtrees)
def testConvGeneralDilatedBatchNotMajor(self): W = jnp.array(np.random.randn(3, 3, 1, 4), dtype=np.float32) x = jnp.array(np.random.randn(3, 5, 7, 5, 1), dtype=np.float32) def f(params, x): one = (1, 1) dimension_numbers = ('HNWC', 'HWIO', 'HWNC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y per_example = vmap(partial(f, W))(x) per_example = jnp.reshape(jnp.transpose(per_example, (1, 2, 0, 3, 4)), (5, 5, 21, 4)) per_example_direct = f(W, jnp.reshape(jnp.transpose(x, (1, 0, 2, 3, 4)), (5, 21, 5, 1))) self.assertAllClose(per_example, per_example_direct)
def testNpMaximumPerExampleGrad(self): R = np.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = jnp.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * jnp.dot( jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex) expected_ans = jnp.transpose(expected_ans) self.assertAllClose( ans[i], expected_ans, check_dtypes=False, atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
def _solve(a, b): _check_solve_shapes(a, b) # Broadcast leading dimensions of b to the shape of a, as is required by # custom_linear_solve. out_shape = tuple(d_a if d_b == 1 else d_b for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape)) b = jnp.broadcast_to(b, out_shape) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, lambda x: _matvec_multiply(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
thunk_name = 'fwd_jaxpr_thunk' params['bwd'] = callback_subtrace(params['bwd'], main) else: raise NotImplementedError(primitive) thunk = params.pop(thunk_name) @pe._memoize def new_thunk(): thunk_jaxpr = core.ClosedJaxpr(*thunk()) closed_jaxpr = callback_jaxpr(thunk_jaxpr, main.callback, main.strip_calls) return closed_jaxpr.jaxpr, closed_jaxpr.literals params[thunk_name] = new_thunk new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals closed_fun_jaxpr = core.ClosedJaxpr( pe.convert_constvars_jaxpr(new_fun_jaxpr), ()) new_num_consts = len(new_consts) + num_consts out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr, num_consts=new_num_consts, **params) return safe_map(trace.pure, out) custom_callback_rules[cd.custom_jvp_call_jaxpr_p] = partial( _custom_derivative_call_jaxpr_callback_rule, cd.custom_jvp_call_jaxpr_p) custom_callback_rules[cd.custom_vjp_call_jaxpr_p] = partial( _custom_derivative_call_jaxpr_callback_rule, cd.custom_vjp_call_jaxpr_p)
if jnp.issubdtype(dtype, np.complexfloating): nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype)) else: nan = xb.constant(c, np.array(np.nan, dtype=dtype)) return xops.Broadcast(nan, shape.dimensions()) def _cholesky_cpu_gpu_translation_rule(potrf_impl, c, operand): shape = c.get_shape(operand) batch_dims = shape.dimensions()[:-2] result, info = potrf_impl(c, operand, lower=True) ok = xops.Eq(info, xops.ConstantLiteral(c, np.array(0, np.int32))) return _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), result, _nan_like(c, result)) xla.backend_specific_translations['cpu'][cholesky_p] = partial( _cholesky_cpu_gpu_translation_rule, lapack.potrf) if cusolver is not None: xla.backend_specific_translations['gpu'][cholesky_p] = partial( _cholesky_cpu_gpu_translation_rule, cusolver.potrf) if rocsolver is not None: xla.backend_specific_translations['gpu'][cholesky_p] = partial( _cholesky_cpu_gpu_translation_rule, rocsolver.potrf) # Asymmetric eigendecomposition def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): return ( xla.apply_primitive(eig_p, operand, compute_left_eigenvectors=compute_left_eigenvectors,
def test_jvp_linearized(self, f, args): jtu.check_jvp(f, partial(jvp_unlinearized, f), args, rtol={np.float32: 3e-2})
def test_jvp(self, f, args): jtu.check_jvp(f, partial(jvp, f), args, rtol={np.float32: 3e-2})
return xops.Tuple(c, outs) def _psum_transpose_rule(cts, *args, axis_name, axis_index_groups): nonzero_out_cts, treedef = tree_util.tree_flatten(cts) nonzero_in_cts = psum_p.bind(*nonzero_out_cts, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, nonzero_in_cts) psum_p = core.Primitive('psum') psum_p.multiple_results = True psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args)) pxla.soft_pmap_rules[psum_p] = \ partial(_allreduce_soft_pmap_rule, psum_p, lax._reduce_sum) xla.parallel_translations[psum_p] = partial(_allreduce_translation_rule, lax.add_p) # type: ignore ad.deflinear2(psum_p, _psum_transpose_rule) pxla.multi_host_supported_collectives.add(psum_p) batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p) batching.collective_rules[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, d: v.sum(d), lambda v, axis_size: axis_size * v) # We set a special bind rule for psum so that psum(1, 'i') can be evaluated at # tracing time. @psum_p.def_custom_bind
CallSpec(fun_with_two_calls, (R(3, 2),)), CallSpec(fun_with_call_closure, (R(3, 2),)), CallSpec(fun_call_jitted, (R(1,),)), CallSpec(fun_with_nested_calls, (R(),)), CallSpec(fun_with_nested_calls, (R(3, 2),)), CallSpec(fun_with_nested_calls_2, (R(1, 2),)), ] def jvp_unlinearized(f, primals, tangents): out, jvp = linearize(f, *primals) return out, jvp(*tangents) test_specs = [] for ts in test_specs_base: test_specs.append(ts) test_specs.append(CallSpec(partial(jvp, ts.fun), (ts.args, ts.args))) test_specs.append(CallSpec(jit(ts.fun), ts.args)) test_specs.append(CallSpec(jit(jit(ts.fun)), ts.args)) test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun), (ts.args, ts.args))) def fwd_deriv(f): def df(x): return jvp(f, (x,), (1.0,))[1] return df class CoreTest(jtu.JaxTestCase):
def _allgather(x, dim, size, index, axis_name, axis_index_groups=None): outs = tree_util.tree_map(partial(_expand, dim, size, index), x) return psum(outs, axis_name, axis_index_groups=axis_index_groups)