def plot_mixing_prob_and_start_end(gp, solver, traj_opt=None): # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) # TODO need to change gp.X to gp.Z (and gp.q_mu) for sparse mixing_probs = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )( Xnew, # gp.X, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt, ) # mixing_probs = mogpe_mixing_probability(Xnew, # gp.X, # gp.kernel, # mean_func=gp.mean_func, # f=gp.Y, # q_sqrt=gp.q_sqrt, # full_cov=False) # print('mixing probs yo') # print(mixing_probs.shape) # mixing_probs = mixing_probs[:, 0, :] * mixing_probs[:, 1, :] # output_dim = mixing_probs.shape[0] fig, ax = plt.subplots(1, 1) plot_contourf( fig, ax, xx, yy, # mixing_probs[:, 1:2], mixing_probs[:, 0:1], label="$\Pr(\\alpha=1 | \mathbf{x})$", ) ax.set_xlabel("$x$") ax.set_ylabel("$y$") plot_omitted_data(fig, ax, color="k") plot_start_and_end_pos(fig, ax, solver) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory" ) ax.legend() return fig, ax
def select_all_tensors(all_tensors, indices): return jax.vmap(select_tensor)(all_tensors, indices)
def loss(r_surf, nn, sg, weight, p): l = r(p, theta) dl = l[:, :-1, :] - l[:, 1:, :] return np.sum(pmap_quadratic_flux(r_surf, nn, sg, dl, l[:, :-1, :])) + weight * np.sum(dl) ####################################################################### # JAX/Python Function Transformations ####################################################################### # functional programming (Python) objective_function = partial(loss, r_surf, nn, sg, 0.1) # vectorize (JAX) biot_savart_surface = vmap(vmap(biot_savart, (0, None, None), 0), (1, None, None), 1) # automatic differentiation (JAX) grad_func = grad(objective_function) # d output / d input # jit-compile (JAX) jit_grad_func = jit(grad_func) # SPMP parallelization (JAX) pmap_quadratic_flux = pmap(quadratic_flux, in_axes=(0, 0, 0, None, None)) ####################################################################### # Optimization ####################################################################### print("loss is {}".format(objective_function(p)))
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) def gp_jacobian_all(x): if len(x.shape) == 1: x = x.reshape(1, -1) return gp_jacobian( x, gp.Z, gp.kernel, gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew) print("gp jacobain mu var") print(mu_j.shape) print(var_j.shape) # mu = np.prod(mu, 1) # var = np.diagonal(var, axis1=-2, axis2=-1) # var = np.prod(var, 1) fig, axs = plot_mean_and_var( xx, yy, mu, var, # mu, # var, llabel="$\mathbb{E}[h^{(1)}]$", rlabel="$\mathbb{V}[h^{(1)}]$", ) for ax in axs: ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k") fig, ax = plot_start_and_end_pos(fig, ax, solver) plot_omitted_data(fig, ax, color="k") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory", ) axs[0].legend() return fig, axs
def dist_logloss(dist_class, fixed_params, opt_params, data): dist = dist_class.from_params(fixed_params, opt_params, traceable=True) if data.size == 1: return -dist.logpdf(data) scores = vmap(dist.logpdf)(data) return -np.sum(scores)
def BW(phim,phiw,fm,fw,phi,f): a = np.moveaxis(vmap(partial(_BW,Sbc=phi))(phim,phiw),1,0) b = np.moveaxis(vmap(partial(_BW,Sbc=f))(fm,fw),1,0) result = dplex.deinsum('ij,ij->ij',a,b) return result
def func_type1(S, A, is_training): # custom haiku function: s,a -> q(s,a) value = hk.Sequential([...]) X = jax.vmap(jnp.kron)(S, A) # or jnp.concatenate((S, A), axis=-1) or whatever you like return value(X) # output shape: (batch_size,)
def inbetween(x): return 1 + vmap(innermost)(x)
def _flip_state_batch_default_impl(hilb, key, states, indxs, scalar_rule): keys = jax.random.split(key, states.shape[0]) res = jax.vmap(scalar_rule, in_axes=(None, 0, 0, 0), out_axes=0)( hilb, keys, states, indxs ) return res
def testIssue1789(self): def f(x): return random.gamma(random.PRNGKey(0), x) grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2))
def _random_state_batch_default_impl(hilb, key, size, dtype, scalar_rule): keys = jax.random.split(key, size) res = jax.vmap(scalar_rule, in_axes=(None, 0, None), out_axes=0)(hilb, key, dtype) return res
def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree, kwargs, **params): """Batching rule for layer_cau primitive to handle custom layers.""" if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(*vals, num_consts=num_consts, in_tree=in_tree, out_tree=out_tree, kwargs=kwargs, **params) orig_vals, orig_dims = vals, dims vals, dims = vals[num_consts:], dims[num_consts:] args = tree_util.tree_unflatten(in_tree, vals) dims_ = [not_mapped if dim is None else dim for dim in dims] layer, args = args[0], args[1:] if hasattr(layer, '_call_and_update_batched'): num_params = len(tree_util.tree_leaves(layer)) layer_dims, arg_dims = dims_[:num_params], dims_[num_params:] if kwargs['has_rng']: rng, args = args[0], args[1:] rng_dim, arg_dims = arg_dims[0], arg_dims[1:] mapping_over_layer = all(layer_dim is not not_mapped for layer_dim in layer_dims) mapping_over_args = all(arg_dim is not not_mapped for arg_dim in arg_dims) assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims) if not mapping_over_layer and mapping_over_args: if kwargs['has_rng']: if rng_dim is not not_mapped: arg_dims = tuple(None if dim is not_mapped else dim for dim in arg_dims) map_fun = jax.vmap( lambda layer, rng, *args: _layer_cau_batched( layer, rng, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **kwargs), in_axes=(None, rng_dim) + (None, ) * len(arg_dims)) else: map_fun = lambda layer, *args: _layer_cau_batched( layer, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **kwargs) vals_out, update_out = map_fun(layer, rng, *args) else: vals_out, update_out = _layer_cau_batched( layer, *args, **kwargs) vals_out = tree_util.tree_leaves(vals_out) update_out = tree_util.tree_leaves(update_out) assert all(dim == 0 for dim in arg_dims) # Assume dimensions out are consistent dims_out = (0, ) * len(vals_out) dims_update = (None, ) * len(update_out) assert len(vals_out) == len(dims_out) assert len(update_out) == len(dims_update) return vals_out + update_out, dims_out + dims_update batched, out_dims = primitive.batch_fun( lu.wrap_init( layer_cau_p.impl, dict(params, num_consts=num_consts, in_tree=in_tree, out_tree=out_tree, kwargs=kwargs)), orig_dims) return batched.call_wrapped(*orig_vals), out_dims()
y_at_t__backwards = solution_and_adjoint_variable_at_t[:, 0, :] adjoint_variable_at_t = solution_and_adjoint_variable_at_t[:, 1, :] J_entire_trajectory__reverse_classical_solve = loss_function_entire_trajectory( y_at_t__backwards, parameters_guess) # The initial condition was not dependent on the parameters d_u0__d_theta = jnp.zeros((2, 4)) # We still have to do an integration over the time horizon dynamic_sensitivity_jacobian = lambda t, y, params: jnp.array( jax.jacobian(model, argnums=2)(t, y, *params)).T # The jit is not really advantageous, because we are only calling the function once vectorized_dynamic_sensitivity_jacobian = jax.jit( jax.vmap(dynamic_sensitivity_jacobian, in_axes=(0, 1, None), out_axes=2)) del_f__del_theta__at_t = vectorized_dynamic_sensitivity_jacobian( t_discrete, y_at_t, parameters_guess) adjoint_variable_matmul_del_f__del_theta_at_t = jnp.einsum( "iN,ijN->jN", adjoint_variable_at_t, del_f__del_theta__at_t) d_J__d_theta__at_end__adjoint = ( adjoint_variable_at_t[:, -1].T @ d_u0__d_theta + jnp.zeros((1, 4)) + integrate.trapezoid(adjoint_variable_matmul_del_f__del_theta_at_t, t_discrete, axis=-1)) time_adjoint__at_end = time.time_ns() - time_adjoint__at_end
# print(f"t.shape {t.shape}") # print(f"X.shape) {X.shape}") input = jnp.concatenate((t, X), 0) # M x D+1 # print("here?") activations = input counter = 0 for w, b in params[:-1]: outputs = jnp.dot(activations, w) + b activations = relu(outputs) counter += 1 final_w, final_b = params[-1] u = jnp.dot(activations, final_w) + final_b return jnp.reshape(u, ()) #need scalar for grad vforward = vmap(forward, in_axes=(None, 0, 0)) def grad_forward(params, t, X): gradu = grad(forward, argnums=(2)) #<wrt X only not params or t # partial(grad(loss), params) Du = gradu(params, t, X) return Du vgrad_forward = vmap(grad_forward, in_axes=(None, 0, 0)) # def Dg_tf(X): # M x D # # g = g_tf(X) # Dg = torch.autograd.grad(outputs=[g], inputs=[X], grad_outputs=torch.ones_like(g), allow_unused=True,
def map_product( metric_or_displacement: DisplacementOrMetricFn ) -> DisplacementOrMetricFn: """Vectorizes a metric or displacement function over all pairs.""" return vmap(vmap(metric_or_displacement, (0, None), 0), (None, 0), 0)
def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=True, return_last_val=False, collection_size=None, **progbar_opts): """ This looping construct works like :func:`~jax.lax.fori_loop` but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via `transform`, and progress bar updates. Note that, `progbar=False` will be faster, especially when collecting a lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`. :param int lower: the index to start the collective work. In other words, we will skip collecting the first `lower` values. :param int upper: number of times to run the loop body. :param body_fun: a callable that takes a collection of `np.ndarray` and returns a collection with the same shape and `dtype`. :param init_val: initial value to pass as argument to `body_fun`. Can be any Python collection type containing `np.ndarray` objects. :param transform: a callable to post-process the values returned by `body_fn`. :param progbar: whether to post progress bar updates. :param bool return_last_val: If `True`, the last value is also returned. This has the same type as `init_val`. :param int collection_size: Size of the returned collection. If not specified, the size will be ``upper - lower``. If the size is larger than ``upper - lower``, only the top ``upper - lower`` entries will be non-zero. :param `**progbar_opts`: optional additional progress bar arguments. A `diagnostics_fn` can be supplied which when passed the current value from `body_fun` returns a string that is used to update the progress bar postfix. Also a `progbar_desc` keyword argument can be supplied which is used to label the progress bar. :return: collection with the same type as `init_val` with values collected along the leading axis of `np.ndarray` objects. """ assert lower <= upper collection_size = upper - lower if collection_size is None else collection_size assert collection_size >= upper - lower init_val_flat, unravel_fn = ravel_pytree(transform(init_val)) @cached_by(fori_collect, body_fun, transform) def _body_fn(i, vals): val, collection, lower_idx = vals val = body_fun(val) i = np.where(i >= lower_idx, i - lower_idx, 0) collection = ops.index_update(collection, i, ravel_pytree(transform(val))[0]) return val, collection, lower_idx collection = np.zeros((collection_size,) + init_val_flat.shape) if not progbar: last_val, collection, _ = fori_loop(0, upper, _body_fn, (init_val, collection, lower)) else: diagnostics_fn = progbar_opts.pop('diagnostics_fn', None) progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '') vals = (init_val, collection, device_put(lower)) if upper == 0: # special case, only compiling jit(_body_fn)(0, vals) else: with tqdm.trange(upper) as t: for i in t: vals = jit(_body_fn)(i, vals) t.set_description(progbar_desc(i), refresh=False) if diagnostics_fn: t.set_postfix_str(diagnostics_fn(vals[0]), refresh=False) last_val, collection, _ = vals unravel_collection = vmap(unravel_fn)(collection) return (unravel_collection, last_val) if return_last_val else unravel_collection
def map_bond( metric_or_displacement: DisplacementOrMetricFn ) -> DisplacementOrMetricFn: """Vectorizes a metric or displacement function over bonds.""" return vmap(metric_or_displacement, (0, 0), 0)
def logistic_mixture_logpdf(params, data): # params are assumed to be normalized if data.size == 1: return logistic_mixture_logpdf1(params, data) scores = vmap(partial(logistic_mixture_logpdf1, params))(data) return np.sum(scores)
def test_nd_payloads(self): cf = checkify.checkify(lambda x, i: x[i], errors=checkify.index_checks) errs, _ = jax.vmap(cf)(jnp.ones((3, 2)), jnp.array([5, 0, 100])) self.assertIsNotNone(errs.get()) self.assertIn("index 5", errs.get()) self.assertIn("index 100", errs.get())
def _vmapped_projection(self, supports, weights, target_support): return jax.vmap(rainbow_agent.project_distribution, in_axes=(0, 0, None))(supports, weights, target_support)
def update_chains(state, rng_key): keys = jax.random.split(rng_key, self.num_chains) new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, state) return new_states, info
def phase(_theta,_rho): result = vmap(_phase)(_theta,_rho) return result
def relu(x: Array) -> Array: return np.maximum(0, x) @jit def predict(params: List[Tuple[Array]], image: Array) -> Array: activations = image for w, b in params[:-1]: out = np.dot(w, activations) + b activations = relu(out) w, b = params[-1] logits = np.dot(w, activations) + b return logits - jax.scipy.special.logsumexp(logits) batch_predict = vmap(predict, in_axes=(None, 0)) def loss(params: List[Tuple[Array]], images: Array, targets: Array) -> float: preds = batch_predict(params, images) return -np.sum(preds * targets) @jit def update(params: List[Tuple[Array]], x: Array, y: Array) -> List[Array]: grads = grad(loss)(params, x, y) return [ (w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads) ]
def mass_matrix_inv_mul(self, q: jnp.ndarray, v: jnp.ndarray, **kwargs) -> jnp.ndarray: """Computes the product of the inverse mass matrix with a vector.""" if self.kinetic_func_form in ("separable_net", "dep_net"): raise ValueError( "It is not possible to compute `M^-1 p` when using a " "network for the kinetic energy.") if self.kinetic_func_form in ("pure_quad", "embed_quad"): return v if self.kinetic_func_form == "matrix_diag_quad": if self.parametrize_mass_matrix: m_diag_log = hk.get_parameter( "MassMatrixDiagLog", shape=[self.system_dim], init=hk.initializers.Constant(0.0)) m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps) else: m_inv_diag_log = hk.get_parameter( "InvMassMatrixDiagLog", shape=[self.system_dim], init=hk.initializers.Constant(0.0)) m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps return m_inv_diag * v if self.kinetic_func_form == "matrix_quad": if self.parametrize_mass_matrix: m_triu = hk.get_parameter( "MassMatrixU", shape=[self.system_dim, self.system_dim], init=hk.initializers.Identity()) m_triu = jnp.triu(m_triu) m = jnp.matmul(m_triu.T, m_triu) m = m + self.mass_eps * jnp.eye(self.system_dim) solve = jnp.linalg.solve for _ in range(v.ndim + 1 - m.ndim): solve = jax.vmap(solve, in_axes=(None, 0)) return solve(m, v) else: m_inv_triu = hk.get_parameter( "InvMassMatrixU", shape=[self.system_dim, self.system_dim], init=hk.initializers.Identity()) m_inv_triu = jnp.triu(m_inv_triu) m_inv = jnp.matmul(m_inv_triu.T, m_inv_triu) m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) return self.feature_matrix_vector(m_inv, v) if self.kinetic_func_form in ("matrix_dep_diag_quad", "matrix_dep_diag_embed_quad"): if self.parametrize_mass_matrix: m_diag_log = self.mass_matrix_net(q, **kwargs) m_inv_diag = 1.0 / (jnp.exp(m_diag_log) + self.mass_eps) else: m_inv_diag_log = self.mass_matrix_net(q, **kwargs) m_inv_diag = jnp.exp(m_inv_diag_log) + self.mass_eps return m_inv_diag * v if self.kinetic_func_form in ("matrix_dep_quad", "matrix_dep_embed_quad"): if self.parametrize_mass_matrix: m_triu = self.mass_matrix_net(q, **kwargs) m_triu = utils.triu_matrix_from_v(m_triu, self.system_dim) m = jnp.matmul(jnp.swapaxes(m_triu, -2, -1), m_triu) m = m + self.mass_eps * jnp.eye(self.system_dim) return jnp.linalg.solve(m, v) else: m_inv_triu = self.mass_matrix_net(q, **kwargs) m_inv_triu = utils.triu_matrix_from_v(m_inv_triu, self.system_dim) m_inv = jnp.matmul(jnp.swapaxes(m_inv_triu, -2, -1), m_inv_triu) m_inv = m_inv + self.mass_eps * jnp.eye(self.system_dim) return self.feature_matrix_vector(m_inv, v) raise NotImplementedError()
def map_product(metric_or_displacement): return vmap(vmap(metric_or_displacement, (0, None), 0), (None, 0), 0)
def plot_aleatoric_var_vs_time(gp, solver, traj_init, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) mixing_probs_init = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )( traj_init[:, 0:2], gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt, ) if mixing_probs_init.shape[-1] == 1: mixing_probs_init = np.concatenate( [mixing_probs_init, 1 - mixing_probs_init], -1 ) if traj_opt is not None: mixing_probs_opt = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )( traj_opt[:, 0:2], gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt, ) if mixing_probs_opt.shape[-1] == 1: mixing_probs_opt = np.concatenate( [mixing_probs_opt, 1 - mixing_probs_opt], -1 ) noise_vars = np.array(gp.noise_vars).reshape(-1, 1) var_init = mixing_probs_init @ noise_vars var_opt = 0 if traj_opt is not None: var_opt = mixing_probs_opt @ noise_vars print("var opt") print(var_opt.shape) fig, ax = plt.subplots(1, 1, figsize=(6.4, 2.8)) ax.set_xlabel("Time $t$") ax.set_ylabel("$\sum_{k=1}^K\Pr(\\alpha=k|\mathbf{x}) (\sigma^{(k)})^2$") ax.plot( solver.times, var_init, color=color_init, label="Initial trajectory" ) if traj_opt is not None: ax.plot( solver.times, var_opt, color=color_opt, label="Optimised trajectory", ) ax.legend() sum_var_init = np.sum(var_init) sum_var_opt = np.sum(var_opt) print("Sum aleatoric var init = ", sum_var_init) print("Sum aleatoric var opt = ", sum_var_opt) return fig, ax
def map_bond(metric_or_displacement): return vmap(metric_or_displacement, (0, 0), 0)
def compute_ssim(img0, img1, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False): """Computes SSIM from two images. This function was modeled after tf.image.ssim, and should produce comparable output. Args: img0: array. An image of size [..., width, height, num_channels]. img1: array. An image of size [..., width, height, num_channels]. max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. filter_size: int >= 1. Window size. filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. k1: float > 0. One of the SSIM dampening parameters. k2: float > 0. One of the SSIM dampening parameters. return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned Returns: Each image's mean SSIM, or a tensor of individual values if `return_map`. """ # Construct a 1D Gaussian blur filter. hw = filter_size // 2 shift = (2 * hw - filter_size + 1) / 2 f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2 filt = jnp.exp(-0.5 * f_i) filt /= jnp.sum(filt) # Blur in x and y (faster than the 2D convolution). filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") # Vmap the blurs to the tensor size, and then compose them. num_dims = len(img0.shape) map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) for d in map_axes: filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) filt_fn = lambda z: filt_fn1(filt_fn2(z)) mu0 = filt_fn(img0) mu1 = filt_fn(img1) mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filt_fn(img0**2) - mu00 sigma11 = filt_fn(img1**2) - mu11 sigma01 = filt_fn(img0 * img1) - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: sigma00 = jnp.maximum(0., sigma00) sigma11 = jnp.maximum(0., sigma11) sigma01 = jnp.sign(sigma01) * jnp.minimum( jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) c1 = (k1 * max_val)**2 c2 = (k2 * max_val)**2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) return ssim_map if return_map else ssim
def wrapped_fn(Ra, Rb, **kwargs): return vmap(vmap(metric_or_displacement, (None, 0)))(-Ra, -Rb, **kwargs)
import jax.numpy as np from jax import grad, jit, vmap from functools import partial def predict(params, inputs): for W, b in params: outputs = np.dot(inputs, W) + b inputs = np.tanh(outputs) return outputs def logprob_fun(params, inputs, targets): preds = predict(params, inputs) return np.sum((preds - targets)**2) grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function perex_grads = jit(lambda params, inputs, targets: # fast per-example gradients vmap(partial(grad_fun, params), inputs, targets))
def objective_vectorized(X): # x is (N,2) f = vmap(objective)(X) return f